深度学习中的分类 vs 回归问题
1. 核心区别
- 分类问题:预测离散的类别标签(例如“猫/狗”,“0-9数字”)。
- 回归问题:预测连续的数值(例如房价、温度、销售额)。
类比:分类是“分桶”,回归是“量尺”。
2. 模型设计
任务类型 | 输出层激活函数 | 输出节点数 | 例子 |
---|---|---|---|
分类 | Softmax(多类)或 Sigmoid(二类) | 类别数量(如10类→10个节点) | 图像分类(判断动物种类) |
回归 | 无激活(线性输出) | 1个或多个数值 | 预测房价或气温 |
为什么用Softmax/Sigmoid?
将输出转换为概率(例如“猫的概率80%,狗的概率20%”)。
3. 损失函数(关键!)
分类:用交叉熵损失(Cross-Entropy Loss)
- 二分类:
Binary Cross-Entropy
- 多分类:
Categorical Cross-Entropy
作用:衡量预测概率和真实标签的差距(概率越准,损失越低)。
- 二分类:
回归:用均方误差(MSE)或平均绝对误差(MAE)
- MSE = 预测值与真实值的平方差的平均
- MAE = 预测值与真实值的绝对差的平均
选择:MSE对异常值更敏感,MAE更鲁棒。
4. 评估指标
分类:
- 准确率(Accuracy):正确预测的比例。
- 精确率(Precision):预测为“猫”的样本中,有多少真是猫。
- 召回率(Recall):所有真实的“猫”中,模型找出了多少。
- F1分数:精确率和召回率的调和平均。
- AUC-ROC曲线:衡量二分类模型的综合性能。
回归:
- RMSE(均方根误差):MSE的平方根,与数据单位一致。
- MAE(平均绝对误差):直接反映误差大小。
- R²分数:模型解释数据变动的比例(0~1,越高越好)。
5. 常见应用场景
分类任务:
- 图像分类(识别物体)
- 垃圾邮件检测(是/否垃圾)
- 疾病诊断(患病/健康)
回归任务:
- 房价预测
- 股票价格趋势
- 用户点击率预测
6. 注意事项(易错点!)
模型名字的坑:
- 逻辑回归(Logistic Regression)其实是分类模型,用Sigmoid输出概率。
- 不要被名字误导!
任务转换:
- 有些问题可以灵活处理。例如预测年龄:
- 回归:直接输出年龄(如25.3岁)
- 分类:分成“儿童/青年/中年/老年”
- 根据需求选择任务类型!
- 有些问题可以灵活处理。例如预测年龄:
标签预处理:
- 分类:标签需转为One-Hot编码(如数字3→[0,0,0,1,0,0,0,0,0,0])。
- 回归:标签直接是数值,需标准化(如缩放到0~1)。
7. 一句话总结
- 分类:分门别类,输出概率。
- 回归:预测数值,误差越小越好。
损失函数的定义
在模型的训练过程中,最重要的是在计算出参数的梯度值以后,利用梯度下降算法对模型参数进行更新。如果要得到梯度值,必须先定义损失函数。损失函数主要用于衡量模型的预测值与实际值之间的误差,然后模型根据这个损失值调整参数以减小误差,从而找到最优参数值。在深度学习中,根据完成的任务,损失函数可分为两类,第一类为用于回归任务的损失函数,第二类为用于分类任务的损失函数。
交叉熵损失
交叉熵与Softmax公式及示例表
公式类型 | 公式 | 例子(应用前 → 应用后) | 说明 |
---|---|---|---|
Softmax函数 | σ(zi)=∑j=1Cezjezi | 输入:z=[2.0,1.0,0.1] 输出:p=[0.659,0.242,0.099] |
将logits转为概率,总和为1。 |
Log-Softmax | log(σ(zi))=zi−log(∑j=1Cezj) | 输入:z=[3.0,1.0,0.5] 输出:log(p)=[2.17,0.17,−0.33] |
数值稳定,避免指数溢出。 |
二分类交叉熵(单样本) | L=−[ylog(p)+(1−y)log(1−p)] | 输入:y=1, p=0.8 输出:L=−log(0.8)≈0.223 |
二分类任务,标签为0或1。 |
多分类交叉熵(单样本) | L=−∑i=1Cyilog(pi) 等价于 L=−log(pk)(真实类别为第 k 类时) |
输入:y=[0,0,1], p=[0.1,0.2,0.7] 输出:L=−log(0.7)≈0.357 |
真实标签为One-Hot编码,仅计算对应类别的损失。 |
多分类交叉熵(多样本平均) | L=−N1∑n=1N∑i=1Cyn,ilog(pn,i) | 输入: 样本1:y1=[0,1,0], p1=[0.2,0.7,0.1] 样本2:y2=[1,0,0], p2=[0.9,0.1,0.0] 输出:L=1/2(−log(0.7)−log(0.9))≈0.107 |
批量计算时取所有样本损失的平均。 |
稀疏交叉熵(多样本) | L=−N1∑n=1Nlog(pn,k) | 输入: 真实标签 k=[2,0](对应One-Hot为[[0,0,1], [1,0,0]]) 预测概率 p=[[0.1,0.2,0.7],[0.9,0.1,0.0]] 输出:L=1/2(−log(0.7)−log(0.9))≈0.107 |
直接使用整数标签,避免One-Hot编码。 |
交叉熵函数详细介绍
1. 交叉熵损失函数的作用
在多分类任务中,交叉熵损失(Cross-Entropy Loss)用于衡量模型预测的概率分布与真实标签分布之间的差异。
- 目标:让模型的预测概率尽可能接近真实标签的分布。
- 核心思想:如果预测概率越接近真实标签,损失值越低;反之则损失值越高。
2. 数学公式与推导
假设一个多分类问题有 C 个类别,模型对输入样本的预测输出为 logits(未归一化的原始分数),记为 z=[z1,z2,...,zC]。
Step 1: Softmax归一化
通过 Softmax函数 将 logits 转换为概率分布:
pi=∑j=1Cezjezi(i=1,2,...,C)
- pi 表示模型预测样本属于第 i 类的概率。
- 所有 pi 之和为 1(满足概率分布)。
Step 2: 真实标签的表示
真实标签 y 通常以 One-Hot编码 表示。例如,若有 3 个类别且真实类别为第 2 类,则 y=[0,1,0]。
Step 3: 交叉熵损失计算
交叉熵损失公式:
L=−i=1∑Cyilog(pi)
- 由于 One-Hot 编码中只有真实类别 k 的 yk=1,其他为 0,因此公式简化为:
L=−log(pk)
- 直观解释:真实类别对应的预测概率 pk 越接近 1,损失越小;越接近 0,损失越大。
3. 实际计算示例
以手写数字识别(10分类)为例,假设某样本的真实标签是数字 3(即 y=[0,0,0,1,0,0,0,0,0,0]),模型输出的 logits 和计算过程如下:
Logits z | Softmax概率 p |
---|---|
z=[2.0,1.0,0.1,3.0,0.5,1.2,0.3,1.8,0.5,0.9] | 通过 Softmax 计算得: p=[0.13,0.08,0.05,0.48,0.06,0.09,0.04,0.10,0.06,0.07] |
- 真实类别为第 4个位置(索引3,从0开始),对应 p3=0.48。
- 交叉熵损失:
L=−log(0.48)≈0.73
4. 为什么用交叉熵而不用均方误差(MSE)?
- 梯度特性:
- 交叉熵的梯度在错误预测时更大,能更快更新模型参数。
- MSE的梯度在概率接近0或1时趋于平缓,导致训练速度慢。
- 适用性:
- 交叉熵直接衡量概率分布的差异,更适合分类任务。
- MSE更适用于回归问题(连续值误差)。