SciKit-Learn 全面分析分类任务 breast_cancer 数据集

发布于:2025-09-13 ⋅ 阅读:(21) ⋅ 点赞:(0)

背景

乳腺癌数据集,569个样本,30个特征,2个类别(良性/恶性)

步骤

  1. 加载数据集
  2. 拆分训练集、测试集
  3. 数据预处理(标准化)
  4. 选择模型
  5. 模型训练(拟合)
  6. 测试模型效果
  7. 评估模型

分析方法

对数据集使用 7 种分类方法进行分析

  • K 近邻(K-NN)
  • 决策树
  • 支持向量机(SVM)
  • 逻辑回归
  • 随机森林
  • 朴素贝叶斯
  • 多层感知机(MLP)

模型分析结果

各模型 ROC & AUC 对比


--- 正在训练 K近邻 (K-NN) 模型 ---
K近邻 (K-NN) 模型的准确率: 0.9591
K近邻 (K-NN) 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.95      0.94      0.94        63
      benign       0.96      0.97      0.97       108

    accuracy                           0.96       171
   macro avg       0.96      0.95      0.96       171
weighted avg       0.96      0.96      0.96       171


--- 正在训练 决策树 模型 ---
决策树 模型的准确率: 0.9415
决策树 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.90      0.95      0.92        63
      benign       0.97      0.94      0.95       108

    accuracy                           0.94       171
   macro avg       0.93      0.94      0.94       171
weighted avg       0.94      0.94      0.94       171


--- 正在训练 支持向量机 (SVM) 模型 ---
支持向量机 (SVM) 模型的准确率: 0.9766
支持向量机 (SVM) 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.97      0.97      0.97        63
      benign       0.98      0.98      0.98       108

    accuracy                           0.98       171
   macro avg       0.97      0.97      0.97       171
weighted avg       0.98      0.98      0.98       171


--- 正在训练 逻辑回归 模型 ---
逻辑回归 模型的准确率: 0.9825
逻辑回归 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.97      0.98      0.98        63
      benign       0.99      0.98      0.99       108

    accuracy                           0.98       171
   macro avg       0.98      0.98      0.98       171
weighted avg       0.98      0.98      0.98       171


--- 正在训练 随机森林 模型 ---
随机森林 模型的准确率: 0.9708
随机森林 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.98      0.94      0.96        63
      benign       0.96      0.99      0.98       108

    accuracy                           0.97       171
   macro avg       0.97      0.96      0.97       171
weighted avg       0.97      0.97      0.97       171


--- 正在训练 朴素贝叶斯 模型 ---
朴素贝叶斯 模型的准确率: 0.9357
朴素贝叶斯 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.92      0.90      0.91        63
      benign       0.94      0.95      0.95       108

    accuracy                           0.94       171
   macro avg       0.93      0.93      0.93       171
weighted avg       0.94      0.94      0.94       171


--- 正在训练 多层感知器 (MLP) 模型 ---
多层感知器 (MLP) 模型的准确率: 0.9825
多层感知器 (MLP) 模型的分类报告:
              precision    recall  f1-score   support

   malignant       0.98      0.97      0.98        63
      benign       0.98      0.99      0.99       108

    accuracy                           0.98       171
   macro avg       0.98      0.98      0.98       171
weighted avg       0.98      0.98      0.98       171

代码

from sklearn.datasets import load_breast_cancer

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neural_network import MLPClassifier

from sklearn.metrics import accuracy_score, classification_report, roc_curve, auc

import matplotlib.pyplot as plt
import numpy as np

# 设置 Matplotlib 字体以正确显示中文
plt.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Zen Hei', 'STHeiti', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False  

def perform_breast_cancer_analysis():
    """
    使用 scikit-learn 对乳腺癌数据集进行全面的分析。
    该函数包含数据加载、预处理、模型训练、评估和 ROC/AUC 可视化。
    """
    print("--- 正在加载乳腺癌数据集 ---")
    # 加载乳腺癌数据集
    cancer = load_breast_cancer()
    
    # 获取数据特征和目标标签
    X = cancer.data
    y = cancer.target
    target_names = cancer.target_names

    print("\n--- 数据集概览 ---")
    print(f"数据形状: {X.shape}")
    print(f"目标名称: {target_names}")

    # 将数据集划分为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

    print("\n--- 数据划分结果 ---")
    print(f"训练集形状: {X_train.shape}")
    print(f"测试集形状: {X_test.shape}")
    
    # 数据标准化
    print("\n--- 正在对数据进行标准化处理 ---")
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # 定义并训练多个分类器模型
    models = {
        "K近邻 (K-NN)": KNeighborsClassifier(n_neighbors=5),
        "决策树": DecisionTreeClassifier(random_state=42),
        "支持向量机 (SVM)": SVC(kernel='rbf', C=1.0, random_state=42, probability=True),
        "逻辑回归": LogisticRegression(random_state=42, max_iter=10000),
        "随机森林": RandomForestClassifier(random_state=42),
        "朴素贝叶斯": GaussianNB(),
        "多层感知器 (MLP)": MLPClassifier(random_state=42, max_iter=10000)
    }

    print("\n--- 模型训练与评估 ---")
    for name, model in models.items():
        print(f"\n--- 正在训练 {name} 模型 ---")
        model.fit(X_train_scaled, y_train)
        y_pred = model.predict(X_test_scaled)
        accuracy = accuracy_score(y_test, y_pred)
        report = classification_report(y_test, y_pred, target_names=target_names)
        print(f"{name} 模型的准确率: {accuracy:.4f}")
        print(f"{name} 模型的分类报告:\n{report}")

    print("\n--- ROC 曲线和 AUC 对比 ---")
    num_models = len(models)
    cols = 3
    rows = (num_models + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(18, 6 * rows))
    axes = axes.flatten()

    for i, (name, model) in enumerate(models.items()):
        ax = axes[i]
        
        # 获取预测概率
        if hasattr(model, "predict_proba"):
            y_score = model.predict_proba(X_test_scaled)[:, 1]
        else:
            y_score = model.decision_function(X_test_scaled)
        
        # 计算 ROC 曲线和 AUC
        fpr, tpr, _ = roc_curve(y_test, y_score)
        roc_auc = auc(fpr, tpr)
        
        # 绘制 ROC 曲线并填充
        ax.plot(fpr, tpr, label=f'AUC = {roc_auc:.2f}', alpha=0.7)
        ax.fill_between(fpr, tpr, alpha=0.1)
        
        # 绘制对角线 (随机猜测)
        ax.plot([0, 1], [0, 1], 'k--', lw=2)
        
        # 设置图表属性
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel('假正率 (FPR)')
        ax.set_ylabel('真正率 (TPR)')
        ax.set_title(f'{name} - ROC 曲线')
        ax.legend(loc="lower right", fontsize='small')
        ax.grid(True)
    
    for j in range(num_models, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    perform_breast_cancer_analysis()

网站公告

今日签到

点亮在社区的每一天
去签到