当前位置: 首页 > news >正文

Pytorch中layernorm实现详解

      平时我们在编写神经网络时,经常会用到layernorm这个函数来加快网络的收敛速度。那layernorm到底在哪个维度上进行归一化的呢? 

一、问题描述

    首先借用知乎上的一张图,原文写的也非常好,大家有空可以去阅读一下,链接放在参考文献里了。如左图所示,假设现在输入的维度是(bs,seq_len, embedding),其中bs代表batch_size, seq_len代表序列长度 ,embedding表示嵌入大小。

    那在layernorm时,我们是对(seq_len, embedding)这个矩阵取均值和方差(上图);还是只对embedding这个维度取均值和方差呢(下图)?前者会得到bs个均值和方差,而后者会得到bs * seq_len 个均值和方差。下面我们进行编程验证。

二、编程实现

import torch

batch_size, seq_size, dim = 2, 3, 4
embedding = torch.randn(batch_size, seq_size, dim)

layer_norm = torch.nn.LayerNorm(dim, elementwise_affine = False)
print("用pytorch的layer_norm所得结果\n", layer_norm(embedding))

print("自己编写layer_norm所得结果")
eps: float = 0.00001
mean = torch.mean(embedding[:, :, :], dim=(-1), keepdim=True)
var = torch.square(embedding[:, :, :] - mean).mean(dim=(-1), keepdim=True)

print("mean: ", mean.shape)
print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))

结果:

用pytorch的layer_norm所得结果
 tensor([[[ 0.7475, -1.7061,  0.6676,  0.2910],
         [ 0.1144, -0.6476,  1.5753, -1.0421],
         [-1.0278, -0.7498,  0.2559,  1.5218]],

        [[-1.0527, -0.8723,  1.3354,  0.5895],
         [-0.6403, -1.1399,  1.4842,  0.2961],
         [ 0.7352, -0.8236, -1.1342,  1.2226]]])
自己编写layer_norm所得结果
mean:  torch.Size([2, 3, 1])
y_custom:  tensor([[[ 0.7475, -1.7061,  0.6676,  0.2910],
         [ 0.1144, -0.6476,  1.5753, -1.0421],
         [-1.0278, -0.7498,  0.2559,  1.5218]],

        [[-1.0527, -0.8723,  1.3354,  0.5895],
         [-0.6403, -1.1399,  1.4842,  0.2961],
         [ 0.7352, -0.8236, -1.1342,  1.2226]]])

结果的相等的。可以看到,我们在取均值和方差时,是对最后一个维度取的。所以我们会得到 (N,C)个均值与方差。假设二是正确的。 

而实际上这种实现方法和Instance Norm是相同的

from torch.nn import InstanceNorm2d
instance_norm = InstanceNorm2d(3, affine=False)
x = torch.randn(2, 3, 4)
output = instance_norm(embedding.reshape(2,3,4,1)) #InstanceNorm2D需要(N,C,H,W)的shape作为输入
print(output.reshape(2,3,4))

layer_norm = torch.nn.LayerNorm(4, elementwise_affine = False)
print(layer_norm(x))

结果:

tensor([[[ 0.7475, -1.7061,  0.6676,  0.2910],
         [ 0.1144, -0.6476,  1.5753, -1.0421],
         [-1.0278, -0.7498,  0.2559,  1.5218]],

        [[-1.0527, -0.8723,  1.3354,  0.5895],
         [-0.6403, -1.1399,  1.4842,  0.2961],
         [ 0.7352, -0.8236, -1.1342,  1.2226]]])
tensor([[[ 0.1293, -1.0034,  1.5760, -0.7018],
         [-1.3981, -0.4828,  1.0876,  0.7933],
         [-1.7034,  0.8545,  0.4876,  0.3612]],

        [[-1.4750,  1.2212, -0.2607,  0.5144],
         [ 0.7017, -0.8350,  1.2502, -1.1169],
         [-1.7273,  0.6965,  0.5147,  0.5161]]])

三、参考文献

(45 封私信 / 80 条消息) 为什么Transformer要用LayerNorm? - 知乎 (zhihu.com)https://www.zhihu.com/question/487766088/answer/2644783144

相关文章:

  • C语言基础(函数)
  • 正则魔法:解码 return /^\d+$/.test(text) ? text : ‘0‘ 的秘密
  • 【笔记】深度学习模型训练的 GPU 内存优化之旅:重计算篇
  • 2025最新电脑IP地址修改方法:Win系统详细步骤
  • springboot使用163发送自定义html格式的邮件
  • 为什么TCP需要三次握手?一次不行吗?
  • 【Servlet 容器和 Spring 容器的关系】
  • 人工智能之数学基础:线性方程组
  • mysql-innodb存储引擎主键索引叶子结点数据结构(非单纯的双向链表)
  • PyCharm安装redis,python安装redis,PyCharm使用失败问题
  • WPF 布局舍入(WPF 边框模糊 或 像素错位 的问题)
  • Datawhale coze-ai-assistant 笔记4
  • 16 预编译指令
  • 再学:ERC20-Permit2、SafeERC20方法 详解ERC721,如何铸造一个NFT以及IPFS的作用
  • 进程控制~
  • 【宇宙回响】从Canvas到MySQL:飞机大战的全栈交响曲【附演示视频与源码】
  • 普通鼠标的500连击的工具来了!!!
  • 【MySQL】MySQL登录,访问,退出操作
  • 微软Data Formulator:用AI重塑数据可视化的未来
  • 突破时空边界:Java实时流处理中窗口操作与时间语义的深度重构
  • 上海浦东单价超10万楼盘228套房源开盘当天售罄,4月已有三个新盘“日光”
  • 楼下电瓶车起火老夫妻逃生时被烧伤,消防解析躲火避烟注意事项
  • 伊朗港口爆炸事件已致195人受伤
  • 临沂文旅集团被诉侵权,原告每年三百余起类案
  • 青海一只人工繁育秃鹫雏鸟破壳后脱险成活,有望填补国内空白
  • 人民论坛:是民生小事,也是融合大势