PyTorch 深度学习实战(27):扩散模型(Diffusion Models)与图像生成

发布于:2025-03-30 ⋅ 阅读:(30) ⋅ 点赞:(0)

一、扩散模型原理

1. 核心思想

扩散模型(Diffusion Models)通过逐步添加和去除噪声学习数据分布,核心分为两个过程:

2. 训练目标(简化损失函数)

通过最小化预测噪声的均方误差:


二、扩散模型实现(基于 PyTorch)

MNIST 手写数字生成 为例,实现扩散模型:

  1. 定义噪声调度:线性或余弦噪声计划

  2. 构建 UNet 模型:预测每一步的噪声

  3. 实现扩散过程:前向加噪与反向去噪

  4. 训练与生成:从噪声生成图像


三、代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
​
# ================== 配置参数 ==================
class DiffusionConfig:
    image_size = 28              # MNIST 图像大小
    batch_size = 64             # 批量大小
    num_epochs = 100             # 训练轮数
    timesteps = 1000             # 扩散步数
    beta_start = 1e-4            # 初始噪声系数
    beta_end = 0.02              # 最终噪声系数
    lr = 2e-4                    # 学习率
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
​
# ================== UNet 模型 ==================
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder with reduced channels
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),  # Reduced from 64 to 32
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        # Middle layer with reduced channels
        self.mid = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1),  # Reduced from 128 to 64
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
        )
        
        # Decoder with reduced channels
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 2, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 1, 3, padding=1)
        )
        
        # Time embedding with reduced dimension
        self.time_embed = nn.Embedding(DiffusionConfig.timesteps, 64)  # Reduced from 128 to 64
        self.time_proj = nn.Sequential(
            nn.Linear(64, 32),  # Adjusted to match encoder channels
            nn.ReLU(),
            nn.Linear(32, 32)
        )
    
    def forward(self, x, t):
        # Time embedding
        t_embed = self.time_embed(t)
        t_embed = self.time_proj(t_embed)
        t_embed = t_embed.view(-1, 32, 1, 1)  # Adjusted to match encoder channels
        
        # Encoder
        x = self.encoder(x)
        
        # Middle layer with time information
        x = x + t_embed
        x = self.mid(x)
        
        # Decoder
        x = self.decoder(x)
        return x
​
# ================== 扩散过程工具函数 ==================
def linear_beta_schedule(timesteps, beta_start, beta_end):
    return torch.linspace(beta_start, beta_end, timesteps)
​
def forward_diffusion(x0, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
    noise = torch.randn_like(x0)
    # Move the indexing tensors to the same device as the input tensor
    sqrt_alpha = sqrt_alphas_cumprod.to(x0.device)[t].view(-1, 1, 1, 1)
    sqrt_one_minus_alpha = sqrt_one_minus_alphas_cumprod.to(x0.device)[t].view(-1, 1, 1, 1)
    xt = sqrt_alpha * x0 + sqrt_one_minus_alpha * noise
    return xt, noise
​
# ================== 训练系统 ==================
class DiffusionTrainer:
    def __init__(self):
        # 数据加载
        transform = transforms.Compose([transforms.ToTensor()])
        self.dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        self.dataloader = DataLoader(self.dataset, batch_size=DiffusionConfig.batch_size, shuffle=True)
        
        # 初始化模型与优化器
        self.model = UNet().to(DiffusionConfig.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=DiffusionConfig.lr)
        
        # 定义噪声调度
        self.betas = linear_beta_schedule(DiffusionConfig.timesteps, DiffusionConfig.beta_start, DiffusionConfig.beta_end).to(DiffusionConfig.device)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
    
    def train(self):
        for epoch in range(DiffusionConfig.num_epochs):
            for batch, (images, _) in enumerate(self.dataloader):
                images = images.to(DiffusionConfig.device)
                batch_size = images.size(0)
                
                # 随机选择时间步
                t = torch.randint(0, DiffusionConfig.timesteps, (batch_size,), device=DiffusionConfig.device)
                
                # 前向扩散加噪
                xt, noise = forward_diffusion(images, t, self.sqrt_alphas_cumprod, self.sqrt_one_minus_alphas_cumprod)
                
                # 预测噪声
                pred_noise = self.model(xt, t)
                
                # 计算损失
                loss = nn.MSELoss()(pred_noise, noise)
                
                # 反向传播
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                if batch % 100 == 0:
                    print(f"Epoch {epoch+1} | Batch {batch} | Loss: {loss.item():.4f}")
    
    def generate(self, num_samples=16):
        # 从纯噪声开始生成
        x = torch.randn(num_samples, 1, 28, 28, device=DiffusionConfig.device)
        for t in reversed(range(DiffusionConfig.timesteps)):
            t_tensor = torch.full((num_samples,), t, device=DiffusionConfig.device)
            pred_noise = self.model(x, t_tensor)
            alpha_t = self.alphas[t]
            beta_t = self.betas[t]
            if t > 0:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)
            x = (x - beta_t * pred_noise / torch.sqrt(1 - self.alphas_cumprod[t])) / torch.sqrt(alpha_t)
            x += torch.sqrt(beta_t) * noise
        return x
​
if __name__ == "__main__":
    trainer = DiffusionTrainer()
    print("开始训练扩散模型...")
    trainer.train()
    print("训练完成,生成示例图像...")
    generated_images = trainer.generate()
    
    # 可视化生成结果
    plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.imshow(generated_images[i].cpu().detach().squeeze(), cmap='gray')
        plt.axis('off')
    plt.savefig('generated_images.png')
    plt.show()

四、关键代码解析


五、训练输出示例

开始训练扩散模型...
Epoch 1 | Batch 0 | Loss: 1.0092
Epoch 1 | Batch 100 | Loss: 0.6777
Epoch 1 | Batch 200 | Loss: 0.4815
Epoch 1 | Batch 300 | Loss: 0.4252
Epoch 1 | Batch 400 | Loss: 0.3800
Epoch 1 | Batch 500 | Loss: 0.3209
Epoch 1 | Batch 600 | Loss: 0.3728
Epoch 1 | Batch 700 | Loss: 0.3235
Epoch 1 | Batch 800 | Loss: 0.2384
Epoch 1 | Batch 900 | Loss: 0.2338
Epoch 2 | Batch 0 | Loss: 0.2652
Epoch 2 | Batch 100 | Loss: 0.2383
Epoch 2 | Batch 200 | Loss: 0.2384
Epoch 2 | Batch 300 | Loss: 0.2533
Epoch 2 | Batch 400 | Loss: 0.2102
Epoch 2 | Batch 500 | Loss: 0.2475
Epoch 2 | Batch 600 | Loss: 0.2152
Epoch 2 | Batch 700 | Loss: 0.2016
Epoch 2 | Batch 800 | Loss: 0.2483
Epoch 2 | Batch 900 | Loss: 0.2043
Epoch 3 | Batch 0 | Loss: 0.1791
Epoch 3 | Batch 100 | Loss: 0.1749
Epoch 3 | Batch 200 | Loss: 0.1671
Epoch 3 | Batch 300 | Loss: 0.2204
Epoch 3 | Batch 400 | Loss: 0.1716
Epoch 3 | Batch 500 | Loss: 0.1707
Epoch 3 | Batch 600 | Loss: 0.1578
Epoch 3 | Batch 700 | Loss: 0.1583
Epoch 3 | Batch 800 | Loss: 0.1771
Epoch 3 | Batch 900 | Loss: 0.1452
Epoch 4 | Batch 0 | Loss: 0.1641
Epoch 4 | Batch 100 | Loss: 0.1811
Epoch 4 | Batch 200 | Loss: 0.1840
Epoch 4 | Batch 300 | Loss: 0.1479
Epoch 4 | Batch 400 | Loss: 0.1603
Epoch 4 | Batch 500 | Loss: 0.1199
Epoch 4 | Batch 600 | Loss: 0.1268
Epoch 4 | Batch 700 | Loss: 0.1435
Epoch 4 | Batch 800 | Loss: 0.1435
Epoch 4 | Batch 900 | Loss: 0.1182
Epoch 5 | Batch 0 | Loss: 0.1348
Epoch 5 | Batch 100 | Loss: 0.1621
Epoch 5 | Batch 200 | Loss: 0.1395
Epoch 5 | Batch 300 | Loss: 0.1433
Epoch 5 | Batch 400 | Loss: 0.1345
Epoch 5 | Batch 500 | Loss: 0.1585
Epoch 5 | Batch 600 | Loss: 0.1284
Epoch 5 | Batch 700 | Loss: 0.1516
Epoch 5 | Batch 800 | Loss: 0.1576
Epoch 5 | Batch 900 | Loss: 0.1025
Epoch 6 | Batch 0 | Loss: 0.1276
Epoch 6 | Batch 100 | Loss: 0.1517
Epoch 6 | Batch 200 | Loss: 0.1494
Epoch 6 | Batch 300 | Loss: 0.1365
Epoch 6 | Batch 400 | Loss: 0.1123
Epoch 6 | Batch 500 | Loss: 0.1200
Epoch 6 | Batch 600 | Loss: 0.1064
Epoch 6 | Batch 700 | Loss: 0.1122
Epoch 6 | Batch 800 | Loss: 0.1074
Epoch 6 | Batch 900 | Loss: 0.1047
Epoch 7 | Batch 0 | Loss: 0.1221
Epoch 7 | Batch 100 | Loss: 0.1275
Epoch 7 | Batch 200 | Loss: 0.1192
Epoch 7 | Batch 300 | Loss: 0.1244
Epoch 7 | Batch 400 | Loss: 0.1094
Epoch 7 | Batch 500 | Loss: 0.0971
Epoch 7 | Batch 600 | Loss: 0.1223
Epoch 7 | Batch 700 | Loss: 0.1337
Epoch 7 | Batch 800 | Loss: 0.1086
Epoch 7 | Batch 900 | Loss: 0.1013
Epoch 8 | Batch 0 | Loss: 0.1212
Epoch 8 | Batch 100 | Loss: 0.1041
Epoch 8 | Batch 200 | Loss: 0.0922
Epoch 8 | Batch 300 | Loss: 0.1079
Epoch 8 | Batch 400 | Loss: 0.1168
Epoch 8 | Batch 500 | Loss: 0.1020
Epoch 8 | Batch 600 | Loss: 0.1096
Epoch 8 | Batch 700 | Loss: 0.1301
Epoch 8 | Batch 800 | Loss: 0.1171
Epoch 8 | Batch 900 | Loss: 0.0942
Epoch 9 | Batch 0 | Loss: 0.1196
Epoch 9 | Batch 100 | Loss: 0.0962
Epoch 9 | Batch 200 | Loss: 0.0923
Epoch 9 | Batch 300 | Loss: 0.1296
Epoch 9 | Batch 400 | Loss: 0.0929
Epoch 9 | Batch 500 | Loss: 0.1088
Epoch 9 | Batch 600 | Loss: 0.0825
Epoch 9 | Batch 700 | Loss: 0.1044
Epoch 9 | Batch 800 | Loss: 0.0999
Epoch 9 | Batch 900 | Loss: 0.1232
Epoch 10 | Batch 0 | Loss: 0.1136
Epoch 10 | Batch 100 | Loss: 0.0980
Epoch 10 | Batch 200 | Loss: 0.0938
Epoch 10 | Batch 300 | Loss: 0.0960
Epoch 10 | Batch 400 | Loss: 0.1349
Epoch 10 | Batch 500 | Loss: 0.0962
Epoch 10 | Batch 600 | Loss: 0.1089
Epoch 10 | Batch 700 | Loss: 0.0783
Epoch 10 | Batch 800 | Loss: 0.1010
Epoch 10 | Batch 900 | Loss: 0.1262
​​​​​​​...
Epoch 100 | Batch 0 | Loss: 0.0415
Epoch 100 | Batch 100 | Loss: 0.0342
Epoch 100 | Batch 200 | Loss: 0.0302
Epoch 100 | Batch 300 | Loss: 0.0402
Epoch 100 | Batch 400 | Loss: 0.0497
Epoch 100 | Batch 500 | Loss: 0.0316
Epoch 100 | Batch 600 | Loss: 0.0379
Epoch 100 | Batch 700 | Loss: 0.0344
Epoch 100 | Batch 800 | Loss: 0.0433
Epoch 100 | Batch 900 | Loss: 0.0403
训练完成,生成示例图像...

生成图像示例(MNIST 手写数字):


在下一篇文章中,我们将深入探讨对比学习(Contrastive Learning)与自监督表示学习,揭示如何通过无监督数据构建强大的视觉表征,并分析其在图像检索、分类等任务中的核心优势!


注意事项

  1. 扩散模型训练需要较长时间(推荐使用 GPU)。

  2. 实际应用中可调整噪声调度策略(如余弦调度)提升生成质量。