手搓LeNet-5(基础模型)实现交通标志识别
本文将使用PyTorch从零实现经典的LeNet-5模型,并在交通标志识别数据集上进行训练和部署。完整代码可直接运行。
一、环境准备
1. 安装Python环境
- 访问Python官网下载安装包:
python 官网 - 选择 Python 3.8+ 版本(推荐3.8.10)
- 安装时勾选 “Add Python to PATH”
2. 安装CUDA(可选,仅需GPU加速时)
- 建议搭配:CUDA 11.8 + cuDNN 8.6.0
CUDA+cuDNN 详细安装配置教程
3. 配置虚拟环境
- 打开命令提示符(CMD)或PowerShell
- 创建并激活虚拟环境(激活后命令行前缀会显示
(lenet_env)
):# 创建虚拟环境 python -m venv lenet_env # 激活环境 .\lenet_env\Scripts\activate
4. 安装PyTorch核心库
- 根据是否使用GPU选择命令:
# GPU版本(需CUDA 11.8)推荐 pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 -f https://mirrors.aliyun.com/pytorch-wheels/cu118/ # 或CPU版本 pip install torch torchvision torchaudio
5. 安装辅助库
- 安装其他库
pip install matplotlib numpy flask requests onnx onnxruntime
6. 验证安装
- 创建
check_env.py
文件并运行:
预期输出示例:import torch print("PyTorch版本:", torch.__version__) print("CUDA可用:", torch.cuda.is_available()) print("设备数量:", torch.cuda.device_count())
PyTorch版本: 2.3.1+cu118 CUDA可用: True 设备数量: 1
7. 准备数据集
- 下载GTSRB数据集:
- 训练集:https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Training_Images.zip
- 测试集:https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip
- 手动解压文件到以下目录结构:
C:/ └─your_project/ ├─data/ │ ├─train/ │ │ └─GTSRB/Final_Training/Images/... │ └─test/ │ └─GTSRB/Final_Test/Images/... └─code/
8.常见问题处理
- CUDA不可用:
- 检查显卡驱动是否为最新版本
- 确保安装的PyTorch版本与CUDA版本匹配
- 运行
nvidia-smi
验证显卡识别
- 数据集路径错误:
- 使用绝对路径(如
C:/your_project/data/train
) - 确保解压后的文件夹层级正确
- 内存不足:
- 降低batch_size参数(建议从64改为32)
- 关闭其他占用显存的程序
二、 数据集处理
使用德国交通标志识别基准(GTSRB)数据集:
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
# 数据预处理
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载数据集
train_set = datasets.ImageFolder(root='./data/train', transform=transform)
test_set = datasets.ImageFolder(root='./data/test', transform=transform)
# 创建数据加载器
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
print(f"训练集大小: {len(train_set)}")
print(f"测试集大小: {len(test_set)}")
print(f"类别数量: {len(train_set.classes)}")
三、 模型实现
LeNet-5的PyTorch实现:
import torch.nn as nn
class LeNet5(nn.Module):
def __init__(self, num_classes=43):
super(LeNet5, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, kernel_size=5), # 输入通道改为3(RGB)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(6, 16, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, num_classes)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
model = LeNet5()
print(model)
四、训练流程
训练配置与执行:
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# 训练循环
for epoch in range(20):
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
# 验证
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Epoch [{epoch+1}/20] Loss: {running_loss/len(train_loader):.4f} | Acc: {100*correct/total:.2f}%")
# 保存模型
torch.save(model.state_dict(), "lenet5_traffic_sign.pth")
五、模型部署
5.1 导出为ONNX格式
dummy_input = torch.randn(1, 3, 32, 32).to(device)
torch.onnx.export(model, dummy_input, "lenet5.onnx",
input_names=["input"], output_names=["output"])
5.2 使用Flask部署服务
from flask import Flask, request, jsonify
from PIL import Image
import numpy as np
app = Flask(__name__)
model.load_state_dict(torch.load("lenet5_traffic_sign.pth"))
model.eval()
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
return transform(image).unsqueeze(0)
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return jsonify({'error': 'No file uploaded'})
file = request.files['file']
image = Image.open(file.stream).convert('RGB')
tensor = preprocess_image(image).to(device)
with torch.no_grad():
outputs = model(tensor)
_, predicted = torch.max(outputs, 1)
return jsonify({'class_id': predicted.item(),
'class_name': train_set.classes[predicted.item()]})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
5.3 测试API
使用curl测试:
curl -X POST -F "file=@test_sign.jpg" http://localhost:5000/predict
六、总结
通过本文我们实现了:
- LeNet-5的PyTorch实现
- 交通标志数据集的加载与处理
- 模型的训练与验证
- 生产环境部署方案
完整代码需配合GTSRB数据集使用,数据集可从这里下载。建议使用GPU加速训练过程。