处理样本不平衡的LOSS—Focal Loss
0 前言
Focal Loss
是为了处理样本不平衡问题而提出的,经时间验证,在多种任务上,效果还是不错的。在理解Focal Loss
前,需要先深刻理一下交叉熵损失,和带权重的交叉熵损失。然后我们从样本权重的角度出发,理解Focal Loss
是如何分配样本权重的。Focal是动词Focus的形容词形式,那么它究竟Focus在什么地方呢?(详细的代码请看Gitee)。
1 交叉熵
1.1 交叉熵损失(Cross Entropy Loss)
有\(N\)个样本,输入一个\(C\)分类器,得到的输出为\(X\in \mathcal{R}^{N\times C}\),它共有\(C\)类;其中某个样本的输出记为\(x\in \mathcal{R}^{1\times C}\),即\(x[j]\)是\(X\)的某个行向量,那么某个交叉熵损失可以写为如下公式:
=-\log \left( \frac{\exp \left( x\left[\text{class} \right] \right)}{\sum_j{\exp\left( x\left[ j \right] \right)}} \right)
=-x\left[\text{class} \right] +\log \left( \sum_j{\exp\left( x\left[ j \right] \right)} \right)
\tag{1-1}
\]
其中\(\text{class}\in [0,\ C)\)是这个样本的类标签
,如果给出了类标签的权重向量\(W\in \mathcal{R}^{1\times C}\),那么带权重的交叉熵损失可以更改为如下公式:
\tag{1-2}
\]
最终对这个\(N\)个样本的损失求和
或者求平均
:
\sum_{i}^{N}{\text{loss}(x^{(i)},\ \text{class}^{(i)})}&\text{, sum}\\
\dfrac{1}{N}\sum_{i}^{N}{\text{loss}(x^{(i)},\ \text{class}^{(i)})}&\text{, mean}
\end{cases}
\tag{1-3}
\]
这个就是我们平时经常用到的交叉熵损失了。
1.2 二分类交叉熵损失(Binary Cross Entropy Loss)
上面所提到的交叉熵损失是适用于多分类(二分类及以上)的,但是它的公式看起来似乎与我们平时在书上或论文中看到的不一样,一般我们常见的交叉熵损失公式如下:
\]
这是一个典型的二分类交叉熵损失,其中\(y\in\{0,\ 1\}\)表示标签值,\(\hat{y}\in[0,\ 1]\)表示分类模型的类别1预测值
。上面这个公式是一个综合的公式,它等价于:
-\log{\hat{y}_0} &y=0 \\
-\log{\hat{y}_1} &y=1
\end{cases}; \quad
\text{where}\quad \hat{y}_0+\hat{y}_1 = 1
\]
其中\(\hat{y}_0, \hat{y}_1\)是二分类模型输出的2个伪概率值
。
例:如果二分类模型是神经网络,且最后一层为: 2个神经元+Softmax,那么\(\hat{y}_0, \hat{y}_1\)就对应着这两个神经元的输出值。当然它也可以带上类别的权重。
同样地,有\(N\)个样本,输入一个2分类器,得到的输出为\(X\in \mathcal{R}^{N\times 2}\),再经过Softmax函数,\(\hat{Y}=\sigma(X)\in \mathcal{R}^{N\times 2}\),标签为\(Y\in \mathcal{R}^{N\times 2}\),每个样本的二分类损失记为\(l^{(i)}, i=0,1,2,\cdots,N\),最终对这个\(N\)个样本的损失求和
或者求平均
:
\sum_{i}^{N}l^{(i)}&\text{, sum}\\
\dfrac{1}{N}\sum_{i}^{N}l^{(i)}&\text{, mean}
\end{cases}; \ \ \
l^{(i)} = -y^{(i)}\log{\hat{y}^{(i)}}-(1-y^{(i)})\log{(1-\hat{y}^{(i)})}
\]
注:如果一次只训练一个样本,即\(N=1\),那么上面带类别权重的损失中的权重是无效的。因为
权重
是相对的,某一个样本的权重大,那么必然需要有另一个样本的权重小,这样才能体现出这一批样本中某些样本的重要性。\(N=1\)时,已没有权重的概念,它是唯一的,也是最重要的。\(N=1\),或者说batch_size=1
这种情况在训练视频\文章数据时,是会常出现的。由于我们显示/内存的限制,而视频/文章数据又比较大,一次只能训练一个样本,此时我们就需要注意权重的问题了。
2 Focal Loss
2.1 基本思想
一般来讲,Focal Loss(以下简称FL)[1]是为解决样本不平衡
的问题,但是更准确地讲,它是为解决难分类样本(Hard Example)
和易分类样本(Easy Example)
的不平衡问题。对于样本不平衡,其实通过上面的带权重的交叉熵损失便可以一定程度上解决这个问题,但是在实际问题中,以权重来解决样本不平衡问题的效果不够理想,此时我们应当思考,表面上我们的样本不平衡,但实质上导致效果不好的原因也许并不是简单地因为样本不平衡,而是因为样本中存在一些Hard Example
,同时存在许多Easy Example
,Easy Example虽然容易被分类器分辨,损失较小,但是由于其数量大,它们累积起来依然于大于Hard Example的Loss值,因此我们需要给Hard Example较大的权重,而Easy Example较小的权重。
那么什么叫Hard Example,什么叫Easy Example呢?看下面的图就知道了。
图2-1 Hard Example | 图2-2 Easy Example1 | 图2-3 Easy Example2 | 图2-4 Example Space |
假设,我们的任务是训练一个分类器,分类出人和马,对于上面的三张图,图2-2和图2-3应该是非常容易判断出来的,但是图2-1就是不那么容易了,它即有人的特征,又有马的特征,非常容易混淆。这种样本虽然在数据集中出现的频率可能并不高,但是想要提高分类器的性能,需要着力解决这种样本分类问题。
提出Hard Example和Easy Example后,可以将样本空间划分为如图2-4所示的样本空间。其中纵轴为多数类样本(Majority Class)
和少数类样本(Minority Class)
,上面的带权重的交叉熵损失只能解决Majority Class和Minority Class的样本不平衡问题,并没有考虑Hard Example和Easy Example的问题,Focal Loss的提出就是为解决这个难易样本的分类问题。
2.2 Focal Loss解决方案
要解决难易样本的分类问题,首先就需要找出Hard Example和Easy Example。这对于神经网络来说,应该是一件比较容易的事情。如图2-6所示,这是一个5分类的网络,神经网络的最后一层输出时,加上一个Softmax
或者Sigmoid
就会得到输出的伪概率值,代表着模型预测的每个类别的概率,
图2-6 Easy Example Classifier Output | 图2-7 Hard Example Classifier Output |
图2-6中,样本标签为1,分类器输出值最大的为第1个神经元(以0开始计数),这刚好预测准确,而且其输出值2也比其它神经元的输出值要大不少,因此可以认为这是一个易分类样本(Easy Example);图2-7的样本标签是3,分类器输出值最大的为第4个神经元,并且这几个神经元的输出值都相差不大,神经网络无法准确判断这个样本的类别,所以可以认为这是一个难分类样本(Hard Example)。其实说白了,判断Easy/Hard Example的方法就是看分类网络的最后的输出值。如果网络预测准确,且其概率较大,那么这是一个Easy Example,如果网络输出的概率较小,这是一个Hard Example。下面用数学公式严谨地表达来Focal Loss的表达式。
令一个\(C\)类分类器的输出为\(\boldsymbol{y}\in \mathcal{R}^{C\times 1}\),定义函数\(f\)将输出\(\boldsymbol{y}\)转为伪概率值\(\boldsymbol{p}=f(\boldsymbol{y})\),当前样本的类标签为\(t\),记\(p_t=\boldsymbol{p}[t]\),它表示分类器预测为\(t\)类的概率值,再结合上面的交叉熵损失,定义Focal Loss为:
\tag{2-1}
\]
这实质就是交叉熵损失前加了一个权重,只不过这个权重有点不一样的来头。为了更好地控制前面权重的大小,可以给前面的权重系数添加一个指数\(\gamma\),那么更改式(2-1):
\tag{2-2}
\]
其中\(\gamma\)一值取值为2就好,\(\gamma\)取值为0时与交叉熵损失等价,\(\gamma\)越大,就越抑制Easy Example的损失,相对就会越放大Hard Example的损失。同时为解决样本类别不平衡的问题,可以再给式(2-2)添加一个类别的权重\(\alpha_t\)(这个类别权重上面的交叉熵损失已经实现):
\tag{2-3}
\]
到这里,Focal Loss理论就结束了,非常简单,但是有效。
3 Focal Loss实现(Pytorch)
3.1 交叉熵损失实现(numpy)
为了更好的理解Focal Loss的实现,先理解交叉熵损失的实现,我这里用numpy简单地实现了一下交叉熵损失。
import numpy as np
def cross_entropy(output, target):
out_exp = np.exp(output)
out_cls = np.array([out_exp[i, t] for i, t in enumerate(target)])
ce = -np.log(out_cls / out_exp.sum(1))
return ce
代码中第5行,可能稍微有点难以理解,它不过是为了找出标签对应的输出值。比如第2个样本的标签值为3,那它分类器的输出应当选择第2行,第3列的值。
3.2 Focal Loss实现
下面的代码的1012行:依据输出,计算概率,再将其转为`focal_weight`;1516行,将类权重和focal_weight
添加到交叉熵损失,得到最终的focal_loss
;18~21行,实现mean
和sum
两种reduction方法,注意求平均不是简单的直接平均,而是加权平均。
class FocalLoss(nn.Module):
def __init__(self, gamma=2, weight=None, reduction='mean'):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.weight = weight
self.reduction = reduction
def forward(self, output, target):
# convert output to pseudo probability
out_target = torch.stack([output[i, t] for i, t in enumerate(target)])
probs = torch.sigmoid(out_target)
focal_weight = torch.pow(1-probs, self.gamma)
# add focal weight to cross entropy
ce_loss = F.cross_entropy(output, target, weight=self.weight, reduction='none')
focal_loss = focal_weight * ce_loss
if self.reduction == 'mean':
focal_loss = (focal_loss/focal_weight.sum()).sum()
elif self.reduction == 'sum':
focal_loss = focal_loss.sum()
return focal_loss
注:上面实现中,output的维度应当满足
output.dim==2
,并且其形状为(batch_size, C)
,且target.max()<C
。
总结
Focal Loss从2017年提出至今,该论文已有2000多引用,足以说明其有效性。其实从本质上讲,它也只不过是给样本重新分配权重,它相对类别权重的分配方法,只不过是将样本空间进行更为细致的划分,从图2-4很容易理解,类别权重的方法,只是将样本空间划分为蓝色线上下两个部分,而加入难易样本的划分,又可以将空间划分为左右两个部分,如此,样本空间便被划分4个部分,这样更加细致。其实借助于这个思想,我们是否可以根据不同任务的需求,更加细致划分我们的样本空间,然后再相应的分配不同的权重呢?
参考文献
处理样本不平衡的LOSS—Focal Loss的更多相关文章
- 【深度学习】Focal Loss 与 GHM——解决样本不平衡问题
Focal Loss 与 GHM Focal Loss Focal Loss 的提出主要是为了解决难易样本数量不平衡(注意:这有别于正负样本数量不均衡问题)问题.下面以目标检测应用场景来说明. 一些 ...
- Focal Loss笔记
论文:<Focal Loss for Dense Object Detection> Focal Loss 是何恺明设计的为了解决one-stage目标检测在训练阶段前景类和背景类极度不均 ...
- focal loss和ohem
公式推导:https://github.com/zimenglan-sysu-512/paper-note/blob/master/focal_loss.pdf 使用的代码:https://githu ...
- 焦点损失函数 Focal Loss 与 GHM
文章来自公众号[机器学习炼丹术] 1 focal loss的概述 焦点损失函数 Focal Loss(2017年何凯明大佬的论文)被提出用于密集物体检测任务. 当然,在目标检测中,可能待检测物体有10 ...
- 技术干货 | 基于MindSpore更好的理解Focal Loss
[本期推荐专题]物联网从业人员必读:华为云专家为你详细解读LiteOS各模块开发及其实现原理. 摘要:Focal Loss的两个性质算是核心,其实就是用一个合适的函数去度量难分类和易分类样本对总的损失 ...
- 深度学习笔记(八)Focal Loss
论文:Focal Loss for Dense Object Detection 论文链接:https://arxiv.org/abs/1708.02002 一. 提出背景 object detect ...
- 论文阅读笔记四十四:RetinaNet:Focal Loss for Dense Object Detection(ICCV2017)
论文原址:https://arxiv.org/abs/1708.02002 github代码:https://github.com/fizyr/keras-retinanet 摘要 目前,具有较高准确 ...
- Focal Loss理解
1. 总述 Focal loss主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题.该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘. 2. 损失函数形式 ...
- Focal Loss(RetinaNet) 与 OHEM
Focal Loss for Dense Object Detection-RetinaNet YOLO和SSD可以算one-stage算法里的佼佼者,加上R-CNN系列算法,这几种算法可以说是目标检 ...
随机推荐
- 源映射错误:request failed with status 404
源映射错误:request failed with status 404:源映射错误:request failed with status 404
- Springboot应用中@EntityScan和@EnableJpaRepositories的用法
在Springboot应用开发中使用JPA时,通常在主应用程序所在包或者其子包的某个位置定义我们的Entity和Repository,这样基于Springboot的自动配置,无需额外配置,我们定义的E ...
- getopt、getopt_long和getopt_long_only解析命令行参数
一:posix约定: 下面是POSIX标准中关于程序名.参数的约定: 程序名不宜少于2个字符且不多于9个字符: 程序名应只包含小写字母和阿拉伯数字: 选项名应该是单字符或单数字,且以短横 '-' 为前 ...
- postman 中post方式提交数据
post方式提交数据时,把参数填写在body中而不是pOST下面的哪一行
- webstorm破解教程
1.下载地址 官网:https://www.jetbrains.com/webstorm/ 下载好之后按照提示安装即可,这里就不再多说了.下面直接说说如何使用补丁破解. 2.使用补丁破解 (http: ...
- ubuntu环境变量的三种设置方法
一:设置环境变量的三种方法 1.1 临时设置 export PATH=/home/yan/share/usr/local/arm/3.4.1/bin:$PATH 1.2 当前用户的全局设置 打开~/. ...
- codeforces1253F(图转换为树减少复杂度)
题意: 给定一个无向图,其中1-k为充电桩,然后给定q个询问\(u_i, v_i\)(都是充电桩),然后问从其中一个充电桩到达另外一个充电桩需要最小的电池的容量. 每经过一条边都需要消耗一定的能量,到 ...
- LEMP--如何在Ubuntu上安装Linux、Nginx、MySQL和PHP
简介 LEMP是用来搭建动态网站的一组软件,首字母缩写分别表示Linux.Nginx(Engine-X).MySQL和PHP. 本文将讲述如何在Ubuntu安装LEMP套件.当然,首先要安装Ubunt ...
- iptables单个规则实例
iptables -F? # -F 是清除的意思,作用就是把 FILTRE TABLE 的所有链的规则都清空 iptables -A INPUT -s 172.20.20.1/32 -m state ...
- <climits>头文件
<climits>头文件定义的符号常量 CHAR_MIN char的最小值SCHAR_MAX signed char 最大值SCHAR_MIN signed char 最小值UCHAR_ ...