demo_GAN

发布于:2024-10-17 ⋅ 阅读:(12) ⋅ 点赞:(0)
# 导入PyTorch库,这是一个用于深度学习的开源库
import torch
# 导入PyTorch的神经网络模块(nn),用于定义神经网络结构
import torch.nn as nn
# 导入PyTorch的函数式模块(functional),提供了一些常用的激活函数和损失函数等
import torch.nn.functional as F
# 导入PyTorch的优化器模块(optim),用于定义优化算法,如梯度下降等
import torch.optim as optim
# 从PyTorch的数据加载器模块中导入DataLoader和TensorDataset类,用于加载和处理数据集
from torch.utils.data import DataLoader, TensorDataset
# 从torchvision库的实用工具模块中导入save_image函数,用于保存生成的图像
from torchvision.utils import save_image
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
# 导入os模块,用于处理文件和目录操作
import os
import matplotlib.pyplot as plt


# 自注意力机制模块定义
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, width, height = x.size()
        query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        key = self.key(x).view(batch_size, -1, width * height)
        energy = torch.bmm(query, key)
        attention = F.softmax(energy, dim=-1)
        value = self.value(x).view(batch_size, -1, width * height)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

# Generator Model定义了一个名为Generator的神经网络模型,它继承自PyTorch框架中的nn.Module类
class Generator(nn.Module):
    def __init__(self, noise_dim, label_dim):
        super(Generator, self).__init__()
        self.label_dim = label_dim

        # 定义了一个名为self.fc的神经网络层序列,含三个层,输入层:随机噪声和标签,批量归一化层,漏洞型relu层
        self.fc = nn.Sequential(
            nn.Linear(noise_dim + label_dim, 1024 * 2 * 2),
            nn.BatchNorm1d(1024 * 2 * 2),
            # 这是一个Leaky ReLU激活函数层,它的作用是将负数的输入值乘以一个小的常数(这里是0.2),然后将结果作为输出,在原始数据上进行操作
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Hidden Layers: Deconv + BN + Leaky ReLU
        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),  # 2x2 -> 4x4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1, inplace=True),

            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # 4x4 -> 8x8
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, inplace=True),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 8x8 -> 16x16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 16x16 -> 32x32
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # 32x32 -> 64x64
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),

            # 增加一层,扩展到128x128
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),  # 64x64 -> 128x128
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=1, padding=1),  # 128x128 -> 128x128 (RGB)
            nn.Tanh()  # 输出层,范围[-1, 1]
        )

    def forward(self, noise, labels):
        # 拼接噪声和标签
        x = torch.cat((noise, labels), dim=1)
        x = self.fc(x).view(-1, 1024, 2, 2)
        return self.deconv_layers(x)


class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()

        # 第一层:并行卷积层(3×3和5×5卷积核),后续拼接
        self.conv1_3x3 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.1, inplace=True)
        )
        self.conv1_5x5 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=5, stride=2, padding=2)),
            nn.LeakyReLU(0.1, inplace=True)
        )

        # (2) conv + BN + leaky Relu (dilation rate 1)
        self.conv2 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, dilation=1)),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # (3) conv + BN + leaky Relu + self-attention mechanism
        self.conv3 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            SelfAttention(256)
        )

        # (4) conv + BN + leaky Relu (parallel 3x3, 5x5, and 7x7 kernels)
        self.conv4_3x3 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv4_5x5 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=2, padding=2)),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv4_7x7 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(in_channels=256, out_channels=512, kernel_size=7, stride=2, padding=3)),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # (5) conv + BN + leaky Relu (dilation rate 3)
        self.conv5 = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(in_channels=1536, out_channels=1024, kernel_size=3, stride=1, padding=3, dilation=3)),
            nn.LeakyReLU(0.1, inplace=True)
        )

        # (6) 用卷积层替换全连接层,输出1x1特征图,并使用sigmoid激活函数
        self.fc = nn.utils.spectral_norm(nn.Linear(1024 * 4 * 4, 1))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x1 = self.conv1_3x3(x)
        x2 = self.conv1_5x5(x)
        x = torch.cat((x1, x2), dim=1)
        x = self.conv2(x)
        x = self.conv3(x)
        x1 = self.conv4_3x3(x)
        x2 = self.conv4_5x5(x)
        x3 = self.conv4_7x7(x)
        x = torch.cat((x1, x2, x3), dim=1)
        x = self.conv5(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        x = self.sigmoid(x)
        return x

# 设置超参数
noise_dim = 100  # 噪声维度
label_dim = 58  # 标签维度
batch_size =64  # 批大小
learning_rate = 0.0001
num_epochs = 500  # 训练轮数
output_dir = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/MMSGAN"  # 生成图像保存路径

# 确保输出目录存在
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# 创建生成器和判别器
G = Generator(noise_dim=noise_dim, label_dim=label_dim).to('cuda')
D = Discriminator(input_channels=3).to('cuda')


# TrafficSignDataset类,用于数据加载
class TrafficSignDataset(Dataset):
    def __init__(self, root_dir, labels_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        with open(labels_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                img_name, label = line.strip().split()
                img_path = os.path.join(root_dir, img_name)
                self.image_paths.append(img_path)
                self.labels.append(int(label))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# 损失函数和优化器
criterion = nn.BCELoss()  # 二元交叉熵损失
optimizer_G = optim.Adam(G.parameters(), lr=learning_rate*4,betas=(0.5, 0.999),weight_decay=1e-4)
optimizer_D = optim.Adam(D.parameters(), lr=learning_rate, betas=(0.5, 0.999))
# 设置学习率衰减参数
decay = 0.0001
num_epochs = 500

# 训练循环
for epoch in range(num_epochs):
    # ... 训练过程 ...

    # 更新学习率
    lr_new_G = learning_rate * 4 / (1 + decay * num_epochs)
    lr_new_D = learning_rate / (1 + decay * num_epochs)

    for param_group in optimizer_G.param_groups:
        param_group['lr'] = lr_new_G
    for param_group in optimizer_D.param_groups:
        param_group['lr'] = lr_new_D

# 定义图像预处理和数据增强
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # 调整图像大小
    # 这个操作会将图像数据从0-255的整数值范围(如果是uint8类型)转换为0-1之间的浮点数范围,并且会将图像的形状从(H, W, C)转换为(C, H, W),其中H是高度,W是宽度,C是通道数。这样做是为了符合PyTorch模型的输入要求.
    transforms.ToTensor(),  # 转换为 Tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到 [-1, 1]
])

# 创建数据集和数据加载器
root_dir = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/BorderlineSMOTE -insepct"
labels_file = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/BorderlineSMOTE -insepct/labels.txt"  # 标签文件路径
dataset = TrafficSignDataset(root_dir=root_dir, labels_file=labels_file, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


# 生成一批随机标签(整数)及其对应的独热编码(one-hot encoding),独热编码提供了一种方便的方式来表示真实标签,使得我们可以使用交叉熵损失等损失函数来计算预测值与真实值之间的差异。
def create_labels(batch_size, label_dim):
    labels = torch.randint(0, label_dim, (batch_size,))
    labels_one_hot = torch.zeros(batch_size, label_dim).scatter_(1, labels.view(-1, 1), 1)
    return labels.to('cuda'), labels_one_hot.to('cuda')


def train():
    torch.cuda.empty_cache()


# 初始化空列表,用于存储生成器和判别器的损失值
d_losses = []
g_losses = []

# 训练循环
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        real_images = real_images.to('cuda')

        # 1. 训练判别器
        # 真实数据损失
        # 这行代码调用了一个名为create_labels的函数,该函数接收两个参数:real_images.size(0)表示真实图像的数量,label_dim表示标签的维度
        real_labels, real_labels_one_hot = create_labels(real_images.size(0), label_dim)
        real_outputs = D(real_images)
        noise_real = torch.rand_like(real_outputs) * -0.1
        real_loss = criterion(real_outputs, torch.full_like(real_outputs, 0.8) + noise_real)

        # 生成数据损失
        noise = torch.randn(real_images.size(0), noise_dim).to('cuda')
        fake_labels, fake_labels_one_hot = create_labels(real_images.size(0), label_dim)
        fake_images = G(noise, fake_labels_one_hot)
        fake_outputs = D(fake_images.detach())
        # 为假标签加入随机噪声(0, 0.1)
        noise_fake = torch.rand_like(fake_outputs) * 0.1
        fake_loss = criterion(fake_outputs, torch.full_like(fake_outputs, 0.2) + noise_fake)

        # 判别器总损失
        d_loss = real_loss + fake_loss
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # 2. 训练生成器
        fake_outputs = D(fake_images)
        g_loss = criterion(fake_outputs, torch.ones_like(fake_outputs))
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        # 追加损失值到列表中
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())

        # 打印损失值
        if i % 50 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], "
                  f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

        # 每隔一定步保存生成的图像
        if i % 200 == 0:
            # 保存每一张生成的图像
            for idx in range(min(30, fake_images.size(0))):  # 遍历生成的每一张图像
                save_image(fake_images[idx],
                           os.path.join(output_dir, f"epoch_{epoch + 1}_image_{idx + 1}.png"),
                           normalize=True)  # 保存每一张图像,命名方式包括epoch, step, 和图像编号

print("训练完成并保存生成图像。")

# 绘制生成器和判别器的损失曲线
plt.figure(figsize=(10, 5))
plt.plot(d_losses, label='Discriminator Loss', color='blue')
plt.plot(g_losses, label='Generator Loss', color='red')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('Generator and Discriminator Loss During Training')
plt.legend()
plt.grid()
plt.show()