从头开始掌握扩散概率模型
理论与数学
扩散模型通过一系列时间步长 T (x₀,xₜ) 逐渐降低图像中的信息量来运作。每一步都会添加少量高斯噪声,最终将图像转换为纯随机噪声,类似于正态分布的样本,这被称为前向过程。从 xₜ₋₁ 到 xₜ 的过渡遵循这种噪声添加机制。
为了扭转这一局面,需要训练神经网络逐步消除噪声。训练完成后,模型可以从正态分布中提取的随机噪声开始。它会迭代地对输入进行去噪,每次迭代都会去除一些噪声,直到最终得到与原始分布相似的清晰图像。
这种方法在概念上与变分自编码器 (VAE)类似。在 VAE 中,图像被编码为高斯分布的均值和方差,然后解码器通过从该分布中采样来重建图像。类似地,扩散模型的去噪过程将随机噪声转换回相干图像,类似于 VAE 中的重建阶段。
扩散是指分子从高浓度区域向低浓度区域移动。从统计学意义上讲,扩散过程是一个随机马尔可夫过程,其特征是连续的样本路径。随机性意味着存在随机性,而马尔可夫则表明未来状态仅取决于当前状态——了解过去状态不会提供更多信息。连续性意味着过程平稳演进,没有突然的跳跃。
在统计学中,扩散描述了在同一域内,复杂分布向更简单分布(通常是先验分布)的变换。如果满足某些条件,对来自任何分布的样本反复应用过渡核,最终都会得到来自该更简单先验分布的样本。在扩散模型中,输入图像代表一个复杂分布,它会逐渐转换为一个简单的正态分布。

这个函数背后发生了什么?
假设系数为⍺和ꞵ,两者相互独立。我们将从左侧给出的分布开始,对图像进行采样并绘制直方图。然后,我们将应用一次过渡步骤(⍺=0.5,ꞵ=0.1),这将导致值发生变化,从而导致直方图也发生变化。
通过应用该方程一次,我们观察到从原始状态逐渐过渡到更随机的状态。当 ⍺ = 0 且 ꞵ = 1 时,我们立即获得高斯分布。然而,目标是通过增量变化逐步实现这种转变,这定义了扩散过程。如果 ⍺ > 1,方差将不受控制地增加,从而阻止收敛到期望分布。因此,对原始值进行一些衰减至关重要。合适的选择是⍺ = 0.999,略低于 1。
接下来,必须逐步引入噪声项 ꞵ。较高的 ꞵ 值会导致分布发生突变,而这正是我们想要避免的。相反,变换应该缓慢进行,以确保随着原始分布的退化,它越来越接近正态分布。最终,我们得到了最优值 ⍺ = √0.99 和 ꞵ = √0.01。
从数学上讲,将项 xₜ₋₁ 和 xₜ₋₂ 代入方程式可以揭示过程中的重复模式,从而强化转变的渐进性。
对于所有时间步长,我们最终得到一个最终方程,其中只有第一项包含 x₀。当 T 足够大时,由于乘积中的所有值都小于 1,因此第一项趋近于零。其余项服从高斯分布,均值为 0,但方差不同。
由于这些项是独立的,它们可以组合成一个高斯分布,整体均值仍然为 0,方差等于各个方差之和。具体来说,最后一项的方差为 β,倒数第二项的方差为 β(1-β),依此类推。这形成了一个几何级数 (GP),其中第一项为 β,公比为 1-β。因此,T 项的和可以表示为:β(1-(1-β)ₜ) / (1-(1-β)) ~ β/β = 1
我们本质上是将原始分布结构被破坏的程度与我们引入的噪声量联系起来。对于一维分布,扩散过程被离散化为有限步骤。这种方法可以扩展到 aw×h,其中前向过程的输出为 aw×h,每个像素类似于均值为 0、方差为 1 的高斯分布样本。
在实践中,我们并非在每一步都使用恒定的噪声方差,而是采用一个调度方案。作者提出了一个线性调度方案,该方案会随时间逐渐增加噪声方差。这种策略合乎逻辑,因为在逆向过程开始时,模型需要学习进行较大的调整。随着图像接近清晰,模型需要进行更小、更精细的调整。
此调度确保方差从输入分布平滑地缩放到目标高斯分布。在噪声方差固定的情况下,分布方差的降低在早期非常显著,并在大约 500 步后达到接近 1。然而,当遵循建议的调度时,方差的降低在大部分时间步长内更加平缓且一致,从而实现更平滑的过渡。
为了获得 t=1000 时 X 的值,我们需要应用转换 1000 次来遍历整个马尔可夫链,这是低效的。

我们需要进行所有这些数学运算,仅仅是为了给输入图像添加噪声,其逆过程也是一个具有高斯转移概率的马尔可夫链。我们无法直接计算它,因为为此我们需要计算整个数据分布,但我们可以用分布 P 来近似它,该分布 P 是一个高斯分布,其均值和方差是我们感兴趣的学习参数。
如果我们不能计算逆分布,我们如何训练一个近似它的模型?
在 VAE 中我们也遇到过类似的情况,即在给定 Z 的情况下我们不知道 X 的真实分布,但我们学习通过神经网络来近似它。因此,我们希望学习 P(X | Z),以便能够生成尽可能接近训练数据分布的图像。
假设 P 服从高斯分布,第一项表示生成的输出与真实图像之间的重建损失。第二项表示先验分布(均值为零、方差为单位的高斯分布)与编码器预测的分布之间的 KL 散度。这个概念可以扩展到扩散过程。
在这种情况下,我们不是直接过渡到Zₜ ,而是逐步通过一系列潜在变量从 x₁ 过渡到 xₜ 。这里,q(z ∣ x) 被 q(xₜ ∣ xₜ₋₁) 取代,后者是固定的,未经学习。与之前一样,目标是最大化观测数据的似然值。
期望中的项,q 是正向过程,P 是反向过程的近似值。由于该链是马尔可夫链,因此 q(xₜ ∣ xₜ₋₁) = q(xₜ ∣ xₜ₋₁, x₀)。
我们来看看分母,应用贝叶斯定理可以得到:
其余项与下界方程相加,我们可以将这个方程分解为三个对数项的和。记住,所有这些都是在 Q 的期望下进行的。
第一项类似于 KL 散度先验项,但在这里,因为我们使用了扩散,所以 q 是固定的,并且根据理论,最终的 q(xₜ ∣ x₀) 实际上将非常接近正态分布,所以这是无参数的,我们不会费心去优化它。第二项是在给定 x₁ 的情况下对输入 x₀ 的重构,最后一项是所有量之和,这些量都是 KL 散度。并且由于我们想要最大化下限,因此我们希望最小化所有这些尺度项,该公式的优点在于最后一项涉及两个相同形式的量,这仅仅要求近似去噪过渡分布非常接近以 x₀ 为条件的真实去噪过渡分布。
我们又回到了同样的问题,因为我们不知道反向分布,所以我们该如何前进?
我们现在有了 q(xₜ₋₁ ∣ xₜ),这在概念上更容易计算。通过应用贝叶斯定理,我们可以将其分解。这个过程中的每个项都是高斯的,我们已经推导出必要的表达式。第一个项代表前向过程,虽然它以 x₀ 为条件,但由于前向过程遵循马尔可夫链,因此它不受影响。
另外两个项可以用之前建立的递归式来表示,这样我们就可以在任意时间步 t 从 x₀ 过渡到噪声图像。由于这些项服从高斯分布,因此可以将其表示为指数形式。主要目标是计算 xₜ₋₁。为了实现这一点,我们可以将表达式重写为完全平方,从而推导出高斯分布的方程。通过这种方法,我们将 x²ₜ₋₁ 项、xₜ₋₁ 项以及所有与 xₜ₋₁ 无关的项分离出来。
到目前为止,我们已经将正态分布表示为指数形式,并使用代数求和将其简化。现在,我们将重点关注最后一个与 xₜ₋₁ 无关的项。进一步简化后,我们发现该项可以分解为一个表达式的平方乘以 2xₜ₋₁。
这样,我们就可以把逆分布的整个方程写成高斯分布,其中均值和方差分别为高斯分布。最后一项需要计算 q(xₜ₋₁ | xₜ₋₁, x₀)。

我们可以看到,平均值是 xₜ 和 x₀ 的加权平均值。如果我们计算 x₀ 的权重,我们会发现它在较高的时间步长下非常低,而随着我们接近逆过程的终点,这个权重会增加。从图中我们可以看到,权重在很长一段时间内为 0,但当我们绘制对数值时,我们可以看到它确实随着时间的推移而增加。
对于逆分布过程,我们必须对其进行近似,因为在生成过程中,我们实际上没有 x₀,但我们知道逆分布是高斯分布,所以我们也可以将近似值作为高斯分布。我们需要做的就是学习均值和方差。作者所做的第一件事就是将方差修正为与真实值去噪步骤完全相同。
似然函数,所有项都是 q 的期望,使其成为 KL 散度,当我们使用高斯的 KL 散度公式时,因为这里两个分布具有完全相同的方差,它最终将是均值之间的差的平方除以方差的两倍,即这个量,因此,由于我们的目标是最大化可能性,我们现在需要减少差异,因为我们的基本事实去噪步骤具有均值。
虽然论文中的损失有所不同,在某些噪声方面有所不同,但这里我们找到了其他的东西。为什么呢?

回顾→在似然项中,我们忽略了第一个项,而专注于最后一个求和项。我们发现它可以写成真实噪声与某个模型使用 xₜ 作为输入生成的噪声预测之间的缩放平方差,实际上我们也提供了时间步长。作者实际上完全忽略了缩放,并通过实验发现,仅基于噪声的平方差训练模型就足够了。
通过让模型近似 x₀ 的噪声以获得 x₁,似然函数中的第二项被包裹在这个损失之下。
训练
在训练过程中,我们首先从数据集中采样一张图像,并均匀地选择一个时间步长 t。接下来,我们从正态分布中采样随机噪声。利用 xₜ 的 x₀ 和 ϵ 方程,我们可以在扩散过程中获得时间步长 t 处图像的噪声版本。然后,我们使用原始图像、采样噪声、时间步长和噪声方案来计算累积乘积项。该噪声图像将通过神经网络,我们使用损失函数训练网络,以最小化预测噪声和实际噪声之间的差异。通过多步训练,我们覆盖了所有时间步长,并有效地优化了求和项的每个组成部分。
对于图像生成,我们遵循相反的过程,从神经网络学习到的去噪步长分布 P 中迭代采样。为了生成图像,我们首先从正态分布中随机抽取一个样本作为步长 t 的初始图像。然后,将其传入我们训练好的模型来预测噪声。需要说明的是,我们的近似去噪分布由时间步长 t 的噪声图像的平均值和预测噪声定义。
一旦有了预测噪声,我们就可以使用均值图像和方差(固定为与前向传播过程中相同的方差)从分布中采样一幅图像。采样后的图像将成为 xₜ₋₁。我们不断重复此过程,直到得到原始图像 x₀。唯一的区别在于,为了从 x1 到 x₀,我们直接返回均值图像。
相关资料+学习资料包↓(或看我个人简介处)