生成式人工智能实战 | 条件生成对抗网络(conditional Generative Adversarial Network, cGAN)

发布于:2025-07-09 ⋅ 阅读:(18) ⋅ 点赞:(0)

0. 前言

生成对抗网络 (Generative Adversarial Network, GAN) 是近年来深度学习领域最具突破性的技术之一,能够生成逼真的图像、音频甚至文本。然而,传统的 GAN 生成过程是随机的,无法控制生成内容的具体属性。条件生成对抗网络 (Conditional GAN, CGAN) 通过引入类别标签等条件信息,使生成过程变得可控,极大拓展了 GAN 的应用场景。本节将深入解析 CGAN 的技术原理,并使用 PyTorchCIFAR-10 数据集上实现一个完整的 CGAN 模型。

1. 条件生成对抗网络

1.1 GAN 基础回顾

生成对抗网络 (Generative Adversarial Network, GAN) 由生成器 (Generator) 和判别器 (Discriminator) 组成,生成器将将随机噪声转换为逼真的数据样本,而判别器区分真实样本和生成样本。两者通过对抗训练共同提升,最终目标是生成器能产生以假乱真的样本。

1.2 cGAN 核心思想

标准 GAN 的生成过程是无条件的,无法控制生成样本的类别。条件 GAN (conditional GAN, cGAN) 通过在生成器和判别器中引入额外的条件信息(如类别标签),实现了对生成过程的控制:

  • 生成器输入:噪声 z z z + 条件信息 c c c
  • 判别器输入:数据 x x x + 条件信息 c c c

cGAN 的主要优点是能够选择生成数据的某些属性,使其更加灵活,适用于那些需要根据特定输入参数来定向或条件化输出的场景。总之,cGAN 是基本 GAN 架构的一种扩展,使得基于条件输入能够有针对性地生成合成数据。

2. cGAN 网络架构

2.1 数学原理

cGAN 的目标函数可以表示为:
m i n G m a x D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ∣ c ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ∣ c ) ) ) ] min_G max_D V(D,G) = E_{x\sim p_{data}(x)}[logD(x|c)] + E_{z\sim p_z(z)}[log(1-D(G(z|c)))] minGmaxDV(D,G)=Expdata(x)[logD(xc)]+Ezpz(z)[log(1D(G(zc)))]
其中, D ( x ∣ y ) D(x|y) D(xy) 表示判别器在给定条件 y y y 下判断 x x x 为真实样本的概率, G ( z ∣ y ) G(z|y) G(zy) 表示生成器在给定条件 y y y 下从噪声 z z z 生成的样本, p d a t a ( x ) p_{data}(x) pdata(x) 是真实数据分布, p z ( z ) p_z(z) pz(z) 是噪声分布。

2.2 网络架构

典型的 cGAN 架构包含以下关键组件,条件信息嵌入用于将类别标签转换为嵌入向量,与噪声向量拼接;生成器网络通常使用转置卷积层逐步上采样噪声向量;判别器网络使用卷积层逐步下采样输入图像。

3. 实现 cGAN

3.1 环境准备与数据加载

(1) 首先导入必要的库并设置设备:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# 检查GPU可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

(2) 定义数据预处理:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 将像素值归一化到[-1,1]
])

(3) 下载并加载 CIFAR-10 训练集:

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True, 
    download=True, 
    transform=transform
)

# CIFAR-10类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

(4) 创建数据加载器:

batch_size = 128
train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True
)

3.2 模型构建

(1) 定义生成器网络结构,输入为噪声向量 z z z (100 维) + 类别标签 (10 维独热编码),输出为 3x32x32 的彩色图像:

class Generator(nn.Module):
    def __init__(self, n_classes=10):
        super(Generator, self).__init__()
        
        # 标签嵌入层,将类别标签转换为特征向量
        self.label_emb = nn.Embedding(n_classes, n_classes)
        
        # 定义生成器网络结构
        self.model = nn.Sequential(
            # 输入: (噪声100维 + 标签10维) -> 输出: 256x4x4
            nn.Linear(100 + n_classes, 256 * 4 * 4),
            nn.BatchNorm1d(256 * 4 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 重塑为256x4x4的特征图
            nn.Unflatten(1, (256, 4, 4)),
            
            # 上采样层1: 256x4x4 -> 128x8x8
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 上采样层2: 128x8x8 -> 64x16x16
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 上采样层3: 64x16x16 -> 3x32x32
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # 输出值在[-1,1]之间,与预处理一致
        )
    
    def forward(self, z, labels):
        # 将标签转换为嵌入向量
        c = self.label_emb(labels)
        
        # 拼接噪声和标签嵌入
        x = torch.cat([z, c], dim=1)
        
        # 通过生成器网络
        img = self.model(x)
        return img

(2) 定义判别器网络结构,输入为 3x32x32 图像 + 类别标签( 10 维独热编码),输出为判别结果(真/假):

class Discriminator(nn.Module):
    def __init__(self, n_classes=10):
        super(Discriminator, self).__init__()
        
        # 标签嵌入层
        self.label_emb = nn.Embedding(n_classes, n_classes)
        
        # 定义判别器网络结构
        self.model = nn.Sequential(
            # 输入: 3x32x32 -> 输出: 64x16x16
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 64x16x16 -> 128x8x8
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 128x8x8 -> 256x4x4
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 展平特征图
            nn.Flatten()
        )
        
        # 最终判别层
        self.adv_layer = nn.Sequential(
            nn.Linear(256 * 4 * 4 + n_classes, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img, labels):
        # 提取图像特征
        img_features = self.model(img)
        
        # 将标签转换为嵌入向量
        c = self.label_emb(labels)
        
        # 拼接图像特征和标签嵌入
        x = torch.cat([img_features, c], dim=1)
        
        # 通过判别层
        validity = self.adv_layer(x)
        return validity

(3) 初始化模型,并定义损失函数与优化器:

# 初始化生成器和判别器
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 定义损失函数和优化器
adversarial_loss = nn.BCELoss()

# 优化器
lr = 0.0002
beta1 = 0.5
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# 定义真实和假的标签
real_label = 1.
fake_label = 0.

3.3 模型训练

(1) 定义训练超参数:

# 训练参数
n_epochs = 100
sample_interval = 400  # 每隔多少batch保存一次生成样本
n_classes = 10
latent_dim = 100

# 用于保存生成样本的固定噪声和固定标签
fixed_noise = torch.randn(10, latent_dim, device=device)
fixed_labels = torch.arange(0, n_classes, device=device).long()

(2) 定义训练循环,训练完成后保存模型:

# 训练循环
for epoch in range(n_epochs):
    for i, (imgs, labels) in enumerate(tqdm(train_loader)):
        batch_size = imgs.shape[0]
        
        # 配置真实图像和标签
        real_imgs = imgs.to(device)
        real_labels = labels.to(device)
        
        optimizer_D.zero_grad()
        
        # 真实图像的损失
        validity_real = discriminator(real_imgs, real_labels)
        d_real_loss = adversarial_loss(validity_real, torch.full((batch_size, 1), real_label, device=device))
        d_real_loss.backward()
        optimizer_D.step()
        # 生成图像的损失
        optimizer_D.zero_grad()
        z = torch.randn(batch_size, latent_dim, device=device)
        gen_labels = torch.randint(0, n_classes, (batch_size,), device=device)
        
        fake_imgs = generator(z, gen_labels)
        validity_fake = discriminator(fake_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, torch.full((batch_size, 1), fake_label, device=device))
        d_fake_loss.backward()
        optimizer_D.step()
        # 总判别器损失
        d_loss = (d_real_loss + d_fake_loss) / 2
        
        optimizer_G.zero_grad()
        
        # 生成器希望生成的图像被判别为真
        validity = discriminator(fake_imgs, gen_labels)
        g_loss = adversarial_loss(validity, torch.full((batch_size, 1), real_label, device=device))
        
        g_loss.backward()
        optimizer_G.step()
        
        # 打印训练状态
        if i % 100 == 0:
            print(
                f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(train_loader)}] "
                f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]"
            )
        
        # 保存生成样本
        if i % sample_interval == 0:
            with torch.no_grad():
                fake = generator(fixed_noise, fixed_labels)
                # 保存图像或显示
                save_image(fake.data, f"images/{epoch}_{i}.png", nrow=5, normalize=True)
    
# 保存模型
torch.save(generator.state_dict(), f"cgan_generator.pth")
torch.save(discriminator.state_dict(), f"cgan_discriminator.pth")

(3) 模型训练完成后,生成样本可视化:

def show_generated_samples(n_row=10):
    # 加载模型
    generator.load_state_dict(torch.load(f"cgan_generator.pth"))
    generator.eval()
    
    # 生成样本
    with torch.no_grad():
        z = torch.randn(n_row, latent_dim, device=device)
        labels = torch.arange(0, n_row, device=device).long()
        samples = generator(z, labels).cpu()
    
    # 反归一化
    samples = samples * 0.5 + 0.5
    
    # 创建图像网格
    fig, axes = plt.subplots(1, n_row, figsize=(20, 2))
    for i, ax in enumerate(axes):
        ax.imshow(np.transpose(samples[i], (1, 2, 0)))
        ax.set_title(classes[i])
        ax.axis('off')
    plt.show()

show_generated_samples()

生成结果


网站公告

今日签到

点亮在社区的每一天
去签到