【基于ALS模型的教育视频推荐系统(Java实现)】
下面是一个完整的基于交替最小二乘法(ALS)的教育视频推荐系统实现,包含数据预处理、模型训练、推荐生成和评估模块。
1. 系统架构
edu-recommender/
├── src/
│ ├── main/
│ │ ├── java/
│ │ │ ├── model/ # 数据模型
│ │ │ ├── algorithm/ # ALS算法实现
│ │ │ ├── service/ # 业务逻辑
│ │ │ ├── util/ # 工具类
│ │ │ └── Main.java # 入口类
│ │ └── resources/ # 配置文件
├── pom.xml # Maven依赖
└── data/ # 示例数据集
2.1 数据模型类
// Video.java
package model;
public class Video {
private int id;
private String title;
private String category;
private double duration; // 分钟
// 构造函数、getter和setter
}
// User.java
package model;
public class User {
private int id;
private String username;
private String educationLevel;
// 构造函数、getter和setter
}
// Rating.java
package model;
public class Rating {
private int userId;
private int videoId;
private double score; // 1-5分
// 构造函数、getter和setter
}
2.2 ALS算法实现
package algorithm;
import java.util.*;
public class ALS {
private int numFeatures; // 特征维度
private double lambda; // 正则化参数
private int maxIter; // 最大迭代次数
private double[][] userFeatures; // 用户特征矩阵
private double[][] itemFeatures; // 物品特征矩阵
public ALS(int numFeatures, double lambda, int maxIter) {
this.numFeatures = numFeatures;
this.lambda = lambda;
this.maxIter = maxIter;
}
// 训练模型
public void train(List<Rating> ratings, int numUsers, int numItems) {
// 初始化特征矩阵
Random rand = new Random();
userFeatures = new double[numUsers][numFeatures];
itemFeatures = new double[numItems][numFeatures];
for (int i = 0; i < numUsers; i++) {
for (int j = 0; j < numFeatures; j++) {
userFeatures[i][j] = rand.nextDouble();
}
}
for (int i = 0; i < numItems; i++) {
for (int j = 0; j < numFeatures; j++) {
itemFeatures[i][j] = rand.nextDouble();
}
}
// 交替优化
for (int iter = 0; iter < maxIter; iter++) {
// 固定物品特征,优化用户特征
updateFeatures(ratings, userFeatures, itemFeatures, true);
// 固定用户特征,优化物品特征
updateFeatures(ratings, itemFeatures, userFeatures, false);
double error = calculateRMSE(ratings);
System.out.printf("Iteration %d, RMSE: %.4f\n", iter, error);
}
}
// 更新特征矩阵
private void updateFeatures(List<Rating> ratings, double[][] mainFeatures,
double[][] fixedFeatures, boolean isUser) {
Map<Integer, List<Rating>> groupedRatings = groupRatings(ratings, isUser);
for (Map.Entry<Integer, List<Rating>> entry : groupedRatings.entrySet()) {
int id = entry.getKey();
List<Rating> group = entry.getValue();
// 构建矩阵A和向量b
double[][] A = new double[numFeatures][numFeatures];
double[] b = new double[numFeatures];
for (Rating r : group) {
int otherId = isUser ? r.getVideoId() : r.getUserId();
double[] otherVec = fixedFeatures[otherId];
// A += otherVec * otherVec^T
for (int i = 0; i < numFeatures; i++) {
for (int j = 0; j < numFeatures; j++) {
A[i][j] += otherVec[i] * otherVec[j];
}
b[i] += otherVec[i] * r.getScore();
}
}
// 添加正则化项: A += lambda * I
for (int i = 0; i < numFeatures; i++) {
A[i][i] += lambda;
}
// 解线性方程组: A * x = b
double[] newFeatures = solveLinearSystem(A, b);
System.arraycopy(newFeatures, 0, mainFeatures[id], 0, numFeatures);
}
}
// 分组评分数据
private Map<Integer, List<Rating>> groupRatings(List<Rating> ratings, boolean byUser) {
Map<Integer, List<Rating>> map = new HashMap<>();
for (Rating r : ratings) {
int key = byUser ? r.getUserId() : r.getVideoId();
map.computeIfAbsent(key, k -> new ArrayList<>()).add(r);
}
return map;
}
// 解线性方程组(使用高斯消元法)
private double[] solveLinearSystem(double[][] A, double[] b) {
int n = b.length;
double[] x = new double[n];
// 前向消元
for (int i = 0; i < n; i++) {
// 找主元
int maxRow = i;
for (int k = i + 1; k < n; k++) {
if (Math.abs(A[k][i]) > Math.abs(A[maxRow][i])) {
maxRow = k;
}
}
// 交换行
double[] tempRow = A[i];
A[i] = A[maxRow];
A[maxRow] = tempRow;
double tempVal = b[i];
b[i] = b[maxRow];
b[maxRow] = tempVal;
// 消元
for (int k = i + 1; k < n; k++) {
double factor = A[k][i] / A[i][i];
b[k] -= factor * b[i];
for (int j = i; j < n; j++) {
A[k][j] -= factor * A[i][j];
}
}
}
// 回代
for (int i = n - 1; i >= 0; i--) {
double sum = 0;
for (int j = i + 1; j < n; j++) {
sum += A[i][j] * x[j];
}
x[i] = (b[i] - sum) / A[i][i];
}
return x;
}
// 计算RMSE
public double calculateRMSE(List<Rating> ratings) {
double sumSquaredError = 0;
for (Rating r : ratings) {
double predicted = predict(r.getUserId(), r.getVideoId());
sumSquaredError += Math.pow(predicted - r.getScore(), 2);
}
return Math.sqrt(sumSquaredError / ratings.size());
}
// 预测评分
public double predict(int userId, int videoId) {
double score = 0;
for (int i = 0; i < numFeatures; i++) {
score += userFeatures[userId][i] * itemFeatures[videoId][i];
}
return Math.max(1, Math.min(5, score)); // 限制在1-5分
}
// 为用户推荐视频
public List<Integer> recommendVideos(int userId, int numRecommendations, int numVideos) {
PriorityQueue<VideoScore> pq = new PriorityQueue<>();
for (int videoId = 0; videoId < numVideos; videoId++) {
double score = predict(userId, videoId);
pq.offer(new VideoScore(videoId, score));
if (pq.size() > numRecommendations) {
pq.poll();
}
}
List<Integer> recommendations = new ArrayList<>();
while (!pq.isEmpty()) {
recommendations.add(0, pq.poll().videoId);
}
return recommendations;
}
// 辅助类
private static class VideoScore implements Comparable<VideoScore> {
int videoId;
double score;
VideoScore(int videoId, double score) {
this.videoId = videoId;
this.score = score;
}
@Override
public int compareTo(VideoScore other) {
return Double.compare(this.score, other.score);
}
}
}
2.3 推荐服务类
package service;
import model.*;
import algorithm.ALS;
import java.util.*;
public class RecommendationService {
private List<User> users;
private List<Video> videos;
private List<Rating> ratings;
private ALS alsModel;
public RecommendationService(List<User> users, List<Video> videos, List<Rating> ratings) {
this.users = users;
this.videos = videos;
this.ratings = ratings;
}
// 训练推荐模型
public void trainModel() {
int numUsers = users.size();
int numVideos = videos.size();
alsModel = new ALS(10, 0.01, 20); // 10个特征,lambda=0.01,20次迭代
alsModel.train(ratings, numUsers, numVideos);
}
// 为用户生成推荐
public List<Video> getRecommendations(int userId, int numRecs) {
List<Integer> videoIds = alsModel.recommendVideos(userId, numRecs, videos.size());
List<Video> recommendations = new ArrayList<>();
for (int videoId : videoIds) {
recommendations.add(videos.get(videoId));
}
return recommendations;
}
// 评估推荐系统
public void evaluate() {
// 划分训练集和测试集
Collections.shuffle(ratings);
int split = (int) (ratings.size() * 0.8);
List<Rating> trainSet = ratings.subList(0, split);
List<Rating> testSet = ratings.subList(split, ratings.size());
// 在训练集上训练
ALS tempModel = new ALS(10, 0.01, 20);
tempModel.train(trainSet, users.size(), videos.size());
// 在测试集上评估
double rmse = tempModel.calculateRMSE(testSet);
System.out.printf("Test RMSE: %.4f\n", rmse);
}
}
2.4 主程序
import model.*;
import service.RecommendationService;
import java.util.*;
public class Main {
public static void main(String[] args) {
// 1. 准备示例数据
List<User> users = createUsers();
List<Video> videos = createVideos();
List<Rating> ratings = createRatings(users, videos);
// 2. 创建推荐服务
RecommendationService service = new RecommendationService(users, videos, ratings);
// 3. 训练模型
service.trainModel();
// 4. 生成推荐
int targetUserId = 0; // 为第一个用户推荐
List<Video> recommendations = service.getRecommendations(targetUserId, 5);
System.out.println("为用户 " + users.get(targetUserId).getUsername() + " 推荐的视频:");
for (Video video : recommendations) {
System.out.println("- " + video.getTitle() + " (" + video.getCategory() + ")");
}
// 5. 评估模型
service.evaluate();
}
private static List<User> createUsers() {
List<User> users = new ArrayList<>();
users.add(new User(0, "张三", "本科"));
users.add(new User(1, "李四", "硕士"));
users.add(new User(2, "王五", "博士"));
return users;
}
private static List<Video> createVideos() {
List<Video> videos = new ArrayList<>();
videos.add(new Video(0, "Java入门", "编程", 120));
videos.add(new Video(1, "机器学习基础", "AI", 180));
videos.add(new Video(2, "高等数学", "数学", 240));
videos.add(new Video(3, "英语写作", "语言", 90));
videos.add(new Video(4, "数据结构", "编程", 150));
videos.add(new Video(5, "深度学习", "AI", 210));
videos.add(new Video(6, "线性代数", "数学", 160));
videos.add(new Video(7, "商务英语", "语言", 95));
return videos;
}
private static List<Rating> createRatings(List<User> users, List<Video> videos) {
List<Rating> ratings = new ArrayList<>();
Random rand = new Random();
// 为每个用户随机评分一些视频
for (User user : users) {
int numRatings = 3 + rand.nextInt(3); // 每个用户3-5个评分
Set<Integer> ratedVideos = new HashSet<>();
for (int i = 0; i < numRatings; i++) {
int videoId;
do {
videoId = rand.nextInt(videos.size());
} while (ratedVideos.contains(videoId));
ratedVideos.add(videoId);
double score = 1 + rand.nextInt(5); // 1-5分
ratings.add(new Rating(user.getId(), videoId, score));
}
}
return ratings;
}
}
3. 系统优化建议
冷启动问题解决方案:
// 在RecommendationService中添加混合推荐方法
public List<Video> getHybridRecommendations(int userId, int numRecs) {
// 如果新用户,使用基于内容的推荐
if (isNewUser(userId)) {
return getContentBasedRecommendations(userId, numRecs);
}
// 否则使用ALS推荐
return getRecommendations(userId, numRecs);
}
private boolean isNewUser(int userId) {
return ratings.stream().noneMatch(r -> r.getUserId() == userId);
}
private List<Video> getContentBasedRecommendations(int userId, int numRecs) {
User user = users.get(userId);
// 根据用户教育水平推荐同类视频
return videos.stream()
.filter(v -> v.getCategory().equals(getPreferredCategory(user)))
.sorted(Comparator.comparingDouble(Video::getDuration).reversed())
.limit(numRecs)
.collect(Collectors.toList());
}
private String getPreferredCategory(User user) {
// 简单逻辑:根据教育水平推荐类别
switch(user.getEducationLevel()) {
case "本科": return "编程";
case "硕士": return "AI";
case "博士": return "数学";
default: return "语言";
}
}
实时更新模型:
// 在ALS类中添加增量更新方法
public void updateModel(Rating newRating) {
int userId = newRating.getUserId();
int videoId = newRating.getVideoId();
// 简单实现:重新计算相关用户和视频的特征
updateUserFeatures(userId);
updateVideoFeatures(videoId);
}
private void updateUserFeatures(int userId) {
// 获取该用户的所有评分
List<Rating> userRatings = ratings.stream()
.filter(r -> r.getUserId() == userId)
.collect(Collectors.toList());
// 重新计算用户特征(简化版)
double[] newFeatures = new double[numFeatures];
for (Rating r : userRatings) {
for (int i = 0; i < numFeatures; i++) {
newFeatures[i] += itemFeatures[r.getVideoId()][i] * r.getScore();
}
}
userFeatures[userId] = newFeatures;
}
性能优化:
// 使用矩阵运算库替代手动实现
import org.apache.commons.math3.linear.*;
// 修改ALS中的solveLinearSystem方法
private double[] solveLinearSystem(double[][] A, double[] b) {
RealMatrix matrix = MatrixUtils.createRealMatrix(A);
DecompositionSolver solver = new LUDecomposition(matrix).getSolver();
return solver.solve(MatrixUtils.createRealVector(b)).toArray();
}
- 评估指标扩展
// 在RecommendationService中添加更多评估指标
public void fullEvaluation() {
// 1. 划分训练测试集
Collections.shuffle(ratings);
int split = (int) (ratings.size() * 0.8);
List<Rating> trainSet = ratings.subList(0, split);
List<Rating> testSet = ratings.subList(split, ratings.size());
// 2. 训练模型
ALS tempModel = new ALS(10, 0.01, 20);
tempModel.train(trainSet, users.size(), videos.size());
// 3. 计算各项指标
double rmse = calculateRMSE(tempModel, testSet);
double precision = calculatePrecision(tempModel, testSet);
double recall = calculateRecall(tempModel, testSet);
System.out.println("=== 评估结果 ===");
System.out.printf("RMSE: %.4f\n", rmse);
System.out.printf("Precision@5: %.4f\n", precision);
System.out.printf("Recall@5: %.4f\n", recall);
}
private double calculateRMSE(ALS model, List<Rating> testRatings) {
return model.calculateRMSE(testRatings);
}
private double calculatePrecision(ALS model, List<Rating> testRatings) {
int hits = 0;
int total = 0;
for (User user : users) {
// 获取用户实际高评分视频(4分以上)
Set<Integer> actualHighRated = testRatings.stream()
.filter(r -> r.getUserId() == user.getId() && r.getScore() >= 4)
.map(Rating::getVideoId)
.collect(Collectors.toSet());
if (!actualHighRated.isEmpty()) {
// 获取推荐视频
List<Integer> recommended = model.recommendVideos(user.getId(), 5, videos.size());
// 计算命中数
for (int videoId : recommended) {
if (actualHighRated.contains(videoId)) {
hits++;
}
}
total += recommended.size();
}
}
return total > 0 ? (double) hits / total : 0;
}
private double calculateRecall(ALS model, List<Rating> testRatings) {
int hits = 0;
int totalHighRated = 0;
for (User user : users) {
// 获取用户实际高评分视频(4分以上)
Set<Integer> actualHighRated = testRatings.stream()
.filter(r -> r.getUserId() == user.getId() && r.getScore() >= 4)
.map(Rating::getVideoId)
.collect(Collectors.toSet());
totalHighRated += actualHighRated.size();
if (!actualHighRated.isEmpty()) {
// 获取推荐视频
List<Integer> recommended = model.recommendVideos(user.getId(), 5, videos.size());
// 计算命中数
for (int videoId : recommended) {
if (actualHighRated.contains(videoId)) {
hits++;
}
}
}
}
return totalHighRated > 0 ? (double) hits / totalHighRated : 0;
}
这个实现提供了基于ALS的教育视频推荐系统完整框架,可以根据实际需求进一步扩展和优化。系统包含核心推荐算法、业务逻辑和评估模块,适合作为学术研究或中小型教育平台的推荐系统基础。