书接上文,在我们进行了训练生成了模型的情况下,如何使用生成的模型那?针对线性模型保存w,b就可以了,但是对相对复杂的模型Pytorch提供了模型的保存和加载机制。
模型保存
上一篇文章代码最后增加下面代码可以将模型保存
# 保存模型
torch.save(model.state_dict(), 'liner_model.pth')
model.state_dict() 函数是PyTorch中用于获取模型参数的函数,,它返回一个简单的Python字典对象,其中每一层与它的对应参数建立映射关系。这个字典包含了模型中所有可以训练的层的参数,如卷积层、线性层等的权重(weights)和偏置(bias)等。值得注意的是,只有那些参数可以训练的层才会被保存到模型的state_dict中。
模型本身就是一个压缩包。
下面的代码可以打印出模型中的数据
print('Model.state_dict:')
for param_tensor in model.state_dict():
# 打印 key value字典
print(param_tensor, '\t', model.state_dict()[param_tensor].size())
模型加载
加载模型进行预测
import os
os.environ['kmp_duplicate_lib_ok'] = "TRUE"
import torch
import torch.nn as nn
import numpy as np
# 2. 设置线性模型,nn是重要的模块
input_size = 1
output_size = 1
model = nn.Linear(input_size, output_size)
model.load_state_dict(torch.load('liner_model.pth'))
# 5. 通过模型进行预测
# 模型设置为评估状态
model.eval()
x_test = np.array([[6.5]], dtype=np.float32)
# 禁用梯度计算
with torch.no_grad():
predictions = model(torch.from_numpy(x_test)) # 使用模型进行预测
print(predictions.data.numpy())
可以基于已经训练的模型结合新的数据进行训练
模型格式
PyTorch中,.pt、.pth和.pth.tar都是用于保存训练好的模型的文件格式。
.pt文件是PyTorch 1.6及以上版本中引入的新的模型文件格式,它可以保存整个PyTorch模型,包括模型结构、模型参数以及优化器状态等信息。.pt文件是一个二进制文件,可以通过torch.save()函数来保存模型,以及通过torch.load()函数来加载模型。
.pth文件是PyTorch旧版本中使用的模型文件格式,它只保存了模型参数,没有保存模型结构和其他相关信息。.pth文件同样是一个二进制文件,可以通过torch.save()函数来保存模型参数,以及通过torch.load()函数来加载模型参数。
.pth.tar文件是一个压缩文件,它包含一个.pth文件以及其他相关信息,比如模型结构、优化器状态、超参数等。.pth.tar文件可以通过Python的标准库tarfile来解压,然后通过torch.load()函数来加载模型。
.safetensors格式
随着huggingface transformers的流行,越来越多的模型采用了.safetensors的文件存储格式。它有什么特点?与传统的torch.save保存的文件有何不同?
这是huggingface设计的一种新格式,大致就是以更加紧凑、跨框架的方式存储Dict[str, Tensor],主要存储的内容为tensor的名字(字符串)及内容(权重)。
本质上就是一个JSON文件加上若干binary形式的buffer。对于tensor而言,它只存储了数据类型、形状、对应的数据区域起点和终点。
如何操作.safetensors格式
# 安装safetensors组件
## pip进行安装
pip install safetensors
## conda进行安装
conda install -c huggingface safetensors
使用一份简单的Python代码就可以将.safetensors文件转成内存中的Dict[str, Tensor]类型的对象:
from safetensors import safe_open
tensors = {}
with safe_open("mix4_v10.safetensors", framework="pt", device=0) as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)