《零基础入门AI:深度学习之全连接网络学习(过拟合处理、批标准化与模型管理)》

发布于:2025-08-15 ⋅ 阅读:(27) ⋅ 点赞:(0)

一、过拟合与欠拟合

1. 概念认知
  • 欠拟合:模型无法捕捉数据的基本模式
    表现:训练集和测试集表现都差
    原因:模型太简单(如层数不足)、特征不足
    示例:用直线拟合抛物线数据

  • 过拟合:模型过度记忆训练数据噪声
    表现:训练集表现好,测试集表现差
    原因:模型太复杂、数据量不足 、正则化强度不足
    示例:模型完美拟合训练数据但无法泛化新数据

  • 如何判断

    持续下降
    稳定高位
    先降后升
    训练损失
    验证损失变化
    正常学习
    欠拟合
    过拟合
2. 解决欠拟合
  1. 增加模型复杂度
    • 添加更多隐藏层
    • 增加每层神经元数量
  2. 特征工程
    • 添加多项式特征(如x2,x3x^2, x^3x2,x3
    • 组合特征(年龄×收入)
  3. 延长训练
    • 增加epoch数量
    • 减小学习率精细调整
  4. 减少正则化强度
    • 适当减小 L1、L2 正则化强度
3. 解决过拟合
(1) L2正则化(权重衰减)
  • 数学原理
    损失函数添加权重平方和惩罚项:
    L总=L原始+λ2∑wi2L_{\text{总}} = L_{\text{原始}} + \frac{\lambda}{2} \sum w_i^2L=L原始+2λwi2
    其中:

    • L原始L_{\text{原始}}L原始 是原始损失函数(比如均方误差、交叉熵等)。
    • λ\lambdaλ 是正则化强度,控制正则化的力度。
    • wi2{w_i^2 }wi2 是模型的第 iii 个权重参数。
    • fracλ2∑wi2frac{\lambda}{2} \sum w_i^2fracλ2wi2 是所有权重参数的平方和,称为 L2 正则化项。

    L2 正则化会惩罚权重参数过大的情况,通过参数平方值对损失函数进行约束。

    为什么是λ2\frac{\lambda}{2}2λ

    假设没有1/2,则对L2 正则化项wiw_iwi的梯度为:2λwi2\lambda w_i2λwi,会引入一个额外的系数 2,使梯度计算和更新公式变得复杂。

    添加1/2后,对wiw_iwi的梯度为:λwi\lambda w_iλwi

  • 梯度更新
    ∂L∂w=∂L原始∂w+λw\frac{\partial L}{\partial w} = \frac{\partial L_{\text{原始}}}{\partial w} + \lambda wwL=wL原始+λw
    权重更新:wnew=wold−η(∂L∂w+λw)w_{new} = w_{old} - \eta \left( \frac{\partial L}{\partial w} + \lambda w \right)wnew=woldη(wL+λw)

    其中:

    • η\etaη 是学习率。
    • ∂L∂w\frac{\partial L}{\partial w}wL是损失函数关于参数 w\ w w 的梯度。
    • lambdawlambda wlambdaw是 L2 正则化项的梯度,对应的是参数值本身的衰减。

    很明显,参数越大惩罚力度就越大,从而让参数逐渐趋向于较小值,避免出现过大的参数。

  • 作用

    • 防止过拟合:当模型过于复杂、参数较多时,模型会倾向于记住训练数据中的噪声,导致过拟合。L2 正则化通过抑制参数的过大值,使得模型更加平滑,降低模型对训练数据噪声的敏感性。
    • 限制模型复杂度:L2 正则化项强制权重参数尽量接近 0,避免模型中某些参数过大,从而限制模型的复杂度。通过引入平方和项,L2 正则化鼓励模型的权重均匀分布,避免单个权重的值过大。
    • 提高模型的泛化能力:正则化项的存在使得模型在测试集上的表现更加稳健,避免在训练集上取得极高精度但在测试集上表现不佳。
    • 平滑权重分布:L2 正则化不会将权重直接变为 0,而是将权重值缩小。这样模型就更加平滑的拟合数据,同时保留足够的表达能力。
  • PyTorch实现

    import torch
    import torch.nn as nn
    import torch.optim as optim
    import matplotlib.pyplot as plt
    
    # 设置随机种子以保证可重复性
    torch.manual_seed(42)
    
    # 生成随机数据
    n_samples = 100
    n_features = 20
    X = torch.randn(n_samples, n_features)  # 输入数据
    y = torch.randn(n_samples, 1)  # 目标值
    
    
    # 定义一个简单的全连接神经网络
    class SimpleNet(nn.Module):
        def __init__(self):
            super(SimpleNet, self).__init__()
            self.fc1 = nn.Linear(n_features, 50)
            self.fc2 = nn.Linear(50, 1)
    
        def forward(self, x):
            x = torch.relu(self.fc1(x))
            return self.fc2(x)
    
    
    # 训练函数
    def train_model(use_l2=False, weight_decay=0.01, n_epochs=100):
        # 初始化模型
        model = SimpleNet()
        criterion = nn.MSELoss()  # 损失函数(均方误差)
    
        # 选择优化器
        if use_l2:
            optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=weight_decay)  # 使用 L2 正则化
        else:
            optimizer = optim.SGD(model.parameters(), lr=0.01)  # 不使用 L2 正则化
    
        # 记录训练损失
        train_losses = []
    
        # 训练过程
        for epoch in range(n_epochs):
            optimizer.zero_grad()  # 清空梯度
            outputs = model(X)  # 前向传播
            loss = criterion(outputs, y)  # 计算损失
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数
    
            train_losses.append(loss.item())  # 记录损失
    
            if (epoch + 1) % 10 == 0:
                print(f'Epoch [{epoch + 1}/{n_epochs}], Loss: {loss.item():.4f}')
    
        return train_losses
    
    
    # 训练并比较两种模型
    train_losses_no_l2 = train_model(use_l2=False)  # 不使用 L2 正则化
    train_losses_with_l2 = train_model(use_l2=True, weight_decay=0.01)  # 使用 L2 正则化
    
    # 绘制训练损失曲线
    plt.plot(train_losses_no_l2, label='Without L2 Regularization')
    plt.plot(train_losses_with_l2, label='With L2 Regularization')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss: L2 Regularization vs No Regularization')
    plt.legend()
    plt.show()
    
(2) L1正则化
  • 数学原理
    L总=L原始+λ∑∣wi∣L_{\text{总}} = L_{\text{原始}} + \lambda \sum |w_i|L=L原始+λwi

    其中:

    • $L_{\text{原始}} $ 是原始损失函数
    • λ\lambdaλ 是正则化强度,控制正则化的力度。
    • ∣wi∣{|w_i| }wi 。是模型第iii 个参数的绝对值
    • lambda∑∣wi∣lambda \sum |w_i|lambdawi 是所有权重参数的绝对值之和,这个项即为 L1 正则化项。
  • 梯度更新
    ∂L∂w=∂L原始∂w+λsign(w)\frac{\partial L}{\partial w} = \frac{\partial L_{\text{原始}}}{\partial w} + \lambda \text{sign}(w)wL=wL原始+λsign(w)

    权重更新:wnew=wold−η(∂L∂w+λw)w_{new} = w_{old} - \eta \left( \frac{\partial L}{\partial w} + \lambda w \right)wnew=woldη(wL+λw)

    其中:

    • η\etaη 是学习率。
    • ∂L∂w\frac{\partial L}{\partial w}wL是损失函数关于参数www 的梯度。
    • lambdawlambda wlambdaw是 L1 正则化项的梯度,对应的是参数值本身的衰减。
  • 作用

    • 稀疏性:L1 正则化的一个显著特性是它会促使许多权重参数变为 。这是因为 L1 正则化倾向于将权重绝对值缩小到零,使得模型只保留对结果最重要的特征,而将其他不相关的特征权重设为零,从而实现 特征选择 的功能。
    • 防止过拟合:通过限制权重的绝对值,L1 正则化减少了模型的复杂度,使其不容易过拟合训练数据。相比于 L2 正则化,L1 正则化更倾向于将某些权重完全移除,而不是减小它们的值。
    • 简化模型:由于 L1 正则化会将一些权重变为零,因此模型最终会变得更加简单,仅依赖于少数重要特征。这对于高维度数据特别有用,尤其是在特征数量远多于样本数量的情况下。
    • 特征选择:因为 L1 正则化会将部分权重置零,因此它天然具有特征选择的能力,有助于自动筛选出对模型预测最重要的特征。
  • 与L2对比

    特性 L1正则化 L2正则化
    解特性 稀疏解 平滑解
    抗噪性 中等
    计算 不可导处需特殊处理 处处可导
(3) Dropout
  • 概念:训练时随机"关闭"部分神经元

  • 工作流程

    • 在每次训练迭代中,随机选择一部分神经元(通常以概率 p丢弃,比如 p=0.5)。

    • 被选中的神经元在当前迭代中不参与前向传播和反向传播。

    • 在测试阶段,所有神经元都参与计算,但需要对权重进行缩放(通常乘以 1−p),以保持输出的期望值一致。

在这里插入图片描述

  • 数学原理
    训练时:y=11−p⋅mask⋅(Wx+b)y = \frac{1}{1-p} \cdot \text{mask} \cdot (Wx + b)y=1p1mask(Wx+b)
    测试时:y=Wx+by = Wx + by=Wx+b
    其中ppp是丢弃概率,mask是伯努利随机矩阵

  • 权重影响:迫使网络不依赖特定神经元,增强鲁棒性

  • 实现逻辑

    # 训练阶段
    mask = (torch.rand(neurons) > p).float()  # 生成掩码
    output = input * mask / (1-p)  # 缩放保持期望不变
    
    # 测试阶段
    output = input  # 使用全部神经元
    
(4) 数据增强

数据增强(Data Augmentation)是一种通过人工生成或修改训练数据来增加数据集多样性的技术,常用于解决过拟合问题。数据增强通过“模拟”更多训练数据,迫使模型学习泛化性更强的规律,而非训练集中的偶然性模式。其本质是一种低成本的正则化手段,尤其在数据稀缺时效果显著。

通过变换原始数据增加样本多样性:

技术 数学变换 作用
缩放 x′=kxx' = kxx=kx 适应不同尺寸
随机裁剪 xcrop=x[i:i+h,j:j+w]x_{crop} = x[i:i+h, j:j+w]xcrop=x[i:i+h,j:j+w] 关注局部特征
水平翻转 xflip[i,j]=x[i,W−j]x_{flip}[i,j] = x[i, W-j]xflip[i,j]=x[i,Wj] 增加对称性
颜色调整 xadj=αx+βx_{adj} = \alpha x + \betaxadj=αx+β 模拟光照变化
旋转 旋转矩阵变换 增强旋转不变性

transforms:

常用变换类

  • transforms.Compose:将多个变换操作组合成一个流水线。
  • transforms.ToTensor:将 PIL 图像或 NumPy 数组转换为 PyTorch 张量,将图像数据从 uint8 类型 (0-255) 转换为 float32 类型 (0.0-1.0)。
  • transforms.Normalize:对张量进行标准化。
  • transforms.Resize:调整图像大小。
  • transforms.CenterCrop:从图像中心裁剪指定大小的区域。
  • transforms.RandomCrop:随机裁剪图像。
  • transforms.RandomHorizontalFlip:随机水平翻转图像。
  • transforms.RandomVerticalFlip:随机垂直翻转图像。
  • transforms.RandomRotation:随机旋转图像。
  • transforms.ColorJitter:随机调整图像的亮度、对比度、饱和度和色调。
  • transforms.RandomGrayscale:随机将图像转换为灰度图像。
  • transforms.RandomResizedCrop:随机裁剪图像并调整大小。

归一化处理

数学原理

  • 均值归一化:x′=x−μσx' = \frac{x - \mu}{\sigma}x=σxμ
  • 范围归一化:x′=x−xmin⁡xmax⁡−xmin⁡x' = \frac{x - x_{\min}}{x_{\max} - x_{\min}}x=xmaxxminxxmin

作用

  • 标准化:将图像的像素值从原始范围(如 [0, 255] 或 [0, 1])转换为均值为 0、标准差为 1 的分布。
  • 加速训练:标准化后的数据分布更均匀,有助于加速模型训练。
  • 提高模型性能:标准化可以使模型更容易学习到数据的特征,提高模型的收敛性和稳定性。

完整增强流程

transform = transforms.Compose([
    transforms.Resize(256),          # 缩放
    transforms.RandomCrop(224),      # 随机裁剪
    transforms.RandomHorizontalFlip(), # 水平翻转
    transforms.ColorJitter(0.2, 0.2, 0.2), # 颜色调整
    transforms.RandomRotation(15),    # 随机旋转
    transforms.ToTensor(),            # 转为张量
    transforms.Normalize([0.5], [0.5]) # 归一化
])

均值(Mean):数据集中所有图像在每个通道上的像素值的平均值。

标准差(Std):数据集中所有图像在每个通道上的像素值的标准差。

RGB 三个通道的均值和标准差 不是随便定义的,而是需要根据具体的数据集进行统计计算。这些值是 ImageNet 数据集的统计结果,已成为计算机视觉任务的默认标准。


二、批量标准化

1. 训练阶段标准化
  • 数学过程

    • 均值:μB=1m∑i=1mxi\mu_B = \frac{1}{m}\sum_{i=1}^m x_iμB=m1i=1mxi

    • 方差:σB2=1m∑i=1m(xi−μB)2\sigma_B^2 = \frac{1}{m}\sum_{i=1}^m (x_i - \mu_B)^2σB2=m1i=1m(xiμB)2

    • 标准化后的值:x^i=xi−μBσB2+ϵ\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}x^i=σB2+ϵ xiμB

      其中,ϵ\epsilonϵ 是一个很小的常数,防止除以零的情况。

    • 平移与缩放:yi=γx^i+βy_i = \gamma \hat{x}_i + \betayi=γx^i+β

      其中,γ\gammaγβ\betaβ 是在训练过程中学习到的参数。它们会随着网络的训练过程通过反向传播进行更新。

  • 全局统计量更新

    通过指数移动平均(Exponential Moving Average, EMA)更新全局均值和方差:

    μ全局=0.9×μ全局+0.1×μB\mu_{\text{全局}} = 0.9 \times \mu_{\text{全局}} + 0.1 \times \mu_Bμ全局=0.9×μ全局+0.1×μB
    σ全局2=0.9×σ全局2+0.1×σB2\sigma^2_{\text{全局}} = 0.9 \times \sigma^2_{\text{全局}} + 0.1 \times \sigma_B^2σ全局2=0.9×σ全局2+0.1×σB2

2. 测试阶段标准化

使用训练时积累的全局统计量:
y=γx−μ全局σ全局2+ϵ+βy = \gamma \frac{x - \mu_{\text{全局}}}{\sqrt{\sigma^2_{\text{全局}} + \epsilon}} + \betay=γσ全局2+ϵ xμ全局+β

然后对标准化后的数据进行缩放和平移:

yi=γ⋅x^i+βyi=γ⋅\hat{x}_i+βyi=γx^i+β

为什么使用全局统计量?

一致性

  • 在测试阶段,输入数据通常是单个样本或少量样本,无法准确计算均值和方差。
  • 使用全局统计量可以确保测试阶段的行为与训练阶段一致。

稳定性

  • 全局统计量是通过训练阶段的大量 mini-batch 数据计算得到的,能够更好地反映数据的整体分布。
  • 使用全局统计量可以减少测试阶段的随机性,使模型的输出更加稳定。

效率

  • 在测试阶段,使用预先计算的全局统计量可以避免重复计算,提高效率。
3. 批标准化的三大作用
  1. 缓解梯度消失/爆炸:标准化处理可以防止激活值过大或过小,保持激活值在稳定区间
  2. 加速训练:由于 BN 使得每层的输入数据分布更为稳定,因此模型允许使用更大学习率(提升10倍)
  3. 轻微正则化:噪声来自批次统计,这有助于提高模型的泛化能力。
4. PyTorch实现
# 定义带BN的神经网络
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.bn1 = nn.BatchNorm1d(256)  # 批标准化层
        self.fc2 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)  # 应用批标准化
        x = F.relu(x)
        x = self.fc2(x)
        return x

参数解析

  • γ(缩放因子)初始化为1
  • β(平移因子)初始化为0
  • ε(数值稳定项)默认1e-5

三、模型的保存与加载

训练一个模型通常需要大量的数据、时间和计算资源。通过保存训练好的模型,可以满足后续的模型部署、模型更新、迁移学习、训练恢复等各种业务需要求。

1. 标准网络模型构建
class MyModle(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModle, self).__init__()
        # 创建一个全连接网络(full connected layer)
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        output = self.fc3(x)
        return output
    
# 创建模型实例
model = MyModel(input_size=10, output_size=2)
# 输入数据
x = torch.randn(5, 10)
# 调用模型
output = model(x)
2. 序列化模型对象

保存整个模型

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

参数说明:

  • obj:要保存的对象,可以是模型、张量、字典等。
  • f:保存文件的路径或文件对象。可以是字符串(文件路径)或文件描述符。
  • pickle_module:用于序列化的模块,默认是 Python 的 pickle 模块。
  • pickle_protocol:pickle 模块的协议版本,默认是 DEFAULT_PROTOCOL(通常是最高版本)。

加载方法

torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

参数说明:

  • f:文件路径或文件对象。可以是字符串(文件路径)或文件描述符。
  • map_location:指定加载对象的设备位置(如 CPU 或 GPU)。默认是 None,表示保持原始设备位置。例如:map_location=torch.device(‘cpu’) 将对象加载到 CPU。
  • pickle_module:用于反序列化的模块,默认是 Python 的 pickle 模块。
  • pickle_load_args:传递给 pickle_module.load() 的额外参数。

优点:包含模型结构和参数
缺点:文件大,依赖原始类定义

3. 保存模型参数(推荐)

保存检查点

torch.save({
    'epoch': 10,
    'model_state_dict': model.state_dict(),  # 模型参数
    'optimizer_state_dict': optimizer.state_dict(),  # 优化器状态
    'loss': loss,
}, 'checkpoint.pth')

加载参数

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

最佳实践

  • 训练中每N个epoch保存检查点
  • 部署时只加载参数(避免依赖问题)
  • 保存优化器状态实现训练中断恢复

关键原理深度解析

1. Dropout的数学原理

考虑带Dropout的全连接层:

  • 训练时:ytrain=11−p⋅m⋅(Wx+b)y_{\text{train}} = \frac{1}{1-p} \cdot m \cdot (Wx + b)ytrain=1p1m(Wx+b)
  • 测试时:ytest=Wx+by_{\text{test}} = Wx + bytest=Wx+b

为什么训练时除以(1-p)?
保持输出期望值不变:
E[ytrain]=E[11−p⋅m⋅z]=11−p⋅(1−p)⋅z=zE[y_{\text{train}}] = E[\frac{1}{1-p} \cdot m \cdot z] = \frac{1}{1-p} \cdot (1-p) \cdot z = zE[ytrain]=E[1p1mz]=1p1(1p)z=z
其中mmm是伯努利掩码,P(m=1)=1−pP(m=1)=1-pP(m=1)=1p

2. 批标准化反向传播

LLL为损失函数,BN层反向传播梯度:

  • ∂L∂x^i=∂L∂yi⋅γ\frac{\partial L}{\partial \hat{x}_i} = \frac{\partial L}{\partial y_i} \cdot \gammax^iL=yiLγ
  • ∂L∂σB2=∑i=1m∂L∂x^i⋅(xi−μB)⋅(−12)(σB2+ϵ)−3/2\frac{\partial L}{\partial \sigma_B^2} = \sum_{i=1}^m \frac{\partial L}{\partial \hat{x}_i} \cdot (x_i - \mu_B) \cdot (-\frac{1}{2}) (\sigma_B^2 + \epsilon)^{-3/2}σB2L=i=1mx^iL(xiμB)(21)(σB2+ϵ)3/2
  • ∂L∂μB=(∑i=1m∂L∂x^i⋅−1σB2+ϵ)+∂L∂σB2⋅∑i=1m−2(xi−μB)m\frac{\partial L}{\partial \mu_B} = \left( \sum_{i=1}^m \frac{\partial L}{\partial \hat{x}_i} \cdot \frac{-1}{\sqrt{\sigma_B^2 + \epsilon}} \right) + \frac{\partial L}{\partial \sigma_B^2} \cdot \frac{\sum_{i=1}^m -2(x_i - \mu_B)}{m}μBL=(i=1mx^iLσB2+ϵ 1)+σB2Lmi=1m2(xiμB)
3. 正则化的几何解释
原始损失函数
添加正则项
L1正则化
菱形约束区域
L2正则化
圆形约束区域
解在坐标轴上
解在原点附近

实战策略

1. 过拟合解决方案选择
结构化数据
图像数据
序列数据
遇到过拟合
数据类型
L1/L2正则化
数据增强+Dropout
Dropout+权重衰减
添加批标准化
2. 完整训练流程示例
'''
练习:使用全连接网络训练和验证CIFAR10数据集
并思考:为什么CIFAR10数据集的准确率很低?
步骤:
1. 数据预处理
2. 数据准备
3. 加载数据
4. 定义网络结构(批量标准化)
5. 训练模型
6. 验证模型
7. 保存模型
8. 预测模型
9. 绘制损失曲线和准确率曲线
10. 思考:为什么CIFAR10数据集的准确率很低?
11. 思考:如何提高CIFAR10数据集的准确率?
12. 思考:如何使用数据增强技术来提高CIFAR10数据集的准确率?
'''
import torch
from torch import nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch import optim
import matplotlib.pyplot as plt
from PIL import Image

# 检查是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1.数据预处理 - 增强版
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

# 2.数据准备
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
eval_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=eval_transform)

# 3.加载数据
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=64, shuffle=False)

# 4.定义改进的网络结构
class ImprovedNet(nn.Module):
    def __init__(self):
        super(ImprovedNet, self).__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(0.2)
        
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(0.2)
        
        self.fc3 = nn.Linear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.dropout3 = nn.Dropout(0.2)
        
        self.fc4 = nn.Linear(128, 64)
        self.bn4 = nn.BatchNorm1d(64)
        self.dropout4 = nn.Dropout(0.2)
        
        self.fc5 = nn.Linear(64, 10)
        
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 32 * 32 * 3)
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.dropout3(x)
        
        x = self.fc4(x)
        x = self.bn4(x)
        x = self.relu(x)
        x = self.dropout4(x)

        x = self.fc5(x)
        return x

# 模型定义
model = ImprovedNet().to(device)
# 损失函数定义
criterion = nn.CrossEntropyLoss()
# 优化器定义
opt = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
# 学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=3, verbose=True)

# 5.训练
def train(model, train_loader, criterion, opt, scheduler, epochs):
    model.train()
    train_losses = []
    train_accuracies = []
    
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            # 梯度清零
            opt.zero_grad()
            # 前向传播
            output = model(data)
            # 计算损失
            loss = criterion(output, target)
            # 反向传播,计算梯度
            loss.backward()
            # 模型参数更新
            opt.step()
            
            running_loss += loss.item()
            # 获取预测结果
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            # 获取预测正确的数量
            correct += (predicted == target).sum().item()

            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
        
        # 计算平均损失和准确率
        avg_loss = running_loss / len(train_loader)
        accuracy = 100. * correct / total
        train_losses.append(avg_loss)
        train_accuracies.append(accuracy)
        print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')
        
        # 更新学习率
        scheduler.step(avg_loss)
    
    return train_losses, train_accuracies

# 6.验证
def eval_model(model, eval_loader, criterion):
    model.eval()
    correct = 0
    total_loss = 0.0
    
    with torch.no_grad():
        for data, target in eval_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()

            _, predicted = torch.max(output.data, 1)
            # 获取预测正确的数量
            correct += (predicted == target).sum().item()
    
    # 计算平均损失和准确率
    avg_loss = total_loss / len(eval_loader)
    acc = 100. * correct / len(eval_loader.dataset)
    print(f'\nTest set: Average loss: {avg_loss:.4f}, Accuracy: {correct}/{len(eval_loader.dataset)} ({acc:.2f}%)\n')
    return avg_loss, acc

# 主程序执行
if __name__ == "__main__":
    epochs = 50

    # 训练模型
    print("开始训练...")
    train_losses, train_accuracies = train(model, train_loader, criterion, opt, scheduler, epochs)
    
    # 验证模型
    print("开始验证...")
    eval_model(model, eval_loader, criterion)

    # 保存模型
    torch.save(model.state_dict(), 'improved_model.pth')
    print("模型已保存为 improved_model.pth")

    # 绘制训练曲线
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(range(1, epochs+1), train_losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(range(1, epochs+1), train_accuracies)
    plt.title('Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    
    plt.tight_layout()
    plt.show()

网站公告

今日签到

点亮在社区的每一天
去签到