DAY07:【pytorch】模型构建

发布于:2025-04-17 ⋅ 阅读:(32) ⋅ 点赞:(0)

一、引言

在机器学习的训练流程中,模型构建是核心环节之一。从传统机器学习的线性模型到深度学习的神经网络,模型的复杂度呈指数级增长。PyTorch 作为主流深度学习框架,通过nn.Module类提供了统一的模型构建接口,使得复杂网络结构的定义与管理变得高效且规范。

二、三要素

2.1 网络层构建

深度学习模型的基础是各类网络层,常见类型包括:

  • 卷积层nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0),用于提取空间特征(如 LeNet 中的 Conv1/Conv2)
  • 池化层nn.MaxPool2d(kernel_size, stride=None),实现特征降维(如 LeNet 中的 Pool1/Pool2)
  • 全连接层nn.Linear(in_features, out_features),完成特征到输出的映射(如 LeNet 中的 Fc1/Fc2/Fc3)
  • 激活函数层nn.ReLU()nn.Sigmoid()等,引入非线性(实际代码中常使用nn.functional中的函数以减少参数)

2.2 网络拼接

以 LeNet 为例,其网络结构可拆解为:

输入`32x32x3`
→ Conv1(6核,5x5,步长1)
→ 输出`28x28x6`
→ MaxPool1(2x2,步长2)
→ 输出`14x14x6`
→ Conv2(16核,5x5,步长1)
→ 输出`10x10x16`
→ MaxPool2(2x2,步长2)
→ 输出`5x5x16`
→ 展平为400维
→ Fc1(400→120)
→ Fc2(120→84)
→ Fc3(84→10)
→ Softmax输出

2.3 权值初始化

合理的初始化可避免梯度消失/爆炸,常见方法:

  • Xavier初始化(适用于sigmoid/tanh):
    nn.init.xavier_uniform_(weight, gain=1.0),保证输入输出方差一致
  • Kaiming初始化(适用于ReLU系列):
    nn.init.kaiming_normal_(weight, mode='fan_in', nonlinearity='relu'),考虑激活函数的非线性特性
  • 均匀/正态分布
    nn.init.uniform_(weight, -a, a)nn.init.normal_(weight, mean=0, std=0.01)

三、nn.Module

3.1 两大要素

每个自定义模型需继承nn.Module,并实现两大核心方法:

  1. __init__():定义子模块(如self.conv1 = nn.Conv2d(...)),初始化可学习参数
  2. forward():定义前向传播逻辑,调用子模块并组合运算(禁止直接修改输入张量的内存,需返回新张量)

代码示例:LeNet 模型定义

import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 卷积层与池化层
        self.conv1 = nn.Conv2d(3, 6, 5)  # 输入3通道,输出6通道,5x5卷积核
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool = nn.MaxPool2d(2, 2)  # 2x2池化,步长2
        # 全连接层
        self.fc1 = nn.Linear(16*5*5, 120)  # 5x5是池化后尺寸(10/2=5)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):  # x形状:(batch_size, 3, 32, 32)
        x = self.pool(F.relu(self.conv1(x)))  # (6, 28, 28) → (6, 14, 14)
        x = self.pool(F.relu(self.conv2(x)))  # (16, 10, 10) → (16, 5, 5)
        x = x.view(-1, 16*5*5)  # 展平为批量维度+特征维度
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  # 输出logits,Softmax在损失函数中处理
        return x

3.2 四大属性

nn.Module通过8个OrderedDict管理内部状态,核心属性包括:

  1. parameters():迭代所有可学习参数(如weight、bias),用于优化器更新

    model = LeNet()
    for name, param in model.named_parameters():
        print(f"参数名:{name},形状:{param.shape}")
    
  2. modules():递归遍历所有子模块(包括自身),用于模型结构检查

    for name, module in model.named_modules():
        print(f"模块名:{name},类型:{type(module)}")
    
  3. buffers():存储非可学习状态(如 BN 层的running_meanrunning_var

    for name, buf in model.named_buffers():
        print(f"缓冲区名:{name},形状:{buf.shape}")
    
  4. hooks:注册前向/反向钩子函数,用于获取中间层输出(调试或特征可视化)

    def hook_fn(module, input, output):
        print(f"模块{type(module).__name__}的输出形状:{output.shape}")
    handle = model.conv1.register_forward_hook(hook_fn)  # 注册钩子
    handle.remove()  # 使用后移除避免内存泄漏
    

四、进阶技巧

4.1 层次化设计:子模块复用

复杂模型(如 ResNet)通过定义子模块(如残差块)提升代码复用性:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

4.2 动态形状处理:避免硬编码尺寸

forward中通过x.shape动态获取维度,提升模型通用性(如支持不同输入尺寸的图像)。

4.3 混合使用nn.Modulenn.functional

  • 推荐场景
    • 具有可学习参数的层(如 Conv、Linear、BN)使用nn.Module子类
    • 无参数的运算(如激活函数、池化、展平)使用nn.functional函数(减少内存占用,更灵活)

五、注意事项

  1. 避免在forward中使用Python控制流:如需条件判断或循环,尽量使用PyTorch内置函数(如torch.where),以保证模型可序列化和JIT编译
  2. 参数初始化的显式调用:在__init__中对自定义层进行初始化,避免使用默认初始化(如全零初始化可能导致对称性破缺)
  3. 模型保存与加载:使用torch.save(model.state_dict(), 'model.pth')保存参数,加载时通过model.load_state_dict(torch.load('model.pth'))恢复,保持模块命名一致性

微语录:不要因为走得太远,而忘记为什么出发。— — 卡里·纪伯伦


网站公告

今日签到

点亮在社区的每一天
去签到