CrossEntropyLoss, NLLLoss 和 BCELoss 本质上都是基于交叉熵(cross entropy)的分类器的损失函数。但是三个函数的输入格式、计算方法和性能(收敛速度)有很大差别。本文记录笔者对此三者的学习笔记和理解。
{:.info}
交叉熵(Cross Entropy)
交叉熵(Cross Entropy)是Shannon信息论中一个重要概念,主要用于度量两个概率分布间的差异性信息。或曰,概率分布 $p$ 和概率分布 $q$ 的相似程度。如果 $p$ 和 $q$ 越相似,那么越能用 $p$ 近似表示 $q$ 或用 $q$ 近似表示 $p$ 。定义交叉熵为:
注意交叉熵不满足对称性。
nn.CrossEntropyLoss
1 | torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean') |
计算方法
根据pytorch官方文档,CrossEntropyLoss的输入值为”unnormalized scores for each class”,即未限制在(0,1)上的各个类别的得分。其表达式为
其中 $x$ 为一个样本,$class$ 为一个类别, $x[j]\in(-\infty,\infty)$ 为分类器给样本 $x$ 在类别 $j$ 上赋予的得分,或当weight不为空时,
其中 $weight[class]$ 为类别 $class$ 的权重。其值越大,总损失中 $class$ 类所占有的损失项的比重越大。
最终,总损失为每个样本$x$上的损失的加权平均,即
总结
- 是否支持多类别分类:支持
- 输入得分值域:$(-\infty,\infty)$
- 神经网络输出层是否需要激活/归一化: 不用
nn.NLLLoss
1 | torch.nn.NLLLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean') |
计算方法
NLLLoss
的用法,实际在CrossEntropyLoss
的文档中给出:
This criterion combines LogSoftmax and NLLLoss in one single class.
{:.success}
也就是说,CrossEntropyLoss
是NLLLoss
和LogSoftmax
的结合体。为看清这一点,我们回到式(2)。其中 $\log\frac{\exp(\cdot)}{\sum_j \exp(\cdot)}$ 就是LogSoftmax.
因此,如果CrossEntropyLoss
的输入值是 $(x,class)$, 那么 NLLLoss
的输入值就是 $LogSoftmax(x), class$。其中 $LogSoftmax(\cdot)$ 需要对 $x$ 的每个分量计算一次,最终 $dim(LogSoftmax(x)) = dim(x)$。除此之外,NLLLoss
在其他部分的计算过程与 CrossEntropyLoss
完全一致。
总结
- 是否支持多类别分类:支持
- 输入得分值域:$(-\infty,\infty)$
- 神经网络输出层是否需要激活/归一化: 需要,最后一层使用LogSoftmax.也可在神经网络中使用softmax或sigmoid,在计算损失函数时显式加入torch.log计算对数
nn.BCELoss
1 | torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean') |
计算方法
BCELoss
是专门针对二分类问题的交叉熵损失函数。其计算形式更加接近式(1):
总结
- 是否支持多类别分类:不支持。只支持二分类
- 输入得分值域:$(0,1)$。常搭配$sigmoid$一起使用
- 神经网络输出层是否需要激活/归一化: sigmoid进行归一化处理。
Tips: 在实际使用时,出现过使用BCELoss时算法不收敛、AUC奇低,但换成NLLLoss后一切都很好用的情况。
对比表格
项目\方法 | CrossEntropyLoss | NLLLoss | BCELoss |
---|---|---|---|
支持多分类 | 是 | 是 | 否 |
y_pred 值域 |
$(-\infty,\infty)$ | $(-\infty,\infty)$ | $(0,1)$ |
type(y_target) |
torch.LongTensor |
torch.LongTensor |
torch.DoubleTensor |
输出层是否需要归一化 | 否 | LogSoftmax |
sigmoid |