简单CNN训练Kaggle上的Fashion MNIST项目。
项目链接:Fashion MNIST | Kaggle
1.导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
df = pd.read_csv('./archive (1)/fashion-mnist_train.csv') # 文件名根据你下载的实际情况改一下
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
def set_seed(seed=42):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True # 可复现
torch.backends.cudnn.benchmark = False # 禁止自动优化
set_seed(42)
# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
2.数据预处理
from torch.utils.data import TensorDataset, DataLoader
labels = df.iloc[:, 0].values # 第一列是标签
images = df.iloc[:, 1:].values.astype('float32') # 后784列是像素值
# 3. 归一化到 [0,1]
images = images / 255.0
# 4. 转换为张量并 reshape 为 [batch_size, 1, 28, 28]
images = torch.tensor(images).reshape(-1, 1, 28, 28)
labels = torch.tensor(labels, dtype=torch.long)
# 5. 标准化:将 [0,1] 区间映射到 [-1,1] 区间
images = (images - 0.5) / 0.5 # 这等价于 Normalize(mean=0.5, std=0.5)
# 6. 封装为 TensorDataset
train_dataset = TensorDataset(images, labels)
# 7. 创建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# 1. 导入测试集
test_df = pd.read_csv('./archive (1)/fashion-mnist_test.csv') # 替换成你的实际路径
# 2. 拆分标签和图像数据
test_labels = test_df.iloc[:, 0].values
test_images = test_df.iloc[:, 1:].values.astype('float32')
# 3. 归一化到 [0, 1]
test_images = test_images / 255.0
# 4. 转换为 tensor 并 reshape 为 [batch_size, 1, 28, 28]
test_images = torch.tensor(test_images).reshape(-1, 1, 28, 28)
test_labels = torch.tensor(test_labels, dtype=torch.long)
# 5. 标准化:将 [0,1] 映射到 [-1,1]
test_images = (test_images - 0.5) / 0.5
# 6. 封装为 TensorDataset
test_dataset = TensorDataset(test_images, test_labels)
# 7. 创建 DataLoader
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
3.数据加载器
test_dataset = TensorDataset(test_images, test_labels)
# 7. 创建 DataLoader
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
4.模型的定义
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 输入图像为 1 通道,28×28
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # 输出 32×28×28
self.pool1 = nn.MaxPool2d(2, 2) # 输出 32×14×14
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 输出 64×14×14
self.pool2 = nn.MaxPool2d(2, 2) # 输出 64×7×7
self.fc1 = nn.Linear(64 * 7 * 7, 128) # Flatten 后全连接
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(128, 10) # 10个类别输出
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x))) # [batch, 32, 14, 14]
x = self.pool2(F.relu(self.conv2(x))) # [batch, 64, 7, 7]
x = x.view(-1, 64 * 7 * 7) # flatten
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
# 初始化模型
model = CNN()
model = model.to(device) # 将模型移至GPU(如果可用)
5.损失函数与优化器选择
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器
# 引入学习率调度器,在训练过程中动态调整学习率--训练初期使用较大的 LR 快速降低损失,训练后期使用较小的 LR 更精细地逼近全局最优解。
# 在每个 epoch 结束后,需要手动调用调度器来更新学习率,可以在训练过程中调用 scheduler.step()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, # 指定要控制的优化器(这里是Adam)
mode='min', # 监测的指标是"最小化"(如损失函数)
patience=3, # 如果连续3个epoch指标没有改善,才降低LR
factor=0.5 # 降低LR的比例(新LR = 旧LR × 0.5)
)
6.训练过程
# 5. 训练模型(记录每个 iteration 的损失)
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs):
model.train() # 设置为训练模式
# 记录每个 iteration 的损失
all_iter_losses = [] # 存储所有 batch 的损失
iter_indices = [] # 存储 iteration 序号
# 记录每个 epoch 的准确率和损失
train_acc_history = []
test_acc_history = []
train_loss_history = []
test_loss_history = []
for epoch in range(epochs):
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device) # 移至GPU
optimizer.zero_grad() # 梯度清零
output = model(data) # 前向传播
loss = criterion(output, target) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
# 记录当前 iteration 的损失
iter_loss = loss.item()
all_iter_losses.append(iter_loss)
iter_indices.append(epoch * len(train_loader) + batch_idx + 1)
# 统计准确率和损失
running_loss += iter_loss
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
# 每100个批次打印一次训练信息
if (batch_idx + 1) % 100 == 0:
print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} '
f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')
# 计算当前epoch的平均训练损失和准确率
epoch_train_loss = running_loss / len(train_loader)
epoch_train_acc = 100. * correct / total
train_acc_history.append(epoch_train_acc)
train_loss_history.append(epoch_train_loss)
# 测试阶段
model.eval() # 设置为评估模式
test_loss = 0
correct_test = 0
total_test = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
_, predicted = output.max(1)
total_test += target.size(0)
correct_test += predicted.eq(target).sum().item()
epoch_test_loss = test_loss / len(test_loader)
epoch_test_acc = 100. * correct_test / total_test
test_acc_history.append(epoch_test_acc)
test_loss_history.append(epoch_test_loss)
# 更新学习率调度器
scheduler.step(epoch_test_loss)
print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')
# 绘制所有 iteration 的损失曲线
plot_iter_losses(all_iter_losses, iter_indices)
# 绘制每个 epoch 的准确率和损失曲线
plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)
return epoch_test_acc # 返回最终测试准确率
# 6. 绘制每个 iteration 的损失曲线
def plot_iter_losses(losses, indices):
plt.figure(figsize=(10, 4))
plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')
plt.xlabel('Iteration(Batch序号)')
plt.ylabel('损失值')
plt.title('每个 Iteration 的训练损失')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
# 7. 绘制每个 epoch 的准确率和损失曲线
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):
epochs = range(1, len(train_acc) + 1)
plt.figure(figsize=(12, 4))
# 绘制准确率曲线
plt.subplot(1, 2, 1)
plt.plot(epochs, train_acc, 'b-', label='训练准确率')
plt.plot(epochs, test_acc, 'r-', label='测试准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率 (%)')
plt.title('训练和测试准确率')
plt.legend()
plt.grid(True)
# 绘制损失曲线
plt.subplot(1, 2, 2)
plt.plot(epochs, train_loss, 'b-', label='训练损失')
plt.plot(epochs, test_loss, 'r-', label='测试损失')
plt.xlabel('Epoch')
plt.ylabel('损失值')
plt.title('训练和测试损失')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
# 8. 执行训练和测试
epochs = 20 # 增加训练轮次以获得更好效果
print("开始使用CNN训练模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")
训练结果:
开始使用CNN训练模型...
Epoch: 1/20 | Batch: 100/938 | 单Batch损失: 0.6728 | 累计平均损失: 1.1205
Epoch: 1/20 | Batch: 200/938 | 单Batch损失: 0.6159 | 累计平均损失: 0.8879
Epoch: 1/20 | Batch: 300/938 | 单Batch损失: 0.3473 | 累计平均损失: 0.7917
Epoch: 1/20 | Batch: 400/938 | 单Batch损失: 0.6818 | 累计平均损失: 0.7279
Epoch: 1/20 | Batch: 500/938 | 单Batch损失: 0.4126 | 累计平均损失: 0.6816
Epoch: 1/20 | Batch: 600/938 | 单Batch损失: 0.4884 | 累计平均损失: 0.6500
Epoch: 1/20 | Batch: 700/938 | 单Batch损失: 0.4775 | 累计平均损失: 0.6224
Epoch: 1/20 | Batch: 800/938 | 单Batch损失: 0.3434 | 累计平均损失: 0.5997
Epoch: 1/20 | Batch: 900/938 | 单Batch损失: 0.4812 | 累计平均损失: 0.5829
Epoch 1/20 完成 | 训练准确率: 79.24% | 测试准确率: 87.73%
Epoch: 2/20 | Batch: 100/938 | 单Batch损失: 0.3459 | 累计平均损失: 0.3237
Epoch: 2/20 | Batch: 200/938 | 单Batch损失: 0.2269 | 累计平均损失: 0.3208
Epoch: 2/20 | Batch: 300/938 | 单Batch损失: 0.2620 | 累计平均损失: 0.3134
Epoch: 2/20 | Batch: 400/938 | 单Batch损失: 0.4117 | 累计平均损失: 0.3060
Epoch: 2/20 | Batch: 500/938 | 单Batch损失: 0.2391 | 累计平均损失: 0.3037
Epoch: 2/20 | Batch: 600/938 | 单Batch损失: 0.2740 | 累计平均损失: 0.2991
Epoch: 2/20 | Batch: 700/938 | 单Batch损失: 0.3822 | 累计平均损失: 0.2972
Epoch: 2/20 | Batch: 800/938 | 单Batch损失: 0.2515 | 累计平均损失: 0.2942
Epoch: 2/20 | Batch: 900/938 | 单Batch损失: 0.3838 | 累计平均损失: 0.2917
Epoch 2/20 完成 | 训练准确率: 89.15% | 测试准确率: 89.65%
Epoch: 3/20 | Batch: 100/938 | 单Batch损失: 0.2323 | 累计平均损失: 0.2451
Epoch: 3/20 | Batch: 200/938 | 单Batch损失: 0.3070 | 累计平均损失: 0.2477
Epoch: 3/20 | Batch: 300/938 | 单Batch损失: 0.2418 | 累计平均损失: 0.2418
Epoch: 3/20 | Batch: 400/938 | 单Batch损失: 0.1767 | 累计平均损失: 0.2437
Epoch: 3/20 | Batch: 500/938 | 单Batch损失: 0.0857 | 累计平均损失: 0.2440
Epoch: 3/20 | Batch: 600/938 | 单Batch损失: 0.1620 | 累计平均损失: 0.2443
Epoch: 3/20 | Batch: 700/938 | 单Batch损失: 0.2602 | 累计平均损失: 0.2430
Epoch: 3/20 | Batch: 800/938 | 单Batch损失: 0.4102 | 累计平均损失: 0.2431
Epoch: 3/20 | Batch: 900/938 | 单Batch损失: 0.1834 | 累计平均损失: 0.2424
Epoch 3/20 完成 | 训练准确率: 90.98% | 测试准确率: 91.45%
Epoch: 4/20 | Batch: 100/938 | 单Batch损失: 0.1462 | 累计平均损失: 0.2029
Epoch: 4/20 | Batch: 200/938 | 单Batch损失: 0.1863 | 累计平均损失: 0.2117
Epoch: 4/20 | Batch: 300/938 | 单Batch损失: 0.2021 | 累计平均损失: 0.2089
Epoch: 4/20 | Batch: 400/938 | 单Batch损失: 0.1214 | 累计平均损失: 0.2105
Epoch: 4/20 | Batch: 500/938 | 单Batch损失: 0.2102 | 累计平均损失: 0.2108
Epoch: 4/20 | Batch: 600/938 | 单Batch损失: 0.2666 | 累计平均损失: 0.2118
Epoch: 4/20 | Batch: 700/938 | 单Batch损失: 0.2241 | 累计平均损失: 0.2131
Epoch: 4/20 | Batch: 800/938 | 单Batch损失: 0.2433 | 累计平均损失: 0.2122
Epoch: 4/20 | Batch: 900/938 | 单Batch损失: 0.1631 | 累计平均损失: 0.2130
Epoch 4/20 完成 | 训练准确率: 92.12% | 测试准确率: 91.63%
Epoch: 5/20 | Batch: 100/938 | 单Batch损失: 0.1170 | 累计平均损失: 0.1761
Epoch: 5/20 | Batch: 200/938 | 单Batch损失: 0.2268 | 累计平均损失: 0.1819
Epoch: 5/20 | Batch: 300/938 | 单Batch损失: 0.1428 | 累计平均损失: 0.1800
Epoch: 5/20 | Batch: 400/938 | 单Batch损失: 0.0922 | 累计平均损失: 0.1840
Epoch: 5/20 | Batch: 500/938 | 单Batch损失: 0.3292 | 累计平均损失: 0.1822
Epoch: 5/20 | Batch: 600/938 | 单Batch损失: 0.1645 | 累计平均损失: 0.1836
Epoch: 5/20 | Batch: 700/938 | 单Batch损失: 0.1229 | 累计平均损失: 0.1842
Epoch: 5/20 | Batch: 800/938 | 单Batch损失: 0.0530 | 累计平均损失: 0.1842
Epoch: 5/20 | Batch: 900/938 | 单Batch损失: 0.1329 | 累计平均损失: 0.1832
Epoch 5/20 完成 | 训练准确率: 93.22% | 测试准确率: 92.15%
Epoch: 6/20 | Batch: 100/938 | 单Batch损失: 0.0954 | 累计平均损失: 0.1555
Epoch: 6/20 | Batch: 200/938 | 单Batch损失: 0.1088 | 累计平均损失: 0.1517
Epoch: 6/20 | Batch: 300/938 | 单Batch损失: 0.1287 | 累计平均损失: 0.1560
Epoch: 6/20 | Batch: 400/938 | 单Batch损失: 0.1425 | 累计平均损失: 0.1573
Epoch: 6/20 | Batch: 500/938 | 单Batch损失: 0.2129 | 累计平均损失: 0.1597
Epoch: 6/20 | Batch: 600/938 | 单Batch损失: 0.2167 | 累计平均损失: 0.1584
Epoch: 6/20 | Batch: 700/938 | 单Batch损失: 0.2693 | 累计平均损失: 0.1582
Epoch: 6/20 | Batch: 800/938 | 单Batch损失: 0.2473 | 累计平均损失: 0.1594
Epoch: 6/20 | Batch: 900/938 | 单Batch损失: 0.1082 | 累计平均损失: 0.1605
Epoch 6/20 完成 | 训练准确率: 94.03% | 测试准确率: 92.64%
Epoch: 7/20 | Batch: 100/938 | 单Batch损失: 0.1022 | 累计平均损失: 0.1202
Epoch: 7/20 | Batch: 200/938 | 单Batch损失: 0.1118 | 累计平均损失: 0.1268
Epoch: 7/20 | Batch: 300/938 | 单Batch损失: 0.0536 | 累计平均损失: 0.1255
Epoch: 7/20 | Batch: 400/938 | 单Batch损失: 0.0925 | 累计平均损失: 0.1261
Epoch: 7/20 | Batch: 500/938 | 单Batch损失: 0.1524 | 累计平均损失: 0.1302
Epoch: 7/20 | Batch: 600/938 | 单Batch损失: 0.1403 | 累计平均损失: 0.1320
Epoch: 7/20 | Batch: 700/938 | 单Batch损失: 0.1818 | 累计平均损失: 0.1335
Epoch: 7/20 | Batch: 800/938 | 单Batch损失: 0.1135 | 累计平均损失: 0.1354
Epoch: 7/20 | Batch: 900/938 | 单Batch损失: 0.1085 | 累计平均损失: 0.1370
Epoch 7/20 完成 | 训练准确率: 94.93% | 测试准确率: 91.92%
Epoch: 8/20 | Batch: 100/938 | 单Batch损失: 0.1835 | 累计平均损失: 0.1144
Epoch: 8/20 | Batch: 200/938 | 单Batch损失: 0.1692 | 累计平均损失: 0.1177
Epoch: 8/20 | Batch: 300/938 | 单Batch损失: 0.1051 | 累计平均损失: 0.1171
Epoch: 8/20 | Batch: 400/938 | 单Batch损失: 0.0776 | 累计平均损失: 0.1200
Epoch: 8/20 | Batch: 500/938 | 单Batch损失: 0.0990 | 累计平均损失: 0.1202
Epoch: 8/20 | Batch: 600/938 | 单Batch损失: 0.0701 | 累计平均损失: 0.1203
Epoch: 8/20 | Batch: 700/938 | 单Batch损失: 0.0546 | 累计平均损失: 0.1206
Epoch: 8/20 | Batch: 800/938 | 单Batch损失: 0.0199 | 累计平均损失: 0.1207
Epoch: 8/20 | Batch: 900/938 | 单Batch损失: 0.0319 | 累计平均损失: 0.1207
Epoch 8/20 完成 | 训练准确率: 95.43% | 测试准确率: 92.38%
Epoch: 9/20 | Batch: 100/938 | 单Batch损失: 0.0767 | 累计平均损失: 0.0956
Epoch: 9/20 | Batch: 200/938 | 单Batch损失: 0.0777 | 累计平均损失: 0.0979
Epoch: 9/20 | Batch: 300/938 | 单Batch损失: 0.0683 | 累计平均损失: 0.0998
Epoch: 9/20 | Batch: 400/938 | 单Batch损失: 0.1044 | 累计平均损失: 0.0995
Epoch: 9/20 | Batch: 500/938 | 单Batch损失: 0.0808 | 累计平均损失: 0.0990
Epoch: 9/20 | Batch: 600/938 | 单Batch损失: 0.0298 | 累计平均损失: 0.1023
Epoch: 9/20 | Batch: 700/938 | 单Batch损失: 0.0805 | 累计平均损失: 0.1033
Epoch: 9/20 | Batch: 800/938 | 单Batch损失: 0.1438 | 累计平均损失: 0.1038
Epoch: 9/20 | Batch: 900/938 | 单Batch损失: 0.1212 | 累计平均损失: 0.1034
Epoch 9/20 完成 | 训练准确率: 96.04% | 测试准确率: 92.55%
Epoch: 10/20 | Batch: 100/938 | 单Batch损失: 0.1138 | 累计平均损失: 0.0861
Epoch: 10/20 | Batch: 200/938 | 单Batch损失: 0.0486 | 累计平均损失: 0.0835
Epoch: 10/20 | Batch: 300/938 | 单Batch损失: 0.0891 | 累计平均损失: 0.0855
Epoch: 10/20 | Batch: 400/938 | 单Batch损失: 0.0323 | 累计平均损失: 0.0858
Epoch: 10/20 | Batch: 500/938 | 单Batch损失: 0.1320 | 累计平均损失: 0.0862
Epoch: 10/20 | Batch: 600/938 | 单Batch损失: 0.1182 | 累计平均损失: 0.0864
Epoch: 10/20 | Batch: 700/938 | 单Batch损失: 0.1017 | 累计平均损失: 0.0868
Epoch: 10/20 | Batch: 800/938 | 单Batch损失: 0.0458 | 累计平均损失: 0.0878
Epoch: 10/20 | Batch: 900/938 | 单Batch损失: 0.1041 | 累计平均损失: 0.0889
Epoch 10/20 完成 | 训练准确率: 96.75% | 测试准确率: 92.44%
Epoch: 11/20 | Batch: 100/938 | 单Batch损失: 0.0352 | 累计平均损失: 0.0519
Epoch: 11/20 | Batch: 200/938 | 单Batch损失: 0.0173 | 累计平均损失: 0.0510
Epoch: 11/20 | Batch: 300/938 | 单Batch损失: 0.0482 | 累计平均损失: 0.0516
Epoch: 11/20 | Batch: 400/938 | 单Batch损失: 0.0953 | 累计平均损失: 0.0538
Epoch: 11/20 | Batch: 500/938 | 单Batch损失: 0.0324 | 累计平均损失: 0.0532
Epoch: 11/20 | Batch: 600/938 | 单Batch损失: 0.1732 | 累计平均损失: 0.0544
Epoch: 11/20 | Batch: 700/938 | 单Batch损失: 0.0624 | 累计平均损失: 0.0544
Epoch: 11/20 | Batch: 800/938 | 单Batch损失: 0.0128 | 累计平均损失: 0.0543
Epoch: 11/20 | Batch: 900/938 | 单Batch损失: 0.0752 | 累计平均损失: 0.0542
Epoch 11/20 完成 | 训练准确率: 98.03% | 测试准确率: 92.77%
Epoch: 12/20 | Batch: 100/938 | 单Batch损失: 0.0514 | 累计平均损失: 0.0360
Epoch: 12/20 | Batch: 200/938 | 单Batch损失: 0.0159 | 累计平均损失: 0.0395
Epoch: 12/20 | Batch: 300/938 | 单Batch损失: 0.0505 | 累计平均损失: 0.0410
Epoch: 12/20 | Batch: 400/938 | 单Batch损失: 0.0058 | 累计平均损失: 0.0417
Epoch: 12/20 | Batch: 500/938 | 单Batch损失: 0.0430 | 累计平均损失: 0.0422
Epoch: 12/20 | Batch: 600/938 | 单Batch损失: 0.0582 | 累计平均损失: 0.0433
Epoch: 12/20 | Batch: 700/938 | 单Batch损失: 0.0393 | 累计平均损失: 0.0440
Epoch: 12/20 | Batch: 800/938 | 单Batch损失: 0.0209 | 累计平均损失: 0.0441
Epoch: 12/20 | Batch: 900/938 | 单Batch损失: 0.0722 | 累计平均损失: 0.0453
Epoch 12/20 完成 | 训练准确率: 98.50% | 测试准确率: 92.97%
Epoch: 13/20 | Batch: 100/938 | 单Batch损失: 0.0180 | 累计平均损失: 0.0360
Epoch: 13/20 | Batch: 200/938 | 单Batch损失: 0.0124 | 累计平均损失: 0.0340
Epoch: 13/20 | Batch: 300/938 | 单Batch损失: 0.0146 | 累计平均损失: 0.0332
Epoch: 13/20 | Batch: 400/938 | 单Batch损失: 0.0827 | 累计平均损失: 0.0342
Epoch: 13/20 | Batch: 500/938 | 单Batch损失: 0.0221 | 累计平均损失: 0.0350
Epoch: 13/20 | Batch: 600/938 | 单Batch损失: 0.0248 | 累计平均损失: 0.0359
Epoch: 13/20 | Batch: 700/938 | 单Batch损失: 0.0079 | 累计平均损失: 0.0364
Epoch: 13/20 | Batch: 800/938 | 单Batch损失: 0.0323 | 累计平均损失: 0.0374
Epoch: 13/20 | Batch: 900/938 | 单Batch损失: 0.0458 | 累计平均损失: 0.0377
Epoch 13/20 完成 | 训练准确率: 98.76% | 测试准确率: 92.98%
Epoch: 14/20 | Batch: 100/938 | 单Batch损失: 0.0505 | 累计平均损失: 0.0259
Epoch: 14/20 | Batch: 200/938 | 单Batch损失: 0.0191 | 累计平均损失: 0.0274
Epoch: 14/20 | Batch: 300/938 | 单Batch损失: 0.0290 | 累计平均损失: 0.0299
Epoch: 14/20 | Batch: 400/938 | 单Batch损失: 0.0241 | 累计平均损失: 0.0288
Epoch: 14/20 | Batch: 500/938 | 单Batch损失: 0.0182 | 累计平均损失: 0.0292
Epoch: 14/20 | Batch: 600/938 | 单Batch损失: 0.0101 | 累计平均损失: 0.0291
Epoch: 14/20 | Batch: 700/938 | 单Batch损失: 0.0338 | 累计平均损失: 0.0293
Epoch: 14/20 | Batch: 800/938 | 单Batch损失: 0.0335 | 累计平均损失: 0.0299
Epoch: 14/20 | Batch: 900/938 | 单Batch损失: 0.0180 | 累计平均损失: 0.0299
Epoch 14/20 完成 | 训练准确率: 99.00% | 测试准确率: 92.87%
Epoch: 15/20 | Batch: 100/938 | 单Batch损失: 0.0114 | 累计平均损失: 0.0188
Epoch: 15/20 | Batch: 200/938 | 单Batch损失: 0.0085 | 累计平均损失: 0.0181
Epoch: 15/20 | Batch: 300/938 | 单Batch损失: 0.0123 | 累计平均损失: 0.0183
Epoch: 15/20 | Batch: 400/938 | 单Batch损失: 0.0186 | 累计平均损失: 0.0181
Epoch: 15/20 | Batch: 500/938 | 单Batch损失: 0.0176 | 累计平均损失: 0.0181
Epoch: 15/20 | Batch: 600/938 | 单Batch损失: 0.0111 | 累计平均损失: 0.0177
Epoch: 15/20 | Batch: 700/938 | 单Batch损失: 0.0140 | 累计平均损失: 0.0177
Epoch: 15/20 | Batch: 800/938 | 单Batch损失: 0.0185 | 累计平均损失: 0.0179
Epoch: 15/20 | Batch: 900/938 | 单Batch损失: 0.0706 | 累计平均损失: 0.0181
Epoch 15/20 完成 | 训练准确率: 99.56% | 测试准确率: 93.08%
Epoch: 16/20 | Batch: 100/938 | 单Batch损失: 0.0055 | 累计平均损失: 0.0155
Epoch: 16/20 | Batch: 200/938 | 单Batch损失: 0.0063 | 累计平均损失: 0.0138
Epoch: 16/20 | Batch: 300/938 | 单Batch损失: 0.0151 | 累计平均损失: 0.0138
Epoch: 16/20 | Batch: 400/938 | 单Batch损失: 0.0098 | 累计平均损失: 0.0138
Epoch: 16/20 | Batch: 500/938 | 单Batch损失: 0.0181 | 累计平均损失: 0.0134
Epoch: 16/20 | Batch: 600/938 | 单Batch损失: 0.0523 | 累计平均损失: 0.0134
Epoch: 16/20 | Batch: 700/938 | 单Batch损失: 0.0055 | 累计平均损失: 0.0138
Epoch: 16/20 | Batch: 800/938 | 单Batch损失: 0.0338 | 累计平均损失: 0.0138
Epoch: 16/20 | Batch: 900/938 | 单Batch损失: 0.0281 | 累计平均损失: 0.0141
Epoch 16/20 完成 | 训练准确率: 99.65% | 测试准确率: 92.94%
Epoch: 17/20 | Batch: 100/938 | 单Batch损失: 0.0020 | 累计平均损失: 0.0099
Epoch: 17/20 | Batch: 200/938 | 单Batch损失: 0.0095 | 累计平均损失: 0.0096
Epoch: 17/20 | Batch: 300/938 | 单Batch损失: 0.0084 | 累计平均损失: 0.0102
Epoch: 17/20 | Batch: 400/938 | 单Batch损失: 0.0085 | 累计平均损失: 0.0104
Epoch: 17/20 | Batch: 500/938 | 单Batch损失: 0.0065 | 累计平均损失: 0.0108
Epoch: 17/20 | Batch: 600/938 | 单Batch损失: 0.0144 | 累计平均损失: 0.0112
Epoch: 17/20 | Batch: 700/938 | 单Batch损失: 0.0039 | 累计平均损失: 0.0114
Epoch: 17/20 | Batch: 800/938 | 单Batch损失: 0.0279 | 累计平均损失: 0.0114
Epoch: 17/20 | Batch: 900/938 | 单Batch损失: 0.0112 | 累计平均损失: 0.0115
Epoch 17/20 完成 | 训练准确率: 99.77% | 测试准确率: 92.95%
Epoch: 18/20 | Batch: 100/938 | 单Batch损失: 0.0050 | 累计平均损失: 0.0081
Epoch: 18/20 | Batch: 200/938 | 单Batch损失: 0.0038 | 累计平均损失: 0.0084
Epoch: 18/20 | Batch: 300/938 | 单Batch损失: 0.0034 | 累计平均损失: 0.0086
Epoch: 18/20 | Batch: 400/938 | 单Batch损失: 0.0021 | 累计平均损失: 0.0085
Epoch: 18/20 | Batch: 500/938 | 单Batch损失: 0.0069 | 累计平均损失: 0.0085
Epoch: 18/20 | Batch: 600/938 | 单Batch损失: 0.0025 | 累计平均损失: 0.0086
Epoch: 18/20 | Batch: 700/938 | 单Batch损失: 0.0236 | 累计平均损失: 0.0093
Epoch: 18/20 | Batch: 800/938 | 单Batch损失: 0.0420 | 累计平均损失: 0.0099
Epoch: 18/20 | Batch: 900/938 | 单Batch损失: 0.0159 | 累计平均损失: 0.0099
Epoch 18/20 完成 | 训练准确率: 99.75% | 测试准确率: 92.90%
Epoch: 19/20 | Batch: 100/938 | 单Batch损失: 0.0061 | 累计平均损失: 0.0061
Epoch: 19/20 | Batch: 200/938 | 单Batch损失: 0.0009 | 累计平均损失: 0.0058
Epoch: 19/20 | Batch: 300/938 | 单Batch损失: 0.0060 | 累计平均损失: 0.0057
Epoch: 19/20 | Batch: 400/938 | 单Batch损失: 0.0053 | 累计平均损失: 0.0056
Epoch: 19/20 | Batch: 500/938 | 单Batch损失: 0.0047 | 累计平均损失: 0.0057
Epoch: 19/20 | Batch: 600/938 | 单Batch损失: 0.0161 | 累计平均损失: 0.0056
Epoch: 19/20 | Batch: 700/938 | 单Batch损失: 0.0152 | 累计平均损失: 0.0058
Epoch: 19/20 | Batch: 800/938 | 单Batch损失: 0.0049 | 累计平均损失: 0.0058
Epoch: 19/20 | Batch: 900/938 | 单Batch损失: 0.0066 | 累计平均损失: 0.0058
Epoch 19/20 完成 | 训练准确率: 99.93% | 测试准确率: 93.04%
Epoch: 20/20 | Batch: 100/938 | 单Batch损失: 0.0017 | 累计平均损失: 0.0043
Epoch: 20/20 | Batch: 200/938 | 单Batch损失: 0.0049 | 累计平均损失: 0.0042
Epoch: 20/20 | Batch: 300/938 | 单Batch损失: 0.0044 | 累计平均损失: 0.0043
Epoch: 20/20 | Batch: 400/938 | 单Batch损失: 0.0044 | 累计平均损失: 0.0044
Epoch: 20/20 | Batch: 500/938 | 单Batch损失: 0.0079 | 累计平均损失: 0.0048
Epoch: 20/20 | Batch: 600/938 | 单Batch损失: 0.0108 | 累计平均损失: 0.0047
Epoch: 20/20 | Batch: 700/938 | 单Batch损失: 0.0038 | 累计平均损失: 0.0047
Epoch: 20/20 | Batch: 800/938 | 单Batch损失: 0.0095 | 累计平均损失: 0.0047
Epoch: 20/20 | Batch: 900/938 | 单Batch损失: 0.0076 | 累计平均损失: 0.0047
Epoch 20/20 完成 | 训练准确率: 99.95% | 测试准确率: 92.99%
训练完成!最终测试准确率: 92.99%
6.Grad-Cam
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import os
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
self._register_hooks()
def _register_hooks(self):
def forward_hook(module, input, output):
self.activations = output.detach()
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
self.target_layer.register_forward_hook(forward_hook)
self.target_layer.register_backward_hook(backward_hook) # 老钩子,兼容性好
def generate_cam(self, input_tensor, target_class=None):
self.model.eval()
output = self.model(input_tensor)
if target_class is None:
target_class = output.argmax(dim=1).item()
self.model.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[0, target_class] = 1
output.backward(gradient=one_hot)
gradients = self.gradients # [1, channels, h, w]
activations = self.activations # [1, channels, h, w]
weights = gradients.mean(dim=(2, 3), keepdim=True) # [1, channels, 1, 1]
cam = (weights * activations).sum(dim=1, keepdim=True) # [1, 1, h, w]
cam = F.relu(cam)
cam = F.interpolate(cam, size=(28, 28), mode='bilinear', align_corners=False)
cam = cam - cam.min()
cam_max = cam.max()
if cam_max > 0:
cam = cam / cam_max
return cam.squeeze().cpu().numpy()
def tensor_to_np(tensor):
img = tensor.cpu().numpy().squeeze(0)
img = img * 0.5 + 0.5 # 假设输入是[-1,1]归一化,转成[0,1]
return np.clip(img, 0, 1)
# 使用示例(假设你已有 model, test_dataset, device)
os.makedirs("grad_cam_per_class", exist_ok=True)
class_names = ['T恤or上衣', '裤子', '套头衫', '连衣裙', '外套', '凉鞋', '衬衫', '运动鞋', '包', '短靴']
grad_cam = GradCAM(model, model.conv2) # 选最后一个卷积层 conv2
shown_classes = set()
for idx in range(len(test_dataset)):
image, label = test_dataset[idx]
label_id = label.item()
if label_id in shown_classes:
continue
input_tensor = image.unsqueeze(0).to(device)
heatmap = grad_cam.generate_cam(input_tensor, target_class=label_id)
img = tensor_to_np(image)
heatmap_color = plt.cm.jet(heatmap)[:, :, :3] # colormap expects [0,1] float
superimposed_img = heatmap_color * 0.4 + np.repeat(img[..., np.newaxis], 3, axis=2) * 0.6
plt.figure(figsize=(10, 3))
plt.subplot(1, 3, 1)
plt.imshow(img, cmap='gray')
plt.title(f'原始图像: {class_names[label_id]}')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title('Grad-CAM 热力图')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(superimposed_img)
plt.title('叠加效果')
plt.axis('off')
plt.tight_layout()
plt.savefig(f"grad_cam_per_class/class_{label_id}_{class_names[label_id]}.png")
plt.close()
shown_classes.add(label_id)
if len(shown_classes) == 10:
break
print("每类Grad-CAM图像已保存至 grad_cam_per_class/ 文件夹。")