为什么 nn.CrossEntropyLoss
= LogSoftmax
+ nn.NLLLoss
?
在使用 PyTorch 时,我们经常听说 nn.CrossEntropyLoss
是 LogSoftmax
和 nn.NLLLoss
的组合。这句话听起来简单,但背后到底是怎么回事?为什么这两个分开的功能加起来就等于一个完整的交叉熵损失?今天我们就从数学公式到代码实现,彻底搞清楚它们的联系。
1. 先认识三个主角
要理解这个等式,先得知道每个部分的定义和作用:
nn.CrossEntropyLoss
:交叉熵损失,直接接受未归一化的 logits,计算模型预测与真实标签的差距,适用于多分类任务。LogSoftmax
:将 logits 转为对数概率(log probabilities),输出范围是负值。nn.NLLLoss
:负对数似然损失,接受对数概率,计算正确类别的负对数值。
表面上看,nn.CrossEntropyLoss
是一个独立的损失函数,而 LogSoftmax
和 nn.NLLLoss
是两步操作。为什么说它们本质上是一回事呢?答案藏在数学公式和计算逻辑里。
2. 数学上的拆解
让我们从交叉熵的定义开始,逐步推导。
(1) 交叉熵的数学形式
交叉熵(Cross-Entropy)衡量两个概率分布的差异。在多分类任务中:
- ( p p p ):真实分布,通常是 one-hot 编码(比如
[0, 1, 0]
表示第 1 类)。 - ( q q q ):预测分布,是模型输出的概率(比如
[0.2, 0.5, 0.3]
)。
交叉熵公式为:
H ( p , q ) = − ∑ c = 1 C p c log ( q c ) H(p, q) = -\sum_{c=1}^{C} p_c \log(q_c) H(p,q)=−c=1∑Cpclog(qc)
对于 one-hot 编码,( p c p_c pc ) 在正确类别上为 1,其他为 0,所以简化为:
H ( p , q ) = − log ( q correct ) H(p, q) = -\log(q_{\text{correct}}) H(p,q)=−log(qcorrect)
其中 ( q correct q_{\text{correct}} qcorrect ) 是正确类别对应的预测概率。对 ( N N N ) 个样本取平均,损失为:
Loss = − 1 N ∑ i = 1 N log ( q i , y i ) \text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} \log(q_{i, y_i}) Loss=−N1i=1∑Nlog(qi,yi)
这正是交叉熵损失的核心。
(2) 从 logits 到概率
神经网络输出的是原始分数(logits),比如 ( z = [ z 1 , z 2 , z 3 ] z = [z_1, z_2, z_3] z=[z1,z2,z3] )。要得到概率 ( q q q ),需要用 Softmax:
q j = e z j ∑ k = 1 C e z k q_j = \frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}} qj=∑k=1Cezkezj
交叉熵损失变成:
Loss = − 1 N ∑ i = 1 N log ( e z i , y i ∑ k = 1 C e z i , k ) \text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} \log\left(\frac{e^{z_{i, y_i}}}{\sum_{k=1}^{C} e^{z_{i,k}}}\right) Loss=−N1i=1∑Nlog(∑k=1Cezi,kezi,yi)
这就是 nn.CrossEntropyLoss
的数学形式。
(3) 分解为两步
现在我们把这个公式拆开:
第一步:LogSoftmax
计算对数概率:
log ( q j ) = log ( e z j ∑ k = 1 C e z k ) = z j − log ( ∑ k = 1 C e z k ) \log(q_j) = \log\left(\frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}}\right) = z_j - \log\left(\sum_{k=1}^{C} e^{z_k}\right) log(qj)=log(∑k=1Cezkezj)=zj−log(k=1∑Cezk)
这正是LogSoftmax
的定义。它把 logits ( z z z ) 转为对数概率 ( log ( q ) \log(q) log(q) )。第二步:NLLLoss
有了对数概率 ( log ( q ) \log(q) log(q) ),取出正确类别的值,取负号并平均:
NLL = − 1 N ∑ i = 1 N log ( q i , y i ) \text{NLL} = -\frac{1}{N} \sum_{i=1}^{N} \log(q_{i, y_i}) NLL=−N1i=1∑Nlog(qi,yi)
这就是nn.NLLLoss
的公式。
组合起来:
LogSoftmax
把 logits 转为 ( log ( q ) \log(q) log(q) )。nn.NLLLoss
对 ( log ( q ) \log(q) log(q) ) 取负号,计算损失。- 两步合起来正好是 ( − log ( q correct ) -\log(q_{\text{correct}}) −log(qcorrect) ),与交叉熵一致。
3. PyTorch 中的实现验证
从数学上看,nn.CrossEntropyLoss
的确可以分解为 LogSoftmax
和 nn.NLLLoss
。我们用代码验证一下:
import torch
import torch.nn as nn
# 输入数据
logits = torch.tensor([[1.0, 2.0, 0.5], [0.1, 0.5, 2.0]]) # [batch_size, num_classes]
target = torch.tensor([1, 2]) # 真实类别索引
# 方法 1:直接用 nn.CrossEntropyLoss
ce_loss_fn = nn.CrossEntropyLoss()
ce_loss = ce_loss_fn(logits, target)
print("CrossEntropyLoss:", ce_loss.item())
# 方法 2:LogSoftmax + nn.NLLLoss
log_softmax = nn.LogSoftmax(dim=1)
nll_loss_fn = nn.NLLLoss()
log_probs = log_softmax(logits) # 计算对数概率
nll_loss = nll_loss_fn(log_probs, target)
print("LogSoftmax + NLLLoss:", nll_loss.item())
运行结果:两个输出的值完全相同(比如 0.75)。这证明 nn.CrossEntropyLoss
在内部就是先做 LogSoftmax
,再做 nn.NLLLoss
。
4. 为什么 PyTorch 这么设计?
既然 nn.CrossEntropyLoss
等价于 LogSoftmax
+ nn.NLLLoss
,为什么 PyTorch 提供了两种方式?
便利性:
nn.CrossEntropyLoss
是一个“一体式”工具,直接输入 logits 就能用,适合大多数场景,省去手动搭配的麻烦。模块化:
LogSoftmax
和nn.NLLLoss
分开设计,给开发者更多灵活性:- 你可以在模型里加
LogSoftmax
,只用nn.NLLLoss
计算损失。 - 可以单独调试对数概率(比如打印
log_probs
)。 - 在某些自定义损失中,可能需要用到独立的
LogSoftmax
。
- 你可以在模型里加
数值稳定性:
nn.CrossEntropyLoss
内部优化了计算,避免了分开操作时可能出现的溢出问题(比如 logits 很大时,Softmax 的分母溢出)。
5. 为什么不直接用 Softmax?
你可能好奇:为什么不用 Softmax
+ 对数 + 取负,而是用 LogSoftmax
?
答案是数值稳定性:
- 单独计算
Softmax
(指数运算)可能导致溢出(比如 ( e 1000 e^{1000} e1000 ))。 LogSoftmax
把指数和对数合并为 ( z j − log ( ∑ e z k ) z_j - \log(\sum e^{z_k}) zj−log(∑ezk) ),计算更稳定。
6. 使用场景对比
nn.CrossEntropyLoss
:- 输入:logits。
- 场景:标准多分类任务(图像分类、文本分类)。
- 优点:简单直接。
LogSoftmax
+nn.NLLLoss
:- 输入:logits 需手动转为对数概率。
- 场景:需要显式控制 Softmax,或者模型已输出对数概率。
- 优点:灵活性高。
7. 小结:为什么等价?
- 数学上:交叉熵 ( − log ( q correct ) -\log(q_{\text{correct}}) −log(qcorrect) ) 可以拆成两步:
LogSoftmax
:从 logits 到 ( log ( q ) \log(q) log(q) )。nn.NLLLoss
:从 ( log ( q ) \log(q) log(q) ) 到 ( − log ( q correct ) -\log(q_{\text{correct}}) −log(qcorrect) )。
- 实现上:
nn.CrossEntropyLoss
把这两步封装成一个函数,结果一致。 - 设计上:PyTorch 提供两种方式,满足不同需求。
所以,nn.CrossEntropyLoss
= LogSoftmax
+ nn.NLLLoss
不是巧合,而是交叉熵计算的自然分解。理解这一点,能帮助你更灵活地使用 PyTorch 的损失函数。
8. 彩蛋:手动推导
想自己验证?试试手动计算:
- logits
[1.0, 2.0, 0.5]
,目标是 1。 - Softmax:
[0.23, 0.63, 0.14]
。 - LogSoftmax:
[-1.47, -0.47, -1.97]
。 - NLL:
-(-0.47) = 0.47
。 - 直接用
nn.CrossEntropyLoss
,结果一样!
希望这篇博客解开了你的疑惑!
后记
2025年2月28日18点51分于上海,在grok3 大模型辅助下完成。