Focal loss是目标检测领域的一篇十分经典的论文,它通过改造损失函数提升了一阶段目标检测的性能,背后关于类别不平衡的学习的思想值得我们深入地去探索和学习。正负样本失衡不仅仅在目标检测算法中会出现,在别的机器学习任务中同样会出现,这篇论文为我们解决类似问题提供了一个很好的启发,所以我认为无论是否从事目标检测领域相关工作,都可以来看一看这篇好论文。

论文的关键性改进在于对损失函数的改造以及对参数初始化的设置。

首先是对损失函数的改造。论文中指出,限制目标检测网络性能的一个关键因素是类别不平衡。二阶段目标检测算法相比于一阶段目标检测算法的优点在于,二阶段的目标检测算法通过候选框筛选算法(proposal stage)过滤了大部分背景样本(负样本),使得正负样本比例适中;而一阶段的目标检测算法中,需要处理大量的负样本,使得包含目标的正样本信息被淹没。这使得一阶段目标检测算法的识别准确度上比不上二阶段的目标检测算法。

为了解决这个问题,Focal loss使用了动态加权的思想,对于置信度高的样本,损失函数进行降权;对于置信度低的样本,损失函数进行加权,使得网络在反向传播时,置信度低的样本能够提供更大的梯度占比,即从未学习好的样本中获取更多的信息(就像高中时期的错题本一样,对于容易错的题目,包含了更多的信息量,需要更加关注这种题目;而对于屡屡正确的题目,可以少点关注,说明已经掌握了这类型的题目)

其巧妙之处就在于,通过了网络本身输出的概率值(置信度)去构建权重,实现了自适应调整权重的目的。

公式的讲解

Focal loss是基于交叉熵损失构建的,二元交叉熵的公式为

\[\mathrm{CE}(p, y)=\left\{\begin{array}{ll}
-\log (p) & \text { if } y = +1 \\
-\log (1-p) & \text { y = -1 }
\end{array}\right.
\]

为了方便表示,定义\(p_t\)为分类正确的概率

\[p_{t}=\left\{\begin{array}{ll}
p & \text { if } y = +1 \\
1-p & \text { y = -1 }
\end{array}\right.
\]

则交叉熵损失表示为\(CE(p,y)=CE(p_t)=-log(p_t)\)。如前文所述,通过置信度对损失进行缩放得到Focal loss。

\[FL(p_t)=-\alpha_t(1-p_t)^\gamma log(p_t)= \alpha_t(1-p_t)^\gamma\times CE(p_t)
\]

其中,\(\alpha_{1}=\left\{\begin{array}{ll}
\alpha & \text { if } y = +1 \\
1-\alpha & \text { y = -1 }
\end{array}\right.\)为缩放乘数(直接调整正负样本的权重),\(\gamma\)为缩放因子,\((1-p_t)\)可以理解为分类错误的概率。公式中起到关键作用的部分是\((1-p_t)^\gamma\)。为了给易分样本降权,通常设置\(\gamma>1\)。

对于正确分类的样本,\(p_t \to 1 \Rightarrow(1-p_t) \to 0\),受到\(\gamma\)的影响很大,\((1-p_t)^\gamma \approx 0\);

对于错误分类的样本,\(p_t \to 0 \Rightarrow(1-p_t) \to 1\),受到\(\gamma\)的影响较小,\((1-p_t)^\gamma \approx (1-p_t)\),对于难分样本的降权较小。

Focal loss本质上是通过置信度给易分样本进行更多的降权,对难分样本进行更少的降权,实现对难分样本的关注。

参数初始化

论文中还有一个比较重要的点是对于子网络最后一层权重的初始化方式,关系到网络初期训练的性能。这里结合论文和我看过的一篇博文进行详细的展开。常规的深度学习网络初始化算法,使用的分布是高斯分布,根据概率论知识,两个高斯分布的变量的乘积仍然服从高斯分布。假设权重\(w\sim N(\mu_w,\sigma_w^2)\),最后一层的特征\(x\sim N(\mu_x,\sigma_x^2)\),则\(wx \sim N(\mu_{wx},\sigma_{wx}^2)\)。

\[\mu_{wx}=\frac{\mu_w \sigma_x^2+\mu_x \sigma_w^2}{\sigma_x^2+\sigma_w^2}\\
\sigma_{wx}=\frac{\sigma_x^2\sigma_w^2}{\sigma_x^2+\sigma_w^2}
\]

其中\(x\)的分布取决于网络的结果,\(w\)的分布参数为\(\mu_w=0,\sigma_w^2=10^{-4}\),只需\(x\)的分布参数满足\(\sigma_x^2\gg 10^{-4},\sigma_x^2\gg10^{-4}\mu_x\)成立,有如下的不等式。(一般情况下,这两个条件是成立的。)

\[\mu_{wx}=\frac{\mu_w \sigma_x^2+\mu_x \sigma_w^2}{\sigma_x^2+\sigma_w^2}=\frac{10^{-4}\mu_x}{\sigma_x^2+10^{-4}}\ll\frac{10^{-4}\mu_x}{10^{-4}\mu_x+10^{-4}}=\frac{1}{1+\frac{1}{\mu_x}}\approx0 \text{由于}\mu_x\text{一般为分数(网络的输入经过归一化到0至1,随着网络加深的连乘,分数会越来越小)}\\
\sigma_{wx}=\frac{\sigma_x^2\sigma_w^2}{\sigma_x^2+\sigma_w^2}=\frac{10^{-4}}{1+\frac{10^{-4}}{\sigma_x^2}}\approx10^{-4} \text{由于}\sigma_x^2\gg10^{-4}
\]

根据上述推导,\(wx\)服从一个均值为0,方差很小的高斯分布,可以在很大概率上认为它就等于0,所以网络最后一层的输出为

\[p=sigmoid(wx+b)=sigmoid(b)=\frac{1}{1+e^{-b}}=\pi
\]

令\(\pi\)为网络初始化时输出为正类的概率,设置为一个很小的值(0.01),则网络在训练初期,将样本都划分为负类,对于正类\(p_t=0.01\),负类\(p_t=0.99\),则训练初期,正类都被大概率错分,负类都被大概率正确分类,所以在训练初期更加关注正类,避免初期的正类信息被淹没在负类信息中。

总结

总的来说,Focal loss通过对损失函数的简单改进,实现了一种自适应的困难样本挖掘策略,使得网络在学习过程中关注更难学习的样本,在一定程度上解决了正负样本分布不均衡的问题(由于正负样本分布不均衡,对于稀少的正样本学习不足,导致正样本普遍表现为难分样本)。

参考资料

论文原文

一篇不错的解析博客

Focal loss论文解析的更多相关文章

  1. 论文阅读笔记四十四:RetinaNet:Focal Loss for Dense Object Detection(ICCV2017)

    论文原址:https://arxiv.org/abs/1708.02002 github代码:https://github.com/fizyr/keras-retinanet 摘要 目前,具有较高准确 ...

  2. Focal Loss for Dense Object Detection 论文阅读

    何凯明大佬 ICCV 2017 best student paper 作者提出focal loss的出发点也是希望one-stage detector可以达到two-stage detector的准确 ...

  3. [论文理解]Focal Loss for Dense Object Detection(Retina Net)

    Focal Loss for Dense Object Detection Intro 这又是一篇与何凯明大神有关的作品,文章主要解决了one-stage网络识别率普遍低于two-stage网络的问题 ...

  4. 论文阅读|Focal loss

    原文标题:Focal Loss for Dense Object Detection 概要 目标检测主要有两种主流框架,一级检测器(one-stage)和二级检测器(two-stage),一级检测器, ...

  5. 深度学习笔记(八)Focal Loss

    论文:Focal Loss for Dense Object Detection 论文链接:https://arxiv.org/abs/1708.02002 一. 提出背景 object detect ...

  6. Focal Loss笔记

    论文:<Focal Loss for Dense Object Detection> Focal Loss 是何恺明设计的为了解决one-stage目标检测在训练阶段前景类和背景类极度不均 ...

  7. [Network Architecture]Mask R-CNN论文解析(转)

    前言 最近有一个idea需要去验证,比较忙,看完Mask R-CNN论文了,最近会去研究Mask R-CNN的代码,论文解析转载网上的两篇博客 技术挖掘者 remanented 文章1 论文题目:Ma ...

  8. 处理样本不平衡的LOSS—Focal Loss

    0 前言 Focal Loss是为了处理样本不平衡问题而提出的,经时间验证,在多种任务上,效果还是不错的.在理解Focal Loss前,需要先深刻理一下交叉熵损失,和带权重的交叉熵损失.然后我们从样本 ...

  9. 目标检测 | RetinaNet:Focal Loss for Dense Object Detection

    论文分析了one-stage网络训练存在的类别不平衡问题,提出能根据loss大小自动调节权重的focal loss,使得模型的训练更专注于困难样本.同时,基于FPN设计了RetinaNet,在精度和速 ...

随机推荐

  1. 关于bat批处理的一些操作,如启动jar 关闭进程等

    先说一下学习这个的前提: 公司要写个生成uid的工具,整完了之后就又整批处理工具,出于此目的,也是为了丰富自己的知识,就学习了一下,下面是相关的批处理脚本 我花了半天的时间找了相关的bat批处理,但是 ...

  2. Unity报与System.IO相关的错误

    比如这个: Type `System.IO.FileInfo' does not contain a definition for `OpenText' and no extension method ...

  3. python之结合if条件判断和生成随机数的相关知识,完成石头剪刀布的游戏

    程序开始,显示下面提示信息: 请输入:剪刀(0).石头(1).布(2): 用户输入数字0-2中的一个数字,与系统随机生成的数字比较后给出结果信息. 例如:输入0后,显示如下 你的输入为:剪刀(0) 随 ...

  4. Redux异步解决方案之Redux-Thunk原理及源码解析

    前段时间,我们写了一篇Redux源码分析的文章,也分析了跟React连接的库React-Redux的源码实现.但是在Redux的生态中还有一个很重要的部分没有涉及到,那就是Redux的异步解决方案.本 ...

  5. css3渐变色实现小功能 ------ css(linaer-gradient)

    由沿直线两种或多种颜色之间的渐进转换的图像.它的结果是数据类型的对象,这是一种特殊的类型. 与任何梯度一样,线性梯度没有内在维度 ; 即,它没有天然或优选的尺寸,也没有优选的比例.其具体尺寸将与其适用 ...

  6. 有关Sql中时间范围的问题

    背景 有时候需要利用sql中处理关于时间的判别问题,简单的如比较时间的早晚,判断一个时间是否在一段时间内的问题等.如果简单将时间判断与数值比较等同,那就会出现一些问题. 处理方式 处理Sql时间范围的 ...

  7. leetcode刷题-48旋转图像

    题目 给定一个 n × n 的二维矩阵表示一个图像. 将图像顺时针旋转 90 度. 思路 没有想到.看过解答后知道可以转置加翻转即可,且能达到最优的时间复杂度O(N^2). 实现 class Solu ...

  8. 动手编写—动态数组(Java实现)

    目录 数组基础回顾 自定义动态数组 动态数组的设计 抽象父类接口设计 抽象父类设计 动态数组之DynamicArray 补充数组缩容 全局的关系图 声明 数组基础回顾 1.数组是一种常见的数据结构,用 ...

  9. jmeter连数据库

    前提:jmeter不能直接连数据库,需要导入一个jar包 步骤: 1.右键线程组--添加--配置元件--JDBC Connection Configuration 2.jdbc的基本配置:可以修改jd ...

  10. oracle之三手工完全恢复

    手工完全恢复 3.1 完全恢复:通过备份.归档日志.current log ,将database恢复到failure 前的最后一次commit状态. 3.2 完全恢复的步骤 1)restore: OS ...