GHM Loss
Focal loss 的问题
Focal Loss的重点是关注难以分类的样本, 这些样本对模型的优化贡献最大. 但GHM(Gradient Harmonizing Mechanism) Loss认为让模型过多关注那些特别难分的样本是存在问题的. 样本中的离群点(Outliers)在模型已经收敛后仍然会被判别错误, 属于超难分类的样本. 让模型取关注这些样本, 会使得模型过度地去拟合噪音, 损伤了泛化性, 造成了过拟合现象.
GHM Loss 的原理
Focal Loss是从模型输出的分类置信度p的角度入手衰减样本的loss, GHM则是从梯度密度的角度去衰减样本的loss.
定义梯度密度, 从定义梯度模长开始. 考虑二分类交叉熵损失函数, 定义梯度模长为:
g=∣p−p∗∣={1−pp if p∗=1 if p∗=0 这样定义梯度模长是因为输出层的sigmoid函数与二分类交叉熵损失函数结合有:
p=sigmoid(x) LCE={−log(p),−log(1−p), if if p∗=1p∗=0 对应的损失函数对输出层的输入的梯度如下, 这一步的推到可以参考Softmax与交叉熵求导:
∂x∂LCE={p−1,p, if p∗=1 if p∗=0=p−p∗ 因此有:
g=∂x∂LCE 对于易于分类的样本, 它们的输出值会非常接近真实标签, 所以对梯度的贡献很小, 梯度模长接近0; 而难以分类的样本, 则会产生大梯度, 梯度模长接近于1. 因此梯度模长的大小可以衡量样本的分类难易程度.
衡量梯度模长与样本数量的关系, 如下图, 可以看到除了大量易分样本, 分类困难的样本数量也非常多, 这些样本可能是标注错误的噪音离群点, 过多的关注这些十分困难的分类样本, 不仅不会提升模型的分类效果, 反而会对模型质量带来一定的损伤.
因此GHM Loss在控制每个样本对梯度贡献的权重时, 选择使用样本梯度模长所在的一定范围内, 如果在这个范围内样本数量较多, 对其梯度加以抑制. 定义一个变量, 衡量出一定梯度范围内的样本数量, 这个变量的概念类似于密度. 定义梯度密度GD(g):
GD(g)=lε(g)1k=1∑Nδε(gk,g) δε(gk,g): 所有样本中, 梯度模长分布在(g−2ε,g+2ε)范围内的样本个数
lε(g): 代表了(g−2ε,g+2ε)区间的长度
GHM Loss定义, 对于每个样本, 使用梯度密度的倒数对样本的交叉熵损失进行缩放, 即:
LGHM−C=N1i=1∑NβiLCE(pi,pi∗)=i=1∑NGD(gi)LCE(pi,pi∗) 计算过程
关键是在于如何计算梯度密度. 首先, 把梯度模长范围划分成10个区域. 这里要求输入必须经过sigmoid计算, 这样梯度模长的范围就限制在0~1
之间:
class GHMC(nn.Module):
def __init__(self, bins=10, ......):
self.bins = bins
edges = torch.arange(bins + 1).float() / bins
......
>>> edges = tensor([0.0000, 0.1000, 0.2000, 0.3000, 0.4000,
0.5000, 0.6000, 0.7000, 0.8000,0.9000, 1.0000])
edges是每个区域的边界, 有了边界就很容易计算出梯度模长落入哪个区间内. 然后根据网络输出pred和ground true计算loss, 再将梯度密度作为权重, 对计算得到的BCE Loss进行缩放:
# 计算梯度模长
g = torch.abs(pred.sigmoid().detach() - target)
# n 用来统计有效的区间数。
# 假如某个区间没有落入任何梯度模长,密度为0,需要额外考虑,不然取个倒数就无穷了。
n = 0 # n valid bins
# 通过循环计算落入10个bins的梯度模长数量
for i in range(self.bins):
inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
num_in_bin = inds.sum().item()
if num_in_bin > 0:
# 重点,所谓的梯度密度就是1/num_in_bin
weights[inds] = num_labels / num_in_bin
n += 1
if n > 0:
weights = weights / n
# 把上面计算的weights填到binary_cross_entropy_with_logits里就行了
loss = torch.nn.functional.binary_cross_entropy_with_logits(
pred, target, weights, reduction='sum') / num_labels
参考资料