使用PyTorch实现MNIST数据集的GAN网络

发布于:2025-04-01 ⋅ 阅读:(17) ⋅ 点赞:(0)

使用PyTorch实现MNIST数据集的GAN网络

1. GAN网络简介

GAN(生成对抗网络)是一种深度学习模型,由生成器(Generator)和判别器(Discriminator)两个网络组成。生成器负责生成逼真的假数据,判别器负责区分真实数据和生成的假数据。通过两个网络的对抗训练,最终生成器能够生成高质量的图像。

2. 网络结构设计

2.1 生成器网络

生成器网络采用了从低维潜在空间到高维图像空间的渐进式生成结构。下面详细说明每一层的结构和维度变化:

class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        
        self.main = nn.Sequential(
            # 输入: [batch_size, latent_dim]
            # 将潜在向量映射到高维特征空间
            nn.Linear(latent_dim, 256 * 7 * 7),  # 输出: [batch_size, 12544]
            nn.BatchNorm1d(256 * 7 * 7),
            nn.ReLU(True),
            
            # 重塑为卷积层输入格式
            nn.Flatten(0, -1),
            nn.Unflatten(0, (-1, 256, 7, 7)),  # 输出: [batch_size, 256, 7, 7]
            
            # 转置卷积层逐步上采样
            # 输入: [batch_size, 256, 7, 7]
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 输出: [batch_size, 128, 14, 14]
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # 输入: [batch_size, 128, 14, 14]
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 输出: [batch_size, 64, 28, 28]
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            # 最终输出层
            # 输入: [batch_size, 64, 28, 28]
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),  # 输出: [batch_size, 1, 28, 28]
            nn.Tanh()
        )

生成器网络的关键设计要点:

  1. 使用线性层将潜在向量(100维)映射到高维特征空间(12544维)
  2. 通过转置卷积实现特征图尺寸的逐步放大:7x7 -> 14x14 -> 28x28
  3. 每层后使用批归一化和ReLU激活函数,其中:
    • BatchNorm1d用于线性层输出的一维特征
    • BatchNorm2d用于卷积层输出的二维特征图
  4. 最后使用Tanh激活函数将输出压缩到[-1,1]范围,与MNIST数据集归一化后的范围匹配
  5. 转置卷积的参数选择:
    • kernel_size=4:常用于上采样,提供足够的感受野
    • stride=2:实现特征图尺寸的翻倍
    • padding=1:保持特征图边缘信息

2.2 判别器网络

判别器网络采用典型的CNN分类器结构,通过逐层降维和特征提取来判断输入图像的真伪:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.main = nn.Sequential(
            # 输入: [batch_size, 1, 28, 28]
            # 第一层卷积
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),  # 输出: [batch_size, 64, 14, 14]
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.3),
            
            # 输入: [batch_size, 64, 14, 14]
            # 第二层卷积
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 输出: [batch_size, 128, 7, 7]
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.3),
            
            # 展平层和全连接层
            # 输入: [batch_size, 128, 7, 7]
            nn.Flatten(),  # 输出: [batch_size, 128 * 7 * 7]
            nn.Linear(128 * 7 * 7, 1),  # 输出: [batch_size, 1]
            nn.Sigmoid()
        )

判别器网络的关键设计要点:

  1. 使用卷积层逐步提取图像特征,实现空间降维:28x28 -> 14x14 -> 7x7
  2. 采用LeakyReLU激活函数(斜率为0.2)防止梯度消失:
    • 相比ReLU,负值输入时仍有较小梯度
    • 有助于生成器获得更好的梯度信息
  3. 使用Dropout层(概率0.3)防止过拟合:
    • 随机丢弃30%的神经元
    • 提高模型泛化能力
    • 特别重要,因为判别器容易过于强大
  4. 最后使用Sigmoid函数将输出压缩到[0,1]区间,表示真实图像的概率
  5. 卷积层参数设计:
    • kernel_size=4:与生成器对称
    • stride=2:实现特征图尺寸减半
    • padding=1:保持边缘信息

2.3 损失函数设计

GAN的损失函数基于最小最大博弈(Minimax Game)理论,数学表达式如下:

m i n G m a x D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] min_G max_D V(D, G) = E_{x \sim p_{data}(x)}[\log D(x)] + E_{z \sim p_z(z)}[\log(1 - D(G(z)))] minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

下面对式子中各项进行详细解释:

  • 极小极大博弈部分: m i n G max ⁡ D V ( D , G ) min_G \max_D V(D, G) minGmaxDV(D,G)这是一个极小极大博弈的表述形式。其中 G G G 代表生成器(Generator), D D D 代表判别器(Discriminator)。
  • 对于判别器 D D D,目标是最大化价值函数 V ( D , G ) V(D, G) V(D,G),即尽可能准确地区分真实数据和生成器生成的虚假数据;对于生成器 G G G,目标是最小化价值函数 V ( D , G ) V(D, G) V(D,G),也就是让生成的数据尽可能逼真,骗过判别器。
期望项一: E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] E_{x \sim p_{data}(x)}[\log D(x)] Expdata(x)[logD(x)]
  • E E E 表示期望(Expectation)。
  • x ∼ p d a t a ( x ) x \sim p_{data}(x) xpdata(x) 表示 x x x 是从真实数据分布 p d a t a ( x ) p_{data}(x) pdata(x) 中采样得到的样本。
  • D ( x ) D(x) D(x) 是判别器对真实样本 x x x的输出,这里使用 l o g D ( x ) log D(x) logD(x) 作为损失项,其目的是让判别器在面对真实数据时,输出的值尽可能接近 1(因为对数函数 l o g log log 在自变量接近 1 时取得较大的值),即判别器能够正确地将真实数据识别为真实的。
期望项二: E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] E_{z \sim p_z(z)}[\log(1 - D(G(z)))] Ezpz(z)[log(1D(G(z)))]
  • z ∼ p z ( z ) z \sim p_z(z) zpz(z)表示 z 是从噪声分布 p z ( z ) p_z(z) pz(z)(通常是简单的分布,如正态分布或均匀分布)中采样得到的噪声向量。
  • G ( z ) G(z) G(z) 是生成器将噪声向量 z 作为输入生成的样本。
  • D ( G ( z ) ) D(G(z)) D(G(z)) 是判别器对生成样本 (G(z)) 的输出。 l o g ( 1 − D ( G ( z ) ) ) log(1 - D(G(z))) log(1D(G(z))) 这一项对于生成器来说,希望 D ( G ( z ) ) D(G(z)) D(G(z)) 尽可能小(接近 0),这样 l o g ( 1 − D ( G ( z ) ) ) log(1 - D(G(z))) log(1D(G(z))) 就会取得较大的值(因为当 D ( G ( z ) ) D(G(z)) D(G(z)) 接近 0 时, 1 − D ( G ( z ) ) 1 - D(G(z)) 1D(G(z)) 接近 1,对数函数 l o g log log 在自变量接近 1 时取得较大的值),也就是让生成器生成的样本能够骗过判别器;对于判别器来说,希望 D ( G ( z ) ) D(G(z)) D(G(z)) 尽可能大(接近 1),这样 l o g ( 1 − D ( G ( z ) ) ) log(1 - D(G(z))) log(1D(G(z)))就会较小,从而使判别器能够正确识别出生成样本是虚假的。

整个目标函数通过这种对抗的方式,促使生成器不断提高生成数据的质量,判别器不断提高区分真实数据和虚假数据的能力 ,最终达到一种动态的平衡状态。

在实际训练中,我们将这个目标分解为两个部分:

  1. 判别器损失:
d_loss_real = criterion(output_real, label_real)  # 真实图像判别损失
d_loss_fake = criterion(output_fake, label_fake)  # 生成图像判别损失
d_loss = d_loss_real + d_loss_fake  # 总判别器损失
  1. 生成器损失:
g_loss = criterion(output_fake, label_real)  # 生成器希望生成的图像被判别为真

这种损失函数设计确保了:

  • 判别器努力提高对真实和生成图像的区分能力
  • 生成器努力生成能够欺骗判别器的逼真图像
  • 两个网络通过对抗训练不断改进

3. 训练过程分析

3.1 超参数设置与调优

latent_dim = 100  # 潜在空间维度
batch_size = 64   # 批次大小
num_epochs = 100  # 训练轮数
lr = 0.0002      # 学习率
beta1 = 0.5      # Adam优化器参数

超参数的选择理由和调优经验:

  1. 潜在空间维度(latent_dim):

    • 选择100维是经验值,可以提供足够的表达能力
    • 过小的维度会限制生成能力
    • 过大的维度会增加训练难度
  2. 批次大小(batch_size):

    • 64是较为常用的值,平衡了训练效率和内存占用
    • 较大的batch_size有助于BatchNorm层的统计
    • 较小的batch_size会增加样本多样性
  3. 学习率(lr):

    • 使用较小的学习率(0.0002)确保训练稳定
    • 可以考虑使用学习率衰减策略
    • 生成器和判别器使用相同的学习率
  4. Adam优化器参数:

    • beta1=0.5:降低动量,减少训练震荡
    • beta2=0.999(默认值):自适应调整学习率
    • weight_decay=0:GAN通常不使用权重衰减

3.2 训练步骤与优化策略

GAN的训练过程是一个动态博弈过程,需要精心设计训练策略:

  1. 判别器训练(每轮迭代):

    • 加载真实MNIST图像batch,标签设为1
    • 生成随机噪声,通过生成器生成假图像
    • 计算判别器在真实和假图像上的损失
    • 优化判别器参数(注意要先detach生成器输出)
  2. 生成器训练(每轮迭代):

    • 生成新的随机噪声
    • 生成假图像并送入判别器
    • 计算生成器损失,目标是让判别器输出1
    • 优化生成器参数
  3. 训练平衡策略:

    • 判别器和生成器交替训练
    • 可以适当增加判别器的训练频率
    • 监控两个网络的损失值,保持相对平衡
    • 避免判别器过强导致生成器梯度消失
  4. 训练稳定性提升:

    • 使用标签平滑:真实标签使用0.9而不是1.0
    • 添加噪声到判别器输入
    • 使用历史生成图像池
    • 定期保存模型检查点

关键代码实现:

# 训练判别器
d_optimizer.zero_grad()
label_real = torch.ones(batch_size, 1).to(device)
label_fake = torch.zeros(batch_size, 1).to(device)

# 真实图像的损失
output_real = discriminator(real_images)
d_loss_real = criterion(output_real, label_real)

# 生成假图像的损失
noise = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(noise)
output_fake = discriminator(fake_images.detach())
d_loss_fake = criterion(output_fake, label_fake)

# 更新判别器
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()

# 训练生成器
g_optimizer.zero_grad()
output_fake = discriminator(fake_images)
g_loss = criterion(output_fake, label_real)
g_loss.backward()
g_optimizer.step()

4. 训练效果展示与分析

4.1 不同训练阶段的生成效果

在训练过程中,每10个epoch保存一次生成的图像样本,可以观察到生成器的学习进展:
在这里插入图片描述

  1. 初始阶段(1-20 epoch):

    • 生成的图像呈现随机噪声状态
    • 基本形状开始出现但非常模糊
    • 判别器损失较大,生成器损失波动明显
  2. 中期阶段(30-60 epoch):

    • 数字轮廓逐渐清晰
    • 笔画粗细开始分化
    • 生成器和判别器损失趋于平衡
  3. 后期阶段(70-100 epoch):

    • 生成图像质量显著提升
    • 数字笔画清晰连贯
    • 背景噪声明显减少
    • 不同数字类别的特征明显

4.2 生成质量评估

  1. 主观评估指标:

    • 图像清晰度
    • 数字结构完整性
    • 笔画连贯性
    • 风格多样性
  2. 常见生成问题及解决方案:

    • 模式崩溃:使用小批量判别或特征匹配
    • 训练不稳定:调整网络结构或优化器参数
    • 生成质量波动:使用滑动平均或模型集成

5. GAN的应用前景

5.1 计算机视觉领域

  1. 图像生成和编辑:

    • 高分辨率图像生成
    • 风格迁移
    • 图像修复和超分辨率
  2. 数据增强:

    • 扩充训练数据集
    • 生成罕见样本
    • 平衡数据分布

5.2 其他应用领域

  1. 医疗领域:

    • 医学图像合成
    • 疾病诊断辅助
    • 药物分子设计
  2. 工业设计:

    • 产品外观设计
    • 材料纹理生成
    • 3D模型生成
  3. 多媒体创作:

    • 音乐生成
    • 视频合成
    • 虚拟试衣

5.3 未来发展方向

  1. 技术改进:

    • 提高生成质量和稳定性
    • 降低计算资源需求
    • 增强可控性和可解释性
  2. 应用拓展:

    • 跨模态生成
    • 个性化定制
    • 实时生成系统
  3. 伦理考虑:

    • 版权保护
    • 隐私安全
    • 防范滥用

5. 总结

通过实现这个MNIST数据集上的GAN网络,我们可以看到:

  1. GAN的训练是一个复杂的平衡过程,需要精心设计网络结构和训练策略
  2. 生成器和判别器的结构设计对生成效果有重要影响
  3. 适当的超参数设置和训练技巧对于稳定训练至关重要
  4. 通过可视化生成结果,可以直观地评估训练效果

这个实现为理解和应用GAN提供了一个很好的起点,可以在此基础上尝试更复杂的数据集和网络结构。

6. 整体代码展示

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import numpy as np

# 设置随机种子以确保可重复性
torch.manual_seed(42)

# 设置设备
device = torch.device('mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu'))

# 生成器网络
# 生成器的作用是将随机噪声转换为逼真的图像
# 网络结构采用转置卷积的方式,逐步将低维噪声向量升维到图像大小
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim  # 潜在空间维度,用于控制生成图像的多样性
        
        self.main = nn.Sequential(
            # 第一层:将潜在向量映射到高维特征空间
            # 输入shape: (batch_size, latent_dim)
            # 输出shape: (batch_size, 256*7*7)
            nn.Linear(latent_dim, 256 * 7 * 7),
            nn.BatchNorm1d(256 * 7 * 7),  # 批归一化有助于训练稳定性
            nn.ReLU(True),  # 使用ReLU激活函数引入非线性
            
            # 重塑为卷积层的输入格式
            # 输入shape: (batch_size, 256*7*7)
            # 输出shape: (batch_size, 256, 7, 7)
            nn.Flatten(0, -1),
            nn.Unflatten(0, (-1, 256, 7, 7)),
            
            # 第一个转置卷积层
            # 输入shape: (batch_size, 256, 7, 7)
            # 输出shape: (batch_size, 128, 14, 14)
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # 第二个转置卷积层
            # 输入shape: (batch_size, 128, 14, 14)
            # 输出shape: (batch_size, 64, 28, 28)
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            # 最后的卷积层
            # 输入shape: (batch_size, 64, 28, 28)
            # 输出shape: (batch_size, 1, 28, 28)
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.main(z)

# 判别器网络
# 判别器的作用是区分真实图像和生成器生成的假图像
# 网络结构使用卷积层逐步提取图像特征,最终输出二分类概率
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.main = nn.Sequential(
            # 第一层卷积:提取基础图像特征
            # 输入shape: (batch_size, 1, 28, 28) (MNIST图像尺寸)
            # 输出shape: (batch_size, 64, 14, 14)
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),  # LeakyReLU避免梯度消失
            nn.Dropout2d(0.3),  # Dropout防止过拟合
            
            # 第二层卷积
            # 输入shape: (batch_size, 64, 14, 14)
            # 输出shape: (batch_size, 128, 7, 7)
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.3),
            
            # 展平层
            # 输入shape: (batch_size, 128, 7, 7)
            # 输出shape: (batch_size, 128 * 7 * 7)
            nn.Flatten(),
            
            # 全连接层
            # 输入shape: (batch_size, 128 * 7 * 7)
            # 输出shape: (batch_size, 1)
            nn.Linear(128 * 7 * 7, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.main(x)

# 超参数设置
# 这些参数的选择对GAN的训练稳定性和生成效果有重要影响
latent_dim = 100  # 潜在空间维度,较大的维度有助于生成更多样的图像
batch_size = 64   # 批次大小,平衡训练效率和内存占用
num_epochs = 100  # 训练轮数
lr = 0.0002      # 学习率,GAN训练需要较小的学习率以保持稳定
beta1 = 0.5      # Adam优化器的beta1参数,较小的值有助于训练稳定

# 数据加载和预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化网络
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# 损失函数和优化器
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# 训练函数
# GAN的训练过程是一个动态博弈过程,需要平衡生成器和判别器的训练
def train():
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)
            real_images = real_images.to(device)
            
            # 训练判别器
            # 判别器需要学会区分真实图像(标签为1)和生成图像(标签为0)
            d_optimizer.zero_grad()  # 清空判别器的梯度
            label_real = torch.ones(batch_size, 1).to(device)    # 真实图像的标签为1
            label_fake = torch.zeros(batch_size, 1).to(device)   # 生成图像的标签为0
            
            # 计算判别器对真实图像的损失
            output_real = discriminator(real_images)  # 判别器对真实图像的预测结果
            d_loss_real = criterion(output_real, label_real)  # 计算真实图像的二元交叉熵损失

            # 生成假图像并计算判别器的损失
            noise = torch.randn(batch_size, latent_dim).to(device)  # 生成随机噪声
            fake_images = generator(noise)  # 使用生成器生成假图像
            output_fake = discriminator(fake_images.detach())  # detach()防止梯度传递到生成器
            d_loss_fake = criterion(output_fake, label_fake)  # 计算假图像的二元交叉熵损失

            # 计算判别器的总损失并更新参数
            d_loss = d_loss_real + d_loss_fake  # 判别器总损失是真假图像损失之和
            d_loss.backward()  # 反向传播计算梯度
            d_optimizer.step()  # 更新判别器参数
            
            # 训练生成器
            # 生成器的目标是生成能够欺骗判别器的图像
            g_optimizer.zero_grad()  # 清空生成器的梯度
            output_fake = discriminator(fake_images)  # 判别器对生成图像的预测
            # 生成器的损失:希望判别器将生成的图像判断为真实图像
            g_loss = criterion(output_fake, label_real)  # 使用真实标签计算损失

            g_loss.backward()  # 反向传播计算梯度
            g_optimizer.step()  # 更新生成器参数
            
            if i % 100 == 0:
                print(f'Epoch [{epoch}/{num_epochs}] Batch [{i}/{len(dataloader)}] '
                      f'd_loss: {d_loss.item():.4f} g_loss: {g_loss.item():.4f}')
        
        # 每个epoch保存生成的图像样本
        if (epoch + 1) % 10 == 0:
            save_fake_images(epoch + 1)

# 保存生成的图像
def save_fake_images(epoch):
    # 保存生成器在当前epoch生成的图像样本
    generator.eval()  # 将生成器设置为评估模式
    with torch.no_grad():  # 不计算梯度,节省内存
        noise = torch.randn(16, latent_dim).to(device)  # 生成16个随机噪声向量
        fake_images = generator(noise)  # 生成16张图像
        fake_images = fake_images.cpu().numpy()  # 将张量转换为NumPy数组
        
        plt.figure(figsize=(4, 4))
        for i in range(16):
            plt.subplot(4, 4, i + 1)
            plt.imshow(fake_images[i, 0], cmap='gray')
            plt.axis('off')
        
        plt.savefig(f'fake_images_epoch_{epoch}.png')
        plt.close()
    generator.train()

if __name__ == '__main__':
    train()

7. 参考链接

  1. GAN相关论文:

  2. PyTorch资源:

  3. MNIST数据集:

  4. GAN进阶资源:


网站公告

今日签到

点亮在社区的每一天
去签到