第 7 期:DDPM 采样提速方案:从 DDPM 到 DDIM

发布于:2025-04-20 ⋅ 阅读:(18) ⋅ 点赞:(0)

本期关键词:采样加速、DDIM 推导、可控性提升、伪逆过程、代码实战

前情回顾:DDPM 的采样瓶颈

在前几期中,我们构建了一个完整的 DDPM 生成流程。但是你可能已经发现:

生成一张图像太慢了!!!

原因是:
DDPM 要在 T 个时间步中一步步地去噪,从 x_T → x_0。而通常 T 至少为 1000,采样一次就意味着 1000 次前向推理,非常耗时!

目标:更快的采样方法!

本期,我们引入一种“非随机”的采样机制 —— DDIM(Denoising Diffusion Implicit Models)

它能在 保留图像质量的同时,将采样步骤从 1000 步减少到几十步!
比如 T=1000 → 50,加速 20 倍+

数学推导:DDIM 与 DDPM 的关系

DDPM 复习公式

我们知道在 DDPM 中,每一步的去噪过程是:

其中 z 是随机噪声。DDIM 做的事就是:

去掉这一步的随机性,将采样变为 确定性过程

DDIM 推导核心公式

这里的 x_0 是模型预测的原始图像,通过 x_0 = (x_t - √(1 - ᾱ_t) * ε) / √(ᾱ_t) 得到。

直观理解:DDIM 是一种“伪逆”的过程,保留了模型预测的主导性。

 

PyTorch 实现 DDIM 推理过程

我们只需要修改之前的采样函数,引入 DDIM:

@torch.no_grad()
def ddim_sample(model, img_size=32, num_samples=16, ddim_steps=50, device='cuda'):
    model.eval()
    step_size = T // ddim_steps
    x_t = torch.randn(num_samples, 3, img_size, img_size).to(device)

    for i in range(0, T, step_size):
        t = torch.full((num_samples,), T - 1 - i, device=device, dtype=torch.long)
        alpha = alphas_cumprod[t].to(device)[:, None, None, None]
        sqrt_alpha = torch.sqrt(alpha)
        sqrt_one_minus_alpha = torch.sqrt(1 - alpha)

        with torch.no_grad():
            epsilon = model(x_t, t.float())
            x_0_pred = (x_t - sqrt_one_minus_alpha * epsilon) / sqrt_alpha
            x_0_pred = x_0_pred.clamp(-1, 1)

        next_t = torch.full((num_samples,), max(t[0] - step_size, 0), device=device, dtype=torch.long)
        next_alpha = alphas_cumprod[next_t].to(device)[:, None, None, None]
        x_t = torch.sqrt(next_alpha) * x_0_pred + torch.sqrt(1 - next_alpha) * epsilon

    return x_0_pred

生成样本可视化

samples = ddim_sample(model, num_samples=16, ddim_steps=50)
samples = (samples.clamp(-1, 1) + 1) / 2

grid = torchvision.utils.make_grid(samples, nrow=4)
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.title("DDIM Fast Sampling Result")
plt.show()

 运行效果图示例:

DDPM vs DDIM 对比

项目 DDPM DDIM
是否随机 ✅ 是 ❌ 否
是否严格等价 ✅ 是 ❌ 不是(近似)
是否可控(重建) ❌ 否 ✅ 是
采样速度 慢(1000步) 快(<50步)
图像质量 接近 DDPM

✅ 总结

在本期中,我们完成了:

  • ✅ DDIM 理论推导;

  • ✅ DDIM PyTorch 实现;

  • ✅ CIFAR-10 样本生成展示;

  • ✅ 与 DDPM 的对比分析。

第 8 期预告:条件生成!

下一期我们将引入 类条件 DDPM,尝试生成某个指定类别的图像(如飞机、青蛙、猫等)!实现“我想生成第几类”的定向控制!

 


网站公告

今日签到

点亮在社区的每一天
去签到