当前位置: 首页 > news >正文

DeepseekV3MLP 模块

目录

    • 代码
    • 代码解释
      • 导入和激活函数
      • 配置类
      • 初始化方法
      • 前向传播方法
      • 计算流程
    • 代码可视化

代码

import torch
import torch.nn as nn
import torch.nn.functional as F# 定义激活函数字典
ACT2FN = {"relu": F.relu,"gelu": F.gelu,"silu": F.silu,"swish": lambda x: x * torch.sigmoid(x),"gelu_new": lambda x: 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))),
}# 简单的配置类
class ModelConfig:def __init__(self, hidden_size=768, intermediate_size=3072, hidden_act="gelu"):self.hidden_size = hidden_sizeself.intermediate_size = intermediate_sizeself.hidden_act = hidden_actclass DeepseekV3MLP(nn.Module):def __init__(self, config, hidden_size=None, intermediate_size=None):super().__init__()self.config = configself.hidden_size = config.hidden_size if hidden_size is None else hidden_sizeself.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_sizeself.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)self.act_fn = ACT2FN[config.hidden_act]def forward(self, x):# 打印输入张量的维度print(f"输入 x 的维度: {x.shape}")# 步骤1: 通过gate_proj线性层gate_output = self.gate_proj(x)print(f"gate_proj(x) 的维度: {gate_output.shape}")# 步骤2: 应用激活函数gate_activated = self.act_fn(gate_output)print(f"act_fn(gate_proj(x)) 的维度: {gate_activated.shape}")# 步骤3: 通过up_proj线性层up_output = self.up_proj(x)print(f"up_proj(x) 的维度: {up_output.shape}")# 步骤4: 元素级乘法gated_up = gate_activated * up_outputprint(f"act_fn(gate_proj(x)) * up_proj(x) 的维度: {gated_up.shape}")# 步骤5: 通过down_proj线性层down_proj = self.down_proj(gated_up)print(f"down_proj(act_fn(gate_proj(x)) * up_proj(x)) 的维度: {down_proj.shape}")return down_proj

代码解释

导入和激活函数

首先导入了必要的 PyTorch 库,并定义了一个激活函数字典 ACT2FN,包含了多种常用的激活函数:

  • relu: 修正线性单元激活函数
  • gelu: 高斯误差线性单元激活函数
  • silu: Sigmoid线性单元激活函数
  • swish: Swish激活函数 (x * sigmoid(x))
  • gelu_new: GELU激活函数的另一种实现方式

配置类

ModelConfig 类用于存储模型的基本配置参数:

  • hidden_size: 隐藏层大小,默认为768
  • intermediate_size: 中间层大小,默认为3072
  • hidden_act: 使用的激活函数类型,默认为"gelu"

初始化方法

def __init__(self, config, hidden_size=None, intermediate_size=None):

初始化方法接收配置对象和可选的隐藏层大小和中间层大小参数,创建了三个线性层:

  • gate_proj: 门控投影层,将输入从 hidden_size 映射到 intermediate_size
  • up_proj: 上投影层,同样将输入从 hidden_size 映射到 intermediate_size
  • down_proj: 下投影层,将中间结果从 intermediate_size 映射回 hidden_size

前向传播方法

前向传播方法实现了 SwiGLU 激活机制的变体,具体步骤如下:

  1. 输入张量 x 通过 gate_proj 线性层
  2. gate_proj 的输出应用激活函数
  3. 输入张量 x 通过 up_proj 线性层
  4. 将激活后的 gate_proj 输出与 up_proj 输出进行元素级乘法
  5. 将乘法结果通过 down_proj 线性层映射回原始维度

这种设计是 SwiGLU 激活的一种变体,通过门控机制增强了模型的表达能力。每一步都打印了张量的维度,便于调试和理解数据流。

计算流程

假设输入张量维度为 [batch_size, seq_length, hidden_size],例如 [2, 10, 768]:

  1. 通过 gate_proj 和 up_proj 后,维度变为 [2, 10, 3072]
  2. 激活函数和元素级乘法保持维度不变
  3. 最后通过 down_proj 将维度映射回 [2, 10, 768]

这种设计允许模型在中间层扩展维度以增加表达能力,然后再压缩回原始维度,是现代大型语言模型中常用的技术。

代码可视化

def main():# 创建配置config = ModelConfig(hidden_size=768, intermediate_size=3072, hidden_act="gelu")# 实例化模型model = DeepseekV3MLP(config)# 创建一个随机输入张量进行测试batch_size = 2seq_length = 10input_tensor = torch.rand(batch_size, seq_length, config.hidden_size)# 前向传播output = model(input_tensor)# 打印输入和输出的形状print(f"模型参数总数: {sum(p.numel() for p in model.parameters())}")if __name__ == "__main__":main()
输入 x 的维度: torch.Size([2, 10, 768])
gate_proj(x) 的维度: torch.Size([2, 10, 3072])
act_fn(gate_proj(x)) 的维度: torch.Size([2, 10, 3072])
up_proj(x) 的维度: torch.Size([2, 10, 3072])
act_fn(gate_proj(x)) * up_proj(x) 的维度: torch.Size([2, 10, 3072])
down_proj(act_fn(gate_proj(x)) * up_proj(x)) 的维度: torch.Size([2, 10, 768])
模型参数总数: 7077888

在这里插入图片描述

相关文章:

  • 快充协议芯片XSP04D支持使用一个Type-C与电脑传输数据和快充取电功能
  • 腾讯一面-软件开发实习-PC客户端开发方向
  • LX4-数据手册相关
  • CentOS 7进入救援模式——VirtualBox虚拟机
  • 23. git reset
  • unity TEngine学习4
  • 【Andorid备案获取keystore里面的公钥和SHA-1码等等】
  • 怎么发布、更新Python第三方库?以potx-cloud为例
  • PHP日志会对服务器产生哪些影响?
  • 基于DeepSeek/AI的资产测绘与威胁图谱构建
  • 华为VRP系统知识总结及案例试题
  • 【Python核心库实战指南】从数据处理到Web开发
  • TapData × 梦加速计划 | 与 AI 共舞,TapData 携 AI Ready 实时数据平台亮相加速营,企业数据基础设施现代化
  • DeepSeek赋能Nuclei:打造网络安全检测的“超级助手”
  • RHCE 练习二:通过 ssh 实现两台主机免密登录以及 nginx 服务通过多 IP 区分多网站
  • 图论-Floyd算法
  • aws服务--S3介绍使用代码集成
  • 【Vue】修饰符
  • 前端笔记-AJAX
  • 【自然语言处理与大模型】模型压缩技术之蒸馏
  • 普京:俄方积极对待任何和平倡议
  • 空山日落雨初收,来文徵明的画中听泉
  • 盗播热门影视剧、电影被追究刑事附带民事责任,最高法发声
  • 2025年上海车展后天开幕,所有进境展品已完成通关手续
  • 国家卫健委:无资质机构严禁开展产前筛查
  • 长三角议事厅·周报|服务业扩大开放:长三角六城联动新探索