这篇文章我们就来利用之前学到的知识,基于 MNIST 数据集 生成手写数字图像。本节课的核心代码和之前文章的内容基本重合,所以这里不做过于详细地解释,主要还是通过这样一个例子来说明扩散模型如何生成图片的。
1 数据集加载
在这部分,我们将开始准备数据。我们导入了所有后续训练和生成所需的库和工具。
# 导入标准库和依赖
import glob
import math
import torch
import torch.nn.functional as F
import torch.nn as nn
# 优化器和数据加载工具
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
# 用于张量重排的 einops
from einops.layers.torch import Rearrange
# 可视化工具
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.utils import save_image, make_grid
# 用户定义的工具函数
from utils import other_utils
from utils import ddpm_utils
from utils import UNet_utils
# 设置计算设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
我们使用的数据是 MNIST,它的结构与之前用过的 FashionMNIST 非常相似,因此我们可以复用之前的数据处理逻辑。这里我们没有进行随机水平翻转操作,因为数字通常不是为了反过来看。
# 加载 MNIST 数据集
def load_MNIST(data_transform, train=True):
return torchvision.datasets.MNIST(
"./data/",
download=True,
train=train,
transform=data_transform,
)
# 对 MNIST 进行预处理、拼接训练和测试集,并返回 DataLoader
def load_transformed_MNIST(img_size, batch_size):
data_transforms = [
transforms.Resize((img_size, img_size)),
transforms.ToTensor(), # 将图像像素缩放到 [0,1]
]
data_transform = transforms.Compose(data_transforms)
train_set = load_MNIST(data_transform, train=True)
test_set = load_MNIST(data_transform, train=False)
data = torch.utils.data.ConcatDataset([train_set, test_set])
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True)
return data, dataloader
接下来,我们定义图像的基本参数,包括图像大小、通道数、批量大小等,然后加载数据。
# 定义图像大小、通道数、批量大小和类别数
IMG_SIZE = 28
IMG_CH = 1
BATCH_SIZE = 128
N_CLASSES = 10
# 加载 MNIST 数据和 DataLoader
data, dataloader = load_transformed_MNIST(IMG_SIZE, BATCH_SIZE)
# 再次定义设备(保证兼容性)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2 构建扩散过程
这一部分中,我们将初始化前向扩散过程所需的变量,并定义扩散过程的数学公式。
# 定义扩散步数与 beta 的初始值和终值
nrows = 10
ncols = 15
T = nrows * ncols
B_start = 0.0001
B_end = 0.02
B = torch.linspace(B_start, B_end, T).to(device)
下面我们计算前向过程中的关键变量: α ˉ \bar{\alpha} αˉ 及其平方根形式,还有用于逆扩散的变量。
# 计算每一步的 alpha、alpha 累乘、均值系数、标准差系数等
a = 1.0 - B
a_bar = torch.cumprod(a, dim=0)
sqrt_a_bar = torch.sqrt(a_bar) # 均值系数
sqrt_one_minus_a_bar = torch.sqrt(1 - a_bar) # 标准差系数
# 逆扩散过程中使用的变量
sqrt_a_inv = torch.sqrt(1 / a)
pred_noise_coeff = (1 - a) / torch.sqrt(1 - a_bar) # 预测噪声的系数
下面是 q
函数的定义,它将输入图像添加噪声变成扩散图像。
# 定义前向扩散函数 q(x_0, t)
def q(x_0, t):
t = t.int()
noise = torch.randn_like(x_0)
sqrt_a_bar_t = sqrt_a_bar[t, None, None, None]
sqrt_one_minus_a_bar_t = sqrt_one_minus_a_bar[t, None, None, None]
x_t = sqrt_a_bar_t * x_0 + sqrt_one_minus_a_bar_t * noise
return x_t, noise
下面这段代码会可视化每一个时间步 t
之后图像是如何变得越来越模糊、越来越像纯噪声的。
# 可视化图像如何逐渐扩散为噪声
plt.figure(figsize=(8, 8))
x_0 = data[0][0].to(device)
xs = []
for t in range(T):
t_tenser = torch.Tensor([t]).type(torch.int64)
x_t, _ = q(x_0, t_tenser)
img = torch.squeeze(x_t).cpu()
xs.append(img)
ax = plt.subplot(nrows, ncols, t + 1)
ax.axis('off')
other_utils.show_tensor_image(x_t)
输出:
接下来我们实现逆扩散函数 reverse_q
,用于将噪声图像一步步还原成原始图像。
# 定义逆扩散过程 reverse_q
@torch.no_grad()
def reverse_q(x_t, t, e_t):
t = t.int()
pred_noise_coeff_t = pred_noise_coeff[t]
sqrt_a_inv_t = sqrt_a_inv[t]
u_t = sqrt_a_inv_t * (x_t - pred_noise_coeff_t * e_t)
if t[0] == 0: # 扩散结束,返回还原图像
return u_t
else:
B_t = B[t - 1] # 添加上一个时间步的噪声
new_noise = torch.randn_like(x_t)
return u_t + torch.sqrt(B_t) * new_noise
3 构建 U-Net 网络
在这一部分,我们将定义用于预测噪声的 U-Net 架构。
# 定义 U-Net 模型
class UNet(nn.Module):
def __init__(
self, T, img_ch, img_size, down_chs=(64, 64, 128), t_embed_dim=8, c_embed_dim=10
):
super().__init__()
self.T = T
up_chs = down_chs[::-1] # 反转通道用于上采样
latent_image_size = img_size // 4 # 下采样两次后的尺寸
small_group_size = 8
big_group_size = 32
# 初始卷积层
self.down0 = ResidualConvBlock(img_ch, down_chs[0], small_group_size)
# 下采样模块
self.down1 = DownBlock(down_chs[0], down_chs[1], big_group_size)
self.down2 = DownBlock(down_chs[1], down_chs[2], big_group_size)
self.to_vec = nn.Sequential(nn.Flatten(), nn.GELU())
# 嵌入模块(时间、类别等)
self.dense_emb = nn.Sequential(
nn.Linear(down_chs[2] * latent_image_size**2, down_chs[1]),
nn.ReLU(),
nn.Linear(down_chs[1], down_chs[1]),
nn.ReLU(),
nn.Linear(down_chs[1], down_chs[2] * latent_image_size**2),
nn.ReLU(),
)
self.sinusoidaltime = SinusoidalPositionEmbedBlock(t_embed_dim)
self.t_emb1 = EmbedBlock(t_embed_dim, up_chs[0])
self.t_emb2 = EmbedBlock(t_embed_dim, up_chs[1])
self.c_embed1 = EmbedBlock(c_embed_dim, up_chs[0])
self.c_embed2 = EmbedBlock(c_embed_dim, up_chs[1])
# 上采样模块
self.up0 = nn.Sequential(
nn.Unflatten(1, (up_chs[0], latent_image_size, latent_image_size)),
GELUConvBlock(up_chs[0], up_chs[0], big_group_size),
)
self.up1 = UpBlock(up_chs[0], up_chs[1], big_group_size)
self.up2 = UpBlock(up_chs[1], up_chs[2], big_group_size)
# 最终输出卷积层
self.out = nn.Sequential(
nn.Conv2d(2 * up_chs[-1], up_chs[-1], 3, 1, 1),
nn.GroupNorm(small_group_size, up_chs[-1]),
nn.ReLU(),
nn.Conv2d(up_chs[-1], img_ch, 3, 1, 1),
)
def forward(self, x, t, c, c_mask):
# 编码过程
down0 = self.down0(x)
down1 = self.down1(down0)
down2 = self.down2(down1)
latent_vec = self.to_vec(down2)
latent_vec = self.dense_emb(latent_vec)
# 位置编码和条件编码
t = t.float() / self.T
t = self.sinusoidaltime(t)
t_emb1 = self.t_emb1(t)
t_emb2 = self.t_emb2(t)
c = c * c_mask
c_emb1 = self.c_embed1(c)
c_emb2 = self.c_embed2(c)
# 解码过程
up0 = self.up0(latent_vec)
up1 = self.up1(c_emb1 * up0 + t_emb1, down2)
up2 = self.up2(c_emb2 * up1 + t_emb2, down1)
return self.out(torch.cat((up2, down0), 1))
下面是网络结构中用到的各个功能模块:
# 下采样模块:包含两个卷积和一个重排池化
class DownBlock(nn.Module):
def __init__(self, in_chs, out_chs, group_size):
super(DownBlock, self).__init__()
layers = [
GELUConvBlock(in_chs, out_chs, group_size),
GELUConvBlock(out_chs, out_chs, group_size),
RearrangePoolBlock(out_chs, group_size),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
# 嵌入模块:用于时间和条件信息编码
class EmbedBlock(nn.Module):
def __init__(self, input_dim, emb_dim):
super(EmbedBlock, self).__init__()
self.input_dim = input_dim
layers = [
nn.Linear(input_dim, emb_dim),
nn.GELU(),
nn.Linear(emb_dim, emb_dim),
nn.Unflatten(1, (emb_dim, 1, 1)),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
x = x.view(-1, self.input_dim)
return self.model(x)
# 卷积模块,使用 GELU 激活
class GELUConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, group_size):
super().__init__()
layers = [
nn.Conv2d(in_ch, out_ch, 3, 1, 1),
nn.GroupNorm(group_size, out_ch),
nn.GELU(),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
# 重排+卷积池化模块,用于压缩特征图
class RearrangePoolBlock(nn.Module):
def __init__(self, in_chs, group_size):
super().__init__()
self.rearrange = Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2)
self.conv = GELUConvBlock(4 * in_chs, in_chs, group_size)
def forward(self, x):
x = self.rearrange(x)
return self.conv(x)
# 残差卷积模块
class ResidualConvBlock(nn.Module):
def __init__(self, in_chs, out_chs, group_size):
super().__init__()
self.conv1 = GELUConvBlock(in_chs, out_chs, group_size)
self.conv2 = GELUConvBlock(out_chs, out_chs, group_size)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
out = x1 + x2
return out
# 正弦位置编码,用于时间步嵌入
class SinusoidalPositionEmbedBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
# 上采样模块,含多个卷积块
class UpBlock(nn.Module):
def __init__(self, in_chs, out_chs, group_size):
super(UpBlock, self).__init__()
layers = [
nn.ConvTranspose2d(2 * in_chs, out_chs, 2, 2),
GELUConvBlock(out_chs, out_chs, group_size),
GELUConvBlock(out_chs, out_chs, group_size),
GELUConvBlock(out_chs, out_chs, group_size),
GELUConvBlock(out_chs, out_chs, group_size),
]
self.model = nn.Sequential(*layers)
def forward(self, x, skip):
x = torch.cat((x, skip), 1)
x = self.model(x)
return x
网络结构已经准备好了,接下来我们来定义模型并统计参数数量:
# 实例化模型,并使用 torch.compile 优化运行
model = UNet(
T, IMG_CH, IMG_SIZE, down_chs=(64, 64, 128), t_embed_dim=8, c_embed_dim=N_CLASSES
)
print("Num params: ", sum(p.numel() for p in model.parameters())) # 输出2547457
model = torch.compile(model.to(device))
4 模型训练
首先我们定义一个classifier-free guidance
函数,这个在前面的文章详细介绍过。在你之前没加入类别嵌入的时候,模型本来就是 纯无条件的,完全靠图像噪声和时间步来恢复图像。现在你加入类别嵌入后,模型变成了条件模型。为了保留无条件的能力,就用这个随机屏蔽的技巧让模型两种模式都学到。
# 随机丢弃条件编码(实现 classifier-free guidance)
def get_context_mask(c, drop_prob):
c_hot = F.one_hot(c.to(torch.int64), num_classes=N_CLASSES).to(device)
c_mask = torch.torch.bernoulli(torch.ones_like(c_hot).float() - drop_prob).to(device)
return c_hot, c_mask
然后我们定义一个损失函数,模型需要预测噪声,而我们将预测值和真实噪声之间的均方误差作为损失:
# 计算损失函数:真实噪声与预测噪声之间的 MSE
def get_loss(model, x_0, t, *model_args):
x_noisy, noise = q(x_0, t)
noise_pred = model(x_noisy, t, *model_args)
return F.mse_loss(noise, noise_pred)
然后是采样函数,用于训练中周期性生成图像,以观察训练进度:
# 从纯噪声逐步采样并生成图像,观察训练效果
def sample_images(model, img_ch, img_size, ncols, *model_args, axis_on=False):
x_t = torch.randn((1, img_ch, img_size, img_size), device=device)
plt.figure(figsize=(8, 8))
hidden_rows = T / ncols
plot_number = 1
# 从 T 倒推回 0,逐步去噪
for i in range(0, T)[::-1]:
t = torch.full((1,), i, device=device).float()
e_t = model(x_t, t, *model_args) # 预测噪声
x_t = reverse_q(x_t, t, e_t)
if i % hidden_rows == 0:
ax = plt.subplot(1, ncols+1, plot_number)
if not axis_on:
ax.axis('off')
other_utils.show_tensor_image(x_t.detach().cpu())
plot_number += 1
plt.show()
最后,我们开始训练模型:
# 训练模型主循环
optimizer = Adam(model.parameters(), lr=0.001)
epochs = 5
preview_c = 0 # 每轮预览的数字类别
model.train()
for epoch in range(epochs):
for step, batch in enumerate(dataloader):
c_drop_prob = 0.1 # 有 10% 概率丢弃条件
optimizer.zero_grad()
# 生成随机时间步
t = torch.randint(0, T, (BATCH_SIZE,), device=device).float()
x = batch[0].to(device)
c_hot, c_mask = get_context_mask(batch[1].to(device), c_drop_prob)
# 计算损失并更新参数
loss = get_loss(model, x, t, c_hot, c_mask)
loss.backward()
optimizer.step()
# 每 100 步显示一次预览
if epoch % 1 == 0 and step % 100 == 0:
print(f"Epoch {epoch} | Step {step:03d} | Loss: {loss.item()} | C: {preview_c}")
c_drop_prob = 0 # 预览时不丢条件
c_hot, c_mask = get_context_mask(torch.Tensor([preview_c]), c_drop_prob)
sample_images(model, IMG_CH, IMG_SIZE, ncols, c_hot, c_mask)
preview_c = (preview_c + 1) % N_CLASSES
部分输出如下:
5 图像采样
这是训练完成后生成图像的关键步骤。我们实现一个 Classifier-Free Diffusion Guidance
的变体,它通过加权控制有条件和无条件的预测输出之间的差异来提升图像生成的可控性。
# 基于 CFG 指导的采样函数
@torch.no_grad()
def sample_w(model, c, w):
input_size = (IMG_CH, IMG_SIZE, IMG_SIZE)
n_samples = len(c)
w = torch.tensor([w]).float()
w = w[:, None, None, None].to(device)
x_t = torch.randn(n_samples, *input_size).to(device)
# 为每个 w 创建一份条件编码
c = c.repeat(len(w), 1)
# 扩展 batch,包含保留和丢弃条件的两个版本
c = c.repeat(2, 1)
c_mask = torch.ones_like(c).to(device)
c_mask[n_samples:] = 0.0 # 后一半去掉条件
for i in range(0, T)[::-1]:
t = torch.tensor([i]).to(device)
t = t.repeat(n_samples, 1, 1, 1)
x_t = x_t.repeat(2, 1, 1, 1)
t = t.repeat(2, 1, 1, 1)
# 计算加权噪声
e_t = model(x_t, t, c, c_mask)
e_t_keep_c = e_t[:n_samples]
e_t_drop_c = e_t[n_samples:]
e_t = (1 + w) * e_t_keep_c - w * e_t_drop_c
x_t = x_t[:n_samples]
t = t[:n_samples]
x_t = reverse_q(x_t, t, e_t)
return x_t
下面这段代码用于可视化不同条件下生成的数字图像:
# 测试采样效果(可多次尝试不同 w)
model.eval()
w = 0.5 # 可调节该参数以提升图像识别性
c = torch.arange(N_CLASSES).to(device)
c_drop_prob = 0
c_hot, c_mask = get_context_mask(c, c_drop_prob)
x_0 = sample_w(model, c_hot, w)
other_utils.to_image(make_grid(x_0.cpu(), nrow=N_CLASSES))
输出:
我们需要确保生成图像的形状是 [10, 1, 28, 28]
:
# 检查图像形状是否正确
x_0.shape
传统的神经网络会有一个测试集用于评估模型表现,但在生成式 AI 中并不总是如此。因为生成的图像好坏往往是主观判断的,是否过拟合,也取决于开发者是否接受。
6 总结
至此,我们完成了《基于扩散模型的生成式AI实战》系列的最后一篇文章。本篇的目标是巩固之前的知识,并以 MNIST 数据集为例,完整实现一个能“生成手写数字”的扩散模型。
我们没有重新讲解太多概念,而是通过“动手实践”的方式,回顾了整个扩散过程的核心机制:
- 如何构建前向扩散与逆向还原过程;
- 如何利用 U-Net 架构进行噪声预测;
- 如何引入时间与条件嵌入;
- 如何训练模型预测噪声;
- 最后使用
Classifier-Free Guidance
技术增强采样效果。
如果你想进一步深入扩散模型,可以尝试更复杂的数据集(如 CIFAR-10
、CelebA
),引入更多控制条件(如文本或风格),或探索更强大的模型如 Stable Diffusion
和 ControlNet
。无论是入门学习还是项目实战,相信这一系列内容都为你打下了坚实基础。