【kneighborsclassifier 函数及其参数介绍】

发布于:2024-07-03 ⋅ 阅读:(9) ⋅ 点赞:(0)

一、kneighborsclassifier是什么?

kneighborsclassifierscikit-learn 库中 K-近邻算法的实现,用于分类任务。KNN 算法的基本思想是给定一个样本数据集,对于每个输入的新数据点,找到其在样本数据集中最近的 K 个数据点,根据这 K 个邻居的类别来预测新数据点的类别。


二、使用步骤

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris

# 载入数据
iris = load_iris()
X = iris.data
y = iris.target

# 分割数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler()
scaler.fit(X_train)

X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)
# 初始化KNN分类器
knn = KNeighborsClassifier(n_neighbors=5)

# 训练模型
knn.fit(X_train, y_train)
# 预测测试集
y_pred = knn.predict(X_test)

# 分类报告
print(classification_report(y_test, y_pred))

# 可视化混淆矩阵
confusion = confusion_matrix(y_test, y_pred)
plt.matshow(confusion)
# 设置中文
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 在每个单元格中添加数字
for i in range(confusion.shape[0]):
    for j in range(confusion.shape[1]):
        plt.text(x=j, y=i, s=str(confusion[i, j]), va='center', ha='center', color='red')

plt.colorbar()
plt.ylabel('实际类型')
plt.xlabel('预测类型')
plt.title('混淆矩阵')
plt.show()

三、kneighborsclassifier函数及其参数详解

1. 参数说明

  • n_neighbors: 用于指定邻居的数目,默认值是 5。
  • weights: 用于确定邻居对预测的贡献。可以是 ‘uniform’(默认值,表示所有邻居的权重相同),‘distance’(邻居的权重与距离成反比),或者用户自定义的权重函数。
  • algorithm: 用于计算最近邻居的算法。可选值有 ‘auto’(默认值,根据数据选择最佳算法),‘ball_tree’,‘kd_tree’,以及 ‘brute’。
  • leaf_size: 用于指定 BallTree 或 KDTree 中叶节点的大小,默认值是 30。影响树的构建和查询速度。
  • p: 用于指定距离度量的方法。p=2 是欧氏距离,p=1 是曼哈顿距离。
  • metric: 用于指定距离度量,默认值是 ‘minkowski’。
  • metric_params: 用于指定距离度量的附加参数,默认是 None。
  • n_jobs: 用于指定并行运行的作业数量。-1 表示使用所有的处理器。