完整代码在文末,可以一键运行。
1. 核心原理
Codebook是一种离散表征学习方法,其核心思想是将连续特征空间映射到离散的码本空间。我们的实现方案包含三个关键组件:
1.1 ViT编码器
class ViTEncoder(nn.Module):
def __init__(self, codebook_dim=512):
super().__init__()
self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
self.proj = nn.Linear(768, codebook_dim)
def forward(self, x):
outputs = self.vit(x).last_hidden_state
patch_embeddings = outputs[:, 1:, :] # 移除CLS token
return self.proj(patch_embeddings)
- 使用预训练的ViT-Base模型提取图像特征
- 移除CLS token,保留196个图像块特征
- 线性投影调整特征维度适配Codebook
1.2 Codebook量化层
class Codebook(nn.Module):
def __init__(self, num_embeddings=1024, embedding_dim=512):
super().__init__()
self.codebook = nn.Embedding(num_embeddings, embedding_dim)
def quantize(self, z):
# 计算L2距离
distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)
# 最近邻查找
indices = torch.argmin(distances, dim=1)
return indices, self.codebook(indices)
- 使用可学习的Embedding层存储离散码本
- 通过L2距离计算实现最近邻查找
- 支持EMA更新(代码中已注释部分)
1.3 ViT解码器
class ViTDecoder(nn.Module):
def __init__(self):
self.head = nn.Sequential(
nn.ConvTranspose2d(768, 384, 4, 2, 1),
nn.ReLU(),
... # 更多上采样层
nn.Conv2d(48, 3, 1))
- 使用转置卷积逐步上采样
- 最终输出224x224分辨率图像
- 与编码器形成对称结构
2. 训练策略
2.1 多目标损失函数
total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss
- MSE Loss: 像素级重建误差
- Perceptual Loss: VGG16特征匹配
- Codebook Loss: 码本向量优化
- Commitment Loss: 编码器输出稳定性
2.2 优化技巧
opt = torch.optim.Adam([
{'params': encoder.parameters()},
{'params': decoder.parameters()},
{'params': codebook.parameters(), 'lr': 1e-4}
], lr=3e-4)
- 分层学习率设置
- EMA指数平滑更新
- 混合精度训练支持
- 动态学习率调整
3. 完整训练流程
3.1 数据准备
transform_train = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(224, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(...)
])
- CIFAR-10数据集
- 随机裁剪+翻转增强
- Batch Size=4适配显存
3.2 训练监控
# TensorBoard记录
writer.add_scalar('Loss/total', total_loss.item(), global_step)
writer.add_image('Reconstruction', grid, global_step)
# 控制台日志
print(f"[Epoch {epoch+1:03d}] Loss: {total_loss.item():.4f}")
完整代码
from transformers import ViTModel, ViTConfig
import torch.nn as nn
import torch
import time
from tqdm import tqdm
class ViTEncoder(nn.Module):
def __init__(self, codebook_dim=512):
super().__init__()
# 加载预训练ViT-Base模型
self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
# 调整输出维度匹配Codebook
self.proj = nn.Linear(768, codebook_dim) # 网页2/6中的线性嵌入策略
def forward(self, x):
outputs = self.vit(x).last_hidden_state # [batch, num_patches+1, 768]
patch_embeddings = outputs[:, 1:, :] # 移除CLS token
return self.proj(patch_embeddings) # [batch, 196, 512]
class Codebook(nn.Module):
def __init__(self, num_embeddings=16384, embedding_dim=512):
super().__init__()
self.codebook = nn.Embedding(num_embeddings, embedding_dim)
nn.init.normal_(self.codebook.weight) # 网页1的EMA更新可在此扩展
def quantize(self, z):
"""
量化输入特征向量
参数:
z: 输入特征 [batch, num_patches, embedding_dim]
返回:
indices: 最近邻码本索引 [batch, num_patches]
quantized: 量化后的特征 [batch, num_patches, embedding_dim]
"""
# 重塑输入为二维矩阵 [batch*num_patches, embedding_dim]
batch, num_patches, dim = z.shape
z_flat = z.reshape(-1, dim) # [batch*num_patches, dim]
# 计算L2距离 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2
z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True) # [batch*num_patches, 1]
e_norm = torch.sum(self.codebook.weight ** 2, dim=1) # [num_embeddings]
dot_product = torch.matmul(z_flat, self.codebook.weight.t()) # [batch*num_patches, num_embeddings]
distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)
# 找到最近邻
indices = torch.argmin(distances, dim=1) # [batch*num_patches]
indices = indices.reshape(batch, num_patches) # 恢复原始形状
quantized = self.codebook(indices) # [batch, num_patches, dim]
return indices, quantized
class ViTDecoder(nn.Module):
def __init__(self, in_dim=512):
super().__init__()
# 反向映射ViT的patch嵌入
self.proj = nn.Linear(in_dim, 768)
config = ViTConfig()
config.is_decoder = True # 网页7中的解码器模式
self.transformer = ViTModel(config).encoder
self.head = nn.Sequential(
# 14x14 -> 28x28
nn.ConvTranspose2d(768, 384, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
# 28x28 -> 56x56
nn.ConvTranspose2d(384, 192, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
# 56x56 -> 112x112
nn.ConvTranspose2d(192, 96, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
# 112x112 -> 224x224
nn.ConvTranspose2d(96, 48, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
# 最终调整到3通道
nn.Conv2d(48, 3, kernel_size=1)
)
def forward(self, x):
x = self.proj(x) # [batch, 196, 768]
x = self.transformer(x).last_hidden_state
x = x.permute(0, 2, 1).view(-1, 768, 14, 14) # 恢复空间布局
return self.head(x) # 输出[1, 3, 224, 224]
# encoder = ViTEncoder()
# codebooker = Codebook()
# decoder = ViTDecoder()
# data = torch.randn(1, 3, 224, 224)
# output = encoder(data)
# print(output.shape)
# indices, quantized = codebooker.quantize(output)
# print(indices.shape, quantized.shape)
# reconstructed = decoder(quantized)
# print(reconstructed.shape)
from torchvision import transforms
import torchvision
import torch.nn.functional as F
# 数据增强和预处理
transform_train = transforms.Compose([
transforms.Resize(224), # 调整图像尺寸适配模型
transforms.RandomCrop(224, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# trainloader = torch.DataLoader(trainset, batch_size=64, shuffle=True)
# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
batch_size = 4 # 增大batch size加速训练
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import vgg16
# 初始化TensorBoard
writer = SummaryWriter('runs/codebook_experiment')
# 改进的Codebook类(增加EMA更新)
class Codebook(nn.Module):
def __init__(self, num_embeddings=1024, embedding_dim=512, commitment_cost=0.25, decay=0.99):
super().__init__()
self.codebook = nn.Embedding(num_embeddings, embedding_dim)
nn.init.normal_(self.codebook.weight)
self.commitment_cost = commitment_cost
self.decay = decay
self.register_buffer('ema_cluster_size', torch.zeros(num_embeddings))
self.ema_w = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
nn.init.normal_(self.ema_w)
def quantize(self, z):
# 重塑输入为二维矩阵 [batch*num_patches, embedding_dim]
batch, num_patches, dim = z.shape
z_flat = z.reshape(-1, dim) # [batch*num_patches, dim]
# 计算L2距离 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2
z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True) # [batch*num_patches, 1]
e_norm = torch.sum(self.codebook.weight ** 2, dim=1) # [num_embeddings]
dot_product = torch.matmul(z_flat, self.codebook.weight.t()) # [batch*num_patches, num_embeddings]
distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)
# 找到最近邻
indices = torch.argmin(distances, dim=1) # [batch*num_patches]
indices = indices.reshape(batch, num_patches) # 恢复原始形状
quantized = self.codebook(indices) # [batch, num_patches, dim]
# 新增EMA更新
# if self.training:
# with torch.no_grad():
# encodings = F.one_hot(indices, self.codebook.num_embeddings).float()
# self.ema_cluster_size = self.decay * self.ema_cluster_size + (1 - self.decay) * torch.sum(encodings, 0)
# n = torch.sum(self.ema_cluster_size)
# self.ema_cluster_size = ((self.ema_cluster_size + 1e-5) / (n + self.codebook.num_embeddings * 1e-5) * n)
# dw = torch.matmul(encodings.t(), z_flat)
# self.ema_w = nn.Parameter(self.ema_w * self.decay + (1 - self.decay) * dw)
# self.codebook.weight.data = self.ema_w / self.ema_cluster_size.unsqueeze(1)
return indices, quantized
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化组件
encoder = ViTEncoder().to(device)
codebook = Codebook(commitment_cost=0.25, decay=0.95).to(device)
decoder = ViTDecoder().to(device)
vgg = vgg16(pretrained=True).features[:16].eval().to(device) # 用于感知损失
# 优化器分开设置
opt = torch.optim.Adam([
{'params': encoder.parameters()},
{'params': decoder.parameters()},
{'params': codebook.parameters(), 'lr': 1e-4} # 更小的学习率
], lr=3e-4)
# 训练循环
for epoch in range(100):
avg_loss = 0
start_time = time.time() # 记录epoch开始时间
for batch_idx, (images, _) in enumerate(tqdm(trainloader, desc=f"Epoch {epoch+1}", ncols=80)):
images = images.to(device)
# 前向传播
z = encoder(images)
indices, quantized = codebook.quantize(z)
recon = decoder(quantized)
# 多目标损失计算
mse_loss = F.mse_loss(recon, images)
# 感知损失(VGG特征匹配)
with torch.no_grad():
real_features = vgg(images)
recon_features = vgg(recon)
percep_loss = F.mse_loss(recon_features, real_features)
# Codebook相关损失
commitment_loss = codebook.commitment_cost * F.mse_loss(z.detach(), quantized)
codebook_loss = F.mse_loss(z, quantized.detach())
# 总损失
total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss
# 反向传播
opt.zero_grad()
total_loss.backward()
opt.step()
# 记录数据
avg_loss += total_loss.item()
if batch_idx % 50 == 0:
# 记录TensorBoard数据
writer.add_scalar('Loss/total', total_loss.item(), epoch*len(trainloader)+batch_idx)
writer.add_scalars('Loss/components', {
'mse': mse_loss.item(),
'perceptual': percep_loss.item(),
'codebook': codebook_loss.item(),
'commitment': commitment_loss.item()
}, epoch*len(trainloader)+batch_idx)
# 保存重建样本
comparison = torch.cat([images[:4], recon[:4]])
grid = vutils.make_grid(comparison.cpu(), nrow=4, normalize=True)
writer.add_image('Reconstruction', grid, epoch*len(trainloader)+batch_idx)
# 打印epoch统计信息
avg_loss /= len(trainloader)
print(f"Epoch {epoch+1}: Avg Loss {avg_loss:.4f}")
# 保存模型检查点
if (epoch+1) % 10 == 0:
torch.save({
'encoder': encoder.state_dict(),
'codebook': codebook.state_dict(),
'decoder': decoder.state_dict(),
'opt': opt.state_dict()
}, f'checkpoint_epoch{epoch+1}.pth')
writer.close()
通过本实践,我们实现了从特征提取到离散表征学习的完整流程。Codebook技术可广泛应用于图像压缩、生成模型等领域,期待读者在此基础上探索更多可能性。