PyTorch: CrossEntropyLoss vs. NLLLoss vs. BCELoss

CrossEntropyLoss, NLLLoss 和 BCELoss 本质上都是基于交叉熵(cross entropy)的分类器的损失函数。但是三个函数的输入格式、计算方法和性能(收敛速度)有很大差别。本文记录笔者对此三者的学习笔记和理解。
{:.info}

交叉熵(Cross Entropy)

交叉熵(Cross Entropy)是Shannon信息论中一个重要概念,主要用于度量两个概率分布间的差异性信息。或曰,概率分布 $p$ 和概率分布 $q$ 的相似程度。如果 $p$ 和 $q$ 越相似,那么越能用 $p$ 近似表示 $q$ 或用 $q$ 近似表示 $p$ 。定义交叉熵为:

注意交叉熵不满足对称性。

nn.CrossEntropyLoss

1
torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

计算方法

根据pytorch官方文档,CrossEntropyLoss的输入值为”unnormalized scores for each class”,即未限制在(0,1)上的各个类别的得分。其表达式为

其中 $x$ 为一个样本,$class$ 为一个类别, $x[j]\in(-\infty,\infty)$ 为分类器给样本 $x$ 在类别 $j$ 上赋予的得分,或当weight不为空时,

其中 $weight[class]$ 为类别 $class$ 的权重。其值越大,总损失中 $class$ 类所占有的损失项的比重越大。

最终,总损失为每个样本$x$上的损失的加权平均,即

总结

  • 是否支持多类别分类:支持
  • 输入得分值域:$(-\infty,\infty)$
  • 神经网络输出层是否需要激活/归一化: 不用

nn.NLLLoss

1
torch.nn.NLLLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

计算方法

NLLLoss的用法,实际在CrossEntropyLoss的文档中给出:

This criterion combines LogSoftmax and NLLLoss in one single class.
{:.success}

也就是说,CrossEntropyLossNLLLossLogSoftmax的结合体。为看清这一点,我们回到式(2)。其中 $\log\frac{\exp(\cdot)}{\sum_j \exp(\cdot)}$ 就是LogSoftmax.

因此,如果CrossEntropyLoss的输入值是 $(x,class)$, 那么 NLLLoss 的输入值就是 $LogSoftmax(x), class$。其中 $LogSoftmax(\cdot)$ 需要对 $x$ 的每个分量计算一次,最终 $dim(LogSoftmax(x)) = dim(x)$。除此之外,NLLLoss 在其他部分的计算过程与 CrossEntropyLoss 完全一致。

总结

  • 是否支持多类别分类:支持
  • 输入得分值域:$(-\infty,\infty)$
  • 神经网络输出层是否需要激活/归一化: 需要,最后一层使用LogSoftmax.也可在神经网络中使用softmax或sigmoid,在计算损失函数时显式加入torch.log计算对数

nn.BCELoss

1
torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')

计算方法

BCELoss是专门针对二分类问题的交叉熵损失函数。其计算形式更加接近式(1):

总结

  • 是否支持多类别分类:不支持。只支持二分类
  • 输入得分值域:$(0,1)$。常搭配$sigmoid$一起使用
  • 神经网络输出层是否需要激活/归一化: sigmoid进行归一化处理。

Tips: 在实际使用时,出现过使用BCELoss时算法不收敛、AUC奇低,但换成NLLLoss后一切都很好用的情况。

对比表格

项目\方法 CrossEntropyLoss NLLLoss BCELoss
支持多分类
y_pred值域 $(-\infty,\infty)$ $(-\infty,\infty)$ $(0,1)$
type(y_target) torch.LongTensor torch.LongTensor torch.DoubleTensor
输出层是否需要归一化 LogSoftmax sigmoid