【学习】对抗训练-WGAN
WGAN
- Wasserstein GAN(WGAN)详解:从优化目标到 Lipschitz 约束
- 一、Wasserstein 距离是啥?
- 二、WGAN 的目标函数形式
- 三、为什么要 1-Lipschitz 约束?
- 什么是 Lipschitz 连续?
- 四、WGAN 如何实现 Lipschitz 约束?
- 1. **WGAN with weight clipping(原始 WGAN)**
- 2. **WGAN-GP(Gradient Penalty)**
- 五、WGAN 的优势
- 六、总结
Wasserstein GAN(WGAN)详解:从优化目标到 Lipschitz 约束
在原始 GAN 中,由于使用了 JS 散度,训练过程常常出现 梯度消失 或 模式崩溃(mode collapse)。为了解决这些问题,Wasserstein GAN(WGAN)应运而生,它用更稳定、更有意义的分布距离 —— Wasserstein 距离 来替代 JS 散度。
一、Wasserstein 距离是啥?
Wasserstein-1 距离(又叫地球移动者距离,Earth Mover’s Distance)定义为:
W ( P r , P g ) = inf γ ∈ Π ( P r , P g ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x, y) \sim \gamma}[\|x - y\|] W(Pr,Pg)=γ∈Π(Pr,Pg)infE(x,y)∼γ[∥x−y∥]
- 直观解释:把一个堆土( P g P_g Pg)变成另一个堆土( P r P_r Pr)所需要“最小的工作量”;
- 优点:即使 P r P_r Pr 与 P g P_g Pg 的支持集没有重叠,Wasserstein 距离依然是有意义的(不像 JS 或 KL)。
二、WGAN 的目标函数形式
根据 Kantorovich-Rubinstein 对偶形式,WGAN 把距离 W ( P r , P g ) W(P_r, P_g) W(Pr,Pg) 转化为如下形式进行优化:
max f ∈ F E x ∼ P r [ f ( x ) ] − E x ∼ P g [ f ( x ) ] \max_{f \in \mathcal{F}} \mathbb{E}_{x \sim P_r}[f(x)] - \mathbb{E}_{x \sim P_g}[f(x)] f∈FmaxEx∼Pr[f(x)]−Ex∼Pg[f(x)]
其中 F \mathcal{F} F 是所有 1-Lipschitz 连续函数的集合。
在实际训练中,WGAN 把判别器 D D D 称为 critic(判别器不再输出概率,而是任意实数),优化目标变为:
min G max D ∈ Lip-1 E x ∼ P r [ D ( x ) ] − E z ∼ P z [ D ( G ( z ) ) ] \min_G \max_{D \in \text{Lip-1}} \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{z \sim P_z}[D(G(z))] GminD∈Lip-1maxEx∼Pr[D(x)]−Ez∼Pz[D(G(z))]
三、为什么要 1-Lipschitz 约束?
因为只有在 D D D 属于 1-Lipschitz 函数的前提下,才能保证:
E x ∼ P r [ D ( x ) ] − E x ∼ P g [ D ( x ) ] = W ( P r , P g ) \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{x \sim P_g}[D(x)] = W(P_r, P_g) Ex∼Pr[D(x)]−Ex∼Pg[D(x)]=W(Pr,Pg)
什么是 Lipschitz 连续?
一个函数 f f f 满足:
∣ f ( x ) − f ( y ) ∣ ≤ K ⋅ ∥ x − y ∥ |f(x) - f(y)| \leq K \cdot \|x - y\| ∣f(x)−f(y)∣≤K⋅∥x−y∥
则称它是 K K K-Lipschitz 连续函数。特别地,当 K = 1 K=1 K=1 时,就是我们 WGAN 所要求的 1-Lipschitz 函数。
四、WGAN 如何实现 Lipschitz 约束?
WGAN 提出了几种做法:
1. WGAN with weight clipping(原始 WGAN)
将判别器的参数强制约束在某一范围内,如 [ − 0.01 , 0.01 ] [-0.01, 0.01] [−0.01,0.01],即:
for p in D.parameters():
p.data.clamp_(-0.01, 0.01)
但这种方法会限制模型表达能力,容易导致训练不稳定。
2. WGAN-GP(Gradient Penalty)
改进做法,引入梯度惩罚项:
E x ^ ∼ P x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] \mathbb{E}_{\hat{x} \sim P_{\hat{x}}} \left[ \left( \|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1 \right)^2 \right] Ex^∼Px^[(∥∇x^D(x^)∥2−1)2]
其中 x ^ \hat{x} x^ 是在真实数据和生成数据之间插值得到的。
最终优化目标:
E x ∼ P r [ D ( x ) ] − E x ∼ P g [ D ( G ( z ) ) ] + λ ⋅ GP \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{x \sim P_g}[D(G(z))] + \lambda \cdot \text{GP} Ex∼Pr[D(x)]−Ex∼Pg[D(G(z))]+λ⋅GP
WGAN-GP 是目前最广泛使用的 WGAN 版本。
五、WGAN 的优势
原始 GAN | WGAN |
---|---|
使用 JS 散度,训练易不稳定 | 使用 Wasserstein 距离,更稳定 |
判别器输出为概率值 | 判别器输出为实数 |
可能梯度消失 | 保持梯度非零,更容易训练 |
Mode collapse 严重 | 缓解 mode collapse |
六、总结
WGAN 提供了更合理的 GAN 训练方法,其核心在于:
- 用 Wasserstein 距离替代 JS;
- 强制判别器满足 1-Lipschitz 连续性;
- 使用梯度惩罚等方法约束判别器行为。