当前位置: 首页 > news >正文

【NeurlPS 2024】MAR:无矢量量化的自回归图像生成

1 前言

​ 本期内容,我们将MAR(掩码自回归模型)。与传统的视觉自回归需要把图像离散化成一个个token的方式不同。MAR构建的是一个连续型的自回归模型。
视频:MAR:无矢量量化的自回归图像生成

2 引入

​ 对于传统的自回归模型而言,比如VQGAN+GPT2,一般来说是需要先用VQGAN把图像量化成1D的token。然后再使用仅解码器的Transformer进行图像生成,比如GPT2。而本篇论文提出了一个问题——“图像自回归的离散化是否是必须的?

​ 论文认为这并不是必须的,因此,他们便想去构建一个连续型的视觉自回归模型。如何去构建呢?在我们的印象中,什么样的东西是直接对连续型数据直接建模的呢?用脚想一下,扩散模型不就是吗?让我们先来回顾一下传统的离散视觉自回归模型

3 离散视觉自回归模型

​ 对于离散视觉自回归模型,往往我们会把图像离散化成一个个的token,比如我们离散化成n个token—— { x 1 , x 2 , ⋯ , x n } \{x^1,x^2,\cdots,x^n\} {x1,x2,,xn}。里面的每一个token可以表示一个个的整数。一般来说,视觉自回归使用的是的仅解码器架构,所以注意力使用的是因果注意力,我们可以把这些token的关系分解为
p ( x 1 , x 2 , ⋯ , x n ) = ∏ i = 1 n p ( x i ∣ x 1 , ⋯ , x i − 1 ) p(x^1,x^2,\cdots,x^n)=\prod\limits_{i=1}^np(x^i|x^1,\cdots,x^{i-1}) p(x1,x2,,xn)=i=1np(xix1,,xi1)
​ 通俗来讲,就是当前时刻的状态,仅依赖于过去的所有时刻,而与未来无关。

​ 为了与论文保持一致的结果,接下来我们定义 x x x为我们下一个要预测的token。记词汇表的大小为K(表示 x x x的取值只有K种,即 0 ≤ x ≤ K 0\le x\le K 0xK)。自回归模型进行生成的时候,会先生成一个连续型的D维向量,即 z ∈ R D z\in \mathbb{R}^D zRD。接着,会使用一个投影矩阵 W ∈ R K × D W\in \mathbb{R}^{K\times D} WRK×D z z z投影到K个类别大小,然后使用使用softmax将其转换为概率,我们可以把该过程用概率去表示 p ( x ∣ z ) = softmax ( W z ) p(x|z)=\text{softmax}(Wz) p(xz)=softmax(Wz)

​ 所以目前来说,我们认为下一个要预测的token x x x ,可以通过建模概率分布 p ( x ∣ z ) p(x|z) p(xz)来获得,离散自回归是将其表示为 softmax ( W z ) \text{softmax}(Wz) softmax(Wz)。那么问题就来了,我们是否可以用其他的方法去获得这个概率分布呢?

​ 当然可以!我们之前所学习过的生成模型,不就是能够学习概率分布吗。假设我们能够学习出 p ( x ∣ z ) p(x|z) p(xz),那么当我们提供 z z z之后,便可获得 x x x,也就是生成 x x x。如果我们用其他生成模型去建模 p ( x ∣ z ) p(x|z) p(xz),比如扩散模型,那么也就没有离散化的硬性要求了,因为扩散模型本身就可以处理连续型的数据。

4 Diffusion Loss

​ 我们用扩散模型去拟合 p ( x ∣ z ) p(x|z) p(xz)这个分布,拟合生成之后,给定一个 z z z,我们就可以获得 x x x了。

训练:

在这里插入图片描述

​ 即图中所示,先获取自回归模型的 z z z,把他作为条件,然后利用 z z z x x x去训练扩散模型。即
L ( z , x ) = E ϵ , t [ ∥ ϵ − ϵ θ ( x t ∣ t , z ) ∥ 2 ] (1) \mathcal{L}(z,x) = \mathbb{E}_{\epsilon,t}\left[\Vert \epsilon - \epsilon_{\theta}(x_t|t,z) \Vert^2 \right]\tag{1} L(z,x)=Eϵ,t[ϵϵθ(xtt,z)2](1)
​ 其中, ϵ \epsilon ϵ是一个高斯白噪声, x t = α ˉ t x + 1 − α ˉ t ϵ x_t = \sqrt{\bar\alpha_t}x+\sqrt{1-\bar\alpha_t}\epsilon xt=αˉt x+1αˉt ϵ α ˉ t \bar\alpha_t αˉt是噪声调度,t是噪声调度的时间步。如图中所示,扩散模型的网络结构只是一个小型的MLP网络, ϵ θ ( x t ∣ t , z ) \epsilon_{\theta}(x_t|t,z) ϵθ(xtt,z)表示将 x t x_t xt作为网络的主输入,而 t t t z z z作为旁系条件输入。里面的 z z z是由自回归网络 f f f生成的,即 z = f ( ⋅ ) z=f(\cdot) z=f()

​ 注意到期望不仅对 ϵ \epsilon ϵ求,还对时刻t求了。为什么呢?论文提到,他们设计的去噪网络的结构很小,因此他们对于给定一个编码向量 z z z,会采样多个t。这有助于提高损失函数的利用率,而无需重新计算z。在训练得时候,每张图像抽样4次。

采样:
x t − 1 = 1 α ˉ t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t ∣ t , z ) ) + σ t δ x_{t-1} = \frac{1}{\sqrt{\bar\alpha_t}}\left( x_t -\frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(x_t|t,z) \right)+\sigma_t\delta xt1=αˉt 1(xt1αˉt 1αtϵθ(xtt,z))+σtδ
​ 其中, δ \delta δ是高斯白噪声, σ t \sigma_t σt是时间步t的噪声水平。初始化 x T ∼ N ( 0 , I ) x_T\sim N(0,I) xTN(0,I),然后条件去噪得到 x 0 x_0 x0,即得到 x 0 ∼ p ( x ∣ z ) x_0\sim p(x|z) x0p(xz)。去噪的时候,会引入一个温度系数 T \mathcal{T} T,用于缩放 σ t δ \sigma_t\delta σtδ。直观上,就是通过调整方差来控制样本多样性。

5 Diffusion Loss for Autoregressive Models

​ 给定序列图像划分成一堆token { x 1 , x 2 , ⋯ , x n } \{x^1,x^2,\cdots,x^n\} {x1,x2,,xn},我们的目标为
p ( x 1 , ⋯ , x n ) = ∏ i = 1 n p ( x i ∣ x 1 , ⋯ , x i − 1 ) p(x^1,\cdots,x^n)=\prod\limits_{i=1}^np(x^i|x^1,\cdots,x^{i-1}) p(x1,,xn)=i=1np(xix1,,xi1)
​ 请注意,这里的 x i x^i xi是连续的,我们记自回归网络为 f f f。在预测 x i x^i xi的时候,它依赖于 x 1 x^1 x1 x i − 1 x^{i-1} xi1时刻的状态,所以,我们把他们送给自回归网络 f f f以获得 z i : z i = f ( x 1 , … , x i − 1 ) z^i:z^i=f(x^1,\dots,x^{i-1}) zi:zi=f(x1,,xi1)。然后,利用扩散模型计算 p ( x i ∣ z i ) p(x^i|z^i) p(xizi)来生成 x i x^i xi

6 双向注意力可以执行自回归

​ 传统观念认为,只有因果注意力才能够执行自回归生成图像。而本文设计了一个双向注意力的自回归模型。此模型是基于MAE(掩码自编码器),运行流程难以描述出来,见视频当中。

在这里插入图片描述

7 MAR(掩码自回归模型)

​ 掩码自回归模型能够不仅可以预测下一个token,还可以直接预测下一个token集合。
p ( x 1 , ⋯ , x n ) = p ( X 1 , ⋯ , X K ) = ∏ k K p ( X k ∣ X 1 , ⋯ , X k − 1 ) p(x^1,\cdots,x^n)=p(X^1,\cdots,X^K)=\prod\limits_{k}^K p (X^k|X^1,\cdots,X^{k-1}) p(x1,,xn)=p(X1,,XK)=kKp(XkX1,,Xk1)
​ 其中, X k = { x i , x i + 1 , ⋯ , x j } X^k=\{x^i,x^{i+1},\cdots,x^j\} Xk={xi,xi+1,,xj},表示第k步要预测的token集合。见图

在这里插入图片描述

8 结束

​ 好了,本期内容到此为止,如有问题,还望指出,阿里嘎多!

在这里插入图片描述

相关文章:

  • 5G融合消息PaaS项目深度解析 - Java架构师面试实战
  • Adruino:人机界面及接口技术
  • 【数据结构与算法】从完全二叉树到堆再到优先队列
  • 【Redis——通用命令】
  • 【Linux应用】交叉编译环境配置,以及最简单粗暴的环境移植(直接从目标板上复制)
  • goweb-signup注册功能实现
  • xVerify:推理模型评估的革新利器,重塑LLM答案验证格局?
  • 《TCP/IP详解 卷1:协议》之第七、八章:Ping Traceroute
  • 【Web应用服务器_Tomcat】二、Tomcat 核心配置与集群搭建
  • 【高频考点精讲】第三方库安全审计:如何避免引入带漏洞的npm包
  • 机器学习之一:机械式学习
  • CentOS 如何使用截图工具截取命令行操作的图片?
  • 计算机网络 | 应用层(1)--应用层协议原理
  • 数据结构和算法(八)--2-3查找树
  • 【学习笔记】Stata
  • hot100—5.盛水最多的容器
  • 一些常见的资源池管理、分布式管理和负载均衡的监控工具
  • 物联网安全运营概览
  • Spring Boot 应用运行指南
  • C++武功秘籍 | 入门知识点
  • 证据公布!菲律宾6人非法登上铁线礁活动
  • 中国公民在日本被机动车碾压身亡,我使馆发布提醒
  • 东风着陆场做好各项搜救准备,迎接神舟十九号航天员天外归来
  • 杭州6宗涉宅用地收金125.76亿元,萧山区地块楼面价冲破5万元/平米
  • 外交部:印度香客赴中国西藏神山圣湖朝圣将于今年夏季恢复
  • 四川苍溪县教育局通报“工作人员辱骂举报学生”:停职检查