论文地址: https://arxiv.org/abs/1706.00384

论文简介

该论文探讨了一种与模型蒸馏(model distillation)相关却不同的模型---即相互学习(mutual learning)。 蒸馏从一个强大的大型预训练教师网络开始,并向未经训练的小型学生网络进行单向知识转移。 相反,在相互学习中,我们从一群未经训练的学生网络开始,他们同时学习一起解决任务。 具体来说,每个学生网络都有两个的损失函数:一种传统的监督性损失函数,以及一种模仿性的损失函数(mimicry loss),使每个学生的后验概率分布与其他学生的类别概率保持一致。

通过这种方式进行训练,结果表明,在这种基于同伴教学(peer-teaching)的情景中,每个学生的学习都比在传统的监督学习方案中单独学习要好得多。 此外,以这种方式训练的学生网络比来自更大的预训练教师的传统蒸馏训练的学生网络获得更好的结果。论文实验表明,各种网络架构可以从相互学习中受益,并在CIFAR-100识别和Market-1501 ReID 数据集上获得令人信服的结果。

模型设计

损失函数

损失函数 函数设计
传统监督损失函数 softmax + 交叉熵损失函数
模仿性的损失函数 softmax + KL散度函数

[两个网络]

[多个网络]

实验结果

实验网络

Cifar-100 实验

[实验设置]

paras values
optimizer SGD with Nesterov
base lr 0.1
momentum 0.9
batch size 64
epoch 200
note The learning rate dropped by 0.1 every 60 epochs

[实验结果]

[实验结论]

  1. 与独立学习相比, ResNet-32,MobileNet 和 WRN-28-10各种不同网络都在DML提高了性能。
  2. 容量较小的网络(ResNet-32和MobileNet)通常可以从DML获取更大的提升。
  3. 虽然WRN-28-10是一个比MobileNet或ResNet-32大得多的网络,但它仍然受益于与较小的网络一起训练。
  4. 与独立学习相比,使用DML训练一组大型网络(WRN-28-10)仍然是有益的。 因此,与模式蒸馏的传统智慧相反,我们看到一个大型的预训练的教师网络不是必不可少的.

Market-1501 实验结果

[实验设置]

每个MobileNet相互学习DML都以双网络方式中进行训练,并报告两个网络的平均性能

paras values
optimizer Adam
β1/β2 0.5/0.999
base lr 0.0002
momentum 0.9
batch size 16
iterations 100,000

[实验结果]

[实验结论]

可以看到,与独立学习相比, 无论是否在ImageNet上进行预训练,DML极大地提高了MobileNet的性能。还可以看出,用两个MobileNets训练的所提出的DML方法的性能显着优于先前的最先进的深度学习方法。

与模型蒸馏比较

[实验结果]

[实验结论]
Table4 将DML与模型蒸馏进行了比较,其中教师网络(Net1)经过预先训练,并为学生网络(Net2)提供固定的后验概率目标。

  1. 正如预期的那样,与独立学习相比,来自强大的预训练教师网络的常规蒸馏方法确实提高了学生网络的表现(Net1蒸馏Net2)
  2. 与蒸馏相比,DML将两个网络一起训练, 两个网络都得到了改进。这意味着在相互学习的过程中,通过与先验未经训练的学生的互动,教师角色的网络实际上变得比预先训练的教师更好。

相互学习的网络与性能联系

[实验设置]

之前实验研究以2名学生队列为例。在这个实验中,论文研究了DML如何与队列中的更多学生进行互动。图2(a)显示了Market-1501上的结果,其中DML训练增加了MobileNets的群组大小。图中显示了平均mAP以及标准偏差。

[实验结果]

[实验结论]

从图2(a)可以看出,平均单一网络的mAP性能随着DML队列中的网络数量的增加而增加。这表明,与越来越多的同龄人一起学习时,学生的泛化能力得到提高。从标准偏差中我们也可以看到,随着DML网络数量的增加,结果越来越稳定。

训练多个网络时的一种常用技术是将它们组合为一个整体并进行组合预测。在图2(b)中,我们使用与图2(a)相同的模型,但基于整体进行预测, 在整体上(基于所有成员的连锁特征进行匹配)而不是报告每个人的平均预测。从结果我们可以看出,集合预测优于预期的各个网络预测(图2 (b)vs(a).

Deep Mutual Learning的更多相关文章

  1. 【论文阅读】Deep Mutual Learning

    文章:Deep Mutual Learning 出自CVPR2017(18年最佳学生论文) 文章链接:https://arxiv.org/abs/1706.00384 代码链接:https://git ...

  2. Paper | Deep Mutual Learning

    目录 1. 动机详述和方法简介 2. 相关工作 3. 方法 3.1 Formulation 3.2 实现 3.3 弱监督学习 4. 实验 4.1 基本实验 4.2 深入实验 [算法和公式很simple ...

  3. 论文笔记: Mutual Learning to Adapt for Joint Human Parsing and Pose Estimation

    Mutual Learning to Adapt for Joint Human Parsing and Pose Estimation 2018-11-03 09:58:58 Paper: http ...

  4. Deep Residual Learning

    最近在做一个分类的任务,输入为3通道车型图片,输出要求将这些图片对车型进行分类,最后分类类别总共是30个. 开始是试用了实验室师姐的方法采用了VGGNet的模型对车型进行分类,据之前得实验结果是训练后 ...

  5. (转) Playing FPS games with deep reinforcement learning

    Playing FPS games with deep reinforcement learning 博文转自:https://blog.acolyer.org/2016/11/23/playing- ...

  6. (zhuan) Deep Reinforcement Learning Papers

    Deep Reinforcement Learning Papers A list of recent papers regarding deep reinforcement learning. Th ...

  7. Learning Roadmap of Deep Reinforcement Learning

    1. 知乎上关于DQN入门的系列文章 1.1 DQN 从入门到放弃 DQN 从入门到放弃1 DQN与增强学习 DQN 从入门到放弃2 增强学习与MDP DQN 从入门到放弃3 价值函数与Bellman ...

  8. Open source packages on Deep Reinforcement Learning

    智能车 self driving car + 强化学习 reinforcement learning + 神经网络 模拟 https://github.com/MorvanZhou/my_resear ...

  9. (转) Deep Reinforcement Learning: Playing a Racing Game

    Byte Tank Posts Archive Deep Reinforcement Learning: Playing a Racing Game OCT 6TH, 2016 Agent playi ...

随机推荐

  1. C++购书系统

    C++购书系统——来自班里某位同学的小学期作业 这是一个购书系统,模拟网上购书的流程.用户可以在这个小程序里输入对应的数字进行浏览书籍信息,查看用户信息,查找书籍,购买书籍以及查询个人订单的操作. 以 ...

  2. 微信内无法自动跳转外部浏览器打开H5分享链接的解决办法

    很多情况下我们用微信分享转发H5链接的时候,都无法在微信内打开,即使开始能打开,过一段时间就会被拦截,拦截后再打开微信会提示 “已停止访问该网址” ,那么导致这个情况的因素有哪些呢,主要有以下四点 1 ...

  3. Map the Debris 轨道周期

    返回一个数组,其内容是把原数组中对应元素的平均海拔转换成其对应的轨道周期. 原数组中会包含格式化的对象内容,像这样 {name: 'name', avgAlt: avgAlt}. 至于轨道周期怎么求, ...

  4. Url校验正则

    最近需要对HTTP请求合法性做一些校验,在网上查找了一些关于URL合法性的正则表达式. 在github上的有个关于weburl匹配的gist: https://gist.github.com/dper ...

  5. linux环境下在springboot项目中获取项目路径(用于保存文件等)

    //application.properties中设置:(file.path=static/qrfile/)//保存到static文件夹下的qrfile目录@Value("${file.pa ...

  6. Salesforce Bulk API 基于.Net平台下的实施

    在最近的salesforce实施项目中应用到Bulk API来做数据接口.顺便把实际应用的例子写下来.希望对做salesforce接口的朋友有借鉴作用. 一 参考网络牛人写好的Demo. 下载地址:h ...

  7. oracle篇 之 组函数

    一,常见组函数 1 . avg:求平均值,操作数值类型 2.sum:求和,操作数值类型 3.min:求最小值,操作任意类型 4.max:求最大值,操作任意类型 select avg(salary),s ...

  8. laravel带参数分页

    <!---分页--> <div id="pagination-box"> {{ $list->appends(['mobile'=>$mobil ...

  9. python学习day16 模块(汇总)

    模块(总) 对于range py2,与py3的区别: py2:range() 在内存中立即把所有的值都创建,xrange() 不会再内存中立即创建,而是在循环时边环边创建. py3:range() 不 ...

  10. 微信小程序与webview交互实现支付

    实现原理:点击h5网页的支付按钮——(跳转)——>嵌套改h5的小程序的支付页面——(处理支付)——>跳转至支付完成后的页面 注意:(1)网页h5中,引入微信的jssdk <scrip ...