import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from openpyxl.styles.builtins import output
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader
import matplotlib.pyplot as plt
def gen_img_plot(model,text_input):
prediction = np.squeeze(model(text_input).detach().cpu().numpy()[:16])
plt.figure(figsize=(4,4))
for i in range(16):
plt.subplot(4,4,i+1)
plt.imshow((prediction[i]+1)/2)
plt.axis('off')
plt.show()
dataset_train = datasets.MNIST(root='./DATA',train=True,download=False,transform=transforms.Compose([transforms.Resize((28,28)),
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5],[0.5])]))
dataset_test = datasets.MNIST(root='./DATA',train=False,download=False,transform=transforms.Compose([transforms.Resize((28,28)),
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5],[0.5])]))
train_loader = DataLoader(dataset_train,batch_size=64,shuffle=True)
test_loader = DataLoader(dataset_test,batch_size=64,shuffle=False)
class Generator(nn.Module):
def __init__(self, latent_dim=100):
super(Generator, self).__init__()
self.model = nn.Sequential(
# 输入: [batch, 64, 1, 1]
nn.ConvTranspose2d(latent_dim, 32, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(True),
# [batch, 32, 4, 4]
nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(True),
# [batch, 16, 8, 8]
nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(8),
nn.ReLU(True),
# [batch, 8, 16, 16]
nn.ConvTranspose2d(8, 1, kernel_size=4, stride=2, padding=3, bias=False),
nn.Tanh()
# 输出: [batch, 1, 28, 28]
)
def forward(self, z):
return self.model(z.view(z.size(0), z.size(1), 1, 1))
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
# 输入: [batch, 1, 28, 28]
nn.Conv2d(1, 4, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# [batch, 4, 14, 14]
nn.Conv2d(4, 8, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(8),
nn.LeakyReLU(0.2, inplace=True),
# [batch, 8, 7, 7]
nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(16),
nn.LeakyReLU(0.2, inplace=True),
# [batch, 16, 4, 4]
nn.Conv2d(16, 1, kernel_size=4, stride=1, padding=0, bias=False),
nn.Sigmoid()
)
def forward(self, img):
return self.model(img).view(-1, 1).squeeze(1)
generator = Generator()
discriminator = Discriminator()
G_optimizer = torch.optim.Adam(generator.parameters(),lr=0.0001)
D_optimizer = torch.optim.Adam(discriminator.parameters(),lr=0.0002)
criterion = torch.nn.BCELoss()
num_epoch = 100
G_loss_save = []
D_loss_save = []
for epoch in range(num_epoch):
G_epoch_loss = 0
D_epoch_loss = 0
count = len(train_loader)
for i, (img,_) in enumerate(train_loader):
size = img.size(0)
#生成随机噪声
fake_img = torch.randn(size,100)
#根据随机噪声生成图像
output_fake = generator(fake_img)
#判断器判断假样本的分数
fake_score = discriminator(output_fake.detach())
#假样本趋于0的损失
D_fake_loss = criterion(fake_score,torch.zeros_like(fake_score))
#判断真样本的分数
real_score = discriminator(img)
#判断真样本趋近于1的损失
D_real_loss = criterion(real_score,torch.ones_like(real_score))
D_loss = D_fake_loss + D_real_loss
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
#训练生成器
fake_G_score = discriminator(output_fake)
#生成器要尽可能的使判别器判1
G_fake_loss = criterion(fake_G_score,torch.ones_like(fake_G_score))
G_optimizer.zero_grad()
G_fake_loss.backward()
G_optimizer.step()
with torch.no_grad():
G_epoch_loss += G_fake_loss
D_epoch_loss += D_loss
with torch.no_grad():
G_epoch_loss /= count
D_epoch_loss /= count
G_loss_save.append(G_epoch_loss)
D_loss_save.append(D_epoch_loss)
print('Epoch:[%d/%d] | G_loss:%.3f | D_loss:%.3f'%(epoch,num_epoch,G_epoch_loss,D_epoch_loss))
text_input = torch.randn(64,100)
gen_img_plot(generator,text_input)
训练50轮后效果如下: