PyTorch Lightning 的简单使用示例

发布于:2025-03-04 ⋅ 阅读:(13) ⋅ 点赞:(0)

1. 环境准备与依赖导入

首先,代码需要导入了实现模型、数据处理和训练所需的各个包:

import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
  • torch 与 torch.nn:PyTorch 的核心库,用于张量操作和定义神经网络层。
  • DataLoader、random_split:用于批量加载数据和将数据集分割为训练、验证集。
  • MNIST 与 ToTensor(来自 torchvision):用于下载 MNIST 数据集和将图像转换为张量格式。转换后,像素值会归一化到 [0, 1] 范围。
  • pytorch_lightning(pl):PyTorch Lightning 提供了更高层次的封装,将训练循环、验证、测试等流程简化,从而让你更多地专注于模型的核心逻辑。

2. 定义 LightningModule 子类(LitModel)

PyTorch Lightning 的核心是继承 pl.LightningModule,在这里我们通过子类化创建了一个简单的模型类 LitModel

2.1. __init__ 方法

def __init__(self):
    super().__init__()
    self.flatten = nn.Flatten()
    self.net = nn.Sequential(
        nn.Linear(28*28, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    )
    self.loss_fn = nn.CrossEntropyLoss()
  • Flatten 层:将输入的 28×28 图像展平成一维向量(大小为 784),便于后续全连接层处理。
  • Sequential 模块
    • 第一层:nn.Linear(28*28, 128) 将 784 维的输入映射到 128 维。
    • ReLU 激活:引入非线性,使模型能拟合复杂函数。
    • 第二层:nn.Linear(128, 10) 将 128 维特征映射到 10 个输出,对应 MNIST 中的 10 个类别。
  • 损失函数:这里使用交叉熵损失(nn.CrossEntropyLoss),适用于多分类问题。

2.2. forward 方法

def forward(self, x):
    x = self.flatten(x)
    return self.net(x)
  • 作用:定义前向传播逻辑,将输入数据依次经过 flatten 层和神经网络,返回最终的 logits。
  • 使用场景:在训练、验证以及测试过程中,Lightning 会调用该方法得到模型的预测结果。

2.3. training_step 方法

def training_step(self, batch, batch_idx):
    x, y = batch
    pred = self(x)
    loss = self.loss_fn(pred, y)
    self.log("train_loss", loss)  # 自动记录训练损失
    return loss
  • 步骤解析
    • batch 中获取输入 x 和标签 y
    • 调用 forward 方法计算预测结果 pred
    • 通过损失函数计算预测结果与真实标签之间的误差。
    • 使用 self.log 自动记录训练损失,便于后续监控和日志分析。
    • 返回 loss,Lightning 会根据此 loss 自动执行反向传播和优化更新。

2.4. validation_step 方法

def validation_step(self, batch, batch_idx):
    x, y = batch
    pred = self(x)
    loss = self.loss_fn(pred, y)
    self.log("val_loss", loss)    # 自动记录验证损失
    return loss
  • 作用:与训练步骤类似,只不过不需要进行反向传播。记录验证损失可以帮助我们监控模型在未见数据上的表现。

2.4. test_step 方法

def validation_step(self, batch, batch_idx):
    x, y = batch
    pred = self(x)
    loss = self.loss_fn(pred, y)
    self.log("test_loss", loss)    # 自动记录测试损失
    return loss
  • 作用:与训练步骤类似,只不过不需要进行反向传播。记录测试损失可以帮助我们监控模型在未见数据上的表现。

2.5. configure_optimizers 方法

def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=1e-3)
  • 作用:配置训练所需的优化器。这里选择了 Adam 优化器,并设置学习率为 0.001。
  • 细节说明:返回的优化器对象会被 Lightning 自动调用,完成参数更新。

3. 定义 LightningDataModule 子类(MNISTDataModule)

使用 LightningDataModule 能够使数据预处理、划分和加载更加模块化,便于在多个训练阶段(训练、验证、测试)中复用同一数据处理流程。

3.1. __init__ 方法

def __init__(self, batch_size=32):
    super().__init__()
    self.batch_size = batch_size
  • 作用:初始化数据模块,并将批处理大小(batch size)作为参数保存,方便后续创建 DataLoader。

3.2. prepare_data 方法

def prepare_data(self):
    # 下载数据集
    MNIST(root="data", train=True, download=True)
    MNIST(root="data", train=False, download=True)
  • 作用:仅在一开始执行一次,用于下载数据集。
  • 注意:在分布式训练时,Lightning 会确保此方法只运行一次,避免重复下载。

3.3. setup 方法

def setup(self, stage=None):
    # 数据预处理和划分
    transform = ToTensor()
    mnist_full = MNIST(root="data", train=True, transform=transform)
    self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
    self.mnist_test = MNIST(root="data", train=False, transform=transform)
  • 作用:加载并处理数据,划分数据集。
    • 转换:使用 ToTensor() 将图像转换为张量。
    • 数据划分:利用 random_split 将训练集划分为 55000 个训练样本和 5000 个验证样本。
    • 同时加载测试数据集,用于最后模型的测试。

3.4. DataLoader 方法

def train_dataloader(self):
    return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)

def val_dataloader(self):
    return DataLoader(self.mnist_val, batch_size=self.batch_size)

def test_dataloader(self):
    return DataLoader(self.mnist_test, batch_size=self.batch_size)
  • 作用
    • train_dataloader:返回训练数据的 DataLoader,设置 shuffle=True 确保数据在每个 epoch 中被打乱。
    • val_dataloader:返回验证数据的 DataLoader,不需要打乱数据。
    • test_dataloader:返回测试数据的 DataLoader,用于最终的模型评估。

4. 训练和测试流程

if __name__ == "__main__": 块中,完成了模型与数据模块的实例化,并利用 Lightning 提供的 Trainer 完成训练和测试。

4.1. 实例化数据模块和模型

dm = MNISTDataModule(batch_size=32)
model = LitModel()
  • 数据模块:初始化时设置 batch_size 为 32。
  • 模型:创建自定义的 LightningModule 实例,包含网络结构、损失函数和优化器配置。

4.2. 创建 Trainer 并配置 GPU 加速

trainer = pl.Trainer(
    max_epochs=3,          # 训练 3 个 epoch
    accelerator="gpu",     # 指定使用 GPU
    devices=[0],
)
  • max_epochs:设定训练周期数为 3。
  • accelerator 和 devices
    • 指定使用 GPU 加速,devices=[0] 表示使用第 0 个 GPU(如果有多个 GPU,可根据需求调整)。
    • 如果没有 GPU,该参数可以改为 “cpu” 或直接省略,Lightning 会自动适配。

4.3. 训练和测试模型

trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)
  • trainer.fit:开始训练过程,自动调用 training_stepvalidation_step,并利用 DataModule 提供的各个 DataLoader。
  • trainer.test:在训练结束后,利用测试集评估模型性能。Lightning 会调用模型中的 test_step(如果有定义)或者复用 validation_step 来计算测试指标。

5. 总结

通过这份代码,你可以学到以下关键点:

  • 模块化设计:利用 LightningModule 将模型、训练逻辑、验证逻辑和优化器配置集中在一个类中;而 LightningDataModule 则统一管理数据的下载、预处理、划分和加载。
  • 简化训练流程:Lightning 的 Trainer 自动处理反向传播、参数更新、日志记录等细节,让你可以专注于模型设计和数据处理。
  • GPU 加速支持:通过简单配置,即可利用 GPU 进行高效训练。
  • 扩展性强:如果需要添加更多功能(如自定义回调、更多日志记录指标等),Lightning 提供了丰富的接口供你扩展。

这种清晰分离模型与数据逻辑的设计,不仅使代码结构更清晰,也方便在不同场景下复用和扩展。希望这个教程能帮助你更好地理解 PyTorch Lightning 的使用方法,并在项目中灵活应用这种高效的训练流程。


完整代码

import torch  
from torch import nn  
from torch.utils.data import DataLoader, random_split  
from torchvision.datasets import MNIST  
from torchvision.transforms import ToTensor  
import pytorch_lightning as pl  
  
# 1. 定义 LightningModule 子类  
class LitModel(pl.LightningModule):  
    def __init__(self):  
        super().__init__()  
        self.flatten = nn.Flatten()  
        self.net = nn.Sequential(  
            nn.Linear(28*28, 128),  
            nn.ReLU(),  
            nn.Linear(128, 10)  
        )  
        self.loss_fn = nn.CrossEntropyLoss()  
  
    def forward(self, x):  
        x = self.flatten(x)  
        return self.net(x)  
  
    def training_step(self, batch, batch_idx):  
        x, y = batch  
        pred = self(x)  
        loss = self.loss_fn(pred, y)  
        self.log("train_loss", loss)  # 自动记录训练损失  
        return loss  
  
    def validation_step(self, batch, batch_idx):  
        x, y = batch  
        pred = self(x)  
        loss = self.loss_fn(pred, y)  
        self.log("val_loss", loss)    # 自动记录验证损失  
        return loss  
  
    def test_step(self, batch, batch_idx):  
        x, y = batch  
        pred = self(x)  
        loss = self.loss_fn(pred, y)  
        self.log("test_loss", loss)   # 自动记录测试损失  
        return loss  
  
    def configure_optimizers(self):  
        return torch.optim.Adam(self.parameters(), lr=1e-3)  
  
# 2. 准备数据模块  
class MNISTDataModule(pl.LightningDataModule):  
    def __init__(self, batch_size=32):  
        super().__init__()  
        self.batch_size = batch_size  
  
    def prepare_data(self):  
        # 下载数据集  
        MNIST(root="data", train=True, download=True)  
        MNIST(root="data", train=False, download=True)  
  
    def setup(self, stage=None):  
        # 数据预处理和划分  
        transform = ToTensor()  
        mnist_full = MNIST(root="data", train=True, transform=transform)  
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])  
        self.mnist_test = MNIST(root="data", train=False, transform=transform)  
  
    def train_dataloader(self):  
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)  
  
    def val_dataloader(self):  
        return DataLoader(self.mnist_val, batch_size=self.batch_size)  
  
    def test_dataloader(self):  
        return DataLoader(self.mnist_test, batch_size=self.batch_size)  
  
# 3. 训练模型  
if __name__ == "__main__":  
    # 初始化数据模块和模型  
    dm = MNISTDataModule(batch_size=32)  
    model = LitModel()  
  
    # 创建训练器并训练  
    trainer = pl.Trainer(  
        max_epochs=3,          # 训练3个epoch  
        accelerator="gpu",   # 选择GPU  
        devices=[0],  
    )  
    trainer.fit(model, datamodule=dm)  
  
    # 测试模型  
    trainer.test(model, datamodule=dm)