PyTorch模型设计入门:从零编写一个完整的__init__
函数
作为初学者,理解如何正确编写PyTorch模型的__init__
函数是构建神经网络的第一步。本文将通过一个的简单模型示例,介绍__init__
的核心要素,并解释每一行代码的作用。
1. 示例模型:用户-物品评分预测器
假设我们要构建一个预测用户对物品评分的简单模型,包含以下功能:
- 用户和物品的嵌入表示(Embedding)
- 全连接层(Linear)进行特征变换
- Dropout层防止过拟合
- 自定义权重初始化
2. 完整代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class RatingPredictor(nn.Module):
def __init__(self, num_users, num_items, embedding_dim=64, hidden_dim=128, dropout_rate=0.2):
"""
初始化评分预测模型
:param num_users: 用户数量
:param num_items: 物品数量
:param embedding_dim: 嵌入维度
:param hidden_dim: 隐藏层维度
:param dropout_rate: Dropout概率
"""
super(RatingPredictor, self).__init__() # 必须调用父类初始化
# === 1. 保存输入参数 ===
self.num_users = num_users
self.num_items = num_items
self.dropout_rate = dropout_rate
# === 2. 定义模型层 ===
# 用户和物品的嵌入层
self.user_embedding = nn.Embedding(num_users, embedding_dim)
self.item_embedding = nn.Embedding(num_items, embedding_dim)
# 全连接层
self.fc1 = nn.Linear(embedding_dim * 2, hidden_dim) # 输入是用户和物品嵌入的拼接
self.fc2 = nn.Linear(hidden_dim, 1) # 输出1个评分值
# Dropout层
self.dropout = nn.Dropout(dropout_rate)
# === 3. 初始化参数 ===
self._init_weights()
def _init_weights(self):
"""自定义参数初始化"""
# 嵌入层使用正态分布初始化
nn.init.normal_(self.user_embedding.weight, mean=0.0, std=0.01)
nn.init.normal_(self.item_embedding.weight, mean=0.0, std=0.01)
# 全连接层使用Xavier初始化
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
# 偏置初始化为0
nn.init.zeros_(self.fc1.bias)
nn.init.zeros_(self.fc2.bias)
def forward(self, user_ids, item_ids):
"""前向传播"""
# 获取嵌入
u_embed = self.user_embed(user_ids)
i_embed = self.item_embed(item_ids)
# 特征拼接
features = torch.cat([u_embed, i_embed], dim=1)
# 特征转换
hidden = self.feature_net(features)
# 预测评分
ratings = self.rating_head(hidden).squeeze(-1)
return ratings
3. 逐行解析__init__
3.1 继承父类
super(RatingPredictor, self).__init__()
- 作用:确保PyTorch能正确管理模型的参数和状态。
- 必须性:漏掉会导致模型无法识别可训练参数!
3.2 保存输入参数
self.num_users = num_users
self.dropout_rate = dropout_rate
- 目的:将外部传入的配置(如用户数量、Dropout率)存储为模型属性,供后续使用。
3.3 定义模型层
(1) 嵌入层(Embedding)
self.user_embedding = nn.Embedding(num_users, embedding_dim)
- 功能:将用户ID(整数)映射为稠密向量(
embedding_dim
维)。 - 类比:类似字典查询,输入用户ID,返回对应的向量。
(2) 全连接层(Linear)
self.fc1 = nn.Linear(embedding_dim * 2, hidden_dim)
- 输入:用户和物品嵌入的拼接(故维度是
embedding_dim * 2
)。 - 输出:隐藏层表示(
hidden_dim
维)。
(3) Dropout层
self.dropout = nn.Dropout(dropout_rate)
- 作用:随机丢弃部分神经元输出,防止过拟合(训练时生效,测试时自动关闭)。
3.4 参数初始化
nn.init.normal_(self.user_embedding.weight, mean=0.0, std=0.01)
nn.init.xavier_uniform_(self.fc1.weight)
- 嵌入层:用小标准差的正态分布初始化,避免初始值过大。
- 全连接层:用Xavier初始化,保持输入输出方差一致。
- 偏置:初始化为0。
4. 为什么这样设计?
组件 | 设计理由 |
---|---|
嵌入层 | 将离散ID转为连续向量,便于模型处理 |
全连接层 | 学习用户和物品嵌入的非线性交互 |
Dropout | 提高泛化能力,防止训练数据过拟合 |
自定义初始化 | 避免梯度消失/爆炸,加速收敛 |
5. 如何使用这个模型?
# 初始化模型
model = RatingPredictor(
num_users=1000,
num_items=500,
embedding_dim=64,
dropout_rate=0.2
)
# 模拟输入
user_ids = torch.tensor([1, 2, 3]) # 3个用户的ID
item_ids = torch.tensor([4, 5, 6]) # 3个物品的ID
# 前向传播(需实现forward方法)
predictions = model(user_ids, item_ids)
6. 初学者常见问题
Q1:可以不写super().__init__()
吗?
- 不行! 缺少它会导致模型无法注册参数,
model.parameters()
将为空。
Q2:为什么用nn.ModuleList
而不是Python列表?
- 如果有多层结构(如多个全连接层),应使用
nn.ModuleList
或nn.Sequential
,PyTorch才能识别其中的参数。
Q3:参数初始化是否必须?
- 不是必须,但好的初始化能加速训练。PyTorch的Linear层已有默认初始化。
7. 总结
一个完整的__init__
函数应包含:
- 继承父类:
super().__init__()
- 保存参数:将配置存储为属性
- 定义网络层:如Embedding、Linear、Dropout
- 初始化参数:控制模型训练的起点