Grouped Query Attention (GQA) PyTorch实现
个人在网上看到的实现好像都长得奇奇怪怪的,没有简洁的感觉,因此在这里给出一种易读的GQA实现方法:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass GroupedQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads, num_groups):super().__init__()assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"self.num_heads = num_headsself.num_groups = num_groupsself.head_dim = embed_dim // num_headsself.group_dim = self.num_groups * self.head_dim # Correct: num_groups * head_dimself.scale = self.head_dim ** -0.5# Projectionsself.q_proj = nn.Linear(embed_dim, embed_dim) # Query: full embed_dim for num_headsself.k_proj = nn.Linear(embed_dim, self.group_dim) # Key: group_dim for num_groupsself.v_proj = nn.Linear(embed_dim, self.group_dim) # Value: group_dim for num_groupsself.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):batch_size, seq_len, embed_dim = x.shape# Project inputs to q, k, vq = self.q_proj(x) # Shape: (batch_size, seq_len, embed_dim)k = self.k_proj(x) # Shape: (batch_size, seq_len, group_dim)v = self.v_proj(x) # Shape: (batch_size, seq_len, group_dim)# Reshape query for multi-head attentionq = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# Shape: (batch_size, num_heads, seq_len, head_dim)# Reshape key and value for grouped attentionk = k.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)# Shape: (batch_size, num_groups, seq_len, head_dim)v = v.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)# Shape: (batch_size, num_groups, seq_len, head_dim)# Repeat k and v to match the number of query headsheads_per_group = self.num_heads // self.num_groupsk = k.repeat_interleave(heads_per_group, dim=1)# Shape: (batch_size, num_heads, seq_len, head_dim)v = v.repeat_interleave(heads_per_group, dim=1)# Shape: (batch_size, num_heads, seq_len, head_dim)# Compute attention scoresscores = torch.matmul(q, k.transpose(-2, -1)) * self.scale# Shape: (batch_size, num_heads, seq_len, seq_len)attn = F.softmax(scores, dim=-1)out = torch.matmul(attn, v) # Shape: (batch_size, num_heads, seq_len, head_dim)# Reshape and project outputout = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)out = self.out_proj(out) # Shape: (batch_size, seq_len, embed_dim)return out# Test the model
embed_dim = 64
num_heads = 8
num_groups = 4
model = GroupedQueryAttention(embed_dim, num_heads, num_groups)
x = torch.randn(2, 10, embed_dim) # Input shape: (batch_size, seq_len, embed_dim)
output = model(x)
print(output.shape) # Expected output: torch.Size([2, 10, 64])
为了读懂GQA,建议读者了解一下MQA的实现,这样顺着读下来会更顺手。
一旦读懂了MQA,GQA的实现思路几乎完全一样,只是多用了一个不太常用的函数tensor.repeat_interleave。关于这个函数,直接点击链接看笔者相关文章就行了,挺好懂的。