过拟合、欠拟合
在机器学习和深度学习中,过拟合(Overfitting)和欠拟合(Underfitting)是模型训练过程中常见的两种问题,直接影响模型的泛化能力(即对未见过的数据的预测能力)。
1. 欠拟合
欠拟合指模型无法充分捕捉训练数据中的规律,导致在训练集和测试集上的表现都很差(误差高、准确率低)。模型对数据的 “学习” 不够充分,甚至没有学到核心特征。
表现
- 训练集准确率低,损失函数值大;
- 测试集表现与训练集接近(同样差),没有明显差距。
原因
- 模型复杂度太低:例如用简单的线性模型去拟合非线性数据(如用直线拟合曲线分布),无法捕捉数据中的复杂模式。
- 特征不足:输入的特征数量少或质量差,无法有效区分不同类别或预测目标。
- 训练不充分:迭代次数太少,模型还未学到数据的规律就停止训练。
解决方法
- 增加模型复杂度:例如将线性模型改为多项式模型,或在神经网络中增加层数、神经元数量。
- 补充高质量特征:通过特征工程提取更多有效特征(如文本的 TF-IDF、图像的边缘特征)。
- 延长训练时间:增加迭代次数,确保模型充分学习数据规律(但需避免过拟合)。
- 减少正则化强度:若正则化过强(如 L1/L2 惩罚系数太大),会限制模型学习能力,适当降低正则化可缓解欠拟合。
2. 过拟合
定义
过拟合指模型过度学习训练数据中的细节(包括噪声),导致在训练集上表现极好,但在未见过的测试集上表现很差,泛化能力弱。
表现
- 训练集准确率极高(接近 100%),损失函数值极小;
- 测试集准确率远低于训练集,误差显著上升,两者差距很大。
原因
- 模型复杂度太高:模型能力过强(如深度很深的神经网络),不仅学到了数据的核心规律,还记住了训练集中的噪声、异常值等偶然信息。
- 训练数据不足或有噪声:数据量太少时,模型容易 “死记硬背” 所有样本;数据中存在错误标注或噪声时,模型会将这些错误当作规律学习。
- 训练过度:迭代次数过多,模型在训练集上过度优化,导致对微小波动也过度敏感。
3. 解决过拟合
1. Dropout(随机丢弃神经元)
概念
在训练过程中随机关闭部分神经元,减少神经元之间的依赖,提高模型泛化能力。
核心思想:
- 打破神经元的共适应性(Co-adaptation):传统神经网络中,神经元可能过度依赖其他特定神经元,导致模型脆弱。Dropout 通过随机丢弃神经元,迫使每个神经元独立学习有用特征。
- 模拟集成学习:每次迭代中,Dropout 训练的是不同的子网络,相当于训练多个子模型的集成,最终效果接近多个模型的平均预测。
Dropout 在训练和测试阶段的行为不同:
假设某个神经元的输出为 x,Dropout 的操作可以表示为:
在训练阶段:
每次前向传播时,以概率
p
(通常设置为 0.5)随机"关闭"(丢弃)网络中的神经元被丢弃的神经元在前向传播和反向传播中都不参与计算
剩余的神经元被按比例放大(乘以
1/(1-p)
)以保持输出的期望值不变
y={x1−p以概率1−p保留神经元0以概率p丢弃神经元 y=\begin{cases}\frac{x}{1−p} & 以概率1−p保留神经元 \\ 0 & 以概率 p 丢弃神经元 \end{cases} y={1−px0以概率1−p保留神经元以概率p丢弃神经元
在测试阶段:
- 关闭 Dropout:所有神经元均被保留,不做任何丢弃。
- 无需缩放:由于训练时已通过
1/p
缩放,测试时直接使用原始权重即可,无需额外调整。
y=x y=x y=x
作用:
- 防止共适应:
- 迫使神经元独立学习有用特征,而不是过度依赖其他特定神经元
- 模型平均(隐式集成):
- 每次训练都在不同的子网络上进行
- 相当于在训练多个不同的模型
- 测试时相当于这些子模型的平均
- 噪声注入:
- 为训练过程添加随机性
- 增强模型对输入扰动的鲁棒性
PyTorch API 示例
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.dropout = nn.Dropout(p=0.5) # 随机关闭50%神经元
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
2. L1/L2 正则化(权重惩罚)
概念
- L1正则化:通过惩罚权重绝对值和,促使部分权重为0(稀疏化)。
- L2正则化:通过惩罚权重平方和,限制权重大小(平滑化)。
原理
L1正则化:
Ltotal=L+λ∑∣wi∣ L_{\text{total}} = L + \lambda \sum |w_i| Ltotal=L+λ∑∣wi∣L2正则化:
Ltotal=L+λ∑wi2 L_{\text{total}} = L + \lambda \sum w_i^2 Ltotal=L+λ∑wi2$ \lambda $:正则化系数,控制惩罚强度。
L2正则化(PyTorch内置):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
3. 早停法(Early Stopping)
概念
监控验证集性能,当损失不再下降时提前终止训练,防止过拟合。
原理
设验证损失为 L_val(t),若在 Δt 轮内 L_val 不再下降,则停止训练。
公式
判断条件:
if Lval(t)≤Lval(t+Δt)∀Δt, stop training \text{if } L_{\text{val}}(t) \leq L_{\text{val}}(t + \Delta t) \quad \forall \Delta t, \text{ stop training} if Lval(t)≤Lval(t+Δt)∀Δt, stop training
PyTorch 实现示例
best_loss = float('inf')
patience = 5 # 容忍无改进的轮数
trigger_times = 0
for epoch in range(100):
# 训练和验证
val_loss = validate(model, val_loader)
if val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), 'best_model.pth') # 保存最佳模型
trigger_times = 0
else:
trigger_times += 1
if trigger_times >= patience:
print("Early stopping!")
break
4. 数据增强(Data Augmentation)
概念
通过对训练数据进行变换(如翻转、旋转),增加数据多样性,提高模型泛化能力。
原理
图像变换:随机水平翻转、旋转、裁剪、颜色抖动等。
数学表达式:
Inew(x,y)=Iold(x,y)×M(x,y) I_{\text{new}}(x, y) = I_{\text{old}}(x, y) \times M(x, y) Inew(x,y)=Iold(x,y)×M(x,y)- M(x, y):二值矩阵(0 或 1),表示像素的保留(1)或裁除(0)。
PyTorch 示例
transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(10), # 随机旋转10度
transforms.ColorJitter(brightness=0.2, contrast=0.2), # 颜色抖动
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
批量标准化
批量标准化(Batch Normalization, BN)是一种用于加速深度神经网络训练并提升模型稳定性的技术。它通过规范化每一层的输入,减少内部协变量偏移(Internal Covariate Shift)的影响,从而改善梯度传播和模型收敛速度。
在深度神经网络中,每一层的输入分布会随着参数的更新而变化,这种现象称为内部协变量偏移。例如,某一层的输入可能因为前一层参数的调整而发生剧烈波动,导致后续层需要不断适应新的输入分布,从而减缓训练速度并增加训练难度。
计算均值和方差
批量标准化的核心思想:
在训练过程中,对每一层的输入进行标准化(零均值、单位方差),并引入可学习的参数以保留模型的表达能力。这使得每一层的输入分布保持稳定,从而缓解内部协变量偏移问题。
批量标准化的处理流程分为以下步骤:
计算小批量均值和方差
对于输入特征 x_i ∈ B,其中 B 是当前小批量数据,计算均值 μ_B 和方差 σ_B²:
μB=1m∑i=1mxi \mu_B = \frac{1}{m} \sum_{i=1}^m x_i μB=m1i=1∑mxiσB2=1m∑i=1m(xi−μB)2 \sigma_B^2 = \frac{1}{m} \sum_{i=1}^m (x_i - \mu_B)^2 σB2=m1i=1∑m(xi−μB)2
其中 m 是小批量的大小。
标准化
使用均值和方差对输入进行标准化:
x^i=xi−μBσB2+ϵ \hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} x^i=σB2+ϵxi−μB
ε 是一个小的常数(如 1e-5),用于防止除零错误。
缩放和平移
引入可学习的参数 γ(缩放因子)和 β(平移因子),以恢复模型的表达能力:
yi=γx^i+β y_i = \gamma \hat{x}_i + \beta yi=γx^i+β- γ 和 β 会在训练过程中通过反向传播进行更新。
训练阶段的处理
在训练过程中,批量标准化需要同时更新全局统计量(全局均值和方差),以便在测试阶段使用。具体步骤如下:
计算当前小批量的均值和方差(如公式所示)。
标准化当前小批量的输入。
使用可学习的 γ 和 β 进行缩放和平移。
更新全局统计量
通过指数移动平均(Exponential Moving Average, EMA)更新全局均值 μ_global 和方差 σ_global²:
μglobal←momentum⋅μglobal+(1−momentum)⋅μB \mu_{\text{global}} \leftarrow \text{momentum} \cdot \mu_{\text{global}} + (1 - \text{momentum}) \cdot \mu_B μglobal←momentum⋅μglobal+(1−momentum)⋅μBσglobal2←momentum⋅σglobal2+(1−momentum)⋅σB2 \sigma_{\text{global}}^2 \leftarrow \text{momentum} \cdot \sigma_{\text{global}}^2 + (1 - \text{momentum}) \cdot \sigma_B^2 σglobal2←momentum⋅σglobal2+(1−momentum)⋅σB2
- 默认的
momentum
值为 0.1(在 PyTorch 中)。
- 默认的
测试阶段的处理
在测试阶段,由于没有小批量数据(或小批量大小为 1),直接使用训练阶段计算的全局统计量进行标准化:
使用全局均值和方差标准化输入:
x^i=xi−μglobalσglobal2+ϵ \hat{x}_i = \frac{x_i - \mu_{\text{global}}}{\sqrt{\sigma_{\text{global}}^2 + \epsilon}} x^i=σglobal2+ϵxi−μglobal缩放和平移:
yi=γx^i+β y_i = \gamma \hat{x}_i + \beta yi=γx^i+β
为什么使用全局统计量?
- 一致性:确保测试阶段的行为与训练阶段一致。
- 稳定性:全局统计量基于训练数据的分布,能更好地反映整体数据特性。
API
在 PyTorch 中,可以通过 nn.BatchNorm1d
、nn.BatchNorm2d
等类实现批量标准化。
参数:
num_features
(必填)
输入的特征数量(通道数)。- 1d 用于一维数据(如全连接层输出),填特征维度
- 2d 用于二维数据(如卷积特征图),填通道数
eps
(默认 1e-5)
防止分母为 0 的微小值momentum
(默认 0.1)
计算移动均值和方差的动量(用于推理时的统计量)affine
(默认 True)
是否添加可学习的缩放(gamma)和偏移(beta)参数track_running_stats
(默认 True)
是否跟踪训练中的移动统计量(推理时使用)
例如:
import torch
import torch.nn as nn
# 全连接层 + 批量标准化
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.bn1 = nn.BatchNorm1d(256) # 输入特征维度为256
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.bn1(self.fc1(x))) # 标准化后接激活函数
x = self.fc2(x)
return x
批量标准化的优点
- 加速训练:减少内部协变量偏移,使模型更快收敛。
- 允许更大的学习率:标准化后的输入分布更稳定,可尝试更高的学习率。
- 正则化效果:在训练阶段,小批量的统计量会引入轻微噪声,具有类似 Dropout 的正则化效果。
- 减少对初始化的依赖:标准化后的输入分布更均匀,降低了对参数初始化的敏感性。
- 简化超参数调优:减少对学习率、权重初始化等超参数的敏感性。