当前位置: 首页 > news >正文

山东大学软件学院ai导论实验之生成对抗网络

目录

实验目的

实验代码

实验内容

实验结果


实验目的

基于Pytorch搭建一个生成对抗网络,使用MNIST数据集。

实验代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os

# 设置环境变量
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# 创建保存生成图像的文件夹
output_path = r"xxxxxxxxxxxxxxxxxx"
os.makedirs(output_path, exist_ok=True)


# 生成器网络
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.network(z)
        return img.view(img.size(0), 1, 28, 28)


# 判别器网络
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.network(img.view(img.size(0), -1))


def generate_and_save_images(generator, test_input, epoch, img_path):
    with torch.no_grad():
        generated_images = generator(test_input).cpu().numpy()

    fig, axes = plt.subplots(4, 4, figsize=(4, 4))
    for i, ax in enumerate(axes.flat):
        # 将图像从形状 (1, 28, 28) 转换为 (28, 28),去除通道维度
        ax.imshow(np.squeeze(generated_images[i]), cmap='gray')
        ax.axis('off')

    img_filename = os.path.join(img_path, f"generated_epoch_{epoch}.png")
    plt.tight_layout()
    plt.savefig(img_filename)
    plt.close()


# 设置设备(使用GPU或CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数
lr = 0.0001
batch_size = 128
latent_dim = 100
epochs = 2000

# 数据预处理和加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./MNIST_data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 测试数据:随机噪声作为输入
test_data = torch.randn(batch_size, latent_dim).to(device)

# 初始化生成器和判别器,并定义损失函数和优化器
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# 记录损失
D_losses = []
G_losses = []

# 训练过程
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(train_loader):
        real_imgs = imgs.to(device)
        batch_size = real_imgs.size(0)

        # 判别器训练
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)

        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # 计算损失
        real_loss = adversarial_loss(discriminator(real_imgs), real_labels)
        fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)
        d_loss = (real_loss + fake_loss) / 2

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # 生成器训练
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(fake_imgs), real_labels)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        # 记录损失
        D_losses.append(d_loss.item())
        G_losses.append(g_loss.item())

        # 打印每2000个步骤的迭代信息
        if (epoch * len(train_loader) + i) % 2000 == 0:
            print(f"Iter: {epoch * len(train_loader) + i}")
            print(f"D_loss: {d_loss.item():.4f}")
            print(f"G_loss: {g_loss.item():.4f}")
    # 每个epoch保存生成的图像
    generate_and_save_images(generator, test_data, epoch, output_path)

    # 保存生成器和判别器的模型
    torch.save(generator.state_dict(), "Generator_mnist.pth")
    torch.save(discriminator.state_dict(), "Discriminator_mnist.pth")

# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.plot(D_losses, label='Discriminator Loss')
plt.plot(G_losses, label='Generator Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.savefig('loss_curve.png')  # 保存图像
plt.show()  # 显示图像

实验内容

1. 数据集加载

与前几次实验一样,本实验仍然使用MNIST数据集作为输入数据集通过torchvision库进行加载并标准化处理,使得图像像素值在[-1, 1]范围内,以适应生成对抗网络的训练要求。

2. 生成器与判别器网络

生成器:生成器网络的任务是生成伪造的图像,以欺骗判别器。输入是一个随机噪声向量(latent vector),输出是一个28x28像素的图像。生成器使用多个全连接层,每个层后面都跟着一个LeakyReLU激活函数,最终输出通过Tanh激活函数确保生成的图像像素值在[-1, 1]范围内。

判别器:判别器网络的任务是区分输入的图像是“真实的”还是“伪造的”。它将图像输入后,通过多个全连接层,最后输出一个介于0和1之间的值,表示图像的真实性。

3. 训练过程

判别器训练:判别器的目标是最大化其准确性,即正确分类真实和伪造的图像。在每次训练中,先计算真实图像的损失,然后计算生成图像的损失,最后将两个损失加权平均得到判别器的总损失。

生成器训练:生成器的目标是最小化判别器对其生成图像的判断错误率。即通过调整其权重,使得生成的图像越来越像真实图像,以此欺骗判别器。生成器的损失函数是判别器对生成图像的输出,标签为“真实”(即1)。

模型优化:使用Adam优化器分别优化生成器和判别器的参数。学习率为0.0001。

  1. 改变隐藏层数

生成器的结构由原来的4个隐藏层缩减为2个隐藏层:

5.生成图像并保存

在每个epoch结束时,使用生成器生成一些图像,并将图像保存为PNG格式文件。每个epoch的图像被保存到指定的文件夹中,以便可视化生成图像的变化。

6. 绘制损失曲线

训练过程中记录并绘制判别器和生成器的损失曲线,以便观察模型的训练进展。

实验结果

迭代得到的训练结果为:

改变隐藏层数得到的部分结果为:

刚开始生成的初始图像为:

运行一段时间后,得到的图像为:

可以明显的看到,随着迭代不断增加,数字越来越清晰,数字识别成功

损失曲线为:

初始:

慢慢的趋于平稳:

相关文章:

  • 【Python爬虫(71)】用Python爬虫解锁教育数据的奥秘
  • obj离线加载(vue+threejs)+apk方式浏览
  • DDNS-GO 动态域名解析
  • 基于YOLO11深度学习的医学X光骨折检测与语音提示系统【python源码+Pyqt5界面+数据集+训练代码】
  • 基于SpringBoot的“洪涝灾害应急信息管理系统”的设计与实现(源码+数据库+文档+PPT)
  • 【Java】I/O 流篇 —— 转换流与序列化流
  • 5分钟学习-什么事前端HTML文件
  • Python 网络爬虫实战全解析:案例驱动的技术探索
  • Linux-IPC-消息队列
  • Java 大视界 -- Java 大数据在智能物流路径规划与车辆调度中的创新应用(102)
  • C# Unity 唐老狮 No.2 模拟面试题
  • 36. Spring Boot 2.1.3.RELEASE 中实现监控信息可视化并添加邮件报警功能
  • 信息系统的安全防护
  • 神经网络 - 激活函数(Sigmoid 型函数)
  • 剑指 Offer II 032. 有效的变位词
  • flask 是如何分发请求的?
  • 机试准备第三天
  • 关于CanvasRenderer.SyncTransform触发调用的机制
  • 04 路由表的IP分组传输过程
  • 【deepseek解决不了的问题】vue2响应式数据在视图改变后被无感置空
  • 事关稳就业稳经济,10张海报看懂这场发布会的政策信号
  • 贸促会答澎湃:5月22日将举办2025年贸易投资促进峰会
  • 香港警务处高级助理处长叶云龙升任警务处副处长(行动)
  • 江西省宁都县政协原二级调研员谢亦礼被查
  • 新华时评:坚定不移办好自己的事,着力抓好“四稳”
  • “谁羽争锋”全国新闻界羽毛球团体邀请赛在厦门开赛