当前位置: 首页 > news >正文

自注意力机制、多头自注意力机制、填充掩码 Python实现

原理讲解

【Transformer系列(2)】注意力机制、自注意力机制、多头注意力机制、通道注意力机制、空间注意力机制超详细讲解

自注意力机制

import torch
import torch.nn as nn# 自注意力机制
class SelfAttention(nn.Module):def __init__(self, input_dim):super(SelfAttention, self).__init__()self.query = nn.Linear(input_dim, input_dim)self.key = nn.Linear(input_dim, input_dim)self.value = nn.Linear(input_dim, input_dim)        def forward(self, x, mask=None):batch_size, seq_len, input_dim = x.shapeq = self.query(x)k = self.key(x)v = self.value(x)atten_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(input_dim, dtype=torch.float))if mask is not None:mask = mask.unsqueeze(1)attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))        atten_scores = torch.softmax(atten_weights, dim=-1)attented_values = torch.matmul(atten_scores, v)return attented_values# 自动填充函数
def pad_sequences(sequences, max_len=None):batch_size = len(sequences)input_dim = sequences[0].shape[-1]lengths = torch.tensor([seq.shape[0] for seq in sequences])max_len = max_len or lengths.max().item()padded = torch.zeros(batch_size, max_len, input_dim)for i, seq in enumerate(sequences):seq_len = seq.shape[0]padded[i, :seq_len, :] = seqmask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)return padded, mask.long()if __name__ == '__main__':batch_size = 2seq_len = 3input_dim = 128seq_len_1 = 3seq_len_2 = 5x1 = torch.randn(seq_len_1, input_dim)            x2 = torch.randn(seq_len_2, input_dim)target_seq_len = 10    padded_x, mask = pad_sequences([x1, x2], target_seq_len)selfattention = SelfAttention(input_dim)    attention = selfattention(padded_x)print(attention)

多头自注意力机制

import torch
import torch.nn as nn# 定义多头自注意力模块
class MultiHeadSelfAttention(nn.Module):def __init__(self, input_dim, num_heads):super(MultiHeadSelfAttention, self).__init__()self.num_heads = num_headsself.head_dim = input_dim // num_headsself.query = nn.Linear(input_dim, input_dim)self.key = nn.Linear(input_dim, input_dim)self.value = nn.Linear(input_dim, input_dim)        def forward(self, x, mask=None):batch_size, seq_len, input_dim = x.shape# 将输入向量拆分为多个头## transpose(1,2)后变成 (batch_size, self.num_heads, seq_len, self.head_dim)形式q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# 计算注意力权重attn_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))# 应用 padding maskif mask is not None:# mask: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) 用于广播mask = mask.unsqueeze(1).unsqueeze(2)  # 扩展维度以便于广播attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))        attn_scores = torch.softmax(attn_weights, dim=-1)# 注意力加权求和attended_values = torch.matmul(attn_scores, v).transpose(1, 2).contiguous().view(batch_size, seq_len, input_dim)return attended_values# 自动填充函数
def pad_sequences(sequences, max_len=None):batch_size = len(sequences)input_dim = sequences[0].shape[-1]lengths = torch.tensor([seq.shape[0] for seq in sequences])max_len = max_len or lengths.max().item()padded = torch.zeros(batch_size, max_len, input_dim)for i, seq in enumerate(sequences):seq_len = seq.shape[0]padded[i, :seq_len, :] = seqmask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)return padded, mask.long()if __name__ == '__main__':heads = 2batch_size = 2seq_len_1 = 3seq_len_2 = 5input_dim = 128x1 = torch.randn(seq_len_1, input_dim)            x2 = torch.randn(seq_len_2, input_dim)target_seq_len = 10    padded_x, mask = pad_sequences([x1, x2], target_seq_len)multiheadattention = MultiHeadSelfAttention(input_dim, heads)attention = multiheadattention(padded_x, mask)    print(attention)

相关文章:

  • Vue如何获取Dom
  • 第5章:MCP框架详解
  • 【LeetCode 热题 100】哈希、双指针、滑动窗口
  • 大模型数据味蕾论
  • 《AI大模型应知应会100篇》第31篇:大模型重塑教育:从智能助教到学习革命的实践探索
  • 在线查看【免费】 mp3,wav,mp4,flv 等音视频格式文件文件格式网站
  • 离线安装rabbitmq全流程
  • 零基础上手Python数据分析 (20):Seaborn 统计数据可视化 - 轻松绘制精美统计图表!
  • 多源异构网络安全数据(CAPEC、CPE、CVE、CVSS、CWE)的作用、数据内容及其相互联系的详细分析
  • 5565反射内存网络产品
  • 【NVIDIA】Isaac Sim 4.5.0 加载 Franka 机械臂
  • (cvpr2025) LSNet: See Large, Focus Small
  • 【Redis】Jedis与Jedis连接池
  • 4月谷歌新政 | Google Play今年对“数据安全”的管控将全面升级!
  • 阴阳龙 第31次CCF-CSP计算机软件能力认证
  • opencv 对图片的操作
  • .NET 8 升级 .NET Upgrade Assistant
  • 逻辑回归(Logistic Regression)
  • IDEA/WebStorm中Git操作缓慢的解决方案
  • UDP协议详解
  • 为什么要读书?——北京地铁春季书单(2025)
  • 商务部:支持外籍医生开设诊所,优化罕见病药品进口抽检模式
  • “6+2”小复式追加票!松江购彩者擒大乐透1672万头奖
  • 经济参考报:安全是汽车智能化的终极目标
  • 海拔四百公里的救赎
  • 第13届京都国际摄影节,14位艺术家展现东西方视角:人性