当前位置: 首页 > news >正文

【PyTorch][chapter-39][MOE][Mixtral of experts -PyTorch】[4]

前言:

    这里面重点通过PyTorch 实现Transformer MoE的模型部分

 

主要有两种架构:

1 一个 Transformer 编码器内部有多个专家。 

2  以整体 Transformer 编码器为专家


目录: 

1:整个 transformer encoder  作为  expert

2:以整体 Transformer 编码器为expert


一: 一个 Transformer 编码器内部有多个专家。 

        把 Transformer  FFN 换成多个MoE

    里面的Expert采用了FFN 模型

   

# -*- coding: utf-8 -*-
"""
Created on Mon Mar 24 11:45:03 2025

@author: chengxf2
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    """与 Transformer Encoder Layers 中的 FFN 类似的 FFN 专家。""" 
    def __init__(self, d_model=512, hidden_dim=1024):
        super(Expert, self).__init__()
        self.input_projection  = nn.Linear(d_model, hidden_dim)
        self.output_projection =  nn.Linear(hidden_dim, d_model)
        self.activation = nn.ReLU()
    
    def forward(self, x):
        x = self.input_projection(x)
        x = self.activation(x)
        output = self.output_projection(x)
        return output

class Router(nn.Module):
    """用于将 token 分发给专家的路由器。""" 
    def __init__(self, d_model, num_experts):
        super(Router, self).__init__()
        self.layer = nn.Linear(d_model, num_experts)
    
    def  forward(self, x):
        z = self.layer(x)
        output = F.softmax(z,dim=-1)
        return output

class MoE(nn.Module):
    def __init__(self, d_model, num_experts, hidden_dim, top_k=2):
        super(MoE, self).__init__()
        self.experts = nn.ModuleList([Expert(d_model,hidden_dim) for i in range(num_experts)])
        self.router = Router(d_model, num_experts)
        self.top_k = top_k
    
    def forward(self, x):
       # 为路由器展平为 (token_num, d_model)
       #其中 token_num = batch_size*seq_len

       routing_weights = self.router(x)
       topk_vals, topk_indices = torch.topk(routing_weights, self.top_k, dim=1)
       topk_vals_normalized    = topk_vals / topk_vals.sum(dim=1, keepdim=True)
       
       outputs = torch.zeros_like(x)
       print("\n topk_vals.shape ",topk_vals.shape)

       for i , expert in enumerate(self.experts):
           #expert_mask.shape: [token_num, top_k]
           expert_mask = (topk_indices==i).float()
           if expert_mask.any():
               #token choice
               #input_to_expert = x.unsqueeze(1).repeat(1,self.top_k,1)*expert_mask
               #inputs_to_expert = x*expert_mask.unsqueeze(-1)
               expert_mask = expert_mask.unsqueeze(-1)
               #print("\n x",x.shape, "\t expert_mask",expert_mask.shape)
               inputs_to_expert = torch.mul(x.unsqueeze(1), expert_mask)
               
               expert_output = expert(inputs_to_expert)
               #print("\n expert_output: ",expert_output.shape, "\t topk_vals_normalized ", topk_vals_normalized.shape)
               weighted_expert_outputs = expert_output * topk_vals_normalized.unsqueeze(-1)
               outputs += weighted_expert_outputs.sum(dim=1)
       return outputs
   
class  TransformerEncoderLayerWithMoE (nn.Module): 
     def __init__(self, d_model, nhead, num_experts, hidden_dim, dropout, top_k):
        super (TransformerEncoderLayerWithMoE, self).__init__() 
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.moe = MoE(d_model, num_experts, hidden_dim,top_k)
        # 前馈模型的实现
        self.dropout = nn.Dropout(dropout) 
        self.norm1 = nn.LayerNorm(d_model) 
        self.norm2 = nn.LayerNorm(d_model) 
        
     def forward(self, src=None, src_mask=None, src_key_padding_mask=None):
        src2 = self.self_attn(src, src, src, attn_mask=src_mask, 
                              key_padding_mask=src_key_padding_mask)[ 0 ] 
        src = src + self.dropout(src2) 
        src = self.norm1(src) 
        #print('\n src ',src.shape)
        # 利用专家混合
        batch_size, seq_len, d_model = src.shape
        src = src.view(-1,d_model)
        
        src2 = self.moe(src) 
        src = src + self.dropout(src2) 
        src = self.norm2(src) 
        src = src.view(batch_size, seq_len,d_model)
        return src

    #Step 3: Initialize the model, the loss and optimizer
num_experts = 8
d_model = 512
nhead = 8
hidden_dim = 1024
dropout = 0.1 
num_layers = 3
batch_size = 2 
seq_len = 3
d_model = 512
top_k = 2
x = torch.randn(batch_size,seq_len, d_model)

input_dim= 512
# Flatten to (batch_size*seq_len, d_model) for the router
model = TransformerEncoderLayerWithMoE(d_model, nhead, num_experts, hidden_dim, dropout, top_k)
output = model(x)
print(output.shape)
           
       
       

        


二   整个 transformer encoder 作为 expert

         在这种方法中,我们用多个 Transformer 编码器内部的前馈网络(FFN)作为专家,如下图

     Step 1: Build an Expert Network

     Step 2: Build the Mixture of Experts

     Step 3: Initialize the model, the loss and optimizer

      Step 4: Train the model

# -*- coding: utf-8 -*-
"""
Created on Fri Mar 21 15:58:13 2025

@author: chengxf2
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    def __init__(self, d_model, nhead, num_layers, input_dim, output_dim):
        super(Expert, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.input_projection = nn.Linear(input_dim, d_model)
        self.output_projection = nn.Linear(d_model, output_dim)

    def forward(self, x):
        x = self.input_projection(x)
        x = self.transformer_encoder(x.unsqueeze(0)).squeeze(0)  # Transformer expects (S, N, E), adjusting for N=1
        output =  self.output_projection(x)
        return output
    
class MixtureOfExperts(nn.Module):
    def __init__(self, num_experts, d_model, nhead, num_layers, input_size, output_size):
        super(MixtureOfExperts, self).__init__()
        self.experts = nn.ModuleList([Expert(d_model, nhead, num_layers, input_size, output_size) for _ in range(num_experts)])
        self.gates = nn.Linear(input_size, num_experts)

    def forward(self, x):
        weights = F.softmax(self.gates(x), dim=1)
        outputs = torch.stack([expert(x) for expert in self.experts], dim=2)
        return (weights.unsqueeze(2) * outputs).sum(dim=2)
    

class Router(nn.Module):
    def __init__(self, input_dim=512, num_experts=8):
        super().__init__()
        self.layer = nn.Linear(input_dim, num_experts)
    
    def forward(self, x):
        
        z = self.layer(x)
        output =  F.softmax(z, dim=-1)
        #print("\n input.shape",x.shape,"\t output", output.shape)
        return output

class MoE(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts, d_model, nhead, num_layers, top_k=2):
        super().__init__()
        self.experts = nn.ModuleList([Expert(d_model, nhead, num_layers, input_dim, output_dim) for _ in range(num_experts)])
        self.router = Router(input_dim, num_experts)
        self.output_dim = output_dim
        self.top_k = top_k
    
    def forward(self, x):
        token_num = x.size(0)
        #Flatten to (batch_size*seq_len, d_model) for the router
        routing_weights = self.router(x)
        topk_vals, topk_indices = torch.topk(routing_weights, self.top_k, dim=1)
        print("\n topk_vals ",topk_vals.shape, "\n topk_indices",topk_indices)
        topk_vals_normalized = topk_vals / topk_vals.sum(dim=1, keepdim=True)
        outputs = torch.zeros(token_num, self.output_dim, device=x.device)
        
        for i, expert in enumerate(self.experts):
            #当前的expert: token choice
            token_choice = (topk_indices == i).float()
            #print("\n token_choice \n",i,token_choice)
            #print(f"expert{i} \n",expert_mask)
            if token_choice.any():
                #[token_num, k]
                d_model = x.size(1)
                #[token_num,top_k, d_model]
                expert_mask = token_choice.unsqueeze(-1)
                #print(expert_mask.shape)
                expert_mask= expert_mask.expand(-1, -1, d_model)
                inputs_to_expert = x.unsqueeze(1).repeat(1, self.top_k, 1) * expert_mask
                #稀疏
                inputs_to_expert = inputs_to_expert.view(-1, d_model)
                expert_outputs =   expert(inputs_to_expert).view(token_num, self.top_k, -1)
                # Weight outputs by normalized routing probability and sum across selected experts
                weighted_expert_outputs = expert_outputs * topk_vals_normalized.unsqueeze(-1)
                outputs += weighted_expert_outputs.sum(dim=1)
        return outputs
    
def train():
   
    #Step 3: Initialize the model, the loss and optimizer
    num_experts = 8
    d_model = 512
    nhead = 8
    num_layers = 3
    batch_size = 2 
    seq_len = 3
    d_model = 512
    top_k = 2
    x = torch.randn(batch_size,seq_len, d_model)
    x = x.view(-1,d_model)
    input_dim= 512
    # Flatten to (batch_size*seq_len, d_model) for the router
    model = MoE(input_dim, num_experts, num_experts, d_model, nhead, num_layers, top_k)
    output = model(x)
    print(output.shape)
    
    '''
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    for epoch in range(100):
        for i, data in enumerate(dataloader):  # Assume dataloader is defined and provides input and target data
            inputs, targets = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
        print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

    '''
        
train()    
        






 

Mixtral of experts | Mistral AI

相关文章:

  • Python----计算机视觉处理(Opencv:图像亮度变换)
  • 页面只有一个搜索框 如何按下enter键阻止页面的提交表单默认行为
  • vue3 如何清空 let arr = reactive([])
  • css 控制彩带图片从左到右逐渐显示有画出来的感觉
  • linux如何释放内存缓存
  • (windows)conda虚拟环境下open-webui安装与启动
  • 为什么后端接口返回数字类型1.00前端会取到1?
  • 【颠覆性缓存架构】Caffeine双引擎缓存实战:CPU和内存双优化,命中率提升到92%,内存减少75%
  • AI大白话(五):计算机视觉——AI是如何“看“世界的?
  • kotlin init执行顺序
  • 制作PaddleOCR/PaddleHub的Docker镜像
  • 解决 IntelliJ IDEA 方法断点导致程序无法运行的问题
  • 气象可视化卫星云图的方式:方法与架构详解
  • Python----计算机视觉处理(Opencv:霍夫变换)
  • Mysql中各种连接的区别
  • 父子组件传递数据和状态管理数据
  • PaddleHub-GPU镜像制作
  • 2025.03.23【前沿工具】| CellPhoneDB:基因网络分析与可视化的利器
  • 面试题分享-多线程顺序打印奇偶数
  • SpringBoot2集成Elasticsearch8(使用spring-boot-starter-data-elasticsearch)
  • 五一假期上海地铁部分线路将延时运营,这些调整请查收
  • 广东雷州农商行董事长、原行长同日被查
  • 夜读丨怀念那个写信的年代
  • 百岁太极拳大师、陈氏太极拳第十一代嫡宗传人陈全忠逝世
  • 清华成立人工智能医院,将构建“AI+医疗+教育+科研”闭环
  • 涨价应对关税变化是短期之策,跨境电商塑造新品牌开辟“新蓝海”