当前位置: 首页 > news >正文

第 4 期:DDPM中的损失函数——为什么只预测噪声?

—— 从变分下界到噪声预测

回顾:我们到底在做什么?

在第 3 期中,我们介绍了扩散模型的逆过程建模。简而言之,目标是通过神经网络学习从噪声 x_t​ 中恢复图像 x_0,并且我们通过预测噪声 ϵ来完成这个任务。

今天,我们将深入解析为什么我们仅仅关注噪声预测,以及如何通过损失函数来指导模型的训练。更重要的是,我们将从变分推导的角度,揭示这一做法的理论基础。

损失函数的直观理解

在DDPM中,损失函数的核心是均方误差(MSE)。它是计算网络预测的噪声 ϵ_θ(x_t,t)与真实噪声 ϵ 之间差异的度量。

损失函数表达式

假设给定原图像x_0,我们通过正向扩散过程获得图像在第t步的版本x_t:

其中,ϵ 是标准正态分布噪声。

逆过程目标是通过神经网络 ϵ_θ(x_t,t) 来估计 ϵ。损失函数则是通过计算网络输出与真实噪声之间的均方误差来进行优化:

 

这个损失函数为什么如此简单且有效?接下来,我们从变分下界(ELBO)的角度来理解。

为什么预测噪声而不是直接预测图像?

1. 简化建模问题

考虑到我们每次只需要预测一个特定的时间步(t)的噪声,网络的目标变得更加明确且简单。直接预测图像 x_0 的像素值,则意味着模型需要从噪声中恢复整个图像结构,这在高维空间中是一个非常复杂的问题。

然而,预测噪声本质上是一个去噪过程,这个过程相对更加容易拟合和收敛。

2. 稳定性和收敛性

在扩散模型中,噪声是添加到每个像素上的随机扰动。通过学习从噪声中恢复出原始图像的噪声成分,网络本质上是在学习图像的细节,而不是整个图像结构。因此,通过减少噪声的预测误差,模型能够更加稳定地训练。

从变分下界(ELBO)看损失函数的推导

🧠 变分推导的基础

假设我们有一个潜在的生成过程 p_θ(x_(0:T)),其由正向过程和逆向过程组成。在最大似然学习中,我们希望最大化数据分布 p_θ(x_0):

这个积分通常无法直接计算,因此我们通过变分下界(ELBO)来近似:

其中,KL表示Kullback-Leibler散度,用于衡量两个分布的差异。通过对 KL 散度的优化,我们可以逼近最优的逆向过程。 

具体推导

在DDPM中,逆过程建模为:

为了简化,假设 Σ_θ(x_t,t)为固定值(通常为单位矩阵)。因此,模型仅需预测均值 μ_θ(x_t,t)。

通过对KL散度进行优化,我们得到最终的损失函数:

这个损失函数实际上就是模型预测噪声 ϵ_θ(x_t,t)与真实噪声 ϵ\epsilonϵ 之间的均方误差。 

代码演示:噪声预测与损失函数

为了更好地理解这个过程,我们来实现一个简单的训练循环,展示如何通过损失函数来训练模型。

import torch
import torch.nn as nn
import matplotlib.pyplot as plt# 定义网络
class DenoiseModel(nn.Module):def __init__(self):super(DenoiseModel, self).__init__()self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(64, 1, kernel_size=3, padding=1)def forward(self, x):x = self.relu(self.conv1(x))return self.conv2(x)# 损失函数
def noise_loss(x_t, noise, model):predicted_noise = model(x_t)  # 预测噪声return nn.MSELoss()(predicted_noise, noise)  # 计算损失# 示例训练过程
def train(model, dataloader, optimizer):model.train()for x_0, _ in dataloader:t = torch.randint(0, T, (x_0.size(0),), device="cuda")noise = torch.randn_like(x_0)x_t = q_sample(x_0, t, noise)  # 通过正向扩散生成带噪图像optimizer.zero_grad()loss = noise_loss(x_t, noise, model)  # 计算噪声预测损失loss.backward()optimizer.step()print(f"Loss: {loss.item():.4f}")

可视化:损失函数与训练效果

在训练过程中,损失函数会随着时间逐渐下降。我们可以使用以下代码可视化训练过程中预测噪声与真实噪声的差异。

def plot_loss_progression(losses):plt.plot(losses)plt.xlabel('Iterations')plt.ylabel('Loss')plt.title('Training Loss Progression')plt.show()

 

 

小结

关键点内容
损失函数通过均方误差(MSE)计算预测噪声和真实噪声之间的差异
目标学习从噪声中恢复原始图像细节
变分下界通过优化KL散度来推导损失函数
网络设计仅需预测噪声 ϵ 而非图像 x_0​

 

下一期预告(第 5 期):

我们将进入实战部分,介绍如何训练第一个DDPM模型。我们将基于MNIST数据集,带领大家一步步实现从数据加载到训练的完整流程,并展示生成的效果!

 

相关文章:

  • 守护进程及gdb调试(新手简略版)
  • 数控铣床自动上下料机械手控制装置设计
  • python豆包语音合成并播放
  • keil5软件配置以及使用技巧
  • Aladdin显卡多任务运行教程
  • 大模型应用_AutoGPT
  • 软件测试之接口测试详解
  • Linux CentOS 更改MySQL数据库目录位置
  • 生态篇|多总线融合与网关设计
  • 函数与数组---------C语言经典题目(1)
  • Vue 3 计算属性与侦听器深度解析:优雅处理响应式数据
  • ​​Nginx快速入门-3:工作流程和模块化
  • day1-小白学习JAVA(mac版)---(jdk安装和环境变量配置)
  • 【Reading Notes】(8.2)Favorite Articles from 2025 February
  • ModbusTCP 转 Profinet 主站网关
  • 从零实现Git安装、使用
  • FairMOT与MCFairMOT算法对比
  • 【web服务_负载均衡Nginx】二、Nginx 核心技术之负载均衡与反向代理
  • c++ static 和 extern 关键字
  • 论文阅读笔记——Mixtral of Experts
  • 大学2025丨浙大哲学院院长王俊:文科的价值不在于直接创造GDP
  • 美国土安全部长餐厅遇窃,重要证件被盗走
  • 一条水脉串起七个特色区块,上海嘉定发布2025年新城行动方案
  • 商务部:中国加快推进服务业扩大开放综合试点为世界注入更多确定性
  • 视觉周刊|第五届中国国际消费品博览会展现全球合作新格局
  • 为溶血性疾病治疗提供新靶点,专家团队在《细胞》发文