GAN
2014 年,Ian Goodfellow 等人提出生成对抗网络(Generative Adversarial Networks),GAN 的出现是划时代的,虽然目前主流的图像/视频生成模型是扩散模型(Diffusion Models)的天下,但是我们仍然有必要了解 GAN 的思想。
GAN 的核心思想是训练两个模型,分别为生成器(Generator)和辨别器(Discriminator),生成器的目标是生成虚假的数据,尽可能混淆辨别器,使其无法判别真实数据和虚假数据,而辨别器的目标则是尽可能将真实数据和虚假数据区分开来。这个过程如下图所示:
生成器和辨别器处于一个对抗的过程,它们的能力不断地提升。GAN 的一个缺点在于它的训练过程不稳定,因此在 GAN 出来后,跟 GAN 相关的论文层出不穷,包括改进 GAN 的损失函数、训练方式,或者采用更先进的模型结构,使 GAN 的生成能力更强,同时使其训练过程更加稳定,但是 GAN 的核心思想是不变的。
模型结构
GAN 的结构如下图所示:
GAN 的生成器和辨别器是两个独立的模型,在原始 GAN 中采用的生成器和辨别器都是多层感知机(Multi Layer Perceptron),后来出现了许多模型结构的改进,例如 DCGAN 将 MLP 替换为卷积神经网络。
辨别器
辨别器本质上是一个分类器,用于区分真实数据和由生成器生成的虚假数据,输出是一个 0-1 范围的标量,表示为真实数据的概率值。辨别器有两个数据来源:真实和虚假数据,训练辨别器的过程中,保持生成器的参数不变,利用二分类损失计算梯度,执行反向传播更新辨别器的参数,过程如下。
生成器
生成器用于生成虚假数据,尽可能混淆辨别器,生成器接受一个随机噪声(Random Noise),随机噪声的采样可以来自于均匀分布、正态分布等等,甚至可以是一张图片。生成器的作用就是将随机噪声分布转换为真实数据的分布,在生成器训练的过程中,保持辨别器的参数不变,利用辨别器的梯度来更新生成器。
损失函数
GAN 采用了 minimax 损失,其数学表达式如下:
min G max D V ( D , G ) = E x ∼ p d a t a ( x ) [ log D ( x ) ] + E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D,G)=E_{x\sim p_{data}(x)}[\log D(x)]+E_{z\sim p_z(z)}[\log(1-D(G(z)))] GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
其中, V ( D , G ) V(D,G) V(D,G) 表示价值函数, x x x 为真实数据采样的样本, z z z 为生成器生成的样本。
minimax 损失本质上是一个二分类损失(Binary Cross Entropy),可以拆解为辨别器损失和生成器损失。
在训练辨别器的过程中,生成器参数保持不变,因此对于辨别器而言,\(G(z)\) 可以视为常数,其损失函数为:
L D = − E x ∼ p d a t a ( x ) [ log D ( x ) ] − E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] L_D=-E_{x\sim p_{data}(x)}[\log D(x)]-E_{z\sim p_z(z)}[\log(1-D(G(z)))] LD=−Ex∼pdata(x)[logD(x)]−Ez∼pz(z)[log(1−D(G(z)))]
在训练生成器的过程中,辨别器参数保持不变,因此对于辨别器而言,价值函数的第一项为常数,在求导时忽略不计,因此生成器的损失函数为:
L G = − E z ∼ p z ( z ) [ log ( D ( G ( z ) ) ) ] L_G=-E_{z\sim p_z(z)}[\log(D(G(z)))] LG=−Ez∼pz(z)[log(D(G(z)))]
对于上述两个损失函数一个直观的理解是,对于 L G L_G LG 而言,我们希望生成器生成的假数据使判别器无法区分,即希望判别器输出的概率接近于 1,取对数后即接近于 0,由于判别器的输出在于 0 - 1 之间,因此取 log 后为负数,即转变为最大化对数概率,或最小化负对数概率,由于优化的过程通常是梯度下降的过程,因此选择后者。
在 GAN 的论文中,给出了一张用于阐述 GAN 的训练过程的图。假设随机噪声 z z z 采样自一维均匀分布,真实数据分布为标准正态分布。图中的黑色点线表示真实数据分布,蓝色虚线表示辨别器输出的概率分布,绿色实线表示生成器输出的概率分布。随着 GAN 的不断训练,生成器生成的数据分布逐渐接近于真实数据分布,辨别器越来越难以区分真实数据和假数据,因此在理想情况下,生成器完全学习到了真实数据分布,辨别器再也无法进行区分,因此输出的概率都为 50%,也就是图(d) 所示的直线。
GAN 的训练过程以及 PyTorch 实现
以下是原始 GAN 论文中的训练算法:
注意:这里生成器的损失函数并不是前面重写的形式,但是它们两个是等价的,在实际中,作者采用前面重写的形式,因为他们认为这样训练更加稳定。实际的情况是都不那么稳定:)。
下面是一个 GAN 的 PyTorch 实现例子,生成器和辨别器均采用 MLP,在数据集 MNIST 上进行训练的代码,具体代码可见:vanilla-gan。
import os
from argparse import Namespace, ArgumentParser
import torch
from torch import nn, Tensor
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
class Discriminator(nn.Module):
"""
Disrcminator in GAN.
Model Architecture: [affine - leaky relu - dropout] x 3 - affine - sigmoid
"""
def __init__(self, image_shape: tuple[int, int, int]) -> None:
super(Discriminator, self).__init__()
C, H, W = image_shape
image_size = C * H * W
self.model = nn.Sequential(
nn.Linear(image_size, 512), nn.LeakyReLU(0.2), nn.Dropout(0.3),
nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Dropout(0.3),
nn.Linear(256, 128), nn.LeakyReLU(0.2), nn.Dropout(0.3),
nn.Linear(128, 1), nn.Sigmoid())
def forward(self, images: Tensor) -> Tensor:
images = images.view(images.size(0), -1)
return self.model(images)
class Generator(nn.Module):
"""
Generator in GAN.
Model Architecture: [affine - batchnorm - relu] x 4 - affine - tanh
"""
def __init__(self, image_shape: tuple[int, int, int], latent_dim: int) -> None:
super(Generator, self).__init__()
C, H, W = image_shape
image_size = C * H * W
self.image_shape = image_shape
self.model = nn.Sequential(
nn.Linear(latent_dim, 128), nn.BatchNorm1d(128), nn.ReLU(),
nn.Linear(128, 256), nn.BatchNorm1d(256), nn.ReLU(),
nn.Linear(256, 512), nn.BatchNorm1d(512), nn.ReLU(),
nn.Linear(512, 1024), nn.BatchNorm1d(1024), nn.ReLU(),
nn.Linear(1024, image_size), nn.Tanh())
def forward(self, z: Tensor) -> Tensor:
images: Tensor = self.model(z)
return images.view(-1, *self.image_shape)
# Image processing.
transform_mnist = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5), std=(0.5))])
transform_cifar = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
# Device configuration.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def denormalize(x: Tensor) -> Tensor:
out = (x + 1) / 2
return out.clamp(0, 1)
def get_args() -> Namespace:
"""Get commandline arguments."""
parser = ArgumentParser()
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate for Adam optimizer')
parser.add_argument('--beta1', type=float, default=0.5, help='first momentum term for Adam')
parser.add_argument('--beta2', type=float, default=0.999, help='second momentum term for Adam')
parser.add_argument('--batch_size', type=int, default=64, help='size of a mini-batch')
parser.add_argument('--num_epochs', type=int, default=100, help='training epochs')
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
parser.add_argument('--dataset', type=str, default='MNIST', help='training dataset(MNIST | FashionMNIST | CIFAR10)')
parser.add_argument('--sample_dir', type=str, default='samples', help='directory of image samples')
parser.add_argument('--interval', type=int, default=1, help='epoch interval between image samples')
parser.add_argument('--logdir', type=str, default='runs', help='directory of running log')
parser.add_argument('--ckpt_dir', type=str, default='checkpoints', help='directory for saving model checkpoints')
parser.add_argument('--seed', type=str, default=10213, help='random seed')
return parser.parse_args()
def setup(args: Namespace) -> None:
torch.manual_seed(args.seed)
# Create directory if not exists.
if not os.path.exists(os.path.join(args.sample_dir, args.dataset)):
os.makedirs(os.path.join(args.sample_dir, args.dataset))
if not os.path.exists(os.path.join(args.ckpt_dir, args.dataset)):
os.makedirs(os.path.join(args.ckpt_dir, args.dataset))
def get_data_loader(args: Namespace) -> DataLoader:
"""Get data loader."""
if args.dataset == 'MNIST':
data = datasets.MNIST(root='../data', train=True, download=True, transform=transform_mnist)
elif args.dataset == 'FashionMNIST':
data = datasets.FashionMNIST(root='../data', train=True, download=True, transform=transform_mnist)
elif args.dataset == 'CIFAR10':
data = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_cifar)
else:
raise ValueError(f'Unkown dataset: {args.dataset}, support dataset: MNIST | FashionMNIST | CIFAR10')
return DataLoader(dataset=data, batch_size=args.batch_size, num_workers=4, shuffle=True)
def train(args: Namespace,
G: Generator, D: Discriminator,
data_loader: DataLoader) -> None:
"""Train Generator and Discriminator.
Args:
args(Namespace): arguments.
G(Generator): Generator in GAN.
D(Discriminator): Discriminator in GAN.
"""
writer = SummaryWriter(os.path.join(args.logdir, args.dataset))
# generate fixed noise for sampling.
fixed_noise = torch.rand(64, args.latent_dim).to(device)
# Loss and optimizer.
criterion = nn.BCELoss().to(device)
optimizer_G = torch.optim.Adam(G.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
optimizer_D = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
# Start training.
for epoch in range(args.num_epochs):
total_d_loss = total_g_loss = 0
for images, _ in data_loader:
m = images.size(0)
images: Tensor = images.to(device)
images = images.view(m, -1)
# Create real and fake labels.
real_labels = torch.ones(m, 1).to(device)
fake_labels = torch.zeros(m, 1).to(device)
# ================================================================== #
# Train the discriminator #
# ================================================================== #
# Forward pass
outputs = D(images)
d_loss_real: Tensor = criterion(outputs, real_labels)
z = torch.rand(m, args.latent_dim).to(device)
fake_images: Tensor = G(z).detach()
outputs = D(fake_images)
d_loss_fake: Tensor = criterion(outputs, fake_labels)
# Backward pass
d_loss: Tensor = d_loss_real + d_loss_fake
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()
total_d_loss += d_loss
# ================================================================== #
# Train the generator #
# ================================================================== #
# Forward pass
z = torch.rand(images.size(0), args.latent_dim).to(device)
fake_images: Tensor = G(z)
outputs = D(fake_images)
# Backward pass
g_loss: Tensor = criterion(outputs, real_labels)
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
total_g_loss += g_loss
print(f'''
=====================================
Epoch: [{epoch + 1}/{args.num_epochs}]
Discriminator Loss: {total_d_loss / len(data_loader):.4f}
Generator Loss: {total_g_loss / len(data_loader):.4f}
=====================================''')
# Log Discriminator and Generator loss.
writer.add_scalar('Discriminator Loss', total_d_loss / len(data_loader), epoch + 1)
writer.add_scalar('Generator Loss', total_g_loss / len(data_loader), epoch + 1)
fake_images: Tensor = G(fixed_noise)
img_grid = make_grid(denormalize(fake_images), nrow=8, padding=2)
writer.add_image('Fake Images', img_grid, epoch + 1)
if (epoch + 1) % args.interval == 0:
save_image(img_grid, os.path.join(args.sample_dir, args.dataset, f'fake_images_{epoch + 1}.png'))
# Save the model checkpoints.
torch.save(G.state_dict(), os.path.join(args.ckpt_dir, args.dataset, 'G.ckpt'))
torch.save(D.state_dict(), os.path.join(args.ckpt_dir, args.dataset, 'D.ckpt'))
def main() -> None:
args = get_args()
setup(args)
image_shape = (1, 28, 28) if args.dataset in ('MNIST', 'FashionMNIST') else (3, 32, 32)
data_loader = get_data_loader(args)
# Generator and Discrminator.
G = Generator(image_shape=image_shape, latent_dim=args.latent_dim).to(device)
D = Discriminator(image_shape=image_shape).to(device)
train(args, G, D, data_loader)
if __name__ == '__main__':
main()
参考
[1] I. Goodfellow et al., “Generative Adversarial Nets,” in Advances in Neural Information Processing Systems, Curran Associates, Inc., 2014. Accessed: Sep. 12, 2024. [Online]. Available: https://papers.nips.cc/paper_files/paper/2014/hash/5ca3e9b122f61f8f06494c97b1afccf3-Abstract.html
[2] eriklindernoren. “PyTorch-GAN”. Github 2018. [Online]. Available: https://github.com/eriklindernoren/PyTorch-GAN
[3] 李沐. “GAN论文逐段精读【论文精读】”. Bilibili 2021. [Online]. Available: https://www.bilibili.com/video/BV1rb4y187vD/?spm_id_from=333.1387.collection.video_card.click&vd_source=c8a32a5a667964d5f1068d38d6182813
n. “PyTorch-GAN”. Github 2018. [Online]. Available: https://github.com/eriklindernoren/PyTorch-GAN