python打卡day35@浙大疏锦行

发布于:2025-05-25 ⋅ 阅读:(22) ⋅ 点赞:(0)

知识点回顾:

  1. 三种不同的模型可视化方法:推荐torchinfo打印summary+权重分布可视化
  2. 进度条功能:手动和自动写法,让打印结果更加美观
  3. 推理的写法:评估模式

作业:调整模型定义时的超参数,对比下效果。

 1. 模型可视化方法

①使用torchinfo打印summary

from torchinfo import summary

model = SimpleNN()
summary(model, input_size=(1, 784))  # 输入尺寸

②权重分布可视化

import matplotlib.pyplot as plt

# 可视化第一层权重
plt.hist(model.fc1.weight.data.numpy().flatten(), bins=50)
plt.title('FC1 Weight Distribution')
plt.show()

 2. 进度条功能

①手动写法

from tqdm import tqdm

for epoch in range(10):
    with tqdm(train_loader, desc=f'Epoch {epoch}') as pbar:
        for data, target in pbar:
            # ...训练代码...
            pbar.set_postfix(loss=loss.item())

②自动写法

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()
for epoch in range(10):
    for i, (data, target) in enumerate(train_loader):
        # ...训练代码...
        writer.add_scalar('Loss/train', loss.item(), epoch*len(train_loader)+i)

③推理写法

model.eval()  # 切换到评估模式
with torch.no_grad():  # 禁用梯度计算
    correct = 0
    for data, target in test_loader:
        output = model(data)
        pred = output.argmax(dim=1)
        correct += (pred == target).sum().item()
    accuracy = correct / len(test_loader.dataset)
    print(f'Test Accuracy: {accuracy:.2%}')

 作业:超参数调整对比

1.模型定义示例(带超参数)

class SimpleNN(nn.Module):
    def __init__(self, hidden_size=128, dropout_rate=0.2):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, hidden_size)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_size, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

2.超参数对比实验

hidden_sizes = [64, 128, 256]
dropout_rates = [0.0, 0.2, 0.5]

for h_size in hidden_sizes:
    for d_rate in dropout_rates:
        model = SimpleNN(hidden_size=h_size, dropout_rate=d_rate)
        # ...训练和评估代码...
        print(f'Hidden: {h_size}, Dropout: {d_rate}, Acc: {accuracy:.2%}')

关键点:

1. 评估时一定要用 model.eval() 和 torch.no_grad()
2. 进度条推荐使用tqdm或tensorboard
3. 超参数调整要系统性地对比(如网格搜索)
4. 可视化有助于理解模型行为


网站公告

今日签到

点亮在社区的每一天
去签到