当前位置: 首页 > news >正文

Multi-Query Attention (MQA) PyTorch 实现

和多头注意力机制的唯一区别:K、V在不同的head之间实现了复用,而对于不同的头,Q依然不同。

因此这里的代码和标准多头注意力的实现也是几乎完全一样:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.scale = self.head_dim ** -0.5# 查询、键、值投影self.q_proj = nn.Linear(embed_dim, embed_dim)  # 多头查询self.k_proj = nn.Linear(embed_dim, self.head_dim)  # 单头键self.v_proj = nn.Linear(embed_dim, self.head_dim)  # 单头值self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):batch_size, seq_len, embed_dim = x.shape# 投影q = self.q_proj(x)  # (batch, seq_len, embed_dim)k = self.k_proj(x)  # (batch, seq_len, head_dim)v = self.v_proj(x)  # (batch, seq_len, head_dim)# 重塑查询为多头q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# (batch, num_heads, seq_len, head_dim)# 键和值保持单头,扩展到多头维度k = k.unsqueeze(1)  # (batch, 1, seq_len, head_dim)v = v.unsqueeze(1)  # (batch, 1, seq_len, head_dim)# 注意力计算scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale# (batch, num_heads, seq_len, seq_len)attn = F.softmax(scores, dim=-1)out = torch.matmul(attn, v)  # (batch, num_heads, seq_len, head_dim)# 合并多头out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)out = self.out_proj(out)  # (batch, seq_len, embed_dim)return out# 示例用法
embed_dim = 64
num_heads = 8
model = MultiQueryAttention(embed_dim, num_heads)
x = torch.randn(2, 10, embed_dim)  # (batch, seq_len, embed_dim)
output = model(x)
print(output.shape)  # torch.Size([2, 10, 64])

相关文章:

  • 2. ubuntu20.04 和VS Code实现 ros的输出 (C++,Python)
  • JAVA设计模式——(1)适配器模式
  • .gitignore 可能失效的原因
  • 在 Amazon Graviton 上运行大语言模型:CPU 推理性能实测与调优指南
  • XCVU13P-2FHGA2104I Xilinx Virtex UltraScale+ FPGA
  • 基于LSTM-AutoEncoder的心电信号时间序列数据异常检测(PyTorch版)
  • 简单代码应用
  • Linux(autoDL云服务器)mamba-ssm环境安装——一次成功!
  • 【计算机网络 | 第二篇】常见的通信协议(一)
  • 【HDFS入门】HDFS数据冗余与容错机制解析:如何保障大数据高可靠存储?
  • day29 学习笔记
  • 洛谷题目:P8624 [蓝桥杯 2015 省 AB] 垒骰子 题解 (本题简)
  • linux kernel irq相关函数详解
  • 系分架构论文《论高并发场景的架构设计和开发方法》
  • 股指期货跨期套利是如何赚取价差利润的?
  • Java实现将MarkDown保留文档内容及格式输出到浏览器页面
  • 基于控制台的小车导航游戏开发详解(C++实现)
  • 嘉立创原理图、PCB常见问题
  • 10.thinkphp的响应
  • MCP协议驱动的功能纳米材料设计及其在光催化甲烷偶联中的创新应用
  • C909飞机开启越南商业运营
  • 直播电商监管新规将公开征求意见,出重拳净化行业生态
  • 上海警方:男子拍摄女性视频后在网上配发诱导他人违法犯罪文字,被行拘
  • 地铁口被吐槽像棺材?杭州地铁公司回应:是一个标志性出入口
  • 美国佛罗里达州立大学发生枪击事件
  • 见微知沪|让民营企业与城市共成长,上海拿出“三件宝”