当前位置: 首页 > news >正文

推荐系统(二十四):Embedding层的参数是如何在模型训练过程中学习的?

近来有不少读者私信我关于嵌入层(Embedding层)参数在模型训练过程中如何学习的问题。虽然之前已经在不少文章介绍过 Embedding,但是为了读者更好地理解,笔者将通过本文详细解读嵌入层(Embedding Layer)的参数如何更新的,尤其是在反向传播过程中,为什么输入层的参数会被更新,而通常反向传播更新的是神经网络的权重参数,而不是输入数据,很多读者可能会混淆输入数据和嵌入层的权重参数。

一、嵌入层参数更新的本质机制解析

嵌入层(Embedding Layer)的参数更新逻辑与传统神经网络层的参数更新有本质区别,核心在于嵌入层的权重矩阵本身就是模型参数,而非输入数据的静态特征。以下从技术原理和实现细节两个层面详细说明:

1.参数定位:权重矩阵是模型的一部分

  • 权重矩阵的角色
    嵌入层的核心是一个可学习的权重矩阵,以商品类目为例,类目对应的 embedding 矩阵形状为 (类目词汇表大小, 嵌入维度)。例如,当类目词汇表包含 1000 个类目且嵌入维度为 300 时,该矩阵为 1000×300 的可训练参数。
    关键区别:输入数据是离散的索引(如类目ID=5),而嵌入层的权重矩阵是模型的一部分,与全连接层的权重性质完全相同。

  • 查表操作的实质
    前向传播时,输入索引(如[2,5])通过查表操作(类似weight[[2,5]])提取对应行向量。此过程看似是“输入处理”,但查表操作本身包含参数访问,因此反向传播时梯度会传递到权重矩阵。

2.反向传播的梯度流动路径

  • 梯度计算示例
    假设输入索引为 i,嵌入向量为 E[i],后续网络层对损失函数的梯度为 dL/dE[i]。通过链式法则:梯度 dL/dE[i] 直接作用于权重矩阵的第 i 行;优化器根据梯度更新矩阵中所有被使用的行向量。读者可能还是有疑惑:Embedding 向量本身不是神经网络的参数,怎么通过梯度更新呢?这里可以简单理解为——原始 embedding 向量为 [1,1,1,1,1,…], 即全 1 向量,而在反向传播中更新的 embedding 向量是一个权重向量 [w1,w2,w3,w4,…],如此一来,在反向传播中不断更新的 embedding 向量 [w1,w2,w3,w4,…] 最终就是真正的 embedding 向量,而虚拟的 [1,1,1,1,1,…] 全为 1,没有存在的必要,当然本身就不存在。所以,实际(embedding 参数) [w1,w2,w3,w4,…]本身就是模型参数的一部分,与原始的输入(比如类目code、索引)是无关的。

  • 参数更新范围
    与全连接层不同,嵌入层仅更新被实际使用的行向量。例如,若某批次仅使用索引 2 和 5,则只更新矩阵的第 2 行和第 5 行,未使用的行(如索引3)梯度为 0。这种稀疏更新特性使其适合处理大规模词表。

3.与“输入数据不更新”原则的兼容性

  • 输入数据与参数的分离
    输入数据:始终是固定的离散索引(如商品类目ID=5),训练过程中不会被修改;
    参数矩阵:根据输入数据的选择性索引,动态更新对应行向量;

  • 类比说明
    将嵌入层类比为“可学习的字典”:字典内容(权重矩阵)由模型拥有,训练目标是优化字典中的词义解释
    输入数据(索引)只是查询字典的钥匙,钥匙本身不改变。

4.PyTorch实现验证

通过代码可直观地观察参数更新过程:

import torch
import torch.nn as nn# 初始化嵌入层(类目词表大小=5,嵌入维度=3)
embedding = nn.Embedding(5, 3)
print("初始权重矩阵:\n", embedding.weight.data)# 模拟训练步骤
optimizer = torch.optim.SGD(embedding.parameters(), lr=0.1)
input_indices = torch.LongTensor([2])  # 选择索引2# 前向传播
output = embedding(input_indices)
loss = output.sum()  # 假设损失函数为输出求和
loss.backward()# 更新参数前查看梯度
print("更新前梯度:\n", embedding.weight.grad)
optimizer.step()# 更新后查看权重变化(仅索引2的行被更新)
print("更新后权重矩阵:\n", embedding.weight.data)

输出:

初始权重矩阵:tensor([[ 0.6638, -1.2098,  0.4363],[ 0.3482,  0.8564, -0.1783],[ 1.2376, -0.5921,  0.9815],  # ← 索引2的初始值[-0.1156,  0.3279,  0.8942],[-0.4392,  0.6653,  0.2191]])更新前梯度:tensor([[0., 0., 0.],[0., 0., 0.],[1., 1., 1.],  # ← 仅索引2的梯度非零[0., 0., 0.],[0., 0., 0.]])更新后权重矩阵:tensor([[ 0.6638, -1.2098,  0.4363],[ 0.3482,  0.8564, -0.1783],[ 1.1376, -0.6921,  0.8815],  # ← 索引2的行值改变(lr=0.1)[-0.1156,  0.3279,  0.8942],[-0.4392,  0.6653,  0.2191]])

5.与全连接层的对比在这里插入图片描述


二、如何计算模型参数量

以 Wide&Deep 模型为例,推荐系统参数量计算需分别分析其 Wide 部分(线性模型)和 Deep 部分(深度神经网络)的结构特征。以下是具体计算方法及实现要点:

1.Wide 部分参数量计算

Wide 部分采用广义线性模型,核心参数由特征交叉项和偏置项构成:
Params wide = ( d w i d e + 1 ) \text{Params}_{\text{wide}} = (d_{wide} + 1) Paramswide=(dwide+1)

  • 特征交叉项:若输入特征维度为 d(包含原始特征和交叉特征),则权重向量维度为 d。
  • 偏置项:1 个标量参数。
  • 示例:若使用用户性别(2 种取值)与商品类别(10 种取值)的交叉特征,则交叉维度为 2×10=20,总参数量为 20+1=21。

2.Deep 部分参数量计算

Deep 部分由嵌入层(Embedding)和全连接层(Dense Layer)构成,计算需分模块:

2.1 嵌入层(Embedding Layer)

Embedding 层的参数量计算公式如下:
Params embed = ∑ i = 1 k ( vocab_size i × e i ) \text{Params}_{\text{embed}} = \sum_{i=1}^k (\text{vocab\_size}_i \times e_i) Paramsembed=i=1k(vocab_sizei×ei)

其中,k 为类别特征数量,vocab_size_i 为第 i 个特征的词表大小,e_i 为对应嵌入维度。

2.2 全连接层参数(DNN Layer)

L 为隐藏层数,n_l 为第 l 层神经元数,n_l − 1 为上一层输出维度。例如,输入维度为100,隐藏层为 100→64→32,则参数量为:
Params dnn = ∑ l = 1 L ( n l − 1 × n l + n l ) = ( 100 × 64 + 64 ) + ( 64 × 32 + 32 ) (实例展开) \begin{aligned} \text{Params}_{\text{dnn}} &= \sum_{l=1}^L (n_{l-1} \times n_l + n_l) \\ &= (100 \times 64 + 64) + (64 \times 32 + 32) \quad \text{(实例展开)} \end{aligned} Paramsdnn=l=1L(nl1×nl+nl)=(100×64+64)+(64×32+32)(实例展开)

2.3 联合输出层参数

联合输出层将 Wide 和 Deep 的输出拼接后通过 Sigmoid 函数:
Params out = ( d w i d e + d d e e p ) ∗ 1 + 1 \text{Params}_{\text{out}} = (d_{wide} + d_{deep})*1+1 Paramsout=(dwide+ddeep)1+1

d_deep 为 Deep 部分最后一层输出维度,+1 为偏置项.

2.4 总参数量

总参数量为三部分之和:
T o t a l P a r a m s = P a r a m s w i d e + P a r a m s e m b e d + P a r a m s d n n + P a r a m s o u t p u t \text Total Params=Params_{wide} + Params_{embed} + Params_{dnn} + Params_{output} TotalParams=Paramswide+Paramsembed+Paramsdnn+Paramsoutput

相关文章:

  • 分糖果——牛客
  • Ragflow、Dify、FastGPT、COZE核心差异对比与Ragflow的深度文档理解能力​​和​​全流程优化设计
  • 文件系统常见函数
  • 2022 年 9 月青少年软编等考 C 语言七级真题解析
  • 根据定义给出json_schema:
  • 【Python】每隔一段时间自动清除网站上cookies的方法
  • 使用 Streamlit 打造一个简单的照片墙应用
  • 极狐GitLab 的压缩和合并是什么?
  • sglang部署DeepSeek-R1-Distill-Qwen-7B
  • fpga系列 HDL:跨时钟域同步 脉冲展宽同步 Pulse Synchronization
  • 四神-华夏大地的守护神
  • 今天开始着手准备PAT(乙级)
  • 第一节:核心概念高频题-Vue3响应式原理与Vue2的区别
  • MYSQL之表的操作
  • 在面试中被问到spring是什么?
  • Kubernetes Multus CNI详细剖析
  • 渗透测试中的信息收集:从入门到精通
  • 爬虫学习总结
  • 滑动窗口算法(一)
  • Transformer起源-Attention Is All You Need
  • 央行上海总部:受益于过境免签政策,上海市外卡刷卡支付交易量稳步增长
  • 神十九乘组视频祝福第十个中国航天日,展望中华民族登月梦圆
  • 金发科技去年净利增160%,机器人等新领域催生材料新需求
  • 聚焦客户真实需求,平安人寿重磅推出“添平安”保险+服务解决方案
  • 新质生产力的宜昌解法:抢滩“高智绿”新赛道,化工产品一克卖数千元
  • 路面突陷大坑致车毁人亡,家属称不知谁来管,长治当地回应