第G1周:生成对抗网络(GAN)入门

发布于:2024-06-29 ⋅ 阅读:(16) ⋅ 点赞:(0)

本文为🔗365天深度学习训练营中的学习记录博客
🍖 原作者:K同学啊 | 接辅导、项目定制
🚀 文章来源:K同学的学习圈子深度学习第J6周:ResNeXt-50实战解析

一.理论

生成对抗网络(Generative Adversarial Networks,GAN)是近年来深度学习领域的一个热点方向,GAN并不指代某一个具体的神经网络,而是指一类基于博弈思想而设计的神经网络。

GAN由两个分别被称为生成器(Generator)和判别器(Discriminator)的神经网络组成。

其中,生成器从某种噪声分布中随机采样作为输入,输出与训练集中与真实样本非常相似的人工样本;

判别器的输入则为真实样本或人工样本,其目的是将人工样本与真实样本尽可能地区分出来。

生成器和判别器交替运行,相互博弈,各自的能力都得到升。理想情况下,经过足够次数的博弈之后,判别器无法判断给定样本的真实性,即对于所有样本都输出50%真,50%假的判断。此时,生成器输出的人工样本已经逼真到使判别器无法分辨真假,停止博弈。这样就可以得到一个具有“伪造”真实样本能力的生成器。

GANS中,判别器 D对于输入的样本 x,输出一个[0,1]之间的概率数值 D(x)。x可能是来自于原始数据集中的真实样本 x,也可能是来自于生成器 G的人工样本 G(z)。通常约定,概率值 D(x)越接近于1就代表此样本为真实样本的可能性更大;反之概率值越小则此样本为伪造样本的可能性越大。也就是说,这里的判别器是一个二分类的神经网络分类器,目的不是判定输入数据的原始类别,而是区分输入样本的真伪。可以注意到,不管在生成器还是判别器中,样本的类别信息都没有用到,也表明 GAN 是一个无监督的学习过程。

如图1所示,GAN是由两个模型组成的:生成模型G和判别模型D。首先第一代生成模型1G的输入是随机噪声z,然后生成模型会生成一张初级照片,训练一代判别模型1D另其进行二分类操作,将生成的图片判别为0,而真实图片判别为1;为了欺瞒一代鉴别器,于是一代生成模型开始优化,然后它进阶成了二代,当它生成的数据成功欺瞒1D时,鉴别模型也会优化更新,进而升级为2D,按照同样的过程也会不断更新出N代的G和D。

二、代码

import argparse
import os
import numpy as np
import torch.cuda
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import  DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch
 
os.makedirs("./images/",exist_ok = True) #记录训练过程的图片效果
os.makedirs("./save/",exist_ok = True) #训练完成时模型保存的位置
os.makedirs("./datasets/mnist",exist_ok = True) #下载数据集存放的位置
 
#超参数配置
n_epochs = 50
batch_size = 64
Ir = 0.0002
b1 = 0.5
b2 = 0.999
n_cpu = 2
latent_dim = 100
img_size = 28
channels = 1
sample_interval = 500
 
img_shape = (channels,img_size,img_size) #图像尺寸(1,28,28)
img_area = np.prod(img_shape) #图像像素面积784
 
cuda = True if torch.cuda.is_available() else False
print(cuda)
mnist = datasets.MNIST(
    root = './datasets/',train = True,download =True,transform = transforms.Compose(
     [transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5],[0.5])]),)
dataoader=DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True,
)
#鉴别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area,512),
            nn.LeakyReLU(0.2,inplace = True),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(256,1),
            nn.Sigmoid(),
        )
 
    def forward(self,img):
        img_flat = img.view(img.size(0),-1)
        validity = self.model(img_flat)
        return validity
##生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
 
        def block(in_feat,out_feat,normalize = True):
            layers = [nn.Linear(in_feat,out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat,0.8))
            layers.append(nn.LeakyReLU(0.2,inplace = True))
            return layers
 
        self.model = nn.Sequential(
            *block(latent_dim,128,normalize = False),
            *block(128,256),
            *block(256,512),
            *block(512,1024),
            nn.Linear(1024,img_area),
            nn.Tanh()
        )
 
    def forward(self,z):
        imgs = self.model(z)
        imgs = imgs.view(imgs.size(0),*img_shape)
        return imgs
generator = Generator()
discriminator = Discriminator()
 
criterion = torch.nn.BCELoss()
 
optimizer_G = torch.optim.Adam(generator.parameters(), lr=Ir, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr = Ir,betas = (b1,b2))
 
if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion = criterion.cuda()
for epoch in range(n_epochs):
    for i,(imgs,_) in enumerate(dataloader):
 
        imgs = imgs.view(imgs.size(0),-1)
        real_img = Variable(imgs).cuda()
        real_label = Variable(torch.ones(imgs.size(0),1)).cuda()
        fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()
 
        real_out = discriminator(real_img)
        loss_real_D = criterion(real_out,real_label)
        real_scores = real_out
 
        z = Variable(torch.randn(imgs.size(0),latent_dim)).cuda()
        fake_img = generator(z).detach()
        fake_out = discriminator(fake_img)
        loss_fake_D = criterion(fake_out,fake_label)
        fake_scores = fake_out
 
        loss_D = loss_real_D + loss_fake_D
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
 
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()
        fake_img = generator(z)
        fake_out = discriminator(fake_img)
        loss_G = criterion(fake_out, real_label)
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
 
        if (i + 1) % 300 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch, n_epochs, i ,len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
            )
 
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25],"./images/%d.png" % batches_done, nrow = 5, normalize = True)
 
#模型保存
torch.save(generator.state_dict(),'./save/generator.pth')
torch.save(discriminator.state_dict(),'./save/discriminator.pth')


网站公告

今日签到

点亮在社区的每一天
去签到