ViT 模型讲解
文章目录
- 一、模型的诞生背景
- 1.1 背景
- 1.2 ViT 的提出(2020年)
- 二、模型架构
- 2.1 patch
- 2.2 模型结构
- 2.2.1 数据 shape 变化
- 2.2.2 代码示例
- 2.2.3 模型结构图
- 2.3 关于空间信息
- 三、实验
- 3.1 主要实验
- 3.2 消融实验
- 四、先验问题
- 4.1 归纳偏置
- 4.2 先验or大数据???
- 五、补充
- 5.1 二维的位置编码
- 5.2 挖坑
Vision Transformer(ViT)是一个开创性的计算机视觉模型,它首次成功地将Transformer架构引入到视觉领域,打破了长期以来以卷积神经网络(CNN)为主导的局面。
一、模型的诞生背景
1.1 背景
-
Transformer 在 NLP 的成功
Transformer 架构由 Vaswani 等人在 2017 年提出(论文《Attention is All You Need》),它在 NLP 中(比如 BERT、GPT)表现非常出色。 -
CV 领域的主流仍是 CNN
视觉领域长期由 CNN 主导(如 ResNet、EfficientNet 等),这些模型通过局部卷积提取空间信息。 -
挑战:Transformer 在图像上的应用
图像不像文本那样是离散的序列,Transformer 需要序列化输入,因此直接用 Transformer 处理图像并不直观。
1.2 ViT 的提出(2020年)
论文:“An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”
作者:Google Research
意义:首次证明了,在足够的数据和算力支持下,Transformer 可以在图像分类任务中超过最好的 CNN。
二、模型架构
2.1 patch
假设输入是一批图像,维度为:
(B, H, W, C)
- B = batch size
- H = 高(比如 224)
- W = 宽(比如 224)
- C = 通道数(比如 RGB → 3)
我们想让 Transformer 接收这个图像,但问题:Transformer 要求输入是序列: 每个 token 有一个 embedding,形状应该是:
(B, N, D)
- N = 序列长度(token 数)
- D = 每个 token 的 embedding 维度
那我们能不能把图像 reshape 成这种形状呢?
尝试直接 reshape:把每个像素当作一个 token?
(B, H, W, C) → reshape → (B, H*W, C)
比如说:
B = 32
H = W = 224
C = 3
→ (32, 224*224, 3) = (32, 50176, 3)
问题:序列长度 N = 50,176,太长了!
- Transformer 的注意力是 O(N²),即
O(50176²) ≈ 2.5 billion
,显存、计算量巨大! - 实际上,BERT 这样的 NLP 模型一般处理的序列长度才 512 左右。
这就是 ViT 的关键创新:
把图像划分为 Patch,每个 Patch 当成一个 token。
我们将图像划分为大小为 P x P
的 patch,比如 16 x 16
。
-
则每个 patch 的像素数为:
P*P*C = 16*16*3 = 768
-
每张图像总共包含 patch 数:
N = (H / P) * (W / P) = (224 / 16)^2 = 14 * 14 = 196
-
所以 reshape 后的数据维度是:
(B, 14, 14, 16, 16, 3) -> (B, 14*14, 16*16*3)
2.2 模型结构
设定一个典型 ViT 的输入参数如下:
- 图像大小:
224 x 224 x 3
(高×宽×通道)- Patch 大小:
16 x 16
- 输出维度(embedding dim):
768
- Patch 数量:
(224 / 16)^2 = 14 x 14 = 196
- Transformer 层数:
12
- Head 数:
12
- MLP Head 输出类别数:
1000
(如 ImageNet)
2.2.1 数据 shape 变化
数据 shape 流程(标准 ViT)
Input Image: (3, 224, 224)
↓
1. Split into Patches (16x16) → 每个 patch 展平成向量:
得到 (196, 3*16*16) = (196, 768)
↓
2. Linear Projection of Patches → Embedding:
每个 patch 向量映射为 (196, 768)
↓
3. 加入 CLS token:
+ 1 个 token,shape 变为 (197, 768)
↓
4. 加入 Position Embedding:
shape 仍是 (197, 768)
↓
5. 输入 Transformer Encoder(L 层):
每层输出:仍是 (197, 768)
Attention → FFN → Add & Norm
↓
6. 取 CLS token 表示:
(1, 768)
↓
7. 分类头(MLP Head):
Linear(768 → 1000)
得到 (1, 1000)
↓
8. softmax 输出:
1000 类别概率
架构图(图示)
┌─────────────────────┐
Image │ Input Image │
(3, 224, 224) │ (3 x 224 x 224) │
└────────┬────────────┘
↓
┌────────────────────────┐
│ Split into Patches │
│ 16x16 patches │
│ 196 patches total │
└────────┬───────────────┘
↓
┌─────────────────────────────┐
│ Flatten + Linear Projection │
│ Each patch → (1 x 768) │
│ → (196 x 768) │
└────────┬────────────────────┘
↓
┌──────────────────────────┐
│ Add CLS Token (1 x 768) │
│ → (197 x 768) │
└────────┬─────────────────┘
↓
┌───────────────────────────────┐
│ Add Positional Embedding │
│ → (197 x 768) │
└────────┬──────────────────────┘
↓
┌────────────────────────────┐
│ L x Transformer Encoder │
│ (Self-Attn + FFN) │
│ Output: (197 x 768) │
└────────┬───────────────────┘
↓
┌──────────────────────────────┐
│ Take [CLS] token (1 x 768) │
└────────┬─────────────────────┘
↓
┌─────────────────────────────┐
│ MLP Head → Class logits │
│ Linear: 768 → 1000 │
└────────┬────────────────────┘
↓
Output: (1 x 1000)
2.2.2 代码示例
PyTorch 实现的示例
# 假设输入图像 batch 为 (B, 3, 224, 224)
B = 32
patch_size = 16
embed_dim = 768
num_patches = (224 // patch_size) ** 2 # 14 * 14 = 196
# 1. Patch embedding
x = torch.randn(B, 3, 224, 224)
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
patches = patches.contiguous().view(B, 3, -1, patch_size, patch_size)
patches = patches.permute(0, 2, 1, 3, 4).contiguous().view(B, num_patches, -1) # (B, 196, 768)
# 2. Linear projection + Add cls + pos embedding → (B, 197, 768)
# 3. Transformer blocks → 输出 shape 不变 (B, 197, 768)
# 4. MLP Head → (B, 1000)
2.2.3 模型结构图
2.3 关于空间信息
ViT 把每个 patch flatten 后再当作 token 输入,是否会丢失空间信息?模型还能理解图像的空间结构吗?
我们来逐步拆解这个问题。
原图 → (B, 224, 224, 3)
patch → 划成 (B, 14, 14, 16, 16, 3)
flatten → (B, 196, 768) ← 每个 patch 被完全拉平
也就是说:
- 每个
16x16
patch 被视为一个 768 维向量,送进 Transformer。 - 在 patch 内部,像素之间的 2D 空间关系已经丢失了(被 flatten)。
这其实是一个 trade-off:局部 vs 全局
CNN 是局部建模:
- 用卷积核滑动 → 保留局部空间结构
- 多层堆叠 → 扩展感受野
ViT 是全局建模:
- 每个 patch 被视作一个整体 token,不知道 patch 内部的位置信息
- 但可以通过 self-attention 在patch 之间建模全局关系
ViT 是如何弥补空间结构缺失的?
-
位置编码(Positional Embedding)
虽然 patch 本身是被 flatten 的,但 ViT 给每个 patch 加了一个位置向量:
patch_embeddings + positional_embeddings
- 这些 positional embeddings 是可学习参数,形状为
(num_patches, embed_dim)
,比如(196, 768)
- 它告诉模型:第几个 patch 是图像的哪一部分(左上、右下、中间…)
- 这些 positional embeddings 是可学习参数,形状为
-
Self-Attention 建模 patch 间的关系
虽然每个 patch 内部的空间结构被 flatten 掉了,但 ViT 有全局的 self-attention:
- 每个 patch 可以“看到”所有其他 patch
- 自注意力机制可以自动学习到:哪些 patch 是连续的、有相似内容的、有上下文关系的
patch之间有位置编码,那patch内部的空间信息就真的没了吗?
基本上……是的,ViT 在标准设计中:
完全忽略了 patch 内部像素的空间结构(除了间接靠大量数据让模型学到某些模式)
这就是为什么后续很多改进型 ViT 试图引入更精细的空间建模,比如:
-
CNN + ViT 结合
比如CvT
,ConViT
:在 embedding 之前加一层卷积提取局部特征。 -
局部注意力 / 层次结构
比如Swin Transformer
:只在局部窗口内做注意力,保留 patch 内空间结构。 -
使用 Patch Token Hierarchy
把图像分层次进行划分,像 CNN 的 downsampling 一样逐层抽象。 -
位置编码增强
有些 ViT 用坐标编码(x, y)拼进去,或者使用相对位置编码(像 NLP 中那样)。
问题 | 是否存在 | 弥补方式 |
---|---|---|
Patch 内部 flatten 导致空间信息丢失 | 存在 | 无法恢复(在标准 ViT 中) |
Patch 之间的空间结构 | 被建模 | 位置编码 + attention |
模型理解图像结构的能力 | 部分靠 attention 建模 | 更依赖大规模预训练数据 |
三、实验
3.1 主要实验
实验 | 变量 | 设置 | 结果/结论 |
---|---|---|---|
ViT vs ResNet | 模型结构 | ViT-B/16 vs ResNet152x4 | 在 ImageNet-21k 上 ViT > ResNet,参数更少,性能更高 |
不同预训练数据 | 数据规模 | ImageNet (1M) vs IN-21k (14M) vs JFT-300M (300M) | ViT 在大数据集上才训练得好,小数据集会过拟合 |
小样本迁移能力 | Fine-tune 到小数据集 | CIFAR, Flowers, VTAB | ViT 表现良好,迁移性强 |
与其他模型迁移对比 | 方法对比 | ResNet vs ViT | ViT 表现更好,但需要预训练支持 |
3.2 消融实验
- Patch Size 影响
- Patch size: 8x8、16x16、32x32
- 越小越好(序列变长,表达力强)
- 但计算量 ↑,内存占用 ↑
- 16x16 是最佳平衡
结论:小 patch 能提升性能,但要付出代价。
- 位置编码
- 加 vs 不加
- 绝对位置 vs 相对位置
- 插值位置编码支持不同分辨率
结论:必须加位置编码,且可插值适配不同输入大小。
- MLP Head 设计
- Linear head vs 多层 MLP
- 结果几乎一样
结论:简单的 linear classifier 足以。
- LayerNorm 放置位置
- Pre-LN(LN 在 Attention 前) vs Post-LN
- Pre-LN 更稳定,训练更好收敛
结论:Pre-LN 更适合 ViT。
- Class Token 的作用
- 类似 BERT 的 [CLS]
- 每层 attention 都能访问,用于最终分类
结论:必须有,模型依赖它来提取全局特征。
- 是否需要强数据增强 / 自监督
- 加了 CutMix、RandAug 等 → 效果提升小
- ViT 不如 CNN 那么依赖数据增强
结论:数据增强对 ViT 不敏感,但对 CNN 很关键。
- 是否使用卷积替代 patch embedding
- 用卷积提 patch embedding 没有显著提升
结论:线性投影就可以,无需 CNN 特征提取头。
关键有效设计
组件 | 作用 | 是否关键 |
---|---|---|
Patch embedding | 降低序列长度 | 1 |
Positional Encoding | 建模空间顺序 | 1 |
Class Token | 输出图像全局表征 | 1 |
Self-Attention | 全局建模 | 1 |
Pre-LN 架构 | 稳定训练 | 1 |
大规模数据预训练 | 避免过拟合 | 111 |
可有可无的点
组件 | 作用 | 结论 |
---|---|---|
强数据增强 | 抗扰动、泛化 | 提升小,可有可无 |
MLP Head | 增加判别能力 | 线性 head 足够 |
卷积前置层 | 提取局部特征 | 变化不大 |
四、先验问题
4.1 归纳偏置
“先验(prior)”在深度学习里,其实就是“归纳偏置(inductive bias)”的一个通俗说法。
归纳偏置(Inductive Bias):指的是一个学习算法在数据不足时,如何做出合理泛化的偏好或假设。
简化理解:
当我看不到所有数据的时候,我默认事情是这样的 —— 这就是我的归纳偏置。
举几个通俗例子:
模型 | 归纳偏置 / 先验 |
---|---|
CNN | 局部连接、权重共享、平移不变性(空间先验) |
RNN/LSTM | 时间顺序相关性 |
GNN | 图结构邻居影响中心点 |
Transformer | 无特别强的归纳偏置(完全数据驱动) |
先验 = 归纳偏置 = 模型对现实世界的一种结构性假设
4.2 先验or大数据???
ViT 完全基于 Transformer,没有 CNN 的结构先验:
- 不知道局部区域重要;
- 不知道平移后的图还是同一个图;
- 所以 ViT 必须通过大量数据自己“悟出这些结构”;
- 而 CNN 天生就有这些“归纳偏置”植入其中。
ViT 证明了一件事:在足够大、足够多样的数据下,模型可以学出先验,从而超越“手工嵌入先验”的模型。
在 ViT 的实验中:
数据集 | 模型表现 |
---|---|
ImageNet-1k | ViT 差于 CNN |
ImageNet-21k | ViT ≈ CNN |
JFT-300M | ViT > CNN(显著) |
这清楚地说明了:
- 在小数据(ImageNet-1k)下,CNN 的归纳偏置(如局部性、平移不变)提供了巨大优势;
- 在大数据(JFT-300M)下,ViT 能通过学习自动获得这些归纳偏置(甚至更多),从而反超 CNN;
- 也就是说:数据能“弥补”先验的缺失,甚至最终超越它。
所以,ViT 是一个“用海量数据和大模型学到归纳偏置”的典范。
但这并不是说“先验没用了”!
ViT 不是在否定先验的价值,而是在提出一个新的平衡:
归纳偏置(先验)和数据之间,是可以相互替代的,但各有代价。
| 有强归纳偏置(CNN) | 少数据、高效训练、泛化强,但灵活性差 |
| 无归纳偏置(ViT) | 灵活、能力上限高,但数据、算力需求大 |
所以:
- CNN 很适合“小数据集、低资源”的任务;
- ViT 更适合“数据丰富、资源充足”的大规模学习;
- ViT 的出现是因为 Google 拥有海量数据和TPU,大部分人是玩不起纯 ViT 的。
ViT 后续的发展说明了:先验和数据,是可以结合的
很多 ViT 变体其实都在 “加回先验”:
- Swin Transformer:加入了局部窗口(locality),有点像滑动卷积核;
- CvT、LeViT:在前几层用卷积替代patch embedding;
- DeiT + distillation:用CNN当teacher,把先验“蒸馏”给ViT;
- Token2Token ViT:加了局部token聚合结构;
- ConvNeXt:是个有Transformer设计灵感的CNN,重新思考CNN架构;
这些模型的核心目标:
找到“结构先验”和“灵活模型”之间最优的折中点。
五、补充
5.1 二维的位置编码
论文中虽然使用的是一维的位置编码(形状是 ( N \times D )),但其内部其实是用 二维网格的位置(行、列)来生成的。这种方式可以分两种做法:
方法一:直接学习一个二维位置编码表
- 假设有一个 ( h \times w ) 的patch grid,例如 ( 14 \times 14 );
- 可以学习一个形状为 ( h \times w \times D ) 的可训练tensor;
- 然后 reshape 为 ( N \times D )(即 ( 196 \times 768 ));
- 加到 patch embedding 上。
方法二:分离式编码(Horizontal + Vertical)
这是某些 ViT 变种(比如 Axial Attention、Swin Transformer、BEiT 等)采用的方式。
- 学两个 embedding:行方向的位置编码
h × D/2
和列方向的位置编码w × D/2
; - 对于每个位置 (i,j),位置编码是两个向量之和或拼接:
PE i , j = RowPE i ∣ ∣ ColPE j \text{PE}_{i,j} = \text{RowPE}_i || \text{ColPE}_j PEi,j=RowPEi∣∣ColPEj - 再 reshape 成一维序列。
这种方式的好处是更结构化、可扩展、可插值(这在fine-tuning时很有用)。
5.2 挖坑
众所周知,ViT挖了很多坑。
众所周知,CV界是出了名的卷。坑多,填坑的也多。
略微总结:
坑号 | 挖坑主题 | 被谁填了? |
---|---|---|
1 | 去CNN化 | Swin、PVT、CvT、LeViT、ConvNeXt |
2 | 位置编码不泛化 | 相对PE、Swin移动窗口、Focal、Twins |
3 | 局部性完全缺失 | T2T、DeiT distill、RegionViT、MobileViT |
4 | 训练资源极高 | DeiT、MAE、DINO、BEiT、自监督系列 |
5 | 没有多尺度层次 | PVT、Swin、SegFormer、Pix2Seq |
6 | 归纳偏置之争 | 整个 community 都在参与 |
7 | 图像语言统一建模范式 | CLIP、BLIP、GPT-4V、Flamingo |