【论文阅读】Deep Mutual Learning
文章:Deep Mutual Learning
出自CVPR2017(18年最佳学生论文)
文章链接:https://arxiv.org/abs/1706.00384
代码链接:https://github.com/YingZhangDUT/Deep-Mutual-Learning
主要贡献:
提出了一种简单且普遍适用的方法,通过在相同/不同的未预训练的网络中进行相互蒸馏,来提高深层神经网络的性能。通过这种方法,我们可以获得比静态教师从强网络中提取的网络性能更好的紧凑网络.
和有教师指导的蒸馏模型相比,相互学习策略具有以下优点:1)随着学生网络的增加其效率也得到提高;2)它可以应用在各种各样的网络中,包括大小不同的网络;3)即使是非常大的网络采用相互学习策略,其性能也能够得到提升
由于是学生网络相互学习,而不是传统知识萃取,文章也说明了两个以上网络共同学习的策略,并从熵值的角度给出理论支持。
知识蒸馏的内容不再赘述,https://blog.csdn.net/nature553863/article/details/80568658
整理得非常完善。
网络结构及损失函数:
每个网络由常规的有监督学习损失和拟态损失来共同训练。拟态损失是指是每个学生的后验类别要和其他学生的类别概率相一致。(拟态在生物学中是指一个动物在进化的过程中会获得与成功物种相似的特征,成功混淆掠食者的认知,从而靠近拟态物种)
如图一所示,同个输入,分别经过Net1和Net2两个网络(和孪生网络不同,这里权重不共享),得到两个logits记为z1,z2(所说文章中有明确提到这里的z是softmax之后的,但是我整的时候取得却是softmax之前的),经过softmax得到预测的软分布p1,p2。送入KL散度计算这两个分布的相似性作为拟态损失。与label比较计算标签损失。
注意,上式表示的是:假设p2是数据的真实分布,p1是数据的理论分布,即是让p1的分布更接近p2的分布。这个别搞反了,毕竟KL散度有几个性质:(1)不对称性:尽管KL散度从直观上是个度量或距离函数,但它并不是一个真正的度量或者距离,因为它不具有对称性,即D(P||Q)!=D(Q||P)。(恩,还有非负和不满足三角不等,这两个在这没用就不写了)。所幸,KL散度求出的值是在0-1之间(分布完全相同则为0),比一般的L1,L2损失更适合拟合分布的损失。但是,如果两个分配P,Q离得很远,完全没有重叠的时候,那么KL散度值是没有意义的,这在学习算法中是比较致命的,这就意味这这一点的梯度为0。梯度消失了。(说实话,文中没有给出其他度量下的结果,我试了KL变为JS,L1,L2这几种分布度量,结果十分相近,KL或许不是最合适的)
网络Net1的总损失为:
其中:
网络Net2的总损失为:
优化步骤:
相互学习策略在每一个基于小批量的模型更新步骤和整个培训过程中执行。在每次迭代中,我们计算两个模型的预测,并根据另一个模型的预测更新两个网络的参数。两个网络的优化是迭代进行的,直到收敛。优化细节总结在算法1中。
恩,这个很简单,就不翻译了。就是固定一个网络更新另一个,循环交替直至收敛。
扩展到更多的网络(理论上你想要几个,只要机器够牛逼都行):
提出的相互学习策略可以扩展到更多的学生网络,假设有K个学生网络,其损失目标函数变为:
添加了系数k-1,以确保培训主要由对真实标签的监督学习指导.拟态损失为与其他所有网络概率分布相似性的平均值。恩,之后的实验都默认使用这一种哈。
另一个学习策略是将所有其他K-1网络的集合作为单个教师,以提供平均的分布概率,这非常类似于蒸馏方法,θk的目标函数可以写成:
拟态损失为与其他所有网络概率分布平均值的相似性。
~_~。这个大家可以类比多分类问题中一个多分类器与多个二分类器。大家可能下意识的以为后一种策略会有更高的精度,但事实恰恰相反,文章最后倒是有给出理由,后面再说吧。
实验部分:
两个数据集,一个是cifar100,一个是行人再检测的maket1501。参数设置什么的文章有给,我不列出来了,反正只要你设定的不是太离谱,都可以得到好结果。
Independent表示单个网络运行结果,DML-Independent,恩,就是DML的结果减去Independent的结果,普遍提升1.5个点(机器问题只试了resnet的,的确能提升这么多)。
这是和知识蒸馏进行对比,当然了,是最原始的histon2014的那版。
这个,算消融实验,只是结果太好了,作者就和其他方法进行了对比。single-query表示单索引,multi-query表示多索引,百度下或问做再检测的朋友是啥子。
这个是扩展到N个网络的结果,恩,之前不是说了两种扩展方法吗,这里都是第一种。左边这个图呢,不是有N个网络嘛,N个网络分别得到N个精度,然后取平均,就得到左图了(应该是凑字数用的吧,应该没人会这么做吧,好浪费的,投票不行吗)。右图表示基于所有成员的连接特征进行匹配,精度明显随网络的增多而增加。
理论:
最后作者回答了为什么这样联合学习能提升精度。《Entropy-sgd:
Biasing gradient descent into wide
valleys》这篇文章大家有兴趣的可以去瞅瞅,个人认为解释了我的某些疑惑。通常情况下有很多的解决方案能够让训练误差变为0,然而有些解决方法的泛化能力要强一点。因为梯度下降法找打的最优点不是在狭小的谷底内,而是在宽阔的峡谷中,当我们加入相对熵以后能够让网络找到更小的值,从而实现更优的结果。
对于DML模型和独立模型,比较了在每个模型参数中加入独立高斯噪声和可变标准差σ后,学习模型的训练损失。发现两个损失基本是相同的,但是在加入这个扰动后,独立模型的训练损失相比DML模型的损失会增加得多。这表明DML模型已经找到了一个更广泛的极小值,这有望提供更好的泛化性能。
在提出的DML策略中,每个学生都是由队列中的所有其他学生单独教授的,不管队列中有多少学生。一种替代的DML策略称为ensemble策略,要求每个学生匹配队列中所有其他学生的联合预测。人们可以合理地期望这种ensemble方法会更好。由于集合预测比单个预测更好,因此它应该提供一个更清晰、更强的教学信号——更像常规蒸馏。在实践中,联合教学(ensemble)比依次单独教学(peer)的效果更差。通过对联合ensemble教学信号的分析,与依次单独peer教学相比,ensemble目标在真标签上的峰值要比peer目标明显得多,从而使DML的预测熵值大于DML-E。因此,虽然集成的噪声平均特性对于做出正确的预测是有效的,但它实际上不利于提供一个教学信号,其中secondary probabilities在总体信号中是突出的,具有高熵的后验代表这模型训练的更鲁棒的解决方案。(大家联想一下软标签与硬标签的不同。)
当然,文中也有说明How
a Better Minima is Found,但是没有看明白,对前置知识(《Entropy-sgd:
Biasing gradient descent into wide valleys》与《Regularizing
neural networks by penalizing confident output
distributions》)没有足够的了解,翻译一下原文:当要求每个网络匹配其对等网络的概率估计时,如果给定网络预测为零,而其教师/对等网络预测为非零,则会受到严重惩罚。因此,DML的总体效果是,当每个网络独立地将一个小的mass放在一组小的secondary
probabilities上时,DML中的所有网络都倾向于聚合它们对secondary
probabilities的预测,i)将更多的mass放在secondary
probabilities上,以及ii)将非零mass置于更明显的secondary
probabilities上。通过比较图中由DML培训的CIFAR-100上的resnet-32获得的top5与独立培训的resnet-32模型的概率来说明这一影响。对于每个训练样本,根据模型产生的后验概率对top5进行排序。在这里,我们可以看到,对于独立学习来说,将mass分配到top1下的概率比DML学习要快得多。这可以用DML训练模型和独立训练模型的所有训练样本的平均熵值来量化,分别为1.7099和0.2602。因此,我们的方法与基于熵正则化的方法有联系,以寻找宽的极小值,但通过对“合理”的选择的相互概率匹配,而不是盲目的高熵偏好。
写在最后:个人认为适用范围挺广的,效果提升也很明显,也还留了许多可以改进的点。知识蒸馏的想法是通过教师网络提供的软标签,给予学生网络硬标签所不能表达的新的信息量(类似与猫更像狗而不是蛋糕)。而文中说的联合学习的方式也提供了更多的信息量。或许有的同学会认为,这是学习了更鲁棒的特征,而不是蕴含了更多的信息量,恩,不相信的朋友不妨试一下把KL(logits)换成KL(features)看下结果。全连接fc之后为logits,之前为features。
【论文阅读】Deep Mutual Learning的更多相关文章
- [论文阅读] Deep Residual Learning for Image Recognition(ResNet)
ResNet网络,本文获得2016 CVPR best paper,获得了ILSVRC2015的分类任务第一名. 本篇文章解决了深度神经网络中产生的退化问题(degradation problem). ...
- Paper | Deep Mutual Learning
目录 1. 动机详述和方法简介 2. 相关工作 3. 方法 3.1 Formulation 3.2 实现 3.3 弱监督学习 4. 实验 4.1 基本实验 4.2 深入实验 [算法和公式很simple ...
- [论文阅读笔记] Adversarial Learning on Heterogeneous Information Networks
[论文阅读笔记] Adversarial Learning on Heterogeneous Information Networks 本文结构 解决问题 主要贡献 算法原理 参考文献 (1) 解决问 ...
- Deep Mutual Learning
论文地址: https://arxiv.org/abs/1706.00384 论文简介 该论文探讨了一种与模型蒸馏(model distillation)相关却不同的模型---即相互学习(mutual ...
- 论文笔记——Deep Residual Learning for Image Recognition
论文地址:Deep Residual Learning for Image Recognition ResNet--MSRA何凯明团队的Residual Networks,在2015年ImageNet ...
- [论文理解]Deep Residual Learning for Image Recognition
Deep Residual Learning for Image Recognition 简介 这是何大佬的一篇非常经典的神经网络的论文,也就是大名鼎鼎的ResNet残差网络,论文主要通过构建了一种新 ...
- 论文阅读:Multi-task Learning for Multi-modal Emotion Recognition and Sentiment Analysis
论文标题:Multi-task Learning for Multi-modal Emotion Recognition and Sentiment Analysis 论文链接:http://arxi ...
- 论文阅读《End-to-End Learning of Geometry and Context for Deep Stereo Regression》
端到端学习几何和背景的深度立体回归 摘要 本文提出一种新型的深度学习网络,用于从一对矫正过的立体图像回归得到其对应的视差图.我们利用问题(对象)的几何知识,形成一个使用深度特征表示的代价量(c ...
- 论文阅读 Inductive Representation Learning on Temporal Graphs
12 Inductive Representation Learning on Temporal Graphs link:https://arxiv.org/abs/2002.07962 本文提出了时 ...
随机推荐
- SpringBoot学习(一)基础篇
目录 关于Springboot Springboot优势 快速入门 关于SpringBoot Spring Boot是由Pivotal团队提供的全新框架,其设计目的是用来简化新Spring应用的初始搭 ...
- (day31) Event+协程+进程/线程池
目录 昨日回顾 GIL全局解释器锁 计算密集型和IO密集型 死锁现象 递归锁 信号量 线程队列 FOFI队列 LIFO队列 优先级队列 今日内容 Event事件 线程池与进程池 异步提交和回调函数 协 ...
- 微信支付 get_brand_wcpay_request fail,Undefined variable: openid
本文将为您描述微信H5支付,微信JSAPI支付返回支付签名验证失败的解决方法 微信JSAPI支付时报这个错误 查看错误详情 alert(JSON.stringify(res)) 微信商户平台相关设置: ...
- mysql中if函数的正确使用姿势
--为了今天要写的内容,运行了将近7个小时的程序,在数据库中存储了1千万条数据.-- 今天要说的是mysql数据库的IF()函数的一个实例. 具体场景如下, 先看看表结构: CREATE TABLE ...
- PHP比较IP大小
function cmpLoginIP($a, $b) { return bindec(decbin(ip2long($a['loginIp']))) > bindec(decbin(ip2lo ...
- ulua、tolua原理解析
在聊ulua.tolua之前,我们先来看看Unity热更新相关知识. 什么是热更新 举例来说: 游戏上线后,玩家下载第一个版本(70M左右或者更大),在运营的过程中,如果需要更换UI显示,或者修改游戏 ...
- 汇编窥探Swift String的底层
String(字符串),是所有编程语言中非常重要的成员,因此非常值得去深入研究.众所周知,字符串的本质是字符序列,由若干个字符组成.比如字符串 "iOS" 由 'i'.'O'.'S ...
- 死磕 java线程系列之线程池深入解析——定时任务执行流程
(手机横屏看源码更方便) 注:java源码分析部分如无特殊说明均基于 java8 版本. 注:本文基于ScheduledThreadPoolExecutor定时线程池类. 简介 前面我们一起学习了普通 ...
- Spring Boot (日志篇):Log4j2整合ELK,搭建实时日志平台
一.安装JDK1.8以上版本 1.从Oracle官网上下载Linux x64版本的 下载地址: http://www.oracle.com/technetwork/java/javase/downlo ...
- 《Effective Java》 读书笔记(一) 使用静态构造方法代替传统构造函数
对象的创建与销毁 ITEM1 使用静态工厂方法代替构造函数 传统的新建一个对象的方法是通过构造函数: Foo foo =new Foo(); 一个类也可以提供一个静态方法产生一个对象: Boolean ...