深入学习Pytorch:第二章-模型使用

发布于:2025-04-06 ⋅ 阅读:(35) ⋅ 点赞:(0)

书接上文,在我们进行了训练生成了模型的情况下,如何使用生成的模型那?针对线性模型保存w,b就可以了,但是对相对复杂的模型Pytorch提供了模型的保存和加载机制。

模型保存

上一篇文章代码最后增加下面代码可以将模型保存

	# 保存模型
	torch.save(model.state_dict(), 'liner_model.pth')

model.state_dict() 函数是PyTorch中用于获取模型参数的函数,,它返回一个简单的Python字典对象,其中每一层与它的对应参数建立映射关系。这个字典包含了模型中所有可以训练的层的参数,如卷积层、线性层等的权重(weights)和偏置(bias)等。值得注意的是,只有那些参数可以训练的层才会被保存到模型的state_dict中。

模型本身就是一个压缩包。

下面的代码可以打印出模型中的数据

	print('Model.state_dict:')
	for param_tensor in model.state_dict():
	  # 打印 key value字典
	  print(param_tensor, '\t', model.state_dict()[param_tensor].size())

模型加载

加载模型进行预测

	import os
	os.environ['kmp_duplicate_lib_ok'] = "TRUE"
	import torch
	import torch.nn as nn
	import numpy as np

	# 2. 设置线性模型,nn是重要的模块
	input_size = 1
	output_size = 1
	model = nn.Linear(input_size, output_size)
	model.load_state_dict(torch.load('liner_model.pth'))

	# 5. 通过模型进行预测
	# 模型设置为评估状态
	model.eval()
	x_test = np.array([[6.5]], dtype=np.float32)
	# 禁用梯度计算
	with torch.no_grad():
	  predictions = model(torch.from_numpy(x_test))  # 使用模型进行预测
	  print(predictions.data.numpy())

可以基于已经训练的模型结合新的数据进行训练

模型格式

PyTorch中,.pt、.pth和.pth.tar都是用于保存训练好的模型的文件格式。

.pt文件是PyTorch 1.6及以上版本中引入的新的模型文件格式,它可以保存整个PyTorch模型,包括模型结构、模型参数以及优化器状态等信息。.pt文件是一个二进制文件,可以通过torch.save()函数来保存模型,以及通过torch.load()函数来加载模型。

.pth文件是PyTorch旧版本中使用的模型文件格式,它只保存了模型参数,没有保存模型结构和其他相关信息。.pth文件同样是一个二进制文件,可以通过torch.save()函数来保存模型参数,以及通过torch.load()函数来加载模型参数。

.pth.tar文件是一个压缩文件,它包含一个.pth文件以及其他相关信息,比如模型结构、优化器状态、超参数等。.pth.tar文件可以通过Python的标准库tarfile来解压,然后通过torch.load()函数来加载模型。

.safetensors格式

随着huggingface transformers的流行,越来越多的模型采用了.safetensors的文件存储格式。它有什么特点?与传统的torch.save保存的文件有何不同?

这是huggingface设计的一种新格式,大致就是以更加紧凑、跨框架的方式存储Dict[str, Tensor],主要存储的内容为tensor的名字(字符串)及内容(权重)。

safetensors格式

本质上就是一个JSON文件加上若干binary形式的buffer。对于tensor而言,它只存储了数据类型、形状、对应的数据区域起点和终点。

如何操作.safetensors格式

# 安装safetensors组件
## pip进行安装
pip install safetensors
## conda进行安装
conda install -c huggingface safetensors

使用一份简单的Python代码就可以将.safetensors文件转成内存中的Dict[str, Tensor]类型的对象:

from safetensors import safe_open

tensors = {}
with safe_open("mix4_v10.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

网站公告

今日签到

点亮在社区的每一天
去签到