使用生成对抗网络(GAN)进行人脸老化生成的Python示例

发布于:2025-03-17 ⋅ 阅读:(18) ⋅ 点赞:(0)

以下是一个使用生成对抗网络(GAN)进行人脸老化生成的Python示例,我们将使用PyTorch库来实现。GAN由生成器和判别器两部分组成,生成器尝试生成逼真的老化人脸图像,判别器则尝试区分生成的图像和真实的老化人脸图像。

步骤概述

  1. 数据准备:准备包含不同年龄段人脸的数据集。
  2. 定义生成器和判别器:构建生成器和判别器的神经网络模型。
  3. 训练GAN:交替训练生成器和判别器。
  4. 生成老化人脸图像:使用训练好的生成器生成老化人脸图像。

代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
import numpy as np
import matplotlib.pyplot as plt

# 定义生成器
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_dim=784):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self, img_dim=784):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)

# 超参数设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 3e-4
z_dim = 100
img_dim = 28 * 28
batch_size = 32
num_epochs = 50

# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 这里假设使用MNIST数据集作为示例,实际应用中需要使用人脸数据集
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化生成器和判别器
gen = Generator(z_dim, img_dim).to(device)
disc = Discriminator(img_dim).to(device)

# 定义优化器和损失函数
opt_gen = optim.Adam(gen.parameters(), lr=lr)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
criterion = nn.BCELoss()

# 训练GAN
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### 训练判别器
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        ### 训练生成器
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD.item():.4f}, Loss G: {lossG.item():.4f}")

# 生成老化人脸图像(这里只是简单示例,实际需要更复杂的处理)
num_samples = 16
noise = torch.randn(num_samples, z_dim).to(device)
generated_images = gen(noise).cpu().detach().view(num_samples, 28, 28).numpy()

# 可视化生成的图像
fig, axes = plt.subplots(4, 4, figsize=(4, 4))
axes = axes.flatten()
for i in range(num_samples):
    axes[i].imshow(generated_images[i], cmap='gray')
    axes[i].axis('off')
plt.show()

代码解释

  1. 数据准备:使用torchvision库加载MNIST数据集作为示例,实际应用中需要使用包含不同年龄段人脸的数据集。
  2. 生成器和判别器
    • Generator:将随机噪声向量映射到图像空间。
    • Discriminator:判断输入图像是真实的还是生成的。
  3. 训练过程
    • 交替训练判别器和生成器。
    • 判别器的目标是最大化区分真实图像和生成图像的能力。
    • 生成器的目标是生成能够欺骗判别器的图像。
  4. 生成图像:使用训练好的生成器生成老化人脸图像,并进行可视化。

注意事项

  • 此示例使用MNIST数据集作为演示,实际应用中需要使用包含不同年龄段人脸的数据集,如CACD、IMDB-WIKI等。
  • 可以根据实际情况调整超参数和网络结构,以获得更好的生成效果。