聊聊损失函数1. 噪声鲁棒损失函数简析 & 代码实现
今天来聊聊非常规的损失函数。在常用的分类交叉熵,以及回归均方误差之外,针对训练样本可能存在的数据长尾,标签噪声,数据不均衡等问题,我们来聊聊适用不同场景有针对性的损失函数。第一章我们介绍,当标注标签存在噪声时可以尝试的损失函数,这里的标签噪声主要指独立于特征分布的标签噪声。代码详见pytorch, Tensorflow
Symmetric Loss Function
paper: Making Risk Minimization Tolerant to Label Noise
这里我们用最基础的二分类问题,和一个简化的假设"标注噪声和标签独立且均匀分布",来解释下什么是对标注噪声鲁棒的损失函数。假设整体误标注的样本占比为\(\eta\),则在真实标签y=0和y=1中均有\(\eta\)比例的误标注,1被标成0,0被标称1。带噪声的损失函数如下
L(f(x), y_{noise}) &= (1-\eta)*L(f(x), y) + \eta * L(f(x), 1-y) \\
& = (1-2\eta)*L(f(x),y) + \eta*[L(f(x),y)+L(f(x),1-y)] \\
& = (1-2\eta)*L(f(x),y) + \eta K \\
\end{align}
\]
因此如果损失函数满足\(L(f(x),y)+L(f(x),1-y)=constant\),则带噪声的损失函数会和不带噪声的\(L(f(x),y)\)收敛到相同的解。作者认为这样的损失函数就是symmetric的。
那有哪些常见的损失函数是symmetric loss呢?
MAE就是!对于二分类的softmax的输出层\(L(f(x),y)+L(f(x),1-y)=|y-f(x)| + |1-y-f(x)| = 1\)
敲黑板!记住这一点,因为后面的GCE和SCE其实都和MAE有着脱不开的关系。这里对symmetric loss的论证做了简化,细节详见论文~
Generalized Cross Entropy(GCE)
paper:Generalized Cross Entropy Loss for Training Deep Neural Networks with Noisy Labels
话接上文,MAE虽然是一种noise robust的损失函数,但是在深度学习中,因为MAE的梯度不是1就是-1,所有样本梯度scale都相同,缺乏对样本难易程度和模型置信度的刻画,因此MAE很难收敛。
作者提出了一种融合MAE和Cross Entropy的方案,话不多说直接上Loss
\]
作者使用了negative box-cox来作为损失函数,乍看和MAE没啥关系。不过改变q的取值,就会发现玄妙所在
- q->1: \(L=1-f_j(x)\), 就是MAE Loss
- q->0: 根据洛必达法则,对分子分母同时求导,就会得到\(L=-log(f_j(x))\), 就是Cross Entropy
所以GCE损失函数通过控制q的取值,在MAE和CrossEntropy中寻找折中点。这个和Huber Loss的设计有些相似,只不过Huber是显式的用alpha权重来融合RMSE和MAE,而GCE是隐式的融合。q->1, 对噪声的鲁棒性更好,但更难收敛。作者还提出了截断GCE,对过大的loss进行截断,这里就不细说了~
pytorch实现如下,TF实现见文首链接
class GeneralizeCrossEntropy(nn.Module):
def __init__(self, q=0.7):
super(GeneralizeCrossEntropy, self).__init__()
self.q = q
def forward(self, logits, labels):
# Negative box cox: (1-f(x)^q)/q
labels = torch.nn.functional.one_hot(labels, num_classes=logits.shape[-1])
probs = F.softmax(logits, dim=-1)
loss = (1 - torch.pow(torch.sum(labels * probs, dim=-1), self.q)) / self.q
loss = torch.mean(loss)
return loss
Symmetric Cross Entropy(SCE)
Symmetric Cross Entropy for Robust Learning with Noisy Labels
作者是从交叉熵的另一个含义出发, 最小化交叉熵实际是为了最小化预测分布和真实分布的KL散度, 二者关联如下,其中H(y)是真实标签的信息熵是个常数
KL(y||f(x)) &= \sum ylog(f(x)) - \sum ylog(y) \\
& = H(y, f(x)) - H(y) = CrossEntropy(y, f(x)) - H(y)
\end{align}
\]
考虑KL散度是非对称的,KL(y||f(x))!=KL(f(x)||y), 前者度量的是使用预测分布对数据进行编码导致的信息损失。然而当y本身存在噪声时,y可能不是正确标签,f(x)才是,这时就需要考虑另一个方向KL散度KL(f(x)||y)。于是作者使用对称KL对应的对称交叉熵(SCE)作为损失函数
= \sum_j y_jlog(f_j(x)) + \sum_j f_j(x)log(y_j)
\]
看到这里多少会有一种作者又拍脑袋了的感觉>.<.不过只需要对RCE的部分做下变换就豁然开朗了。以二分类为例,log(0)无法计算用常数A代替
\]
RCE的部分就是一个MAE!所以SCE本质上是显式的融合交叉熵和MAE!pytorch实现如下,TF实现见文首链接
class SymmetricCrossEntropy(nn.Module):
def __init__(self, alpha=0.1, beta=1):
super(SymmetricCrossEntropy, self).__init__()
self.alpha = alpha
self.beta = beta
self.epsilon = 1e-10
def forward(self, logits, labels):
# KL(p|q) + KL(q|p)
labels = torch.nn.functional.one_hot(labels, num_classes=logits.shape[-1])
probs = F.softmax(logits, dim=-1)
# KL
y_true = torch.clip(labels, self.eps, 1.0 - self.eps)
y_pred = probs
ce = -torch.mean(torch.sum(y_true * torch.log(y_pred), dim=-1))
# reverse KL
y_true = probs
y_pred = torch.clip(labels, self.eps, 1.0 - self.eps)
rce = -torch.mean(torch.sum(y_true * torch.log(y_pred), dim=-1))
return self.alpha * ce + self.beta * rce
Peer Loss
- Peer Loss Functions:Learning from Noisy Labels without Knowning Noise Rates
- NLNL: Negative Learning for Noisy Labels
Peer Loss相比GCE和SCE只适用于Cross Entropy, 它的设计更加灵活。每个样本的损失函数由常规loss和随机label的loss加权得到,权重为alpha,这里的loss支持任意的分类损失函数。随机label作者通过打乱一个batch里面的label顺序得到~
原理上感觉Peer Loss和NLNL很是相似都是negative learning的思路。对比下二者的损失函数,PL是最小化带噪标签y的损失的同时,最大化模型在随机标签上的损失。NL是直接最大化模型在非真实标签y上的损失。本质上都是negative learning,模型学习的不是x是什么,而是x不是什么,通过推动所有不正确分类的p->0,来得到正确的标签。从这个逻辑上说感觉Peer Loss和NLNL在高维的多分类场景下应该有更好的表现~
\]
\]
pytorch实现如下,TF实现见文首链接
class PeerLoss(nn.Module):
def __init__(self, alpha=0.5, loss):
super(PeerLoss, self).__init__()
self.alpha = alpha
self.loss = loss
def forward(self, preds, labels):
index = list(range(labels.shape[0]))
rand_index = random.shuffle(index)
rand_labels = labels[rand_index]
loss_true = self.loss(preds, labels)
loss_rand = self.loss(preds, rand_labels)
loss = loss_true - self.alpha * loss_rand
return loss
Bootstrap Loss
Training Deep Neural Networks on Noisy Labels with Bootstrapping
Bootstrap Loss是从预测一致性的角度来降低噪声标签对模型的影响,作者给了soft和hard两种损失函数。
soft Bootstrap是在Cross Entropy的基础上加上预测熵值,在最小化预测误差的同时最小化概率熵值,推动概率趋近于0/1,得到更置信的预测。这里其实用到了之前在半监督时提到的最小熵原则(小样本利器3. 半监督最小熵正则)也就是推动分类边界远离高密度区。
对噪声标签,模型初始预估的熵值会较大(p->0.5), 因为加入了熵正则项,模型即便不去拟合噪声标签,而是向正确标签移动(提高预测置信度降低熵值),也会降低损失函数.不过这里感觉熵正则的引入也有可能使得模型预测置信度过高而导致过拟合
\]
而Hard Bootstrap是把以上的预测概率值替换为预测概率最大的分类,Hard相比Soft更加类似label smoothing。举个栗子:当真实标签为y=0,噪声标签y=1,预测概率为[0.7,0.3]时,\(\beta=0.9\)时Bootstrap拟合的y实际为[0.1,0.9], 会降低错误标签的置信度,给模型学习其他标签的机会。而当模型预测和标签一致时y值不变,所以不会对正确有样本有太多影响,效果上作者评估也是Hard Bootstrap的效果要显著更好~
\]
pytorch实现如下,TF实现见文首链接
class BootstrapCrossEntropy(nn.Module):
def __init__(self, beta=0.95, is_hard=0):
super(BootstrapCrossEntropy, self).__init__()
self.beta = beta
self.is_hard = is_hard
def forward(self, logits, labels):
# (beta * y + (1-beta) * p) * log(p)
labels = F.one_hot(labels, num_classes=logits.shape[-1])
probs = F.softmax(logits, dim=-1)
probs = torch.clip(probs, self.eps, 1 - self.eps)
if self.is_hard:
pred_label = F.one_hot(torch.argmax(probs, dim=-1), num_classes=logits.shape[-1])
else:
pred_label = probs
loss = torch.sum((self.beta * labels + (1 - self.beta) * pred_label) * torch.log(probs), dim=-1)
loss = torch.mean(- loss)
return loss
对更多降噪loss感兴趣的朋友望过来https://github.com/subeeshvasu/Awesome-Learning-with-Label-Noise
又到年末填坑时间,争取把今年写了一半的草稿都补完,冲鸭!
Reference
- https://zhuanlan.zhihu.com/p/147371861
- https://blog.csdn.net/suredied/article/details/113528384
- https://zhuanlan.zhihu.com/p/370775044
- https://zhuanlan.zhihu.com/p/569526954
- https://zhuanlan.zhihu.com/p/299404214
聊聊损失函数1. 噪声鲁棒损失函数简析 & 代码实现的更多相关文章
- Huber鲁棒损失函数
在统计学习角度,Huber损失函数是一种使用鲁棒性回归的损失函数,它相比均方误差来说,它对异常值不敏感.常常被用于分类问题上. 下面先给出Huber函数的定义: 这个函数对于小的a值误差函数是二次的, ...
- H∞一般控制问题的鲁棒叙述性说明
Robust Control System:反馈控制有承受一定类不确定能力的影响,这一直保持在这种不确定的条件(制)稳定.动态特性(灵敏度)和稳态特性(逐步调整)的能力. 非结构不确定性(Unstru ...
- 如何编写高质量的 JS 函数(2) -- 命名/注释/鲁棒篇
本文首发于 vivo互联网技术 微信公众号 链接:https://mp.weixin.qq.com/s/sd2oX0Z_cMY8_GvFg8pO4Q作者:杨昆 上篇<如何编写高质量的 JS 函数 ...
- CVPR2020:基于自适应采样的非局部神经网络鲁棒点云处理(PointASNL)
CVPR2020:基于自适应采样的非局部神经网络鲁棒点云处理(PointASNL) PointASNL: Robust Point Clouds Processing Using Nonlocal N ...
- 基于2D-RNN的鲁棒行人跟踪
基于2D-RNN的鲁棒行人跟踪 Recurrent Neural Networks RNN 行人跟踪 读"G.L. Masala, et.al., 2D Recurrent Neural N ...
- SIFT+HOG+鲁棒统计+RANSAC
今天的计算机视觉课老师讲了不少内容,不过都是大概讲了下,我先记录下,细讲等以后再补充. SIFT特征: 尺度不变性:用不同参数的高斯函数作用于图像(相当于对图像进行模糊,得到不同尺度的图像),用得到的 ...
- Robust Locally Weighted Regression 鲁棒局部加权回归 -R实现
鲁棒局部加权回归 [转载时请注明来源]:http://www.cnblogs.com/runner-ljt/ Ljt 作为一个初学者,水平有限,欢迎交流指正. 算法参考文献: (1) Robust L ...
- 鲁棒图(Robustness Diagram)
鲁棒图与系统需求分析 鲁棒图(Robustness Diagram)是由Ivar Jacobson于1991年发明的,用以回答“每个用例需要哪些对象”的问题.后来的UML并没有将鲁棒图列入UML标准, ...
- python练习 英文字符的鲁棒输入+数字的鲁棒输入
鲁棒 = Robust 健壮 英文字符的鲁棒输入 描述 获得用户的任何可能输入,将其中的英文字符进行打印输出,程序不出现错误. ...
- 【论文阅读】Beyond OCR + VQA: 将OCR融入TextVQA的执行流程中形成更鲁棒更准确的模型
论文题目:Beyond OCR + VQA: Involving OCR into the Flow for Robust and Accurate TextVQA 论文链接:https://dl.a ...
随机推荐
- 又拍云+PicGo搭建图床教程
具体搭建方法 https://blog.csdn.net/qq_41684621/article/details/114068076 这里有个细节 注意这里一定要加上 http:// 否则在自动生成 ...
- OpenGL 模型加载详解
1. Assimp 目前为止,我们已经可以绘制一个物体,并添加不同的光照效果了.但是我们的顶点数据太过简单,只能绘制简单的立方体.但是房子汽车这种不规则的形状我们的顶点数据就很难定制了.索性,这部分并 ...
- inget
万能密码考点 payload ?id=1' or 1=1--+
- Java核心知识体系7:线程安全性讨论
Java核心知识体系1:泛型机制详解 Java核心知识体系2:注解机制详解 Java核心知识体系3:异常机制详解 Java核心知识体系4:AOP原理和切面应用 Java核心知识体系5:反射机制详解 J ...
- extern关键字的用法
extern关键字的理解 extern是C/C++语言中的一个关键字,用于声明一个变量或函数具有外部链接性(external linkage),即这些变量或函数可以被其他文件访问. 在C/C++中,如 ...
- DDD学习与感悟——总是觉得自己在CRUD怎么办?
一.DDD是什么? DDD全名叫做Domins drives Design:领域驱动设计.再说的通俗一点就是:通过领域建模的方式来实现软件设计. 问题来了:什么是软件设计?为什么要进行软件设计? 软件 ...
- 掌握这些,轻松管理BusyBox:inittab文件的配置和作用解析
BusyBox 是一个轻量级的开源工具箱,其中包含了许多标准的 Unix 工具,例如 sh.ls.cp.sed.awk.grep 等,同时它也支持大多数关键的系统功能,例如自启动.进程管理.启动脚本等 ...
- LeetCode15:三数之和(双指针)
解题思路:常规解法很容易想到O(n^3)的解法,但是,n最大为1000,很显然会超时. 如何优化到O(n^2),a+b+c =0,我们只需要判断 a+b的相反数是否在数组中出现,而且元素的取值范围在 ...
- Cloudeye对接Prometheus实现华为云全方位监控
本文分享自华为云社区<Cloudeye对接Prometheus实现华为云全方位监控>,作者:可以交个朋友 . 一. 背景 云眼系统Cloudeye服务为我们提供了针对弹性云服务器.宽带等资 ...
- android webview(外部浏览器)调起app
最近写的项目中涉及外部浏览器以及项目webview中调起app,所以总结下,和大家分享下. 总的实现方法还是比较简单的, 1:在清单中注册 首先在AndroidManifest文件中,注册一个过滤器 ...