Mamba2模型的实现
深入探索Mamba模型架构与应用 - 商品搜索 - 京东
DeepSeek大模型高性能核心技术与多模态融合开发 - 商品搜索 - 京东
Mamba2在原有的Mamba模型的基础上融入了注意力机制,这一创新性的改进赋予了模型对远程信息的关注能力。通过引入注意力机制,Mamba2不仅保留了原Mamba模型对序列信息敏感的优点,还能够有效地捕获并处理长距离依赖关系。这一变革性的增强使得Mamba2在处理复杂任务时更加灵活和全面,大大提高了模型的性能和应用范围。
结构化状态空间对偶性是在核心部分添加注意力机制,即将原有的SSM替换成带有注意力组件的新型架构,在具体实现上,我们可以参照GLM架构的注意力实现,首先完成其中的注意力机制,代码如下:
def segsum(x: Tensor, device: Device = None) -> Tensor: """Stable segment sum calculation. `exp(segsum(A))` 生成一个1-半可分矩阵,等同于一个标量SSM(Scalar SSM,可能是指某种特定的半可分矩阵)""" # 获取输入Tensor x的最后一个维度的大小,通常代表时间序列的长度 T = x.size(-1) # 使用repeat函数扩展x的维度,使其在最后一个维度上增加一个与T相同大小的维度e # 这实际上是为后续的矩阵操作做准备,生成一个二维的矩阵,其中每一行都是原始x的复制 x = repeat(x, "... d -> ... d e", e=T)# 创建一个下三角矩阵,其中对角线下方的元素为1(True),其余为0(False)# 这个矩阵将用作后续操作的掩码mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1) # 使用上面创建的掩码,将x中上三角部分的元素替换为0 x = x.masked_fill(~mask, 0) # 沿着倒数第二个维度(即新扩展的维度e)计算累积和 x_segsum = torch.cumsum(x, dim=-2) # 创建一个新的下三角矩阵,但这次包括对角线元素 mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0) # 使用新的掩码,将x_segsum中上三角部分的元素替换为负无穷大 # 这样做可能是为了在后续的计算中忽略这些值,或者使它们在softmax等操作中变得非常小 x_segsum = x_segsum.masked_fill(~mask, -torch.inf) # 返回计算后的分段和Tensor return x_segsum
这里简单地完成了注意力计算,即对输入的序列内容进行注意力建模,而mask的存在可以使得模型在计算时只关注前面步骤中的Token而不会“窥视”未来的内容。
SSD的存在是在原有的SSM架构上添加了注意力机制,即通过注意力机制对输入的数据进行全局建模,代码如下:
def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None): """结构化状态空间对偶性(SSD) - Mamba-2的核心 这与博客文章中的最小SSD代码几乎完全相同参数 x: 输入数据,形状为 (batch, seqlen, n_heads, d_head) A: 形状为 (batch, seqlen, n_heads) 的参数 B: 形状为 (batch, seqlen, n_heads, d_state) 的参数 C: 形状为 (batch, seqlen, n_heads, d_state) 的参数 返回 y: 输出数据,形状为 (batch, seqlen, n_heads, d_head) 来源 1. https://tridao.me/blog/2024/mamba2-part3-algorithm/ 2. 给定的GitHub链接 """ # 将数据重新排列成块,以便于分块处理 x, A, B, C = [rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)] # 对A进行重排,并计算其累积和 A = rearrange(A, "b c l h -> b h c l") A_cumsum = torch.cumsum(A, dim=-1) # 1. 计算每个块内的输出(对角块) L = torch.exp(segsum(A, device=device)) # 计算稳定的分段和,并取指数 # 使用einsum进行高效计算 Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x) # 2. 计算每个块内的状态(B项,用于低秩分解的非对角块) decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) # 计算衰减状态 # 计算状态 states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x) # 3. 计算块间的SSM递推关系,以在块边界产生正确的SSM状态 if initial_states is None: # 如果没有提供初始状态,则使用零初始化 initial_states = torch.zeros_like(states[:, :1]) # 连接初始状态和计算出的状态 states = torch.cat([initial_states, states], dim=1) decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device)) # 计算块间的衰减 # 更新状态 new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states) # 分离出最终状态和其余状态 states, final_state = new_states[:, :-1], new_states[:, -1] # 4. 计算每个块的状态到输出的转换 state_decay_out = torch.exp(A_cumsum) # 计算输出衰减 Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out) # 计算非对角块的输出 # 将对角块和非对角块的输出相加,得到最终输出 Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") # 重新排列输出形状 return Y, final_state # 返回输出和最终状态
这段代码实现了结构化状态空间对偶性(Structured State Space Duality,SSD)算法,它是Mamba2模型的核心组成部分。通过分块处理、计算分段和以及使用einsum进行高效计算等步骤,实现了对输入数据的转换和输出。
Mamba2是一个基于结构化状态空间对偶性的深度学习模型,旨在处理序列数据,如自然语言处理或时间序列分析中的任务。在具体实现上,Mamba2的实现架构与Mamba类似。
- 输入投影层(self.in_proj):首先,输入数据u(形状为(batch, seqlen, d_model)),通过一个线性层(nn.Linear),该层将输入映射到一个更高维的空间,以便模型能够捕捉到更丰富的特征。这个映射后的表示被分割成多个部分,包括z(用于归一化)、xBC(包含状态信息和候选输出)和dt(时间步长偏差),这些部分将在后续处理中发挥不同作用。
- 一维卷积层(self.conv1d):xBC部分通过一个一维卷积层,该层用于捕捉序列中的局部依赖关系。卷积层使用分组卷积,每个通道独立处理,以减少参数数量和计算复杂度。卷积后的输出经过激活函数(如SiLU)以增加非线性,然后再次分割成x(状态表示)、B(状态转换矩阵的一部分)和C(输出转换矩阵的一部分)。
- SSD计算:模型的核心是结构化状态空间对偶性的计算,它结合了自注意力机制和状态空间模型的优势。通过ssd函数,模型利用x、A(对数衰减率,通过self.A_log参数计算得出)、B、C和dt来计算输出y和SSM状态。SSD允许模型以高效的方式处理长序列,同时捕捉到序列中的长距离依赖关系。
- 残差连接和归一化:计算出的输出y通过一个残差连接(residual connection)[yx1] 与原始状态表示x进行加权求和(通过self.D参数控制权重),以增加模型的表达能力和稳定性。然后,这个结果通过RMSNorm归一化层进行归一化处理,以加快训练速度和改善模型的泛化能力。
- 输出投影层(self.out_proj):归一化后的输出通过一个线性层(nn.Linear)投影回原始输入数据的维度(d_model),以生成最终的输出表示y。
- 隐藏状态和推理步骤:在推理过程中,模型使用InferenceCache对象来存储和更新隐藏状态(包括卷积状态和SSM状态)。通过step方法,模型可以逐步处理输入序列中的每个时间步,并在每个步骤中更新隐藏状态和生成输出。这种方式使得Mamba2在推理时的时间复杂度与序列长度成线性关系,相比传统的自注意力模型具有更高的效率。
其代码实现如下:
class Mamba2(nn.Module):def _ _init_ _(self, args: Mamba2Config, device: Device = None):super()._ _init_ _()self.args = argsself.device = device# Order: (z, x, B, C, dt)d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheadsself.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)conv_dim = args.d_inner + 2 * args.d_stateself.conv1d = nn.Conv1d(in_channels=conv_dim,out_channels=conv_dim,kernel_size=args.d_conv,groups=conv_dim,padding=args.d_conv - 1,device=device,)self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))self.D = nn.Parameter(torch.empty(args.nheads, device=device))self.norm = RMSNorm(args.d_inner, device=device)self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)def forward(self, u: Tensor, h: InferenceCache | None = None):if h:return self.step(u, h)A = -torch.exp(self.A_log) # (nheads,)zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)z, xBC, dt = torch.split(zxbcdt,[self.args.d_inner,self.args.d_inner + 2 * self.args.d_state,self.args.nheads,],dim=-1,)dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads)# Pad or truncate xBC seqlen to d_convconv_state = F.pad(rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0))xBC = silu(self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]) # (batch, seqlen, d_inner + 2 * d_state))x, B, C = torch.split(xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1)x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)y, ssm_state = ssd(x * dt.unsqueeze(-1),A * dt,rearrange(B, "b l n -> b l 1 n"),rearrange(C, "b l n -> b l 1 n"),self.args.chunk_size,device=self.device,)y = y + x * self.D.unsqueeze(-1)y = rearrange(y, "b l h p -> b l (h p)")y = self.norm(y, z)y = self.out_proj(y)h = InferenceCache(conv_state, ssm_state)return y, hdef step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]:"""单步推理函数,根据当前输入u和隐藏状态h计算输出和更新后的隐藏状态""" assert u.shape[1] == 1, "Only one token can be decoded per inference step"zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj)z, xBC, dt = torch.split(zxbcdt,[self.args.d_inner,self.args.d_inner + 2 * self.args.d_state,self.args.nheads,],dim=-1,)# Advance convolution inputh.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))h.conv_state[:, :, -1] = xBC# Convolution stepxBC = torch.sum(h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)xBC += self.conv1d.biasxBC = silu(xBC)x, B, C = torch.split(xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1)A = -torch.exp(self.A_log) # (nheads,)# SSM stepdt = F.softplus(dt + self.dt_bias) # (batch, nheads)dA = torch.exp(dt * A) # (batch, nheads)x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)y = y + rearrange(self.D, "h -> h 1") * xy = rearrange(y, "b h p -> b (h p)")y = self.norm(y, z)y = self.out_proj(y)return y.unsqueeze(1), h

在推理模式下,Mamba2模型支持通过step函数进行单步推理。这使得模型在处理长序列时能够逐个时间步地生成输出,从而显著减少内存占用并提高推理速度。在step函数中,隐藏状态(包括卷积状态和SSM状态)在每个推理步骤中都会根据当前输入进行更新,并用于生成下一个输出。
Mamba2模型通过结合SSD的高效计算、一维卷积的局部特征提取能力以及残差连接和归一化的稳定性,实现了对序列数据的快速而准确的处理。