【逻辑回归实现多分类】

发布于:2024-07-04 ⋅ 阅读:(68) ⋅ 点赞:(0)


一、逻辑回归是什么?

逻辑回归是一种预测分析算法,其基于概率理论。对于二分类问题,逻辑回归模型会预测一个事件发生的概率。这是通过使用逻辑函数(也称为Sigmoid函数)实现的,该函数将任意实数映射到0到1的区间内,表示概率。

1. Sigmoid函数

Sigmoid函数公式如下:

σ ( z ) = 1 1 + e − z \sigma(z) = \frac{1}{1 + e^{-z}} σ(z)=1+ez1

其中,z 是输入特征和权重的线性组合。

二、如何扩展到多分类

逻辑回归可以通过多种方法扩展到多分类问题,包括“一对多”(One-vs-Rest, OvR)、“一对一”(One-vs-One, OvO)以及“多对多”(Multinomial or Softmax,简称MvM)方法。

1. 一对多(OvR)

在OvR方法中,对于有N个类别的分类问题,我们训练N个不同的二分类逻辑回归模型。每个模型负责将其中一个类别与其它所有类别区分开。这种方法的主要优点是模型数量相对较少,每个模型的训练和预测速度都较快。

实现细节

在OvR方法中,为每个类别创建一个模型。例如,如果有三个类别A、B和C,会训练三个模型:

  • 模型1: A与非A(即B和C)
  • 模型2: B与非B(即A和C)
  • 模型3: C与非C(即A和B)

2. 一对一(OvO)

OvO方法则是在每对类别之间训练一个二分类模型。对于N个类别,这需要训练 ( N \times (N-1) / 2 ) 个模型。每个模型都是一个独立的逻辑回归分类器,专注于区分两个特定的类别。

3. 多对多(MvM)

多对多方法,也称为Softmax回归或多项式逻辑回归,是逻辑回归的直接扩展,用于处理多类分类问题。与OvR和OvO不同,它不是通过组合多个二分类器来实现的,而是直接对多个类别进行建模。

实现细节

在MvM方法中,模型试图学习每个类别相对于其他类别的概率。通过使用Softmax函数实现,它是Sigmoid函数的一种推广,用于多类别。Softmax函数将一个含多个值的向量映射为一个概率分布。

Softmax函数的公式如下:

S o f t m a x ( z ) i = e z i ∑ j = 1 K e z j \mathrm{Softmax}(z)_i=\frac{e^{z_i}}{\sum_{j=1}^Ke^{z_j}} Softmax(z)i=j=1Kezjezi

其中, z z z是输入特征与权重的线性组合, K K K 是类别的总数, z i z_i zi是第 i i i个类别的分数。

在训练过程中,模型学习区分各个类别的特征。在预测时,给定一个输入样本,模型会计算每个类别的分数,然后应用Softmax函数来估计属于每个类别的概率,最终选择概率最高的类别。

根据具体的问题和数据集,可以选择适合的方法来将逻辑回归应用于多分类问题。OvR适用于类别数量较少且数据平衡的情况,OvO适用于类别较多的复杂问题,而MvM提供了一种直接对多个类别进行建模的方法,尤其适用于类别之间存在内在顺序或关联的情况。

三、sklearn库实现

使用Python的scikit-learn库实现逻辑回归的多分类。

1. 引入库

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.datasets import load_iris
from sklearn.multiclass import OneVsOneClassifier

2. 加载并准备数据

使用scikit-learn中的鸢尾花数据集作为示例。数据集包含三种鸢尾花的特征数据,总共150个样本。

# 加载数据
iris = load_iris()
X, y = iris.data, iris.target

# 分割数据为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

3. 使用OvR方法训练模型

# 使用OvR方法训练模型
model_ovr = LogisticRegression(multi_class='ovr', max_iter=200)
model_ovr.fit(X_train, y_train)

# 预测
y_pred_ovr = model_ovr.predict(X_test)

# 评估模型
print("OvR方法")
print(f"准确率: {accuracy_score(y_test, y_pred_ovr)}")
print("分类报告:")
print(classification_report(y_test, y_pred_ovr))

OvR方法
准确率: 0.9666666666666667
分类报告:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        10
           1       1.00      0.89      0.94         9
           2       0.92      1.00      0.96        11

    accuracy                           0.97        30
   macro avg       0.97      0.96      0.97        30
weighted avg       0.97      0.97      0.97        30

4. 使用OvO方法训练模型

# 使用OvO方法训练模型
model_ovo = OneVsOneClassifier(LogisticRegression(max_iter=200))
model_ovo.fit(X_train, y_train)

# 预测
y_pred_ovo = model_ovo.predict(X_test)

# 评估模型
print("OvO方法")
print(f"准确率: {accuracy_score(y_test, y_pred_ovo)}")
print("分类报告:")
print(classification_report(y_test, y_pred_ovo))

OvO方法
准确率: 1.0
分类报告:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        10
           1       1.00      1.00      1.00         9
           2       1.00      1.00      1.00        11

    accuracy                           1.00        30
   macro avg       1.00      1.00      1.00        30
weighted avg       1.00      1.00      1.00        30

5. 使用MvM方法训练模型

model_mvm = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=200)
model_mvm.fit(X_train, y_train)

# 预测
y_pred_mvm = model_mvm.predict(X_test)

# 评估模型
print("MvM方法")
print(f"准确率: {accuracy_score(y_test, y_pred_mvm)}")
print("分类报告:")
print(classification_report(y_test, y_pred_mvm))

MvM方法
准确率: 1.0
分类报告:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        10
           1       1.00      1.00      1.00         9
           2       1.00      1.00      1.00        11

    accuracy                           1.00        30
   macro avg       1.00      1.00      1.00        30
weighted avg       1.00      1.00      1.00        30

网站公告

今日签到

点亮在社区的每一天
去签到