当前位置: 首页 > news >正文

【AI模型学习】Swin Transformer——优雅的模型

文章目录

    • 一、论文背景:从 ViT 到 Swin
      • 1.1 ViT 的局限
      • 1.2 Swin Transformer的目标
    • 二、模型
      • 2.1 Patch Partition 和 Patch Embedding
      • 2.2 Window Multi-head Self Attention (W-MSA)
      • 2.3 Shifted Window Multi-head Self Attention (SW-MSA)
      • 2.3 Patch Merging
      • 2.4 金字塔结构
    • 三、细节
      • 3.1 关于复杂度
      • 3.2 Mask 的具体实现细节
      • 3.3 下游任务
    • 四、实验
      • 4.1 主实验结果汇总
      • 4.2 消融实验

一、论文背景:从 ViT 到 Swin

Transformer 模型在 NLP 领域大获成功之后,ViT(Vision Transformer) 率先把它引入了视觉任务,用 patch 的方式处理图像并直接送入 Transformer 编码器,在大规模数据下表现优异。但它也带来了一些严重的问题:

1.1 ViT 的局限

  1. 计算复杂度高
    全局自注意力(Self-Attention)的计算复杂度是 O ( N 2 ) O(N^2) O(N2),其中 N N N 是图像 patch 的数量。图像越大,patch 越多,计算量激增。

  2. 缺乏局部性 inductive bias
    CNN 有平移不变性(Translation Equivariance)和局部感知能力,而原始 ViT 完全依赖大数据学习这些特性,效率低。

  3. 不适配金字塔结构
    CNN 的经典架构具有“金字塔式”的多尺度特性,有利于分割、检测等密集预测任务;ViT 原始版本处理的是固定大小的 patch,结构过于平坦,不易扩展。


1.2 Swin Transformer的目标

Swin Transformer:

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
—— 发表在 ICCV 2021,作者来自微软亚洲研究院。

  1. 提出一种计算友好的局部自注意力机制(Window-based Attention),代替全局 Attention。
  2. 构造具有层次化结构的 Transformer,使其具备 CNN 那种金字塔式的表达能力。
  3. 通过滑窗策略(Shifted Windows)增强跨窗口的信息交流,避免窗口孤立。
  4. 最终能在多个任务中泛化,如分类、检测、分割等,成为通用视觉骨干。

Swin Transformer 试图解决 Vision Transformer 在计算效率、本地性建模、多尺度表达上的缺陷,用 滑动窗口 + 层次化设计 将 Transformer 带入实际视觉任务场景。

文章要学习的点挺多。

  • Patch Partition 和 Patch Embedding
  • Window Multi-head Self Attention (W-MSA)
  • Shifted Window Multi-head Self Attention (SW-MSA)
  • Patch Merging 的金字塔构建方式

论文链接(arXiv)如下:
https://arxiv.org/abs/2103.14030

PDF 下载地址(直接下载):
https://arxiv.org/pdf/2103.14030.pdf


二、模型

2.1 Patch Partition 和 Patch Embedding

输入图像的维度是:

(B, H, W, C)
  • B = batch size
  • H, W = 图像的高和宽(比如 224)
  • C = 图像通道数(RGB → 3)

Swin 使用 4 × 4 大小的 patch,把图像分块,每个 patch 当成一个 token。

  1. 划分 patch(Patch Partition)

    把整张图像切成 4 × 4 的小块,每个 patch 包含:

    4(height) × 4(width) × 3(channel) = 48 个数
    
  2. 计算 patch 数量

    每张图像被切成:

    (H / 4) × (W / 4) = 56 × 56 
    
  3. reshape 操作

    把输入变换为:

    (B, 56, 56, 4, 4, 3) → reshape → (B, 56, 56, 48)
    
  4. Linear Projection(Patch Embedding)

    用一个线性层把每个 patch 从 48 维投影到更高维,比如 96:

    (B, 56, 56, 48) → Linear → (B, 56, 56, 96)
    

输出:

(B, 56, 56, 96)
  • 表示每张图像被编码为 56×56 个 token(patch)
  • 每个 token 是一个 96 维的向量

patch之间"拍瘪":

(B, 3136, 96)  # 方便送入 Transformer 做 Attention

2.2 Window Multi-head Self Attention (W-MSA)

为什么不用“全局 Self-Attention”?

传统 Transformer 的自注意力是全局的,也就是说,每个 token 都和其他所有 token 做注意力计算。

如果我们现在的输入是:

(B, 3136, 96)   # 即 56×56 个 patch,每个 patch 是 96 维向量

那么计算注意力的代价是:

O(N² × D) = O(3136² × 96)  # 太大太大了,显存爆炸 

解决办法:Window-based Attention(W-MSA)

核心思想:

不在整幅图上计算 attention,只在每个小窗口内进行 self-attention

具体流程

假设输入是:

(B, 56, 56, 96)
  1. 划分成小窗口(默认窗口大小是 7 × 7)
    我们把每张图像划分成不重叠的窗口:

    • 每个窗口:7×7 = 49 个 patch
    • 每张图:56 / 7 = 8 行 × 8 列 → 一共 64 个窗口
  2. reshape 成窗口批处理格式
    我们将每个窗口提取出来,构造成批处理格式:

    (B, 56, 56, 96)
    → reshape → (B * 64, 7 * 7, 96) = (B * 64, 49, 96)
    
    • 每个窗口有 49 个 token(patch)
    • 每个 token 是 96 维
    • 一共 B×64 个窗口
  3. 在每个窗口中做 Multi-head Self-Attention
    对每个窗口内部做标准的 MSA:

    (B * 64, 49, 96)
    ↓ W-MSA(每个窗口 attention)
    → (B * 64, 49, 96)
    
  4. 把所有窗口还原回原图结构
    恢复原始的空间布局:

    (B * 64, 49, 96)
    → reshape → (B, 56, 56, 96)
    

总结流程

Input:      (B, 56, 56, 96)
↓ 划分 7×7 窗口
→ reshape: (B*64, 49, 96)
↓ 每个窗口内做 MSA
→ (B*64, 49, 96)
↓ reshape 回原图形状
Output:     (B, 56, 56, 96)
项目内容
Attention 类型每个窗口内部做
计算复杂度从 O(N²) 降为 O(M × K²),其中 M 是窗口个数,K 是窗口大小
优点大幅减少计算量,适配高分辨率图像
缺点窗口之间无法直接通信(后面用 SW-MSA 弥补!)

2.3 Shifted Window Multi-head Self Attention (SW-MSA)

Swin Transformer 的王牌设计

目的:解决 W-MSA 的“信息孤岛”问题

W-MSA 中,我们把图像划分成不重叠的小窗口,每个窗口自己做注意力,虽然节省了计算量,但也带来了一个问题:

每个窗口之间完全没交流,只能看到自己这一块。

也就是说,窗口 A 里的 token 不知道窗口 B 发生了什么,这对捕捉全局关系是致命的限制。

SW-MSA 的核心想法:

在下一层,把窗口整体“滑动”一半(shift),让新的窗口能覆盖多个旧窗口,打破原有边界,形成跨窗口信息流动!

在这里插入图片描述

具体流程

我们接着上一步,输入仍是:

(B, 56, 56, 96)
  1. 平移窗口(Shift)

    将整张 feature map 在空间维度上平移:

    向下平移 3 行、向右平移 3 列(因为窗口大小是 7,所以 7 // 2 = 3)
    

    这时候,原来相邻但是在不同窗口的 patch 被划入了同一个窗口中。

    注意:平移过程中会打破原来的窗口对齐方式,有些位置会被“移出边界”,需要后续处理。(在mask中会讲到)

  2. 遮挡 Mask(mask padding region)

    由于平移后窗口不再对齐,需要对超出边界的地方进行“掩码”(mask padding),避免注意力看到无效区域。

    论文中的做法是构造一个 attention mask,屏蔽跨 block 的 attention。(之后再讲)

  3. 对 Shift 后的新窗口做 W-MSA

    将新窗口也划为 7 × 7,即使现在不是对齐划分的,也一样处理:

    (B, 56, 56, 96)
    ↓ shift → 平移后的特征图
    ↓ window partition
    → reshape → (B * 64, 49, 96)
    ↓ attention with mask
    → (B * 64, 49, 96)
    ↓ reshape → (B, 56, 56, 96)
    
  4. shift back(还原位移)

    Attention 完成后,把平移回来的特征图“再平移回原位置”,确保位置信息不乱套。

总结

Input:       (B, 56, 56, 96)
↓ shift      → 位移(往右下滑 3)
↓ window partition + mask attention
↓ W-MSA      → 每个滑动窗口内部做 MSA
↓ shift back → 恢复原始位置
Output:      (B, 56, 56, 96)

对比:W-MSA vs SW-MSA

特性W-MSASW-MSA
窗口划分固定、规则、无重叠滑动、打破边界
是否跨窗口交流
是否需要掩码不需要需要 attention mask
是否恢复位移不需要需要 shift back

两者是交替使用的!

每个 Swin Block 都是:

W-MSA → SW-MSA → W-MSA → SW-MSA → ...

这样就能:

  1. 降低计算量(小窗口)
  2. 建立全局联系(滑动窗口)

2.3 Patch Merging

Patch Merging 是 Stage 之间的下采样过程,和 CNN 的 pool/stride 类似。
觉得自己讲有点抽象,干脆找了一篇CSDN的博客。
ViT中的上采样和下采样——patch merge
假设当前输入维度是:

(B, H, W, C)

先把图片一分为4,如何分的,可以看上面的文章。图片宽高都减半,将4张子图在通道维拼接。

(B, H, W, C) ->  (B, H/2, W/2, 4C)

CV一般“惯例”都是池化的时候,高宽减半,通道翻倍。我们这里翻了4倍,所以做一个全连接。

(B, H/2, W/2, 4C) -> (B, H/2, W/2, 2C)

2.4 金字塔结构

在这里插入图片描述

Swin Transformer 不像 ViT 那样一直保持固定大小的 feature,而是像 CNN 一样,每过一个 Stage:

  • 减小空间尺寸(下采样)
  • 增加通道数(增强表达)

每个 SwinBlock 包含两大部分:

  1. Window-based Multi-head Self-Attention(W-MSA SW-MSA)模块
  2. MLP(前馈神经网络)
SwinBlock:
├── LN
├── W-MSA / SW-MSA
├── Dropout + Residual
├── LN
├── MLP (Linear → GELU → Linear)
├── Dropout + Residual

SwinBlock之间通过Patch Merging进行“池化”

Stage 流程总览

每一层结构如下:

[Stage n]
↓ SwinBlock × L
↓ Patch Merging

举个完整流程(Swin-Tiny):

Stage 1:Input       : (B, 56, 56, 96)SwinBlock×2 : (B, 56, 56, 96)PatchMerge  : (B, 28, 28, 192)Stage 2:SwinBlock×2 : (B, 28, 28, 192)PatchMerge  : (B, 14, 14, 384)Stage 3:SwinBlock×6 : (B, 14, 14, 384)PatchMerge  : (B, 7, 7, 768)Stage 4:SwinBlock×2 : (B, 7, 7, 768)

最后接上:

↓ Global Average Pooling → (B, 768)
↓ Linear classifier      → (B, num_classes)

三、细节

3.1 关于复杂度

这里可以直接看李沐老师的讲论文系列
Swin Transformer:Hierarchical Vision Transformer using Shifted Windows

Self-Attention 的复杂度

项目复杂度表达式说明
Q/K/V 投影 3 N D 2 3ND^2 3ND2三个线性映射
Q K ⊤ QK^\top QK N 2 D N^2D N2D所有 head 共享序列长度 N
Softmax × V N 2 D N^2D N2D同样按每个 head 分别处理
输出线性投影 N D 2 ND^2 ND2拼接后投影回 D 维
总计 4 N D 2 + 2 N 2 D 4ND^2 + 2N^2D 4ND2+2N2D主要瓶颈是 N 2 D N^2D N2D(注意力部分)

ViT 的 Attention 复杂度(全局)

ViT 默认每张图像切成 16 × 16 16 \times 16 16×16 patch,输入大小为 224 × 224 224 \times 224 224×224

  • 每张图像的 token 数:
    N = (224 / 16)^2 = 14 × 14 = 196
    
  • 假设 token embedding 维度 D = 768 D = 768 D=768

代入公式:

C ViT = 4 N D 2 + 2 N 2 D = 4 ⋅ 196 ⋅ 76 8 2 + 2 ⋅ 19 6 2 ⋅ 768 ≈ 451 M + 59 M ≈ 510 M \mathcal{C}_\text{ViT} = 4ND^2 + 2N^2D = 4 \cdot 196 \cdot 768^2 + 2 \cdot 196^2 \cdot 768 ≈ 451M + 59M ≈ 510M CViT=4ND2+2N2D=41967682+21962768451M+59M510M

Swin 的 Attention 复杂度(局部窗口)

Swin 使用 4 × 4 4 \times 4 4×4 patch → 得到 56 × 56 = 3136 56 \times 56 = 3136 56×56=3136 个 patch
窗口大小是 7 × 7 = 49 7 \times 7 = 49 7×7=49

于是我们可以这么计算:

  • 每个窗口: K = 49 K = 49 K=49 个 token
  • 窗口个数: M = 56 ⋅ 56 / ( 7 ⋅ 7 ) = 3136 / 49 = 64 M = 56 \cdot 56 / (7 \cdot 7) = 3136 / 49 = 64 M=5656/(77)=3136/49=64
  • 每个窗口做 attention,其复杂度是:
    C window = 4 K d 2 + 2 K 2 d \mathcal{C}_\text{window} = 4Kd^2 + 2K^2d Cwindow=4Kd2+2K2d
    其中 d = D / h d = D / h d=D/h,假设 h = 12 h=12 h=12 d = 64 d = 64 d=64

那我们代入:

  • K = 49 K = 49 K=49
  • d = 64 d = 64 d=64
  • 每窗口复杂度: 4 ⋅ 49 ⋅ 6 4 2 + 2 ⋅ 4 9 2 ⋅ 64 = 802 , 816 + 307 , 328 = 1 , 110 , 144 4 \cdot 49 \cdot 64^2 + 2 \cdot 49^2 \cdot 64 = 802,816 + 307,328 = 1,110,144 449642+249264=802,816+307,328=1,110,144
  • 总复杂度 = 窗口数 × 每窗口复杂度:
    C Swin = 64 ⋅ 1.11 M ≈ 71 M \mathcal{C}_\text{Swin} = 64 \cdot 1.11M ≈71M CSwin=641.11M71M

3.2 Mask 的具体实现细节

我们要做的,是在 SW-MSA 中生成一个掩码张量,用于在 softmax 前屏蔽非法 attention。这个 mask 形状是:

(1, num_windows, window_size*window_size, window_size*window_size)
  • 每个窗口一个 mask
  • 每个窗口内的 token 对之间有一个标志:是同一区域还是不同区域
  1. 给每个 patch 分配区域编号

    我们构造一个虚拟的 H × W 的二维矩阵 img_mask,每一块 window_size × window_size 区域赋一个唯一的整数 ID,标识它的原始区域编号:

    img_mask = torch.zeros((1, H, W, 1))  # shape: (1, H, W, 1)
    cnt = 0
    for h in range(0, H, window_size):for w in range(0, W, window_size):img_mask[:, h:h+window_size, w:w+window_size, :] = cntcnt += 1
    

    这样每个 7x7 区块都有唯一编号。

  2. img_mask 做 shift(滑动窗口)

    shifted_mask = torch.roll(img_mask, shifts=(-shift_size, -shift_size), dims=(1, 2))
    

    比如 shift_size = 3 时,相当于向上左滑动了 3 个像素。

  3. 再按窗口划分,并 reshape 成 attention 的输入格式

    我们把 shifted 后的 mask 分成窗口,每个窗口是一个 ( w i n d o w _ s i z e × w i n d o w _ s i z e ) (window\_size \times window\_size) (window_size×window_size) 区域:

    mask_windows = window_partition(shifted_mask, window_size)  # shape: (num_windows, window_size, window_size, 1)
    mask_windows = mask_windows.view(-1, window_size * window_size)
    

    现在每个窗口内是一个 shape 为 (49,) 的向量,每个值是区域 ID。

  4. 构造 attention mask(核心)

    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float('-inf')).masked_fill(attn_mask == 0, 0)
    
    • attn_mask[i, j] 表示第 i i i 个 token 和第 j j j 个 token 的差值
      • 如果来自不同原始区域 → 不等 → 设为 -inf
      • 如果来自同一区域 → 相等 → 设为 0

    最终你就得到一个形状为:

    (num_windows, 49, 49)
    

    的 attention mask,在后续计算中会加到注意力 logits 上:

    attn_scores = (Q @ K^T) / sqrt(d)
    attn_scores += attn_mask
    attn_probs = softmax(attn_scores)
    
步骤操作目的
1构造区域编号矩阵 img_mask标记不同窗口的 token
2img_mask 做 shift模拟 SW-MSA 的窗口
3按窗口切块 → reshape获取窗口内的 token 区域分布
4构造差值矩阵 + masked_fill设置 cross-window 的 attention 为 -inf
  • 掩码是 与输入内容无关的结构掩码,可以在模型初始化时就构建好(只依赖窗口大小与输入尺寸)
  • torch.roll 是非常高效的窗口滑动方式
  • 使用 .masked_fill() 是 PyTorch 中处理 Attention Mask 的标准做法
  • 注意:在 Swin 中窗口是 batch-wise共享的,所以 mask 只构建一次即可复用

3.3 下游任务

一、Swin-T 如何接入各种类型的下游任务?

Swin-T 作为视觉 Transformer,可以看作 CNN 的替代 backbone,输出 层次化的特征图,非常适合各种任务中模块化插入。

  1. 分类任务(Image Classification)

    • Stage 4 的输出 (B, 7, 7, 768)
    • 做一次 Global Average Pooling
    • 接一个 Linear 分类头
  2. 检测任务(Object Detection)

    目标检测通常需要多层次特征,因此 Swin 的多 stage 输出非常合适:

    将 Swin 的各个 Stage 输出喂入 FPN(Feature Pyramid Network)

    Swin stages → [C1, C2, C3, C4]↓FPN↓Detection Head (e.g., RetinaNet, Faster R-CNN)
    

    其中:

    StageOutput Shape用于替代哪层
    C156 × 56 × 96ResNet C2
    C228 × 28 × 192ResNet C3
    C314 × 14 × 384ResNet C4
    C47 × 7 × 768ResNet C5
  3. 语义分割(Semantic Segmentation)

    语义分割也需要多尺度语义信息,Swin 可以直接连接到分割头,如:

    输出四层 [C1, C2, C3, C4],拼接 or 上采样对齐后,接 Segmentation Head

    Swin Features → UPerNet / FPN Head → Pixel-wise 分类
    
  4. 实例分割(Instance Segmentation)

    和目标检测类似,通常基于 Mask R-CNN 等结构,只需将 Swin 替代 ResNet:

    Swin Backbone → FPN → ROIAlign → BBox + Mask Head
    
  5. 视频理解(Action Recognition)

    Swin 可以扩展到时序维度 → 变为 Swin-Video Transformer:

    • 在时间轴上引入 attention(Swin-ViT 3D)
    • 用于视频分类、时序分割等任务

二、Swin 的预训练方式

  1. ImageNet 预训练(分类任务通用)

  2. Masked Image Modeling(自监督)

  3. Detection / Segmentation 预训练(针对 COCO / ADE20K)

  4. 视频预训练(Swin-Video)


四、实验

4.1 主实验结果汇总

模型ParamsFLOPsTop-1 Acc (%)COCO box APCOCO mask APADE20K mIoU
Swin-T29M4.5G81.350.543.744.5
Swin-S50M8.7G83.051.844.747.6
Swin-B88M15.4G83.551.945.048.1
Swin-B+38488M47.1G84.5---
Swin-L+384*197M103.9G85.253.546.253.5
DeiT-B86M17.6G81.844.940.943.3
ResNet-5025M4.1G76.239.235.436.7
ResNeXt-10183M32.0G79.344.039.840.2
  • *表示使用更大分辨率、更多数据(例如 Swin-L+384 使用 ImageNet-22K + finetune)
  • FLOPs 以 ImageNet-1K 输入大小计算
  • 所有检测/分割任务均在 COCO / ADE20K 上 finetune

4.2 消融实验

改动项Top-1 (%)COCO Box APADE20K mIoU说明
full Swin-T baseline81.350.544.5原始模型
no shifted windows80.548.441.6不使用 SW-MSA(只有 W-MSA)
no patch merging80.847.942.8所有 stage 的分辨率保持不变
no relative position bias81.049.443.4去除相对位置编码
use absolute pos embedding81.149.643.3替换成绝对位置编码
single-scale training only80.948.743.0不使用 multi-scale 训练
replace LayerNorm → BatchNorm80.747.842.0用 BN 替换 LN
only 1 stage (no hierarchy)78.142.639.1所有 block 用相同分辨率

亮点

模块对性能的提升贡献
Shifted Window (SW-MSA)关键设计,提升 2%+
Patch Merging(金字塔)提升多尺度建模能力
相对位置偏置比绝对位置编码更有效
LayerNorm 更适配 Transformer替换成 BN 会显著退化

相关文章:

  • 图像预处理-直方图均衡化
  • WebRTC服务器Coturn服务器的管理平台功能
  • 再次理解 瓦瑟斯坦距离(Wasserstein Distance)
  • 【C语言】初阶算法相关习题(一)
  • Docker 部署 Redis 缓存服务
  • 安宝特案例 | 某知名日系汽车制造厂,借助AR实现智慧化转型
  • 安宝特分享|AR智能装备赋能企业效率跃升
  • BEVDepth: Acquisition of Reliable Depth for Multi-View 3D Object Detection
  • leetcode 二分查找
  • 神经网络 “疑难杂症” 破解指南:梯度消失与爆炸全攻略(六)
  • 信奥赛CSP-J复赛集训(DP专题)(19):P3399 丝绸之路
  • Trent硬件工程师培训完整135讲
  • Windows 下 Git 入门指南:从安装、配置 SSH 到加速 GitHub 下载
  • gradle可用的下载地址(免费)
  • 研发效率破局之道阅读总结(3)工程优化
  • 【Lua】Lua 入门知识点总结
  • 使用 acme.sh 自动更新 SSL 证书的指南
  • 【MySQL】005.MySQL表的约束(上)
  • WPS Office安卓版云文档同步速度与PDF转换体验测评
  • 突破AI检测边界:对抗技术与学术伦理的终极博弈
  • 爱奇艺要转型做微剧?龚宇:是误解,微剧是增量业务,要提高投资回报效益
  • 解密帛书两千年文化传承,《帛书传奇》央视今晚开播
  • 西安雁塔区委书记王征拟任市领导班子副职,曾从浙江跨省调任陕西
  • 哈佛大学就联邦经费遭冻结起诉特朗普政府
  • 商务部24日下午将举行发布会,介绍近期商务领域重点工作情况
  • 甘肃古浪县发生3.0级地震,未接到人员伤亡和财产损失报告