人脸图像生成(DCGAN)

发布于:2025-07-13 ⋅ 阅读:(17) ⋅ 点赞:(0)

- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/rnFa-IeY93EpjVu0yzzjkw) 中的学习记录博客**
- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

深度卷积对抗网络(Deep Convolutional Generative Adversarial Networks)是一种深度学习模型,由生成器(Generator)和判别器(Discriminator)两个神经网络组成。DCGAN 结合了卷积神经网络和生成对抗网络的思想,用于生成逼真的图像。

一. 理论基础

1.DCGAN原理

深度卷积对抗网络是生成对抗网络的一种模型改进,其将卷积运算的思想引入到生成式模型当中来做无监督的训练,利用卷积网络强大的特征提取能力来提高生成网络的学习效果。DCGAN模型有以下特点:

  • 判别器模型使用了卷积步长取代了空间池化,生成器模型中使用了反卷积操作扩大数据维度。
  • 除了生成器模型的输出层和判别器模型的输入层,在整个对抗网络的其他层上都使用了Batch Normalization, 原因是Batch Normalization 可以稳定学习,有助于优化初始化参数值不良而导致的训练问题。
  • 整个网络去除了全连接层,直接使用卷积层连接生成器和判别器的输入层以及输出层。
  • 在生成器的输出层使用Tanh激活函数以控制输出范围,而在其他层中均使用了ReLU激活函数;在判别器上使用了Leaky ReLU激活函数。

图中所示了一种常见的DCGAN结构。主要包含了一个生成网络G 和一个判别网络 D,生成网络G 负责生成图像,它接受一个随机的噪声z,通过该噪声生成图像,将生成的图像记为G(z),判别网络D 负责判断一张图是否为真实,它的输入是x,代表一张图像,输出D(x)表示x为真实图像的概率。

实际上判别网络D是对数据的来源进行一个判别:究竟这个数据是来自真是的数据分布Pd(x)判别为“1”,还是来自于一个生成网络G所产生的一个数据分布Pg(z)(判别为“0”)。所以在整个训练过程中,生成网络G的目标是生成可以以假乱真的图像G(z),当判别网络D无法区分,即D(G(z))=0.5时,便得到了一个生成网络G用来生产图像扩充数据集。

二.前期准备

1.导入第三方库

import torch,random,os
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from torch.autograd import Variable

manualSeed = 999
print("random seed:",manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True)

2.设置超参数

dataroot = "/content/drive/MyDrive/GAN_Dataset"
batch_size = 128 #训练过程中的批次大小
image_size = 64 #图像的尺寸(宽度和高度)
nz = 100 # z潜在的向量大小(生成器输入的尺寸)
ngf = 64 # 生成器中的特征图大小
ndf = 64 #判别器中的特征图大小
num_epochs = 50 #训练的总论数
lr = 0.0002 #学习率
beta1=0.5 #adam 优化器的beta1超参数

3.导入数据

dataset = dset.ImageFolder(root=dataroot,transform = transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]))
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=5)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:24],padding=2,normalize=True).cpu(),(1,2,0)))

三.定义模型

1.初始化权重

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    nn.init.normal_(m.weight.data,0.0,0.02)
  elif classname.find('BatchNorm')!=-1:
    nn.init.normal_(m.weight.data,1.0,0.02)
    nn.init.constant_(m.bias.data,0)

2.定义生成器

class Generator(nn.Module):
  def __init__(self):
    super(Generator,self).__init__()
    self.main = nn.Sequential(
        nn.ConvTranspose2d(nz,ngf*8,4,1,0,bias=False),
        nn.BatchNorm2d(ngf*8),
        nn.ReLU(True),
        #输出尺寸:(ngf*8)x4x4
        nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1,bias=False),
        nn.BatchNorm2d(ngf*4),
        nn.ReLU(True),
        #输出尺寸:(ngf*4)x8x8
        nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),
        nn.BatchNorm2d(ngf*2),
        nn.ReLU(True),
        #输出尺寸:(ngf*2)x16x16
        nn.ConvTranspose2d(ngf*2,ngf,4,2,1,bias=False),
        nn.BatchNorm2d(ngf),
        nn.ReLU(True),
        #输出尺寸:(ngf)x32x32
        nn.ConvTranspose2d(ngf,3,4,2,1,bias=False),
        nn.Tanh()
        #输出尺寸:3x64x64
    )
  
  def forward(self,input):
    return self.main(input)
#创建生成器
netG = Generator().to(device)
netG.apply(weights_init)
print(netG)

3.定义鉴别器

class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    self.main = nn.Sequential(
        nn.Conv2d(3,ndf,4,2,1,bias=False),
        nn.LeakyReLU(0.2,inplace=True),
        #输出尺寸:(ndf)x32x32
        nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),
        nn.BatchNorm2d(ndf*2),
        nn.LeakyReLU(0.2,inplace=True),
        #输出尺寸:(ndf*2)x16x16
        nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),
        nn.BatchNorm2d(ndf*4),
        nn.LeakyReLU(0.2,inplace=True),
        #输出尺寸:(ndf*4)x8x8
        nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),
        nn.BatchNorm2d(ndf*8),
        nn.LeakyReLU(0.2,inplace=True),
        #输出尺寸:(ndf*8)x4x4
        nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
        nn.Sigmoid()
    )

  def forward(self,input):
    return self.main(input)
#创建判别器模型
netD = Discriminator().to(device)
netD.apply(weights_init)
print(netD)

四:训练模型

1.定义训练参数

criterion = nn.BCELoss()
fixed_noise = torch.randn(64,nz,1,1,device = device)

real_label =1.
fake_label =0.

optimizerD = optim.Adam(netD.parameters(),lr=lr,betas=(beta1,0.999))
optimizerG = optim.Adam(netG.parameters(),lr=lr,betas=(beta1,0.999))

2.训练模型

下面的训练代码是一个典型的GAN训练循环。在训练过程中,首先更新判别器网络,然后更新生成器网络。在每个epoch的每个batch中,会进行以下操作:

  • 更新判别器网络:通过训练真实图像样本和生成图像样本,最大化判别器的损失。具体步骤如下:

    • 对于真实图像样本,计算判别器对真实图像样本的输出和真实标签之间的损失,然后进行反向传播计算梯度。
    • 对于生成的图像样本,计算判别器对生成图像样本的输出和假标签之间的损失,然后进行反向传播计算梯度。
    • 将真实图像样本的损失和生成图像样本的损失相加得到判别器的总损失,并更新判别器的参数。
  • 更新生成器网络:通过最大化生成器的损失,迫使生成器产生更逼真的图像样本。具体步骤如下:

    • 使用生成器生成一批假图像样本。
    • 将生成图像样本输入判别器,计算判别器对生成图像样本的输出和真实标签之间的损失,并进行反向传播计算生成器的梯度。
    • 更新生成器的参数。
  • 输出训练统计信息:每隔一定的步数,输出当前训练的epoch、batch以及判别器和生成器的损失值等信息。

  • 保存损失值:将生成器和判别器的损失值存储到相应的列表中,以便后续绘图和分析。

  • 检查生成器的性能:每隔一定的步数或者在训练结束时,通过将固定的噪声输入生成器,生成一批图像样本,并保存到img_list列表中。这样可以观察生成器在训练过程中生成的图像质量的变化。

  • 更新迭代次数:每完成一个batch的训练,将迭代次数iters加1。

总体来说,这段代码实现了GAN的训练过程,通过交替更新判别器和生成器的参数,目标是使生成器生成逼真的图像样本,同时判别器能够准确区分真实图像样本和生成图像样本。

img_list =[]
G_losses=[]
D_losses=[]
iters=0
print("start training")

for epoch in range(num_epochs):
  for i,data in enumerate(dataloader,0):
    netD.zero_grad()
    real_cpu = data[0].to(device)
    b_size = real_cpu.size(0)
    label = torch.full((b_size,),real_label,dtype=torch.float,device=device)

    output = netD(real_cpu).view(-1)
    errD_real = criterion(output,label)
    errD_real.backward()
    D_x = output.mean().item()

    #使用生成图像样本训练
    noise = torch.randn(b_size,nz,1,1,device=device)
    fake = netG(noise)
    label.fill_(fake_label)
    output = netD(fake.detach()).view(-1)
    errD_fake = criterion(output,label)
    errD_fake.backward()
    D_G_z1 = output.mean().item()
    errD = errD_real + errD_fake
    optimizerD.step()

    #更新生成器网络
    netG.zero_grad()
    label.fill_(real_label)
    output = netD(fake).view(-1)
    errG = criterion(output,label)
    errG.backward()
    D_G_z2 = output.mean().item()
    optimizerG.step()

    if i % 400 == 0:
      print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
              % (epoch, num_epochs, i, len(dataloader),errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
    
    G_losses.append(errG.item())
    D_losses.append(errD.item())

    if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
      with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
      img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
    iters += 1

3.可视化

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

real_batch = next(iter(dataloader))
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()


网站公告

今日签到

点亮在社区的每一天
去签到