简单的 PyTorch 示例,可视化和解释 weight decay 的作用

发布于:2025-07-01 ⋅ 阅读:(14) ⋅ 点赞:(0)

场景:拟合一个简单的正弦函数

我们将训练一个小的神经网络去拟合一个正弦曲线(带噪声),并比较 使用和不使用 weight decay 的效果

1. 准备数据

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

# 构造数据(sin 函数 + 噪声)
x = torch.linspace(0, 2 * np.pi, 100).unsqueeze(1)
y = torch.sin(x) + 0.1 * torch.randn_like(x)  # 添加噪声

2. 定义模型(一个简单的 MLP)

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.net(x)

3. 分别训练两个模型(一个使用 weight decay,一个不使用)

def train(weight_decay_value):
    model = SimpleNet()
    optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=weight_decay_value)
    loss_fn = nn.MSELoss()

    for epoch in range(500):
        #告诉模型:进入训练模式(比如启用 dropout、batchnorm 的训练行为)。
        model.train()
        #前向传播(预测)
        y_pred = model(x)
        # 计算损失
        loss = loss_fn(y_pred, y)
        # 反向传播准备
        #清空之前累积的梯度,防止干扰新一轮的计算。
        optimizer.zero_grad()
        #反向传播(求梯度)
        #根据损失反向传播,计算参数的梯度。
        loss.backward()
        # 根据梯度和优化器策略(含 weight decay)更新模型参数。
        optimizer.step()
    return model

4. 可视化结果对比

model_no_decay = train(0.0)
model_with_decay = train(0.01)

# 绘图
plt.figure(figsize=(10, 5))
plt.scatter(x.numpy(), y.numpy(), label='Data', color='gray', alpha=0.5)

with torch.no_grad():
    y_pred1 = model_no_decay(x)
    y_pred2 = model_with_decay(x)

plt.plot(x.numpy(), y_pred1.numpy(), label='No Weight Decay', color='blue')
plt.plot(x.numpy(), y_pred2.numpy(), label='With Weight Decay', color='red')
plt.legend()
plt.title("Effect of Weight Decay on Fitting Sin Curve")
plt.show()

运行结果如下

5. 解释结果

模型 效果
❌ 无 weight decay 拟合得非常贴近噪声,容易过拟合,曲线不光滑
✅ 有 weight decay 曲线更平滑,不那么贴合噪声,更接近真实函数

✅ 原因:

  • 没有 weight decay:网络自由度太高,学到了很多噪声特征;

  • 加了 weight decay:对权重值大小施加了惩罚,网络“更保守”,只学到主要趋势。


🔚 总结

  • weight_decay 本质上就是 L2 正则化,防止参数变得太大;

  • 它可以 减少过拟合、提高泛化能力

  • 在 LoRA 微调、预训练、分类任务中都非常重要;

  • 推荐值通常在 0.01 左右,需调参。


网站公告

今日签到

点亮在社区的每一天
去签到