生成对抗网络(GAN)原理
介绍
示例代码
生成对抗网络(Generative Adversarial Network,GAN)是由 Ian Goodfellow 等人在 2014 年提出的一种深度生成模型。它通过两个神经网络之间的博弈(对抗)过程,学习数据的生成分布,从而生成以假乱真的数据(如图像、语音等)。GAN 是近年来生成模型领域的重要突破,广泛应用于图像生成、风格迁移、图像修复等任务中。
一、GAN 的基本结构
GAN 主要由两个部分组成:
1. 生成器(Generator,记作 G)
- 目标:生成尽可能真实的数据,欺骗判别器。
- 输入:随机噪声向量(一般从正态分布或均匀分布中采样)
- 输出:“伪造”的样本,尽可能与真实样本相似。
2. 判别器(Discriminator,记作 D)
- 目标:判断输入数据是真实的样本还是生成器生成的伪造样本。
- 输入:真实样本或生成样本
- 输出:一个概率值,表示输入是“真实”的概率。
二、对抗过程(博弈思想)
GAN 的训练过程是一个零和博弈(min-max game):
- 生成器试图最小化判别器对生成样本的识别能力;
- 判别器试图最大化识别真实样本与生成样本的能力。
这个过程可以表示为一个最优化问题:
min G max D V ( D , G ) = E x ∼ p d a t a ( x ) [ log D ( x ) ] + E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
其中:
- p d a t a ( x ) p_{data}(x) pdata(x) 是真实数据的分布;
- p z ( z ) p_z(z) pz(z) 是生成器输入噪声的分布(如高斯分布);
- D ( x ) D(x) D(x) 是判别器输出 x 是真实数据的概率;
- G ( z ) G(z) G(z) 是生成器输出的伪造样本。
三、训练过程
固定生成器 G,训练判别器 D:
- 给 D 一部分真实样本(标签为 1);
- 给 D 一部分 G 生成的样本(标签为 0);
- 通过交叉熵损失训练 D,使其能区分真假样本。
固定判别器 D,训练生成器 G:
- 通过 G 生成假样本;
- D 会判断其为假;
- G 的目标是欺骗 D,即最大化 D ( G ( z ) ) D(G(z)) D(G(z)),让 D 判错;
- 通常优化的是 log D ( G ( z ) ) \log D(G(z)) logD(G(z)) 的反函数,例如 log ( 1 − D ( G ( z ) ) ) \log(1 - D(G(z))) log(1−D(G(z))) 或更稳定的变体(如使用 feature matching 或 Wasserstein loss)。
交替训练 D 和 G,直到生成器生成的样本无法被判别器区分为假(判别器输出接近 0.5)。
四、存在的问题与改进方向
1. 模式崩溃(Mode Collapse)
生成器只学会生成一小部分模式样本,导致多样性丢失。
2. 训练不稳定
D 和 G 的能力不均衡、学习率不合适等因素可能导致 GAN 训练震荡或失败。
3. 衡量指标困难
GAN 的损失函数不能很好地反映生成质量,因此通常使用 FID、IS 等指标辅助评估。
五、GAN 的改进与变种
为了克服原始 GAN 的不足,研究人员提出了许多变种:
名称 | 简介 |
---|---|
DCGAN | 使用卷积神经网络的 GAN,适合图像数据生成 |
WGAN | Wasserstein GAN,引入 Wasserstein 距离,解决训练不稳定问题 |
WGAN-GP | 在 WGAN 基础上加上梯度惩罚项,提高训练稳定性 |
CGAN | 条件 GAN,可以控制生成样本的类别(如生成特定数字) |
CycleGAN | 用于图像风格转换(如马<->斑马、夏天<->冬天) |
StyleGAN | 高质量人脸图像生成的里程碑,支持精细控制生成风格 |
BigGAN | 更大规模、更高分辨率的图像生成模型 |
六、GAN 的应用
- 图像生成(人脸、卡通、合成图像)
- 图像风格转换(如照片转油画)
- 图像修复与超分辨率重建
- 数据增强(医学图像、遥感等领域)
- 音频合成、文本生成(结合 RNN、Transformer)
- 深度伪造(Deepfake)
下面是一个使用 PyTorch 实现的 DCGAN(Deep Convolutional GAN)完整示例,用于在 MNIST 或 CIFAR-10 数据集上生成图像。与普通 GAN 相比,DCGAN 使用卷积神经网络结构,更适合图像处理,生成效果更清晰、真实。
✅ DCGAN 与普通 GAN 的区别
模块 | 普通 GAN | DCGAN |
---|---|---|
生成器结构 | 全连接层(MLP) | 转置卷积(ConvTranspose2d) |
判别器结构 | 全连接层 | 卷积层(Conv2d) |
激活函数 | ReLU + Sigmoid | LeakyReLU + Sigmoid / No sigmoid |
正则化 | 无 / BatchNorm1d | BatchNorm2d |
✅ DCGAN 示例(基于 MNIST)
🔧 安装依赖
pip install torch torchvision matplotlib
🧠 DCGAN 架构代码(Generator + Discriminator)
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import os
# 超参数
img_size = 64
channels = 1 # 改为 3 可用于 CIFAR-10
latent_dim = 100
batch_size = 128
lr = 0.0002
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 创建目录
os.makedirs("dcgan_images", exist_ok=True)
# 数据预处理(MNIST 被 resize 成 64x64 )
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataloader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True, transform=transform),
batch_size=batch_size, shuffle=True
)
# 生成器(使用转置卷积)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
# 输入是 latent_dim 向量,输出 1024
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z):
z = z.view(z.size(0), latent_dim, 1, 1)
return self.model(z)
# 判别器(使用卷积)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(channels, 128, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x).view(-1, 1).squeeze(1)
# 初始化模型
G = Generator().to(device)
D = Discriminator().to(device)
# 损失和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
# 训练 DCGAN
for epoch in range(epochs):
for i, (real_imgs, _) in enumerate(dataloader):
real_imgs = real_imgs.to(device)
b_size = real_imgs.size(0)
# 标签
valid = torch.ones(b_size, device=device)
fake = torch.zeros(b_size, device=device)
# ========== 训练判别器 ==========
optimizer_D.zero_grad()
real_loss = criterion(D(real_imgs), valid)
z = torch.randn(b_size, latent_dim, device=device)
gen_imgs = G(z)
fake_loss = criterion(D(gen_imgs.detach()), fake)
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_D.step()
# ========== 训练生成器 ==========
optimizer_G.zero_grad()
g_loss = criterion(D(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
if i % 100 == 0:
print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
# 保存生成图像
with torch.no_grad():
z = torch.randn(64, latent_dim, device=device)
gen_imgs = G(z)
grid = make_grid(gen_imgs, nrow=8, normalize=True)
save_image(grid, f"dcgan_images/{epoch:03d}.png")
print("DCGAN 训练完成,图像保存在 dcgan_images 文件夹中。")
🧪 使用说明
若想改用 彩色图像(如 CIFAR-10),需:
channels = 3
- 使用
datasets.CIFAR10
替代 MNIST - 修改
transforms.Normalize([0.5]*3, [0.5]*3)