【论文笔记】生成对抗网络 GAN

发布于:2025-03-22 ⋅ 阅读:(30) ⋅ 点赞:(0)

GAN

2014 年,Ian Goodfellow 等人提出生成对抗网络(Generative Adversarial Networks),GAN 的出现是划时代的,虽然目前主流的图像/视频生成模型是扩散模型(Diffusion Models)的天下,但是我们仍然有必要了解 GAN 的思想。

GAN 的核心思想是训练两个模型,分别为生成器(Generator)和辨别器(Discriminator),生成器的目标是生成虚假的数据,尽可能混淆辨别器,使其无法判别真实数据和虚假数据,而辨别器的目标则是尽可能将真实数据和虚假数据区分开来。这个过程如下图所示:

gan-example

生成器和辨别器处于一个对抗的过程,它们的能力不断地提升。GAN 的一个缺点在于它的训练过程不稳定,因此在 GAN 出来后,跟 GAN 相关的论文层出不穷,包括改进 GAN 的损失函数、训练方式,或者采用更先进的模型结构,使 GAN 的生成能力更强,同时使其训练过程更加稳定,但是 GAN 的核心思想是不变的。

模型结构

GAN 的结构如下图所示:

gan

GAN 的生成器和辨别器是两个独立的模型,在原始 GAN 中采用的生成器和辨别器都是多层感知机(Multi Layer Perceptron),后来出现了许多模型结构的改进,例如 DCGAN 将 MLP 替换为卷积神经网络。

辨别器

辨别器本质上是一个分类器,用于区分真实数据和由生成器生成的虚假数据,输出是一个 0-1 范围的标量,表示为真实数据的概率值。辨别器有两个数据来源:真实和虚假数据,训练辨别器的过程中,保持生成器的参数不变,利用二分类损失计算梯度,执行反向传播更新辨别器的参数,过程如下。

gan-disc

生成器

生成器用于生成虚假数据,尽可能混淆辨别器,生成器接受一个随机噪声(Random Noise),随机噪声的采样可以来自于均匀分布、正态分布等等,甚至可以是一张图片。生成器的作用就是将随机噪声分布转换为真实数据的分布,在生成器训练的过程中,保持辨别器的参数不变,利用辨别器的梯度来更新生成器。

在这里插入图片描述

损失函数

GAN 采用了 minimax 损失,其数学表达式如下:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D,G)=E_{x\sim p_{data}(x)}[\log D(x)]+E_{z\sim p_z(z)}[\log(1-D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中, V ( D , G ) V(D,G) V(D,G) 表示价值函数, x x x 为真实数据采样的样本, z z z 为生成器生成的样本。

minimax 损失本质上是一个二分类损失(Binary Cross Entropy),可以拆解为辨别器损失和生成器损失。

在训练辨别器的过程中,生成器参数保持不变,因此对于辨别器而言,\(G(z)\) 可以视为常数,其损失函数为:

L D = − E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] − E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] L_D=-E_{x\sim p_{data}(x)}[\log D(x)]-E_{z\sim p_z(z)}[\log(1-D(G(z)))] LD=Expdata(x)[logD(x)]Ezpz(z)[log(1D(G(z)))]

在训练生成器的过程中,辨别器参数保持不变,因此对于辨别器而言,价值函数的第一项为常数,在求导时忽略不计,因此生成器的损失函数为:

L G = − E z ∼ p z ( z ) [ log ⁡ ( D ( G ( z ) ) ) ] L_G=-E_{z\sim p_z(z)}[\log(D(G(z)))] LG=Ezpz(z)[log(D(G(z)))]

对于上述两个损失函数一个直观的理解是,对于 L G L_G LG 而言,我们希望生成器生成的假数据使判别器无法区分,即希望判别器输出的概率接近于 1,取对数后即接近于 0,由于判别器的输出在于 0 - 1 之间,因此取 log 后为负数,即转变为最大化对数概率,或最小化负对数概率,由于优化的过程通常是梯度下降的过程,因此选择后者。

在 GAN 的论文中,给出了一张用于阐述 GAN 的训练过程的图。假设随机噪声 z z z 采样自一维均匀分布,真实数据分布为标准正态分布。图中的黑色点线表示真实数据分布,蓝色虚线表示辨别器输出的概率分布,绿色实线表示生成器输出的概率分布。随着 GAN 的不断训练,生成器生成的数据分布逐渐接近于真实数据分布,辨别器越来越难以区分真实数据和假数据,因此在理想情况下,生成器完全学习到了真实数据分布,辨别器再也无法进行区分,因此输出的概率都为 50%,也就是图(d) 所示的直线。

gan-process

GAN 的训练过程以及 PyTorch 实现

以下是原始 GAN 论文中的训练算法:

train

注意:这里生成器的损失函数并不是前面重写的形式,但是它们两个是等价的,在实际中,作者采用前面重写的形式,因为他们认为这样训练更加稳定。实际的情况是都不那么稳定:)。

下面是一个 GAN 的 PyTorch 实现例子,生成器和辨别器均采用 MLP,在数据集 MNIST 上进行训练的代码,具体代码可见:vanilla-gan

import os
from argparse import Namespace, ArgumentParser
import torch
from torch import nn, Tensor
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader


class Discriminator(nn.Module):
    """
    Disrcminator in GAN.
    
    Model Architecture: [affine - leaky relu - dropout] x 3 - affine - sigmoid
    """
    def __init__(self, image_shape: tuple[int, int, int]) -> None:
        super(Discriminator, self).__init__()
        C, H, W = image_shape
        image_size = C * H * W
        self.model = nn.Sequential(
            nn.Linear(image_size, 512), nn.LeakyReLU(0.2), nn.Dropout(0.3),
            nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Dropout(0.3),
            nn.Linear(256, 128), nn.LeakyReLU(0.2), nn.Dropout(0.3),
            nn.Linear(128, 1), nn.Sigmoid())

    def forward(self, images: Tensor) -> Tensor:
        images = images.view(images.size(0), -1)
        return self.model(images)


class Generator(nn.Module):
    """
    Generator in GAN.

    Model Architecture: [affine - batchnorm - relu] x 4 - affine - tanh
    """
    def __init__(self, image_shape: tuple[int, int, int], latent_dim: int) -> None:
        super(Generator, self).__init__()
        C, H, W = image_shape
        image_size = C * H * W
        self.image_shape = image_shape
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Linear(128, 256), nn.BatchNorm1d(256), nn.ReLU(),
            nn.Linear(256, 512), nn.BatchNorm1d(512), nn.ReLU(),
            nn.Linear(512, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
            nn.Linear(1024, image_size), nn.Tanh())
    
    def forward(self, z: Tensor) -> Tensor:
        images: Tensor = self.model(z)
        return images.view(-1, *self.image_shape)

# Image processing.
transform_mnist = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])

transform_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

# Device configuration.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def denormalize(x: Tensor) -> Tensor:
    out = (x + 1) / 2
    return out.clamp(0, 1)


def get_args() -> Namespace:
    """Get commandline arguments."""
    parser = ArgumentParser()
    parser.add_argument('--lr', type=float, default=0.0002, help='learning rate for Adam optimizer')
    parser.add_argument('--beta1', type=float, default=0.5, help='first momentum term for Adam')
    parser.add_argument('--beta2', type=float, default=0.999, help='second momentum term for Adam')
    parser.add_argument('--batch_size', type=int, default=64, help='size of a mini-batch')
    parser.add_argument('--num_epochs', type=int, default=100, help='training epochs')
    parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
    parser.add_argument('--dataset', type=str, default='MNIST', help='training dataset(MNIST | FashionMNIST | CIFAR10)')
    parser.add_argument('--sample_dir', type=str, default='samples', help='directory of image samples')
    parser.add_argument('--interval', type=int, default=1, help='epoch interval between image samples')
    parser.add_argument('--logdir', type=str, default='runs', help='directory of running log')
    parser.add_argument('--ckpt_dir', type=str, default='checkpoints', help='directory for saving model checkpoints')
    parser.add_argument('--seed', type=str, default=10213, help='random seed')
    return parser.parse_args()


def setup(args: Namespace) -> None:
    torch.manual_seed(args.seed)
    # Create directory if not exists.
    if not os.path.exists(os.path.join(args.sample_dir, args.dataset)):
        os.makedirs(os.path.join(args.sample_dir, args.dataset))
    if not os.path.exists(os.path.join(args.ckpt_dir, args.dataset)):
        os.makedirs(os.path.join(args.ckpt_dir, args.dataset))


def get_data_loader(args: Namespace) -> DataLoader:
    """Get data loader."""
    if args.dataset == 'MNIST':
        data = datasets.MNIST(root='../data', train=True, download=True, transform=transform_mnist)
    elif args.dataset == 'FashionMNIST':
        data = datasets.FashionMNIST(root='../data', train=True, download=True, transform=transform_mnist)
    elif args.dataset == 'CIFAR10':
        data = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_cifar)
    else:
        raise ValueError(f'Unkown dataset: {args.dataset}, support dataset: MNIST | FashionMNIST | CIFAR10')
    return DataLoader(dataset=data, batch_size=args.batch_size, num_workers=4, shuffle=True)


def train(args: Namespace, 
          G: Generator, D: Discriminator, 
          data_loader: DataLoader) -> None:
    """Train Generator and Discriminator.

    Args:
        args(Namespace): arguments.
        G(Generator): Generator in GAN.
        D(Discriminator): Discriminator in GAN.
    """
    writer = SummaryWriter(os.path.join(args.logdir, args.dataset))

    # generate fixed noise for sampling.
    fixed_noise = torch.rand(64, args.latent_dim).to(device)

    # Loss and optimizer.
    criterion = nn.BCELoss().to(device)
    optimizer_G = torch.optim.Adam(G.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
    optimizer_D = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

    # Start training.
    for epoch in range(args.num_epochs):
        total_d_loss = total_g_loss = 0
        for images, _ in data_loader:
            m = images.size(0)
            images: Tensor = images.to(device)
            images = images.view(m, -1)
            # Create real and fake labels.
            real_labels = torch.ones(m, 1).to(device)
            fake_labels = torch.zeros(m, 1).to(device)
            # ================================================================== #
            #                      Train the discriminator                       #
            # ================================================================== #

            # Forward pass
            outputs = D(images)
            d_loss_real: Tensor = criterion(outputs, real_labels)
            
            z = torch.rand(m, args.latent_dim).to(device)
            fake_images: Tensor = G(z).detach()
            outputs = D(fake_images)
            d_loss_fake: Tensor = criterion(outputs, fake_labels)

            # Backward pass
            d_loss: Tensor = d_loss_real + d_loss_fake
            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()
            total_d_loss += d_loss

            # ================================================================== #
            #                        Train the generator                         #
            # ================================================================== #
            
            # Forward pass
            z = torch.rand(images.size(0), args.latent_dim).to(device)
            fake_images: Tensor = G(z)
            outputs = D(fake_images)
            
            # Backward pass
            g_loss: Tensor = criterion(outputs, real_labels)
            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()
            total_g_loss += g_loss
        print(f'''
=====================================
Epoch: [{epoch + 1}/{args.num_epochs}]
Discriminator Loss: {total_d_loss / len(data_loader):.4f}
Generator Loss: {total_g_loss / len(data_loader):.4f}
=====================================''')
        # Log Discriminator and Generator loss.
        writer.add_scalar('Discriminator Loss', total_d_loss / len(data_loader), epoch + 1)
        writer.add_scalar('Generator Loss', total_g_loss / len(data_loader), epoch + 1)
        fake_images: Tensor = G(fixed_noise)
        img_grid = make_grid(denormalize(fake_images), nrow=8, padding=2)
        writer.add_image('Fake Images', img_grid, epoch + 1)
        if (epoch + 1) % args.interval == 0:
            save_image(img_grid, os.path.join(args.sample_dir, args.dataset, f'fake_images_{epoch + 1}.png'))
    # Save the model checkpoints.
    torch.save(G.state_dict(), os.path.join(args.ckpt_dir, args.dataset, 'G.ckpt'))
    torch.save(D.state_dict(), os.path.join(args.ckpt_dir, args.dataset, 'D.ckpt'))


def main() -> None:
    args = get_args()
    setup(args)
    image_shape = (1, 28, 28) if args.dataset in ('MNIST', 'FashionMNIST') else (3, 32, 32)
    data_loader = get_data_loader(args)
    # Generator and Discrminator.
    G = Generator(image_shape=image_shape, latent_dim=args.latent_dim).to(device)
    D = Discriminator(image_shape=image_shape).to(device)
    train(args, G, D, data_loader)


if __name__ == '__main__':
    main()

参考

[1] I. Goodfellow et al., “Generative Adversarial Nets,” in Advances in Neural Information Processing Systems, Curran Associates, Inc., 2014. Accessed: Sep. 12, 2024. [Online]. Available: https://papers.nips.cc/paper_files/paper/2014/hash/5ca3e9b122f61f8f06494c97b1afccf3-Abstract.html

[2] eriklindernoren. “PyTorch-GAN”. Github 2018. [Online]. Available: https://github.com/eriklindernoren/PyTorch-GAN

[3] 李沐. “GAN论文逐段精读【论文精读】”. Bilibili 2021. [Online]. Available: https://www.bilibili.com/video/BV1rb4y187vD/?spm_id_from=333.1387.collection.video_card.click&vd_source=c8a32a5a667964d5f1068d38d6182813
n. “PyTorch-GAN”. Github 2018. [Online]. Available: https://github.com/eriklindernoren/PyTorch-GAN