一、为什么选择PyTorch Lightning?
Lightning解决工业级开发的四大痛点:
- 代码规范:强制模块化分离(模型/数据/训练)
- 扩展性:无缝支持100+ GPU的分布式训练
- 可复现性:内置种子设置/版本控制
- 生产就绪:直接支持TPU训练、模型部署
二、环境配置与基础概念
# 安装核心库及扩展组件
pip install pytorch-lightning lightning-bolts torchmetrics wandb optuna
三、MNIST分类实战:从PyTorch到Lightning
1. 原始PyTorch实现(对比用)
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 数据准备
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST("./data", download=True, train=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
# 模型定义
class Net(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
return self.net(x.view(-1, 28*28))
# 训练逻辑
model = Net()
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
for epoch in range(5):
for batch in train_loader:
x, y = batch
preds = model(x)
loss = criterion(preds, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
2. Lightning改造版本
import pytorch_lightning as pl
from torchmetrics import Accuracy
class LitMNIST(pl.LightningModule):
def __init__(self, hidden_size=512, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters() # 保存超参数
self.model = nn.Sequential(
nn.Linear(28*28, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 10)
)
self.metric = Accuracy(task="multiclass", num_classes=10)
def forward(self, x):
return self.model(x.view(-1, 28*28))
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = nn.functional.cross_entropy(logits, y)
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
def prepare_data(self):
datasets.MNIST("./data", download=True)
def train_dataloader(self):
return DataLoader(
datasets.MNIST("./data", train=True, transform=transforms.ToTensor()),
batch_size=128,
num_workers=4
)
# 启动训练
trainer = pl.Trainer(
max_epochs=5,
accelerator="auto",
devices="auto",
enable_progress_bar=True
)
model = LitMNIST()
trainer.fit(model)
四、工业级功能扩展
1. 生产必备组件
trainer = pl.Trainer(
callbacks=[
pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),
pl.callbacks.ModelCheckpoint(
dirpath="./checkpoints",
filename="best_model_{epoch}_{val_acc:.2f}",
monitor="val_acc",
mode="max"
)
],
logger=pl.loggers.WandbLogger(project="MNIST"),
precision="16-mixed", # 混合精度训练
gradient_clip_val=0.5, # 梯度裁剪
accumulate_grad_batches=4, # 梯度累积
)
2. 分布式训练(无需修改代码)
# 启动多GPU训练(自动检测可用设备)
trainer = pl.Trainer(
devices=4,
strategy="ddp_find_unused_parameters_false",
accelerator="gpu"
)
3. 超参数优化(集成Optuna)
import optuna
def objective(trial):
model = LitMNIST(
hidden_size=trial.suggest_categorical("hidden_size", [256, 512, 1024]),
learning_rate=trial.suggest_float("lr", 1e-5, 1e-3, log=True)
)
trainer = pl.Trainer(max_epochs=10, enable_checkpointing=False)
trainer.fit(model)
return trainer.callback_metrics["val_acc"].item()
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=20)
print("最佳超参数:", study.best_params)
五、模型部署与监控
1. TorchScript导出
script = model.to_torchscript()
torch.jit.save(script, "mnist_model.pt")
2. 生产环境监控
class ProductionMonitor(pl.Callback):
def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
if batch_idx % 100 == 0:
memory = torch.cuda.max_memory_allocated() // 1024**2
print(f"GPU内存使用: {memory}MB")
# 接入Prometheus监控
import prometheus_client
metrics = {"train_loss": prometheus_client.Gauge("train_loss", "Training loss")}
六、调试技巧
1. 快速开发模式
# 自动检测数据/模型问题
trainer = pl.Trainer(fast_dev_run=True)
2. 性能分析
# 生成训练性能报告
trainer = pl.Trainer(
profiler="simple", # 或"advanced"/"pytorch"
benchmark=True
)
七、常见问题解答
Q1:如何恢复中断的训练?
trainer = pl.Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")
Q2:如何处理自定义数据集?
class CustomDataModule(pl.LightningDataModule):
def __init__(self, data_dir):
super().__init__()
self.data_dir = data_dir
def setup(self, stage=None):
self.train_dataset = CustomDataset(self.data_dir, train=True)
self.val_dataset = CustomDataset(self.data_dir, train=False)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=32)
Q3:如何自定义训练步骤?
def training_step(self, batch, batch_idx):
x, y = batch
# 实现定制逻辑
...
self.log_dict({"loss": loss, "acc": acc})
return loss