在(一)中我们已经介绍了模型需要完成的任务,本篇文章主要来介绍下模型的具体结构,以及训练的过程。
本人是人工智能专业本科在读,能力有限,模型结构略显粗糙,欢迎批评指正,如果你对模型有更好的修改建议欢迎私信我,共同交流进步。
项目资源,百度网盘链接:https://pan.baidu.com/s/1THHHOxzT3GBFtXDwY0kDlw?pwd=6vcf
一.模型结构
直接上图,一切都在图中了
然后接下来就是具体的层的剖析
用户信息嵌入层
其中有三个嵌入矩阵
分别是用户性别嵌入矩阵[2,32],用户年龄嵌入矩阵[7,32],用户职业嵌入矩阵[21,32]
将对应的离散信息 查找对应嵌入矩阵的对应行 转换成连续的32维向量
这个嵌入矩阵是可训练的参数
用户信息全连接层
将三个绿色的用户的特征向量直接cat拼接,成为96维的向量,然后开始通过全连接层
这个全连接层内部细分其实有两个全连接层:
归一化,按batch进行归一化,可以有效避免梯度消失or梯度爆炸
使用不同激活函数,增加模型复杂性
注意添加dropout,防止过拟合
文本卷积层
文本卷积通过滑动窗口提取文本中的局部特征,捕捉单词组合和语序信息,使用不同大小的卷积核提取多样化特征,并通过降维简化特征表示,从而提高模型的表达能力和计算效率。
这是个文本卷积的示意图,一个单词代表一行的向量,红色的框就是一个卷积核,作用于原矩阵,步幅为1,即可得到一列的特征向量(图中从左往右第二幅图应该少了一格),类似的,黄色又是一个卷积核,这些不同颜色代表不同规格的卷积核,同一规格的卷积核又能有多个num_filters(代码中取64)
经过上述处理的一些列特征向量,将其通过最大池化降维,然后再全连接。
上面是文本卷积的基本原理示意,下图是我设计的模型中的文本卷积结构:
首先是10* 50的词向量矩阵,
将64个卷积核1(2*50)作用于它,会得到64个 9*1 的特征向量,这些特征向量依次通过最大池化层得到64个特征值
卷积核2,卷积核3过程同上
最后将 这些个特征值进行拼接,得到 192长度的电影标题特征向量
电影标题全连接层
这个全连接层是专门用于电影标题的,其实是属于卷积网络的一部分,因为卷积的最后往往伴随着利用全连接进行降维。
没啥好说,直接上图吧
电影体裁嵌入层
这个嵌入层与用户信息嵌入层类似,但是存在一个问题需要解决。
用户信息比如 性别,年龄,都是一个单独的离散值,我们可以通过这个离散值直接索引对应嵌入矩阵的行向量,从而转变成向量。
但是电影体裁,往往有多个类别,比如
就有三个类别
解决方式:
每个类别对应一个离散值,根据这个离散值在嵌入矩阵中寻找对应的向量,得到若干个向量后直接相加就好咯,其实可以再平均化一下,避免数值过大,但是代码中没实现,读者可自行尝试添加。
电影信息全连接层
也没啥好说,直接上图吧,简洁明了
到这一步止,我们已经获得了 用户综合信息(256) 和 电影综合信息(256),这两个其实是我们最终需要的东西,但是模型的参数需要训练,训练就需要个目标,我们的目标是用户对电影的评分,现在考虑如何通过用户和电影的综合信息 得到评分
评分预测层
其实最初我只是简单的直接 用户综合信息 和 电影综合信息 做cat拼接然后传入全连接层,但效果不好。
然后采用逐元素相乘获取交叉信息,然后将三个拼接后传入全连接层,模型损失得到了进一步降低。
然后依次经过两个全连接层逐步降维,最后输出评分
输出的评分和真实评分 计算MSE损失,进行模型的训练
以上就是模型的基本架构了,代码如下:
二.模型代码 pytorch编写
import torch
import torch.nn as nn
# 文本卷积网络
class TextCNN(nn.Module):
def __init__(self, embed_dim=50, num_filters=64):
super().__init__()
# 三种卷积核
self.convs = nn.ModuleList([
nn.Conv2d(1, num_filters, (kernel_size, embed_dim))
for kernel_size in [2, 3, 4]
])
# 自适应池化,统一输出尺寸
self.pools = nn.ModuleList([
nn.AdaptiveMaxPool2d((1, 1)) for _ in range(3)
])
def forward(self, x):
# x形状: [batch_size, 1, seq_len(10), embed_dim(50)]
features = []
for conv, pool in zip(self.convs, self.pools):
conv_out = conv(x) # [batch, num_filters, H, 1]
pooled = pool(conv_out) # [batch, num_filters, 1, 1]
features.append(pooled.squeeze(-1).squeeze(-1)) # [batch, num_filters]
combined = torch.cat(features, dim=1) # [batch, num_filters * 3]
return combined
class MovieRecommendationModel_Rating(nn.Module):
def __init__(self):
super().__init__()
# ----------------- 用户特征分支 -----------------
# 嵌入层
self.user_gender_embedding = nn.Embedding(2, 32)
self.user_age_embedding = nn.Embedding(7, 32)
self.user_occupation_embedding = nn.Embedding(21, 32)
# 全连接层(增加BatchNorm和Dropout)
self.user_fc = nn.Sequential(
nn.Linear(32 * 3, 200),
nn.BatchNorm1d(200),
nn.Tanh(),
nn.Dropout(0.3),
nn.Linear(200, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3)
)
# ----------------- 电影特征分支 -----------------
# 类别嵌入层
self.movie_genres_embedding = nn.Embedding(18 + 1, 32, padding_idx=0)
# 文本卷积层(使用TextCNN模块)
self.text_cnn = TextCNN(embed_dim=50, num_filters=64)
self.movie_title_fc = nn.Sequential(
nn.Linear(64 * 3, 32),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.Dropout(0.3)
)
# 电影全连接层
self.movie_fc = nn.Sequential(
nn.Linear(32 + 32, 128), # 类别嵌入 + 文本特征
nn.BatchNorm1d(128),
nn.Tanh(),
nn.Dropout(0.3),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3)
)
# ----------------- 评分预测层 -----------------
self.rating_predictor = nn.Sequential(
nn.Linear(256 * 3, 128), # 用户特征 + 电影特征 + 交互特征
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(128, 1)
)
def forward(self, user_gender, user_age, user_occupation,
movie_genres_vector, movie_title_matrix):
# ----------------- 用户特征 -----------------
# 嵌入层
gender_emb = self.user_gender_embedding(user_gender) # [B,32]
age_emb = self.user_age_embedding(user_age) # [B,32]
occupation_emb = self.user_occupation_embedding(user_occupation) # [B,32]
# 拼接并全连接
user_features = torch.cat([gender_emb, age_emb, occupation_emb], dim=1)
user_features = self.user_fc(user_features) # [B,256]
# ----------------- 电影特征 -----------------
# 类别嵌入(处理多类别)
genre_emb = self.movie_genres_embedding(movie_genres_vector) # [B, num_genres,32]
genre_emb = torch.sum(genre_emb, dim=1) # [B,32]
# 文本卷积
title_feat = movie_title_matrix.unsqueeze(1) # [B,1, seq_len,50]
title_feat = self.text_cnn(title_feat) # [B, 64*3]
title_feat = self.movie_title_fc(title_feat) # [B,32]
# 拼接并全连接
movie_features = torch.cat([genre_emb, title_feat], dim=1)
movie_features = self.movie_fc(movie_features) # [B,256]
# ----------------- 特征交互与评分预测 -----------------
# 交互特征:用户特征 * 电影特征
interaction = user_features * movie_features # [B,256]
# 拼接所有特征
combined = torch.cat([
user_features,
movie_features,
interaction
], dim=1) # [B, 256*3]
# 预测评分
rating = self.rating_predictor(combined).squeeze() # [B]
return rating
'''
# 使用实例
# 初始化模型
model = MovieRecommendationModel_Rating()
model.eval()
# 输入示例
user_gender = torch.tensor([0])
user_age = torch.tensor([3])
user_occupation = torch.tensor([10])
movie_genres_vector = torch.tensor(
[[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
movie_title_matrix = torch.tensor([[[ 3.1474e-01, 4.1662e-01, 1.3480e-01, 1.5854e-01, 8.8812e-01,
4.3317e-01, -5.5916e-01, 3.0476e-02, -1.4623e-01, -1.4273e-01,
-1.7949e-01, -1.7343e-01, -4.9264e-01, 2.6775e-01, 4.8799e-01,
-2.9537e-01, 1.8485e-01, 1.4937e-01, -7.5009e-01, -3.5651e-01,
-2.3699e-01, 1.8490e-01, 1.7237e-01, 2.3611e-01, 1.4077e-01,
-1.9031e+00, -6.5353e-01, -2.2539e-02, 1.0383e-01, -4.3705e-01,
3.7810e+00, -4.4077e-02, -4.6643e-02, 2.7274e-02, 5.1883e-01,
1.3353e-01, 2.3231e-01, 2.5599e-01, 6.0888e-02, -6.5618e-02,
-1.5556e-01, 3.0818e-01, -9.3586e-02, 3.3296e-01, -1.4613e-01,
1.6332e-02, -2.4251e-01, -2.0526e-01, 7.0090e-02, -1.1568e-01],
[ 9.0684e-01, -4.4680e-02, 4.0558e-01, -4.0515e-01, -2.0227e-02,
-5.5786e-01, -1.2427e+00, 1.4452e-01, -1.0421e-01, -9.5475e-01,
1.5414e-01, -2.5213e-01, -3.1168e-01, 8.3801e-01, -4.4238e-01,
-3.7311e-03, -9.3742e-01, 2.5382e-01, -1.6882e+00, -6.4045e-02,
-4.0331e-02, 1.0587e+00, 1.7816e-01, -7.2627e-01, 2.1079e-01,
-1.4382e+00, 8.1284e-01, 6.3073e-01, -1.3734e-01, -4.5729e-02,
1.2641e+00, 4.0018e-01, -5.1226e-01, 2.4990e-01, 3.7707e-01,
-2.2095e-01, 2.6125e-01, 1.8248e-01, -7.7887e-01, 5.9554e-01,
-9.2233e-02, 5.9457e-01, 4.9302e-01, -8.4481e-01, 7.2107e-02,
2.1698e-01, -4.1040e-01, -3.8655e-01, -8.4270e-01, -2.8183e-01],
[ 1.2972e-01, 8.8073e-02, 2.4375e-01, 7.8102e-02, -1.2783e-01,
2.7831e-01, -4.8693e-01, 1.9649e-01, -3.9558e-01, -2.8362e-01,
-4.7425e-01, -5.9317e-01, -5.8804e-01, -3.1702e-01, 4.9593e-01,
8.7594e-03, 3.9613e-02, -4.2495e-01, -9.7641e-01, -4.6534e-01,
2.0675e-02, 8.6042e-02, 3.9317e-01, -5.1255e-01, -1.7913e-01,
-1.8333e+00, 5.6220e-01, 4.1626e-01, 7.5127e-02, 2.1890e-02,
3.7840e+00, 7.1067e-01, -7.3943e-02, 1.5373e-01, -3.8530e-01,
-7.0163e-02, -3.5374e-01, 7.4501e-02, -8.4228e-02, -4.5548e-01,
-8.1068e-02, 3.9157e-01, 1.7300e-01, 2.2540e-01, -1.2836e-01,
4.0951e-01, -2.6079e-01, 9.0912e-02, -6.0515e-01, -9.8270e-01],
[ 4.1800e-01, 2.4968e-01, -4.1242e-01, 1.2170e-01, 3.4527e-01,
-4.4457e-02, -4.9688e-01, -1.7862e-01, -6.6023e-04, -6.5660e-01,
2.7843e-01, -1.4767e-01, -5.5677e-01, 1.4658e-01, -9.5095e-03,
1.1658e-02, 1.0204e-01, -1.2792e-01, -8.4430e-01, -1.2181e-01,
-1.6801e-02, -3.3279e-01, -1.5520e-01, -2.3131e-01, -1.9181e-01,
-1.8823e+00, -7.6746e-01, 9.9051e-02, -4.2125e-01, -1.9526e-01,
4.0071e+00, -1.8594e-01, -5.2287e-01, -3.1681e-01, 5.9213e-04,
7.4449e-03, 1.7778e-01, -1.5897e-01, 1.2041e-02, -5.4223e-02,
-2.9871e-01, -1.5749e-01, -3.4758e-01, -4.5637e-02, -4.4251e-01,
1.8785e-01, 2.7849e-03, -1.8411e-01, -1.1514e-01, -7.8581e-01],
[ 1.2304e+00, 9.2381e-01, 1.5314e-01, -5.8599e-01, 5.9628e-01,
7.6089e-01, -1.0720e+00, -5.0380e-01, 2.4693e-01, -1.0284e+00,
1.2579e-02, 4.3244e-01, 8.5405e-01, 8.6211e-01, -3.7902e-01,
9.6824e-01, 8.2349e-01, 1.7716e-01, -1.4482e+00, 8.9373e-02,
-5.7099e-01, -6.4416e-01, 3.3518e-01, -2.3957e-01, -2.2411e-01,
-5.4009e-02, 1.5806e-01, 6.6506e-01, 9.1831e-01, -4.6689e-01,
9.8718e-01, 4.4871e-01, -2.1724e-01, 3.7540e-02, -5.8915e-01,
4.4585e-01, 2.9472e-01, -1.7280e-01, 1.4035e-01, -6.6766e-01,
-1.4425e-02, -1.1219e+00, -2.2225e-01, 1.2725e+00, 1.0586e+00,
-2.1209e-01, 4.0087e-02, -9.5260e-01, -3.0329e-01, -3.9670e-02],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]])
# Forward pass
rating = model(user_gender, user_age, user_occupation, movie_genres_vector, movie_title_matrix)
print(f"Predicted Rating: {rating.tolist()}")
'''
心得:文本卷积核中增加num_filter很重要,直接将模型损失降低了0.2,
归一化也很重要,最后的评分预测,获取交叉信息很重要
本人能力着实有限,已经是我能调到的最好结果了,5个epoch训练损失0.9左右,如果你有更好的想法欢迎私信我,沟通交流。
三.模型训练代码
import torch
from data_process_load import MovieLensDataset
from model_1 import MovieRecommendationModel_Rating
from torch.utils.data import DataLoader, random_split
from torch import optim
import torch.nn as nn
# 文件路径
ratings_file = 'data/ml-1m/ratings.dat'
movies_file = 'data/ml-1m/movies.dat'
users_file = 'data/ml-1m/users.dat'
glove_file = 'data/glove.6B.50d.txt'
print(f"data reading ...")
dataset = MovieLensDataset(ratings_file, movies_file, users_file, glove_file)
# 划分训练集和测试集
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
# DataLoader
batch_size = 64
print(f"train_data loading ...")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print(f"test_data loading ...")
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 初始化模型
model = MovieRecommendationModel_Rating()
# 选择优化器和损失函数
optimizer = optim.Adam(model.parameters(), lr=0.0015)
criterion = nn.MSELoss() # 均方误差损失
def train(model, train_loader, optimizer, criterion, device, print_interval=64):
model.train()
running_loss = 0.0
batch_count = 0 # 用于记录已处理的batch数量
for i, data in enumerate(train_loader):
user_gender = data["user_gender"]
user_age = data["user_age"]
user_occupation = data["user_occupation"]
movie_genres_vector = data["movie_genres_vector"]
movie_title_matrix = data["movie_title_matrix"]
ratings = data["rating"]
# 将数据转移到设备上
user_gender, user_age, user_occupation = user_gender.to(device), user_age.to(device), user_occupation.to(device)
movie_genres_vector, movie_title_matrix = movie_genres_vector.to(device), movie_title_matrix.to(device)
ratings = ratings.to(device)
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = model(user_gender, user_age, user_occupation, movie_genres_vector, movie_title_matrix)
# 计算损失
loss = criterion(outputs.squeeze(), ratings)
# 反向传播
loss.backward()
# 更新参数
optimizer.step()
running_loss += loss.item()
batch_count += 1
# 每 64 个 batch 输出一次损失
if batch_count % print_interval == 0:
avg_loss = running_loss / batch_count
print(f"Batch [{batch_count}], Loss: {avg_loss:.4f}")
return running_loss / len(train_loader)
# 训练和测试
import time
start_time = time.time()
print(f"torch.cuda.is_available():{torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
num_epochs = 5
for epoch in range(num_epochs):
# 训练
print(f"Epoch:{epoch}")
train_loss = train(model, train_loader, optimizer, criterion, device)
print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}")
# 保存模型参数
torch.save(model.state_dict(), 'model_params.pth')
end_time = time.time()
elapsed_time = end_time - start_time # 计算运行时间
print(f"代码运行时间:{elapsed_time:.6f}秒")
模型损失最终在0.9左右,这个训练时间较长,接近1h15min左右,可以根据你GPU显存调整batchsize大小,加快训练速度
也可以直接使用我训练过的模型参数,pth文件在文章开头提供的百度网盘里。
本篇是(二)主要介绍了模型的结构和参数分布,训练代码
其余部分欢迎点击作者主页继续查看,遇到任何问题欢迎私信我
苦逼在校大学生,如果你觉得对你有用的话,欢迎打赏: