第 7 期:DDPM 采样提速方案:从 DDPM 到 DDIM
本期关键词:采样加速、DDIM 推导、可控性提升、伪逆过程、代码实战
前情回顾:DDPM 的采样瓶颈
在前几期中,我们构建了一个完整的 DDPM 生成流程。但是你可能已经发现:
生成一张图像太慢了!!!
原因是:
DDPM 要在 T 个时间步中一步步地去噪,从 x_T → x_0
。而通常 T 至少为 1000,采样一次就意味着 1000 次前向推理,非常耗时!
目标:更快的采样方法!
本期,我们引入一种“非随机”的采样机制 —— DDIM(Denoising Diffusion Implicit Models)。
它能在 保留图像质量的同时,将采样步骤从 1000 步减少到几十步!
比如 T=1000 → 50
,加速 20 倍+
数学推导:DDIM 与 DDPM 的关系
DDPM 复习公式
我们知道在 DDPM 中,每一步的去噪过程是:
其中 z
是随机噪声。DDIM 做的事就是:
去掉这一步的随机性,将采样变为 确定性过程!
DDIM 推导核心公式
这里的 x_0
是模型预测的原始图像,通过 x_0 = (x_t - √(1 - ᾱ_t) * ε) / √(ᾱ_t)
得到。
直观理解:DDIM 是一种“伪逆”的过程,保留了模型预测的主导性。
PyTorch 实现 DDIM 推理过程
我们只需要修改之前的采样函数,引入 DDIM:
@torch.no_grad()
def ddim_sample(model, img_size=32, num_samples=16, ddim_steps=50, device='cuda'):model.eval()step_size = T // ddim_stepsx_t = torch.randn(num_samples, 3, img_size, img_size).to(device)for i in range(0, T, step_size):t = torch.full((num_samples,), T - 1 - i, device=device, dtype=torch.long)alpha = alphas_cumprod[t].to(device)[:, None, None, None]sqrt_alpha = torch.sqrt(alpha)sqrt_one_minus_alpha = torch.sqrt(1 - alpha)with torch.no_grad():epsilon = model(x_t, t.float())x_0_pred = (x_t - sqrt_one_minus_alpha * epsilon) / sqrt_alphax_0_pred = x_0_pred.clamp(-1, 1)next_t = torch.full((num_samples,), max(t[0] - step_size, 0), device=device, dtype=torch.long)next_alpha = alphas_cumprod[next_t].to(device)[:, None, None, None]x_t = torch.sqrt(next_alpha) * x_0_pred + torch.sqrt(1 - next_alpha) * epsilonreturn x_0_pred
生成样本可视化
samples = ddim_sample(model, num_samples=16, ddim_steps=50)
samples = (samples.clamp(-1, 1) + 1) / 2grid = torchvision.utils.make_grid(samples, nrow=4)
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.title("DDIM Fast Sampling Result")
plt.show()
运行效果图示例:
DDPM vs DDIM 对比
项目 | DDPM | DDIM |
---|---|---|
是否随机 | ✅ 是 | ❌ 否 |
是否严格等价 | ✅ 是 | ❌ 不是(近似) |
是否可控(重建) | ❌ 否 | ✅ 是 |
采样速度 | 慢(1000步) | 快(<50步) |
图像质量 | 高 | 接近 DDPM |
✅ 总结
在本期中,我们完成了:
-
✅ DDIM 理论推导;
-
✅ DDIM PyTorch 实现;
-
✅ CIFAR-10 样本生成展示;
-
✅ 与 DDPM 的对比分析。
第 8 期预告:条件生成!
下一期我们将引入 类条件 DDPM,尝试生成某个指定类别的图像(如飞机、青蛙、猫等)!实现“我想生成第几类”的定向控制!