Kotlin OpenCV 机器学习70 DTrees 梯度提升树
1 OpenCV 机器学习算法
算法 | 适用场景 | 优点 | 缺点 |
---|---|---|---|
支持向量机 (SVM) | 分类问题 回归问题 异常检测 |
在高维空间中有效 在数据维度大于样本数量时仍然有效 使用不同的核函数可以解决各种非线性问题 |
对大规模数据集计算成本高 需要仔细调参 对特征缩放敏感 |
决策树 (Decision Trees) | 分类和回归问题 特征重要性分析 |
易于理解和解释 可处理数值型和类别型数据 不需要数据归一化 |
容易过拟合,特别是树很深时 可能创建有偏差的树,如果某些类别占主导地位 |
随机森林 (Random Forests) | 分类和回归问题 特征选择 |
减少过拟合风险 对异常值不敏感 可以处理高维数据 |
对于非常高维的稀疏数据可能表现不佳 模型解释性较差 |
梯度提升树 (Gradient Boosting Trees) | 分类和回归问题 特征重要性排序 |
通常性能优于其他机器学习算法 可以处理不同类型的特征 可以自动处理特征交互 |
容易过拟合,需要仔细调参 训练时间可能较长 |
K-最近邻 (K-Nearest Neighbors) | 分类和回归问题 推荐系统 |
简单易实现 不需要训练过程 适用于多分类问题 |
计算成本高,特别是对大数据集 对异常值敏感 需要大量内存来存储训练数据 |
朴素贝叶斯 (Naive Bayes) | 文本分类 垃圾邮件检测 情感分析 |
对小数据集效果好 可处理多类别问题 训练速度快 |
假设特征间独立,实际可能不成立 对数据分布敏感 |
2 OpenCV 深度学习算法
算法 | 适用场景 | 优点 | 缺点 |
---|---|---|---|
卷积神经网络 (CNN) | 图像分类 物体检测 图像分割 |
自动学习特征 参数共享减少了模型大小 适合处理具有空间结构的数据 |
需要大量标注数据 计算资源需求高 黑盒模型,解释性差 |
循环神经网络 (RNN) / 长短期记忆网络 (LSTM) | 序列数据处理 自然语言处理 时间序列预测 |
可以处理变长序列 能捕捉长期依赖关系 适合处理时序数据 |
训练困难(梯度消失/爆炸问题) 计算速度较慢 难以并行化 |
深度神经网络 (DNN) | 复杂非线性映射 特征学习 大规模数据集 |
可以学习高度非线性的关系 可以自动学习特征 适用于大规模数据 |
需要大量数据和计算资源 调参复杂 容易过拟合 |
3 OpenCV 无监督学习算法
算法 | 适用场景 | 优点 | 缺点 |
---|---|---|---|
K-均值聚类 (K-Means Clustering) | 数据分组 图像分割 异常检测 |
简单易实现 可扩展到大数据集 收敛速度快 |
需要预先指定簇的数量 对初始质心选择敏感 不适合处理非凸形状的簇 |
主成分分析 (PCA) | 降维 特征提取 数据压缩 |
可以减少数据的维度 去除数据中的噪声 可以用于可视化高维数据 |
只能捕捉线性关系 可能丢失有用信息 结果难以解释 |
4 Kotlin 梯度提升树
package com.xu.com.xu.ml
import cn.hutool.core.util.CharsetUtil
import cn.hutool.extra.compress.CompressUtil
import org.opencv.core.Core
import org.opencv.core.CvType
import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgcodecs.Imgcodecs
import org.opencv.imgproc.Imgproc
import org.opencv.ml.DTrees
import org.opencv.ml.Ml
import java.io.File
import java.util.*
object Train {
init {
val os = System.getProperty("os.name")
val type = System.getProperty("sun.arch.data.model")
if (os.uppercase(Locale.getDefault()).contains("WINDOWS")) {
val lib = if (type.endsWith("64")) {
File("lib\\opencv-4.9\\x64\\" + System.mapLibraryName(Core.NATIVE_LIBRARY_NAME))
} else {
File("lib\\opencv-4.9\\x86\\" + System.mapLibraryName(Core.NATIVE_LIBRARY_NAME))
}
System.load(lib.absolutePath)
}
println(Core.VERSION)
}
@JvmStatic
fun main(args: Array<String>) {
val (trainImages, trainLabels) = load("lib/data/image/train/")
val (testImages, testLabels) = load("lib/data/image/predict/")
// 梯度提升树
val model = DTrees.create()
model.maxDepth = 20
model.minSampleCount = 2
model.useSurrogates = false
model.cvFolds = 0
model.use1SERule = false
model.truncatePrunedTree = false
model.regressionAccuracy = 0.01f
// 转换为OpenCV的Mat格式
val trainImagesData = Mat(trainImages.size, 784, CvType.CV_32F)
trainImages.forEachIndexed { index, floatArray ->
trainImagesData.put(index, 0, floatArray)
}
val trainLabelsData = Mat(trainLabels.size, 1, CvType.CV_32S)
trainLabelsData.put(0, 0, trainLabels.toIntArray())
// 训练模型
model.train(trainImagesData, Ml.ROW_SAMPLE, trainLabelsData)
model.save("lib/data/image/ml/DTrees.xml")
// 评估训练集准确率
val train = accuracy(model, trainImages, trainLabels)
println("训练集准确率: $train")
// 评估测试集准确率
val test = accuracy(model, testImages, testLabels)
println("测试集准确率: $test")
}
/**
* 加载数据
*/
private fun load(path: String): Pair<List<FloatArray>, List<Int>> {
val images = mutableListOf<FloatArray>()
val labels = mutableListOf<Int>()
for (i in 0..9) {
val dir = File("$path/$i")
dir.listFiles()?.forEach { file ->
val img = Imgcodecs.imread(file.absolutePath, Imgcodecs.IMREAD_GRAYSCALE)
if (!img.empty()) {
Imgproc.resize(img, img, Size(28.0, 28.0))
val array = ByteArray(784)
img.get(0, 0, array)
images.add(array.map { it / 255.0f }.toFloatArray())
labels.add(i)
}
}
}
return Pair(images, labels)
}
/**
* 计算准确率
*/
private fun accuracy(model: DTrees, images: List<FloatArray>, labels: List<Int>): Double {
var correct = 0
images.forEachIndexed { index, image ->
val sample = Mat(1, 784, CvType.CV_32F)
sample.put(0, 0, image)
val response = model.predict(sample)
if (response.toInt() == labels[index]) {
correct++
}
}
return correct.toDouble() / images.size * 100
}
private fun unzip() {
// 解压训练图片
CompressUtil.createExtractor(
CharsetUtil.defaultCharset(),
File("lib/data/image/train.7z")
).extract(File("lib/data/image/train/"))
// 解压测试图片
CompressUtil.createExtractor(
CharsetUtil.defaultCharset(),
File("lib/data/image/predict.7z")
).extract(File("lib/data/image/predict/"))
}
}
5 Kotlin 梯度提升树 训练结果
4.9.0
训练集准确率: 98.19333333333333
测试集准确率: 84.54