生成对抗网络——pytorch与paddle实现生成对抗网络

发布于:2025-03-17 ⋅ 阅读:(18) ⋅ 点赞:(0)

生成对抗网络——pytorch与paddle实现生成对抗网络

本文将深入探讨生成对抗网络的理论基础,并通过PyTorch和PaddlePaddle两个深度学习框架来展示如何实现生成对抗网络模型。我们将首先介绍生成对抗网络的基本概念,这些理论基础是理解和实现生成对抗网络的基础。通过PyTorch和PaddlePaddle的代码示例,我们将展示如何设计、训练和评估一个生成对抗网络模型,从而让读者能够直观地理解并掌握这两种框架在计算机视觉问题中的应用。

本文部分为torch框架以及部分理论分析,paddle框架对应代码可见生成对抗网络paddle

import torch
print("pytorch version:",torch.__version__)
pytorch version: 2.4.1+cu118

生成对抗网络原理

神经网络常被用于信息的缩减、提取及概括,这一点在图像分类神经网络分类器得以体现。具体来说,如MNIST数据集,其对应分类器具有784个输入值,而其输出值仅为10个,显著少于输入值的数量。如果将神经网络的输入与输出进行反转,理论上应能实现与数据“缩减”相反的效果,即能够将少量的数据扩展为更大量的数据。通过这种方式,我们可以生成图像数据。生成对抗的过程就是通过生成器(Generator)和判别器(Discriminator)之间的对抗训练,使得生成器能够生成与真实数据难以区分的新数据。

生成对抗网络(Generative Adversarial Networks,简称GAN)是一种无监督的深度学习模型,由Ian J. Goodfellow等人在2014年提出。GAN的主要原理是通过生成器(Generator)和判别器(Discriminator)之间的对抗训练,使得生成器能够生成与真实数据难以区分的新数据。
GAN由两个主要部分组成:

  1. 生成器(Generator)

    • 功能:负责生成新的数据样本。
    • 结构:通常是一个深度神经网络,输入为低维向量(如随机噪声),输出为高维向量(如图片、文本或语音)。
    • 训练目标:生成尽可能真实的数据,以欺骗判别器。
  2. 判别器(Discriminator)

    • 功能:负责区分输入的数据是真实数据还是由生成器生成的假数据。
    • 结构:同样是一个深度神经网络,输入为高维向量(如图片、文本或语音),输出为一个标量,表示输入数据的真实性概率。
    • 训练目标:尽可能准确地区分真实数据和生成数据。

GAN的训练过程是一个对抗训练的过程,生成器和判别器交替进行训练:

  1. 训练判别器

    • 从真实数据集中采样一批真实样本。
    • 从随机噪声中通过生成器生成一批假样本。
    • 将真实样本和假样本混合后输入判别器进行训练,使判别器能够准确地区分真实样本和假样本。
  2. 训练生成器

    • 固定判别器的参数,从随机噪声中生成一批假样本。
    • 将这批假样本输入判别器,并调整生成器的参数,使得判别器将这批假样本误判为真实样本的概率最大化。

GAN的优化目标可以表示为一个极小极大值问题:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中, V ( D , G ) V(D, G) V(D,G)是价值函数, D ( x ) D(x) D(x)表示判别器对真实样本 x x x的判断概率, G ( z ) G(z) G(z)表示生成器根据随机噪声 z z z生成的假样本, p data ( x ) p_{\text{data}}(x) pdata(x)是真实数据的分布, p z ( z ) p_z(z) pz(z)是随机噪声的分布。

  • 对于判别器 D D D,其目标是最大化 V ( D , G ) V(D, G) V(D,G),即提高区分真实样本和假样本的能力。
  • 对于生成器 G G G,其目标是最小化 max ⁡ D V ( D , G ) \max_D V(D, G) maxDV(D,G),即生成能够欺骗判别器的假样本。

在理想情况下,随着训练的进行,生成器生成的数据分布将逐渐接近真实数据分布,判别器将无法准确区分两者。此时,GAN达到了一个纳什平衡点(Nash equilibrium),生成器和判别器的性能都不再提升。

GAN由于其强大的生成能力,被广泛应用于图像生成、文本生成、语音合成、数据增强等领域。例如,在图像生成方面,GAN可以生成高质量的图片,甚至可以实现风格迁移、图像修复等功能。

综上所述,生成对抗网络通过生成器和判别器之间的对抗训练,实现了从随机噪声中生成与真实数据难以区分的新数据的能力。其原理简单而深刻,为深度学习领域带来了全新的突破。

接下来,我们设计一个简单的GAN模型,用于生成手写数字图像。我们将使用MNIST数据集作为训练数据。首先定义判别器和生成器模型。

import torch  
import torch.nn as nn  
import numpy as np  
import torch.nn.functional as F  

IMAGE_SHAPE = (1, 28, 28)  # MNIST图像的形状
# 超参数设置  
Z_dim = 100  
h_dim = 128  

# 定义生成器  
class Generator(nn.Module):  
    def __init__(self):  
        super(Generator, self).__init__()  
        self.main = nn.Sequential(  
            nn.Linear(Z_dim, 256),  
            nn.LeakyReLU(0.2, inplace=True),  
            nn.BatchNorm1d(256, momentum=0.8),  
            nn.Linear(256, 512),  
            nn.LeakyReLU(0.2, inplace=True),  
            nn.BatchNorm1d(512, momentum=0.8),  
            nn.Linear(512, 1024),  
            nn.LeakyReLU(0.2, inplace=True),  
            nn.BatchNorm1d(1024, momentum=0.8),  
            nn.Linear(1024, np.prod(IMAGE_SHAPE)),  
            nn.Tanh(),  
            # Reshape操作在forward中实现  
        )  
  
    def forward(self, x):  
        x = self.main(x)  
        x = x.view(-1, 28, 28)
        return x   
  
# 定义判别器  
class Discriminator(nn.Module):  
    def __init__(self):  
        super(Discriminator, self).__init__()  
        # Flatten操作在PyTorch中通常不使用单独的层,而是在forward中直接用view或reshape实现  
        self.fc1 = nn.Linear(np.prod(IMAGE_SHAPE), 512)  # np.prod计算IMAGE_SHAPE中所有元素的乘积  
        self.fc2 = nn.Linear(512, 256)  
        self.fc3 = nn.Linear(256, 1)  
          
    def forward(self, x):  
        x = x.view(-1, np.prod(IMAGE_SHAPE))  # 将输入x展平为一维向量  
        x = F.leaky_relu(self.fc1(x), 0.2)  
        x = F.leaky_relu(self.fc2(x), 0.2)  
        x = torch.sigmoid(self.fc3(x))  
        return x 

接下来,我们构建MNIST数据集和数据加载器。

from torch.utils.data import DataLoader  
from torchvision import datasets, transforms  

transform = transforms.Compose([  
    transforms.ToTensor(),  
    transforms.Normalize(mean=(0.5,), std=(0.5,))  
])  
mb_size = 64  

# 构建数据集  
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# 构建数据加载器  
dataloader = DataLoader(dataset=train_dataset, batch_size=mb_size, shuffle=True)  

接下来,我们定义训练过程,包括判别器和生成器的训练,并进行训练,保存生成的图像。(注意,下方代码运行时间较长,普通1060笔记本端显卡需约300min,读者可根据个人设备情况,调整batchsize或者训练轮次,下方代码要有较好效果num_epochs至少得150次)

import torch.optim as optim  
import matplotlib.pyplot as plt  
import matplotlib.gridspec as gridspec  
import os 

# 超参数设置  
mb_size = 64  
lr = 1e-3  
num_epochs = 1000
output_dir = 'out'  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 初始化模型  
G = Generator().to(device)  
D = Discriminator().to(device) 

# 优化器  
G_solver = optim.Adam(G.parameters(), lr=lr)  
D_solver = optim.Adam(D.parameters(), lr=lr)  
  
# 标签  
ones_label = torch.ones(mb_size, 1).to(device)
zeros_label = torch.zeros(mb_size, 1).to(device)

pre_G_loss = 1e5 # 用于记录生成器的损失
  
# 训练过程  
for it in range(num_epochs):  
    for X, _ in dataloader:  
        if X.shape[0] != mb_size:
            continue
        X = X.to(device)
        # 判别器训练  
        z = torch.randn(mb_size, Z_dim).to(device)
        G_sample = G(z)  
        D_real = D(X)  
        D_fake = D(G_sample.detach())  # 避免梯度传递到生成器  
  
        D_loss_real = nn.BCELoss()(D_real, ones_label)  
        D_loss_fake = nn.BCELoss()(D_fake, zeros_label)  
        D_loss = D_loss_real + D_loss_fake  
  
        D_solver.zero_grad()  
        D_loss.backward()  
        D_solver.step()  
  
        # 生成器训练  
        z = torch.randn(mb_size, Z_dim).to(device)
        G_sample = G(z)  
        D_fake = D(G_sample)  

        G_loss = nn.BCELoss()(D_fake, ones_label)  

        G_solver.zero_grad()  
        G_loss.backward()  
        G_solver.step()  
  
    # 打印和保存图像  
    if it % 50 == 0:  
        print(f'Iter-{it}; D_loss: {D_loss.item()}; G_loss: {G_loss.item()}')  
  
        samples = G(torch.randn(16, Z_dim).to(device)).detach().cpu().numpy()  
  
        fig = plt.figure(figsize=(4, 4))  
        gs = gridspec.GridSpec(4, 4)  
        gs.update(wspace=0.05, hspace=0.05)  
  
        for i, sample in enumerate(samples):  
            ax = plt.subplot(gs[i])  
            plt.axis('off')  
            ax.set_xticklabels([])  
            ax.set_yticklabels([])  
            ax.set_aspect('equal')  
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')  
  
        if not os.path.exists(output_dir):  
            os.makedirs(output_dir)  
  
        plt.savefig(f'{output_dir}/{str(it // 50).zfill(3)}.png', bbox_inches='tight')  
        plt.close(fig)
Iter-0; D_loss: 0.011108401231467724; G_loss: 17.891992568969727
Iter-50; D_loss: 0.8838750123977661; G_loss: 2.4737629890441895
Iter-100; D_loss: 0.762171745300293; G_loss: 1.7359180450439453
Iter-150; D_loss: 1.1558961868286133; G_loss: 2.003592014312744
Iter-200; D_loss: 1.1423180103302002; G_loss: 1.6627423763275146
Iter-250; D_loss: 1.053521990776062; G_loss: 1.7526977062225342
Iter-300; D_loss: 0.7899768948554993; G_loss: 1.5552549362182617
Iter-350; D_loss: 1.0171101093292236; G_loss: 2.3687257766723633
Iter-400; D_loss: 0.8682758212089539; G_loss: 1.4403650760650635
Iter-450; D_loss: 0.8783559799194336; G_loss: 2.5538430213928223
Iter-500; D_loss: 0.7236818075180054; G_loss: 2.3121109008789062
Iter-550; D_loss: 0.636387050151825; G_loss: 2.1528429985046387
Iter-600; D_loss: 0.5927819013595581; G_loss: 1.7480260133743286
Iter-650; D_loss: 0.7258435487747192; G_loss: 1.3098269701004028
Iter-700; D_loss: 0.6658307313919067; G_loss: 1.9678267240524292
Iter-750; D_loss: 0.6192352175712585; G_loss: 1.9035545587539673
Iter-800; D_loss: 0.6187129020690918; G_loss: 1.9537832736968994
Iter-850; D_loss: 0.7183228135108948; G_loss: 2.339176893234253
Iter-900; D_loss: 0.6317700147628784; G_loss: 2.2820653915405273
Iter-950; D_loss: 0.6620879769325256; G_loss: 2.7352004051208496

让我们看看生成器生成的图像。

import matplotlib.pyplot as plt  
import os

# 指定保存图像的文件夹路径
folder_path = 'out'

# 获取文件夹中所有的.png文件
images = [f for f in os.listdir(folder_path) if f.endswith('.png')]

# 按文件名排序,确保图像按顺序显示
images = sorted(images)
fig = plt.figure(figsize=(10, 10))
for i, image in enumerate(images):
    image_path = os.path.join(folder_path, image)
    plt.subplot(4, 5, i+1)
    plt.imshow(plt.imread(image_path))
    plt.axis('off')

在这里插入图片描述

可以看到,随着训练轮次的增加,生成器生成的图像质量逐渐提高,逐渐接近真实图像。此时我们便得到了一个能够生成手写数字的生成式网络。

下面我们详细介绍一下训练过程。生成对抗网络(GAN)的训练过程是一个迭代的过程,其中包括同时训练两个模型:生成器(Generator)和判别器(Discriminator)。

首先初始化模型和优化器

  • 生成器(G):负责从随机噪声生成逼真的数据。
  • 判别器(D):负责区分真实数据和生成的数据。
  • 优化器:使用Adam优化器,分别为生成器和判别器设置学习率(lr),用于更新模型参数。

准备标签

  • 标签:为真实数据和生成数据准备标签。ones_label用于真实数据,zeros_label用于生成数据。

训练循环

首先进行判别器训练,包括:

  • 生成随机噪声:从标准正态分布中采样噪声z
  • 生成假数据:通过生成器生成假数据G_sample
  • 判别真实数据:判别器对真实数据X进行判别,输出D_real
  • 判别假数据:判别器对生成器生成的假数据进行判别,输出D_fake。这里使用detach()避免梯度传递到生成器。
  • 计算损失:使用二元交叉熵损失(BCELoss)计算真实数据和假数据的损失,并将它们相加得到总损失D_loss
  • 更新判别器:通过反向传播计算梯度,并使用优化器更新判别器的参数。

生成器训练过程包括:

  • 重新生成随机噪声:再次从标准正态分布中采样噪声z
  • 重新生成假数据:通过生成器生成新的假数据G_sample
  • 判别假数据:判别器对新生成的假数据进行判别,输出D_fake
  • 计算损失:使用二元交叉熵损失计算生成器的损失G_loss,目标是让判别器将生成的数据误判为真。
  • 更新生成器:通过反向传播计算梯度,并使用优化器更新生成器的参数。

GAN的训练过程是一个零和博弈,生成器和判别器相互竞争、相互优化:

  • 判别器试图区分真实数据和生成数据。
  • 生成器试图生成足以欺骗判别器的数据。

通过反复迭代这个过程,生成器逐渐学会生成越来越逼真的数据,而判别器也逐渐提高区分真假数据的能力。最终,在理想情况下,生成器可以生成几乎与真实数据分布一致的数据,而判别器则无法区分它们。

  • 判别假数据:判别器对新生成的假数据进行判别,输出D_fake
  • 计算损失:使用二元交叉熵损失计算生成器的损失G_loss,目标是让判别器将生成的数据误判为真。
  • 更新生成器:通过反向传播计算梯度,并使用优化器更新生成器的参数。

GAN的训练过程是一个零和博弈,生成器和判别器相互竞争、相互优化:

  • 判别器试图区分真实数据和生成数据。
  • 生成器试图生成足以欺骗判别器的数据。

通过反复迭代这个过程,生成器逐渐学会生成越来越逼真的数据,而判别器也逐渐提高区分真假数据的能力。最终,在理想情况下,生成器可以生成几乎与真实数据分布一致的数据,而判别器则无法区分它们。