import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=train_transform
)
test_dataset = datasets.CIFAR10(
root='./data',
train=False,
transform=test_transform
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
class ChannelAttention(nn.Module):
def __init__(self, in_channels: int, ratio: int = 16):
"""
通道注意力机制
Args:
in_channels: 输入通道数
ratio: 降维比例,默认16
"""
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // ratio, bias=False),
nn.ReLU(),
nn.Linear(in_channels // ratio, in_channels, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
前向传播
Args:
x: 输入特征图 (B, C, H, W)
Returns:
通道加权后的特征图
"""
b, c, h, w = x.shape
avg_feat = self.fc(self.avg_pool(x).view(b, c))
max_feat = self.fc(self.max_pool(x).view(b, c))
attn = self.sigmoid(avg_feat + max_feat).view(b, c, 1, 1)
return x * attn
class SpatialAttention(nn.Module):
def __init__(self, kernel_size: int = 7):
"""
空间注意力机制
Args:
kernel_size: 卷积核尺寸,默认7
"""
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
前向传播
Args:
x: 输入特征图 (B, C, H, W)
Returns:
空间加权后的特征图
"""
avg_feat = torch.mean(x, dim=1, keepdim=True)
max_feat, _ = torch.max(x, dim=1, keepdim=True)
pool_feat = torch.cat([avg_feat, max_feat], dim=1)
attn = self.conv(pool_feat)
return x * self.sigmoid(attn)
class CBAM(nn.Module):
def __init__(self, in_channels: int, ratio: int = 16, kernel_size: int = 7):
"""
卷积块注意力模块 (CBAM)
Args:
in_channels: 输入通道数
ratio: 通道注意力降维比例,默认16
kernel_size: 空间注意力卷积核尺寸,默认7
"""
super().__init__()
self.channel_attn = ChannelAttention(in_channels, ratio)
self.spatial_attn = SpatialAttention(kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
前向传播(先通道注意力,后空间注意力)
Args:
x: 输入特征图 (B, C, H, W)
Returns:
注意力加权后的特征图
"""
x = self.channel_attn(x)
x = self.spatial_attn(x)
return x
class CBAM_CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv_block1 = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.cbam1 = CBAM(in_channels=32)
self.conv_block2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.cbam2 = CBAM(in_channels=64)
self.conv_block3 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.cbam3 = CBAM(in_channels=128)
self.fc_layers = nn.Sequential(
nn.Linear(128 * 4 * 4, 512),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(512, 10)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
前向传播流程
Args:
x: 输入图像 (B, 3, 32, 32)
Returns:
分类 logits (B, 10)
"""
x = self.conv_block1(x)
x = self.cbam1(x)
x = self.conv_block2(x)
x = self.cbam2(x)
x = self.conv_block3(x)
x = self.cbam3(x)
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
return x
model = CBAM_CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
def train(
model: nn.Module,
train_loader: DataLoader,
test_loader: DataLoader,
criterion: nn.Module,
optimizer: optim.Optimizer,
scheduler: optim.lr_scheduler._LRScheduler,
device: torch.device,
epochs: int
) -> float:
"""
训练过程主函数
Args:
model: 待训练模型
train_loader: 训练数据加载器
test_loader: 测试数据加载器
criterion: 损失函数
optimizer: 优化器
scheduler: 学习率调度器
device: 计算设备
epochs: 训练轮数
Returns:
最终测试准确率
"""
model.train()
train_loss_history = []
test_loss_history = []
train_acc_history = []
test_acc_history = []
all_iter_losses = []
iter_indices = []
for epoch in range(epochs):
running_loss = 0.0
correct_train = 0
total_train = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
iter_loss = loss.item()
all_iter_losses.append(iter_loss)
iter_indices.append(epoch * len(train_loader) + batch_idx + 1)
running_loss += iter_loss
_, predicted = output.max(1)
total_train += target.size(0)
correct_train += predicted.eq(target).sum().item()
if (batch_idx + 1) % 100 == 0:
avg_loss = running_loss / (batch_idx + 1)
print(f"Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} "
f"| 单Batch损失: {iter_loss:.4f} | 平均损失: {avg_loss:.4f}")
epoch_train_loss = running_loss / len(train_loader)
epoch_train_acc = 100. * correct_train / total_train
train_loss_history.append(epoch_train_loss)
train_acc_history.append(epoch_train_acc)
model.eval()
test_loss = 0.0
correct_test = 0
total_test = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
_, predicted = output.max(1)
total_test += target.size(0)
correct_test += predicted.eq(target).sum().item()
epoch_test_loss = test_loss / len(test_loader)
epoch_test_acc = 100. * correct_test / total_test
test_loss_history.append(epoch_test_loss)
test_acc_history.append(epoch_test_acc)
scheduler.step(epoch_test_loss)
print(f"Epoch {epoch+1}/{epochs} 完成 | "
f"Train Acc: {epoch_train_acc:.2f}% | Test Acc: {epoch_test_acc:.2f}%")
plot_iter_losses(all_iter_losses, iter_indices)
plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)
return epoch_test_acc
def plot_iter_losses(losses: list, indices: list) -> None:
"""绘制每个迭代的损失曲线"""
plt.figure(figsize=(10, 4))
plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')
plt.xlabel('Iteration (Batch序号)')
plt.ylabel('Loss值')
plt.title('训练过程中每个Batch的损失变化')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
def plot_epoch_metrics(
train_acc: list,
test_acc: list,
train_loss: list,
test_loss: list
) -> None:
"""绘制 epoch 级准确率和损失曲线"""
epochs = range(1, len(train_acc) + 1)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs, train_acc, 'b-', label='Train Accuracy')
plt.plot(epochs, test_acc, 'r-', label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('训练与测试准确率对比')
plt.legend()
plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(epochs, train_loss, 'b-', label='Train Loss')
plt.plot(epochs, test_loss, 'r-', label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss值')
plt.title('训练与测试损失对比')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
epochs = 50
print("开始训练带CBAM的CNN模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")
@浙大疏锦行