一、引言
在机器学习的训练流程中,模型构建是核心环节之一。从传统机器学习的线性模型到深度学习的神经网络,模型的复杂度呈指数级增长。PyTorch 作为主流深度学习框架,通过nn.Module
类提供了统一的模型构建接口,使得复杂网络结构的定义与管理变得高效且规范。
二、三要素
2.1 网络层构建
深度学习模型的基础是各类网络层,常见类型包括:
- 卷积层:
nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0)
,用于提取空间特征(如 LeNet 中的 Conv1/Conv2) - 池化层:
nn.MaxPool2d(kernel_size, stride=None)
,实现特征降维(如 LeNet 中的 Pool1/Pool2) - 全连接层:
nn.Linear(in_features, out_features)
,完成特征到输出的映射(如 LeNet 中的 Fc1/Fc2/Fc3) - 激活函数层:
nn.ReLU()
、nn.Sigmoid()
等,引入非线性(实际代码中常使用nn.functional
中的函数以减少参数)
2.2 网络拼接
以 LeNet 为例,其网络结构可拆解为:
输入`32x32x3`
→ Conv1(6核,5x5,步长1)
→ 输出`28x28x6`
→ MaxPool1(2x2,步长2)
→ 输出`14x14x6`
→ Conv2(16核,5x5,步长1)
→ 输出`10x10x16`
→ MaxPool2(2x2,步长2)
→ 输出`5x5x16`
→ 展平为400维
→ Fc1(400→120)
→ Fc2(120→84)
→ Fc3(84→10)
→ Softmax输出
2.3 权值初始化
合理的初始化可避免梯度消失/爆炸,常见方法:
- Xavier初始化(适用于sigmoid/tanh):
nn.init.xavier_uniform_(weight, gain=1.0)
,保证输入输出方差一致 - Kaiming初始化(适用于ReLU系列):
nn.init.kaiming_normal_(weight, mode='fan_in', nonlinearity='relu')
,考虑激活函数的非线性特性 - 均匀/正态分布:
nn.init.uniform_(weight, -a, a)
、nn.init.normal_(weight, mean=0, std=0.01)
三、nn.Module
3.1 两大要素
每个自定义模型需继承nn.Module
,并实现两大核心方法:
__init__()
:定义子模块(如self.conv1 = nn.Conv2d(...)
),初始化可学习参数forward()
:定义前向传播逻辑,调用子模块并组合运算(禁止直接修改输入张量的内存,需返回新张量)
代码示例:LeNet 模型定义
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 卷积层与池化层
self.conv1 = nn.Conv2d(3, 6, 5) # 输入3通道,输出6通道,5x5卷积核
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool = nn.MaxPool2d(2, 2) # 2x2池化,步长2
# 全连接层
self.fc1 = nn.Linear(16*5*5, 120) # 5x5是池化后尺寸(10/2=5)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x): # x形状:(batch_size, 3, 32, 32)
x = self.pool(F.relu(self.conv1(x))) # (6, 28, 28) → (6, 14, 14)
x = self.pool(F.relu(self.conv2(x))) # (16, 10, 10) → (16, 5, 5)
x = x.view(-1, 16*5*5) # 展平为批量维度+特征维度
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x) # 输出logits,Softmax在损失函数中处理
return x
3.2 四大属性
nn.Module
通过8个OrderedDict
管理内部状态,核心属性包括:
parameters()
:迭代所有可学习参数(如weight、bias),用于优化器更新model = LeNet() for name, param in model.named_parameters(): print(f"参数名:{name},形状:{param.shape}")
modules()
:递归遍历所有子模块(包括自身),用于模型结构检查for name, module in model.named_modules(): print(f"模块名:{name},类型:{type(module)}")
buffers()
:存储非可学习状态(如 BN 层的running_mean
、running_var
)for name, buf in model.named_buffers(): print(f"缓冲区名:{name},形状:{buf.shape}")
hooks
:注册前向/反向钩子函数,用于获取中间层输出(调试或特征可视化)def hook_fn(module, input, output): print(f"模块{type(module).__name__}的输出形状:{output.shape}") handle = model.conv1.register_forward_hook(hook_fn) # 注册钩子 handle.remove() # 使用后移除避免内存泄漏
四、进阶技巧
4.1 层次化设计:子模块复用
复杂模型(如 ResNet)通过定义子模块(如残差块)提升代码复用性:
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return F.relu(out)
4.2 动态形状处理:避免硬编码尺寸
在forward
中通过x.shape
动态获取维度,提升模型通用性(如支持不同输入尺寸的图像)。
4.3 混合使用nn.Module
与nn.functional
- 推荐场景:
- 具有可学习参数的层(如 Conv、Linear、BN)使用
nn.Module
子类 - 无参数的运算(如激活函数、池化、展平)使用
nn.functional
函数(减少内存占用,更灵活)
- 具有可学习参数的层(如 Conv、Linear、BN)使用
五、注意事项
- 避免在
forward
中使用Python控制流:如需条件判断或循环,尽量使用PyTorch内置函数(如torch.where
),以保证模型可序列化和JIT编译 - 参数初始化的显式调用:在
__init__
中对自定义层进行初始化,避免使用默认初始化(如全零初始化可能导致对称性破缺) - 模型保存与加载:使用
torch.save(model.state_dict(), 'model.pth')
保存参数,加载时通过model.load_state_dict(torch.load('model.pth'))
恢复,保持模块命名一致性
微语录:不要因为走得太远,而忘记为什么出发。— — 卡里·纪伯伦