当前位置: 首页 > news >正文

论文阅读笔记——Generating Long Sequences with Sparse Transformers

Sparse Transformer 论文
解决了 Transformer 在长序列建模时的计算开销和内存过大的问题。
可视化了一个 128 层自注意力在 CIFAR-10 的数据集上学习到的注意力模式,发现:1)稀疏性普遍存在:大多数层在多数数据点上表现出稀疏注意力;2)例外:部分层想要捕捉全局依赖关系。Transformer 的注意力机制呈现了和卷积模型类似的归纳偏置,即浅层的网络倾向于提取纹理信息,深层的网络倾向于提取语义信息。

分解自注意力(Factorized self-attention)

Local 自注意力只关注自身相邻的,其余设为 0,类似于卷积;Atrous 自注意力是跳着计算,类似膨胀卷积;一种简单思路是交替使用 Local 自注意力和 Atrous 自注意力。但 OpenAI 并没有这么做,而是将二者合为一。

在这里插入图片描述
由于 Transformer 的最复杂的计算是 Q K T QK^T QKT,稀疏注意力是让设置好的像素点参与注意力的计算。由此,引入了连接模式的变量 S = { S 1 , … … , S n } S=\{S_1,……,S_n\} S={S1,……,Sn}。其中 S i S_i Si 是在预测第 i 个时间片的索引,是一个由 0 和 1 组成的二维矩阵。
Attend ⁡ ( X , S ) = ( a ( x i , S i ) ) i ∈ { 1 , … , n } ( 2 ) a ( x i , S i ) = softmax ⁡ ( ( W q x i ) K S i T d ) V S i ( 3 ) K S i = ( W k x j ) j ∈ S i V S i = ( W v x j ) j ∈ S i ( 4 ) \begin{aligned} \operatorname{Attend}(X, S) = \left(a(\mathbf{x}_i, S_i)\right)_{i \in \{1, \ldots, n\}} \quad (2) \\a(\mathbf{x}_i, S_i) = \operatorname{softmax}\left(\frac{(W_q \mathbf{x}_i) K_{S_i}^T}{\sqrt{d}}\right) V_{S_i} \quad (3) \\K_{S_i} = \left(W_k \mathbf{x}_j\right)_{j \in S_i} \quad V_{S_i} = \left(W_v \mathbf{x}_j\right)_{j \in S_i} \quad (4) \end{aligned} Attend(X,S)=(a(xi,Si))i{1,,n}(2)a(xi,Si)=softmax(d (Wqxi)KSiT)VSi(3)KSi=(Wkxj)jSiVSi=(Wvxj)jSi(4)
其中 W q , W k , W v W_q,W_k,W_v Wq,Wk,Wv 是计算 Query,Key,Value 三个向量的权值矩阵。稀疏 Transformer 通过让链接模式作用到 K T K^T KT 上,从而降低 Q K T QK^T QKT 的复杂度

跨步注意力(Stride Attention) 由两种形式的连接模式组成。假设步长 l l l,行注意力是当前时间片的前 l l l 个时间片的值为 1,其余为 0;列注意力是每隔 l l l 个时间片段值为 1, 其余为 0。行注意力和列注意力的表达式如下,复杂度均为 O ( n ) O(\sqrt{n}) O(n )
A i ( 1 ) = { t , t + 1 , t + 2 , … … , i } , w h e r e t = m a x ( 0 , i − l ) A i ( 2 ) = { j : ( i − j ) m o d    l = 0 } \begin{aligned} A_i^{(1)}=\{t,t+1,t+2,……,i\},where\quad t = max(0,i-l) \\A_i^{(2)}=\{j:(i-j)\mod l =0\} \end{aligned} Ai(1)={t,t+1,t+2,……,i},wheret=max(0,il)Ai(2)={j:(ij)modl=0}
固定注意力(Fixed Attention) 也有行注意力和列注意力组成:
A i ( 1 ) = { j : ( [ j / l ] = [ i / l ] ) } A i ( 2 ) = { j : j m o d    l ∈ { t , t + 1 , … … , l } } \begin{aligned} A_i^{(1)}=\{j:([j/l]=[i/l])\} \\A_i^{(2)}=\{j:j\mod l \in\{t,t+1,……,l\}\} \end{aligned} Ai(1)={j:([j/l]=[i/l])}Ai(2)={j:jmodl{t,t+1,……,l}}
将以上注意力核融入网络中:

  • 每个残差块使用不同的注意力类型 a t t e n t i o n ( X ) = W p ⋅ a t t e n d ( X , A ( r m o d    p ) ) attention(X)=W_p·attend(X,A^{(r \mod p)}) attention(X)=Wpattend(X,A(rmodp)) 其中 r 是当前残差块的缩影,p 是注意力核的类别数;
  • 每个注意力头计算所有类型注意力核,合并他们的结果 a t t e n t i o n ( X ) = W p ⋅ a t t e n d ( X , ∪ m = 1 p A ( m ) attention(X)=W_p·attend(X,\cup_{m=1}^p A^{(m)} attention(X)=Wpattend(X,m=1pA(m)
  • 对于多头注意力,每个头选择一个注意力核,合并结果 a t t e n t i o n ( X ) = W p ( a t t e n d ( X , A ) ( i ) ) i ∈ { 1 , … … , n h } attention(X)=W_p(attend(X,A)^{(i)})_{i\in\{1,……,n_h\}} attention(X)=Wp(attend(X,A)(i))i{1,……,nh} 其中 n h n_h nh 组不同注意力核并行计算,然后在特征维度拼接。

多层 Transformer 训练

在这里插入图片描述

作者使用了在 ResNet v2 中提出的激活前置的残差模块,一个 N N N 层的网络可以表示为:
H 0 = e m b e d ( X , W e ) H k = H k − 1 + r e s b l o c k ( H k − 1 ) y = s o f t m a x ( n o r m ( H N ) W o u t ) \begin{aligned} H_0=embed(X,W_e) \\H_k=H_{k-1}+resblock(H_{k-1}) \\y=softmax(norm(H_N)W_{out}) \end{aligned} H0=embed(X,We)Hk=Hk1+resblock(Hk1)y=softmax(norm(HN)Wout)
其中 embed 是可学习的嵌入层: e m b e d ( X , W e ) = ( x i W e + ∑ j = 1 n e m b o i ( j ) W j ) embed(X,W_e)=\left(\boldsymbol{x}_iW_e+\sum_{j=1}^{n_{emb}}\boldsymbol{o}_i^{(j)}W_j\right) embed(X,We)=(xiWe+j=1nemboi(j)Wj) 其中 n e m b n_{emb} nemb 的值为 d d a t a d_{data} ddata d a t t n d_{attn} dattn x i \boldsymbol{x}_i xi 是序列中第 i 个元素的 one-hot 编码, o i ( j ) \boldsymbol{o}_i^{(j)} oi(j) x i \boldsymbol{x}_i xi 在第 j j j 维特征上的 one-hot 编码。
resblock(h) 由一个注意力模块和一个前馈神经网络组成:
a ( H ) = dropout( attention  ( n o r m ( H ) ) ) b ( H ) = d r o p o u t ( f f ( n o r m ( H + a ( H ) ) ) ) resblock ⁡ ( H ) = a ( H ) + b ( H ) \begin{gathered} a(H)=\text{dropout( attention }(\mathrm{norm}(H))) \\ b(H)=\mathrm{dropout}(\mathrm{ff}(\mathrm{norm}(H+a(H)))) \\ \operatorname{resblock}(H)=a(H)+b(H) \end{gathered} a(H)=dropout( attention (norm(H)))b(H)=dropout(ff(norm(H+a(H))))resblock(H)=a(H)+b(H)

梯度检查点

一个以时间换空间的一个策略,在反向传播的过程中,不是保存所有节点的参数值,而是只保留部分关键节点的值,然后通过这些关键节点反向推出其他节点的值。这样虽然引入了额外的节点参数的计算工作,但是大大节约了显存,从而使得训练更长的序列成为可能。

实验结果

在这里插入图片描述
在这里插入图片描述

相关文章:

  • Before After:SQL整容级优化
  • 学习八股的随机思考
  • Scratch037-(钢琴)
  • 数据库9(实验过程中补充学习)
  • 负氧离子是怎样产生的?
  • 百度网盘安卓版下载速度与储存体验分析
  • 2025年机电一体化、机器人与人工智能国际学术会议(MRAI 2025)
  • 解决在linux下运行rust/tauri项目出现窗口有内容,但是渲染出来成纯黑问题
  • 多语言编写的图片爬虫教程
  • Jmeter接口性能测试方案
  • (一)机器人仿真平台pybullet基础学习(操作记录)
  • yolov11设置n、m、s、l、x对应的模型大小
  • 服务器风扇故障导致过热问题的解决方案
  • 力扣面试150题—旋转图像和矩阵置零
  • Alembic 和 fbx存储结构和存储动画对比
  • 48、Spring Boot 详细讲义(五)
  • 最新扣子实战教程,利用扣子平台通过在线表格记录,批量生图,再也不要一条条的粘贴提示词了
  • 如何查看网页或任意文档中的颜色数值
  • 如何用DeepSeek大模型提升MySQL DBA工作效率?实战案例解析
  • 英飞凌TLE9891 +TLE5501 有感油泵FOC控制方案
  • 吕治国执掌全国唯一的热带海洋大学,曾从教育部“空降”海南
  • 收藏家尤伦斯辞世,曾是中国当代艺术的推手与收藏者
  • 一季度全社会用电量同比增长2.5%,3月增速显著回升
  • 42岁北京大学科学技术与医学史系副教授陈昊逝世
  • 许志强评《伐木》|伯恩哈德的文人共和国
  • TP-LINK4.36亿元竞得上海青浦徐泾办公地块,需引入全球领先的总部型企业