当前位置: 首页 > news >正文

Mamba2模型的实现

深入探索Mamba模型架构与应用 - 商品搜索 - 京东

 DeepSeek大模型高性能核心技术与多模态融合开发 - 商品搜索 - 京东

Mamba2在原有的Mamba模型的基础上融入了注意力机制,这一创新性的改进赋予了模型对远程信息的关注能力。通过引入注意力机制,Mamba2不仅保留了原Mamba模型对序列信息敏感的优点,还能够有效地捕获并处理长距离依赖关系。这一变革性的增强使得Mamba2在处理复杂任务时更加灵活和全面,大大提高了模型的性能和应用范围。

12.1.1  Mamba2核心组件SSD详解

结构化状态空间对偶性是在核心部分添加注意力机制,即将原有的SSM替换成带有注意力组件的新型架构,在具体实现上,我们可以参照GLM架构的注意力实现,首先完成其中的注意力机制,代码如下:

def segsum(x: Tensor, device: Device = None) -> Tensor:  """Stable segment sum calculation.  `exp(segsum(A))` 生成一个1-半可分矩阵,等同于一个标量SSM(Scalar SSM,可能是指某种特定的半可分矩阵)"""  # 获取输入Tensor x的最后一个维度的大小,通常代表时间序列的长度  T = x.size(-1)  # 使用repeat函数扩展x的维度,使其在最后一个维度上增加一个与T相同大小的维度e  # 这实际上是为后续的矩阵操作做准备,生成一个二维的矩阵,其中每一行都是原始x的复制  x = repeat(x, "... d -> ... d e", e=T)# 创建一个下三角矩阵,其中对角线下方的元素为1(True),其余为0(False)# 这个矩阵将用作后续操作的掩码mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)  # 使用上面创建的掩码,将x中上三角部分的元素替换为0  x = x.masked_fill(~mask, 0)  # 沿着倒数第二个维度(即新扩展的维度e)计算累积和  x_segsum = torch.cumsum(x, dim=-2)  # 创建一个新的下三角矩阵,但这次包括对角线元素  mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)  # 使用新的掩码,将x_segsum中上三角部分的元素替换为负无穷大  # 这样做可能是为了在后续的计算中忽略这些值,或者使它们在softmax等操作中变得非常小  x_segsum = x_segsum.masked_fill(~mask, -torch.inf)  # 返回计算后的分段和Tensor  return x_segsum

这里简单地完成了注意力计算,即对输入的序列内容进行注意力建模,而mask的存在可以使得模型在计算时只关注前面步骤中的Token而不会“窥视”未来的内容。

SSD的存在是在原有的SSM架构上添加了注意力机制,即通过注意力机制对输入的数据进行全局建模,代码如下:

def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):  """结构化状态空间对偶性(SSD) - Mamba-2的核心  这与博客文章中的最小SSD代码几乎完全相同参数  x: 输入数据,形状为 (batch, seqlen, n_heads, d_head)  A: 形状为 (batch, seqlen, n_heads) 的参数  B: 形状为 (batch, seqlen, n_heads, d_state) 的参数  C: 形状为 (batch, seqlen, n_heads, d_state) 的参数  返回  y: 输出数据,形状为 (batch, seqlen, n_heads, d_head)  来源  1. https://tridao.me/blog/2024/mamba2-part3-algorithm/  2. 给定的GitHub链接  """  # 将数据重新排列成块,以便于分块处理  x, A, B, C = [rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)]  # 对A进行重排,并计算其累积和  A = rearrange(A, "b c l h -> b h c l")  A_cumsum = torch.cumsum(A, dim=-1)  # 1. 计算每个块内的输出(对角块)  L = torch.exp(segsum(A, device=device))  # 计算稳定的分段和,并取指数 # 使用einsum进行高效计算  Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)  # 2. 计算每个块内的状态(B项,用于低秩分解的非对角块)  decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)  # 计算衰减状态  # 计算状态  states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)  # 3. 计算块间的SSM递推关系,以在块边界产生正确的SSM状态  if initial_states is None:  # 如果没有提供初始状态,则使用零初始化  initial_states = torch.zeros_like(states[:, :1])  # 连接初始状态和计算出的状态  states = torch.cat([initial_states, states], dim=1)  decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))  # 计算块间的衰减  # 更新状态  new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)  # 分离出最终状态和其余状态  states, final_state = new_states[:, :-1], new_states[:, -1]  # 4. 计算每个块的状态到输出的转换  state_decay_out = torch.exp(A_cumsum)  # 计算输出衰减  Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)  # 计算非对角块的输出  # 将对角块和非对角块的输出相加,得到最终输出  Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") # 重新排列输出形状  return Y, final_state  # 返回输出和最终状态

这段代码实现了结构化状态空间对偶性(Structured State Space Duality,SSD)算法,它是Mamba2模型的核心组成部分。通过分块处理、计算分段和以及使用einsum进行高效计算等步骤,实现了对输入数据的转换和输出。

12.1.2  基于SSD的Mamba2模型

Mamba2是一个基于结构化状态空间对偶性的深度学习模型,旨在处理序列数据,如自然语言处理或时间序列分析中的任务。在具体实现上,Mamba2的实现架构与Mamba类似。

  1. 输入投影层(self.in_proj):首先,输入数据u(形状为(batch, seqlen, d_model)),通过一个线性层(nn.Linear),该层将输入映射到一个更高维的空间,以便模型能够捕捉到更丰富的特征。这个映射后的表示被分割成多个部分,包括z(用于归一化)、xBC(包含状态信息和候选输出)和dt(时间步长偏差),这些部分将在后续处理中发挥不同作用。
  2. 一维卷积层(self.conv1d):xBC部分通过一个一维卷积层,该层用于捕捉序列中的局部依赖关系。卷积层使用分组卷积,每个通道独立处理,以减少参数数量和计算复杂度。卷积后的输出经过激活函数(如SiLU)以增加非线性,然后再次分割成x(状态表示)、B(状态转换矩阵的一部分)和C(输出转换矩阵的一部分)。
  3. SSD计算:模型的核心是结构化状态空间对偶性的计算,它结合了自注意力机制和状态空间模型的优势。通过ssd函数,模型利用xA(对数衰减率,通过self.A_log参数计算得出)、BCdt来计算输出y和SSM状态。SSD允许模型以高效的方式处理长序列,同时捕捉到序列中的长距离依赖关系。
  4. 残差连接和归一化:计算出的输出y通过一个残差连接(residual connection)[yx1] 与原始状态表示x进行加权求和(通过self.D参数控制权重),以增加模型的表达能力和稳定性。然后,这个结果通过RMSNorm归一化层进行归一化处理,以加快训练速度和改善模型的泛化能力。
  5. 输出投影层(self.out_proj):归一化后的输出通过一个线性层(nn.Linear)投影回原始输入数据的维度(d_model),以生成最终的输出表示y
  6. 隐藏状态和推理步骤:在推理过程中,模型使用InferenceCache对象来存储和更新隐藏状态(包括卷积状态和SSM状态)。通过step方法,模型可以逐步处理输入序列中的每个时间步,并在每个步骤中更新隐藏状态和生成输出。这种方式使得Mamba2在推理时的时间复杂度与序列长度成线性关系,相比传统的自注意力模型具有更高的效率。

其代码实现如下:

class Mamba2(nn.Module):def _ _init_ _(self, args: Mamba2Config, device: Device = None):super()._ _init_ _()self.args = argsself.device = device# Order: (z, x, B, C, dt)d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheadsself.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)conv_dim = args.d_inner + 2 * args.d_stateself.conv1d = nn.Conv1d(in_channels=conv_dim,out_channels=conv_dim,kernel_size=args.d_conv,groups=conv_dim,padding=args.d_conv - 1,device=device,)self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))self.D = nn.Parameter(torch.empty(args.nheads, device=device))self.norm = RMSNorm(args.d_inner, device=device)self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)def forward(self, u: Tensor, h: InferenceCache | None = None):if h:return self.step(u, h)A = -torch.exp(self.A_log)  # (nheads,)zxbcdt = self.in_proj(u)  # (batch, seqlen, d_in_proj)z, xBC, dt = torch.split(zxbcdt,[self.args.d_inner,self.args.d_inner + 2 * self.args.d_state,self.args.nheads,],dim=-1,)dt = F.softplus(dt + self.dt_bias)  # (batch, seqlen, nheads)# Pad or truncate xBC seqlen to d_convconv_state = F.pad(rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0))xBC = silu(self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :])  # (batch, seqlen, d_inner + 2 * d_state))x, B, C = torch.split(xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1)x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)y, ssm_state = ssd(x * dt.unsqueeze(-1),A * dt,rearrange(B, "b l n -> b l 1 n"),rearrange(C, "b l n -> b l 1 n"),self.args.chunk_size,device=self.device,)y = y + x * self.D.unsqueeze(-1)y = rearrange(y, "b l h p -> b l (h p)")y = self.norm(y, z)y = self.out_proj(y)h = InferenceCache(conv_state, ssm_state)return y, hdef step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]:"""单步推理函数,根据当前输入u和隐藏状态h计算输出和更新后的隐藏状态"""  assert u.shape[1] == 1, "Only one token can be decoded per inference step"zxbcdt = self.in_proj(u.squeeze(1))  # (batch, d_in_proj)z, xBC, dt = torch.split(zxbcdt,[self.args.d_inner,self.args.d_inner + 2 * self.args.d_state,self.args.nheads,],dim=-1,)# Advance convolution inputh.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))h.conv_state[:, :, -1] = xBC# Convolution stepxBC = torch.sum(h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)xBC += self.conv1d.biasxBC = silu(xBC)x, B, C = torch.split(xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1)A = -torch.exp(self.A_log)  # (nheads,)# SSM stepdt = F.softplus(dt + self.dt_bias)  # (batch, nheads)dA = torch.exp(dt * A)  # (batch, nheads)x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)y = y + rearrange(self.D, "h -> h 1") * xy = rearrange(y, "b h p -> b (h p)")y = self.norm(y, z)y = self.out_proj(y)return y.unsqueeze(1), h

在推理模式下,Mamba2模型支持通过step函数进行单步推理。这使得模型在处理长序列时能够逐个时间步地生成输出,从而显著减少内存占用并提高推理速度。在step函数中,隐藏状态(包括卷积状态和SSM状态)在每个推理步骤中都会根据当前输入进行更新,并用于生成下一个输出。

Mamba2模型通过结合SSD的高效计算、一维卷积的局部特征提取能力以及残差连接和归一化的稳定性,实现了对序列数据的快速而准确的处理。

相关文章:

  • 《系统架构 - Java 企业应用架构中的完整层级划分》
  • 大学之大:韩国科学技术研究院2025.4.28
  • 聊一聊接口自动化测试的稳定性如何保障
  • 探秘Transformer系列之(31)--- Medusa
  • 嵌入式RTOS实战:uC/OS-III最新版移植指南(附项目源码)
  • DAY9-USF4.0技术文档笔记
  • 学习笔记:Qlib 量化投资平台框架 — MAIN COMPONENTS (Part I)
  • PHP经验笔记
  • 【C++教程】三目运算符
  • Vue3中Hooks与普通函数的区别
  • 高效的CMS能帮助你快速建站。
  • 微机控制电液伺服钢轨滚动疲劳试验机
  • 喜马拉雅卖身腾讯音乐:在线音频独立时代的终结
  • shell(3)
  • 软件评测师考点重点知识
  • NdrpPointerUnmarshallInternal函数分析之pStubMsg--pAllocAllNodesContext的由来
  • vmare pro安装报错用户在命令行上发出了EULAS_AGREED=1,表示不接受许可协议的错误解决方法
  • MCP:如何通过模型控制推理助力AI模型实现“深度思考”?
  • timerfd定时器时间轮定时器
  • 机器学习:【抛掷硬币的贝叶斯后验概率】
  • 法治日报调查直播间“杀熟”乱象:熟客越买越贵,举证难维权不易
  • 淮安四韵·名城新章: 网络名人领略“运河之都”魅力
  • 影子调查丨起底“三无”拖拉机产销链:出口掩内销,监管如虚设
  • 俄联邦安全局:俄军高级官员汽车爆炸案嫌疑人已被捕
  • 共话城市自然之美,“微观黄浦”自媒体网络大V沙龙首场活动举行
  • 三亚一景区发生游客溺亡事件,官方通报:排除他杀