VAE中的编码器(Encoder)详解
VAE中的编码器(Encoder)详解
变分自编码器(Variational Autoencoder, VAE)是一种强大的生成模型,其核心组件之一是编码器(Encoder)。在介绍了重参数化技巧之后,我们可以进一步探讨 VAE 编码器的具体结构和功能。本文将详细讲解 VAE 编码器的工作原理、其数学形式,以及它如何与 ELBO(证据下界)优化目标结合,助力模型训练。
VAE编码器的结构
在 VAE 中,编码器的作用是将输入数据 ( x x x ) 映射到一个潜变量分布 ( q φ ( z ∣ x ) q_φ(z|x) qφ(z∣x) ),从而捕捉数据的潜在表示。通常,我们假设 ( q φ ( z ∣ x ) q_φ(z|x) qφ(z∣x) ) 是一个高斯分布,其均值 ( μ μ μ ) 和方差 ( σ 2 σ^2 σ2 ) 由神经网络输出。具体的形式为:
(
μ
,
σ
2
)
=
EncoderNetwork
φ
(
x
)
(μ, σ^2) = \text{EncoderNetwork}_φ(x)
(μ,σ2)=EncoderNetworkφ(x)
q
φ
(
z
∣
x
)
=
N
(
z
∣
μ
φ
(
x
)
,
σ
φ
2
(
x
)
I
)
q_φ(z|x) = \mathcal{N}(z | μ_φ(x), σ^2_φ(x) I)
qφ(z∣x)=N(z∣μφ(x),σφ2(x)I)
这里的 ( EncoderNetwork φ ( ⋅ ) \text{EncoderNetwork}_φ(\cdot) EncoderNetworkφ(⋅) ) 是一个神经网络,参数化为 ( φ φ φ )。它接受输入 ( x x x )(例如一张图像),输出高斯分布的参数 ( μ φ ( x ) μ_φ(x) μφ(x) ) 和 ( σ φ 2 ( x ) σ^2_φ(x) σφ2(x) )。为了明确这些参数的依赖关系,我们记:
μ = μ φ ( x ) , σ 2 = σ φ 2 ( x ) μ = μ_φ(x), \quad σ^2 = σ^2_φ(x) μ=μφ(x),σ2=σφ2(x)
- ( μ φ ( x ) μ_φ(x) μφ(x) ) 是均值向量,表示潜变量 ( z z z ) 的期望位置。
- ( σ φ 2 ( x ) I σ^2_φ(x) I σφ2(x)I ) 是协方差矩阵,这里假设为对角矩阵,即各维度独立,方差由 ( σ φ 2 ( x ) σ^2_φ(x) σφ2(x) ) 决定。
这种表示强调了 ( μ μ μ ) 和 ( σ 2 σ^2 σ2 ) 是 ( x x x ) 的函数:不同的输入 ( x x x ) 会产生不同的分布参数,而 ( φ φ φ ) 则是控制神经网络权重和偏差的参数。
生成潜变量 ( z z z )
给定第 ( ℓ \ell ℓ ) 个训练样本 ( x ( ℓ ) x^{(\ell)} x(ℓ) ),编码器会输出对应的分布参数 ( μ φ ( x ( ℓ ) ) μ_φ(x^{(\ell)}) μφ(x(ℓ)) ) 和 ( σ φ 2 ( x ( ℓ ) ) σ^2_φ(x^{(\ell)}) σφ2(x(ℓ)) )。潜变量 ( z ( ℓ ) z^{(\ell)} z(ℓ) ) 从该分布中采样:
z ( ℓ ) ∼ N ( z ∣ μ φ ( x ( ℓ ) ) , σ φ 2 ( x ( ℓ ) ) I ) z^{(\ell)} \sim \mathcal{N}(z | μ_φ(x^{(\ell)}), σ^2_φ(x^{(\ell)}) I) z(ℓ)∼N(z∣μφ(x(ℓ)),σφ2(x(ℓ))I)
直接采样高斯分布在神经网络中并不方便,因此我们引入重参数化技巧。
重参数化技巧在编码器中的应用
重参数化技巧将高斯采样的随机性分解为确定性部分和独立噪声部分。对于高维高斯分布 ( z ∼ N ( μ , σ 2 I ) z \sim \mathcal{N}(μ, σ^2 I) z∼N(μ,σ2I) ),我们可以写为:
z = μ + σ ⊙ ϵ , ϵ ∼ N ( 0 , I ) z = μ + σ \odot ϵ, \quad ϵ \sim \mathcal{N}(0, I) z=μ+σ⊙ϵ,ϵ∼N(0,I)
其中 ( ⊙ \odot ⊙ ) 表示逐元素相乘,( ϵ ϵ ϵ ) 是标准正态分布的样本。因此,对于 ( z ( ℓ ) z^{(\ell)} z(ℓ) ),我们有:
z ( ℓ ) = μ φ ( x ( ℓ ) ) + σ φ ( x ( ℓ ) ) ⊙ ϵ , ϵ ∼ N ( 0 , I ) z^{(\ell)} = μ_φ(x^{(\ell)}) + σ_φ(x^{(\ell)}) \odot ϵ, \quad ϵ \sim \mathcal{N}(0, I) z(ℓ)=μφ(x(ℓ))+σφ(x(ℓ))⊙ϵ,ϵ∼N(0,I)
证明:一般协方差矩阵的情况
为了更通用地理解这一变换,我们考虑任意协方差矩阵 ( Σ Σ Σ ) 的情况。假设 ( z ∼ N ( μ , Σ ) z \sim \mathcal{N}(μ, Σ) z∼N(μ,Σ) ),可以通过以下方式采样:
z = μ + Σ 1 / 2 ϵ , ϵ ∼ N ( 0 , I ) z = μ + Σ^{1/2} ϵ, \quad ϵ \sim \mathcal{N}(0, I) z=μ+Σ1/2ϵ,ϵ∼N(0,I)
其中 ( Σ 1 / 2 Σ^{1/2} Σ1/2 ) 是 ( Σ Σ Σ ) 的“平方根”,可以通过特征分解或 Cholesky 分解得到。例如,若 ( Σ = U S U T Σ = U S U^T Σ=USUT )(特征分解),则 ( Σ 1 / 2 = U S 1 / 2 U T Σ^{1/2} = U S^{1/2} U^T Σ1/2=US1/2UT ),( S S S ) 是对角特征值矩阵。关于特征分解,可以参考笔者的另一篇博客:特征分解(Eigen decomposition)在深度学习中的应用与理解
验证其正确性:
-
期望:
E [ z ] = E [ μ + Σ 1 / 2 ϵ ] = μ + Σ 1 / 2 E [ ϵ ] = μ \mathbb{E}[z] = \mathbb{E}[μ + Σ^{1/2} ϵ] = μ + Σ^{1/2} \mathbb{E}[ϵ] = μ E[z]=E[μ+Σ1/2ϵ]=μ+Σ1/2E[ϵ]=μ
因为 ( E [ ϵ ] = 0 \mathbb{E}[ϵ] = 0 E[ϵ]=0 )。 -
协方差:
Cov ( z ) = E [ ( z − μ ) ( z − μ ) T ] = E [ Σ 1 / 2 ϵ ( Σ 1 / 2 ϵ ) T ] = Σ 1 / 2 E [ ϵ ϵ T ] ( Σ 1 / 2 ) T = Σ 1 / 2 I ( Σ 1 / 2 ) T = Σ \text{Cov}(z) = \mathbb{E}[(z - μ)(z - μ)^T] = \mathbb{E}[Σ^{1/2} ϵ (Σ^{1/2} ϵ)^T] = Σ^{1/2} \mathbb{E}[ϵ ϵ^T] (Σ^{1/2})^T = Σ^{1/2} I (Σ^{1/2})^T = Σ Cov(z)=E[(z−μ)(z−μ)T]=E[Σ1/2ϵ(Σ1/2ϵ)T]=Σ1/2E[ϵϵT](Σ1/2)T=Σ1/2I(Σ1/2)T=Σ
因为 ( E [ ϵ ϵ T ] = I \mathbb{E}[ϵ ϵ^T] = I E[ϵϵT]=I )。
当 ( Σ = σ 2 I Σ = σ^2 I Σ=σ2I ) 时,( Σ 1 / 2 = σ I Σ^{1/2} = σ I Σ1/2=σI ),退化为:
z = μ + σ ϵ z = μ + σ ϵ z=μ+σϵ
这正是 VAE 编码器中使用的形式。
编码器与ELBO的关系
VAE 的优化目标是最大化 ELBO,其表达式为:
ELBO ( x ) = E q φ ( z ∣ x ) [ log p θ ( x ∣ z ) ] − D K L ( q φ ( z ∣ x ) ∥ p ( z ) ) \text{ELBO}(x) = \mathbb{E}_{q_φ(z|x)} [\log p_θ(x|z)] - D_{KL}(q_φ(z|x) \| p(z)) ELBO(x)=Eqφ(z∣x)[logpθ(x∣z)]−DKL(qφ(z∣x)∥p(z))
其中:
- 第一项是重构项,衡量生成数据的质量。
- 第二项是先验匹配项,使用 KL 散度 ( D K L ( q φ ( z ∣ x ) ∥ p ( z ) ) D_{KL}(q_φ(z|x) \| p(z)) DKL(qφ(z∣x)∥p(z))) 约束 ( q φ ( z ∣ x ) q_φ(z|x) qφ(z∣x) ) 接近先验分布 ( p ( z ) p(z) p(z) )(通常取 ( p ( z ) = N ( 0 , I ) p(z) = \mathcal{N}(0, I) p(z)=N(0,I) ))。
计算KL散度
对于两个高斯分布 ( N ( μ 0 , Σ 0 ) \mathcal{N}(μ_0, Σ_0) N(μ0,Σ0) ) 和 ( N ( μ 1 , Σ 1 ) \mathcal{N}(μ_1, Σ_1) N(μ1,Σ1) ),KL 散度的解析形式为:具体证明过程请参考笔者的另一篇博客:KL散度在高斯分布中的推导详解
D K L ( N ( μ 0 , Σ 0 ) ∥ N ( μ 1 , Σ 1 ) ) = 1 2 ( Tr ( Σ 1 − 1 Σ 0 ) − d + ( μ 1 − μ 0 ) T Σ 1 − 1 ( μ 1 − μ 0 ) + log det Σ 1 det Σ 0 ) D_{KL}(\mathcal{N}(μ_0, Σ_0) \| \mathcal{N}(μ_1, Σ_1)) = \frac{1}{2} \left( \text{Tr}(Σ_1^{-1} Σ_0) - d + (μ_1 - μ_0)^T Σ_1^{-1} (μ_1 - μ_0) + \log \frac{\det Σ_1}{\det Σ_0} \right) DKL(N(μ0,Σ0)∥N(μ1,Σ1))=21(Tr(Σ1−1Σ0)−d+(μ1−μ0)TΣ1−1(μ1−μ0)+logdetΣ0detΣ1)
在 VAE 中:
- ( q φ ( z ∣ x ) = N ( μ φ ( x ) , σ φ 2 ( x ) I ) q_φ(z|x) = \mathcal{N}(μ_φ(x), σ^2_φ(x) I) qφ(z∣x)=N(μφ(x),σφ2(x)I) ),即 ( μ 0 = μ φ ( x ) μ_0 = μ_φ(x) μ0=μφ(x) ),( Σ 0 = σ φ 2 ( x ) I Σ_0 = σ^2_φ(x) I Σ0=σφ2(x)I )。
- ( p ( z ) = N ( 0 , I ) p(z) = \mathcal{N}(0, I) p(z)=N(0,I) ),即 ( μ 1 = 0 μ_1 = 0 μ1=0 ),( Σ 1 = I Σ_1 = I Σ1=I )。
代入计算:
- ( Σ 1 − 1 = I − 1 = I Σ_1^{-1} = I^{-1} = I Σ1−1=I−1=I ),( Tr ( Σ 1 − 1 Σ 0 ) = Tr ( I ⋅ σ φ 2 ( x ) I ) = Tr ( σ φ 2 ( x ) I ) = σ φ 2 ( x ) ⋅ d \text{Tr}(Σ_1^{-1} Σ_0) = \text{Tr}(I \cdot σ^2_φ(x) I) = \text{Tr}(σ^2_φ(x) I) = σ^2_φ(x) \cdot d Tr(Σ1−1Σ0)=Tr(I⋅σφ2(x)I)=Tr(σφ2(x)I)=σφ2(x)⋅d )(因为迹是标量和)。
- ( ( μ 1 − μ 0 ) T Σ 1 − 1 ( μ 1 − μ 0 ) = ( 0 − μ φ ( x ) ) T I ( 0 − μ φ ( x ) ) = ∥ μ φ ( x ) ∥ 2 (μ_1 - μ_0)^T Σ_1^{-1} (μ_1 - μ_0) = (0 - μ_φ(x))^T I (0 - μ_φ(x)) = \|μ_φ(x)\|^2 (μ1−μ0)TΣ1−1(μ1−μ0)=(0−μφ(x))TI(0−μφ(x))=∥μφ(x)∥2 )。
- ( log det Σ 1 det Σ 0 = log det I det ( σ φ 2 ( x ) I ) = log 1 ( σ φ 2 ( x ) ) d = − d log σ φ 2 ( x ) = − 2 d log σ φ ( x ) \log \frac{\det Σ_1}{\det Σ_0} = \log \frac{\det I}{\det (σ^2_φ(x) I)} = \log \frac{1}{(σ^2_φ(x))^d} = -d \log σ^2_φ(x) = -2d \log σ_φ(x) logdetΣ0detΣ1=logdet(σφ2(x)I)detI=log(σφ2(x))d1=−dlogσφ2(x)=−2dlogσφ(x) )(假设 ( σ φ ( x ) > 0 σ_φ(x) > 0 σφ(x)>0 ))。
- ( − d -d −d ) 项保持不变。
因此:
D K L ( q φ ( z ∣ x ) ∥ p ( z ) ) = 1 2 ( σ φ 2 ( x ) d − d + ∥ μ φ ( x ) ∥ 2 − 2 d log σ φ ( x ) ) D_{KL}(q_φ(z|x) \| p(z)) = \frac{1}{2} \left( σ^2_φ(x) d - d + \|μ_φ(x)\|^2 - 2d \log σ_φ(x) \right) DKL(qφ(z∣x)∥p(z))=21(σφ2(x)d−d+∥μφ(x)∥2−2dlogσφ(x))
梯度计算
KL 散度的梯度关于 ( φ φ φ ) 为:
∇ φ D K L ( q φ ( z ∣ x ) ∥ p ( z ) ) = 1 2 ∇ φ ( σ φ 2 ( x ) d − d + ∥ μ φ ( x ) ∥ 2 − 2 d log σ φ ( x ) ) \nabla_φ D_{KL}(q_φ(z|x) \| p(z)) = \frac{1}{2} \nabla_φ \left( σ^2_φ(x) d - d + \|μ_φ(x)\|^2 - 2d \log σ_φ(x) \right) ∇φDKL(qφ(z∣x)∥p(z))=21∇φ(σφ2(x)d−d+∥μφ(x)∥2−2dlogσφ(x))
由于 ( μ φ ( x ) μ_φ(x) μφ(x) ) 和 ( σ φ ( x ) σ_φ(x) σφ(x) ) 是神经网络的输出,梯度没有闭合形式,但可以通过自动微分数值计算。关于 ( θ θ θ ) 的梯度为零,因为 KL 散度不依赖解码器参数 ( θ θ θ )。
总结
VAE 的编码器通过神经网络 ( EncoderNetwork φ ( x ) \text{EncoderNetwork}_φ(x) EncoderNetworkφ(x) ) 输出高斯分布参数 ( μ φ ( x ) μ_φ(x) μφ(x) ) 和 ( σ φ 2 ( x ) σ^2_φ(x) σφ2(x) ),结合重参数化技巧生成潜变量 ( z z z )。它不仅为后续解码提供了输入,还通过 KL 散度约束潜变量分布与先验的接近性。编码器的设计体现了 VAE 在概率建模与神经网络优化之间的巧妙结合,是理解 VAE 训练过程的关键一环。
希望这篇博客能帮助你深入理解 VAE 编码器!
参考
https://arxiv.org/pdf/2403.18103
后记
2025年3月3日20点00分于上海,在grok 3大模型辅助下完成。