G1周打卡——GAN入门

发布于:2025-06-13 ⋅ 阅读:(23) ⋅ 点赞:(0)

一、定义超参数

import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch

## 创建文件夹
os.makedirs("./images/", exist_ok=True)         # 记录训练过程的图片效果
os.makedirs("./save/", exist_ok=True)           # 训练完成时模型保存的位置
os.makedirs("./datasets/mnist", exist_ok=True)  # 下载数据集存放的位置

## 超参数配置
n_epochs  = 50
batch_size= 64
lr        = 0.0002
b1        = 0.5
b2        = 0.999
n_cpu     = 2
latent_dim= 100
img_size  = 28
channels  = 1
sample_interval=500

# 图像的尺寸:(1, 28, 28),  和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)

# 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)

二、下载数据

# mnist数据集下载
mnist = datasets.MNIST(
    root='./datasets/', train=True, download=True, transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]), 
)

三、配置数据

# 配置数据到加载器
dataloader = DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True,
)

四、定义鉴别器

# 将图片28x28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,
# 最后接sigmoid激活函数得到一个0到1之间的概率进行二分类
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),         # 输入特征数为784,输出为512
            nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射
            nn.Linear(512, 256),              # 输入特征数为512,输出为256
            nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射
            nn.Linear(256, 1),                # 输入特征数为256,输出为1
            nn.Sigmoid(),                     # sigmoid是一个激活函数,二分类问题中可将实数映射到[0, 1],作为概率值, 多分类用softmax函数
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)
        validity = self.model(img_flat)      # 通过鉴别器网络
        return validity       

五、训练模型并保存

## 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()

## 首先需要定义loss的度量方式  (二分类的交叉熵)
criterion = torch.nn.BCELoss()

## 其次定义 优化函数,优化函数的学习率为0.0003
## betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

## 如果有显卡,都在cuda模式中运行
if torch.cuda.is_available():
    generator     = generator.cuda()
    discriminator = discriminator.cuda()
    criterion     = criterion.cuda()

## 进行多个epoch的训练
for epoch in range(n_epochs):                   # epoch:50
    for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)
        
        ## =============================训练判别器==================
        ## view(): 相当于numpy中的reshape,重新定义矩阵的形状, 相当于reshape(128,784)  原来是(128, 1, 28, 28)
        imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)
        real_img = Variable(imgs).cuda()      # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
        real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()      ## 定义真实的图片label为1
        fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()     ## 定义假的图片的label为0

        ## ---------------------
        ##  Train Discriminator
        ## 分为两部分:1、真的图像判别为真;2、假的图像判别为假
        ## ---------------------
        ## 计算真实图片的损失
        real_out = discriminator(real_img)            # 将真实图片放入判别器中
        loss_real_D = criterion(real_out, real_label) # 得到真实图片的loss
        real_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好
        ## 计算假的图片的损失
        ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 随机生成一些噪声, 大小为(128, 100)
        fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。 
        fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片
        loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的loss
        fake_scores = fake_out                                              ## 得到假图片的判别值,对于判别器来说,假图片的损失越接近0越好
        ## 损失函数和优化
        loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失
        optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0
        loss_D.backward()                   # 将误差反向传播
        optimizer_D.step()                  # 更新参数

        ## -----------------
        ##  Train Generator
        ## 原理:目的是希望生成的假的图片被判别器判断为真的图片,
        ## 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,
        ## 反向传播更新的参数是生成网络里面的参数,
        ## 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的, 这样就达到了对抗的目的
        ## -----------------
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 得到随机噪声
        fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片
        output = discriminator(fake_img)                                    ## 经过判别器得到的结果
        ## 损失函数和优化
        loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的loss
        optimizer_G.zero_grad()                                             ## 梯度归0
        loss_G.backward()                                                   ## 进行反向传播
        optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数

        ## 打印训练过程中的日志
        ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if (i + 1) % 300 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
            )
        ## 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)

## 保存模型
torch.save(generator.state_dict(), './save/generator.pth')
torch.save(discriminator.state_dict(), './save/discriminator.pth')

总结:

一、代码结构总览

该代码主要包含以下部分:

1. 定义超参数
2. 下载并配置 MNIST 数据集
3. 构建判别器 Discriminator
4. 构建生成器 Generator
5. 训练模型:交替训练判别器 D 和生成器 G
6. 保存模型和训练过程中的图像

二、关键技术点说明

 1. 超参数设定
  • 设置了训练轮数(n_epochs=50)、批次大小(batch_size=64)、学习率(lr=0.0002)、Adam 优化器参数(b1=0.5, b2=0.999)等。
  • 使用标准 GAN 的推荐参数组合,适合入门级训练任务。
 2. 数据准备
  • 下载并使用 MNIST 数据集;
  • 对图像进行归一化处理,使其像素值范围为 [-1, 1];
  • 使用 DataLoader 加载数据,支持批量训练。
 3. 判别器设计
  • 使用多层全连接网络,将输入图像(28x28=784维)映射为一个 [0,1] 区间的概率值;
  • 激活函数采用 LeakyReLU,避免神经元死亡问题;
  • 最后一层使用 Sigmoid 函数输出判断真假的概率。
4. 生成器设计(未展示)
  • 接收一个 100 维的随机噪声向量;
  • 输出一张 28x28 的图像;
  • 网络结构通常也由多个全连接层构成,并使用 ReLU 或 Tanh 激活函数。
5. 损失函数与优化器
  • 使用 二分类交叉熵损失(BCELoss)
  • 分别使用 Adam 优化器对生成器和判别器进行优化;
  • 在训练生成器时固定判别器,在训练判别器时同时考虑真图与假图。
6. 训练过程
  • 交替训练生成器和判别器:
    • 判别器的目标是正确区分真假图像;
    • 生成器的目标是生成能让判别器误判为真的图像;
  • 每隔一定迭代步数打印当前损失信息;
  • 保存训练过程中的生成图像以观察生成效果;
  • 最终保存训练好的模型参数供后续使用。

三、GAN 原理简要回顾

GAN 是一种由两个神经网络组成的对抗系统:

  • 生成器(Generator):接收随机噪声,生成图像;
  • 判别器(Discriminator):判断输入图像是真实数据还是生成器生成的假数据;
  • 两者通过博弈不断优化,最终目标是让生成器生成接近真实分布的数据。

四、优点与适用场景

 优点
  • 结构清晰,适合初学者理解 GAN 的基本原理;
  • 使用标准优化器和损失函数,训练稳定;
  • 有可视化输出,便于监控训练效果。
 适用场景
  • 图像生成;
  • 数据增强;
  • 风格迁移;
  • 各类无监督生成任务。