当前位置: 首页 > news >正文

【NLP】25.python实现点积注意力,加性注意力,Decoder(解码器)与 Attention

1. 点积注意力(Dot-Product Attention)

点积注意力是最简单的注意力机制之一,其基本思想是通过计算查询(query)和键(key)之间的点积来得到相似度,进而为每个值(value)分配一个权重。具体步骤如下:

  • 计算相似度:将查询向量和键向量的点积作为它们的相似度。
  • 归一化:对相似度进行softmax归一化,得到注意力权重。
  • 加权求和:根据计算出来的注意力权重,对值(value)进行加权求和,得到上下文向量。
代码实现:
class DotProductAttention(nn.Module):
    def __init__(self, hidden_size):
        super(DotProductAttention, self).__init__()
        self.out_size = hidden_size * 2

    def forward(self, query, keys):
        # 计算查询和键的点积相似度
        scores = (query * keys).sum(-1)
        scores = scores.unsqueeze(1)

        # 对相似度进行归一化
        weights = F.softmax(scores, dim=-1)

        # 加权求和得到上下文向量
        context = torch.bmm(weights, keys)
        return context, weights

在这里:

  • query 是解码器的隐藏状态。
  • keys 是编码器的输出。
  • scores 是查询和键的点积。
  • weights 是对 scores 进行 softmax 归一化后的注意力权重。
  • context 是加权求和后的上下文向量,表示当前时刻的注意力上下文。

2. 加性注意力(Additive Attention)

加性注意力是另一种常见的注意力机制,它通过将查询和键分别通过一个线性变换后加和,再通过一个非线性激活函数(如tanh)来计算相似度。

具体步骤如下:

  • 计算相似度:查询和键经过线性变换后加和,经过tanh激活函数,再通过一个线性变换输出最终的相似度。
  • 归一化:对相似度进行softmax归一化,得到注意力权重。
  • 加权求和:根据注意力权重加权求和得到上下文向量。
代码实现:
class AdditiveAttention(nn.Module):
    def __init__(self, hidden_size):
        super(AdditiveAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)
        self.out_size = hidden_size * 2

    def forward(self, query, keys):
        # 计算查询和键的加性相似度
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        # 对相似度进行归一化
        weights = F.softmax(scores, dim=-1)

        # 加权求和得到上下文向量
        context = torch.bmm(weights, keys)
        return context, weights

在这里:

  • query 是解码器的隐藏状态。
  • keys 是编码器的输出。
  • scores 是加性计算后的相似度。
  • weights 是对 scores 进行 softmax 归一化后的注意力权重。
  • context 是加权求和后的上下文向量。

3. Decoder(解码器)与 Attention

解码器(Decoder)将注意力机制集成到其计算过程中。解码器的输入包括编码器的输出(即 encoder_outputs)和编码器的最后一个隐藏状态(即 encoder_hidden)。在每一步,解码器都计算当前的上下文向量,并根据这个上下文向量生成新的输出。

代码实现:
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, attention_type="none", dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.attention = get_attention_module(attention_type, hidden_size)
        self.gru = nn.GRU(self.attention.out_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []
        attentions = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if target_tensor is not None:
                # Teacher forcing
                decoder_input = target_tensor[:, i].unsqueeze(1)
            else:
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions

在这段代码中,AttnDecoderRNN 包含了注意力机制,并且在每个时间步长上,根据当前的输入和上下文(从编码器传来的输出和隐藏状态)生成下一个预测值。

相关文章:

  • 六、adb通过Wifi连接
  • cut命令:剪切
  • LeetCode[18]四数之和
  • 江顺科技应收账款期后回款比率大降:现金流急剧减少,研发费用率下滑
  • Unity中计算闭合路径内部的所有点位
  • Kubenetes-基于kubespray 部署集群
  • 鸿蒙开发-编译器使用
  • 如何 在 Cesium 中选取特定经纬度区域,特定视角 ,渲染成图片
  • 什么叫“架构”
  • 交通运输部4项网络与数据安全标准发布
  • Bash脚本编写基础指南
  • 对接印度尼西亚股票数据源API
  • Linux ELF文件格式
  • 【笔记ing】AI大模型-03深度学习基础理论
  • 深入剖析C++中 String 类的模拟实现
  • Java实现快速排序算法
  • Java 数据库访问工具 dbVisitor 的技术解析与同类工具比较
  • Kimi-VL 解读:高效 MoE 视觉语言模型VLM,兼顾长上下文与高分辨率
  • MySQL——学习InnoDB(1)
  • LabVIEW配电器自动测试系统
  • 明查|俄罗斯征兵部门突袭澡堂抓捕壮丁?
  • 上海古籍书店重新开卷,在这里淘旧书获新知
  • 林诗栋4比1战胜梁靖崑,晋级世界杯男单决赛将和雨果争冠
  • 寺庙餐饮,被年轻人追捧成新顶流
  • 明查|美军“杜鲁门”号航空母舰遭胡塞武装打击已退役?
  • 广西东兰官方通报“村民求雨耕种”:摆拍,恶意炒作