【AI模型学习】Swin Transformer——优雅的模型
文章目录
- 一、论文背景:从 ViT 到 Swin
- 1.1 ViT 的局限
- 1.2 Swin Transformer的目标
- 二、模型
- 2.1 Patch Partition 和 Patch Embedding
- 2.2 Window Multi-head Self Attention (W-MSA)
- 2.3 Shifted Window Multi-head Self Attention (SW-MSA)
- 2.3 Patch Merging
- 2.4 金字塔结构
- 三、细节
- 3.1 关于复杂度
- 3.2 Mask 的具体实现细节
- 3.3 下游任务
- 四、实验
- 4.1 主实验结果汇总
- 4.2 消融实验
一、论文背景:从 ViT 到 Swin
Transformer 模型在 NLP 领域大获成功之后,ViT(Vision Transformer) 率先把它引入了视觉任务,用 patch 的方式处理图像并直接送入 Transformer 编码器,在大规模数据下表现优异。但它也带来了一些严重的问题:
1.1 ViT 的局限
-
计算复杂度高:
全局自注意力(Self-Attention)的计算复杂度是 O ( N 2 ) O(N^2) O(N2),其中 N N N 是图像 patch 的数量。图像越大,patch 越多,计算量激增。 -
缺乏局部性 inductive bias:
CNN 有平移不变性(Translation Equivariance)和局部感知能力,而原始 ViT 完全依赖大数据学习这些特性,效率低。 -
不适配金字塔结构:
CNN 的经典架构具有“金字塔式”的多尺度特性,有利于分割、检测等密集预测任务;ViT 原始版本处理的是固定大小的 patch,结构过于平坦,不易扩展。
1.2 Swin Transformer的目标
Swin Transformer:
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
—— 发表在 ICCV 2021,作者来自微软亚洲研究院。
- 提出一种计算友好的局部自注意力机制(Window-based Attention),代替全局 Attention。
- 构造具有层次化结构的 Transformer,使其具备 CNN 那种金字塔式的表达能力。
- 通过滑窗策略(Shifted Windows)增强跨窗口的信息交流,避免窗口孤立。
- 最终能在多个任务中泛化,如分类、检测、分割等,成为通用视觉骨干。
Swin Transformer 试图解决 Vision Transformer 在计算效率、本地性建模、多尺度表达上的缺陷,用 滑动窗口 + 层次化设计 将 Transformer 带入实际视觉任务场景。
文章要学习的点挺多。
- Patch Partition 和 Patch Embedding
- Window Multi-head Self Attention (W-MSA)
- Shifted Window Multi-head Self Attention (SW-MSA)
- Patch Merging 的金字塔构建方式
论文链接(arXiv)如下:
https://arxiv.org/abs/2103.14030
PDF 下载地址(直接下载):
https://arxiv.org/pdf/2103.14030.pdf
二、模型
2.1 Patch Partition 和 Patch Embedding
输入图像的维度是:
(B, H, W, C)
- B = batch size
- H, W = 图像的高和宽(比如 224)
- C = 图像通道数(RGB → 3)
Swin 使用 4 × 4
大小的 patch,把图像分块,每个 patch 当成一个 token。
-
划分 patch(Patch Partition)
把整张图像切成
4 × 4
的小块,每个 patch 包含:4(height) × 4(width) × 3(channel) = 48 个数
-
计算 patch 数量
每张图像被切成:
(H / 4) × (W / 4) = 56 × 56
-
reshape 操作
把输入变换为:
(B, 56, 56, 4, 4, 3) → reshape → (B, 56, 56, 48)
-
Linear Projection(Patch Embedding)
用一个线性层把每个 patch 从 48 维投影到更高维,比如 96:
(B, 56, 56, 48) → Linear → (B, 56, 56, 96)
输出:
(B, 56, 56, 96)
- 表示每张图像被编码为 56×56 个 token(patch)
- 每个 token 是一个 96 维的向量
patch之间"拍瘪":
(B, 3136, 96) # 方便送入 Transformer 做 Attention
2.2 Window Multi-head Self Attention (W-MSA)
为什么不用“全局 Self-Attention”?
传统 Transformer 的自注意力是全局的,也就是说,每个 token 都和其他所有 token 做注意力计算。
如果我们现在的输入是:
(B, 3136, 96) # 即 56×56 个 patch,每个 patch 是 96 维向量
那么计算注意力的代价是:
O(N² × D) = O(3136² × 96) # 太大太大了,显存爆炸
解决办法:Window-based Attention(W-MSA)
核心思想:
不在整幅图上计算 attention,只在每个小窗口内进行 self-attention
具体流程
假设输入是:
(B, 56, 56, 96)
-
划分成小窗口(默认窗口大小是 7 × 7)
我们把每张图像划分成不重叠的窗口:- 每个窗口:7×7 = 49 个 patch
- 每张图:56 / 7 = 8 行 × 8 列 → 一共 64 个窗口
-
reshape 成窗口批处理格式
我们将每个窗口提取出来,构造成批处理格式:(B, 56, 56, 96) → reshape → (B * 64, 7 * 7, 96) = (B * 64, 49, 96)
- 每个窗口有 49 个 token(patch)
- 每个 token 是 96 维
- 一共 B×64 个窗口
-
在每个窗口中做 Multi-head Self-Attention
对每个窗口内部做标准的 MSA:(B * 64, 49, 96) ↓ W-MSA(每个窗口 attention) → (B * 64, 49, 96)
-
把所有窗口还原回原图结构
恢复原始的空间布局:(B * 64, 49, 96) → reshape → (B, 56, 56, 96)
总结流程
Input: (B, 56, 56, 96)
↓ 划分 7×7 窗口
→ reshape: (B*64, 49, 96)
↓ 每个窗口内做 MSA
→ (B*64, 49, 96)
↓ reshape 回原图形状
Output: (B, 56, 56, 96)
项目 | 内容 |
---|---|
Attention 类型 | 每个窗口内部做 |
计算复杂度 | 从 O(N²) 降为 O(M × K²),其中 M 是窗口个数,K 是窗口大小 |
优点 | 大幅减少计算量,适配高分辨率图像 |
缺点 | 窗口之间无法直接通信(后面用 SW-MSA 弥补!) |
2.3 Shifted Window Multi-head Self Attention (SW-MSA)
Swin Transformer 的王牌设计
目的:解决 W-MSA 的“信息孤岛”问题
在 W-MSA 中,我们把图像划分成不重叠的小窗口,每个窗口自己做注意力,虽然节省了计算量,但也带来了一个问题:
每个窗口之间完全没交流,只能看到自己这一块。
也就是说,窗口 A 里的 token 不知道窗口 B 发生了什么,这对捕捉全局关系是致命的限制。
SW-MSA 的核心想法:
在下一层,把窗口整体“滑动”一半(shift),让新的窗口能覆盖多个旧窗口,打破原有边界,形成跨窗口信息流动!
具体流程
我们接着上一步,输入仍是:
(B, 56, 56, 96)
-
平移窗口(Shift)
将整张 feature map 在空间维度上平移:
向下平移 3 行、向右平移 3 列(因为窗口大小是 7,所以 7 // 2 = 3)
这时候,原来相邻但是在不同窗口的 patch 被划入了同一个窗口中。
注意:平移过程中会打破原来的窗口对齐方式,有些位置会被“移出边界”,需要后续处理。(在mask中会讲到)
-
遮挡 Mask(mask padding region)
由于平移后窗口不再对齐,需要对超出边界的地方进行“掩码”(mask padding),避免注意力看到无效区域。
论文中的做法是构造一个 attention mask,屏蔽跨 block 的 attention。(之后再讲)
-
对 Shift 后的新窗口做 W-MSA
将新窗口也划为
7 × 7
,即使现在不是对齐划分的,也一样处理:(B, 56, 56, 96) ↓ shift → 平移后的特征图 ↓ window partition → reshape → (B * 64, 49, 96) ↓ attention with mask → (B * 64, 49, 96) ↓ reshape → (B, 56, 56, 96)
-
shift back(还原位移)
Attention 完成后,把平移回来的特征图“再平移回原位置”,确保位置信息不乱套。
总结
Input: (B, 56, 56, 96)
↓ shift → 位移(往右下滑 3)
↓ window partition + mask attention
↓ W-MSA → 每个滑动窗口内部做 MSA
↓ shift back → 恢复原始位置
Output: (B, 56, 56, 96)
对比:W-MSA vs SW-MSA
特性 | W-MSA | SW-MSA |
---|---|---|
窗口划分 | 固定、规则、无重叠 | 滑动、打破边界 |
是否跨窗口交流 | 否 | 是 |
是否需要掩码 | 不需要 | 需要 attention mask |
是否恢复位移 | 不需要 | 需要 shift back |
两者是交替使用的!
每个 Swin Block 都是:
W-MSA → SW-MSA → W-MSA → SW-MSA → ...
这样就能:
- 降低计算量(小窗口)
- 建立全局联系(滑动窗口)
2.3 Patch Merging
Patch Merging 是 Stage 之间的下采样过程,和 CNN 的 pool/stride 类似。
觉得自己讲有点抽象,干脆找了一篇CSDN的博客。
ViT中的上采样和下采样——patch merge
假设当前输入维度是:
(B, H, W, C)
先把图片一分为4,如何分的,可以看上面的文章。图片宽高都减半,将4张子图在通道维拼接。
(B, H, W, C) -> (B, H/2, W/2, 4C)
CV一般“惯例”都是池化的时候,高宽减半,通道翻倍。我们这里翻了4倍,所以做一个全连接。
(B, H/2, W/2, 4C) -> (B, H/2, W/2, 2C)
2.4 金字塔结构
Swin Transformer 不像 ViT 那样一直保持固定大小的 feature,而是像 CNN 一样,每过一个 Stage:
- 减小空间尺寸(下采样)
- 增加通道数(增强表达)
每个 SwinBlock 包含两大部分:
- Window-based Multi-head Self-Attention(W-MSA SW-MSA)模块
- MLP(前馈神经网络)
SwinBlock:
├── LN
├── W-MSA / SW-MSA
├── Dropout + Residual
├── LN
├── MLP (Linear → GELU → Linear)
├── Dropout + Residual
SwinBlock之间通过Patch Merging进行“池化”
Stage 流程总览
每一层结构如下:
[Stage n]
↓ SwinBlock × L
↓ Patch Merging
举个完整流程(Swin-Tiny):
Stage 1:Input : (B, 56, 56, 96)SwinBlock×2 : (B, 56, 56, 96)PatchMerge : (B, 28, 28, 192)Stage 2:SwinBlock×2 : (B, 28, 28, 192)PatchMerge : (B, 14, 14, 384)Stage 3:SwinBlock×6 : (B, 14, 14, 384)PatchMerge : (B, 7, 7, 768)Stage 4:SwinBlock×2 : (B, 7, 7, 768)
最后接上:
↓ Global Average Pooling → (B, 768)
↓ Linear classifier → (B, num_classes)
三、细节
3.1 关于复杂度
这里可以直接看李沐老师的讲论文系列
Swin Transformer:Hierarchical Vision Transformer using Shifted Windows
Self-Attention 的复杂度
项目 | 复杂度表达式 | 说明 |
---|---|---|
Q/K/V 投影 | 3 N D 2 3ND^2 3ND2 | 三个线性映射 |
Q K ⊤ QK^\top QK⊤ | N 2 D N^2D N2D | 所有 head 共享序列长度 N |
Softmax × V | N 2 D N^2D N2D | 同样按每个 head 分别处理 |
输出线性投影 | N D 2 ND^2 ND2 | 拼接后投影回 D 维 |
总计 | 4 N D 2 + 2 N 2 D 4ND^2 + 2N^2D 4ND2+2N2D | 主要瓶颈是 N 2 D N^2D N2D(注意力部分) |
ViT 的 Attention 复杂度(全局)
ViT 默认每张图像切成 16 × 16 16 \times 16 16×16 patch,输入大小为 224 × 224 224 \times 224 224×224:
- 每张图像的 token 数:
N = (224 / 16)^2 = 14 × 14 = 196
- 假设 token embedding 维度 D = 768 D = 768 D=768
代入公式:
C ViT = 4 N D 2 + 2 N 2 D = 4 ⋅ 196 ⋅ 76 8 2 + 2 ⋅ 19 6 2 ⋅ 768 ≈ 451 M + 59 M ≈ 510 M \mathcal{C}_\text{ViT} = 4ND^2 + 2N^2D = 4 \cdot 196 \cdot 768^2 + 2 \cdot 196^2 \cdot 768 ≈ 451M + 59M ≈ 510M CViT=4ND2+2N2D=4⋅196⋅7682+2⋅1962⋅768≈451M+59M≈510M
Swin 的 Attention 复杂度(局部窗口)
Swin 使用 4 × 4 4 \times 4 4×4 patch → 得到 56 × 56 = 3136 56 \times 56 = 3136 56×56=3136 个 patch
窗口大小是 7 × 7 = 49 7 \times 7 = 49 7×7=49
于是我们可以这么计算:
- 每个窗口: K = 49 K = 49 K=49 个 token
- 窗口个数: M = 56 ⋅ 56 / ( 7 ⋅ 7 ) = 3136 / 49 = 64 M = 56 \cdot 56 / (7 \cdot 7) = 3136 / 49 = 64 M=56⋅56/(7⋅7)=3136/49=64
- 每个窗口做 attention,其复杂度是:
C window = 4 K d 2 + 2 K 2 d \mathcal{C}_\text{window} = 4Kd^2 + 2K^2d Cwindow=4Kd2+2K2d
其中 d = D / h d = D / h d=D/h,假设 h = 12 h=12 h=12 → d = 64 d = 64 d=64
那我们代入:
- K = 49 K = 49 K=49
- d = 64 d = 64 d=64
- 每窗口复杂度: 4 ⋅ 49 ⋅ 6 4 2 + 2 ⋅ 4 9 2 ⋅ 64 = 802 , 816 + 307 , 328 = 1 , 110 , 144 4 \cdot 49 \cdot 64^2 + 2 \cdot 49^2 \cdot 64 = 802,816 + 307,328 = 1,110,144 4⋅49⋅642+2⋅492⋅64=802,816+307,328=1,110,144
- 总复杂度 = 窗口数 × 每窗口复杂度:
C Swin = 64 ⋅ 1.11 M ≈ 71 M \mathcal{C}_\text{Swin} = 64 \cdot 1.11M ≈71M CSwin=64⋅1.11M≈71M
3.2 Mask 的具体实现细节
我们要做的,是在 SW-MSA 中生成一个掩码张量,用于在 softmax 前屏蔽非法 attention。这个 mask 形状是:
(1, num_windows, window_size*window_size, window_size*window_size)
- 每个窗口一个 mask
- 每个窗口内的 token 对之间有一个标志:是同一区域还是不同区域
-
给每个 patch 分配区域编号
我们构造一个虚拟的
H × W
的二维矩阵img_mask
,每一块window_size × window_size
区域赋一个唯一的整数 ID,标识它的原始区域编号:img_mask = torch.zeros((1, H, W, 1)) # shape: (1, H, W, 1) cnt = 0 for h in range(0, H, window_size):for w in range(0, W, window_size):img_mask[:, h:h+window_size, w:w+window_size, :] = cntcnt += 1
这样每个
7x7
区块都有唯一编号。 -
对
img_mask
做 shift(滑动窗口)shifted_mask = torch.roll(img_mask, shifts=(-shift_size, -shift_size), dims=(1, 2))
比如
shift_size = 3
时,相当于向上左滑动了 3 个像素。 -
再按窗口划分,并 reshape 成 attention 的输入格式
我们把 shifted 后的 mask 分成窗口,每个窗口是一个 ( w i n d o w _ s i z e × w i n d o w _ s i z e ) (window\_size \times window\_size) (window_size×window_size) 区域:
mask_windows = window_partition(shifted_mask, window_size) # shape: (num_windows, window_size, window_size, 1) mask_windows = mask_windows.view(-1, window_size * window_size)
现在每个窗口内是一个 shape 为
(49,)
的向量,每个值是区域 ID。 -
构造 attention mask(核心)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float('-inf')).masked_fill(attn_mask == 0, 0)
attn_mask[i, j]
表示第 i i i 个 token 和第 j j j 个 token 的差值- 如果来自不同原始区域 → 不等 → 设为
-inf
- 如果来自同一区域 → 相等 → 设为
0
- 如果来自不同原始区域 → 不等 → 设为
最终你就得到一个形状为:
(num_windows, 49, 49)
的 attention mask,在后续计算中会加到注意力 logits 上:
attn_scores = (Q @ K^T) / sqrt(d) attn_scores += attn_mask attn_probs = softmax(attn_scores)
步骤 | 操作 | 目的 |
---|---|---|
1 | 构造区域编号矩阵 img_mask | 标记不同窗口的 token |
2 | 对 img_mask 做 shift | 模拟 SW-MSA 的窗口 |
3 | 按窗口切块 → reshape | 获取窗口内的 token 区域分布 |
4 | 构造差值矩阵 + masked_fill | 设置 cross-window 的 attention 为 -inf |
- 掩码是 与输入内容无关的结构掩码,可以在模型初始化时就构建好(只依赖窗口大小与输入尺寸)
torch.roll
是非常高效的窗口滑动方式- 使用
.masked_fill()
是 PyTorch 中处理 Attention Mask 的标准做法 - 注意:在 Swin 中窗口是 batch-wise共享的,所以 mask 只构建一次即可复用
3.3 下游任务
一、Swin-T 如何接入各种类型的下游任务?
Swin-T 作为视觉 Transformer,可以看作 CNN 的替代 backbone,输出 层次化的特征图,非常适合各种任务中模块化插入。
-
分类任务(Image Classification)
- 取 Stage 4 的输出
(B, 7, 7, 768)
- 做一次 Global Average Pooling
- 接一个
Linear
分类头
- 取 Stage 4 的输出
-
检测任务(Object Detection)
目标检测通常需要多层次特征,因此 Swin 的多 stage 输出非常合适:
将 Swin 的各个 Stage 输出喂入 FPN(Feature Pyramid Network):
Swin stages → [C1, C2, C3, C4]↓FPN↓Detection Head (e.g., RetinaNet, Faster R-CNN)
其中:
Stage Output Shape 用于替代哪层 C1 56 × 56 × 96 ResNet C2 C2 28 × 28 × 192 ResNet C3 C3 14 × 14 × 384 ResNet C4 C4 7 × 7 × 768 ResNet C5 -
语义分割(Semantic Segmentation)
语义分割也需要多尺度语义信息,Swin 可以直接连接到分割头,如:
输出四层
[C1, C2, C3, C4]
,拼接 or 上采样对齐后,接Segmentation Head
Swin Features → UPerNet / FPN Head → Pixel-wise 分类
-
实例分割(Instance Segmentation)
和目标检测类似,通常基于
Mask R-CNN
等结构,只需将 Swin 替代 ResNet:Swin Backbone → FPN → ROIAlign → BBox + Mask Head
-
视频理解(Action Recognition)
Swin 可以扩展到时序维度 → 变为 Swin-Video Transformer:
- 在时间轴上引入 attention(Swin-ViT 3D)
- 用于视频分类、时序分割等任务
二、Swin 的预训练方式
-
ImageNet 预训练(分类任务通用)
-
Masked Image Modeling(自监督)
-
Detection / Segmentation 预训练(针对 COCO / ADE20K)
-
视频预训练(Swin-Video)
四、实验
4.1 主实验结果汇总
模型 | Params | FLOPs | Top-1 Acc (%) | COCO box AP | COCO mask AP | ADE20K mIoU |
---|---|---|---|---|---|---|
Swin-T | 29M | 4.5G | 81.3 | 50.5 | 43.7 | 44.5 |
Swin-S | 50M | 8.7G | 83.0 | 51.8 | 44.7 | 47.6 |
Swin-B | 88M | 15.4G | 83.5 | 51.9 | 45.0 | 48.1 |
Swin-B+384 | 88M | 47.1G | 84.5 | - | - | - |
Swin-L+384* | 197M | 103.9G | 85.2 | 53.5 | 46.2 | 53.5 |
DeiT-B | 86M | 17.6G | 81.8 | 44.9 | 40.9 | 43.3 |
ResNet-50 | 25M | 4.1G | 76.2 | 39.2 | 35.4 | 36.7 |
ResNeXt-101 | 83M | 32.0G | 79.3 | 44.0 | 39.8 | 40.2 |
- *表示使用更大分辨率、更多数据(例如 Swin-L+384 使用 ImageNet-22K + finetune)
- FLOPs 以
ImageNet-1K
输入大小计算- 所有检测/分割任务均在 COCO / ADE20K 上 finetune
4.2 消融实验
改动项 | Top-1 (%) | COCO Box AP | ADE20K mIoU | 说明 |
---|---|---|---|---|
full Swin-T baseline | 81.3 | 50.5 | 44.5 | 原始模型 |
no shifted windows | 80.5 | 48.4 | 41.6 | 不使用 SW-MSA(只有 W-MSA) |
no patch merging | 80.8 | 47.9 | 42.8 | 所有 stage 的分辨率保持不变 |
no relative position bias | 81.0 | 49.4 | 43.4 | 去除相对位置编码 |
use absolute pos embedding | 81.1 | 49.6 | 43.3 | 替换成绝对位置编码 |
single-scale training only | 80.9 | 48.7 | 43.0 | 不使用 multi-scale 训练 |
replace LayerNorm → BatchNorm | 80.7 | 47.8 | 42.0 | 用 BN 替换 LN |
only 1 stage (no hierarchy) | 78.1 | 42.6 | 39.1 | 所有 block 用相同分辨率 |
亮点
模块 | 对性能的提升贡献 |
---|---|
Shifted Window (SW-MSA) | 关键设计,提升 2%+ |
Patch Merging(金字塔) | 提升多尺度建模能力 |
相对位置偏置 | 比绝对位置编码更有效 |
LayerNorm 更适配 Transformer | 替换成 BN 会显著退化 |