pytorch 演示 “变分状态空间模型(Variational State-Space Model, VSSM)“ 基于 MINIST数据集

发布于:2025-07-23 ⋅ 阅读:(33) ⋅ 点赞:(0)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os

# 设置随机种子,保证结果可复现
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 创建保存图像的目录
os.makedirs('visualizations', exist_ok=True)

# 数据加载和预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# 由于显存限制,增大batch_size可能会导致显存不足,因此选择适中的batch_size
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# 定义简化版的VSSM模型
class VSSM(nn.Module):
  def __init__(self, input_size=784, hidden_size=32, state_size=16, output_size=10):
    super(VSSM, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.state_size = state_size
    self.output_size = output_size
    
    # 编码器网络 - 将输入映射到隐状态分布
    self.encoder = nn.Sequential(
      nn.Linear(input_size, hidden_size),
      nn.ReLU(),
      nn.Linear(hidden_size, hidden_size),
      nn.ReLU()
    )
    
    # 变分推断网络 - 生成隐状态的均值和方差
    self.fc_mu = nn.Linear(hidden_size, state_size)
    self.fc_logvar = nn.Linear(hidden_size, state_size)
    
    # 状态转移网络 - 预测下一个隐状态
    self.transition = nn.Sequential(
      nn.Linear(state_size, hidden_size),
      nn.ReLU(),
      nn.Linear(hidden_size, state_size)
    )
    
    # 解码器网络 - 从隐状态重构输入
    self.decoder = nn.Sequential(
      nn.Linear(state_size, hidden_size),
      nn.ReLU(),
      nn.Linear(hidden_size, input_size)
    )
    
    # 分类器网络 - 从隐状态预测类别
    self.classifier = nn.Sequential(
      nn.Linear(state_size, hidden_size),
      nn.ReLU(),
      nn.Dropout(0.2),  # 添加Dropout减少过拟合
      nn.Linear(hidden_size, output_size)
    )
  
  def encode(self, x):
    # x: [batch_size, input_size]
    h = self.encoder(x)
    mu = self.fc_mu(h)      # [batch_size, state_size]
    logvar = self.fc_logvar(h)  # [batch_size, state_size]
    return mu, logvar
  
  def reparameterize(self, mu, logvar):
    # 重参数化技巧,实现隐变量的随机采样
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std  # [batch_size, state_size]
  
  def decode(self, z):
    # z: [batch_size, state_size]
    return self.decoder(z)  # [batch_size, input_size]
  
  def classify(self, z):
    # z: [batch_size, state_size]
    return self.classifier(z)  # [batch_size, output_size]
  
  def forward(self, x):
    # x: [batch_size, 1, 28, 28]
    batch_size = x.size(0)
    x_flat = x.view(batch_size, -1)  # [batch_size, 784]
    
    # 编码并采样隐状态
    mu, logvar = self.encode(x_flat)
    z = self.reparameterize(mu, logvar)
    
    # 状态转移
    z_next = self.transition(z)
    
    # 解码和分类
    recon_flat = self.decode(z_next)
    pred = self.classify(z)
    
    return recon_flat, pred, mu, logvar, z, x_flat

# 定义VSSM损失函数
def vssm_loss(recon_x, x, pred, target, mu, logvar, lambda_kl=0.1, lambda_cls=1.0):
  # 重构损失 - 衡量重构图像与原始图像的差异
  recon_loss = F.mse_loss(recon_x, x.view(x.size(0), -1), reduction='sum')
  
  # KL散度 - 衡量隐变量分布与标准正态分布的差异
  kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
  
  # 分类损失
  cls_loss = F.cross_entropy(pred, target, reduction='sum')
  
  # 计算总损失
  batch_size = x.size(0)
  total_loss = (recon_loss + lambda_kl * kl_loss + lambda_cls * cls_loss) / batch_size
  
  return total_loss, recon_loss.item()/batch_size, kl_loss.item()/batch_size, cls_loss.item()/batch_size

# 绘制损失曲线的函数
def pltLoss(train_losses, test_losses, epochs):
  plt.figure(figsize=(10, 5))
  plt.plot(range(1, epochs+1), train_losses, 'b-', label='Training Loss')
  plt.plot(range(1, epochs+1), test_losses, 'r-', label='Test Loss')
  plt.xlabel('Epochs')
  plt.ylabel('Loss')
  plt.title('Training and Test Loss')
  plt.legend()
  plt.grid(True)
  plt.tight_layout()
  plt.savefig('loss_curve.png')
  plt.close()

# 可视化测试样本及其预测结果的函数
def plotTest(model, test_loader, device, epoch):
  model.eval()
  best_sample = None
  best_confidence = -1
  best_info = None
  
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      
      # 前向传播获取中间结果
      recon_flat, pred, mu, logvar, z, x_flat = model(data)
      
      # 计算预测置信度
      confidence = F.softmax(pred, dim=1).max(dim=1)[0]
      
      # 找到置信度最高的样本
      max_idx = confidence.argmin().item()
      if confidence[max_idx] > best_confidence:
        best_confidence = confidence[max_idx].item()
        best_sample = {
          'input': data[max_idx].cpu(),
          'recon': recon_flat[max_idx].cpu().view(1, 28, 28),
          'target': target[max_idx].cpu().item(),
          'pred': pred[max_idx].argmax().cpu().item(),
          'confidence': best_confidence,
          'mu': mu[max_idx].cpu().numpy(),
          'logvar': logvar[max_idx].cpu().numpy(),
          'z': z[max_idx].cpu().numpy(),
          'pred_dist': F.softmax(pred[max_idx], dim=0).cpu().numpy()
        }
      
      # 释放不再需要的张量以节省显存
      del data, target, recon_flat, pred, mu, logvar, z, x_flat, confidence, max_idx
      torch.cuda.empty_cache()
  
  if best_sample is not None:
    # 创建可视化
    plt.figure(figsize=(12, 8))
    
    # 1. 原始输入图像
    plt.subplot(2, 3, 1)
    plt.title(f'Input Image (True: {best_sample["target"]})')
    plt.imshow(best_sample['input'].squeeze().numpy(), cmap='gray')
    plt.axis('off')
    
    # 2. 重构图像
    plt.subplot(2, 3, 2)
    plt.title(f'Reconstructed Image')
    plt.imshow(best_sample['recon'].squeeze().numpy(), cmap='gray')
    plt.axis('off')
    
    # 3. 隐变量均值
    plt.subplot(2, 3, 3)
    plt.title('Latent Mean (μ)')
    plt.bar(range(len(best_sample['mu'])), best_sample['mu'])
    plt.xlabel('Dimension')
    plt.ylabel('Value')
    
    # 4. 隐变量方差
    plt.subplot(2, 3, 4)
    plt.title('Latent Log Variance (log σ²)')
    plt.bar(range(len(best_sample['logvar'])), best_sample['logvar'])
    plt.xlabel('Dimension')
    plt.ylabel('Value')
    
    # 5. 采样的隐变量
    plt.subplot(2, 3, 5)
    plt.title('Sampled Latent Variable (z)')
    plt.bar(range(len(best_sample['z'])), best_sample['z'])
    plt.xlabel('Dimension')
    plt.ylabel('Value')
    
    # 6. 预测分布
    plt.subplot(2, 3, 6)
    plt.title(f'Prediction Distribution (Pred: {best_sample["pred"]}, Conf: {best_sample["confidence"]:.4f})')
    plt.bar(range(10), best_sample['pred_dist'])
    plt.xticks(range(10))
    plt.xlabel('Class')
    plt.ylabel('Probability')
    
    plt.tight_layout()
    plt.savefig(f'visualizations/epoch_{epoch}_best_sample.png')
    plt.close()

# 初始化模型、优化器和学习率调度器
model = VSSM().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)

# 训练函数
def train(model, train_loader, optimizer, epoch, device):
  model.train()
  train_loss = 0
  train_recon_loss = 0
  train_kl_loss = 0
  train_cls_loss = 0
  
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    
    optimizer.zero_grad()
    
    # 前向传播 - 接收所有6个返回值
    recon, pred, mu, logvar, z, x_flat = model(data)
    
    # 计算损失
    loss, recon_loss, kl_loss, cls_loss = vssm_loss(recon, data, pred, target, mu, logvar)
    
    # 反向传播和优化
    loss.backward()
    optimizer.step()
    
    # 累加损失
    train_loss += loss.item()
    train_recon_loss += recon_loss
    train_kl_loss += kl_loss
    train_cls_loss += cls_loss
    
    # 释放不再需要的张量以节省显存
    # del data, target, recon, pred, mu, logvar, z, x_flat, loss, recon_loss, kl_loss, cls_loss
    # torch.cuda.empty_cache()
    
    # 打印训练进度
    if batch_idx % 100 == 0:
      print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
            f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
  
  # 计算平均损失
  avg_loss = train_loss / len(train_loader)
  avg_recon_loss = train_recon_loss / len(train_loader)
  avg_kl_loss = train_kl_loss / len(train_loader)
  avg_cls_loss = train_cls_loss / len(train_loader)
  
  print(f'Epoch: {epoch} Average training loss: {avg_loss:.4f} '
        f'(Recon: {avg_recon_loss:.4f}, KL: {avg_kl_loss:.4f}, Cls: {avg_cls_loss:.4f})')
  
  return avg_loss

# 测试函数
def test(model, test_loader, device):
  model.eval()
  test_loss = 0
  test_recon_loss = 0
  test_kl_loss = 0
  test_cls_loss = 0
  correct = 0
  
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      
      # 前向传播 - 接收所有6个返回值
      recon, pred, mu, logvar, z, x_flat = model(data)
      
      # 计算损失
      loss, recon_loss, kl_loss, cls_loss = vssm_loss(recon, data, pred, target, mu, logvar)
      
      # 累加损失
      test_loss += loss.item()
      test_recon_loss += recon_loss
      test_kl_loss += kl_loss
      test_cls_loss += cls_loss
      
      # 计算分类准确率
      pred_class = pred.argmax(dim=1, keepdim=True)
      correct += pred_class.eq(target.view_as(pred_class)).sum().item()
      
    #   # 释放不再需要的张量以节省显存
    #   del data, target, recon, pred, mu, logvar, z, x_flat, loss, recon_loss, kl_loss, cls_loss, pred_class
    #   torch.cuda.empty_cache()
  
  # 计算平均损失和准确率
  avg_loss = test_loss / len(test_loader)
  avg_recon_loss = test_recon_loss / len(test_loader)
  avg_kl_loss = test_kl_loss / len(test_loader)
  avg_cls_loss = test_cls_loss / len(test_loader)
  accuracy = 100. * correct / len(test_loader.dataset)
  
  print(f'Average test loss: {avg_loss:.4f} '
        f'(Recon: {avg_recon_loss:.4f}, KL: {avg_kl_loss:.4f}, Cls: {avg_cls_loss:.4f})')
  print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')
  
  return avg_loss, accuracy

# 主训练循环
epochs = 10
train_losses = []
test_losses = []
best_accuracy = 0.0

for epoch in range(1, epochs + 1):
  print(f'\nEpoch {epoch}/{epochs}')
  
  # 训练一个epoch
  train_loss = train(model, train_loader, optimizer, epoch, device)
  train_losses.append(train_loss)
  
  # 测试模型
  test_loss, accuracy = test(model, test_loader, device)
  test_losses.append(test_loss)
  
  # 可视化最佳样本
  plotTest(model, test_loader, device, epoch)
  
  # 学习率调整
  scheduler.step(test_loss)
  
  # 保存最佳模型
  if accuracy > best_accuracy:
    best_accuracy = accuracy
    torch.save(model.state_dict(), 'best_model.pth')
    print(f'Best model saved with accuracy: {accuracy:.2f}%')
  
  # 绘制损失曲线
  pltLoss(train_losses, test_losses, epoch)
  
  # 释放不再需要的张量以节省显存
  torch.cuda.empty_cache()

print(f'\nTraining completed. Best accuracy: {best_accuracy:.2f}%')  

网站公告

今日签到

点亮在社区的每一天
去签到