12.FFN基于位置的前馈网络
从入门AI到手写Transformer-12.FFN基于位置的前馈网络
- 12.FFN基于位置的前馈网络
- 代码
整理自视频 老袁不说话 。
12.FFN基于位置的前馈网络
之间讲了残差连接,规范化。
这节是FFN基于位置的前馈网络,其实就是MLP,两层线性连接层。
输入 X : [ b s , n , d ] X:[bs,n,d] X:[bs,n,d]
第一层: Y = X W 1 + B 1 W 1 : [ d , d m ] Y=XW_1+B_1\quad W1:[d,d_m] Y=XW1+B1W1:[d,dm],代入 Y Y Y, Y = R e L U ( X ) Y=ReLU(X) Y=ReLU(X),输出 Y : [ b s , n , d m ] Y:[bs,n,d_m] Y:[bs,n,dm]
第二层: Y = X W 2 + B 2 W 1 : [ d m , d ] Y=XW_2+B_2\quad W1:[d_m,d] Y=XW2+B2W1:[dm,d],输出 Y : [ b s , n , d ] Y:[bs,n,d] Y:[bs,n,d]
代码
import torch
from torch import nnclass FFN(nn.Module):# dm=4*ddef __init__(self,d,dm,*args,**kwargs)->None:super(FFN,self).__init__(*args,**kwargs)self.dense1=nn.Linear(d,dm) # weight:[dm,d] bias:[dm]self.relu=nn.ReLU()self.dense2=nn.Linear(dm,d) # weight:[d,dm] bias:[d]def forward(self,X):Y=self.dense1(X)Y=self.relu(Y)Y=self.dense2(Y)return YX=torch.randn(3,5,10)
ffn=FFN(10,40)
o=ffn(X)
print(o.shape)