DAY 35 模型可视化与推理

发布于:2025-06-19 ⋅ 阅读:(15) ⋅ 点赞:(0)

DAY 35 模型可视化与推理

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import time
import matplotlib.pyplot as plt

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.LongTensor(y_train).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_test = torch.LongTensor(y_test).to(device)

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(4, 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

model = MLP().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
num_epochs = 20000
losses = []
start_time = time.time()

for epoch in range(num_epochs):
    outputs = model(X_train)
    loss = criterion(outputs, y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 200 == 0:
        losses.append(loss.item())
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
    
    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

time_all = time.time() - start_time
print(f'Training time: {time_all:.2f} seconds')

plt.plot(range(len(losses)), losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

使用设备: cpu
Epoch [100/20000], Loss: 1.0420
Epoch [200/20000], Loss: 0.9975
Epoch [200/20000], Loss: 0.9975
Epoch [300/20000], Loss: 0.9480
Epoch [400/20000], Loss: 0.8947
Epoch [400/20000], Loss: 0.8947
Epoch [500/20000], Loss: 0.8393
Epoch [600/20000], Loss: 0.7838
Epoch [600/20000], Loss: 0.7838
Epoch [700/20000], Loss: 0.7300
Epoch [800/20000], Loss: 0.6797
Epoch [800/20000], Loss: 0.6797
Epoch [900/20000], Loss: 0.6337
Epoch [1000/20000], Loss: 0.5927
Epoch [1000/20000], Loss: 0.5927
Epoch [1100/20000], Loss: 0.5566
Epoch [1200/20000], Loss: 0.5251
Epoch [1200/20000], Loss: 0.5251
Epoch [1300/20000], Loss: 0.4975
Epoch [1400/20000], Loss: 0.4733
Epoch [1400/20000], Loss: 0.4733
Epoch [1500/20000], Loss: 0.4518
Epoch [1600/20000], Loss: 0.4325
Epoch [1600/20000], Loss: 0.4325
Epoch [1700/20000], Loss: 0.4150
Epoch [1800/20000], Loss: 0.3990
Epoch [1800/20000], Loss: 0.3990
Epoch [1900/20000], Loss: 0.3840
Epoch [2000/20000], Loss: 0.3701
Epoch [2000/20000], Loss: 0.3701
Epoch [2100/20000], Loss: 0.3570
Epoch [2200/20000], Loss: 0.3447
Epoch [2200/20000], Loss: 0.3447
Epoch [2300/20000], Loss: 0.3329
Epoch [2400/20000], Loss: 0.3218
Epoch [2400/20000], Loss: 0.3218
Epoch [2500/20000], Loss: 0.3112
Epoch [2600/20000], Loss: 0.3011
Epoch [2600/20000], Loss: 0.3011
Epoch [2700/20000], Loss: 0.2914
Epoch [2800/20000], Loss: 0.2822
Epoch [2800/20000], Loss: 0.2822
Epoch [2900/20000], Loss: 0.2735
Epoch [3000/20000], Loss: 0.2651
Epoch [3000/20000], Loss: 0.2651
Epoch [3100/20000], Loss: 0.2572
Epoch [3200/20000], Loss: 0.2496
Epoch [3200/20000], Loss: 0.2496
Epoch [3300/20000], Loss: 0.2423
Epoch [3400/20000], Loss: 0.2354
Epoch [3400/20000], Loss: 0.2354
Epoch [3500/20000], Loss: 0.2288
Epoch [3600/20000], Loss: 0.2226
Epoch [3600/20000], Loss: 0.2226
Epoch [3700/20000], Loss: 0.2166
Epoch [3800/20000], Loss: 0.2109
Epoch [3800/20000], Loss: 0.2109
Epoch [3900/20000], Loss: 0.2054
Epoch [4000/20000], Loss: 0.2003
Epoch [4000/20000], Loss: 0.2003
Epoch [4100/20000], Loss: 0.1953
Epoch [4200/20000], Loss: 0.1906
Epoch [4200/20000], Loss: 0.1906
Epoch [4300/20000], Loss: 0.1861
Epoch [4400/20000], Loss: 0.1818
Epoch [4400/20000], Loss: 0.1818
Epoch [4500/20000], Loss: 0.1777
Epoch [4600/20000], Loss: 0.1738
Epoch [4600/20000], Loss: 0.1738
Epoch [4700/20000], Loss: 0.1700
Epoch [4800/20000], Loss: 0.1664
Epoch [4800/20000], Loss: 0.1664
Epoch [4900/20000], Loss: 0.1630
Epoch [5000/20000], Loss: 0.1597
Epoch [5000/20000], Loss: 0.1597
Epoch [5100/20000], Loss: 0.1566
Epoch [5200/20000], Loss: 0.1536
Epoch [5200/20000], Loss: 0.1536
Epoch [5300/20000], Loss: 0.1507
Epoch [5400/20000], Loss: 0.1479
Epoch [5400/20000], Loss: 0.1479
Epoch [5500/20000], Loss: 0.1452
Epoch [5600/20000], Loss: 0.1427
Epoch [5600/20000], Loss: 0.1427
Epoch [5700/20000], Loss: 0.1402
Epoch [5800/20000], Loss: 0.1379
Epoch [5800/20000], Loss: 0.1379
Epoch [5900/20000], Loss: 0.1356
Epoch [6000/20000], Loss: 0.1335
Epoch [6000/20000], Loss: 0.1335
Epoch [6100/20000], Loss: 0.1314
Epoch [6200/20000], Loss: 0.1294
Epoch [6200/20000], Loss: 0.1294
Epoch [6300/20000], Loss: 0.1274
Epoch [6400/20000], Loss: 0.1256
Epoch [6400/20000], Loss: 0.1256
Epoch [6500/20000], Loss: 0.1238
Epoch [6600/20000], Loss: 0.1220
Epoch [6600/20000], Loss: 0.1220
Epoch [6700/20000], Loss: 0.1204
Epoch [6800/20000], Loss: 0.1188
Epoch [6800/20000], Loss: 0.1188
Epoch [6900/20000], Loss: 0.1172
Epoch [7000/20000], Loss: 0.1157
Epoch [7000/20000], Loss: 0.1157
Epoch [7100/20000], Loss: 0.1143
Epoch [7200/20000], Loss: 0.1129
Epoch [7200/20000], Loss: 0.1129
Epoch [7300/20000], Loss: 0.1115
Epoch [7400/20000], Loss: 0.1102
Epoch [7400/20000], Loss: 0.1102
Epoch [7500/20000], Loss: 0.1089
Epoch [7600/20000], Loss: 0.1077
Epoch [7600/20000], Loss: 0.1077
Epoch [7700/20000], Loss: 0.1065
Epoch [7800/20000], Loss: 0.1054
Epoch [7800/20000], Loss: 0.1054
Epoch [7900/20000], Loss: 0.1043
Epoch [8000/20000], Loss: 0.1032
Epoch [8000/20000], Loss: 0.1032
Epoch [8100/20000], Loss: 0.1022
Epoch [8200/20000], Loss: 0.1012
Epoch [8200/20000], Loss: 0.1012
Epoch [8300/20000], Loss: 0.1002
Epoch [8400/20000], Loss: 0.0992
Epoch [8400/20000], Loss: 0.0992
Epoch [8500/20000], Loss: 0.0983
Epoch [8600/20000], Loss: 0.0974
Epoch [8600/20000], Loss: 0.0974
Epoch [8700/20000], Loss: 0.0965
Epoch [8800/20000], Loss: 0.0957
Epoch [8800/20000], Loss: 0.0957
Epoch [8900/20000], Loss: 0.0949
Epoch [9000/20000], Loss: 0.0941
Epoch [9000/20000], Loss: 0.0941
Epoch [9100/20000], Loss: 0.0933
Epoch [9200/20000], Loss: 0.0926
Epoch [9200/20000], Loss: 0.0926
Epoch [9300/20000], Loss: 0.0918
Epoch [9400/20000], Loss: 0.0911
Epoch [9400/20000], Loss: 0.0911
Epoch [9500/20000], Loss: 0.0904
Epoch [9600/20000], Loss: 0.0898
Epoch [9600/20000], Loss: 0.0898
Epoch [9700/20000], Loss: 0.0891
Epoch [9800/20000], Loss: 0.0885
Epoch [9800/20000], Loss: 0.0885
Epoch [9900/20000], Loss: 0.0878
Epoch [10000/20000], Loss: 0.0872
Epoch [10000/20000], Loss: 0.0872
Epoch [10100/20000], Loss: 0.0866
Epoch [10200/20000], Loss: 0.0861
Epoch [10200/20000], Loss: 0.0861
Epoch [10300/20000], Loss: 0.0855
Epoch [10400/20000], Loss: 0.0850
Epoch [10400/20000], Loss: 0.0850
Epoch [10500/20000], Loss: 0.0844
Epoch [10600/20000], Loss: 0.0839
Epoch [10600/20000], Loss: 0.0839
Epoch [10700/20000], Loss: 0.0834
Epoch [10800/20000], Loss: 0.0829
Epoch [10800/20000], Loss: 0.0829
Epoch [10900/20000], Loss: 0.0824
Epoch [11000/20000], Loss: 0.0819
Epoch [11000/20000], Loss: 0.0819
Epoch [11100/20000], Loss: 0.0815
Epoch [11200/20000], Loss: 0.0810
Epoch [11200/20000], Loss: 0.0810
Epoch [11300/20000], Loss: 0.0806
Epoch [11400/20000], Loss: 0.0802
Epoch [11400/20000], Loss: 0.0802
Epoch [11500/20000], Loss: 0.0797
Epoch [11600/20000], Loss: 0.0793
Epoch [11600/20000], Loss: 0.0793
Epoch [11700/20000], Loss: 0.0789
Epoch [11800/20000], Loss: 0.0785
Epoch [11800/20000], Loss: 0.0785
Epoch [11900/20000], Loss: 0.0781
Epoch [12000/20000], Loss: 0.0778
Epoch [12000/20000], Loss: 0.0778
Epoch [12100/20000], Loss: 0.0774
Epoch [12200/20000], Loss: 0.0770
Epoch [12200/20000], Loss: 0.0770
Epoch [12300/20000], Loss: 0.0767
Epoch [12400/20000], Loss: 0.0763
Epoch [12400/20000], Loss: 0.0763
Epoch [12500/20000], Loss: 0.0760
Epoch [12600/20000], Loss: 0.0756
Epoch [12600/20000], Loss: 0.0756
Epoch [12700/20000], Loss: 0.0753
Epoch [12800/20000], Loss: 0.0750
Epoch [12800/20000], Loss: 0.0750
Epoch [12900/20000], Loss: 0.0747
Epoch [13000/20000], Loss: 0.0744
Epoch [13000/20000], Loss: 0.0744
Epoch [13100/20000], Loss: 0.0741
Epoch [13200/20000], Loss: 0.0738
Epoch [13200/20000], Loss: 0.0738
Epoch [13300/20000], Loss: 0.0735
Epoch [13400/20000], Loss: 0.0732
Epoch [13400/20000], Loss: 0.0732
Epoch [13500/20000], Loss: 0.0729
Epoch [13600/20000], Loss: 0.0726
Epoch [13600/20000], Loss: 0.0726
Epoch [13700/20000], Loss: 0.0724
Epoch [13800/20000], Loss: 0.0721
Epoch [13800/20000], Loss: 0.0721
Epoch [13900/20000], Loss: 0.0719
Epoch [14000/20000], Loss: 0.0716
Epoch [14000/20000], Loss: 0.0716
Epoch [14100/20000], Loss: 0.0713
Epoch [14200/20000], Loss: 0.0711
Epoch [14200/20000], Loss: 0.0711
Epoch [14300/20000], Loss: 0.0709
Epoch [14400/20000], Loss: 0.0706
Epoch [14400/20000], Loss: 0.0706
Epoch [14500/20000], Loss: 0.0704
Epoch [14600/20000], Loss: 0.0702
Epoch [14600/20000], Loss: 0.0702
Epoch [14700/20000], Loss: 0.0699
Epoch [14800/20000], Loss: 0.0697
Epoch [14800/20000], Loss: 0.0697
Epoch [14900/20000], Loss: 0.0695
Epoch [15000/20000], Loss: 0.0693
Epoch [15000/20000], Loss: 0.0693
Epoch [15100/20000], Loss: 0.0691
Epoch [15200/20000], Loss: 0.0689
Epoch [15200/20000], Loss: 0.0689
Epoch [15300/20000], Loss: 0.0687
Epoch [15400/20000], Loss: 0.0685
Epoch [15400/20000], Loss: 0.0685
Epoch [15500/20000], Loss: 0.0683
Epoch [15600/20000], Loss: 0.0681
Epoch [15600/20000], Loss: 0.0681
Epoch [15700/20000], Loss: 0.0679
Epoch [15800/20000], Loss: 0.0677
Epoch [15800/20000], Loss: 0.0677
Epoch [15900/20000], Loss: 0.0675
Epoch [16000/20000], Loss: 0.0673
Epoch [16000/20000], Loss: 0.0673
Epoch [16100/20000], Loss: 0.0671
Epoch [16200/20000], Loss: 0.0670
Epoch [16200/20000], Loss: 0.0670
Epoch [16300/20000], Loss: 0.0668
Epoch [16400/20000], Loss: 0.0666
Epoch [16400/20000], Loss: 0.0666
Epoch [16500/20000], Loss: 0.0664
Epoch [16600/20000], Loss: 0.0663
Epoch [16600/20000], Loss: 0.0663
Epoch [16700/20000], Loss: 0.0661
Epoch [16800/20000], Loss: 0.0660
Epoch [16800/20000], Loss: 0.0660
Epoch [16900/20000], Loss: 0.0658
Epoch [17000/20000], Loss: 0.0656
Epoch [17000/20000], Loss: 0.0656
Epoch [17100/20000], Loss: 0.0655
Epoch [17200/20000], Loss: 0.0653
Epoch [17200/20000], Loss: 0.0653
Epoch [17300/20000], Loss: 0.0652
Epoch [17400/20000], Loss: 0.0650
Epoch [17400/20000], Loss: 0.0650
Epoch [17500/20000], Loss: 0.0649
Epoch [17600/20000], Loss: 0.0647
Epoch [17600/20000], Loss: 0.0647
Epoch [17700/20000], Loss: 0.0646
Epoch [17800/20000], Loss: 0.0645
Epoch [17800/20000], Loss: 0.0645
Epoch [17900/20000], Loss: 0.0643
Epoch [18000/20000], Loss: 0.0642
Epoch [18000/20000], Loss: 0.0642
Epoch [18100/20000], Loss: 0.0640
Epoch [18200/20000], Loss: 0.0639
Epoch [18200/20000], Loss: 0.0639
Epoch [18300/20000], Loss: 0.0638
Epoch [18400/20000], Loss: 0.0636
Epoch [18400/20000], Loss: 0.0636
Epoch [18500/20000], Loss: 0.0635
Epoch [18600/20000], Loss: 0.0634
Epoch [18600/20000], Loss: 0.0634
Epoch [18700/20000], Loss: 0.0633
Epoch [18800/20000], Loss: 0.0631
Epoch [18800/20000], Loss: 0.0631
Epoch [18900/20000], Loss: 0.0630
Epoch [19000/20000], Loss: 0.0629
Epoch [19000/20000], Loss: 0.0629
Epoch [19100/20000], Loss: 0.0628
Epoch [19200/20000], Loss: 0.0627
Epoch [19200/20000], Loss: 0.0627
Epoch [19300/20000], Loss: 0.0626
Epoch [19400/20000], Loss: 0.0624
Epoch [19400/20000], Loss: 0.0624
Epoch [19500/20000], Loss: 0.0623
Epoch [19600/20000], Loss: 0.0622
Epoch [19600/20000], Loss: 0.0622
Epoch [19700/20000], Loss: 0.0621
Epoch [19800/20000], Loss: 0.0620
Epoch [19800/20000], Loss: 0.0620
Epoch [19900/20000], Loss: 0.0619
Epoch [20000/20000], Loss: 0.0618
Epoch [20000/20000], Loss: 0.0618
Training time: 8.60 seconds

在这里插入图片描述

1.三种不同的模型可视化方法:推荐torchinfo打印summary+权重分布可视化
print(model)

MLP(
  (fc1): Linear(in_features=4, out_features=10, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=10, out_features=3, bias=True)
)
for name, param in model.named_parameters():
    print(f'Parameter name: {name}, Shape: {param.shape}')

Parameter name: fc1.weight, Shape: torch.Size([10, 4])
Parameter name: fc1.bias, Shape: torch.Size([10])
Parameter name: fc2.weight, Shape: torch.Size([3, 10])
Parameter name: fc2.bias, Shape: torch.Size([3])
import numpy as np
weight_data = {}
for name, param in model.named_parameters():
    if 'weight' in name:
        weight_data[name] = param.detach().cpu().numpy()

fig, axes = plt.subplots(1, len(weight_data), figsize=(15, 5))
fig.suptitle('Weight Distribution of Layers')

for i, (name, weights) in enumerate(weight_data.items()):
    weights_flat = weights.flatten()
    axes[i].hist(weights_flat, bins=50, alpha=0.7)
    axes[i].set_title(name)
    axes[i].set_xlabel('Weight Value')
    axes[i].set_ylabel('Frequency')
    axes[i].grid(True, linestyle='--', alpha=0.7)

plt.tight_layout()
plt.subplots_adjust(top=0.85)
plt.show()

print("\n=== 权重统计信息 ===")

for name, weights in weight_data.items():
    mean = np.mean(weights)
    std = np.std(weights)
    min_val = np.min(weights)
    max_val = np.max(weights)
    print(f"{name}:")
    print(f"  均值: {mean:.6f}")
    print(f"  标准差: {std:.6f}")
    print(f"  最小值: {min_val:.6f}")
    print(f"  最大值: {max_val:.6f}")
    print("-" * 30)

在这里插入图片描述

=== 权重统计信息 ===
fc1.weight:
  均值: 0.038066
  标准差: 0.929686
  最小值: -2.286270
  最大值: 2.450587
------------------------------
fc2.weight:
  均值: -0.023167
  标准差: 1.232054
  最小值: -3.803612
  最大值: 2.585007
------------------------------
from torchsummary import summary

summary(model, input_size=(4,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                   [-1, 10]              50
              ReLU-2                   [-1, 10]               0
            Linear-3                    [-1, 3]              33
================================================================
Total params: 83
Trainable params: 83
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
from torchinfo import summary

summary(model, input_size=(4, ))

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
MLP                                      [3]                       --
├─Linear: 1-1                            [10]                      50
├─ReLU: 1-2                              [10]                      --
├─Linear: 1-3                            [3]                       33
==========================================================================================
Total params: 83
Trainable params: 83
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================
2.进度条功能:手动和自动写法,让打印结果更加美观
from tqdm import tqdm
import time

with tqdm(total=10) as pbar: 
    for i in range(10):
        time.sleep(0.5)
        pbar.update(1)
        
100%|██████████| 10/10 [00:05<00:00,  1.95it/s]
from tqdm import tqdm
import time

with tqdm(total=5, desc='下载文件', unit='个') as pbar:
    for i in range(5):
        time.sleep(1)
        pbar.update(1)

下载文件: 100%|██████████| 5/5 [00:05<00:00,  1.01s/个]
from tqdm import tqdm
import time

for i in tqdm(range(3), desc='处理任务', unit='epoch'):
    time.sleep(1)

处理任务: 100%|██████████| 3/3 [00:03<00:00,  1.01s/epoch]
from tqdm import tqdm
import time

total = 0
with tqdm(total=10, desc='累加进度') as pbar:
    for i in range(1, 11):
        time.sleep(0.3)
        total += i
        pbar.update(1)
        pbar.set_postfix({'当前总和': total})
        
累加进度: 100%|██████████| 10/10 [00:03<00:00,  3.27it/s, 当前总和=55]
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import time
import matplotlib.pyplot as plt
from tqdm import tqdm 

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.LongTensor(y_train).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_test = torch.LongTensor(y_test).to(device)

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(4, 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

model = MLP().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
num_epochs = 20000
losses = []
epochs = []
start_time = time.time()

with tqdm(total=num_epochs, desc='训练进度', unit='epoch') as pbar:
    for epoch in range(num_epochs):
        outputs = model(X_train)
        loss = criterion(outputs, y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 200 == 0:
            losses.append(loss.item())
            epochs.append(epoch + 1)
            pbar.set_postfix({'Loss': f'{loss.item():.4f}'})

        if (epoch + 1) % 1000 == 0:
            pbar.update(1000)

    if pbar.n < num_epochs:
        pbar.update(num_epochs - pbar.n)

time_all = time.time() - start_time
print(f'Training time: {time_all:.2f} seconds')

使用设备: cpu


训练进度: 100%|██████████| 20000/20000 [00:08<00:00, 2395.41epoch/s, Loss=0.0608]

Training time: 8.35 seconds
3.推理的写法:评估模式
model.eval()

with torch.no_grad():
    outputs = model(X_test)
    _, predicted = torch.max(outputs, 1)

    correct = (predicted == y_test).sum().item()
    accuracy = correct / y_test.size(0)
    print(f'测试集准确率: {accuracy * 100:.2f}%')

测试集准确率: 96.67%
作业:调整模型定义时的超参数,对比下效果。
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import time
import matplotlib.pyplot as plt
from tqdm import tqdm 

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}\n')

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.LongTensor(y_train).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_test = torch.LongTensor(y_test).to(device)

class MLP_Original(nn.Module):
    def __init__(self):
        super(MLP_Original, self).__init__()
        self.fc1 = nn.Linear(4, 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

class MLP_Larger(nn.Module):
    def __init__(self):
        super(MLP_Larger, self).__init__()
        self.fc1 = nn.Linear(4, 20)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(20, 10)
        self.fc3 = nn.Linear(10, 3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        return out

class MLP_Smaller(nn.Module):
    def __init__(self):
        super(MLP_Smaller, self).__init__()
        self.fc1 = nn.Linear(4, 5)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(5, 3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

class MLP_Tanh(nn.Module):
    def __init__(self):
        super(MLP_Tanh, self).__init__()
        self.fc1 = nn.Linear(4, 10)
        self.act = nn.Tanh()
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.act(out)
        out = self.fc2(out)
        return out

def train_and_evaluate(model_class, optimizer_class, lr, num_epochs=20000):
    model = model_class().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optimizer_class(model.parameters(), lr=lr)
    
    losses = []
    epochs = []
    start_time = time.time()
    
    with tqdm(total=num_epochs, desc=f'训练 {model_class.__name__}', unit='epoch') as pbar:
        for epoch in range(num_epochs):
            outputs = model(X_train)
            loss = criterion(outputs, y_train)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (epoch + 1) % 200 == 0:
                losses.append(loss.item())
                epochs.append(epoch + 1)
                pbar.set_postfix({'Loss': f'{loss.item():.4f}'})

            if (epoch + 1) % 1000 == 0:
                pbar.update(1000)

        if pbar.n < num_epochs:
            pbar.update(num_epochs - pbar.n)

    time_all = time.time() - start_time
    
    with torch.no_grad():
        outputs = model(X_test)
        _, predicted = torch.max(outputs.data, 1)
        accuracy = (predicted == y_test).sum().item() / y_test.size(0)
    
    print(f'{model_class.__name__} 训练时间: {time_all:.2f}秒, 测试准确率: {accuracy:.4f}\n')
    
    return epochs, losses, accuracy

configs = [
    (MLP_Original, optim.SGD, 0.01),
    (MLP_Larger, optim.SGD, 0.01),
    (MLP_Smaller, optim.SGD, 0.01),
    (MLP_Tanh, optim.SGD, 0.01),
    (MLP_Original, optim.Adam, 0.001),
    (MLP_Original, optim.SGD, 0.1),
    (MLP_Original, optim.SGD, 0.001)
]

plt.figure(figsize=(12, 8))
for config in configs:
    epochs, losses, accuracy = train_and_evaluate(*config)
    plt.plot(epochs, losses, label=f'{config[0].__name__} {config[1].__name__} lr={config[2]} (Acc:{accuracy:.2f})')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Comparison with Different Hyperparameters')
plt.legend()
plt.grid(True)
plt.show()

使用设备: cpu



训练 MLP_Original: 100%|██████████| 20000/20000 [00:08<00:00, 2347.88epoch/s, Loss=0.0629]


MLP_Original 训练时间: 8.52秒, 测试准确率: 0.9667



训练 MLP_Larger: 100%|██████████| 20000/20000 [00:10<00:00, 1848.93epoch/s, Loss=0.0480]


MLP_Larger 训练时间: 10.82秒, 测试准确率: 1.0000



训练 MLP_Smaller: 100%|██████████| 20000/20000 [00:08<00:00, 2366.75epoch/s, Loss=0.1377]


MLP_Smaller 训练时间: 8.45秒, 测试准确率: 0.9667



训练 MLP_Tanh: 100%|██████████| 20000/20000 [00:08<00:00, 2326.77epoch/s, Loss=0.0646]


MLP_Tanh 训练时间: 8.60秒, 测试准确率: 0.9667



训练 MLP_Original: 100%|██████████| 20000/20000 [00:13<00:00, 1468.79epoch/s, Loss=0.0466]


MLP_Original 训练时间: 13.62秒, 测试准确率: 1.0000



训练 MLP_Original: 100%|██████████| 20000/20000 [00:08<00:00, 2334.48epoch/s, Loss=0.0468]


MLP_Original 训练时间: 8.57秒, 测试准确率: 1.0000



训练 MLP_Original: 100%|██████████| 20000/20000 [00:08<00:00, 2334.72epoch/s, Loss=0.4256]


MLP_Original 训练时间: 8.57秒, 测试准确率: 0.9000

在这里插入图片描述

@浙大疏锦行


网站公告

今日签到

点亮在社区的每一天
去签到