场景:拟合一个简单的正弦函数
我们将训练一个小的神经网络去拟合一个正弦曲线(带噪声),并比较 使用和不使用 weight decay 的效果。
1. 准备数据
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
# 构造数据(sin 函数 + 噪声)
x = torch.linspace(0, 2 * np.pi, 100).unsqueeze(1)
y = torch.sin(x) + 0.1 * torch.randn_like(x) # 添加噪声
2. 定义模型(一个简单的 MLP)
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
def forward(self, x):
return self.net(x)
3. 分别训练两个模型(一个使用 weight decay,一个不使用)
def train(weight_decay_value):
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=weight_decay_value)
loss_fn = nn.MSELoss()
for epoch in range(500):
#告诉模型:进入训练模式(比如启用 dropout、batchnorm 的训练行为)。
model.train()
#前向传播(预测)
y_pred = model(x)
# 计算损失
loss = loss_fn(y_pred, y)
# 反向传播准备
#清空之前累积的梯度,防止干扰新一轮的计算。
optimizer.zero_grad()
#反向传播(求梯度)
#根据损失反向传播,计算参数的梯度。
loss.backward()
# 根据梯度和优化器策略(含 weight decay)更新模型参数。
optimizer.step()
return model
4. 可视化结果对比
model_no_decay = train(0.0)
model_with_decay = train(0.01)
# 绘图
plt.figure(figsize=(10, 5))
plt.scatter(x.numpy(), y.numpy(), label='Data', color='gray', alpha=0.5)
with torch.no_grad():
y_pred1 = model_no_decay(x)
y_pred2 = model_with_decay(x)
plt.plot(x.numpy(), y_pred1.numpy(), label='No Weight Decay', color='blue')
plt.plot(x.numpy(), y_pred2.numpy(), label='With Weight Decay', color='red')
plt.legend()
plt.title("Effect of Weight Decay on Fitting Sin Curve")
plt.show()
运行结果如下
5. 解释结果
模型 | 效果 |
---|---|
❌ 无 weight decay | 拟合得非常贴近噪声,容易过拟合,曲线不光滑 |
✅ 有 weight decay | 曲线更平滑,不那么贴合噪声,更接近真实函数 |
✅ 原因:
没有 weight decay:网络自由度太高,学到了很多噪声特征;
加了 weight decay:对权重值大小施加了惩罚,网络“更保守”,只学到主要趋势。
🔚 总结
weight_decay
本质上就是 L2 正则化,防止参数变得太大;它可以 减少过拟合、提高泛化能力;
在 LoRA 微调、预训练、分类任务中都非常重要;
推荐值通常在
0.01
左右,需调参。