生成式人工智能实战 | 条件生成对抗网络
0. 前言
生成对抗网络 (Generative Adversarial Network, GAN) 是近年来深度学习领域最具突破性的技术之一,能够生成逼真的图像、音频甚至文本。然而,传统的 GAN
生成过程是随机的,无法控制生成内容的具体属性。条件生成对抗网络 (Conditional GAN
, CGAN
) 通过引入类别标签等条件信息,使生成过程变得可控,极大拓展了 GAN
的应用场景。本节将深入解析 CGAN
的技术原理,并使用 PyTorch
在 CIFAR-10
数据集上实现一个完整的 CGAN
模型。
1. 条件生成对抗网络
1.1 GAN 基础回顾
生成对抗网络 (Generative Adversarial Network, GAN) 由生成器 (Generator
) 和判别器 (Discriminator
) 组成,生成器将将随机噪声转换为逼真的数据样本,而判别器区分真实样本和生成样本。两者通过对抗训练共同提升,最终目标是生成器能产生以假乱真的样本。
1.2 cGAN 核心思想
标准 GAN
的生成过程是无条件的,无法控制生成样本的类别。条件 GAN
(conditional GAN
, cGAN
) 通过在生成器和判别器中引入额外的条件信息(如类别标签),实现了对生成过程的控制:
- 生成器输入:噪声 z z z + 条件信息 c c c
- 判别器输入:数据 x x x + 条件信息 c c c
cGAN
的主要优点是能够选择生成数据的某些属性,使其更加灵活,适用于那些需要根据特定输入参数来定向或条件化输出的场景。总之,cGAN
是基本 GAN
架构的一种扩展,使得基于条件输入能够有针对性地生成合成数据。
2. cGAN 网络架构
2.1 数学原理
cGAN
的目标函数可以表示为:
m i n G m a x D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ∣ c ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ∣ c ) ) ) ] min_G max_D V(D,G) = E_{x\sim p_{data}(x)}[logD(x|c)] + E_{z\sim p_z(z)}[log(1-D(G(z|c)))] minGmaxDV(D,G)=Ex∼pdata(x)[logD(x∣c)]+Ez∼pz(z)[log(1−D(G(z∣c)))]
其中, D ( x ∣ y ) D(x|y) D(x∣y) 表示判别器在给定条件 y y y 下判断 x x x 为真实样本的概率, G ( z ∣ y ) G(z|y) G(z∣y) 表示生成器在给定条件 y y y 下从噪声 z z z 生成的样本, p d a t a ( x ) p_{data}(x) pdata(x) 是真实数据分布, p z ( z ) p_z(z) pz(z) 是噪声分布。
2.2 网络架构
典型的 cGAN
架构包含以下关键组件,条件信息嵌入用于将类别标签转换为嵌入向量,与噪声向量拼接;生成器网络通常使用转置卷积层逐步上采样噪声向量;判别器网络使用卷积层逐步下采样输入图像。
3. 实现 cGAN
3.1 环境准备与数据加载
(1) 首先导入必要的库并设置设备:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
# 检查GPU可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
(2) 定义数据预处理:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 将像素值归一化到[-1,1]
])
(3) 下载并加载 CIFAR-10
训练集:
train_dataset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
# CIFAR-10类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
(4) 创建数据加载器:
batch_size = 128
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)
3.2 模型构建
(1) 定义生成器网络结构,输入为噪声向量 z z z (100
维) + 类别标签 (10
维独热编码),输出为 3x32x32
的彩色图像:
class Generator(nn.Module):
def __init__(self, n_classes=10):
super(Generator, self).__init__()
# 标签嵌入层,将类别标签转换为特征向量
self.label_emb = nn.Embedding(n_classes, n_classes)
# 定义生成器网络结构
self.model = nn.Sequential(
# 输入: (噪声100维 + 标签10维) -> 输出: 256x4x4
nn.Linear(100 + n_classes, 256 * 4 * 4),
nn.BatchNorm1d(256 * 4 * 4),
nn.LeakyReLU(0.2, inplace=True),
# 重塑为256x4x4的特征图
nn.Unflatten(1, (256, 4, 4)),
# 上采样层1: 256x4x4 -> 128x8x8
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 上采样层2: 128x8x8 -> 64x16x16
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
# 上采样层3: 64x16x16 -> 3x32x32
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
nn.Tanh() # 输出值在[-1,1]之间,与预处理一致
)
def forward(self, z, labels):
# 将标签转换为嵌入向量
c = self.label_emb(labels)
# 拼接噪声和标签嵌入
x = torch.cat([z, c], dim=1)
# 通过生成器网络
img = self.model(x)
return img
(2) 定义判别器网络结构,输入为 3x32x32
图像 + 类别标签( 10
维独热编码),输出为判别结果(真/假):
class Discriminator(nn.Module):
def __init__(self, n_classes=10):
super(Discriminator, self).__init__()
# 标签嵌入层
self.label_emb = nn.Embedding(n_classes, n_classes)
# 定义判别器网络结构
self.model = nn.Sequential(
# 输入: 3x32x32 -> 输出: 64x16x16
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
# 64x16x16 -> 128x8x8
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 128x8x8 -> 256x4x4
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# 展平特征图
nn.Flatten()
)
# 最终判别层
self.adv_layer = nn.Sequential(
nn.Linear(256 * 4 * 4 + n_classes, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
# 提取图像特征
img_features = self.model(img)
# 将标签转换为嵌入向量
c = self.label_emb(labels)
# 拼接图像特征和标签嵌入
x = torch.cat([img_features, c], dim=1)
# 通过判别层
validity = self.adv_layer(x)
return validity
(3) 初始化模型,并定义损失函数与优化器:
# 初始化生成器和判别器
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 定义损失函数和优化器
adversarial_loss = nn.BCELoss()
# 优化器
lr = 0.0002
beta1 = 0.5
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
# 定义真实和假的标签
real_label = 1.
fake_label = 0.
3.3 模型训练
(1) 定义训练超参数:
# 训练参数
n_epochs = 100
sample_interval = 400 # 每隔多少batch保存一次生成样本
n_classes = 10
latent_dim = 100
# 用于保存生成样本的固定噪声和固定标签
fixed_noise = torch.randn(10, latent_dim, device=device)
fixed_labels = torch.arange(0, n_classes, device=device).long()
(2) 定义训练循环,训练完成后保存模型:
# 训练循环
for epoch in range(n_epochs):
for i, (imgs, labels) in enumerate(tqdm(train_loader)):
batch_size = imgs.shape[0]
# 配置真实图像和标签
real_imgs = imgs.to(device)
real_labels = labels.to(device)
optimizer_D.zero_grad()
# 真实图像的损失
validity_real = discriminator(real_imgs, real_labels)
d_real_loss = adversarial_loss(validity_real, torch.full((batch_size, 1), real_label, device=device))
d_real_loss.backward()
optimizer_D.step()
# 生成图像的损失
optimizer_D.zero_grad()
z = torch.randn(batch_size, latent_dim, device=device)
gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)
fake_imgs = generator(z, gen_labels)
validity_fake = discriminator(fake_imgs.detach(), gen_labels)
d_fake_loss = adversarial_loss(validity_fake, torch.full((batch_size, 1), fake_label, device=device))
d_fake_loss.backward()
optimizer_D.step()
# 总判别器损失
d_loss = (d_real_loss + d_fake_loss) / 2
optimizer_G.zero_grad()
# 生成器希望生成的图像被判别为真
validity = discriminator(fake_imgs, gen_labels)
g_loss = adversarial_loss(validity, torch.full((batch_size, 1), real_label, device=device))
g_loss.backward()
optimizer_G.step()
# 打印训练状态
if i % 100 == 0:
print(
f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(train_loader)}] "
f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]"
)
# 保存生成样本
if i % sample_interval == 0:
with torch.no_grad():
fake = generator(fixed_noise, fixed_labels)
# 保存图像或显示
save_image(fake.data, f"images/{epoch}_{i}.png", nrow=5, normalize=True)
# 保存模型
torch.save(generator.state_dict(), f"cgan_generator.pth")
torch.save(discriminator.state_dict(), f"cgan_discriminator.pth")
(3) 模型训练完成后,生成样本可视化:
def show_generated_samples(n_row=10):
# 加载模型
generator.load_state_dict(torch.load(f"cgan_generator.pth"))
generator.eval()
# 生成样本
with torch.no_grad():
z = torch.randn(n_row, latent_dim, device=device)
labels = torch.arange(0, n_row, device=device).long()
samples = generator(z, labels).cpu()
# 反归一化
samples = samples * 0.5 + 0.5
# 创建图像网格
fig, axes = plt.subplots(1, n_row, figsize=(20, 2))
for i, ax in enumerate(axes):
ax.imshow(np.transpose(samples[i], (1, 2, 0)))
ax.set_title(classes[i])
ax.axis('off')
plt.show()
show_generated_samples()