深度学习之--Rnn--图像识别应用

发布于:2025-07-11 ⋅ 阅读:(17) ⋅ 点赞:(0)
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns
from PIL import Image
import random

# 设置随机种子以确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

class ConvolutionalNeuralNetwork(nn.Module):
    """
    卷积神经网络类
    用于图像分类任务,包含卷积层、池化层、全连接层等
    """
    
    def __init__(self, input_channels=3, num_classes=10, dropout_rate=0.5):
        """
        初始化卷积神经网络
        
        参数:
        input_channels (int): 输入图像通道数(RGB为3,灰度为1)
        num_classes (int): 分类类别数量
        dropout_rate (float): Dropout比例
        """
        super(ConvolutionalNeuralNetwork, self).__init__()
        
        # 参数验证
        if not isinstance(input_channels, int) or input_channels <= 0:
            raise ValueError("input_channels必须为正整数")
        if not isinstance(num_classes, int) or num_classes <= 1:
            raise ValueError("num_classes必须为大于1的整数")
        if not (0 <= dropout_rate < 1):
            raise ValueError("dropout_rate必须在[0, 1)区间")
        
        # 卷积层块1
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)  # 第一个卷积层
        self.bn1 = nn.BatchNorm2d(32)  # 批归一化层

        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)  # 第二个卷积层
        self.bn2 = nn.BatchNorm2d(32)  # 批归一化层
        self.pool1 = nn.MaxPool2d(2, 2)  # 最大池化层
        
        # 卷积层块3
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # 第三个卷积层
        self.bn3 = nn.BatchNorm2d(64)  # 批归一化层

        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)  # 第四个卷积层
        self.bn4 = nn.BatchNorm2d(64)  # 批归一化层
        self.pool2 = nn.MaxPool2d(2, 2)  # 最大池化层
        
        # 卷积层块5
        self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)  # 第五个卷积层
        self.bn5 = nn.BatchNorm2d(128)  # 批归一化层

        self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)  # 第六个卷积层
        self.bn6 = nn.BatchNorm2d(128)  # 批归一化层
        self.pool3 = nn.MaxPool2d(2, 2)  # 最大池化层
        
        # 全连接层
        self.dropout = nn.Dropout(dropout_rate)  # Dropout层
        self.fc1 = nn.Linear(128 * 4 * 4, 512)  # 第一个全连接层(假设输入图像为32x32)

        self.fc2 = nn.Linear(512, 256)  # 第二个全连接层
        
        self.fc3 = nn.Linear(256, num_classes)  # 输出层
        
        # 保存参数
        self.input_channels = input_channels
        self.num_classes = num_classes
        self.dropout_rate = dropout_rate
        
    def forward(self, x):
        """
        前向传播过程
        
        参数:
        x (torch.Tensor): 输入图像张量 [batch_size, channels, height, width]
        
        返回:
        torch.Tensor: 各类别的logits
        """
        # 输入验证
        if not torch.is_tensor(x):
            raise TypeError("输入x必须为torch.Tensor")
        if len(x.shape) != 4:
            raise ValueError(f"输入张量应为4维 [batch_size, channels, height, width],实际为{len(x.shape)}维")
        if x.shape[1] != self.input_channels:
            raise ValueError(f"输入通道数应为{self.input_channels},实际为{x.shape[1]}")
        
        # 卷积层块1
        x = F.relu(self.bn1(self.conv1(x)))  # 卷积 -> 批归一化 -> ReLU
        x = F.relu(self.bn2(self.conv2(x)))  # 卷积 -> 批归一化 -> ReLU
        x = self.pool1(x)  # 最大池化
        
        # 卷积层块2
        x = F.relu(self.bn3(self.conv3(x)))  # 卷积 -> 批归一化 -> ReLU
        x = F.relu(self.bn4(self.conv4(x)))  # 卷积 -> 批归一化 -> ReLU
        x = self.pool2(x)  # 最大池化
        
        # 卷积层块3
        x = F.relu(self.bn5(self.conv5(x)))  # 卷积 -> 批归一化 -> ReLU
        x = F.relu(self.bn6(self.conv6(x)))  # 卷积 -> 批归一化 -> ReLU
        x = self.pool3(x)  # 最大池化
        
        # 展平特征图
        x = x.view(x.size(0), -1)  # 将特征图展平为一维向量
        
        # 全连接层
        x = self.dropout(x)  # Dropout
        x = F.relu(self.fc1(x))  # 全连接 -> ReLU
        
        x = self.dropout(x)  # Dropout
        x = F.relu(self.fc2(x))  # 全连接 -> ReLU
        x = self.fc3(x)  # 输出层
        
        return x
    
    def predict_proba(self, x):
        """
        预测各类别的概率
        
        参数:
        x (torch.Tensor): 输入图像张量
        
        返回:
        torch.Tensor: 各类别的概率分布
        """
        self.eval()  # 设置为评估模式
        with torch.no_grad():  # 关闭梯度计算
            logits = self.forward(x)  # 前向传播
            probabilities = F.softmax(logits, dim=1)  # 计算概率
        return probabilities
    
    def predict(self, x):
        """
        预测类别
        
        参数:
        x (torch.Tensor): 输入图像张量
        
        返回:
        torch.Tensor: 预测的类别标签
        """
        probabilities = self.predict_proba(x)  # 获取概率
        return torch.argmax(probabilities, dim=1)  # 返回最大概率的类别
    
    def save(self, path):
        """
        保存模型参数到指定路径
        
        参数:
        path (str): 文件保存路径
        """
        torch.save(self.state_dict(), path)
        print(f"模型已保存到: {path}")
    
    @classmethod
    def load(cls, path, input_channels=3, num_classes=10, dropout_rate=0.5):
        """
        从文件加载模型参数
        
        参数:
        path (str): 文件路径
        input_channels (int): 输入通道数
        num_classes (int): 类别数
        dropout_rate (float): Dropout比例
        
        返回:
        ConvolutionalNeuralNetwork: 加载的模型实例
        """
        model = cls(input_channels, num_classes, dropout_rate)
        model.load_state_dict(torch.load(path, map_location='cpu'))
        model.eval()
        print(f"模型已从 {path} 加载")
        return model

class ImageDataGenerator:
    """
    图像数据生成器类
    用于生成模拟图像数据
    """
    
    @staticmethod
    def generate_synthetic_images(num_samples=1000, image_size=(32, 32), num_classes=10, channels=3):
        """
        生成合成图像数据
        
        参数:
        num_samples (int): 样本数量
        image_size (tuple): 图像尺寸 (height, width)
        num_classes (int): 类别数量
        channels (int): 图像通道数
        
        返回:
        tuple: (图像数据, 标签, 类别名称)
        """
        print("正在生成合成图像数据...")
        
        height, width = image_size
        images = []
        labels = []
        
        # 定义类别名称
        class_names = [f'类别_{i}' for i in range(num_classes)]
        
        for i in range(num_samples):
            # 随机选择类别
            class_id = np.random.randint(0, num_classes)
            
            # 根据类别生成不同特征的图像
            if channels == 3:  # RGB图像
                # 为每个类别生成不同颜色主题的图像
                base_color = np.array([
                    [1.0, 0.0, 0.0],  # 红色
                    [0.0, 1.0, 0.0],  # 绿色
                    [0.0, 0.0, 1.0],  # 蓝色
                    [1.0, 1.0, 0.0],  # 黄色
                    [1.0, 0.0, 1.0],  # 紫色
                    [0.0, 1.0, 1.0],  # 青色
                    [1.0, 0.5, 0.0],  # 橙色
                    [0.5, 0.0, 1.0],  # 紫蓝色
                    [0.0, 0.5, 0.0],  # 深绿色
                    [0.5, 0.5, 0.5],  # 灰色
                ])[class_id % 10]
                
                # 生成带有噪声的图像
                image = np.random.rand(height, width, channels) * 0.3  # 基础噪声
                
                # 添加类别特征(几何形状)
                center_x, center_y = width // 2, height // 2
                
                if class_id % 4 == 0:  # 圆形
                    y, x = np.ogrid[:height, :width]
                    mask = (x - center_x)**2 + (y - center_y)**2 <= (min(width, height) // 4)**2
                    image[mask] = base_color + np.random.rand(3) * 0.2
                
                elif class_id % 4 == 1:  # 矩形
                    start_x, end_x = width // 4, 3 * width // 4
                    start_y, end_y = height // 4, 3 * height // 4
                    image[start_y:end_y, start_x:end_x] = base_color + np.random.rand(3) * 0.2
                
                elif class_id % 4 == 2:  # 三角形
                    for y in range(height // 4, 3 * height // 4):
                        for x in range(width // 4, 3 * width // 4):
                            if x - width // 4 <= (y - height // 4) * 2:
                                image[y, x] = base_color + np.random.rand(3) * 0.2
                
                else:  # 线条
                    image[height // 2 - 2:height // 2 + 2, :] = base_color + np.random.rand(3) * 0.2
                    image[:, width // 2 - 2:width // 2 + 2] = base_color + np.random.rand(3) * 0.2
                
            else:  # 灰度图像
                # 生成灰度图像
                image = np.random.rand(height, width, 1) * 0.3
                intensity = (class_id + 1) / num_classes
                
                # 添加类别特征
                center_x, center_y = width // 2, height // 2
                if class_id % 2 == 0:  # 亮区域
                    y, x = np.ogrid[:height, :width]
                    mask = (x - center_x)**2 + (y - center_y)**2 <= (min(width, height) // 3)**2
                    image[mask] = intensity + np.random.rand(1) * 0.2
                else:  # 暗区域
                    image[height//4:3*height//4, width//4:3*width//4] = intensity + np.random.rand(1) * 0.2
            
            # 确保像素值在[0, 1]范围内
            image = np.clip(image, 0, 1)
            
            images.append(image)
            labels.append(class_id)
        
        # 转换为numpy数组
        images = np.array(images)
        labels = np.array(labels)
        
        # 调整维度顺序为PyTorch格式 [N, C, H, W]
        if channels == 3:
            images = images.transpose(0, 3, 1, 2)
        else:
            images = images.transpose(0, 3, 1, 2)
        
        print(f"生成了 {num_samples} 个样本")
        print(f"图像尺寸: {images.shape}")
        print(f"类别数量: {num_classes}")
        print(f"各类别样本数量: {np.bincount(labels)}")
        
        return images, labels, class_names
    
    @staticmethod
    def create_data_loaders(images, labels, batch_size=32, test_split=0.2, shuffle=True):
        """
        创建数据加载器
        
        参数:
        images (np.ndarray): 图像数据
        labels (np.ndarray): 标签数据
        batch_size (int): 批次大小
        test_split (float): 测试集比例
        shuffle (bool): 是否打乱数据
        
        返回:
        tuple: (训练数据加载器, 测试数据加载器)
        """
        print("正在创建数据加载器...")
        
        # 转换为PyTorch张量
        images_tensor = torch.FloatTensor(images)
        labels_tensor = torch.LongTensor(labels)
        
        # 分割数据
        num_samples = len(images)
        num_test = int(num_samples * test_split)
        
        if shuffle:
            indices = torch.randperm(num_samples)
        else:
            indices = torch.arange(num_samples)
        
        train_indices = indices[num_test:]
        test_indices = indices[:num_test]
        
        # 创建数据集
        train_dataset = TensorDataset(images_tensor[train_indices], labels_tensor[train_indices])
        test_dataset = TensorDataset(images_tensor[test_indices], labels_tensor[test_indices])
        
        # 创建数据加载器
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        
        print(f"训练集样本数: {len(train_dataset)}")
        print(f"测试集样本数: {len(test_dataset)}")
        print(f"批次大小: {batch_size}")
        
        return train_loader, test_loader

class CNNTrainer:
    """
    CNN训练器类
    用于训练和评估卷积神经网络
    """
    
    def __init__(self, model, device='cpu'):
        """
        初始化训练器
        
        参数:
        model: CNN模型
        device (str): 计算设备
        """
        self.model = model.to(device)
        self.device = device
        self.training_history = {
            'train_loss': [],
            'train_accuracy': [],
            'val_loss': [],
            'val_accuracy': []
        }
    
    def train(self, train_loader, val_loader, epochs=50, learning_rate=0.001, patience=10):
        """
        训练模型
        
        参数:
        train_loader: 训练数据加载器
        val_loader: 验证数据加载器
        epochs (int): 训练轮数
        learning_rate (float): 学习率
        patience (int): 早停耐心值
        """
        print(f"\n开始训练CNN模型...")
        print(f"训练轮数: {epochs}")
        print(f"学习率: {learning_rate}")
        print(f"设备: {self.device}")
        
        # 定义损失函数和优化器
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
        
        # 早停机制
        best_val_loss = float('inf')
        patience_counter = 0
        
        for epoch in range(epochs):
            # 训练阶段
            self.model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0
            
            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.to(self.device), target.to(self.device)
                
                # 前向传播
                optimizer.zero_grad()
                output = self.model(data)
                loss = criterion(output, target)
                
                # 反向传播
                loss.backward()
                optimizer.step()
                
                # 统计
                train_loss += loss.item()
                _, predicted = torch.max(output.data, 1)
                train_total += target.size(0)
                train_correct += (predicted == target).sum().item()
            
            # 验证阶段
            self.model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            
            with torch.no_grad():
                for data, target in val_loader:
                    data, target = data.to(self.device), target.to(self.device)
                    output = self.model(data)
                    loss = criterion(output, target)
                    
                    val_loss += loss.item()
                    _, predicted = torch.max(output.data, 1)
                    val_total += target.size(0)
                    val_correct += (predicted == target).sum().item()
            
            # 计算平均损失和准确率
            train_loss_avg = train_loss / len(train_loader)
            val_loss_avg = val_loss / len(val_loader)
            train_accuracy = train_correct / train_total
            val_accuracy = val_correct / val_total
            
            # 记录历史
            self.training_history['train_loss'].append(train_loss_avg)
            self.training_history['train_accuracy'].append(train_accuracy)
            self.training_history['val_loss'].append(val_loss_avg)
            self.training_history['val_accuracy'].append(val_accuracy)
            
            # 学习率调度
            scheduler.step(val_loss_avg)
            
            # 打印进度
            if (epoch + 1) % 5 == 0:
                print(f'轮次 [{epoch+1}/{epochs}], '
                      f'训练损失: {train_loss_avg:.4f}, '
                      f'训练准确率: {train_accuracy:.4f}, '
                      f'验证损失: {val_loss_avg:.4f}, '
                      f'验证准确率: {val_accuracy:.4f}')
            
            # 早停检查
            if val_loss_avg < best_val_loss:
                best_val_loss = val_loss_avg
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"早停在第 {epoch+1} 轮")
                    break
        
        print("训练完成!")
    
    def evaluate(self, test_loader, class_names=None):
        """
        评估模型性能
        
        参数:
        test_loader: 测试数据加载器
        class_names: 类别名称列表
        
        返回:
        dict: 评估结果
        """
        print("\n正在评估模型性能...")
        
        self.model.eval()
        all_predictions = []
        all_targets = []
        test_loss = 0.0
        
        criterion = nn.CrossEntropyLoss()
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = criterion(output, target)
                
                test_loss += loss.item()
                _, predicted = torch.max(output, 1)
                
                all_predictions.extend(predicted.cpu().numpy())
                all_targets.extend(target.cpu().numpy())
        
        # 计算指标
        accuracy = accuracy_score(all_targets, all_predictions)
        test_loss_avg = test_loss / len(test_loader)
        
        # 生成分类报告
        if class_names is not None:
            report = classification_report(all_targets, all_predictions, target_names=class_names)
        else:
            report = classification_report(all_targets, all_predictions)
        
        # 生成混淆矩阵
        cm = confusion_matrix(all_targets, all_predictions)
        
        print(f"测试损失: {test_loss_avg:.4f}")
        print(f"测试准确率: {accuracy:.4f}")
        print(f"\n分类报告:")
        print(report)
        
        return {
            'accuracy': accuracy,
            'test_loss': test_loss_avg,
            'predictions': all_predictions,
            'true_labels': all_targets,
            'confusion_matrix': cm,
            'classification_report': report
        }

class CNNVisualizer:
    """
    CNN可视化工具类
    用于绘制训练过程和结果
    """
    
    @staticmethod
    def plot_training_history(training_history):
        """
        绘制训练历史
        
        参数:
        training_history (dict): 训练历史数据
        """
        plt.rcParams['font.sans-serif'] = ['SimHei']
        plt.rcParams['axes.unicode_minus'] = False
        
        fig, ((ax1, ax2)) = plt.subplots(1, 2, figsize=(15, 5))
        
        # 损失函数曲线
        ax1.plot(training_history['train_loss'], label='训练损失', color='blue')
        ax1.plot(training_history['val_loss'], label='验证损失', color='red')
        ax1.set_xlabel('训练轮次')
        ax1.set_ylabel('损失值')
        ax1.set_title('损失函数变化')
        ax1.legend()
        ax1.grid(True)
        
        # 准确率曲线
        ax2.plot(training_history['train_accuracy'], label='训练准确率', color='blue')
        ax2.plot(training_history['val_accuracy'], label='验证准确率', color='red')
        ax2.set_xlabel('训练轮次')
        ax2.set_ylabel('准确率')
        ax2.set_title('准确率变化')
        ax2.legend()
        ax2.grid(True)
        
        plt.tight_layout()
        plt.show()
    
    @staticmethod
    def plot_confusion_matrix(cm, class_names=None):
        """
        绘制混淆矩阵
        
        参数:
        cm (np.ndarray): 混淆矩阵
        class_names (list): 类别名称
        """
        plt.figure(figsize=(10, 8))
        
        if class_names is None:
            class_names = [f'类别 {i}' for i in range(len(cm))]
        
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                   xticklabels=class_names, yticklabels=class_names)
        plt.title('混淆矩阵')
        plt.xlabel('预测标签')
        plt.ylabel('真实标签')
        plt.show()
    
    @staticmethod
    def plot_sample_images(images, labels, class_names, num_samples=16):
        """
        绘制样本图像
        
        参数:
        images (np.ndarray): 图像数据
        labels (np.ndarray): 标签数据
        class_names (list): 类别名称
        num_samples (int): 显示的样本数量
        """
        plt.figure(figsize=(12, 8))
        
        for i in range(min(num_samples, len(images))):
            plt.subplot(4, 4, i + 1)
            
            # 调整图像维度用于显示
            if images.shape[1] == 3:  # RGB图像
                img = images[i].transpose(1, 2, 0)
            else:  # 灰度图像
                img = images[i].squeeze()
                plt.imshow(img, cmap='gray')
            
            if images.shape[1] == 3:
                plt.imshow(img)
            
            plt.title(f'{class_names[labels[i]]}')
            plt.axis('off')
        
        plt.tight_layout()
        plt.show()
    
    @staticmethod
    def plot_feature_maps(model, image, layer_name='conv1'):
        """
        可视化特征图
        
        参数:
        model: 训练好的模型
        image (torch.Tensor): 输入图像
        layer_name (str): 要可视化的层名称
        """
        model.eval()
        
        # 获取指定层的输出
        def hook_fn(module, input, output):
            global feature_maps
            feature_maps = output
        
        # 注册钩子
        if hasattr(model, layer_name):
            handle = getattr(model, layer_name).register_forward_hook(hook_fn)
        else:
            print(f"模型中未找到层: {layer_name}")
            return
        
        # 前向传播
        with torch.no_grad():
            _ = model(image.unsqueeze(0))
        
        # 移除钩子
        handle.remove()
        
        # 可视化特征图
        if 'feature_maps' in globals():
            feature_maps_np = feature_maps.squeeze().cpu().numpy()
            
            num_filters = min(16, feature_maps_np.shape[0])
            fig, axes = plt.subplots(4, 4, figsize=(12, 12))
            
            for i in range(num_filters):
                ax = axes[i // 4, i % 4]
                ax.imshow(feature_maps_np[i], cmap='viridis')
                ax.set_title(f'特征图 {i+1}')
                ax.axis('off')
            
            plt.tight_layout()
            plt.show()

def main():
    """
    主函数:完整的CNN图像识别流程
    """
    print("=== PyTorch 卷积神经网络图像识别 ===\n")
    
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    
    # 1. 生成模拟图像数据
    print("\n1. 生成模拟图像数据")
    images, labels, class_names = ImageDataGenerator.generate_synthetic_images(
        num_samples=2000,
        image_size=(32, 32),
        num_classes=10,
        channels=3
    )
    
    # 2. 创建数据加载器
    print("\n2. 创建数据加载器")
    train_loader, test_loader = ImageDataGenerator.create_data_loaders(
        images, labels, batch_size=32, test_split=0.2
    )
    
    # 3. 显示样本图像
    print("\n3. 显示样本图像")
    CNNVisualizer.plot_sample_images(images, labels, class_names)
    
    # 4. 创建CNN模型
    print("\n4. 创建CNN模型")
    model = ConvolutionalNeuralNetwork(
        input_channels=3,
        num_classes=10,
        dropout_rate=0.5
    )
    
    print(f"模型结构:")
    print(model)
    
    # 5. 训练模型
    print("\n5. 训练模型")
    trainer = CNNTrainer(model, device)
    trainer.train(
        train_loader=train_loader,
        val_loader=test_loader,
        epochs=30,
        learning_rate=0.001,
        patience=10
    )
    
    # 6. 评估模型
    print("\n6. 评估模型")
    results = trainer.evaluate(test_loader, class_names)
    
    # 7. 可视化结果
    print("\n7. 可视化结果")
    
    # 绘制训练历史
    CNNVisualizer.plot_training_history(trainer.training_history)
    
    # 绘制混淆矩阵
    CNNVisualizer.plot_confusion_matrix(results['confusion_matrix'], class_names)
    
    # 8. 示例预测
    print("\n8. 示例预测")
    
    # 获取一些测试样本
    test_images, test_labels = next(iter(test_loader))
    test_images = test_images.to(device)
    
    # 预测
    model.eval()
    with torch.no_grad():
        predictions = model.predict(test_images)
        probabilities = model.predict_proba(test_images)
    
    # 显示预测结果
    for i in range(min(5, len(test_images))):
        pred_class = predictions[i].item()
        true_class = test_labels[i].item()
        prob = probabilities[i][pred_class].item()
        
        print(f"样本 {i+1}:")
        print(f"  真实类别: {class_names[true_class]}")
        print(f"  预测类别: {class_names[pred_class]}")
        print(f"  预测概率: {prob:.4f}")
        print(f"  预测正确: {'是' if pred_class == true_class else '否'}")
        print()
    
    # 9. 可视化特征图
    print("\n9. 可视化特征图")
    sample_image = test_images[0].cpu()
    CNNVisualizer.plot_feature_maps(model.cpu(), sample_image, 'conv1')
    
    # 10. 保存模型
    print("\n10. 保存模型")
    model.save('cnn_image_classifier.pth')
    
    # 11. 模型加载示例
    print("\n11. 模型加载示例")
    loaded_model = ConvolutionalNeuralNetwork.load('cnn_image_classifier.pth', 3, 10, 0.5)
    
    # 验证加载的模型
    with torch.no_grad():
        original_pred = model.predict(test_images[:1])
        loaded_pred = loaded_model.predict(test_images[:1].cpu())
        
        print(f"原始模型预测: {original_pred.item()}")
        print(f"加载模型预测: {loaded_pred.item()}")
        print(f"预测一致性: {'一致' if original_pred.item() == loaded_pred.item() else '不一致'}")
    
    print("\n=== 程序执行完成 ===")

if __name__ == "__main__":
    main()

-------------------------------------------------------------以下是Rnn图像识别的调用流程--------------------------


网站公告

今日签到

点亮在社区的每一天
去签到