译 | 介绍PyTabKit:一个试图超越 Scikit-Learn的新机器学习库

发布于:2025-08-02 ⋅ 阅读:(19) ⋅ 点赞:(0)

github地址:https://github.com/dholzmueller/pytabkit
译原文地址:Get with the Times: PyTabKit for better Tabular Machine Learning over Sk-Learn (CODE Included)


长期以来,Scikit-Learn 一直是处理表格数据机器学习的首选库,提供了丰富的算法、预处理工具和模型评估功能。它仍然很出色,但为什么还要开着你爷爷那辆老旧的 58 年款雪佛兰车呢?让它保持古董地位吧。现在介绍 PyTabKit —— 一个新框架,旨在取代 Scikit-Learn,用于表格数据的分类和回归,采用了最新技术如 RealMLP 和为梯度提升树(GBDT)优化的默认超参数。

完整文章链接: 2407.04491

PyTabKit 提供了类似 scikit-learn 接口的现代表格分类和回归方法,并在我们的论文中进行了基准测试。它还包含了用于基准测试的相关代码。

支持的模型

  • 神经网络:RealMLP(调优默认、HPO、集成)
  • 梯度提升树:XGBoost、LightGBM、CatBoost(默认、调优、HPO)
  • 其他模型:TabR、TabM、ResNet 等

后处理和校准
支持时序缩放等后处理技术,提升预测概率的准确性。示例:

from pytabkit import RealMLP_TD_Classifier

clf = RealMLP_TD_Classifier(
    val_metric_name='ref-ll-ts',  # 采用对数损失
    calibration_method='ts-mix',  # 时序缩放
    use_ls=False
)

为什么要超越 Scikit-Learn?

Scikit-Learn 为模型开发提供了坚实基础,但缺乏高度优化的深度学习方法和高效的自动调参功能。最新研究表明:

RealMLP 可与 GBDTs 竞争

  • 传统上,表格数据的深度学习模型需要大量调参,导致训练慢且不够实用。
  • RealMLP 是一个经过优化的多层感知机,基于 118 个数据集的基准测试进行了微调,在中等到大型数据集(1K 到 50 万样本)上性能可与 GBDTs 相媲美。
  • RealMLP 的改进包括稳健的数值缩放、数值嵌入和优化的权重初始化,使其成为传统模型的强有力替代。

更好的默认超参数很重要

  • Scikit-Learn 的默认超参数表现通常不如调优后的模型。
  • PyTabKit 为 XGBoost、LightGBM 和 CatBoost 提供了元调优的默认参数,能在无需调参的情况下超越 Scikit-Learn 的基线实现。
  • 这些默认设置在元训练基准上优化,并在 90 个未见过的数据集上验证了效果。

效率与准确性兼顾

  • 超参数优化代价高昂,尤其是深度学习模型。
  • PyTabKit 的优化默认配置让用户在许多情况下可以跳过调参,开箱即用,得到强劲效果。
  • 这使其成为 AutoML 系统中速度与准确性权衡的更佳选择。

RealMLP:表格数据神经网络的变革者

虽然梯度提升是结构化数据的主流方法,但深度学习如果正确实施,能缩小差距。RealMLP 引入了多项架构改进:

预处理改进

  • 对数值特征使用稳健缩放和平滑裁剪。
  • 对低基数类别特征使用独热编码。

架构增强

  • 引入对角权重层,提升表示能力。
  • 采用新颖的数值嵌入,优于传统特征变换。
  • 更智能的初始化策略,加快收敛速度。

性能提升

  • 基准测试显示 RealMLP 在某些场景下能匹配甚至超越 GBDTs。
  • 将 RealMLP 与优化的 GBDT 默认参数结合,能实现无需昂贵调参的最先进结果

未来展望:PyTabKit 作为新标准

PyTabKit 不仅是另一个机器学习库,而是一场范式转变。结合更强的神经网络架构、更优的默认超参数和实用的高效性,它有潜力取代 Scikit-Learn,成为许多实际应用的首选。

对于处理中等到大型数据集的用户,PyTabKit 提供了更快的训练速度、竞争力的准确率和减少的调参工作量,是现代机器学习工作流的理想方案。

代码示例:用 PyTabKit 训练 RealMLP 和树模型

使用方式和 Sklearn 一样简单!

安装

pip install pytabkit
pip install openml

获取数据集

这里使用 OpenML 的 Covertype 数据集,为了演示限制为 15,000 个样本。

import openml
from sklearn.model_selection import train_test_split
import numpy as np

task = openml.tasks.get_task(361113)
dataset = openml.datasets.get_dataset(task.dataset_id, download_data=False)
X, y, categorical_indicator, attribute_names = dataset.get_data(
    dataset_format='dataframe',
    target=task.target_name
)

index = np.random.choice(range(len(X)), 15000, replace=False)
X = X.iloc[index]
y = y.iloc[index]

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

使用 RealMLP 训练

from pytabkit import RealMLP_TD_Classifier
from sklearn.metrics import accuracy_score

model = RealMLP_TD_Classifier()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy of RealMLP: {acc}")

预期输出:

Accuracy of RealMLP: 0.8770666666666667

使用 Bagging(交叉验证集成)

RealMLP 支持通过设置 n_cv=5 进行 5 折交叉验证集成,训练仍高效。

model = RealMLP_TD_Classifier(n_cv=5)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy of RealMLP with bagging: {acc}")

预期输出:

Accuracy of RealMLP with bagging: 0.8930666666666667

超参数优化

使用 RealMLP_HPO_Classifier 进行超参调优,调优步数可调。

from pytabkit import RealMLP_HPO_Classifier

n_hyperopt_steps = 3
model = RealMLP_HPO_Classifier(n_hyperopt_steps=n_hyperopt_steps)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy of RealMLP with {n_hyperopt_steps} steps HPO: {acc}")

预期输出:

Accuracy of RealMLP with 3 steps HPO: 0.8605333333333334

使用优化默认参数的树模型

调优默认(TD)模型使用优化后的超参数,默认(D)模型使用库默认参数。

from pytabkit import (
    CatBoost_TD_Classifier, CatBoost_D_Classifier,
    LGBM_TD_Classifier, LGBM_D_Classifier,
    XGB_TD_Classifier, XGB_D_Classifier
)

for model in [CatBoost_TD_Classifier(), CatBoost_D_Classifier(),
              LGBM_TD_Classifier(), LGBM_D_Classifier(),
              XGB_TD_Classifier(), XGB_D_Classifier()]:
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    acc = accuracy_score(y_test, y_pred)
    print(f"Accuracy of {model.__class__.__name__}: {acc}")

预期输出:

Accuracy of CatBoost_TD_Classifier: 0.8685333333333334
Accuracy of CatBoost_D_Classifier: 0.8464
Accuracy of LGBM_TD_Classifier: 0.8602666666666666
Accuracy of LGBM_D_Classifier: 0.8344
Accuracy of XGB_TD_Classifier: 0.8544
Accuracy of XGB_D_Classifier: 0.8472

集成优化默认参数的树模型和 RealMLP

通过集成多个模型可以建立强基线。

from pytabkit import Ensemble_TD_Classifier

model = Ensemble_TD_Classifier()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy of Ensemble_TD_Classifier: {acc}")

以上内容展示了 PyTabKit 在表格数据机器学习中的强大能力,尤其是 RealMLP 的表现和优化默认参数的树模型,为实际应用提供了更高效、准确的解决方案。


网站公告

今日签到

点亮在社区的每一天
去签到