实战:用 PyTorch 复现一个 3 层全连接网络,训练 MNIST,达到 95%+ 准确率

发布于:2025-08-12 ⋅ 阅读:(15) ⋅ 点赞:(0)

1. 使用 Anaconda 创建一个新环境,包括 python 和 与你显卡对应的 torch

2. PyCharm(2025.1.3.1)绑定 Conda 环境-CSDN博客

3. 

import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# 一次给模型看多少张图片
BATCH_SIZE = 64
# 把全部训练数据重复看多少遍
EPOCHS = 10
LR = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_set = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_set  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

# 原始数据集中,一张 MNIST 图片的形状是 (1, 28, 28) ← 1 个通道(灰度),高 28,宽 28。
# 当 DataLoader 按 batch_size=64 打包后,它把 64 张这样的图片堆在一起,形成一个新的 4 维张量,形状变成 (64, 1, 28, 28)
# shuffle = True 的作用:在每个 epoch 开始时,把训练集里的 60 000 张图片顺序彻底打乱一次。
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE)

# 搭建神经网络:把图片拉成一条长条 → 过 128 个神经元 → 再过 64 个神经元 → 最后给出 10 个数字的得分
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 128), nn.ReLU(),
            nn.Linear(128, 64),  nn.ReLU(),
            nn.Linear(64, 10)
        )
    def forward(self, x):
        return self.net(x)

model = Net().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

# 训练
for epoch in range(1, EPOCHS + 1):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for x, y in pbar:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item())

model.eval()
correct = total = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        pred = model(x).argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
print(f"Test Accuracy: {100*correct/total:.2f}%")

4. 运行