Spark的AI/机器学习实践
基于Java Spark的AI/机器学习实例
以下是一些基于Java Spark的AI/机器学习实例,涵盖分类、回归、聚类、推荐系统等常见任务。这些示例使用Apache Spark MLlib库,适合大数据环境下的分布式计算。
分类任务示例
逻辑回归分类
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
Dataset<Row> training = spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
LogisticRegression lr = new LogisticRegression().setMaxIter(10).setRegParam(0.3);
LogisticRegressionModel model = lr.fit(training);
随机森林分类
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
RandomForestClassifier rf = new RandomForestClassifier().setNumTrees(10);
RandomForestClassificationModel rfModel = rf.fit(training);
梯度提升树分类
import org.apache.spark.ml.classification.GBTClassifier;
import org.apache.spark.ml.classification.GBTClassificationModel;
GBTClassifier gbt = new GBTClassifier().setMaxIter(10);
GBTClassificationModel gbtModel = gbt.fit(training);
回归任务示例
线性回归
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel;
LinearRegression lr = new LinearRegression().setMaxIter(10).setRegParam(0.3);
LinearRegressionModel model = lr.fit(training);
决策树回归
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
DecisionTreeRegressor dt = new DecisionTreeRegressor();
DecisionTreeRegressionModel dtModel = dt.fit(training);
聚类任务示例
K-means聚类
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.ml.clustering.KMeansModel;
KMeans kmeans = new KMeans().setK(2).setSeed(1L);
KMeansModel model = kmeans.fit(training);
高斯混合模型
import org.apache.spark.ml.clustering.GaussianMixture;
import org.apache.spark.ml.clustering.GaussianMixtureModel;
GaussianMixture gm = new GaussianMixture().setK(2);
GaussianMixtureModel model = gm.fit(training);
推荐系统示例
交替最小二乘法(ALS)
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
ALS als = new ALS().setMaxIter(5).setRegParam(0.01).setUserCol("userId").setItemCol("movieId").setRatingCol("rating");
ALSModel model = als.fit(training);
特征处理示例
TF-IDF文本特征提取
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.Tokenizer;
Tokenizer tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words");
Dataset<Row> wordsData = tokenizer.transform(training);
HashingTF hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures");
Dataset<Row> featurizedData = hashingTF.transform(wordsData);
IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
IDFModel idfModel = idf.fit(featurizedData);
PCA降维
import org.apache.spark.ml.feature.PCA;
PCA pca = new PCA().setInputCol("features").setOutputCol("pcaFeatures").setK(3);
PCAModel pcaModel = pca.fit(training);
模型评估示例
二分类评估
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator().setLabelCol("label").setRawPredictionCol("rawPrediction").setMetricName("areaUnderROC");
double auc = evaluator.evaluate(predictions);
多分类评估
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("accuracy");
double accuracy = evaluator.evaluate(predictions);
模型调优示例
交叉验证
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.ParamGridBuilder;
ParamMap[] paramGrid = new ParamGridBuilder().addGrid(lr.regParam(), new double[]{0.1, 0.01}).build();
CrossValidator cv = new CrossValidator().setEstimator(lr).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(3);
CrossValidatorModel cvModel = cv.fit(training);
其他实用示例
保存与加载模型
model.save("path/to/model");
LinearRegressionModel sameModel = LinearRegressionModel.load("path/to/model");
管道(Pipeline)
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{tokenizer, hashingTF, idf, lr});
PipelineModel pipelineModel = pipeline.fit(training);
这些示例覆盖了Spark MLlib的主要功能模块,可根据实际需求调整参数或扩展。完整项目建议参考Apache Spark官方文档及示例代码库。
Spring Boot 与 Spark 结合实例
Spring Boot 与 Apache Spark 结合可以实现大数据处理与微服务的无缝集成。以下是常见的应用场景和示例代码片段。
集成 Spark 与 Spring Boot
在 pom.xml
中添加 Spark 依赖:
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<version>3.2.0</version>
</dependency>
创建 Spring Boot 服务启动 Spark 上下文:
@SpringBootApplication
public class SparkApp {
public static void main(String[] args) {
SpringApplication.run(SparkApp.class, args);
SparkConf conf = new SparkConf().setAppName("SpringSpark").setMaster("local[*]");
JavaSparkContext sc = new JavaSparkContext(conf);
}
}
数据清洗示例
从 CSV 读取数据并过滤:
JavaRDD<String> data = sc.textFile("input.csv");
JavaRDD<String> filtered = data.filter(line -> !line.contains("NULL"));
filtered.saveAsTextFile("output");
实时统计单词频率
JavaRDD<String> lines = sc.textFile("textfile.txt");
JavaPairRDD<String, Integer> counts = lines
.flatMap(line -> Arrays.asList(line.split(" ")).iterator())
.mapToPair(word -> new Tuple2<>(word, 1))
.reduceByKey((a, b) -> a + b);
counts.saveAsTextFile("wordcounts");
机器学习模型训练
使用 Spark MLlib 实现线性回归:
Dataset<Row> data = spark.read().format("libsvm").load("data.txt");
LinearRegression lr = new LinearRegression();
LinearRegressionModel model = lr.fit(data);
System.out.println("Coefficients: " + model.coefficients());