本文演示了PyTorch中张量(Tensor)和模型参数的保存与加载方法,并提供完整的代码示例及输出结果,帮助读者快速掌握数据持久化的核心操作。
1. 保存和加载单个张量
通过torch.save
和torch.load
可以直接保存和读取张量。
import torch
# 创建并保存张量
x = torch.arange(4)
torch.save(x, 'x-file')
# 加载张量
x2 = torch.load('x-file')
print(x2) # 输出:tensor([0, 1, 2, 3])
输出结果:
tensor([0, 1, 2, 3])
2. 保存和加载张量列表
可以将多个张量存储为列表,并一次性加载。
# 创建两个张量并保存为列表
y = torch.zeros(4)
torch.save([x, y], 'x-files')
# 加载列表
x2, y2 = torch.load('x-files')
print((x2, y2))
输出结果:
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
3. 保存和加载字典
通过字典可以更灵活地管理多个张量。
# 创建字典并保存
mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
# 加载字典
mydict2 = torch.load('mydict')
print(mydict2)
输出结果:
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
4. 定义神经网络模型
以下是一个简单的全连接神经网络示例:
from torch import nn
from torch.nn import functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.hidden = nn.Linear(20, 256) # 隐藏层
self.output = nn.Linear(256, 10) # 输出层
def forward(self, x):
return self.output(F.relu(self.hidden(x)))
# 实例化模型并进行前向传播
net = Model()
x = torch.rand(size=(2, 20))
y = net(x)
print(y)
输出结果(因随机初始化可能不同):
tensor([[-0.0711, 0.1161, -0.1113, ..., 0.0787],
[-0.0151, 0.0275, -0.1652, ..., 0.0109]], grad_fn=<AddmmBackward0>)
5. 保存模型参数
使用state_dict
保存模型参数:
torch.save(net.state_dict(), 'net.params')
6. 加载模型参数并验证
加载参数到新模型实例,并验证一致性:
# 创建新模型并加载参数
clone = Model()
clone.load_state_dict(torch.load('net.params'))
clone.eval() # 设置为评估模式(关闭Dropout/BatchNorm等)
# 比较输出结果
Y_clone = clone(x)
print(Y_clone == y)
输出结果:
tensor([[True, True, ..., True],
[True, True, ..., True]])
总结
张量读写:直接使用
torch.save
和torch.load
,支持列表和字典。模型参数保存:通过
state_dict
保存模型状态,加载时需重新实例化模型。验证一致性:加载参数后,输出与原模型一致表明操作成功。
通过本文的代码示例,读者可以快速掌握PyTorch中数据和模型参数的持久化方法,为模型训练和部署提供便利。