一、概述
支持向量机(Support Vector Machine,SVM)是一种应用范围非常广泛的算法,既可以用于分类,也可以用于回归。
本文将介绍如何将线性支持向量机应用于二元分类问题,以间隔(margin)最大化为基准,得到更好的决策边界。虽然该算法的决策边界与逻辑回归一样是线性的,但有时线性支持向量机得到的结果更好。
下面对同一数据分别应用线性支持向量机和逻辑回归,并比较其结果
左图是学习前的数据,右图是从数据中学习后的结果。右图中用黑色直线标记的是线性支持向量机的决策边界,用蓝色虚线标记的是逻辑回归的决策边界。线性支持向量机的分类结果更佳。线性支持向量机的学习方式是:以间隔最大化为基准,让决策边界尽可能地远离数据。下面来看一下线性支持向量机是如何从数据中学习的。
二、算法说明
间隔的定义:
以平面上的二元分类问题为例进行说明,并且假设数据可以完全分类。线性支持向量机通过线性的决策边界将平面一分为二,据此进行二元分类。此时,训练数据中最接近决策边界的数据与决策边界之间的距离就称为间隔
右图的间隔大于左图的间隔。支持向量机试图通过增大决策边界和训练数据之间的间隔来获得更合理的边界。
三、示例代码
下面生成线性可分的数据,将其分割成训练数据和验证数据,使用训练数据训练线性支持向量机,使用验证数据评估正确率。另外,由于使用了随机数,所以每次运行的结果可能有所不同。
"""
LinearSVC:线性支持向量分类模型。
make_blobs:生成模拟数据集。
train_test_split:划分训练集和测试集。
accuracy_score:计算分类准确率。
"""
from sklearn.svm import LinearSVC
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
"""
生成 50个样本,分为两类,每类25个样本。
数据特征为二维,围绕两个中心点 (-1, -0.125) 和 (0.5, 0.5) 生成。
cluster_std=0.3 标准差表示数据点围绕中心点的分布较集中
"""
centers = [(-1, -0.125), (0.5, 0.5)]
X, y = make_blobs(n_samples=50, n_features=2,
centers=centers, cluster_std=0.3)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) # 15个样本作为测试集(30%)
model = LinearSVC()
model.fit(X_train, y_train) # 训练
y_pred = model.predict(X_test)
print(accuracy_score(y_pred, y_test))
"""
评估函数说明:
输入参数:y_pred:模型对测试集的预测结果(由 model.predict(X_test) 生成)
y_test:测试集的真实标签(实际正确答案)
输出结果:
返回一个 0~1 之间的浮点数,表示预测正确的样本比例。
例如:0.85 表示 85% 的测试样本被正确分类。
"""
四、详细说明
软间隔和支持向量
目前我们了解的都是数据可以线性分离的情况,这种不允许数据进入间隔内侧的情况称为硬间隔。但一般来说,数据并不是完全可以线性分离的,所以要允许一部分数据进入间隔内侧,这种情况叫作软间隔。通过引入软间隔,无法线性分离的数据也可以如图所示进行学习。
基于线性支持向量机的学习结果,我们可以将训练数据分为以下3种。
1. 与决策边界之间的距离比间隔还要远的数据:间隔外侧的数据。
2. 与决策边界之间的距离和间隔相同的数据:间隔上的数据。
3. 与决策边界之间的距离比间隔近,或者误分类的数据:间隔内侧的数据。
其中,我们将间隔上的数据和间隔内侧的数据特殊对待,称为支持向量。支持向量是确定决策边界的重要数据。间隔外侧的数据则不会影响决策边界的形状。由于间隔内侧的数据包含被误分类的数据,所以乍看起来通过调整间隔,使间隔内侧不存在数据的做法更好。但对于线性可分的数据,如果强制训练数据不进入间隔内侧,可能会导致学习结果对数据过拟合。使用由两个标签组成的训练数据训练线性支持向量机而得到的结果如图所示。
左图是不允许数据进入间隔内侧的硬间隔的情况,右图是允许数据进入间隔内侧的软间隔的情况。
另外,在蓝色点表示的训练数据中特意加上了偏离值。
比较两个结果可以发现,使用了硬间隔的左图上的决策边界受偏离值的影响很大;而在引入软间隔的右图上的学习结果不容易受到偏离值的影响。在使用软间隔时,允许间隔内侧进入多少数据由超参数决定。
与其他算法一样,在决定超参数时,需要使用网格搜索(grid search)和随机搜索(random search)等方法反复验证后再做决定。