当前位置: 首页 > news >正文

多头注意力

Multi-Head Attention

  • 论文地址

    https://arxiv.org/pdf/1706.03762

多头注意力介绍

  • 多头注意力是Transformer模型的关键创新,通过并行执行多个独立的注意力计算单元,使模型能够同时关注来自不同表示子空间的信息。每个注意力头学习不同的语义特征,最后通过线性变换将多头的输出组合为最终结果。

    image-20250424185647015

    当n_heads=1时,多头注意力等价于标准缩放点积注意力。多头设计通过建立多个独立的"观察视角",使模型能够捕获更丰富的上下文信息。

数学公式

  • 多头注意力通过以下三个步骤实现:

    1. 线性投影:对Q/K/V进行线性变换
    2. 并行注意力计算:在多个子空间中计算缩放点积注意力
    3. 输出融合:拼接多头结果并进行线性变换

    公式表达为:
    MultiHead ( Q , K , V ) = Concat ( h e a d 1 , . . . , h e a d h ) W O where  h e a d i = Attention ( Q W i Q , K W i K , V W i V ) \text{MultiHead}(Q,K,V) = \text{Concat}(head_1,...,head_h)W^O \\ \text{where } head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) MultiHead(Q,K,V)=Concat(head1,...,headh)WOwhere headi=Attention(QWiQ,KWiK,VWiV)
    其中:

    • W i Q ∈ R s e q _ l e n × d _ k W_i^Q \in \mathbb{R}^{seq\_len \times d\_k} WiQRseq_len×d_k:第i个头的查询投影矩阵
    • W i K ∈ R s e q _ l e n × d _ k W_i^K \in \mathbb{R}^{seq\_len \times d\_k} WiKRseq_len×d_k:第i个头的键投影矩阵
    • W i V ∈ R s e q _ l e n × d _ k W_i^V \in \mathbb{R}^{seq\_len \times d\_k} WiVRseq_len×d_k:第i个头的值投影矩阵
    • W O ∈ R s e q _ l e n × d _ m o d e l W^O \in \mathbb{R}^{seq\_len \times d\_model} WORseq_len×d_model:输出投影矩阵

    s e q _ l e n seq\_len seq_len 为序列长度, d _ m o d e l d\_model d_model 为embedding的维度, d _ k d\_k d_k 为每个注意力头的维度,假设头数量为n_heads ,且n_heads能被d_model给整数,则
    d _ k = d _ m o d e l n _ h e a d s d\_k = \frac{d\_model}{n\_heads} d_k=n_headsd_model

代码实现

  • 需要导入之前写的计算注意力分数的函数

    https://blog.csdn.net/hbkybkzw/article/details/147462845

    import torch
    from torch import nndef calculate_attention(query, key, value, mask=None):d_k = key.shape[-1]att_scaled = torch.matmul(query, key.transpose(-2, -1)) / d_k ** 0.5  # 缩放点积,注意区分这里的 d_k 是多头注意力中的“d_model”if mask is not None:att_scaled = att_scaled.masked_fill_(mask=mask, value=-1e9)  # 或者value= -inf,这样在softmax的时候,掩码的部分就为0了att_softmax = torch.softmax(input=att_scaled, dim=-1)  # softmaxatt_score = torch.matmul(att_softmax, value)  # 注意力分数return att_score
    
  • 多头注意力模块

    class MultiHeadAttention(nn.Module):def __init__(self, n_heads, d_model, dropout_prob=0.1):super(MultiHeadAttention, self).__init__()assert d_model % n_heads == 0  # 需要确保d_model一定要能被注意力头的数量整除# 定义四个全连接层(Q/K/V/Output)self.q_linear = nn.Linear(in_features=d_model, out_features=d_model, bias=False)self.k_linear = nn.Linear(in_features=d_model, out_features=d_model, bias=False)self.v_linear = nn.Linear(in_features=d_model, out_features=d_model, bias=False)self.linear = nn.Linear(in_features=d_model, out_features=d_model, bias=False)self.dropout = nn.Dropout(dropout_prob)self.n_heads = n_headsself.d_k = d_model // n_headsself.d_model = d_modeldef forward(self, q, k, v, mask=None):"""前向传播过程输入形状:[batch_size, seq_len, d_model]输出形状:[batch_size, seq_len, d_model]"""# 步骤1:线性投影并分割多头q = self.q_linear(q)q = q.reshape(q.shape[0], -1, self.n_heads, self.d_k)  # [batch_size,seq_len,n_heads,d_k]q = q.transpose(1, 2)  # [batch_size, n_heads, seq_len, d_k]# 等价于# q = self.q_linear(q).reshape(q.shape[0],-1,self.n_heads,self.d_k).transpose(1,2)k = self.k_linear(k)k = k.reshape(k.shape[0], -1, self.n_heads, self.d_k)k = k.transpose(1, 2)v = self.v_linear(v)v = v.reshape(v.shape[0], -1, self.n_heads, self.d_k)v = v.transpose(1, 2)# 步骤2:计算注意力(每个头独立计算)out = calculate_attention(q, k, v, mask)  # [batch_size,seq_len,n_heads,d_k]# 步骤3:拼接多头结果并进行输出投影out = out.transpose(1, 2)  # [batch_size,n_heads,seq_len,d_k]out = out.reshape(out.shape[0], -1, self.d_model)  # [batch_size,seq_len,d_model]out = self.linear(out)out = self.dropout(out)return out
  • 关键操作解析

    维度变换流程

    操作步骤张量形状变化示例
    输入数据[batch_size, seq_len, d_model]
    步骤1-线性投影(保持维度)[batch_size, seq_len, d_model]
    步骤1-分割多头(n_heads)[batch_size, seq_len, n_heads, d_k]
    步骤1-维度转置(交换头与序列轴)[batch_size, n_heads , seq_len, d_k]
    步骤2-计算注意力(各头独立)[batch_size, n_heads , seq_len, d_k]
    步骤3-维度转置(交换头与序列轴)[batch_size , seq_len, n_heads ,d_k]
    步骤3-拼接多头(恢复原始维度)[batch_size, seq_len, d_model]

使用示例

  • 测试代码

    if __name__ == "__main__":# 实例化模块:8头注意力,512维模型,20% dropoutmultihead_attention = MultiHeadAttention(n_heads=8, d_model=512, dropout_prob=0.2)# 模拟输入:batch_size=4,序列长度100,维度512x = torch.randn(4, 100, 512)# 前向传播(自注意力模式)output = multihead_attention(x, x, x)print("输入形状:", x.shape)        # torch.Size([4, 100, 512])print("输出形状:", output.shape)   # torch.Size([4, 100, 512])print("输出范数:", torch.norm(output))  # 约1.2-1.8(取决于初始化)
    

相关文章:

  • 海思ISP调试记录
  • 使用vue2开发一个在线旅游预订平台-前端静态网站项目练习
  • 第16章:MCP服务端项目开发实战:对话系统
  • 优考试V4.20机构版【可注册】
  • 类模板 (Class Templates)
  • 负环-P3385-P2136
  • 【Microsoft Store 中的软件推荐】
  • 4月24号
  • 乡村治理数字化平台:信息技术赋能乡村振兴的深度探索
  • 西甲001:奥萨苏纳VS塞维利亚
  • 手机访问电脑端Nginx服务器配置方式
  • 【力扣刷题|第五天作业】二分查找-寻找旋转排序数组中的最小值 II
  • frome time import * 与 import time
  • 多物理场仿真底层算法到AI智能仿真革命
  • Win11右键显示全部内容
  • 架构-软件架构设计
  • Windows下QT打包后程序运行后弹出CMD命令窗口的问题解决方法
  • 某大型电解铝厂电解系统谐波治理装置改造沃伦森电气
  • 深度学习3.7 softmax回归的简洁实现
  • 基于大模型的食管平滑肌瘤全周期预测与诊疗方案研究
  • 万能险新规落地:保险期限不得低于五年,明确万能险销售“负面清单”
  • 四川:全省统一取消普通住宅和非普通住宅标准
  • 美称中美芬太尼问题谈判但中方不够真诚,外交部回应
  • 牛市早报|商务部:目前中美之间未进行任何经贸谈判
  • 技术派|“会飞的手榴弹”:微型无人机将深刻改变单兵作战方式
  • 上海黄浦一季度实到外资总量全市第二,同比增速领先全市