目录
引言
随机森林(Random Forest)和决策树(Decision Tree)是两种在机器学习中广泛使用的分类和回归方法,它们都属于监督学习算法。这两种算法在理解数据、构建预测模型方面有着各自的特点和优势,同时也存在紧密的联系。
决策树(Decision Tree)
决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一个类别(对于分类树)或一个数值(对于回归树)。决策树通过学习简单的决策规则来预测目标变量的值。
构建过程:
- 选择最佳特征进行分割:通常使用信息增益(对于ID3算法)、增益率(对于C4.5算法)或基尼不纯度(对于CART算法)等指标来选择最佳特征。
- 分割数据集:根据选择的特征将数据集分割成子集。
- 递归构建树:对每个子集重复上述过程,直到满足停止条件(如子集为空或达到预设的树深度)。
- 剪枝(可选):为了避免过拟合,可以移除树中的一些子树或叶节点。
优点:
- 易于理解和解释。
- 能够处理非线性关系。
- 不需要数据标准化或归一化。
缺点:
- 容易过拟合。
- 对数据中的噪声敏感。
- 不稳定,不同的样本集可能生成差异较大的树。
随机森林(Random Forest)
随机森林是一种集成学习方法,它构建多个决策树,并通过输出这些树的多数投票(对于分类问题)或平均值(对于回归问题)来预测结果。随机森林通过引入随机性来增强模型的泛化能力。
构建过程:
- 构建多棵决策树:
- 随机选择样本:从原始数据集中随机有放回地抽取多个样本集,每个样本集用于构建一棵树。
- 随机选择特征:在构建树的每个节点时,随机选择一部分特征来寻找最佳分割点。
- 组合多棵树:通过多数投票或平均值来组合多棵树的预测结果。
优点:
- 具有很高的预测准确率,通常优于单个决策树。
- 能够处理高维数据,不需要进行特征选择。
- 对异常值和噪声具有很好的容忍度,不容易过拟合。
- 易于并行化,可以提高计算效率。
缺点:
- 在某些噪声很大的分类或回归问题上会过拟合。
- 相对于单个决策树,随机森林的模型解释性较差。
总结:
决策树和随机森林都是强大的机器学习算法,它们在处理分类和回归问题时各有优势。决策树简单直观,但容易过拟合;随机森林通过集成多个决策树来提高模型的稳定性和准确性,是处理复杂数据集时的优选算法之一。
数据集
数据集是著名的鸢尾花(Iris)数据集,它常被用于分类算法的测试和教学。数据集包含了150个样本,每个样本都有4个特征(花萼长度Sepal.Length、花萼宽度Sepal.Width、花瓣长度Petal.Length、花瓣宽度Petal.Width)和一个目标变量(Species),即鸢尾花的种类。在这个数据集中,鸢尾花被分为三种类型:Setosa、Versicolour和Virginica。
- Sepal.Length:花萼的长度,以厘米为单位。
- Sepal.Width:花萼的宽度,以厘米为单位。
- Petal.Length:花瓣的长度,以厘米为单位。
- Petal.Width:花瓣的宽度,以厘米为单位。
- Species:鸢尾花的种类,有三种可能的值:Setosa、Versicolour和Virginica。
这个数据集非常适合用于分类算法的学习和测试,因为它包含了足够数量的样本和特征,同时又有清晰的分类标签。通过使用这个数据集,可以训练模型来预测给定花萼和花瓣的尺寸时,鸢尾花的种类
结果
单颗决策树如图所示:
代码实现
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support
import matplotlib.pyplot as plt
from sklearn import tree
data = pd.read_csv('D:/iris.csv')
X = data.iloc[:, :4]
y = data.iloc[:, 4]
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
model = RandomForestClassifier(random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
accuracy = model.score(X_test, y_test)
precision, recall, _, _ = precision_recall_fscore_support(y_test, y_pred, average='weighted')
print("Accuracy: {:.2%}".format(accuracy))
print("Precision: {:.2%}".format(precision))
print("Recall: {:.2%}".format(recall))
plt.figure(figsize=(20, 10))
tree.plot_tree(model.estimators_[0], feature_names=data.columns[:4], class_names=data.columns[4], filled=True)
plt.show()