简介
简介:在训练数据样本之前首先利用VAE来推断潜在空间中不同类的分布,用于后续的训练,并使用它来初始化GAN。与ACGAN和BAGAN不同的是,提出的GIEGAN有一个分类器结构,这个分类器主要判断生成的图像或者样本图像属于哪个类,而鉴别器仅判断图像是来自于生成器还是真实样本。
论文题目:Generator Information Enhancement Generative Adversarial Networks for Alleviating Data Imbalance Problems(生成器信息增强生成对抗网络缓解数据不平衡问题)
会议:2022 7th International Conference on Intelligent Computing and Signal Processing (ICSP)
摘要:在现实数据集中,大量的类分布是不平衡的,严重影响了深度学习方法的分类能力。数据增强可以通过为少数类创建数据来重新平衡数据集来解决此类问题。然而,如何从给定的有限训练数据的真实分布中生成数据仍然是一个挑战。为了克服这一挑战,我们提出了一种新颖有效的合成过采样方法,生成器信息增强生成对抗网络(GIE-GAN)。其生成信息在两个方面得到增强。首先,生成器可以通过使用变分自编码器进行初始化来获得所有类的公共知识。其次,在训练过程中,在GAN中加入一个能够为生成器提供准确分类结果的独立分类器,进一步增强生成信息。在基准数据集上的实验结果表明,与竞争对手的方法相比,我们的方法可以生成高质量的图像,并获得更好的深度学习分类器性能,特别是在高失衡比的情况下。
模型结构
VAE结构介绍
VAE使用编码器将图像X编码为潜在向量z和解码器,将向量Z解码为与X一样相似的图像。 潜在矢量通过编码器预测的高斯分布。 在这项工作中,VAE培训采用了L2重建损失和KL损失。
通过编码器可以得到均值与对数方差,而解码器之后可以得到重建的图像。
class VAEEncoder(nn.Module):
def __init__(self, latent_dim):
super(VAEEncoder, self).__init__()
self.fc1 = nn.Linear(1024, 256)
self.fc21 = nn.Linear(256, latent_dim) # 均值
self.fc22 = nn.Linear(256, latent_dim) # 对数方差
def forward(self, x):
h1 = torch.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
class VAEDecoder(nn.Module):
def __init__(self, latent_dim):
super(VAEDecoder, self).__init__()
self.fc1 = nn.Linear(latent_dim, 256)
self.fc2 = nn.Linear(256, 1024)
def forward(self, z):
h1 = torch.relu(self.fc1(z))
return torch.tanh(self.fc2(h1))
生成器架构
生成器的架构考虑DCGAN的模型结构。在代码中是如下安排:
class Generator(nn.Module):
def __init__(self, latent_dim, num_classes):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(num_classes, latent_dim)
self.model = nn.Sequential(
nn.ConvTranspose2d(latent_dim * 2, 512, kernel_size=4, stride=1, padding=0),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, kernel_size&#