【DLI】Generative AI with Diffusion Models通关秘籍

发布于:2025-04-05 ⋅ 阅读:(12) ⋅ 点赞:(0)

Generative AI with Diffusion Models,加载时间在20分钟左右,耐心等待。
在这里插入图片描述
在这里插入图片描述

6.2TODO

在这里插入图片描述

这里是在设置扩散模型的参数,代码里的FIXME部分需要根据上下文进行替换。以下是各个FIXME的替换说明:
1.a_bar 是 a 的累积乘积,在 PyTorch 里可以用 torch.cumprod 实现。
2.sqrt_a_bar、sqrt_one_minus_a_bar 和 sqrt_a_inv 都是对输入张量求平方根,可使用 torch.sqrt 实现。
3.pred_noise_coeff 中的 FIXME(1 - a_bar) 同样是求平方根,用 torch.sqrt 即可。
以下是替换后的代码:

nrows = 10
ncols = 15

T = nrows * ncols
B_start = 0.0001
B_end = 0.02
B = torch.linspace(B_start, B_end, T).to(device)

a = 1.0 - B
a_bar = torch.cumprod(a, dim=0)
sqrt_a_bar = torch.sqrt(a_bar)  # Mean Coefficient
sqrt_one_minus_a_bar = torch.sqrt(1 - a_bar)  # St. Dev. Coefficient

# Reverse diffusion variables
sqrt_a_inv = torch.sqrt(1 / a)
pred_noise_coeff = (1 - a) / torch.sqrt(1 - a_bar)  # Predicted Noise Coefficient

在扩散模型里,正向扩散过程 q 函数是按照如下公式把原始图像 x_0 逐步添加噪声变成 x_t 的
在这里插入图片描述
FIXME 部分应该分别用 sqrt_a_bar_t 和 sqrt_one_minus_a_bar_t 来替换。
在这个 q 函数中,按照扩散模型的正向过程公式,把原始图像 x_0 和随机噪声 noise 按一定比例组合,从而得到加噪后的图像 x_t。

def q(x_0, t):
    t = t.int()
    noise = torch.randn_like(x_0)
    sqrt_a_bar_t = sqrt_a_bar[t, None, None, None]
    sqrt_one_minus_a_bar_t = sqrt_one_minus_a_bar[t, None, None, None]

    x_t = sqrt_a_bar_t * x_0 + sqrt_one_minus_a_bar_t * noise
    return x_t, noise

在反向扩散过程中,我们要根据当前的潜在图像,当前时间步 , 以及预测的噪声 来恢复上一个时间步的图像。在这里插入图片描述
在这个 reverse_q 函数中,我们根据反向扩散过程的公式,从当前的潜在图像和预测的噪声中恢复上一个时间步的图像。如果当前时间步为 0,则表示反向扩散过程完成。否则,我们会添加一些噪声以模拟扩散过程。下面是对代码中 FIXME 部分的分析与替换:

@torch.no_grad()
def reverse_q(x_t, t, e_t):
    t = t.int()
    pred_noise_coeff_t = pred_noise_coeff[t]
    sqrt_a_inv_t = sqrt_a_inv[t]
    u_t = sqrt_a_inv_t * (x_t - pred_noise_coeff_t * e_t)
    if t[0] == 0:  # All t values should be the same
        return u_t  # Reverse diffusion complete!
    else:
        B_t = B[t - 1]  # Apply noise from the previous timestep
        new_noise = torch.randn_like(x_t)
        return u_t + torch.sqrt(B_t) * new_noise

在这里插入图片描述

6.3TODO

在这里插入图片描述

每个类的功能来添加正确模块名 依次改写FIXME 即可:

DownBlock进行下采样操作,包含卷积和池化相关的块
EmbedBlock将输入进行线性变换和激活
GELUConvBlock使用了卷积、组归一化和 GELU 激活函数,通常是一个卷积块
RearrangePoolBlock使用了 Rearrange 进行张量重排和卷积操作
ResidualConvBlock使用了两个卷积块并进行了残差连接
SinusoidalPositionEmbedBlock实现了正弦位置嵌入的功能
UpBlock上采样操作,包含转置卷积和卷积块

6.4TODO

在这个 get_context_mask 函数里,其目的是随机丢弃上下文信息。要实现随机丢弃,通常会使用 torch.bernoulli 函数。torch.bernoulli 函数会依据给定的概率来生成一个二进制掩码张量,其中每个元素为 1 的概率就是传入的概率值。
在这个函数中,我们希望以 drop_prob 的概率丢弃上下文,所以每个元素保留的概率是 1 - drop_prob。因此,FIXME 处应该填入 bernoulli。

def get_context_mask(c, drop_prob):
    c_hot = F.one_hot(c.to(torch.int64), num_classes=N_CLASSES).to(device)
    c_mask = torch.bernoulli(torch.ones_like(c_hot).float() * (1 - drop_prob)).to(device)
    return c_hot, c_mask

代码解释:
c_hot = F.one_hot(c.to(torch.int64), num_classes=N_CLASSES).to(device):将输入的 c 转换为独热编码向量,并且移动到指定的设备(如 GPU)上。
c_mask = torch.bernoulli(torch.ones_like(c_hot).float() * (1 - drop_prob)).to(device):生成一个与 c_hot 形状相同的二进制掩码张量,每个元素以 1 - drop_prob 的概率为 1,以 drop_prob 的概率为 0。
return c_hot, c_mask:返回独热编码向量和二进制掩码张量。
这样,你就可以使用这个函数来随机丢弃上下文信息了。

在这里插入图片描述

在扩散模型里,通常采用均方误差损失(Mean Squared Error Loss,MSE)来衡量预测噪声 noise_pred 和实际添加的噪声 noise 之间的差异。因为均方误差能够很好地衡量两个向量之间的平均平方误差,这对于扩散模型中预测噪声的准确性评估是很合适的。
在 PyTorch 中,nn.functional.mse_loss 函数可用于计算均方误差损失。所以 FIXME 处应填入 mse_loss。

def get_loss(model, x_0, t, *model_args):
    x_noisy, noise = q(x_0, t)
    noise_pred = model(x_noisy, t/T, *model_args)
    return F.mse_loss(noise, noise_pred)

代码解释
x_noisy, noise = q(x_0, t):调用 q 函数给原始图像 x_0 添加噪声,得到加噪后的图像 x_noisy 以及实际添加的噪声 noise。
noise_pred = model(x_noisy, t/T, *model_args):把加噪后的图像 x_noisy 和归一化后的时间步 t/T 输入到模型 model 中,得到模型预测的噪声 noise_pred。
return F.mse_loss(noise, noise_pred):使用 F.mse_loss 函数计算实际噪声 noise 和预测噪声 noise_pred 之间的均方误差损失并返回。
通过使用均方误差损失,模型能够学习到如何更准确地预测添加到图像中的噪声,从而在反向扩散过程中更好地恢复原始图像。

下一个 TODO

  1. c_drop_prob 的设置
    c_drop_prob 是上下文丢弃概率,一般在训练过程中会采用线性衰减策略,也就是在训练初期以较高概率丢弃上下文,随着训练的推进逐渐降低丢弃概率。在代码中,我们可以简单地将其设置为一个随着训练轮数逐渐降低的值。
  2. get_context_mask 函数的输入
    get_context_mask 函数需要一个上下文标签作为输入,在代码里这个标签应该从 batch 中获取。通常假设 batch 的第二个元素为上下文标签。

optimizer = Adam(model.parameters(), lr=0.001)
epochs = 5
preview_c = 0

model.train()
for epoch in range(epochs):
    # 线性衰减上下文丢弃概率
    c_drop_prob = max(0.1, 1 - epoch / epochs)  #这里我调整了顺序
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()

        t = torch.randint(0, T, (BATCH_SIZE,), device=device).float()
        x = batch[0].to(device)
        # 假设 batch 的第二个元素是上下文标签
        c = batch[1].to(device)
        c_hot, c_mask = get_context_mask(c, c_drop_prob)
        loss = get_loss(model, x, t, c_hot, c_mask)
        loss.backward()
        optimizer.step()

        if epoch % 1 == 0 and step % 100 == 0:
            print(f"Epoch {epoch} | Step {step:03d} | Loss: {loss.item()} | C: {preview_c}")
            c_drop_prob = 0  # Do not drop context for preview
            c_hot, c_mask = get_context_mask(torch.Tensor([preview_c]).to(device), c_drop_prob)
            sample_images(model, IMG_CH, IMG_SIZE, ncols, c_hot, c_mask)
            preview_c = (preview_c + 1) % N_CLASSES

代码解释
c_drop_prob 的设置:运用线性衰减策略,在训练初期 c_drop_prob 为 0.9,随着训练的推进逐渐降低到 0.1。
get_context_mask 函数的输入:假设 batch 的第二个元素是上下文标签,将其传入 get_context_mask 函数。
训练过程:在每个训练步骤中,先将梯度清零,接着计算损失,再进行反向传播和参数更新。每训练 100 个步骤,就打印一次损失信息并进行一次样本生成。
通过这些修改,代码就能正常运行,从而开始训练模型。
在这里插入图片描述

6.5TODO

在扩散模型的采样过程中,为了给扩散过程添加权重,一般会根据给定的权重 w 对保留上下文的预测噪声 e_t_keep_c 和丢弃上下文的预测噪声 e_t_drop_c 进行加权组合。在这里插入图片描述
在代码中,FIXME 处应该根据上述公式进行计算,将 e_t_keep_c 和 e_t_drop_c 按照权重 w 进行组合。具体的代码如下:

def sample_w(model, c, w):
    input_size = (IMG_CH, IMG_SIZE, IMG_SIZE)
    n_samples = len(c)
    w = torch.tensor([w]).float()
    w = w[:, None, None, None].to(device)  # Make w broadcastable
    x_t = torch.randn(n_samples, *input_size).to(device)

    # One c for each w
    c = c.repeat(len(w), 1)

    # Double the batch
    c = c.repeat(2, 1)

    # Don't drop context at test time
    c_mask = torch.ones_like(c).to(device)
    c_mask[n_samples:] = 0.0

    x_t_store = []
    for i in range(0, T)[::-1]:
        # Duplicate t for each sample
        t = torch.tensor([i]).to(device)
        t = t.repeat(n_samples, 1, 1, 1)

        # Double the batch
        x_t = x_t.repeat(2, 1, 1, 1)
        t = t.repeat(2, 1, 1, 1)

        # Find weighted noise
        e_t = model(x_t, t, c, c_mask)
        e_t_keep_c = e_t[:n_samples]
        e_t_drop_c = e_t[n_samples:]
        e_t = w * e_t_keep_c + (1 - w) * e_t_drop_c

        # Deduplicate batch for reverse diffusion
        x_t = x_t[:n_samples]
        t = t[:n_samples]
        x_t = reverse_q(x_t, t, e_t)

    return x_t

## TODO

在扩散模型里,权重 w 可用于控制上下文信息在生成过程中的影响程度。w 值越接近 1,生成结果就越依赖上下文信息;w 值越接近 0,生成结果受上下文信息的影响就越小。若要让生成的数字能够被持续识别,你可以试着增大 w 的值,以此增强上下文信息对生成过程的影响。
下面是修改后的代码,你可以调整 w 的值来观察生成结果:

model.eval()
w = 5.0  # 可以尝试不同的值,通常大于 1 能增强上下文的影响
c = torch.arange(N_CLASSES).to(device)
c_drop_prob = 0 
c_hot, c_mask = get_context_mask(c, c_drop_prob)

x_0 = sample_w(model, c_hot, w)
other_utils.to_image(make_grid(x_0.cpu(), nrow=N_CLASSES))

代码解释
w = 5.0:把 w 的值设为 5.0,你可以根据实际情况调整这个值。通常,当 w 大于 1 时,上下文信息的影响会得到增强,这样生成的数字可能会更易于识别。
x_0 = sample_w(model, c_hot, w):调用 sample_w 函数生成图像,将 w 作为参数传入。
other_utils.to_image(make_grid(x_0.cpu(), nrow=N_CLASSES)):把生成的图像转换为可视化的形式。
你可以多次运行这段代码,并且调整 w 的值,直到生成的数字能够被稳定识别。

至此结束。
在这里插入图片描述

完整代码都在图片里