Distilling the Knowledge in Neural Network
Geoffrey Hinton, Oriol Vinyals, Jeff Dean
preprint arXiv:1503.02531, 2015
NIPS 2014 Deep Learning Workshop

简单总结

主要工作(What)

  1. “蒸馏”(distillation):把大网络的知识压缩成小网络的一种方法
  2. “专用模型”(specialist models):对于一个大网络,可以训练多个专用网络来提升大网络的模型表现

具体做法(How)

  1. 蒸馏:先训练好一个大网络,在最后的softmax层使用合适的温度参数T,最后训练得到的概率称为“软目标”。以这个软目标和真实标签作为目标,去训练一个比较小的网络,训练的时候也使用在大模型中确定的温度参数T
  2. 专用模型:对于一个已经训练好的大网络,可以训练一系列的专用模型,每个专用模型只训练一部分专用的类以及一个“不属于这些专用类的其它类”,比如专用模型1训练的类包括“显示器”,“鼠标”,“键盘”,...,“其它”;专用模型2训练的类包括“玻璃杯”,“保温杯”,“塑料杯”,“其它“。最后以专用模型和大网络的预测输出作为目标,训练一个最终的网络来拟合这个目标。

意义(Why)

  1. 蒸馏把大网络压成小网络,这样就可以先在训练阶段花费大精力训练一个大网络,然后在部署阶段以较小的计算代价来产生一个较小的网络,同时保持一定的网络预测表现。
  2. 对于一个已经训练好的大网络,如果要去做集成的话计算开销是很大的,可以在这个基础上训练一系列专用模型,因为这些模型通常比较小,所以训练会快很多,而且有了这些专用模型的输出可以得到一个软目标,实验证明使用软目标训练可以减小过拟合。最后根据这个大网络和一系列专用模型的输出作为目标,训练一个最终的网络,可以得到不错的表现,而且不需要对大网络做大量的集成计算

Abstract

提高机器学习算法表现的一个简单方法就是,训练不同模型然后对预测结果取平均。
但是要训练多个模型会带来过高的计算复杂度和部署难度。
可以将集成的知识压缩在单一的模型中。
论文使用这种方法在MNIST上做实验,发现取得了不错的效果。
论文还介绍了一种新型的集成,包括一个或多个完整模型和专用模型,能够学习区分完整模型容易混淆的细粒度的类别。

1 Introduction

昆虫有幼虫期和成虫期,幼虫期主要行为是吸收养分,成虫期主要行为是生长繁殖。
类似地,大规模机器学习应用可以分为训练阶段和部署阶段,训练阶段不要求实时操作,允许训练一个复杂缓慢的模型,这个模型可以是分别训练多个模型的集成,也可以是单独的一个很大的带有强正则比如dropout的模型。
一旦模型训练好,可以用不同的训练,这里称为“蒸馏”,去把知识转移到更适合部署的小模型上。

复杂模型学习区分大量的类,通常的训练目标是最大化正确答案的平均log概率,这么做有一个副作用就是训练模型同时也会给所有的错误答案分配概率,即使这个概率很小,而有一些概率会比其它的大很多。错误答案的相对概率体现了复杂模型的泛化能力。举个例子,宝马的图像被错认为垃圾箱的概率很低,但是这被个错认为垃圾桶的概率相比于被错认为胡萝卜的概率来说,是很大的。(可以认为模型不止学到了训练集中的宝马图像特征,还学到了一些别的特征,比如和垃圾桶共有的一些特征,这样就可能捕捉到在新的测试集上的宝马出现这些的特征,这就是泛化能力的体现)

将复杂模型转为小模型需要保留模型的泛化能力,一个方法就是用复杂模型产生的分类概率作为“软目标”来训练小模型。
当软目标的熵值较高时,相对于硬目标,每个训练样本提供更多的信息,训练样本之间会有更小的梯度方差。
所以小模型经常可以被训练在小数据集上,而且可以使用更高的学习率。

像MNIST这种分类任务,复杂模型可以产生很好的表现,大部分信息分布在小概率的软目标中。
为了规避这个问题,Caruana和他的合作者们使用softmax输出前的units值,而不是softmax后的概率,最小化复杂模型和简单模型的units的平方误差来训练小模型。
而更通用的方法,蒸馏法,先提高softmax的温度参数直到模型能产生合适的软目标。然后在训练小模型匹配软目标的时候使用相同的温度T。

被用于训练小模型的转移训练集可以包括未打标签的数据(可以没有原始的实际标签,因为可以通过复杂模型获取一个软目标作为标签),或者使用原始的数据集,使用原始数据集可以得到更好的表现。

2 Distillation

softmax公式: $ q_{i} = \frac{exp(z_{i}/T)}{\sum_{j}^{ }exp(z_{j}/T)} $
其中温度参数T通常设置为1,T越大可以得到更“软”的概率分布。
T越大,不同激活值的概率差异越小,所有激活值的概率趋于相同;T越小,不同激活值的概率差异越大
在蒸馏训练的时候使用较大的T的原因是,较小的T对于那些远小于平均激活值的单元会给予更少的关注,而这些单元是有用的,使用较高的T能够捕捉这些信息

最简单的蒸馏形式就是,训练小模型的时候,以复杂模型得到的“软目标”为目标,采用复杂模型中的较高的T,训练完之后把T改为1。

当部分或全部转移训练集的正确标签已知时,蒸馏得到的模型会更优。一个方法就是使用正确标签来修改软目标。
但是我们发现一个更好的方法,简单对两个不同的目标函数进行权重平均,第一个目标函数是和复杂模型的软目标做一个交叉熵,使用的复杂模型的温度T;第二个目标函数是和正确标签的交叉熵,温度设置为1。我们发现第二个目标函数被分配一个低权重时通常会取得最好的结果。

3 Preliminary experiments on MNIST

net layers units of each layer activation regularization test errors
single net1 2 1600 relu dropout 67
single net2 2 800 relu no 146

(防止表格黏在一起)

net large net small net temperature test errors
distilled net single net1 single net2 20 74

第一个表格中是两个单独的网络,一个大网络和一个小网络。
第二个表格是使用了蒸馏的方法,先训练大网络,然后根据大网络的“软目标”结果和温度T来训练小网络。
可以看到,通过蒸馏的方法将大网络中的知识压缩到小网络中,取得了不错的效果。

4 Experiments on speech recognition

system Test Frame Accuracy Word Error Rate on dev set
baseline 58.9% 10.9%
10XEnsemble 61.1% 10.7%
Distilled model 60.8% 10.7%

其中basline的配置为

  • 8 层,每层2560个relu单元
  • softmax层的单元数为14000
  • 训练样本大小约为 700M,2000个小时的语音文本数据

10XEnsemble是对baseline训练10次(随机初始化为不同参数)然后取平均

蒸馏模型的配置为

  • 使用的候选温度为{1, 2, 5, 10}, 其中T为2时表现最好
  • hard target 的目标函数给予0.5的相对权重

可以看到,相对于10次集成后的模型表现提升,蒸馏保留了超过80%的效果提升

5 Training ensembles of specialists on very big datasets

训练一个大的集成模型可以利用并行计算来训练,训练完成后把大模型蒸馏成小模型,但是另一个问题就是,训练本身就要花费大量的时间,这一节介绍的是如何学习专用模型集合,集合中的每个模型集中于不同的容易混淆的子类集合,这样可以减小计算需求。专用模型的主要问题是容易集中于区分细粒度特征而导致过拟合,可以使用软目标来防止过拟合。

5.1 JFT数据集

JFT是一个谷歌的内部数据集,有1亿的图像,15000个标签。google用一个深度卷积神经网络,训练了将近6个月。
我们需要更快的方法来提升baseline模型。

5.2 专用模型

将一个复杂模型分为两部分,一部分是一个用于训练所有数据的通用模型,另一部分是很多个专用模型,每个专用模型训练的数据集是一个容易混淆的子类集合。这些专用模型的softmax结合所有不关心的类为一类来使模型更小。

为了减少过拟合,共享学习到的低水平特征,每个专用模型用通用模型的权重进行初始化。另外,专用模型的训练样本一半来自专用子类集合,另一半从剩余训练集中随机抽取。

5.3 将子类分配到专用模型

专用模型的子类分组集中于容易混淆的那些类别,虽然计算出了混淆矩阵来寻找聚类,但是可以使用一种更简单的办法,不需要使用真实标签来构建聚类。对通用模型的预测结果计算协方差,根据协方差把经常一起预测的类作为其中一个专用模型的要预测的类别。几个简单的例子如下。

JFT 1: Tea party; Easter; Bridal shower; Baby shower; Easter Bunny; ...
JFT 2: Bridge; Cable-stayed bridge; Suspension bridge; Viaduct; Chimney; ...
JFT 3: Toyota Corolla E100; Opel Signum; Opel Astra; Mazda Familia; ...

5.4 实验表现

system Conditional Test Accuracy Test Accuracy
baseline 43.1% 25.0%
61 specialist models 45.9% 26.1%

6 Soft Targets as Regularizers

对于前面提到过的,对于大量数据训练好的语音baseline模型,用更少的数据去拟合这个模型的时候,使用软目标可以达到更好的效果,减小过拟合。实验结果如下。

system & training set Train Frame Accuracy Test Frame Accuracy
baseline(100% training set) 63.4% 58.9%
baseline(3% training set) 67.3% 44.5%
soft targets(3% training set) 65.4% 57.0%

论文笔记:蒸馏网络(Distilling the Knowledge in Neural Network)的更多相关文章

  1. 论文笔记:(CVPR2019)Relation-Shape Convolutional Neural Network for Point Cloud Analysis

    目录 摘要 一.引言 二.相关工作 基于视图和体素的方法 点云上的深度学习 相关性学习 三.形状意识表示学习 3.1关系-形状卷积 建模 经典CNN的局限性 变换:从关系中学习 通道提升映射 3.2性 ...

  2. 论文笔记《ImageNet Classification with Deep Convolutional Neural Network》

    一.摘要 了解CNN必读的一篇论文,有些东西还是可以了解的. 二.结构 1. Relu的好处: 1.在训练时间上,比tanh和sigmod快,而且BP的时候求导也很容易 2.因为是非饱和函数,所以基本 ...

  3. 论文笔记之:Hybrid computing using a neural network with dynamic external memory

    Hybrid computing using a neural network with dynamic external memory Nature  2016 原文链接:http://www.na ...

  4. 【论文阅读】Sequence to Sequence Learning with Neural Network

    Sequence to Sequence Learning with NN <基于神经网络的序列到序列学习>原文google scholar下载. @author: Ilya Sutske ...

  5. 论文翻译:2020_WaveCRN: An efficient convolutional recurrent neural network for end-to-end speech enhancement

    论文地址:用于端到端语音增强的卷积递归神经网络 论文代码:https://github.com/aleXiehta/WaveCRN 引用格式:Hsieh T A, Wang H M, Lu X, et ...

  6. 论文翻译:2020_FLGCNN: A novel fully convolutional neural network for end-to-end monaural speech enhancement with utterance-based objective functions

    论文地址:FLGCNN:一种新颖的全卷积神经网络,用于基于话语的目标函数的端到端单耳语音增强 论文代码:https://github.com/LXP-Never/FLGCCRN(非官方复现) 引用格式 ...

  7. 论文翻译:2022_PACDNN: A phase-aware composite deep neural network for speech enhancement

    论文地址:PACDNN:一种用于语音增强的相位感知复合深度神经网络 引用格式:Hasannezhad M,Yu H,Zhu W P,et al. PACDNN: A phase-aware compo ...

  8. 论文阅读笔记十八:ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation(CVPR2016)

    论文源址:https://arxiv.org/abs/1606.02147 tensorflow github: https://github.com/kwotsin/TensorFlow-ENet ...

  9. 论文笔记——Channel Pruning for Accelerating Very Deep Neural Networks

    论文地址:https://arxiv.org/abs/1707.06168 代码地址:https://github.com/yihui-he/channel-pruning 采用方法 这篇文章主要讲诉 ...

随机推荐

  1. 探索PowerShell----函数

    http://marui.blog.51cto.com/1034148/294775/

  2. (转)有关Queue队列

    Queue Queue是python标准库中的线程安全的队列(FIFO)实现,提供了一个适用于多线程编程的先进先出的数据结构,即队列,用来在生产者和消费者线程之间的信息传递 基本FIFO队列 clas ...

  3. Android popupwindow 演示样例程序一

    经过多番測试实践,实现了popupwindow 弹出在指定控件的下方.代码上有凝视.有须要注意的地方.popupwindow 有自已的布局,里面控件的监听实现都有.接下来看代码实现. 项目资源下载:点 ...

  4. Jquery-easyUi------(布局)

    <%@ Master Language="C#" Inherits="System.Web.Mvc.ViewMasterPage" %> <! ...

  5. MCU相关知识

    一个处理器达到 200 DMIPS的性能,这是个什么概念? DMIPS全称叫Dhrystone MIPS 这项测试是用来计算同一秒内系统的处理能力,它的单位以百万来计算,也就是(MIPS) 上面的意思 ...

  6. NSUserDefaults设置bool值重启后bool只设置丢失问题

    本文转载至 http://blog.csdn.net/cerastes/article/details/38036875   NSUserDefaultsbool同步synchronize无效 今天使 ...

  7. python的其他安全隐患

    零.绪论 python这里以python2.7为研究对象,对应的我们会简要说明一下python3,其他指与反序列化无关的安全隐患问题. 一.标准输入输出: 1.首先,我们来看下标准输入输出 impor ...

  8. 170411、java Socket通信的简单例子(UDP)

    服务端代码: package com.bobohe.socket; import java.io.*; import java.net.*; class UDPServer { public stat ...

  9. 160428、JavaScript知识总结—cookie及其应用

    一.cookie基本介绍 cookie是document的对象.cookie可以使得JavaScript代码能够在用户的硬盘上持久地存储数据,并且能够获得以这种方式存储的数据.cookie还可以用于客 ...

  10. Servlet------>jsp输出JavaBean

    JavaBean是遵循特殊写法的java类 它通常具有如下特点: 1.这个java类必须具有一个无参的构造函数 2.属性必须私有化 3.私有化必须通过public类暴露给其他程序,而且方法的命名必须遵 ...