决策树像人类的思考过程,用一系列“是/否”问题层层逼近答案
目录
1. DecisionTreeClassifier(决策树分类器)
gini、entropy与经典算法ID3、C4.5、CART的关系
2. 信息增益(Entropy) → ID3/C4.5 的启发
3. scikit-learn 的决策树实现本质是 CART
2. DecisionTreeRegressor(决策树回归器)
3. ExtraTreeClassifier(极端随机树分类器)
4. ExtraTreeRegressor(极端随机树回归器)
一、决策树的核心本质
决策树是一种模仿人类决策过程的树形结构分类/回归模型。它通过节点(问题) 和 边(答案) 构建路径,最终在叶节点(决策结果) 输出预测值。这种白盒模型的优势在于极高的可解释性。
二、决策树的核心构成
- 根节点:初始特征划分点
- 内部节点:特征测试点(每个节点对应一个判断条件)
- 分支:判断条件的可能结果
- 叶节点:最终决策结果(分类/回归值)
关键概念:
- 纯度(Purity):节点内样本类别的统一程度(Gini指数/熵越小越纯)
- 信息增益(Information Gain):分裂后纯度的提升量
- 剪枝(Pruning):防止过拟合的关键技术(预剪枝/后剪枝)
三、决策树的数学原理
决策树通过递归分割寻找最优特征:
1、选择分裂特征:
- ID3算法:使用信息增益(缺陷:偏好多值特征)
- C4.5算法:改进为增益率(消除特征取值数量的影响)
- CART算法:使用Gini指数(计算效率更高)
2、停止条件:
- 节点样本全属同一类
- 特征已用完
- 样本数低于阈值(超参数控制)
四、算法对比
参考:
算法 | 原生设计 | 能否用于回归 | 原因/限制 |
---|---|---|---|
ID3 | 分类(信息增益) | ❌ 不能直接使用 | 依赖离散标签的熵计算,回归任务是连续值,无法直接计算类别纯度。 |
C4.5 | 分类(增益率) | ❌ 不能直接使用 | 同ID3,分裂标准基于分类熵,且要求离散特征。 |
CART | 分类+回归 | ✅ 可直接使用 | 设计时同时支持Gini指数(分类)和最小方差(回归),天然兼容连续值目标变量。 |
维度 | 分类树(ID3/C4.5/CART) | 回归树(CART) |
---|---|---|
分裂标准 | 信息增益、增益率、Gini指数 | 最小化方差(MSE)或绝对误差(MAE) |
叶节点输出 | 类别标签 | 连续值(均值/中位数) |
特征类型 | 离散特征(ID3/C4.5)或混合(CART) | 支持连续和离散特征 |
ID3和C4.5原生仅支持分类,但通过替换分裂标准和离散化连续特征,可间接适配回归任务。实际应用中,CART是更高效且通用的选择。
五、决策树的双面性
优势 ✅:
- 直观可视化(业务人员可理解)
- 无需数据标准化
- 支持混合特征(数值+类别)
局限 ⚠️:
- 对数据扰动敏感(小变动可能导致结构剧变)
- 容易过拟合(必须剪枝)
- 不适合学习复杂关系(如异或问题)
延展思考:决策树作为集成学习的基模型(如随机森林/XGBoost)时,通过“群体智慧”能极大克服自身缺陷。在实际应用中,超过80%的预测场景会优先尝试树模型家族。
六、Python代码实战
tree包有什么?
from sklearn.tree import DecisionTreeClassifier, plot_tree
看一下tree包有些什么?
"""Decision tree based models for classification and regression."""
# 模块文档字符串,说明这个模块提供基于决策树的分类和回归模型
# Authors: The scikit-learn developers
# 标明作者是 scikit-learn 开发团队
# SPDX-License-Identifier: BSD-3-Clause
# 软件许可证声明,使用 BSD 3-Clause 许可证
from ._classes import (
BaseDecisionTree, # 决策树的基类(抽象类)
DecisionTreeClassifier, # 决策树分类器
DecisionTreeRegressor, # 决策树回归器
ExtraTreeClassifier, # 极端随机树分类器
ExtraTreeRegressor, # 极端随机树回归器
)
# 从当前目录的 _classes.py 模块导入决策树相关类
from ._export import export_graphviz, export_text, plot_tree
# 从当前目录的 _export.py 模块导入可视化导出函数:
# - export_graphviz: 导出Graphviz格式的可视化
# - export_text: 导出文本形式的决策规则
# - plot_tree: 绘制决策树图形
__all__ = [
"BaseDecisionTree", # 公开的基类
"DecisionTreeClassifier", # 公开的分类器
"DecisionTreeRegressor", # 公开的回归器
"ExtraTreeClassifier", # 公开的极端随机树分类器
"ExtraTreeRegressor", # 公开的极端随机树回归器
"export_graphviz", # 公开的Graphviz导出函数
"plot_tree", # 公开的绘图函数
"export_text", # 公开的文本导出函数
]
# 定义模块的公开API,当使用 from sklearn.tree import * 时,只有这里列出的名称会被导入
关键点说明:
模块结构分为两部分:
_classes.py
包含决策树的核心实现类_export.py
包含可视化相关工具函数
提供的5个核心类:
基类
BaseDecisionTree
包含通用实现(BaseDecisionTree
是 scikit-learn 中所有决策树模型的基类(抽象基类),它定义了决策树的核心框架和通用方法,但普通用户通常不会直接使用它。它的主要用途是作为DecisionTreeClassifier
、DecisionTreeRegressor
等具体实现类的父类,提供共享的逻辑和接口。除非你要自定义一种新的决策树变体(例如实现一种新的分裂准则),否则不需要直接使用BaseDecisionTree
。)标准决策树和极端随机树(ExtraTrees)两种变体
每种变体都有分类和回归版本
可视化工具:
支持图形化(
plot_tree
)和文本(export_text
)两种展示方式支持导出到Graphviz格式(
export_graphviz
)用于进一步处理
__all__
严格控制了模块的公开接口
鸢尾花数据集(Iris Dataset)
鸢尾花数据集是机器学习领域最经典的数据集之一,由英国统计学家和生物学家 Ronald Fisher 在1936年提出,常用于分类算法的入门和测试。
- 样本数量:150 个样本(3 类 × 50 个样本)
- 特征数量:4 个数值型特征
- 目标类别:3 种鸢尾花品种
特征(Feature) | 描述 |
---|---|
花萼长度(sepal length) | 单位:cm |
花萼宽度(sepal width) | 单位:cm |
花瓣长度(petal length) | 单位:cm |
花瓣宽度(petal width) | 单位:cm |
类别(Target) | 描述 |
-------------- | ------ |
Setosa(山鸢尾) | 线性可分,容易分类 |
Versicolor(杂色鸢尾) | 与 Virginica 部分重叠 |
Virginica(维吉尼亚鸢尾) | 与 Versicolor 部分重叠 |
from sklearn.datasets import load_iris
import pandas as pd
iris = load_iris()
X = iris.data # 特征矩阵 (150, 4)
y = iris.target # 类别标签 (0, 1, 2)
feature_names = iris.feature_names # 特征名称
target_names = iris.target_names # 类别名称
# 转为DataFrame(可选)
df = pd.DataFrame(X, columns=feature_names)
df['species'] = [target_names[label] for label in y]
print(df.head())
在 sklearn.tree
模块中,提供了几种不同的决策树模型,适用于分类和回归任务。
1. DecisionTreeClassifier
(决策树分类器)
适用于分类问题(预测离散类别标签)。
基本用法
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# 加载数据
iris = load_iris()
X, y = iris.data, iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 创建决策树分类器
clf = DecisionTreeClassifier(
max_depth=3, # 树的最大深度(防止过拟合)
criterion="gini", # 分裂标准:"gini"(基尼系数)或 "entropy"(信息增益)
random_state=42, # 随机种子(确保结果可复现)
)
# 训练模型
clf.fit(X_train, y_train)
# 预测
y_pred = clf.predict(X_test)
# 评估准确率
accuracy = clf.score(X_test, y_test)
print(f"测试集准确率: {accuracy:.2f}") #测试集准确率: 1.00
关键参数
max_depth
:树的最大深度(控制过拟合)criterion
:分裂标准("gini"
或"entropy",默认"gini"
)min_samples_split
:节点分裂所需的最小样本数min_samples_leaf
:叶子节点所需的最小样本数random_state
:随机种子(确保结果可复现)
所有参数如下
def __init__(
self,
*,
criterion="gini",
splitter="best",
max_depth=None,
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.0,
max_features=None,
random_state=None,
max_leaf_nodes=None,
min_impurity_decrease=0.0,
class_weight=None,
ccp_alpha=0.0,
monotonic_cst=None,
)
1. 树的结构控制
参数 | 默认值 | 作用 |
---|---|---|
max_depth |
None |
树的最大深度。None 表示不限制,直到所有叶子节点纯净或达到 min_samples_split 。防止过拟合的关键参数。 |
max_leaf_nodes |
None |
最大叶子节点数。优先调整 max_depth ,此参数作为补充限制。 |
min_samples_split |
2 |
节点分裂所需的最小样本数。若样本数 < 此值,则不再分裂。 |
min_samples_leaf |
1 |
叶子节点所需的最小样本数。分裂后子节点样本数必须 ≥ 此值。 |
min_weight_fraction_leaf |
0.0 |
叶子节点样本权重和的最小占比(加权数据时使用)。 |
2. 分裂策略
参数 | 默认值 | 作用 |
---|---|---|
criterion |
"gini" |
分裂质量的衡量标准: |
splitter |
"best" |
分裂策略: - "best" :选择最优分裂- "random" :随机选择分裂(更快的训练,适合 ExtraTree ) |
max_features |
None |
寻找最优分裂时考虑的最大特征数: - None :全部特征- "sqrt" :√(总特征数)- "log2" :log₂(总特征数)- 整数/浮点数:直接指定数量/比例 |
3. 正则化与防过拟合
参数 | 默认值 | 作用 |
---|---|---|
min_impurity_decrease |
0.0 |
分裂的最小不纯度减少量。若分裂后不纯度减少 < 此值,则停止分裂。 |
ccp_alpha |
0.0 |
代价复杂度剪枝的 α 参数(≥0)。值越大,剪枝越激进。 |
monotonic_cst |
None |
单调性约束(高级功能),强制预测值随特征单调变化。 |
4. 随机性与权重
参数 | 默认值 | 作用 |
---|---|---|
random_state |
None |
随机种子,控制特征/分裂的随机选择(确保结果可复现)。 |
class_weight |
None |
类别权重: - None :所有类别权重=1- "balanced" :自动按类别频率反比加权- 字典:手动指定类别权重(如 {0: 0.5, 1: 1.0} ) |
gini、
entropy与
经典算法ID3、C4.5、CART的关系
在 scikit-learn 的决策树实现中,criterion="gini"
(基尼系数)和 criterion="entropy"
(信息增益)的选择与经典算法(ID3、C4.5、CART)的关系如下:
1. 基尼系数(Gini) → CART 算法
对应算法:CART(Classification and Regression Trees,分类与回归树)
特点:
基尼系数是 CART 算法默认的分裂标准(用于分类任务)。
计算更高效(无需对数运算),但结果通常与信息增益非常接近。
2. 信息增益(Entropy) → ID3/C4.5 的启发
对应算法:ID3 和 C4.5 使用信息增益(或增益比),但 scikit-learn 并未完全实现 C4.5。
特点:
信息增益基于信息熵,计算稍慢(涉及对数运算)。
注意:scikit-learn 的
entropy
仅实现信息增益,未实现 C4.5 的增益比(Gain Ratio),因此不完全等同于 C4.5。
3. scikit-learn 的决策树实现本质是 CART
无论选择 gini
还是 entropy
,scikit-learn 的底层实现均基于 CART 框架,与经典算法的主要区别如下:
特性 | CART (scikit-learn) | ID3 | C4.5 |
---|---|---|---|
分裂标准 | 基尼系数或信息增益 | 信息增益 | 增益比(Gain Ratio) |
任务类型 | 分类 + 回归 | 仅分类 | 仅分类 |
特征类型 | 支持连续和离散特征 | 仅离散特征 | 支持连续和离散 |
二叉树/多叉树 | 二叉树(总是二元分裂) | 多叉树 | 多叉树 |
缺失值处理 | 内置支持 | 不支持 | 支持 |
剪枝方式 | 代价复杂度剪枝(CCP) | 无 | 悲观剪枝 |
为什么 scikit-learn 选择 CART 框架?
统一性:CART 同时支持分类和回归任务(ID3/C4.5 仅支持分类)。
效率:二叉树结构比多叉树更高效,适合大规模数据。
灵活性:支持连续特征和缺失值处理(无需像 ID3 那样预处理离散化)。
如何选择 gini
或 entropy
?
基尼系数(Gini):计算更快(推荐默认使用)。对类别分布不均匀的数据更鲁棒。
信息增益(Entropy):理论更贴近信息论。可能生成更平衡的树(但对性能影响通常很小)。
实际应用中,两者的分类效果通常差异不大,优先选择 gini
(除非有特定需求)。
总结
criterion="gini"
→ CART 算法的标准实现。criterion="entropy"
→ 借鉴了 ID3/C4.5 的思想,但仍在 CART 框架下运行。scikit-learn 没有完整实现 ID3/C4.5(如多叉树、增益比等功能),其决策树本质是 CART 的优化版本。
2. DecisionTreeRegressor
(决策树回归器)
适用于回归问题(预测连续值)。
基本用法
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import fetch_california_housing
from sklearn.metrics import mean_squared_error
# 加载数据
housing = fetch_california_housing()
X, y = housing.data, housing.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 创建决策树回归器
reg = DecisionTreeRegressor(
max_depth=4, # 树的最大深度
criterion="squared_error", # 分裂标准(MSE)
random_state=42,
)
# 训练模型
reg.fit(X_train, y_train)
# 预测
y_pred = reg.predict(X_test)
# 评估(均方误差 MSE)
mse = mean_squared_error(y_test, y_pred)
print(f"测试集均方误差: {mse:.2f}")
关键参数
criterion
:分裂标准("squared_error"
(MSE)、"friedman_mse"
或"absolute_error"
(MAE))其他参数与
DecisionTreeClassifier
类似(max_depth
、min_samples_split
等)
3. ExtraTreeClassifier
(极端随机树分类器)
与 DecisionTreeClassifier
类似,但分裂时随机选择特征和阈值(更随机化,训练更快,可能泛化更好)。
基本用法
from sklearn.tree import ExtraTreeClassifier
# 创建极端随机树分类器
clf = ExtraTreeClassifier(
max_depth=3,
criterion="gini",
random_state=42,
)
# 训练和预测(与 DecisionTreeClassifier 相同)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print(f"测试集准确率: {clf.score(X_test, y_test):.2f}")
关键区别
分裂时随机选择特征和阈值(比普通决策树更随机)
训练速度更快,但可能牺牲一些准确率
4. ExtraTreeRegressor
(极端随机树回归器)
与 DecisionTreeRegressor
类似,但分裂时随机选择特征和阈值。
基本用法
from sklearn.tree import ExtraTreeRegressor
# 创建极端随机树回归器
reg = ExtraTreeRegressor(
max_depth=4,
criterion="squared_error",
random_state=42,
)
# 训练和预测(与 DecisionTreeRegressor 相同)
reg.fit(X_train, y_train)
y_pred = reg.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print(f"测试集均方误差: {mse:.2f}")
关键区别
分裂时随机选择特征和阈值
训练更快,适用于大数据集
5. 可视化决策树
可以使用 plot_tree
或 export_graphviz
可视化决策树。
使用 plot_tree
(推荐)
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 8))
plot_tree(
clf, # 训练好的决策树模型
feature_names=iris.feature_names, # 特征名
class_names=iris.target_names, # 类别名
filled=True, # 填充颜色
rounded=True, # 圆角节点
)
plt.show()
输出示例:
使用 export_text
(文本形式)
from sklearn.tree import export_text
tree_rules = export_text(
clf,
feature_names=iris.feature_names,
)
print(tree_rules)
输出示例:
|--- median_income <= 5.03
| |--- ocean_proximity_INLAND <= 0.50
| | |--- median_income <= 3.11
| | | |--- median_income <= 2.21
| | | | |--- truncated branch of depth 21
| | | |--- median_income > 2.21
| | | | |--- truncated branch of depth 20
| | |--- median_income > 3.11
| | | |--- longitude <= -118.31
| | | | |--- truncated branch of depth 24
| | | |--- longitude > -118.31
| | | | |--- truncated branch of depth 22
| |--- ocean_proximity_INLAND > 0.50
| | |--- median_income <= 3.04
| | | |--- median_income <= 2.22
| | | | |--- truncated branch of depth 17
| | | |--- median_income > 2.22
| | | | |--- truncated branch of depth 19
| | |--- median_income > 3.04
| | | |--- median_income <= 4.07
| | | | |--- truncated branch of depth 16
| | | |--- median_income > 4.07
| | | | |--- truncated branch of depth 14
|--- median_income > 5.03
| |--- median_income <= 6.87
| | |--- ocean_proximity_INLAND <= 0.50
| | | |--- housing_median_age <= 36.50
| | | | |--- truncated branch of depth 18
| | | |--- housing_median_age > 36.50
| | | | |--- truncated branch of depth 11
| | |--- ocean_proximity_INLAND > 0.50
| | | |--- housing_median_age <= 32.50
| | | | |--- truncated branch of depth 11
| | | |--- housing_median_age > 32.50
| | | | |--- truncated branch of depth 5
| |--- median_income > 6.87
| | |--- median_income <= 8.16
| | | |--- housing_median_age <= 27.50
| | | | |--- truncated branch of depth 12
| | | |--- housing_median_age > 27.50
| | | | |--- truncated branch of depth 8
| | |--- median_income > 8.16
| | | |--- total_bedrooms <= 33.00
| | | | |--- truncated branch of depth 2
| | | |--- total_bedrooms > 33.00
| | | | |--- truncated branch of depth 9
总结
模型 | 适用任务 | 关键特点 | 示例场景 |
---|---|---|---|
DecisionTreeClassifier |
分类 | 标准决策树 | 鸢尾花分类 |
DecisionTreeRegressor |
回归 | 标准决策树 | 房价预测 |
ExtraTreeClassifier |
分类 | 更随机的分裂 | 高维数据分类 |
ExtraTreeRegressor |
回归 | 更随机的分裂 | 大数据回归 |
默认用
DecisionTreeClassifier/Regressor
(更稳定)数据量大时用
ExtraTree
(训练更快)防止过拟合:调整
max_depth
、min_samples_split
等参数