【基于ALS模型的教育视频推荐系统(Java实现)】

发布于:2025-05-13 ⋅ 阅读:(15) ⋅ 点赞:(0)

【基于ALS模型的教育视频推荐系统(Java实现)】

下面是一个完整的基于交替最小二乘法(ALS)的教育视频推荐系统实现,包含数据预处理、模型训练、推荐生成和评估模块。

1. 系统架构

edu-recommender/
├── src/
│   ├── main/
│   │   ├── java/
│   │   │   ├── model/          # 数据模型
│   │   │   ├── algorithm/      # ALS算法实现
│   │   │   ├── service/        # 业务逻辑
│   │   │   ├── util/           # 工具类
│   │   │   └── Main.java       # 入口类
│   │   └── resources/          # 配置文件
├── pom.xml                     # Maven依赖
└── data/                       # 示例数据集

2.1 数据模型类

// Video.java
package model;

public class Video {
    private int id;
    private String title;
    private String category;
    private double duration; // 分钟
    
    // 构造函数、getter和setter
}

// User.java
package model;

public class User {
    private int id;
    private String username;
    private String educationLevel;
    
    // 构造函数、getter和setter
}

// Rating.java
package model;

public class Rating {
    private int userId;
    private int videoId;
    private double score; // 1-5分
    
    // 构造函数、getter和setter
}

2.2 ALS算法实现

package algorithm;

import java.util.*;

public class ALS {
    private int numFeatures; // 特征维度
    private double lambda; // 正则化参数
    private int maxIter; // 最大迭代次数
    private double[][] userFeatures; // 用户特征矩阵
    private double[][] itemFeatures; // 物品特征矩阵
    
    public ALS(int numFeatures, double lambda, int maxIter) {
        this.numFeatures = numFeatures;
        this.lambda = lambda;
        this.maxIter = maxIter;
    }
    
    // 训练模型
    public void train(List<Rating> ratings, int numUsers, int numItems) {
        // 初始化特征矩阵
        Random rand = new Random();
        userFeatures = new double[numUsers][numFeatures];
        itemFeatures = new double[numItems][numFeatures];
        
        for (int i = 0; i < numUsers; i++) {
            for (int j = 0; j < numFeatures; j++) {
                userFeatures[i][j] = rand.nextDouble();
            }
        }
        
        for (int i = 0; i < numItems; i++) {
            for (int j = 0; j < numFeatures; j++) {
                itemFeatures[i][j] = rand.nextDouble();
            }
        }
        
        // 交替优化
        for (int iter = 0; iter < maxIter; iter++) {
            // 固定物品特征,优化用户特征
            updateFeatures(ratings, userFeatures, itemFeatures, true);
            
            // 固定用户特征,优化物品特征
            updateFeatures(ratings, itemFeatures, userFeatures, false);
            
            double error = calculateRMSE(ratings);
            System.out.printf("Iteration %d, RMSE: %.4f\n", iter, error);
        }
    }
    
    // 更新特征矩阵
    private void updateFeatures(List<Rating> ratings, double[][] mainFeatures, 
                              double[][] fixedFeatures, boolean isUser) {
        Map<Integer, List<Rating>> groupedRatings = groupRatings(ratings, isUser);
        
        for (Map.Entry<Integer, List<Rating>> entry : groupedRatings.entrySet()) {
            int id = entry.getKey();
            List<Rating> group = entry.getValue();
            
            // 构建矩阵A和向量b
            double[][] A = new double[numFeatures][numFeatures];
            double[] b = new double[numFeatures];
            
            for (Rating r : group) {
                int otherId = isUser ? r.getVideoId() : r.getUserId();
                double[] otherVec = fixedFeatures[otherId];
                
                // A += otherVec * otherVec^T
                for (int i = 0; i < numFeatures; i++) {
                    for (int j = 0; j < numFeatures; j++) {
                        A[i][j] += otherVec[i] * otherVec[j];
                    }
                    b[i] += otherVec[i] * r.getScore();
                }
            }
            
            // 添加正则化项: A += lambda * I
            for (int i = 0; i < numFeatures; i++) {
                A[i][i] += lambda;
            }
            
            // 解线性方程组: A * x = b
            double[] newFeatures = solveLinearSystem(A, b);
            System.arraycopy(newFeatures, 0, mainFeatures[id], 0, numFeatures);
        }
    }
    
    // 分组评分数据
    private Map<Integer, List<Rating>> groupRatings(List<Rating> ratings, boolean byUser) {
        Map<Integer, List<Rating>> map = new HashMap<>();
        for (Rating r : ratings) {
            int key = byUser ? r.getUserId() : r.getVideoId();
            map.computeIfAbsent(key, k -> new ArrayList<>()).add(r);
        }
        return map;
    }
    
    // 解线性方程组(使用高斯消元法)
    private double[] solveLinearSystem(double[][] A, double[] b) {
        int n = b.length;
        double[] x = new double[n];
        
        // 前向消元
        for (int i = 0; i < n; i++) {
            // 找主元
            int maxRow = i;
            for (int k = i + 1; k < n; k++) {
                if (Math.abs(A[k][i]) > Math.abs(A[maxRow][i])) {
                    maxRow = k;
                }
            }
            
            // 交换行
            double[] tempRow = A[i];
            A[i] = A[maxRow];
            A[maxRow] = tempRow;
            double tempVal = b[i];
            b[i] = b[maxRow];
            b[maxRow] = tempVal;
            
            // 消元
            for (int k = i + 1; k < n; k++) {
                double factor = A[k][i] / A[i][i];
                b[k] -= factor * b[i];
                for (int j = i; j < n; j++) {
                    A[k][j] -= factor * A[i][j];
                }
            }
        }
        
        // 回代
        for (int i = n - 1; i >= 0; i--) {
            double sum = 0;
            for (int j = i + 1; j < n; j++) {
                sum += A[i][j] * x[j];
            }
            x[i] = (b[i] - sum) / A[i][i];
        }
        
        return x;
    }
    
    // 计算RMSE
    public double calculateRMSE(List<Rating> ratings) {
        double sumSquaredError = 0;
        for (Rating r : ratings) {
            double predicted = predict(r.getUserId(), r.getVideoId());
            sumSquaredError += Math.pow(predicted - r.getScore(), 2);
        }
        return Math.sqrt(sumSquaredError / ratings.size());
    }
    
    // 预测评分
    public double predict(int userId, int videoId) {
        double score = 0;
        for (int i = 0; i < numFeatures; i++) {
            score += userFeatures[userId][i] * itemFeatures[videoId][i];
        }
        return Math.max(1, Math.min(5, score)); // 限制在1-5分
    }
    
    // 为用户推荐视频
    public List<Integer> recommendVideos(int userId, int numRecommendations, int numVideos) {
        PriorityQueue<VideoScore> pq = new PriorityQueue<>();
        
        for (int videoId = 0; videoId < numVideos; videoId++) {
            double score = predict(userId, videoId);
            pq.offer(new VideoScore(videoId, score));
            if (pq.size() > numRecommendations) {
                pq.poll();
            }
        }
        
        List<Integer> recommendations = new ArrayList<>();
        while (!pq.isEmpty()) {
            recommendations.add(0, pq.poll().videoId);
        }
        return recommendations;
    }
    
    // 辅助类
    private static class VideoScore implements Comparable<VideoScore> {
        int videoId;
        double score;
        
        VideoScore(int videoId, double score) {
            this.videoId = videoId;
            this.score = score;
        }
        
        @Override
        public int compareTo(VideoScore other) {
            return Double.compare(this.score, other.score);
        }
    }
}

2.3 推荐服务类

package service;

import model.*;
import algorithm.ALS;
import java.util.*;

public class RecommendationService {
    private List<User> users;
    private List<Video> videos;
    private List<Rating> ratings;
    private ALS alsModel;
    
    public RecommendationService(List<User> users, List<Video> videos, List<Rating> ratings) {
        this.users = users;
        this.videos = videos;
        this.ratings = ratings;
    }
    
    // 训练推荐模型
    public void trainModel() {
        int numUsers = users.size();
        int numVideos = videos.size();
        
        alsModel = new ALS(10, 0.01, 20); // 10个特征,lambda=0.01,20次迭代
        alsModel.train(ratings, numUsers, numVideos);
    }
    
    // 为用户生成推荐
    public List<Video> getRecommendations(int userId, int numRecs) {
        List<Integer> videoIds = alsModel.recommendVideos(userId, numRecs, videos.size());
        List<Video> recommendations = new ArrayList<>();
        
        for (int videoId : videoIds) {
            recommendations.add(videos.get(videoId));
        }
        
        return recommendations;
    }
    
    // 评估推荐系统
    public void evaluate() {
        // 划分训练集和测试集
        Collections.shuffle(ratings);
        int split = (int) (ratings.size() * 0.8);
        List<Rating> trainSet = ratings.subList(0, split);
        List<Rating> testSet = ratings.subList(split, ratings.size());
        
        // 在训练集上训练
        ALS tempModel = new ALS(10, 0.01, 20);
        tempModel.train(trainSet, users.size(), videos.size());
        
        // 在测试集上评估
        double rmse = tempModel.calculateRMSE(testSet);
        System.out.printf("Test RMSE: %.4f\n", rmse);
    }
}

2.4 主程序

import model.*;
import service.RecommendationService;
import java.util.*;

public class Main {
    public static void main(String[] args) {
        // 1. 准备示例数据
        List<User> users = createUsers();
        List<Video> videos = createVideos();
        List<Rating> ratings = createRatings(users, videos);
        
        // 2. 创建推荐服务
        RecommendationService service = new RecommendationService(users, videos, ratings);
        
        // 3. 训练模型
        service.trainModel();
        
        // 4. 生成推荐
        int targetUserId = 0; // 为第一个用户推荐
        List<Video> recommendations = service.getRecommendations(targetUserId, 5);
        
        System.out.println("为用户 " + users.get(targetUserId).getUsername() + " 推荐的视频:");
        for (Video video : recommendations) {
            System.out.println("- " + video.getTitle() + " (" + video.getCategory() + ")");
        }
        
        // 5. 评估模型
        service.evaluate();
    }
    
    private static List<User> createUsers() {
        List<User> users = new ArrayList<>();
        users.add(new User(0, "张三", "本科"));
        users.add(new User(1, "李四", "硕士"));
        users.add(new User(2, "王五", "博士"));
        return users;
    }
    
    private static List<Video> createVideos() {
        List<Video> videos = new ArrayList<>();
        videos.add(new Video(0, "Java入门", "编程", 120));
        videos.add(new Video(1, "机器学习基础", "AI", 180));
        videos.add(new Video(2, "高等数学", "数学", 240));
        videos.add(new Video(3, "英语写作", "语言", 90));
        videos.add(new Video(4, "数据结构", "编程", 150));
        videos.add(new Video(5, "深度学习", "AI", 210));
        videos.add(new Video(6, "线性代数", "数学", 160));
        videos.add(new Video(7, "商务英语", "语言", 95));
        return videos;
    }
    
    private static List<Rating> createRatings(List<User> users, List<Video> videos) {
        List<Rating> ratings = new ArrayList<>();
        Random rand = new Random();
        
        // 为每个用户随机评分一些视频
        for (User user : users) {
            int numRatings = 3 + rand.nextInt(3); // 每个用户3-5个评分
            Set<Integer> ratedVideos = new HashSet<>();
            
            for (int i = 0; i < numRatings; i++) {
                int videoId;
                do {
                    videoId = rand.nextInt(videos.size());
                } while (ratedVideos.contains(videoId));
                
                ratedVideos.add(videoId);
                double score = 1 + rand.nextInt(5); // 1-5分
                ratings.add(new Rating(user.getId(), videoId, score));
            }
        }
        
        return ratings;
    }
}

3. 系统优化建议

冷启动问题解决方案:

// 在RecommendationService中添加混合推荐方法
public List<Video> getHybridRecommendations(int userId, int numRecs) {
    // 如果新用户,使用基于内容的推荐
    if (isNewUser(userId)) {
        return getContentBasedRecommendations(userId, numRecs);
    }
    // 否则使用ALS推荐
    return getRecommendations(userId, numRecs);
}

private boolean isNewUser(int userId) {
    return ratings.stream().noneMatch(r -> r.getUserId() == userId);
}

private List<Video> getContentBasedRecommendations(int userId, int numRecs) {
    User user = users.get(userId);
    // 根据用户教育水平推荐同类视频
    return videos.stream()
        .filter(v -> v.getCategory().equals(getPreferredCategory(user)))
        .sorted(Comparator.comparingDouble(Video::getDuration).reversed())
        .limit(numRecs)
        .collect(Collectors.toList());
}

private String getPreferredCategory(User user) {
    // 简单逻辑:根据教育水平推荐类别
    switch(user.getEducationLevel()) {
        case "本科": return "编程";
        case "硕士": return "AI";
        case "博士": return "数学";
        default: return "语言";
    }
}

实时更新模型:

// 在ALS类中添加增量更新方法
public void updateModel(Rating newRating) {
    int userId = newRating.getUserId();
    int videoId = newRating.getVideoId();
    
    // 简单实现:重新计算相关用户和视频的特征
    updateUserFeatures(userId);
    updateVideoFeatures(videoId);
}

private void updateUserFeatures(int userId) {
    // 获取该用户的所有评分
    List<Rating> userRatings = ratings.stream()
        .filter(r -> r.getUserId() == userId)
        .collect(Collectors.toList());
    
    // 重新计算用户特征(简化版)
    double[] newFeatures = new double[numFeatures];
    for (Rating r : userRatings) {
        for (int i = 0; i < numFeatures; i++) {
            newFeatures[i] += itemFeatures[r.getVideoId()][i] * r.getScore();
        }
    }
    userFeatures[userId] = newFeatures;
}

性能优化:

// 使用矩阵运算库替代手动实现
import org.apache.commons.math3.linear.*;

// 修改ALS中的solveLinearSystem方法
private double[] solveLinearSystem(double[][] A, double[] b) {
    RealMatrix matrix = MatrixUtils.createRealMatrix(A);
    DecompositionSolver solver = new LUDecomposition(matrix).getSolver();
    return solver.solve(MatrixUtils.createRealVector(b)).toArray();
}
  1. 评估指标扩展
// 在RecommendationService中添加更多评估指标
public void fullEvaluation() {
    // 1. 划分训练测试集
    Collections.shuffle(ratings);
    int split = (int) (ratings.size() * 0.8);
    List<Rating> trainSet = ratings.subList(0, split);
    List<Rating> testSet = ratings.subList(split, ratings.size());
    
    // 2. 训练模型
    ALS tempModel = new ALS(10, 0.01, 20);
    tempModel.train(trainSet, users.size(), videos.size());
    
    // 3. 计算各项指标
    double rmse = calculateRMSE(tempModel, testSet);
    double precision = calculatePrecision(tempModel, testSet);
    double recall = calculateRecall(tempModel, testSet);
    
    System.out.println("=== 评估结果 ===");
    System.out.printf("RMSE: %.4f\n", rmse);
    System.out.printf("Precision@5: %.4f\n", precision);
    System.out.printf("Recall@5: %.4f\n", recall);
}

private double calculateRMSE(ALS model, List<Rating> testRatings) {
    return model.calculateRMSE(testRatings);
}

private double calculatePrecision(ALS model, List<Rating> testRatings) {
    int hits = 0;
    int total = 0;
    
    for (User user : users) {
        // 获取用户实际高评分视频(4分以上)
        Set<Integer> actualHighRated = testRatings.stream()
            .filter(r -> r.getUserId() == user.getId() && r.getScore() >= 4)
            .map(Rating::getVideoId)
            .collect(Collectors.toSet());
        
        if (!actualHighRated.isEmpty()) {
            // 获取推荐视频
            List<Integer> recommended = model.recommendVideos(user.getId(), 5, videos.size());
            
            // 计算命中数
            for (int videoId : recommended) {
                if (actualHighRated.contains(videoId)) {
                    hits++;
                }
            }
            total += recommended.size();
        }
    }
    
    return total > 0 ? (double) hits / total : 0;
}

private double calculateRecall(ALS model, List<Rating> testRatings) {
    int hits = 0;
    int totalHighRated = 0;
    
    for (User user : users) {
        // 获取用户实际高评分视频(4分以上)
        Set<Integer> actualHighRated = testRatings.stream()
            .filter(r -> r.getUserId() == user.getId() && r.getScore() >= 4)
            .map(Rating::getVideoId)
            .collect(Collectors.toSet());
        
        totalHighRated += actualHighRated.size();
        
        if (!actualHighRated.isEmpty()) {
            // 获取推荐视频
            List<Integer> recommended = model.recommendVideos(user.getId(), 5, videos.size());
            
            // 计算命中数
            for (int videoId : recommended) {
                if (actualHighRated.contains(videoId)) {
                    hits++;
                }
            }
        }
    }
    
    return totalHighRated > 0 ? (double) hits / totalHighRated : 0;
}

这个实现提供了基于ALS的教育视频推荐系统完整框架,可以根据实际需求进一步扩展和优化。系统包含核心推荐算法、业务逻辑和评估模块,适合作为学术研究或中小型教育平台的推荐系统基础。


网站公告

今日签到

点亮在社区的每一天
去签到