当前位置: 首页 > news >正文

DreamDiffusion的mae_for_eeg.py网络架构

DreamDiffusion的mae_for_eeg.py网络架构图,包含更细致的模块说明和数据处理流程:

graph TD%% 输入层A[原始EEG信号<br>(批次=2, 通道=128, 长度=512)] --> B[1D块嵌入层]%% 块嵌入处理B -->|"1D卷积<br>(核大小=4, 步长=4)"| C[块序列<br>(2, 128, 1024)]%% 掩码处理C --> D[随机块掩码<br>丢弃75%块<br>可选focus_range局部增强]D --> E[编码器主干<br>24层Transformer]%% 编码器输出E --> F[潜在表示<br>(2, 129, 1024)<br>含CLS标记]%% 双解码路径F --> G[EEG重建解码器]F --> I[图像特征解码器*<br>*可选模块]%% 输出层G --> H[重建EEG信号<br>(2, 128, 512)]I --> J[图像特征图<br>(2, 512, 28×28)]%% 辅助模块F --> K[分类头] & L[映射头]K --> M[分类结果<br>40类]L --> N[低维特征<br>768维]%% 样式定义classDef input fill:#f9f,stroke:#333;classDef conv fill:#9cf,stroke:#333;classDef transformer fill:#c9f,stroke:#333;classDef output fill:#f96,stroke:#333;classDef aux fill:#6f9,stroke:#333;class A input;class B,C conv;class D,E,F,G,I transformer;class H,J output;class K,L,M,N aux;

详细架构说明

  1. 输入层

    • 接收形状为(2,128,512)的EEG信号
    • 2=批次大小,128=通道数,512=时间步长
  2. 1D块嵌入层(PatchEmbed1D)

    nn.Conv1d(in_chans=128, embed_dim=1024, kernel_size=4, stride=4)
    
    • 将512长度分割为128个块,每块大小4
    • 通过卷积将每个块投影到1024维
  3. 掩码处理(关键创新点)

    • 随机丢弃75%的块
    • 支持focus_range参数强化特定区域的掩码率
    • 保留的块与可学习掩码标记拼接
  4. Transformer编码器

    • 包含24个标准Transformer块
    • 每块含:
      Block(embed_dim=1024, num_heads=16, mlp_ratio=4)
      
    • 输出附加CLS标记的全局特征
  5. 双路解码器设计

    • EEG重建路径
      • 8层轻量Transformer
      • 最终全连接层还原原始维度
      • 损失函数:掩码块的MSE损失
    • 图像生成路径
      • 2层Transformer+转置卷积
      • 输出28×28图像特征
      • 损失函数:tanh空间的MSE
  6. 辅助任务头

    • 分类头
      nn.Sequential(nn.Conv1d(1281),nn.Linear(102440)
      )
      
    • 映射头
      nn.Sequential(nn.Conv1d(1281),nn.Linear(1024768)
      )
      

数据流动示例

前向传播
图像损失
重建损失
总损失
分块
EEG输入
掩码
编码
解码重建
图像生成
反向传播

关键超参数

模块参数名典型值作用说明
块嵌入patch_size4控制信号分割粒度
编码器depth24特征提取能力
注意力头num_heads16多注意力机制维度
图像解码器img_recon_weight1.0图像重建损失权重
掩码策略focus_rate0.5局部区域掩码增强强度

此架构特别适用于:

  • 脑电信号的特征学习
  • 跨模态(EEG→图像)生成任务
  • 自监督预训练

如果需要更详细的某部分实现细节(如Transformer块内部结构),可以进一步展开说明。

相关文章:

  • 基于maven-jar-plugin打造一款自动识别主类的maven打包插件
  • [Spring]SSM整合
  • 游戏引擎学习第238天:让 OpenGL 使用我们的屏幕坐标
  • 基于Redis实现RAG架构的技术解析与实践指南
  • idea中运行groovy程序报错
  • 【perf】perf工具的使用生成火焰图
  • 基于 OpenCV 的图像与视频处理
  • Kubernetes(k8s)学习笔记(二)--k8s 集群安装
  • React+TS编写轮播图
  • 计算机视觉cv入门之Haarcascade的基本使用方法(人脸识别为例)
  • 【后端】【Django】Django 模型中的 `clean()` 方法详解:数据校验的最后防线
  • 【人工智能】推荐开源企业级OCR大模型InternVL3
  • css3新特性第四章(渐变)
  • 【条形码识别改名工具】如何批量识别图片条形码,并以条码内容批量重命名,基于WPF和Zxing的开发总结
  • 【iOS】alloc init new底层原理
  • 嵌入式---零点漂移(Zero Drift)
  • 网络设备基础运维全攻略:华为/思科核心操作与巡检指南
  • IDEA多环节实现优雅配置
  • IDEA在Git提交时添加.ignore忽略文件,解决为什么Git中有时候使用.gitignore也无法忽略一些文件
  • 国际数据加密算法(IDEA)详解
  • 体坛联播|皇马补时绝杀毕尔巴鄂,利物浦最快下轮即可夺冠
  • 女子伸腿阻止高铁关门被拘,央媒:严格依规公开处理以儆效尤
  • 上海崇明“人鸟争食”何解?检察机关推动各方寻找最优解
  • 杨国荣丨阐释学的内涵与意义——张江《阐释学五辨》序
  • 为博眼球竟编造一女孩被活埋,公安机关公布10起谣言案件
  • 安徽省合肥市人大常委会原副主任杜平太接受审查调查