当前位置: 首页 > news >正文

ViT 模型讲解

文章目录

    • 一、模型的诞生背景
      • 1.1 背景
      • 1.2 ViT 的提出(2020年)
    • 二、模型架构
      • 2.1 patch
      • 2.2 模型结构
        • 2.2.1 数据 shape 变化
        • 2.2.2 代码示例
        • 2.2.3 模型结构图
      • 2.3 关于空间信息
    • 三、实验
      • 3.1 主要实验
      • 3.2 消融实验
    • 四、先验问题
      • 4.1 归纳偏置
      • 4.2 先验or大数据???
    • 五、补充
      • 5.1 二维的位置编码
      • 5.2 挖坑


Vision Transformer(ViT)是一个开创性的计算机视觉模型,它首次成功地将Transformer架构引入到视觉领域,打破了长期以来以卷积神经网络(CNN)为主导的局面。


一、模型的诞生背景

1.1 背景

  1. Transformer 在 NLP 的成功
    Transformer 架构由 Vaswani 等人在 2017 年提出(论文《Attention is All You Need》),它在 NLP 中(比如 BERT、GPT)表现非常出色。

  2. CV 领域的主流仍是 CNN
    视觉领域长期由 CNN 主导(如 ResNet、EfficientNet 等),这些模型通过局部卷积提取空间信息。

  3. 挑战:Transformer 在图像上的应用
    图像不像文本那样是离散的序列,Transformer 需要序列化输入,因此直接用 Transformer 处理图像并不直观。

1.2 ViT 的提出(2020年)

论文:“An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”
作者:Google Research
意义:首次证明了,在足够的数据和算力支持下,Transformer 可以在图像分类任务中超过最好的 CNN


二、模型架构

2.1 patch

假设输入是一批图像,维度为:

(B, H, W, C)
  • B = batch size
  • H = 高(比如 224)
  • W = 宽(比如 224)
  • C = 通道数(比如 RGB → 3)

我们想让 Transformer 接收这个图像,但问题:Transformer 要求输入是序列: 每个 token 有一个 embedding,形状应该是:

(B, N, D)
  • N = 序列长度(token 数)
  • D = 每个 token 的 embedding 维度

那我们能不能把图像 reshape 成这种形状呢?
尝试直接 reshape:把每个像素当作一个 token?

(B, H, W, C) → reshape → (B, H*W, C)

比如说:

B = 32
H = W = 224
C = 3(32, 224*224, 3) = (32, 50176, 3)

问题:序列长度 N = 50,176,太长了!

  • Transformer 的注意力是 O(N²),即 O(50176²) ≈ 2.5 billion,显存、计算量巨大!
  • 实际上,BERT 这样的 NLP 模型一般处理的序列长度才 512 左右。

这就是 ViT 的关键创新:

把图像划分为 Patch,每个 Patch 当成一个 token。

我们将图像划分为大小为 P x P 的 patch,比如 16 x 16

  • 则每个 patch 的像素数为:P*P*C = 16*16*3 = 768

  • 每张图像总共包含 patch 数:

    N = (H / P) * (W / P) = (224 / 16)^2 = 14 * 14 = 196
    
  • 所以 reshape 后的数据维度是:

    (B, 14, 14, 16, 16, 3) -> (B, 14*14, 16*16*3)
    

2.2 模型结构

设定一个典型 ViT 的输入参数如下:

  • 图像大小:224 x 224 x 3(高×宽×通道)
  • Patch 大小:16 x 16
  • 输出维度(embedding dim):768
  • Patch 数量:(224 / 16)^2 = 14 x 14 = 196
  • Transformer 层数:12
  • Head 数:12
  • MLP Head 输出类别数:1000(如 ImageNet)
2.2.1 数据 shape 变化

数据 shape 流程(标准 ViT)

Input Image: (3, 224, 224)
 ↓
1. Split into Patches (16x16) → 每个 patch 展平成向量:
    得到 (196, 3*16*16) = (196, 768)
 ↓
2. Linear Projection of Patches → Embedding:
    每个 patch 向量映射为 (196, 768)
 ↓
3. 加入 CLS token:
    + 1 个 token,shape 变为 (197, 768)
 ↓
4. 加入 Position Embedding:
    shape 仍是 (197, 768)
 ↓
5. 输入 Transformer Encoder(L 层):
    每层输出:仍是 (197, 768)
    Attention → FFN → Add & Norm
 ↓
6. 取 CLS token 表示:
    (1, 768)
 ↓
7. 分类头(MLP Head):
    Linear(768 → 1000)
    得到 (1, 1000)
 ↓
8. softmax 输出:
    1000 类别概率

架构图(图示)

                   ┌─────────────────────┐
     Image         │  Input Image        │
  (3, 224, 224)    │   (3 x 224 x 224)   │
                   └────────┬────────────┘
                            ↓
                ┌────────────────────────┐
                │ Split into Patches     │
                │   16x16 patches        │
                │   196 patches total    │
                └────────┬───────────────┘
                         ↓
             ┌─────────────────────────────┐
             │ Flatten + Linear Projection │
             │ Each patch → (1 x 768)      │
             │ → (196 x 768)               │
             └────────┬────────────────────┘
                      ↓
            ┌──────────────────────────┐
            │ Add CLS Token (1 x 768)  │
            │ → (197 x 768)            │
            └────────┬─────────────────┘
                     ↓
         ┌───────────────────────────────┐
         │ Add Positional Embedding      │
         │ → (197 x 768)                 │
         └────────┬──────────────────────┘
                  ↓
        ┌────────────────────────────┐
        │ L x Transformer Encoder    │
        │ (Self-Attn + FFN)          │
        │ Output: (197 x 768)        │
        └────────┬───────────────────┘
                 ↓
        ┌──────────────────────────────┐
        │ Take [CLS] token (1 x 768)   │
        └────────┬─────────────────────┘
                 ↓
        ┌─────────────────────────────┐
        │ MLP Head → Class logits     │
        │ Linear: 768 → 1000          │
        └────────┬────────────────────┘
                 ↓
             Output: (1 x 1000)
2.2.2 代码示例

PyTorch 实现的示例

# 假设输入图像 batch 为 (B, 3, 224, 224)
B = 32
patch_size = 16
embed_dim = 768
num_patches = (224 // patch_size) ** 2  # 14 * 14 = 196

# 1. Patch embedding
x = torch.randn(B, 3, 224, 224)
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
patches = patches.contiguous().view(B, 3, -1, patch_size, patch_size)
patches = patches.permute(0, 2, 1, 3, 4).contiguous().view(B, num_patches, -1)  # (B, 196, 768)

# 2. Linear projection + Add cls + pos embedding → (B, 197, 768)
# 3. Transformer blocks → 输出 shape 不变 (B, 197, 768)
# 4. MLP Head → (B, 1000)
2.2.3 模型结构图

在这里插入图片描述


2.3 关于空间信息

ViT 把每个 patch flatten 后再当作 token 输入,是否会丢失空间信息?模型还能理解图像的空间结构吗?

我们来逐步拆解这个问题。

原图 → (B, 224, 224, 3)
patch → 划成 (B, 14, 14, 16, 16, 3)
flatten → (B, 196, 768)   ← 每个 patch 被完全拉平

也就是说:

  • 每个 16x16 patch 被视为一个 768 维向量,送进 Transformer。
  • 在 patch 内部,像素之间的 2D 空间关系已经丢失了(被 flatten)。

这其实是一个 trade-off:局部 vs 全局

CNN 是局部建模:

  • 用卷积核滑动 → 保留局部空间结构
  • 多层堆叠 → 扩展感受野

ViT 是全局建模:

  • 每个 patch 被视作一个整体 token,不知道 patch 内部的位置信息
  • 但可以通过 self-attention 在patch 之间建模全局关系

ViT 是如何弥补空间结构缺失的?

  1. 位置编码(Positional Embedding)

    虽然 patch 本身是被 flatten 的,但 ViT 给每个 patch 加了一个位置向量:

    patch_embeddings + positional_embeddings
    
    • 这些 positional embeddings 是可学习参数,形状为 (num_patches, embed_dim),比如 (196, 768)
    • 它告诉模型:第几个 patch 是图像的哪一部分(左上、右下、中间…)
  2. Self-Attention 建模 patch 间的关系

    虽然每个 patch 内部的空间结构被 flatten 掉了,但 ViT 有全局的 self-attention:

    • 每个 patch 可以“看到”所有其他 patch
    • 自注意力机制可以自动学习到:哪些 patch 是连续的、有相似内容的、有上下文关系的

patch之间有位置编码,那patch内部的空间信息就真的没了吗?

基本上……是的,ViT 在标准设计中:
完全忽略了 patch 内部像素的空间结构(除了间接靠大量数据让模型学到某些模式)

这就是为什么后续很多改进型 ViT 试图引入更精细的空间建模,比如:

  1. CNN + ViT 结合
    比如 CvT, ConViT:在 embedding 之前加一层卷积提取局部特征。

  2. 局部注意力 / 层次结构
    比如 Swin Transformer:只在局部窗口内做注意力,保留 patch 内空间结构。

  3. 使用 Patch Token Hierarchy
    把图像分层次进行划分,像 CNN 的 downsampling 一样逐层抽象。

  4. 位置编码增强
    有些 ViT 用坐标编码(x, y)拼进去,或者使用相对位置编码(像 NLP 中那样)。

问题是否存在弥补方式
Patch 内部 flatten 导致空间信息丢失存在无法恢复(在标准 ViT 中)
Patch 之间的空间结构被建模位置编码 + attention
模型理解图像结构的能力部分靠 attention 建模更依赖大规模预训练数据

三、实验

3.1 主要实验

实验变量设置结果/结论
ViT vs ResNet模型结构ViT-B/16 vs ResNet152x4在 ImageNet-21k 上 ViT > ResNet,参数更少,性能更高
不同预训练数据数据规模ImageNet (1M) vs IN-21k (14M) vs JFT-300M (300M)ViT 在大数据集上才训练得好,小数据集会过拟合
小样本迁移能力Fine-tune 到小数据集CIFAR, Flowers, VTABViT 表现良好,迁移性强
与其他模型迁移对比方法对比ResNet vs ViTViT 表现更好,但需要预训练支持

3.2 消融实验

  1. Patch Size 影响
  • Patch size: 8x8、16x16、32x32
  • 越小越好(序列变长,表达力强)
  • 但计算量 ↑,内存占用 ↑
  • 16x16 是最佳平衡

结论:小 patch 能提升性能,但要付出代价。

  1. 位置编码
  • 加 vs 不加
  • 绝对位置 vs 相对位置
  • 插值位置编码支持不同分辨率

结论:必须加位置编码,且可插值适配不同输入大小。

  1. MLP Head 设计
  • Linear head vs 多层 MLP
  • 结果几乎一样

结论:简单的 linear classifier 足以。

  1. LayerNorm 放置位置
  • Pre-LN(LN 在 Attention 前) vs Post-LN
  • Pre-LN 更稳定,训练更好收敛

结论:Pre-LN 更适合 ViT。

  1. Class Token 的作用
  • 类似 BERT 的 [CLS]
  • 每层 attention 都能访问,用于最终分类

结论:必须有,模型依赖它来提取全局特征。

  1. 是否需要强数据增强 / 自监督
  • 加了 CutMix、RandAug 等 → 效果提升小
  • ViT 不如 CNN 那么依赖数据增强

结论:数据增强对 ViT 不敏感,但对 CNN 很关键。

  1. 是否使用卷积替代 patch embedding
  • 用卷积提 patch embedding 没有显著提升

结论:线性投影就可以,无需 CNN 特征提取头。

关键有效设计

组件作用是否关键
Patch embedding降低序列长度1
Positional Encoding建模空间顺序1
Class Token输出图像全局表征1
Self-Attention全局建模1
Pre-LN 架构稳定训练1
大规模数据预训练避免过拟合111

可有可无的点

组件作用结论
强数据增强抗扰动、泛化提升小,可有可无
MLP Head增加判别能力线性 head 足够
卷积前置层提取局部特征变化不大

四、先验问题

4.1 归纳偏置

先验(prior)”在深度学习里,其实就是“归纳偏置(inductive bias)”的一个通俗说法。

归纳偏置(Inductive Bias):指的是一个学习算法在数据不足时,如何做出合理泛化的偏好或假设

简化理解:

当我看不到所有数据的时候,我默认事情是这样的 —— 这就是我的归纳偏置。

举几个通俗例子:

模型归纳偏置 / 先验
CNN局部连接、权重共享、平移不变性(空间先验)
RNN/LSTM时间顺序相关性
GNN图结构邻居影响中心点
Transformer无特别强的归纳偏置(完全数据驱动)

先验 = 归纳偏置 = 模型对现实世界的一种结构性假设


4.2 先验or大数据???

ViT 完全基于 Transformer,没有 CNN 的结构先验:

  • 不知道局部区域重要;
  • 不知道平移后的图还是同一个图;
  • 所以 ViT 必须通过大量数据自己“悟出这些结构”;
  • 而 CNN 天生就有这些“归纳偏置”植入其中。

ViT 证明了一件事:在足够大、足够多样的数据下,模型可以学出先验,从而超越“手工嵌入先验”的模型。

在 ViT 的实验中:

数据集模型表现
ImageNet-1kViT 差于 CNN
ImageNet-21kViT ≈ CNN
JFT-300MViT > CNN(显著)

这清楚地说明了:

  • 在小数据(ImageNet-1k)下,CNN 的归纳偏置(如局部性、平移不变)提供了巨大优势
  • 在大数据(JFT-300M)下,ViT 能通过学习自动获得这些归纳偏置(甚至更多),从而反超 CNN;
  • 也就是说:数据能“弥补”先验的缺失,甚至最终超越它。

所以,ViT 是一个“用海量数据和大模型学到归纳偏置”的典范。

但这并不是说“先验没用了”!

ViT 不是在否定先验的价值,而是在提出一个新的平衡:

归纳偏置(先验)和数据之间,是可以相互替代的,但各有代价。

| 有强归纳偏置(CNN) | 少数据、高效训练、泛化强,但灵活性差 |
| 无归纳偏置(ViT) | 灵活、能力上限高,但数据、算力需求大 |

所以:

  • CNN 很适合“小数据集、低资源”的任务;
  • ViT 更适合“数据丰富、资源充足”的大规模学习;
  • ViT 的出现是因为 Google 拥有海量数据和TPU,大部分人是玩不起纯 ViT 的。

ViT 后续的发展说明了:先验和数据,是可以结合的

很多 ViT 变体其实都在 “加回先验”:

  • Swin Transformer:加入了局部窗口(locality),有点像滑动卷积核;
  • CvT、LeViT:在前几层用卷积替代patch embedding;
  • DeiT + distillation:用CNN当teacher,把先验“蒸馏”给ViT;
  • Token2Token ViT:加了局部token聚合结构;
  • ConvNeXt:是个有Transformer设计灵感的CNN,重新思考CNN架构;

这些模型的核心目标:

找到“结构先验”和“灵活模型”之间最优的折中点。


五、补充

5.1 二维的位置编码

论文中虽然使用的是一维的位置编码(形状是 ( N \times D )),但其内部其实是用 二维网格的位置(行、列)来生成的。这种方式可以分两种做法:

方法一:直接学习一个二维位置编码表

  1. 假设有一个 ( h \times w ) 的patch grid,例如 ( 14 \times 14 );
  2. 可以学习一个形状为 ( h \times w \times D ) 的可训练tensor;
  3. 然后 reshape 为 ( N \times D )(即 ( 196 \times 768 ));
  4. 加到 patch embedding 上。

方法二:分离式编码(Horizontal + Vertical)

这是某些 ViT 变种(比如 Axial Attention、Swin Transformer、BEiT 等)采用的方式。

  • 学两个 embedding:行方向的位置编码 h × D/2 和列方向的位置编码 w × D/2
  • 对于每个位置 (i,j),位置编码是两个向量之和或拼接:
    PE i , j = RowPE i ∣ ∣ ColPE j \text{PE}_{i,j} = \text{RowPE}_i || \text{ColPE}_j PEi,j=RowPEi∣∣ColPEj
  • 再 reshape 成一维序列。

这种方式的好处是更结构化、可扩展、可插值(这在fine-tuning时很有用)。


5.2 挖坑

众所周知,ViT挖了很多坑。
众所周知,CV界是出了名的卷。坑多,填坑的也多。
略微总结:

坑号挖坑主题被谁填了?
1去CNN化Swin、PVT、CvT、LeViT、ConvNeXt
2位置编码不泛化相对PE、Swin移动窗口、Focal、Twins
3局部性完全缺失T2T、DeiT distill、RegionViT、MobileViT
4训练资源极高DeiT、MAE、DINO、BEiT、自监督系列
5没有多尺度层次PVT、Swin、SegFormer、Pix2Seq
6归纳偏置之争整个 community 都在参与
7图像语言统一建模范式CLIP、BLIP、GPT-4V、Flamingo

相关文章:

  • 【Java八股】
  • 3.2.2.2 Spring Boot配置视图控制器
  • 机器学习项目三:颜色检测
  • Java老鼠迷宫(递归)---案例来自韩顺平老师讲Java
  • Neo4j GDS-11-neo4j GDS 库中相似度算法实现
  • 鸿蒙开发-ArkUi控件使用
  • 重学Redis:Redis常用数据类型+存储结构(源码篇)
  • 5.5 GitHub数据秒级分析核心揭秘:三层提示工程架构设计解析
  • 日志文件爆满_配置使用logback_只保留3天日志文件_每天定时生成一个日志文件---SpringCloud工作笔记206
  • 如何制定有效的风险应对计划
  • C++ std::string_view介绍及性能提升分析
  • android面试情景题详解:android如何处理断网、网络切换或低速网络情况下的业务连续性
  • 关于SENSOR 720P/1080P 静电保护方案
  • Python静态方法和类方法详解
  • 在断网的时候,websocket 一直在CLOSING 状态
  • 如何制定合理的项目预算
  • Docker详细使用
  • Windows 系统如何使用Redis 服务
  • 什么是分布式声波传感
  • 性能炸裂的数据可视化分析工具:DataEase!
  • 普京呼吁乌方响应和平倡议,称将分析民用设施停火提议
  • 商务部24日下午将举行发布会,介绍近期商务领域重点工作情况
  • 全国登记在册民营企业超过5700万户,占企业总量92.3%
  • 9厘米,25克!最小最轻的无线陆空两栖机器人来了
  • 对话地铁读书人|科研服务者岳先生:地铁适合浅阅读
  • 寻找“香奈儿”代工厂