模型训练验证
损失函数和优化器
loss_function = nn.CrossEntropyLoss() # 损失函数
optimizer = Adam(model.parameters()) # 优化器,优化参数
模型优化
获得模型所有的可训练参数(比如每一层的权重、偏置),设置优化器类型,自动调整学习步长(自适应学习率),后续训练更新参数。
# 雇佣Adam教练,让他管理模型参数
optimizer = Adam(model.parameters(), lr=0.001) # lr是初始学习率
# 1. optimizer.zero_grad() # 清空上一轮的成绩单
# 2. loss.backward() # 计算每个参数要改进的方向(梯度)
# 3. optimizer.step() # 参数调整
训练函数
def train():
loss = 0
accuracy = 0
model.train()
for x, y in train_loader: # 获得每个batch数据
x, y = x.to(device), y.to(device)
output = model(x) # 得到预测label
optimizer.zero_grad() # 梯度清零
batch_loss = loss_function(output, y) # 计算batch误差
batch_loss.backward() # 计算误差梯度
optimizer.step() # 调整模型参数
loss += batch_loss.item()
accuracy += get_batch_accuracy(output, y, train_N)
print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))
验证函数
def validate():
loss = 0
accuracy = 0
model.eval() # 评估模式,关闭随机性等增加稳定性
with torch.no_grad(): # 禁用梯度,提高效率
for x, y in valid_loader:
x, y = x.to(device), y.to(device)
output = model(x)
# 不用进行梯度计算、参数调整
loss += loss_function(output, y).item()
accuracy += get_batch_accuracy(output, y, valid_N)
print('Valid - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))
模型保存
.pth 文件是PyTorch模型的“存档文件”,保存了所有必要信息。加载后,模型即可直接运行,无需重新训练!
# 保存整个模型(结构 + 参数)
torch.save(model, 'model.pth')
.pth 文件可以用https://netron.app/查看