从零开始训练Codebook:基于ViT的图像重建实践

发布于:2025-04-04 ⋅ 阅读:(27) ⋅ 点赞:(0)

完整代码在文末,可以一键运行。

在这里插入图片描述

1. 核心原理

Codebook是一种离散表征学习方法,其核心思想是将连续特征空间映射到离散的码本空间。我们的实现方案包含三个关键组件:

1.1 ViT编码器

class ViTEncoder(nn.Module):
    def __init__(self, codebook_dim=512):
        super().__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.proj = nn.Linear(768, codebook_dim)
        
    def forward(self, x):
        outputs = self.vit(x).last_hidden_state
        patch_embeddings = outputs[:, 1:, :]  # 移除CLS token
        return self.proj(patch_embeddings)
  • 使用预训练的ViT-Base模型提取图像特征
  • 移除CLS token,保留196个图像块特征
  • 线性投影调整特征维度适配Codebook

1.2 Codebook量化层

class Codebook(nn.Module):
    def __init__(self, num_embeddings=1024, embedding_dim=512):
        super().__init__()
        self.codebook = nn.Embedding(num_embeddings, embedding_dim)
        
    def quantize(self, z):
        # 计算L2距离
        distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)
        # 最近邻查找
        indices = torch.argmin(distances, dim=1)
        return indices, self.codebook(indices)
  • 使用可学习的Embedding层存储离散码本
  • 通过L2距离计算实现最近邻查找
  • 支持EMA更新(代码中已注释部分)

1.3 ViT解码器

class ViTDecoder(nn.Module):
    def __init__(self):
        self.head = nn.Sequential(
            nn.ConvTranspose2d(768, 384, 4, 2, 1),
            nn.ReLU(),
            ... # 更多上采样层
            nn.Conv2d(48, 3, 1))
  • 使用转置卷积逐步上采样
  • 最终输出224x224分辨率图像
  • 与编码器形成对称结构

2. 训练策略

2.1 多目标损失函数

total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss
  • MSE Loss: 像素级重建误差
  • Perceptual Loss: VGG16特征匹配
  • Codebook Loss: 码本向量优化
  • Commitment Loss: 编码器输出稳定性

2.2 优化技巧

opt = torch.optim.Adam([
    {'params': encoder.parameters()},
    {'params': decoder.parameters()},
    {'params': codebook.parameters(), 'lr': 1e-4}
], lr=3e-4)
  • 分层学习率设置
  • EMA指数平滑更新
  • 混合精度训练支持
  • 动态学习率调整

3. 完整训练流程

3.1 数据准备

transform_train = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(...)
])
  • CIFAR-10数据集
  • 随机裁剪+翻转增强
  • Batch Size=4适配显存

3.2 训练监控

# TensorBoard记录
writer.add_scalar('Loss/total', total_loss.item(), global_step)
writer.add_image('Reconstruction', grid, global_step)

# 控制台日志
print(f"[Epoch {epoch+1:03d}] Loss: {total_loss.item():.4f}")

完整代码

from transformers import ViTModel, ViTConfig
import torch.nn as nn
import torch
import time
from tqdm import tqdm
class ViTEncoder(nn.Module):
    def __init__(self, codebook_dim=512):
        super().__init__()
        # 加载预训练ViT-Base模型
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        # 调整输出维度匹配Codebook
        self.proj = nn.Linear(768, codebook_dim)  # 网页2/6中的线性嵌入策略
        
    def forward(self, x):
        outputs = self.vit(x).last_hidden_state  # [batch, num_patches+1, 768]
        patch_embeddings = outputs[:, 1:, :]     # 移除CLS token
        return self.proj(patch_embeddings)       # [batch, 196, 512]

class Codebook(nn.Module):
    def __init__(self, num_embeddings=16384, embedding_dim=512):
        super().__init__()
        self.codebook = nn.Embedding(num_embeddings, embedding_dim)
        nn.init.normal_(self.codebook.weight)  # 网页1的EMA更新可在此扩展
        
    def quantize(self, z):
        """
        量化输入特征向量
        参数:
            z: 输入特征 [batch, num_patches, embedding_dim]
        返回:
            indices: 最近邻码本索引 [batch, num_patches]
            quantized: 量化后的特征 [batch, num_patches, embedding_dim]
        """
        # 重塑输入为二维矩阵 [batch*num_patches, embedding_dim]
        batch, num_patches, dim = z.shape
        z_flat = z.reshape(-1, dim)  # [batch*num_patches, dim]
        
        # 计算L2距离 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2
        z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True)  # [batch*num_patches, 1]
        e_norm = torch.sum(self.codebook.weight ** 2, dim=1)  # [num_embeddings]
        dot_product = torch.matmul(z_flat, self.codebook.weight.t())  # [batch*num_patches, num_embeddings]
        
        distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)
        
        # 找到最近邻
        indices = torch.argmin(distances, dim=1)  # [batch*num_patches]
        indices = indices.reshape(batch, num_patches)  # 恢复原始形状
        quantized = self.codebook(indices)  # [batch, num_patches, dim]
        
        return indices, quantized
class ViTDecoder(nn.Module):
    def __init__(self, in_dim=512):
        super().__init__()
        # 反向映射ViT的patch嵌入
        self.proj = nn.Linear(in_dim, 768)
        config = ViTConfig()
        config.is_decoder = True  # 网页7中的解码器模式
        self.transformer = ViTModel(config).encoder  
        self.head = nn.Sequential(
            # 14x14 -> 28x28
            nn.ConvTranspose2d(768, 384, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            
            # 28x28 -> 56x56
            nn.ConvTranspose2d(384, 192, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            
            # 56x56 -> 112x112 
            nn.ConvTranspose2d(192, 96, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            
            # 112x112 -> 224x224
            nn.ConvTranspose2d(96, 48, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            
            # 最终调整到3通道
            nn.Conv2d(48, 3, kernel_size=1)
        )
        
    def forward(self, x):
        x = self.proj(x)  # [batch, 196, 768]
        x = self.transformer(x).last_hidden_state
        x = x.permute(0, 2, 1).view(-1, 768, 14, 14)  # 恢复空间布局
        return self.head(x)  # 输出[1, 3, 224, 224]
# encoder = ViTEncoder()
# codebooker = Codebook()
# decoder = ViTDecoder()

# data = torch.randn(1, 3, 224, 224)
# output = encoder(data)
# print(output.shape)
# indices, quantized = codebooker.quantize(output)
# print(indices.shape, quantized.shape)
# reconstructed = decoder(quantized)
# print(reconstructed.shape)

from torchvision import transforms
import torchvision
import torch.nn.functional as F
# 数据增强和预处理
transform_train = transforms.Compose([
    transforms.Resize(224),  # 调整图像尺寸适配模型
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# trainloader = torch.DataLoader(trainset, batch_size=64, shuffle=True)
# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

batch_size = 4  # 增大batch size加速训练
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import vgg16

# 初始化TensorBoard
writer = SummaryWriter('runs/codebook_experiment')

# 改进的Codebook类(增加EMA更新)
class Codebook(nn.Module):
    def __init__(self, num_embeddings=1024, embedding_dim=512, commitment_cost=0.25, decay=0.99):
        super().__init__()
        self.codebook = nn.Embedding(num_embeddings, embedding_dim)
        nn.init.normal_(self.codebook.weight)
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.register_buffer('ema_cluster_size', torch.zeros(num_embeddings))
        self.ema_w = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
        nn.init.normal_(self.ema_w)
        
    def quantize(self, z):
        # 重塑输入为二维矩阵 [batch*num_patches, embedding_dim]
        batch, num_patches, dim = z.shape
        z_flat = z.reshape(-1, dim)  # [batch*num_patches, dim]
        
        # 计算L2距离 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2
        z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True)  # [batch*num_patches, 1]
        e_norm = torch.sum(self.codebook.weight ** 2, dim=1)  # [num_embeddings]
        dot_product = torch.matmul(z_flat, self.codebook.weight.t())  # [batch*num_patches, num_embeddings]
        
        distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)
        
        # 找到最近邻
        indices = torch.argmin(distances, dim=1)  # [batch*num_patches]
        indices = indices.reshape(batch, num_patches)  # 恢复原始形状
        quantized = self.codebook(indices)  # [batch, num_patches, dim]
        
        # 新增EMA更新
        # if self.training:
        #     with torch.no_grad():
        #         encodings = F.one_hot(indices, self.codebook.num_embeddings).float()
        #         self.ema_cluster_size = self.decay * self.ema_cluster_size + (1 - self.decay) * torch.sum(encodings, 0)
        #         n = torch.sum(self.ema_cluster_size)
        #         self.ema_cluster_size = ((self.ema_cluster_size + 1e-5) / (n + self.codebook.num_embeddings * 1e-5) * n)
                
        #         dw = torch.matmul(encodings.t(), z_flat)
        #         self.ema_w = nn.Parameter(self.ema_w * self.decay + (1 - self.decay) * dw)
                
        #         self.codebook.weight.data = self.ema_w / self.ema_cluster_size.unsqueeze(1)
        return indices, quantized
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化组件
encoder = ViTEncoder().to(device)
codebook = Codebook(commitment_cost=0.25, decay=0.95).to(device)
decoder = ViTDecoder().to(device)
vgg = vgg16(pretrained=True).features[:16].eval().to(device)  # 用于感知损失

# 优化器分开设置
opt = torch.optim.Adam([
    {'params': encoder.parameters()},
    {'params': decoder.parameters()},
    {'params': codebook.parameters(), 'lr': 1e-4}  # 更小的学习率
], lr=3e-4)

# 训练循环
for epoch in range(100):
    avg_loss = 0
    start_time = time.time()  # 记录epoch开始时间
    for batch_idx, (images, _) in enumerate(tqdm(trainloader, desc=f"Epoch {epoch+1}", ncols=80)):
        images = images.to(device)
        
        # 前向传播
        z = encoder(images)
        indices, quantized = codebook.quantize(z)
        recon = decoder(quantized)
        
        # 多目标损失计算
        mse_loss = F.mse_loss(recon, images)
        
        # 感知损失(VGG特征匹配)
        with torch.no_grad():
            real_features = vgg(images)
        recon_features = vgg(recon)
        percep_loss = F.mse_loss(recon_features, real_features)
        
        # Codebook相关损失
        commitment_loss = codebook.commitment_cost * F.mse_loss(z.detach(), quantized)
        codebook_loss = F.mse_loss(z, quantized.detach())
        
        # 总损失
        total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss
        
        # 反向传播
        opt.zero_grad()
        total_loss.backward()
        opt.step()
        
        # 记录数据
        avg_loss += total_loss.item()
        if batch_idx % 50 == 0:
            # 记录TensorBoard数据
            writer.add_scalar('Loss/total', total_loss.item(), epoch*len(trainloader)+batch_idx)
            writer.add_scalars('Loss/components', {
                'mse': mse_loss.item(),
                'perceptual': percep_loss.item(),
                'codebook': codebook_loss.item(),
                'commitment': commitment_loss.item()
            }, epoch*len(trainloader)+batch_idx)
            
            # 保存重建样本
            comparison = torch.cat([images[:4], recon[:4]])
            grid = vutils.make_grid(comparison.cpu(), nrow=4, normalize=True)
            writer.add_image('Reconstruction', grid, epoch*len(trainloader)+batch_idx)
    
    # 打印epoch统计信息
    avg_loss /= len(trainloader)
    print(f"Epoch {epoch+1}: Avg Loss {avg_loss:.4f}")
    
    # 保存模型检查点
    if (epoch+1) % 10 == 0:
        torch.save({
            'encoder': encoder.state_dict(),
            'codebook': codebook.state_dict(),
            'decoder': decoder.state_dict(),
            'opt': opt.state_dict()
        }, f'checkpoint_epoch{epoch+1}.pth')

writer.close()


通过本实践,我们实现了从特征提取到离散表征学习的完整流程。Codebook技术可广泛应用于图像压缩、生成模型等领域,期待读者在此基础上探索更多可能性。