深度学习篇---断点重训&模型部署文件

发布于:2025-03-27 ⋅ 阅读:(31) ⋅ 点赞:(0)


前言

PaddlePaddle 框架中,断点重训(恢复训练)和 模型部署 需要保存不同类型的文件


一、断点重训(Checkpoint)文件

断点重训需要保存训练过程中的完整状态,包括 模型参数优化器状态学习率调度器状态 以及 训练进度信息(如当前 epoch、迭代步数等)
PaddlePaddle 动态图(推荐)和静态图模式下保存的文件略有不同,但核心文件后缀如下:

1. 动态图(DyGraph)模式

.pdparams 文件

保存模型的 可学习参数(如权重、偏置),通过 model.state_dict() 生成。

paddle.save(model.state_dict(), "model.pdparams")  # 仅保存模型参数

.pdopt 文件

保存 优化器的状态(如动量、梯度历史等),通过 optimizer.state_dict() 生成。

paddle.save(optimizer.state_dict(), "optimizer.pdopt")  # 保存优化器状态

.pdscaler 文件

如果使用了混合精度训练(paddle.amp.GradScaler),保存梯度缩放器的状态

paddle.save(scaler.state_dict(), "scaler.pdscaler")

.pdmeta 或 .pkl 文件

保存其他元信息(如当前 epoch、迭代步数、损失值等),需用户自定义保存。

checkpoint = {
    "epoch": 10,
    "step": 1000,
    "loss": 0.02,
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict()
}
paddle.save(checkpoint, "checkpoint_epoch10.pdparams")  # 自定义后缀

2. 静态图(Static Graph)模式

.pdparams 和 .pdopt 文件

与动态图类似,分别保存模型参数和优化器状态

.ckpt 文件

检查点文件(如 model.ckpt-0),通常通过保存所有持久化变量(包括模型参数和优化器状态)。

3. 恢复训练

恢复训练时需 同时加载模型参数、优化器状态和元信息:

# 加载模型参数和优化器状态
model_state_dict = paddle.load("model.pdparams")
optimizer_state_dict = paddle.load("optimizer.pdopt")
model.set_state_dict(model_state_dict)
optimizer.set_state_dict(optimizer_state_dict)

# 加载元信息(如 epoch、step)
checkpoint = paddle.load("checkpoint_epoch10.pdparams")
current_epoch = checkpoint["epoch"]

二、模型部署文件

部署模型时需要将模型结构和参数固化生成推理专用的文件。PaddlePaddle 支持两种部署格式:

1. 动态图部署文件

使用 paddle.jit.save() 导出为 静态图推理模型(推荐)

.pdmodel

存储模型的 静态图结构(计算图定义),用于推理时加载模型结构

.pdiparams

存储模型的 参数值,与 .pdmodel 配合使用。

示例代码

model.eval()  # 切换为评估模式
input_spec = [paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype="float32")]
paddle.jit.save(model, "deploy_model", input_spec=input_spec)  # 生成 deploy_model.pdmodel 和 deploy_model.pdiparams

2. Paddle Inference 部署

部署时使用 paddle.inference 库加载 .pdmodel 和 .pdiparams

config = paddle.inference.Config("deploy_model.pdmodel", "deploy_model.pdiparams")
predictor = paddle.inference.create_predictor(config)

三、核心区别总结

用途 文件类型 动态图(DyGraph) 静态图(Static Graph)
断点重训 模型参数 .pdparams .pdparams 或 .ckpt
优化器状态 .pdopt .pdopt 或 .ckpt
元信息 自定义(如 .pdmeta 或 .pkl) 自定义
模型部署 模型结构 .pdmodel model
模型参数 .pdiparams params

四、关键注意事项

断点重训

  1. 必须同时保存 模型参数、优化器状态、训练进度元信息,缺一不可。
  2. 混合精度训练时需额外保存 .pdscaler 文件

模型部署

  1. 使用 paddle.jit.save() 导出前需切换模型为评估模式(model.eval())。
  2. 静态图部署需指定输入张量的 InputSpec,确保计算图固定。

文件管理

  1. 建议将**断点文件(.pdparams、.pdopt)部署文件(.pdmodel、.pdiparams)**分目录存储,避免混淆。

通过合理管理这些文件,可以高效实现训练中断恢复模型快速部署!