保存 CheckPoint 格式文件,在模型训练过程中,可以添加检查点(CheckPoint)用于保存模型的参数,以便进行推理及再训练使用。如果想继续在不同硬件平台上做推理,可通过网络和CheckPoint格式文件生成对应的MINDIR、AIR和ONNX格式文件。model = network()
mindspore.save_checkpoint(model, "model.ckpt")
可以通过CheckpointConfig对象可以设置CheckPoint的保存策略。
- save_checkpoint_steps表示每隔多少个step保存一次。
- keep_checkpoint_max表示最多保留CheckPoint文件的数量。
- prefix表示生成CheckPoint文件的前缀名。
- directory表示存放文件的目录。
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix='resnet50', directory=None, config=config_ck)
model.train(epoch_num, dataset, callbacks=ckpoint_cb)
要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpoint
和load_param_into_net
方法加载参数。 model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)
param_not_load
是未被加载的参数列表,为空时代表所有参数均加载成功。[]