当前位置: 首页 > news >正文

Grouped Query Attention (GQA) PyTorch实现

个人在网上看到的实现好像都长得奇奇怪怪的,没有简洁的感觉,因此在这里给出一种易读的GQA实现方法:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass GroupedQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads, num_groups):super().__init__()assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"self.num_heads = num_headsself.num_groups = num_groupsself.head_dim = embed_dim // num_headsself.group_dim = self.num_groups * self.head_dim  # Correct: num_groups * head_dimself.scale = self.head_dim ** -0.5# Projectionsself.q_proj = nn.Linear(embed_dim, embed_dim)  # Query: full embed_dim for num_headsself.k_proj = nn.Linear(embed_dim, self.group_dim)  # Key: group_dim for num_groupsself.v_proj = nn.Linear(embed_dim, self.group_dim)  # Value: group_dim for num_groupsself.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):batch_size, seq_len, embed_dim = x.shape# Project inputs to q, k, vq = self.q_proj(x)  # Shape: (batch_size, seq_len, embed_dim)k = self.k_proj(x)  # Shape: (batch_size, seq_len, group_dim)v = self.v_proj(x)  # Shape: (batch_size, seq_len, group_dim)# Reshape query for multi-head attentionq = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# Shape: (batch_size, num_heads, seq_len, head_dim)# Reshape key and value for grouped attentionk = k.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)# Shape: (batch_size, num_groups, seq_len, head_dim)v = v.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)# Shape: (batch_size, num_groups, seq_len, head_dim)# Repeat k and v to match the number of query headsheads_per_group = self.num_heads // self.num_groupsk = k.repeat_interleave(heads_per_group, dim=1)# Shape: (batch_size, num_heads, seq_len, head_dim)v = v.repeat_interleave(heads_per_group, dim=1)# Shape: (batch_size, num_heads, seq_len, head_dim)# Compute attention scoresscores = torch.matmul(q, k.transpose(-2, -1)) * self.scale# Shape: (batch_size, num_heads, seq_len, seq_len)attn = F.softmax(scores, dim=-1)out = torch.matmul(attn, v)  # Shape: (batch_size, num_heads, seq_len, head_dim)# Reshape and project outputout = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)out = self.out_proj(out)  # Shape: (batch_size, seq_len, embed_dim)return out# Test the model
embed_dim = 64
num_heads = 8
num_groups = 4
model = GroupedQueryAttention(embed_dim, num_heads, num_groups)
x = torch.randn(2, 10, embed_dim)  # Input shape: (batch_size, seq_len, embed_dim)
output = model(x)
print(output.shape)  # Expected output: torch.Size([2, 10, 64])

为了读懂GQA,建议读者了解一下MQA的实现,这样顺着读下来会更顺手。

一旦读懂了MQA,GQA的实现思路几乎完全一样,只是多用了一个不太常用的函数tensor.repeat_interleave。关于这个函数,直接点击链接看笔者相关文章就行了,挺好懂的。

相关文章:

  • 关于学习STM32的C语言的知识
  • matlab 处理海洋数据并画图的工具包--ocean_data_tools
  • 基于模板匹配的信用卡号码识别系统
  • 学习笔记十七——Rust 支持面向对象编程吗?
  • system V消息队列和信号量的学习
  • Python番外——常用的包功能讲解和分类组合
  • 服务治理-搭建Nacos注册中心
  • @EnableAsync+@Async源码学习笔记之六
  • 【自动化测试框架】什么是对象层?
  • [密码学基础]密码学常用名词深度解析:从基础概念到实战应用
  • npm 常用操作和配置
  • 国产GPU生态现状评估:从寒武纪到壁仞的编程适配挑战
  • DeepSeek与Napkin:信息可视化领域的创新利器
  • 安徽合肥京东自营代运营如何突围?
  • CSRF 请求伪造Referer 同源置空配合 XSSToken 值校验复用删除
  • 第3章 垃圾收集器与内存分配策略《深入理解Java虚拟机:JVM高级特性与最佳实践(第3版)》
  • FPGA练习———DDS波形发生器
  • Linux419 三次握手四次挥手抓包 wireshark
  • Dubbo(65)如何实现Dubbo的服务文档生成?
  • ThingsBoard3.9.1 MQTT Topic(3)
  • 关税战推高成本,美澳“奥库斯”核潜艇协议或将生变
  • 马上评|治理“龟速车”,也是一次驾驶文明的升级
  • 智慧菜场团标试验:标准化的同时还能保留个性化吗?
  • 老年人越“懒”越健康,特别是这5种“懒”
  • 刘国梁:奥运会乒乓球项目增至六金,国乒机遇与挑战并存
  • 金融监管总局:建立保险销售人员违法违规导致经济损失的佣金薪酬追索扣回机制