当前位置: 首页 > news >正文

深入浅出 Multi-Head Attention:原理 + 例子 + PyTorch 实现

本文带你一步步理解 Transformer 中最核心的模块:多头注意力机制(Multi-Head Attention)。从原理到实现,配图 + 举例 + PyTorch 代码,一次性说清楚!


什么是 Multi-Head Attention?

简单说,多头注意力就是一种让模型在多个角度“看”一个序列的机制。

在自然语言中,一个词的含义往往依赖于上下文,比如:

“我把苹果给了她”

模型在处理“苹果”时,需要关注“我”“她”“给了”等词,多头注意力就是这样一种机制——从多个角度理解上下文关系。


Self-Attention 是什么?为什么还要多头?

在讲“多头”之前,咱们先回顾一下基础的 Self-Attention

Self-Attention(自注意力)机制的目标是:

让每个词都能“关注”整个句子里的其他词,融合上下文。

它的核心步骤是:

  1. 对每个词生成 Query、Key、Value 向量

  2. 用 Query 和所有 Key 做点积,算出每个词对其他词的关注度(打分)

  3. 用 Softmax 得到权重,对 Value 加权平均,生成当前词的新表示

这样做的好处是:词的语义表示不再是孤立的,而是上下文相关的。


Self-Attention vs Multi-Head Attention

但问题是——单头 Self-Attention 视角有限。就像一个老师只能从一种角度讲课。

于是,Multi-Head Attention 应运而生

特性Self-Attention(单头)Multi-Head Attention(多头)
输入映射矩阵一组 Q/K/V 线性变换多组 Q/K/V,每个头一组
学习角度单一视角多角度并行理解
表达能力有限更丰富、强大
结构简单并行多个头 + 合并输出

一句话总结:

Multi-Head Attention = 多个不同“视角”的 Self-Attention 并行处理 + 合并结果


 多头注意力:8个脑袋一起思考!

多头 = 多个“单头注意力”并行处理!

每个头使用不同的线性变换矩阵,所以能从不同视角处理数据:

  • 第1个头可能专注短依赖(like 动词和主语)

  • 第2个头可能专注实体关系(我 vs 她)

  • 第3个头可能关注时间顺序(“给了”前后)

  • ……共用同一个输入,学习到不同特征!

多头的步骤:

  1. 将输入向量(如512维)拆成多个头(比如8个,每个64维)

  2. 每个头独立进行 attention

  3. 所有头的输出拼接

  4. 再过一次线性变换,融合成最终输出


 PyTorch 实现(简洁版)

我们来看下 PyTorch 中的简化实现:

import torch
import torch.nn as nn
import copydef clones(module, N):return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])def attention(query, key, value, mask=None, dropout=None):d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = torch.softmax(scores, dim=-1)if dropout:p_attn = dropout(p_attn)return torch.matmul(p_attn, value), p_attnclass MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):super().__init__()assert d_model % h == 0self.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, mask=None):if mask is not None:mask = mask.unsqueeze(1)nbatches = query.size(0)query, key, value = [lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for lin, x in zip(self.linears, (query, key, value))]x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)

举个例子:多头在实际模型中的作用

假设输入是句子:

"The animal didn't cross the street because it was too tired."

多头注意力的不同头可能会:

  • 🧠 头1:关注“animal”和“it”之间的指代关系;

  • 📐 头2:识别“because”和“tired”之间的因果联系;

  • 📚 头3:注意句子的结构层次……

所以说,多头注意力本质上是一个“并行注意力专家系统”!


 总结

项目解释
目的提升模型表达能力,从多个角度理解输入
核心机制将向量分头 → 每头独立 attention → 合并输出
技术关键view, transpose, matmul, softmax, 拼接线性层

推荐学习路径

  • 🔹 理解 Self-Attention 的点积公式

  • 🔹 搞懂 view, transpose 等张量操作

  • 🔹 看 Transformer 整体结构,关注每层作用

相关文章:

  • 研0大模型学习(第四、五天)
  • 武林秘籍之INSERT篇:一键插入,笑傲数据库
  • 数据分析处理库Pandas常用方法汇总
  • 极狐GitLab 项目和群组的导入导出速率限制如何设置?
  • 论文阅读--Orient Anything
  • spring注解@Transactional会回滚哪些异常
  • 供应链项目技术实现方案,供应链详细设计方案书,采购管理,财务管理(Word原件)
  • [Vue3]动态引入图片
  • L2-002 链表去重
  • MATLAB 控制系统设计与仿真 - 36
  • 使用 PySpark 批量清理 Hive 表历史分区
  • 在Qt中验证LDAP账户(Windows平台)
  • 【dataframe显示不全问题】打开一个行列超多的excel转成df之后行列显示不全
  • Android tinyalsa库函数剖析
  • 几款开源C#插件框架
  • 2025年山东燃气瓶装送气工考试真题练习
  • 单调队列模板cpp
  • Java漏洞原理与实战
  • RT-DETR源码学习bug记录
  • 51单片机实验七:EEPROM AT24C02 与单片机的通信实例
  • 原四川省农村信用社联合社党委副书记、监事长杨家卷被查
  • 超导电路新设计有望提升量子处理器速度
  • 新华每日电讯:上海“绿色大民生”撑起“春日大经济”
  • 夹缝中的责编看行业:长视频之殇,漫长周期
  • 新华保险一季度净赚58.82亿增19%,保费收入增28%
  • 民生访谈|规范放生活动、提升供水品质……上海将有这些举措