《昇思25天学习打卡营第9天|保存与加载》

发布于:2024-07-06 ⋅ 阅读:(21) ⋅ 点赞:(0)


今日所学:

在上一章节主要学习了如何调整超参数以进行网络模型训练。在这一过程中,我们通常会想要保存一些中间或最终的结果,以便进行后续的模型微调和推理部署。在本章节,我进一步学习了如何保存和加载模型。


一、构建与准备

首先因为我们已经预装了mindspore,如果还没有安装的可以参考:《昇思25天学习打卡营第2天|快速入门》
引用库和初步构建代码如下:

import numpy as np
import mindspore
from mindspore import nn
from mindspore import Tensor

def network():
    model = nn.SequentialCell(
                nn.Flatten(),
                nn.Dense(28*28, 512),
                nn.ReLU(),
                nn.Dense(512, 512),
                nn.ReLU(),
                nn.Dense(512, 10))
    return model

二、保存和加载模型权重

首先学习了保存和加载模型权重,其中保存模型使用了Mindspore框架的save_checkpoint接口,传入网络和指定的保存路径,代码如下:

model = network()
mindspore.save_checkpoint(model, "model.ckpt")

然后学习了加载模型权重,先创建相同模型的实例,然后使用load_checkpoint和load_param_into_net方法来加载参数,代码如下:

model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

三、保存和加载MindIR

MindSpore除了提供了Checkpoint功能外,还提供了一种统一的中间表示(Intermediate Representation,IR)用于云端(训练)和端侧(推理)。这意味着我们可以使用export接口直接将模型保存为MindIR格式。代码如下:

model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

nn.GraphCell是专为图模式设计的。这意味着在使用MindSpore框架时,我们可以将已经保存的MindIR模型通过load接口轻松加载,并通过传入nn.GraphCell进行推理。但值得注意的是,为了进行这个过程,我们需要先定义输入Tensor以获取输入shape,因为MindIR保存了Checkpoint和模型结构,代码如下:

mindspore.set_context(mode=mindspore.GRAPH_MODE)
​
graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)

在这里插入图片描述

总结

在今天的学习中,我深入了解了如何在模型训练过程中保存和加载模型。我学习了如何利用MindSpore的save_checkpoint接口将模型保存下来,然后通过load_checkpoint和load_param_into_net方法将参数加载到模型中。此外,我还了解了MindSpore提供的统一的中间表示(Intermediate Representation,IR)功能,学习了如何将模型直接保存为MindIR格式,并在需要时加载这些模型进行推理。我还学习了如何使用nn.GraphCell,这是一种专为图模式设计的接口,可以便捷地加载保存的MindIR模型,并进行推理。总的来说,我了解了如何有效地保存训练好的模型,并在需要时加载它们进行后续的微调和推理部署,这对于深度学习的实践非常重要。


网站公告

今日签到

点亮在社区的每一天
去签到