我之前在2024-07-15的时候实现过一版胶囊网络,但是当时无论我怎么训练,都没办法达到Hinton论文里的99.23%(MNIST扩展数据集上):
前两天心血来潮,又认真读了一下Hinton论文,严格按照论文要求进行复现:最终达到了Hinton里的性能上限,同时训练速度也比以往那版要快大概五六倍。
1. 导入数据集
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 下载并加载MNIST训练数据集
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
def show_image(image, label):
plt.imshow(image, cmap='gray')
plt.title(f'Label: {label}')
plt.show()
# 显示一个训练样本
show_image(trainset[0][0][0], trainset[0][1])
2. 模型代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
def squash(s):
norm = torch.norm(s, dim=-1, keepdim=True)
s_squared_norm = norm ** 2
return (s_squared_norm / (1 + s_squared_norm)) * (s / norm)
def routing(u_hat, num_iteratiobns):
# u_hat: (B,N,M,D)
batch_size, num_capsules_i, num_capsules_j, n_dim = u_hat.size()
b_ij = torch.zeros(batch_size, num_capsules_i, num_capsules_j, 1).to(u_hat.device) # (B,N,M,1)
for _ in range(num_iteratiobns):
c_ij = F.softmax(b_ij, dim=1) # (B,N,M,D)
s_j = torch.sum(c_ij * u_hat, dim=1, keepdim=True) # (B,1,M,D)
v_j = squash(s_j) # (B,1,M,D)
b_ij += torch.sum(u_hat * v_j, dim=-1, keepdim=True) # (B,N,M,1)
return v_j.squeeze(1) # (B,M,D)
class PrimaryCaps(nn.Module):
def __init__(self, in_channels=256, out_channels=32, capsule_dim=8, kernel_size=9, stride=2):
super(PrimaryCaps, self).__init__()
self.capsule_dim = capsule_dim
self.out_channels = out_channels
# 使用卷积层来生成初级胶囊的输入(激活初始胶囊向量)
self.conv2 = nn.Conv2d(in_channels, out_channels * capsule_dim, kernel_size=kernel_size, stride=stride)
def forward(self, x):
# x: (B, C, H, W)
B = x.size(0)
# 进行卷积操作,并将输出调整为胶囊向量的形式,然后整合所有胶囊向量
x = self.conv2(x).permute(0, 2, 3, 1).reshape(B, -1, self.capsule_dim).contiguous() # (B, N, D)
# 对每个胶囊的输出向量应用squash函数
x = squash(x)
return x
class DigitCaps(nn.Module):
def __init__(self, num_capsules=10, num_route_nodes=1152, in_channels=8, out_channels=16, num_iterations=3):
"""
:param in_channels: 输入胶囊的维度
:param out_channels: 输出胶囊的维度
:param num_capsules: 输出的胶囊数量,对应数字类别数(通常为 10)
:param num_iterations: 动态路由的迭代次数
:param W: 权重矩阵,用于将输入胶囊映射到输出胶囊
"""
super(DigitCaps, self).__init__()
self.num_capsules = num_capsules
self.num_iterations = num_iterations
self.W = nn.Parameter(torch.randn(1, num_route_nodes, num_capsules, in_channels, out_channels))
# 也可以所以胶囊共享一个Wj,性能并不会比前者差多少。
# self.W = nn.Parameter(torch.randn(1, 1, num_capsules, in_channels, out_channels))
def forward(self, x):
# x: (B, N, D1)
# 计算预测向量
x = x.unsqueeze(-2).unsqueeze(-2) # (B, N, 1, 1, D1)
u_hat = torch.matmul(x, self.W).squeeze(-2) # (B, N, M, D2)
# 进行动态路由
v = routing(u_hat, self.num_iterations) # (B, M, D2)
# 返回输出胶囊向量的长度
v = torch.norm(v, dim=-1) # (B, M)
return v
class CapesuleNet(nn.Module):
def __init__(self, num_classes=10):
super(CapesuleNet, self).__init__()
self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1) # 灰度图只有一个原始维度
self.primary_capsules = PrimaryCaps()
self.digit_capsules = DigitCaps(num_capsules=num_classes)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.primary_capsules(x)
x = self.digit_capsules(x)
return x
class MarginLoss(nn.Module):
def __init__(self, m_plus=0.9, m_minus=0.1, lambd=0.5):
super(MarginLoss, self).__init__()
self.m_plus = m_plus
self.m_minus = m_minus
self.lambd = lambd
def forward(self, v, target):
"""
v: 形状为 (batch_size, num_classes),v_k表示第k个数字胶囊的实例化向量的长度
target: 形状为 (batch_size, num_classes),one-hot编码的目标标签
"""
target = torch.eye(num_classes)[target.detach().cpu()].to(device) # one-hot编码
left = torch.clamp(self.m_plus - v, min=0) ** 2
right = torch.clamp(v - self.m_minus, min=0) ** 2
loss = target * left + self.lambd * (1 - target) * right
return torch.mean(torch.sum(loss, dim=1))
3. 训练代码
num_epochs=10
num_classes=10
batch_size=128
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CapesuleNet().to(device)
margin_loss = MarginLoss()
# margin_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters())
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
@torch.no_grad()
def estimate_acc():
model.eval()
acc = {}
loader = test_loader
correct = 0
total = 0
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc.update({'val': (correct / total) * 100})
model.train()
return acc
for epoch in range(num_epochs):
correct = 0
total = 0
with tqdm(total=len(train_loader), desc="epoch %d" % epoch) as pbar:
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = margin_loss(outputs, labels)
loss.backward()
optimizer.step()
predicted = torch.argmax(outputs.data, dim=1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 更新进度条
pbar.set_postfix({
'Loss': f'{loss.item():.4f}',
'Accuracy': f'{(correct / total) * 100:.4f}%(train)',
})
pbar.update(1)
acc = estimate_acc()
print(f"Accuracy: {acc['val']:.4f}%(val)")
可以看到神经网络迅速地收敛,并在第7个epoch达到了论文里的性能上限99.23%!