第三十一篇 AI的“能力考”:模型评估、保存与加载的艺术【总结前面3】

发布于:2025-08-01 ⋅ 阅读:(15) ⋅ 点赞:(0)

前言:从“学习”到“应用”的最后一步

在上一章,我们亲手搭建并驱动了一个AI的“思考引擎”,它正在努力学习识别手写数字。但模型学得怎么样?它的“识字”能力达到了什么水平?更重要的是,辛辛苦苦训练出的这个AI“大脑”,如何才能保存下来,以便未来可以投入实际应用,或者分享给他人呢?
使用模型注意事项

这些问题,正是我们今天将要解决的。本章将带领你完成AI学习的最后闭环:模型评估(检验成果)、模型保存(持久化智慧)和模型加载与推理(让智慧重焕光彩,投入实战)。

第一章:AI的“成绩单”——模型评估

在训练过程中,我们关注损失的下降。但损失下降并不代表模型真的“学得好”,我们还需要用独立的测试集来评估它的泛化能力
AI成绩单

1.1 评估模式:model.eval()与torch.no_grad()的智慧

在评估模型时,我们必须切换到评估模式。这是为了:

model.eval():告诉模型现在是评估阶段,关闭nn.Dropout(防止随机丢弃神经元)和nn.BatchNorm(停止更新均值和方差,使用训练时的统计量)等层。这确保了模型在评估时的行为是确定的、可重复的。

torch.no_grad():创建一个上下文管理器,告诉PyTorch在这个代码块内部,不要计算梯度。
为什么? 评估阶段我们不需要进行反向传播来更新参数,计算梯度是多余的,这会浪费计算资源和内存。

好处:加速评估过程,减少内存占用。

1.2 分类任务的“及格线”:准确率(Accuracy)

对于分类任务,最直观、最常用的评估指标就是准确率(Accuracy)。

准确率 = (正确预测的样本数 / 总样本数) * 100%

1.3 亲手评估你的AI模型准确率

目标:对上一章训练好的MNIST分类器进行准确率评估,并可视化部分预测结果。

前置:你需要确保上一章simple_mnist_classifier_full.py已经运行过,并且其训练好的模型权重文件simple_mnist_mlp.pth已经保存在mnist_results/目录下

代码展示

# case_10_3_model_evaluation.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os

# --- 0. 定义模型结构和加载权重 (与训练时保持一致) ---
# 这部分代码必须和训练模型的 SimpleMLPC Classifier 类定义完全相同
class SimpleMLPClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleMLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = x.view(-1, input_dim) # 注意这里的 input_dim 变量需要传入或定义
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# --- 加载训练时定义的超参数 ---
INPUT_DIM = 28 * 28
HIDDEN_DIM = 256
OUTPUT_DIM = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = 'mnist_results/simple_mnist_mlp.pth' # 上一章保存的模型权重路径


def evaluate_model(model, test_loader, device):
    model.eval() # 设置模型为评估模式
    correct = 0
    total = 0
    # 用于记录错误预测,以便可视化
    wrong_predictions = [] 
    
    with torch.no_grad(): # 在评估时,禁用梯度计算
        for data, target in test_loader:
            data, target = data.to(device), target.to(device) # 将数据移动到设备
            outputs = model(data) # 前向传播,获取模型预测的Logits
            
            # 从Logits中找到预测概率最高的类别
            # torch.max(outputs.data, 1) 返回每一行最大值及其索引。1表示在维度1上求最大值
            _, predicted = torch.max(outputs.data, 1) # _ 是最大值,我们只关心索引(predicted类别)
            
            total += target.size(0) # 累加当前批次的样本总数
            # 比较预测类别和真实类别,并累加正确预测的数量
            correct += (predicted == target).sum().item() 
            
            # 记录错误的预测 (用于可视化)
            incorrect_mask = (predicted != target)
            for i in range(len(incorrect_mask)):
                if incorrect_mask[i]:
                    wrong_predictions.append((data[i].cpu(), target[i].cpu(), predicted[i].cpu()))

    accuracy = 100 * correct / total
    print(f'在 {total} 张测试图片上的准确率: {accuracy:.2f}%')
    return accuracy, wrong_predictions

def visualize_predictions(test_loader, model, device, num_display=10):
    model.eval()
    
    # 随机选择一个Batch用于可视化
    data_iter = iter(test_loader)
    data, labels = next(data_iter)
    data, labels = data.to(device), labels.to(device)

    with torch.no_grad():
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
    
    plt.figure(figsize=(12, 6))
    plt.suptitle("模型预测结果示例 (绿色为正确,红色为错误)", fontsize=16)
    for i in range(num_display):
        plt.subplot(2, 5, i + 1)
        # .squeeze() 移除单维度通道 (例如 [1, 28, 28] -> [28, 28])
        plt.imshow(data[i].cpu().squeeze(), cmap='gray')
        
        is_correct = (predicted[i] == labels[i]).item()
        color = 'green' if is_correct else 'red'
        
        plt.title(f"Pred: {predicted[i].item()}\nTrue: {labels[i].item()}", color=color)
        plt.axis('off')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig('mnist_results/sample_predictions.png')
    plt.show()

# --- 主执行流程 ---
if __name__ == '__main__':
    # 确保mnist_results目录存在
    os.makedirs('mnist_results', exist_ok=True)

    # 加载测试数据集
    transform = transforms.ToTensor()
    test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)
    test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False) # 批次大小可以设小一些

    # 实例化模型
    model_for_eval = SimpleMLPClassifier(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM).to(DEVICE)
    
    # 加载预训练权重 (这是关键!确保你有这个文件)
    try:
        model_for_eval.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
        print(f"✅ 模型权重已从 '{MODEL_PATH}' 成功加载。")
    except FileNotFoundError:
        print(f"❌ 错误:未找到模型权重文件 '{MODEL_PATH}'。请先运行上一章代码训练并保存模型!")
        exit() # 如果模型文件不存在,无法继续

    print("\n--- 开始评估模型 ---")
    accuracy, _ = evaluate_model(model_for_eval, test_loader, DEVICE)
    print(f"\n模型最终准确率为: {accuracy:.2f}%")

    print("\n--- 可视化部分预测结果 ---")
    visualize_predictions(test_loader, model_for_eval, DEVICE, num_display=10)
    print(f"部分预测结果图已保存到: mnist_results/sample_predictions.png")

代码解读与见证奇迹】

运行这段代码,你将看到模型在MNIST测试集上的准确率,通常会达到90%以上。
可视化部分,你会看到一系列原始的数字图片,上面标注着模型的预测结果和真实标签。正确预测的标题是绿色的,错误预测是红色的,让你直观感受到AI的“识字”能力!

这证明了我们亲手搭建的MLP模型,经过简单的训练,已经具备了对未见过数据进行识别的泛化能力。

第二章:AI“智慧”的持久化——模型保存

学习如何将辛苦训练好的AI模型“存档”到硬盘,并深入对比torch.save(基于Pickle)和safetensors两种主流保存方式的安全与效率差异。
模型持久化

2.1 PyTorch的“传统艺能”:torch.save与state_dict

这是PyTorch原生的保存方法。最推荐的做法是只保存模型的state_dict
model.state_dict():它返回一个Python字典,包含了模型所有可学习参数(权重和偏置)的副本。这就像模型的“灵魂”或“基因组”。

优点:文件小,只包含数据,不包含代码逻辑,加载时需明确模型结构,相对安全。

缺点:底层使用Python的pickle协议,这会带来潜在的安全风险

2.2 新王登基:.safetensors的安全与高效

为了解决pickle的安全问题,Hugging Face社区推出了**.safetensors**格式。

核心优势:

  1. 绝对安全:只存储Tensor的原始二进制数据和JSON格式的元数据(形状、类型),不包含任何可

2.执行代码。safetensors.torch.load_file()不会执行任意代码。

  1. 加载极快:特别是在部分加载(分片)或跨语言加载时,其效率远超pickle。

  2. 跨平台/框架:设计上就考虑了不同AI框架(PyTorch, TensorFlow, JAX)的兼容性。
    推荐:在分享模型或加载未知来源的模型时,优先使用.safetensors格式。

2.3 将你的AI模型安全“存档”

目标:使用torch.save和safetensors两种方法,保存我们训练好的MLP分类器的权重。

前置:假设你已经运行了上一章的代码,并且训练好的模型实例trained_model可用。

# case_10_3_model_saving.py

import torch
import torch.nn as nn
import os
# 需要安装safetensors库: pip install safetensors
from safetensors.torch import save_file, load_file

# --- 0. 准备工作:定义模型结构 (同评估时一致) ---
class SimpleMLPClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleMLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        x = x.view(-1, input_dim) # 注意这里的 input_dim 变量需要传入或定义
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# --- 导入训练时定义的超参数 (或直接定义) ---
INPUT_DIM = 28 * 28
HIDDEN_DIM = 256
OUTPUT_DIM = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_DIR = 'mnist_results' # 结果目录
os.makedirs(MODEL_DIR, exist_ok=True)

# 假设我们有一个已经训练好的模型实例
# 在实际运行中,你可以从上一章的 main_training_loop 返回 trained_model
# 这里我们为了独立运行,先实例化并模拟加载权重
model_to_save = SimpleMLPClassifier(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM).to(DEVICE)
# 模拟加载一个随机权重,表示它是“训练好的”
model_to_save.load_state_dict(torch.load('mnist_results/simple_mnist_mlp.pth', map_location=DEVICE))
model_to_save.eval()
print("模型实例已准备好进行保存。")

# --- 1. 策略一:使用 torch.save 只保存 State Dict (.pth/.pt) ---
print("\n--- 策略一:使用 torch.save 保存模型权重 ---")
pytorch_save_path = os.path.join(MODEL_DIR, 'simple_mnist_mlp_weights_torch.pth')
# model.state_dict() 获取模型的权重字典
torch.save(model_to_save.state_dict(), pytorch_save_path)
print(f"✅ 模型权重已用 torch.save 保存到: {pytorch_save_path}")

# --- 2. 策略二:使用 safetensors 保存 State Dict (.safetensors) ---
print("\n--- 策略二:使用 safetensors 安全保存模型权重 ---")
safetensors_save_path = os.path.join(MODEL_DIR, 'simple_mnist_mlp.safetensors')
# safetensors.torch.save_file() 函数更推荐
save_file(model_to_save.state_dict(), safetensors_save_path)
print(f"✅ 模型权重已用 safetensors 安全保存到: {safetensors_save_path}")

print("\n模型已成功保存到两种文件格式中!")

代码解读与安全警示】
运行这段代码,你会看到两个文件被创建:.pth和.safetensors。它们都包含了模型的权重信息,但背后的安全性却天差地别。
⚠️ 再次提醒: torch.save()(使用Pickle协议)在加载不信任来源的文件时存在任意代码执行的风险。safetensors则从设计上规避了这一风险,是更安全的选择。

第三章:AI的“重生”与“实战”——模型加载与推理

学习如何将保存的模型加载回内存,并用它来对新的、单张图片进行推理预测。
模型重生

3.1 加载模型:让AI的“智慧”重新焕发生机

模型加载,就是将硬盘上的模型文件重新读取到内存中,再次实例化为PyTorch模型对象的过程。

加载state_dict: 始终是首选。你需要先定义模型的类结构(即SimpleMLPClassifier),然后使用

model.load_state_dict(torch.load(…))来加载权重。

map_location: 在torch.load()时,map_location参数非常有用。它可以指定将模型加载到CPU或GPU,尤其当你在GPU上训练但在CPU上推理时。

3.2 加载模型并进行单张图片推理

目标:加载我们之前保存的.safetensors模型,并用它来预测一张新的手写数字图片。
前置:确保simple_mnist_mlp.safetensors文件已存在。

# case_10_3_model_loading_inference.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
# 需要安装safetensors库: pip install safetensors
from safetensors.torch import save_file, load_file # 导入load_file函数

# --- 0. 准备工作:定义模型结构 (必须和训练/保存时完全一致) ---
class SimpleMLPClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleMLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        x = x.view(-1, input_dim) # 注意这里的 input_dim 变量需要传入或定义
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# --- 定义模型超参数和文件路径 ---
INPUT_DIM = 28 * 28
HIDDEN_DIM = 256
OUTPUT_DIM = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH_SAFETENSORS = 'mnist_results/simple_mnist_mlp.safetensors' # 使用safetensors保存的文件

# --- 1. 加载模型 ---
print("--- 1. 加载模型 ---")
# 实例化一个新的模型“躯体”
loaded_model = SimpleMLPClassifier(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM).to(DEVICE)

# 加载保存的权重文件 (.safetensors)
try:
    # load_file 返回一个state_dict
    loaded_state_dict = load_file(MODEL_PATH_SAFETENSORS)
    loaded_model.load_state_dict(loaded_state_dict)
    print(f"✅ 模型权重已从 '{MODEL_PATH_SAFETENSORS}' 成功加载。")
except FileNotFoundError:
    print(f"❌ 错误:未找到模型权重文件 '{MODEL_PATH_SAFETENSORS}'。请先运行上一章代码训练并保存模型!")
    exit()
loaded_model.eval() # 设置为评估模式,非常重要!
print("模型已准备好进行推理。")

# --- 2. 准备一张新的测试图片进行推理 ---
print("\n--- 2. 准备新的测试图片进行推理 ---")
# 从MNIST测试集中随机取一张图片
transform = transforms.ToTensor()
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)
# 随机选择一个索引
random_idx = np.random.randint(0, len(test_dataset))
sample_image, true_label = test_dataset[random_idx]

print(f"随机选择的图片索引: {random_idx}")
print(f"真实标签: {true_label}")

# 将图片添加到Batch维度,并移动到设备
# 从 [1, 28, 28] -> [1, 1, 28, 28] (增加Batch维度)
input_for_inference = sample_image.unsqueeze(0).to(DEVICE) 
print(f"用于推理的图片形状: {input_for_inference.shape}")

# --- 3. 进行推理预测 ---
print("\n--- 3. 模型进行推理预测 ---")
with torch.no_grad(): # 推理时禁用梯度计算
    output_logits = loaded_model(input_for_inference)
    # 获取概率分布 (可选,CrossEntropyLoss内部已做)
    probabilities = F.softmax(output_logits, dim=1)
    # 获取预测类别
    _, predicted_class = torch.max(output_logits, 1)

predicted_class_item = predicted_class.item() # 从Tensor中提取预测类别数值
predicted_prob = probabilities[0, predicted_class_item].item() # 获取预测类别的概率

print(f"模型预测类别: {predicted_class_item}")
print(f"预测概率: {predicted_prob*100:.2f}%")

# --- 4. 可视化结果 ---
plt.figure(figsize=(4, 4))
plt.imshow(sample_image.squeeze().numpy(), cmap='gray') # 显示图片
plt.title(f"预测: {predicted_class_item} (真实: {true_label})", 
          color='green' if predicted_class_item == true_label else 'red', 
          fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, f'inference_result_{random_idx}.png'))
plt.show()

print("\n🎉 模型推理完成,结果已可视化!")`在这里插入代码片`

代码解读与见证奇迹】
运行这段代码,你会看到:
模型权重从.safetensors文件被成功加载。
随机从测试集中选择一张图片,并显示出来。
模型对这张图片给出了准确的预测,并且显示了对应的概率。
这证明了你的AI模型,在经过训练、保存、加载之后,能够重新“复活”,并投入到实际的推理应用中。你已经掌握了AI模型从“实验室”走向“真实世界”的最后一步

总结与展望:你已拥有AI的“入门级驾驶证”

总结与展望:你已拥有AI的“入门级驾驶证”
恭喜你!今天你已经通过亲手编写和运行代码,完成了AI学习流程的最后闭环。
✨ 本章惊喜概括 ✨

你掌握了什么? 对应的核心操作/概念
准确评估模型 ✅ model.eval(), torch.no_grad(), 准确率计算
AI“智慧”的持久化 ✅ torch.save和safetensors保存state_dict
AI的“重生” ✅ 加载模型权重,map_location
AI的“实战应用” ✅ 对单张图片进行推理预测
你现在不仅仅是“听说过”AI模型,你已经能够从零开始搭建、训练、评估、保存、加载、并使用一个完整的AI模型了! 你已经拥有了AI世界的“入门级驾驶证”,可以自信地开始探索更复杂的AI应用和架构。

网站公告

今日签到

点亮在社区的每一天
去签到