多头注意力(Multi‑Head Attention)
1. 多头注意力(Multi‑Head Attention)原理
设输入序列表示为矩阵 X ∈ R B × L × d model X\in\mathbb{R}^{B\times L\times d_{\text{model}}} X∈RB×L×dmodel,其中
- B B B:批大小(batch size),
- L L L:序列长度(sequence length),
- d model d_{\text{model}} dmodel:模型隐层维度(model dimension)。
多头注意力基于对缩放点乘注意力的并行化扩展,引入了 h h h 个注意力头(heads),每个头在不同子空间中学习不同的表示。
1.1 线性映射与切分
我们首先为每个头定义三组可学习权重:
W i Q ∈ R d model × d k , W i K ∈ R d model × d k , W i V ∈ R d model × d v , i = 1 , … , h W_i^Q \in \mathbb{R}^{d_{\text{model}}\times d_k},\quad W_i^K \in \mathbb{R}^{d_{\text{model}}\times d_k},\quad W_i^V \in \mathbb{R}^{d_{\text{model}}\times d_v}, \quad i=1,\dots,h WiQ∈Rdmodel×dk,WiK∈Rdmodel×dk,WiV∈Rdmodel×dv,i=1,…,h
其中
- h h h:头数(number of heads),
- d k d_k dk:每个头中 Query/Key 的维度(key/query dimension),
- d v d_v dv:每个头中 Value 的维度(value dimension),
- 通常 d model = h × d k d_{\text{model}}=h\times d_k dmodel=h×dk 且取 d v = d k d_v = d_k dv=dk。
对输入 X X X 进行投影,得到第 i i i 个头的查询、键、值:
Q i = X W i Q , K i = X W i K , V i = X W i V Q_i = X\,W_i^Q,\quad K_i = X\,W_i^K,\quad V_i = X\,W_i^V Qi=XWiQ,Ki=XWiK,Vi=XWiV
其中
- Q i ∈ R B × L × d k Q_i \in \mathbb{R}^{B\times L\times d_k} Qi∈RB×L×dk,
- K i ∈ R B × L × d k K_i \in \mathbb{R}^{B\times L\times d_k} Ki∈RB×L×dk,
- V i ∈ R B × L × d v V_i \in \mathbb{R}^{B\times L\times d_v} Vi∈RB×L×dv。
1.2 缩放点乘注意力
对第 i i i 个头,分别对所有位置做点积注意力:
- 打分矩阵
S c o r e i = Q i K i ⊤ ∈ R B × L × L \mathrm{Score}_i = Q_i\,K_i^\top \quad\in\mathbb{R}^{B\times L\times L} Scorei=QiKi⊤∈RB×L×L - 缩放
S c o r e ~ i = S c o r e i d k \widetilde{\mathrm{Score}}_i = \frac{\mathrm{Score}_i}{\sqrt{d_k}} Score i=dkScorei - Softmax 归一化
A i = s o f t m a x ( S c o r e ~ i ) ∈ R B × L × L A_i = \mathrm{softmax}\bigl(\widetilde{\mathrm{Score}}_i\bigr) \quad\in\mathbb{R}^{B\times L\times L} Ai=softmax(Score i)∈RB×L×L - 加权求和
h e a d i = A i V i ∈ R B × L × d v \mathrm{head}_i = A_i\,V_i \quad\in\mathbb{R}^{B\times L\times d_v} headi=AiVi∈RB×L×dv
1.3 拼接与线性变换
将所有头的输出在最后一维拼接,再做一次线性投影:
C o n c a t = [ h e a d 1 , … , h e a d h ] ∈ R B × L × ( h d v ) \mathrm{Concat} = \bigl[\mathrm{head}_1,\dots,\mathrm{head}_h\bigr] \quad\in\mathbb{R}^{B\times L\times (h\,d_v)} Concat=[head1,…,headh]∈RB×L×(hdv)
定义输出权重矩阵
W O ∈ R ( h d v ) × d model W^O\in\mathbb{R}^{(h\,d_v)\times d_{\text{model}}} WO∈R(hdv)×dmodel
最终输出:
M u l t i H e a d ( X ) = C o n c a t W O ∈ R B × L × d model \mathrm{MultiHead}(X) = \mathrm{Concat}\;W^O \quad\in\mathbb{R}^{B\times L\times d_{\text{model}}} MultiHead(X)=ConcatWO∈RB×L×dmodel
2. PyTorch 代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass MultiHeadAttention(nn.Module):def __init__(self, d_model: int, h: int):"""d_model: 模型维度 d_modelh: 注意力头数 h"""super().__init__()assert d_model % h == 0, "d_model 必须能被 h 整除"self.d_model = d_model # d_modelself.h = h # hself.d_k = d_model // h # d_k = d_model / hself.d_v = self.d_k # d_v 通常等于 d_k# 投影矩阵 W_i^Q, W_i^K, W_i^V,实际上合并为一个大矩阵后在 forward 再切分self.w_q = nn.Linear(d_model, d_model) # 同时生成 h 个 Q 投影self.w_k = nn.Linear(d_model, d_model) # 同时生成 h 个 K 投影self.w_v = nn.Linear(d_model, d_model) # 同时生成 h 个 V 投影# 输出线性变换 W^Oself.w_o = nn.Linear(d_model, d_model)def forward(self, X: torch.Tensor, mask: torch.Tensor = None):"""X: 输入张量,形状 (B, L, d_model)mask: 可选掩码,形状 (B, 1, L, L) 或 (B, L, L)"""B, L, _ = X.size()# 1. 线性映射到 Q, K, V,然后切分 h 头# 先得到 (B, L, h*d_k),再 view/transpose 为 (B, h, L, d_k)Q = self.w_q(X).view(B, L, self.h, self.d_k).transpose(1, 2)K = self.w_k(X).view(B, L, self.h, self.d_k).transpose(1, 2)V = self.w_v(X).view(B, L, self.h, self.d_k).transpose(1, 2)# 此时 Q, K, V 形状均为 (B, h, L, d_k)# 2. 计算点积注意力# scores = Q @ K^T -> (B, h, L, L)scores = torch.matmul(Q, K.transpose(-2, -1)) # 缩放:除以 sqrt(d_k)scores = scores / math.sqrt(self.d_k)# 可选掩码:将被屏蔽位置设为 -inf if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))# Softmax 归一化 -> (B, h, L, L)A = F.softmax(scores, dim=-1)# 加权求和 -> head_i 形状 (B, h, L, d_k)heads = torch.matmul(A, V)# 3. 拼接 h 个头:transpose 回 (B, L, h, d_k) 再 reshapeconcat = heads.transpose(1, 2).contiguous().view(B, L, self.h * self.d_k)# concat 形状 (B, L, h*d_k) == (B, L, d_model)# 4. 最后一次线性变换 W^Ooutput = self.w_o(concat) # -> (B, L, d_model)return output, A # 返回输出及注意力权重 A