1. 环境准备与依赖导入
首先,代码需要导入了实现模型、数据处理和训练所需的各个包:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
- torch 与 torch.nn:PyTorch 的核心库,用于张量操作和定义神经网络层。
- DataLoader、random_split:用于批量加载数据和将数据集分割为训练、验证集。
- MNIST 与 ToTensor(来自 torchvision):用于下载 MNIST 数据集和将图像转换为张量格式。转换后,像素值会归一化到 [0, 1] 范围。
- pytorch_lightning(pl):PyTorch Lightning 提供了更高层次的封装,将训练循环、验证、测试等流程简化,从而让你更多地专注于模型的核心逻辑。
2. 定义 LightningModule 子类(LitModel)
PyTorch Lightning 的核心是继承 pl.LightningModule
,在这里我们通过子类化创建了一个简单的模型类 LitModel
。
2.1. __init__
方法
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.net = nn.Sequential(
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
self.loss_fn = nn.CrossEntropyLoss()
- Flatten 层:将输入的 28×28 图像展平成一维向量(大小为 784),便于后续全连接层处理。
- Sequential 模块:
- 第一层:nn.Linear(28*28, 128) 将 784 维的输入映射到 128 维。
- ReLU 激活:引入非线性,使模型能拟合复杂函数。
- 第二层:nn.Linear(128, 10) 将 128 维特征映射到 10 个输出,对应 MNIST 中的 10 个类别。
- 损失函数:这里使用交叉熵损失(
nn.CrossEntropyLoss
),适用于多分类问题。
2.2. forward
方法
def forward(self, x):
x = self.flatten(x)
return self.net(x)
- 作用:定义前向传播逻辑,将输入数据依次经过 flatten 层和神经网络,返回最终的 logits。
- 使用场景:在训练、验证以及测试过程中,Lightning 会调用该方法得到模型的预测结果。
2.3. training_step
方法
def training_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
self.log("train_loss", loss) # 自动记录训练损失
return loss
- 步骤解析:
- 从
batch
中获取输入x
和标签y
。 - 调用
forward
方法计算预测结果pred
。 - 通过损失函数计算预测结果与真实标签之间的误差。
- 使用
self.log
自动记录训练损失,便于后续监控和日志分析。 - 返回 loss,Lightning 会根据此 loss 自动执行反向传播和优化更新。
- 从
2.4. validation_step
方法
def validation_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
self.log("val_loss", loss) # 自动记录验证损失
return loss
- 作用:与训练步骤类似,只不过不需要进行反向传播。记录验证损失可以帮助我们监控模型在未见数据上的表现。
2.4. test_step
方法
def validation_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
self.log("test_loss", loss) # 自动记录测试损失
return loss
- 作用:与训练步骤类似,只不过不需要进行反向传播。记录测试损失可以帮助我们监控模型在未见数据上的表现。
2.5. configure_optimizers
方法
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
- 作用:配置训练所需的优化器。这里选择了 Adam 优化器,并设置学习率为 0.001。
- 细节说明:返回的优化器对象会被 Lightning 自动调用,完成参数更新。
3. 定义 LightningDataModule 子类(MNISTDataModule)
使用 LightningDataModule 能够使数据预处理、划分和加载更加模块化,便于在多个训练阶段(训练、验证、测试)中复用同一数据处理流程。
3.1. __init__
方法
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
- 作用:初始化数据模块,并将批处理大小(batch size)作为参数保存,方便后续创建 DataLoader。
3.2. prepare_data
方法
def prepare_data(self):
# 下载数据集
MNIST(root="data", train=True, download=True)
MNIST(root="data", train=False, download=True)
- 作用:仅在一开始执行一次,用于下载数据集。
- 注意:在分布式训练时,Lightning 会确保此方法只运行一次,避免重复下载。
3.3. setup
方法
def setup(self, stage=None):
# 数据预处理和划分
transform = ToTensor()
mnist_full = MNIST(root="data", train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
self.mnist_test = MNIST(root="data", train=False, transform=transform)
- 作用:加载并处理数据,划分数据集。
- 转换:使用
ToTensor()
将图像转换为张量。 - 数据划分:利用
random_split
将训练集划分为 55000 个训练样本和 5000 个验证样本。 - 同时加载测试数据集,用于最后模型的测试。
- 转换:使用
3.4. DataLoader 方法
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
- 作用:
- train_dataloader:返回训练数据的 DataLoader,设置
shuffle=True
确保数据在每个 epoch 中被打乱。 - val_dataloader:返回验证数据的 DataLoader,不需要打乱数据。
- test_dataloader:返回测试数据的 DataLoader,用于最终的模型评估。
- train_dataloader:返回训练数据的 DataLoader,设置
4. 训练和测试流程
在 if __name__ == "__main__":
块中,完成了模型与数据模块的实例化,并利用 Lightning 提供的 Trainer 完成训练和测试。
4.1. 实例化数据模块和模型
dm = MNISTDataModule(batch_size=32)
model = LitModel()
- 数据模块:初始化时设置 batch_size 为 32。
- 模型:创建自定义的 LightningModule 实例,包含网络结构、损失函数和优化器配置。
4.2. 创建 Trainer 并配置 GPU 加速
trainer = pl.Trainer(
max_epochs=3, # 训练 3 个 epoch
accelerator="gpu", # 指定使用 GPU
devices=[0],
)
- max_epochs:设定训练周期数为 3。
- accelerator 和 devices:
- 指定使用 GPU 加速,
devices=[0]
表示使用第 0 个 GPU(如果有多个 GPU,可根据需求调整)。 - 如果没有 GPU,该参数可以改为 “cpu” 或直接省略,Lightning 会自动适配。
- 指定使用 GPU 加速,
4.3. 训练和测试模型
trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)
- trainer.fit:开始训练过程,自动调用
training_step
和validation_step
,并利用 DataModule 提供的各个 DataLoader。 - trainer.test:在训练结束后,利用测试集评估模型性能。Lightning 会调用模型中的
test_step
(如果有定义)或者复用validation_step
来计算测试指标。
5. 总结
通过这份代码,你可以学到以下关键点:
- 模块化设计:利用 LightningModule 将模型、训练逻辑、验证逻辑和优化器配置集中在一个类中;而 LightningDataModule 则统一管理数据的下载、预处理、划分和加载。
- 简化训练流程:Lightning 的 Trainer 自动处理反向传播、参数更新、日志记录等细节,让你可以专注于模型设计和数据处理。
- GPU 加速支持:通过简单配置,即可利用 GPU 进行高效训练。
- 扩展性强:如果需要添加更多功能(如自定义回调、更多日志记录指标等),Lightning 提供了丰富的接口供你扩展。
这种清晰分离模型与数据逻辑的设计,不仅使代码结构更清晰,也方便在不同场景下复用和扩展。希望这个教程能帮助你更好地理解 PyTorch Lightning 的使用方法,并在项目中灵活应用这种高效的训练流程。
完整代码
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
# 1. 定义 LightningModule 子类
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.net = nn.Sequential(
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
x = self.flatten(x)
return self.net(x)
def training_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
self.log("train_loss", loss) # 自动记录训练损失
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
self.log("val_loss", loss) # 自动记录验证损失
return loss
def test_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
self.log("test_loss", loss) # 自动记录测试损失
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# 2. 准备数据模块
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
def prepare_data(self):
# 下载数据集
MNIST(root="data", train=True, download=True)
MNIST(root="data", train=False, download=True)
def setup(self, stage=None):
# 数据预处理和划分
transform = ToTensor()
mnist_full = MNIST(root="data", train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
self.mnist_test = MNIST(root="data", train=False, transform=transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
# 3. 训练模型
if __name__ == "__main__":
# 初始化数据模块和模型
dm = MNISTDataModule(batch_size=32)
model = LitModel()
# 创建训练器并训练
trainer = pl.Trainer(
max_epochs=3, # 训练3个epoch
accelerator="gpu", # 选择GPU
devices=[0],
)
trainer.fit(model, datamodule=dm)
# 测试模型
trainer.test(model, datamodule=dm)