当前位置: 首页 > news >正文

Einsum(Einstein summation convention)

Einsum(Einstein summation convention)

笔记来源:
Permute和Reshape嫌麻烦?einsum来帮忙!

The Einstein summation convention is a notational shorthand used in tensor calculus, particularly in the fields of physics and mathematics, to simplify the representation of sums over indices in tensor equations. This convention is widely used in general relativity and other areas involving tensors.

爱因斯坦求和约定(Einstein summation convention)的一种函数,它通过指定的索引规则执行张量操作。爱因斯坦求和表示法使得高维张量运算更加直观灵活,比如可以用于更复杂的张量运算,而不仅限于简单的矩阵乘法

numpy.einsum(subscripts,*operands,out=None,dtype=None,order='K',casting='safe',optimize=False)
torch.einsum(equation,*operands)
tf.einsum(equation,*inputs,**kwargs)

1.1 什么是einsum?运算规则是什么?

在求和公式中,某些下标在等式两边都有出现(例如下标 i i i)而有些下标(被求和的维度)只出现在一侧(例如下标 j j j


所以即便是省略求和符号也不会产生歧义,即我们仍然知道哪个维度被求和了

这种运算与变量是什么并没有关系,因此上式可以进一步简化

1.在不同输入之间重复出现的索引表示沿着这一维度进行乘法(例如 k k k
2.只出现在输入中的索引表示在这一维度上求和(例如输出有 i , j i,j i,j,也就是说 k k k只出现在输入中)

C = torch.einsum('ik,kj->ij',A,B)
# 箭头和箭头右侧的可以省略
C = torch.einsum('ik,kj',A,B)
#等价于 
C = torch.matmul(A, B)

3.输出中维度的顺序可以是任意的(例如 i j ij ij j i ji ji
这里 C j i C_{ji} Cji就是 C i j C_{ij} Cij的转置

C = torch.einsum('ik,kj->ji',A,B)
# 省略号可以用于broadcasting,也就是忽略不关心的维度,只对最后两个维度进行计算
C = torch.einsum('...ik,...kj->ji',A,B)

1.2 Einsum怎么用?

向量外积



C = torch.einsum('i,j->ij',a,b)

提取对角元素


a = torch.einsum('kk->k',A)

1.3 Einum在多头注意力的应用

原版本

qkv:torch.Tensor self.qkv(x) #B,patches,3*dim
qkv = qkv.reshape(B, patches, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2,0,3,1,4)
q:torch.Tensor = qkv[0]
k:torch.Tensor = qkv[1]
v:torch.Tensor = qkv[2]
k_t = k.transpose(-2,-1)
attn = torch.softmax(q k_t self.scale,dim=-1)
attn = self.attn_drop(attn)
wa = attn @ v
wa = wa.transpose(1,2)
wa = wa.flatten(2)

使用Einum简化

q,k,v = map(lambda t:rearrange(t,'b n (h d)->b h n d',h=self.num_heads),qkv)
attn = torch.einsum('bijc,bikc -bijk',q,k)*self.scale
attn = attn.softmax(dim=-1)
x = torch.einsum('bijk,bikc -bijc',attn,v)
x = rearrange(x,'b i jc->b j (i c)')

(1)使用EINOPS库中的rearrange操作

q,k,v = map(lambda t:rearrange(t,'b n (h d)->b h n d',h=self.num_heads),qkv)


(2)q乘k转置除以缩放比例

attn = torch.einsum('bijc,bikc -bijk',q,k)*self.scale


(3)softmax得到attention数值

attn = attn.softmax(dim=-1)


(4)attention值对v加权

x = torch.einsum('bijk,bikc -bijc',attn,v)


(5)将x的维度还原为输入的形式

x = rearrange(x,'b i jc->b j (i c)')

[ B , h e a d , N , C / / h e a d ] − > [ B , N , C ] [B,head,N,C//head]->[B,N,C] [B,head,N,C//head]>[B,N,C]

1.4 Einsum优缺点

优点:

  1. 一次调用、一个函数完成多个操作
  2. 有时比多个Permute和Transpose操作组合的可读性高
  3. 可以避免生成中间变量

缺点:
求和表达式复杂时耗费内存,导致性能问题

相关文章:

  • 30天pandas挑战
  • 面试准备-6
  • 【Qt】qt发布Release版本,打包.exe可执行文件
  • 如何打造高校实验室教学管理系统?Java SpringBoot助力,MySQL存储优化,2025届必备设计指南
  • 手写登录页面,unique_ptr智能指针
  • 项目实战 ---- 商用落地视频搜索系统(7)---预处理二次优化
  • 海事行政执法证照片要求及尺寸格式修改方法
  • 虚幻中的c++(持续更新)
  • JVM 垃圾回收机制:GC
  • 计算机毕业设计 | SpringBoot+vue 游戏商城 steam网站管理系统(附源码)
  • 浅谈Unity协程的工作机制
  • 模版的价值工程
  • 内推|京东|后端开发|运维|算法...|北京 更多岗位扫内推码了解,直接投递,跟踪进度
  • CSS学习11--版心和布局流程以及几种分布的例子
  • 【C++二分查找 拆位法】2411. 按位或最大的最小子数组长度
  • Java | Leetcode Java题解之第390题消除游戏
  • Windows自动化应用程序已启动/未启动,有进程无进程情况-拽起应用程序
  • Percona 开源监控方案 PMM 详解
  • 爆改YOLOv8|利用图像分割网络UNetV2改进yolov8主干-即插即用
  • Modbus-RTU协议
  • 玉渊谭天丨中方减少美国农产品进口后,舟山港陆续出现巴西大豆船
  • 申花四连胜领跑中超联赛,下轮榜首大战对蓉城将是硬仗考验
  • 印巴在克什米尔实控线附近小规模交火,巴防长发出“全面战争”警告
  • 刘非履新浙江省委常委、杭州市委书记,曾在吉湘云多省任职
  • 我国首次实现地月距离尺度的卫星激光测距
  • 瑞士外长答澎湃:瑞中都愿升级自贸协定,关税战没有任何好处