Paper | 多任务学习的鼻祖
论文:Multitask learning
Caruana, Rich. "Multitask learning." Machine learning 28.1 (1997): 41-75.
Over 3600 citations (2019).
1. MTL的定义
Multi-task learning (MTL) is a subfield of machine learning in which multiple learning tasks are solved at the same time, while exploiting commonalities and differences across tasks.
Notice, MTL is a collection of ideas, techniques, and algorithms, not one algorithm.
E.g., k-nearest neighbor and decision trees.
2. MTL的机制
这部分我们主要设计一些简单的实验,来说明MTL的机制,并且纠正一些错误的观点。
这部分主要参考:Multitask Learning (R.CARUANA, 1997), 3531 citations.
2.1. Representation Bias
这一部分告诉我们:如果任务之间是无关的,那么多任务会起反作用。
假设任务T有两个极小值点A和B,任务T'有两个极小值点A和C。也就是说,这两个任务共享一个极小值点A。
经过精心设计,任务T落入A和B的概率相同,任务T'同理。
In the first, we selected the minima so that nets trained on T alone are equally likely to find A or B, and nets trained on
T' alone are equally likely to find A or C.
然后,作者让两个任务在同一个网络中进行训练。结果发现,这些网络通常会落入共同极小值点A。这说明:在MTL任务中,如果某些任务偏好一种对共享层的表示,那么其他任务也会对其产生偏好。
Nets trained on both T and T 0 usually fall into A for both tasks. This shows that MTL tasks prefer hidden layer representations that other tasks prefer.
当然了,前提是这种表示对所有任务都是局部最优的。
其次,作者改变了任务T的极小值点:T非常容易落入B,而不太容易落入A。T'的极小值点不变,落入A和C的概率相同。
此时再同时训练,作者发现:T'很难影响T的偏好,T仍然总是落入B。同时,T'落入C的概率大大增加了!
这表明:在T的影响下,共享层更偏好于远离极小值点B,这影响到了其他共享这一共享层的任务。
T creates a “tide” in the hidden layer representation towards B that flows away from A. T' has no preference for A or C, but is subject to the tide created by T.
这一部分告诉我们:如果两个任务鲜有关联,那么放在同一个网络训练的后果就是:两个任务都会偏离它们的共同极值点(局部,弱),因此共享层是没有意义的。
个人认为,这样反倒有可能限制二者的学习能力,因为两个任务被迫共享没有意义的隐藏层。
2.2. Uncorrelated Tasks May Help?
Holmstrom等人在1992年发现,在BP过程中添加噪声,有时可以提高模型的泛化能力。
有人就想:如果MTL网络中任务之间是无关的,那么它们和共享层之间的BP过程,彼此互为噪声(没有一致性的BP)。既然添加噪声可以提高泛化能力,那么uncorrelated tasks是否能互相促进呢?
作者设计了一个实验。对已有的MTL任务集和一个MTL网络,在训练时,extra tasks的training signals是随机提供的。
此时,如果MTL受益于“噪声”,那么该MTL网络将胜过STL网络;如果MTL受益于任务关联性,那么MTL网络的表现将不如STL网络。
实验证明,MTL要比STL更差。因此MTL一定是受益于相关任务的,这一点无需再质疑。
3. MTL的用途
3.1. Using the Future to Predict the Present
假设我们要设计一个肺炎诊断系统。该系统的输入是病人的体检数据(身高,体重,白细胞指标等),输出即是否得肺炎。
参加体检的有没患肺炎的:做了简单的检查就离开了;也有患肺炎入院的:入院后还进行了进一步的体检。
这两部分体检的数据,我们都有。前者就是所谓的"present",后者就是所谓的"future"。
前者是直接可用的,但后者是不能直接作为训练数据的!因为后者全部来源于肺炎病人,直接作为输入训练很有可能导致泛化能力的严重退化。
那么后者就要被浪费了吗?此时,MTL就提供了一个全新的方法:将这部分数据作为extra任务的输出标签。此时,该任务就可以帮助主任务更好地构建hidden layer representation。
实验表明,新构建的MTL任务,比STL任务的错误率要低5%以上。(参见Multitask learning P9)
更常见的情形是:来自"future"的数据没法在训练时实时获取,但可以在运行过程中获取而用于进一步调整模型。
比如,在自动驾驶任务中,交通标线lane markings是很重要的信息。虚实,单双,黄白,都有各自的意义。
但对于一辆正在行驶的车辆而言,它只能获取一定范围内的lane marking,远处的只能在事后获取。
假设我们将预测lane markings作为extra task,显然该任务对我们的主线任务有很大的帮助。我们可以将实时获取的lane markings作为之前预测结果的label,迭代训练我们的MTL模型。
3.2. Time Series Prediction
实现时序预测问题最简单的方法是:将每一个时刻对应一个输出。
这实际上就是我们的RNN。
3.3. Using Extra Tasks to Focus Attention
刚刚我们提到,lane markings对自动驾驶任务很重要。但在实际应用中,lane markings往往只占输入图片中的一小部分,并且在快速变化甚至消失,很容易被忽略。但它又很重要。
为此,我们将预测标线作为extra task,强迫网络建立起与标线相关的support,这样主线任务能从中受益。
3.4. Quantization Smoothing
我们的训练数据中,几乎都是经过量化得到的。如果我们能补偿量化带来的精度损失,那么模型的预测精度将会提高。
- 方案1:加入量化更细致的支线任务。
- 方案2:加入量化方式与主线任务不同的支线任务。
回到我们的肺炎诊断模型。肺炎诊断是一个复杂的问题,但我们将其简单地归为二分类问题,实际上是非常不合理的。
为了增强模型的预测平滑性,我们可以增加一个支线任务:预测留院时间。
该支线任务非常复杂。高危病人既有可能长时间留园,也有可能很快死亡。因此,该任务强迫我们的网络学习更加复杂的映射。
3.5. Some Inputs Work Better as Outputs
刚刚提到,有些特征无法事先得到,或是容易被忽略,因此我们将其作为输出。还有一些特征,我们将其作为输出,比作为输入更好。
参见Multitask learning P21。
这是一个人为设计的问题,此外还有一些更自然的应用。比如,当特征中存在噪声时,作为输出通常比作为输入更好。我们所谓的Dropout方法,实际上就等价于在输出中增加噪声。
3.6. 其他MTL方法
- KNN:KNN的关键在于衡量样本之间的距离。
在加入其他任务一起衡量距离后,表现会更好。参见Multitask learning P23。 - Decision tree:在自上而下的推导过程中,主线任务和支线任务一起决定是否分叉。如果这些任务是相关的,那么决策会更合理。
具体是计算information gain。可以引入参数lambda来控制支线任务的权重。
4. 讨论
4.1. Predictions for Multiple Tasks
MTL trains many tasks in parallel on one learner, but this does not mean one learned model should be used to make predictions for many tasks.
The reason for training multiple tasks on one learner is so one task can benefit from the information contained in the training signals of other tasks, not to reduce the number of models that must be learned.
比如我们的KNN和Decision tree,支线任务可以有权重。
这里举一个例子。NETtalk是MTL在1989年的一个代表性工作,可以根据输入的句子,同时学习phonemes和stresses,最终输出句子的英文读音。研究发现,当phoneme网络最优时,stress网络已经严重过拟合。
此时,我们应该给两个网络分配不同的学习率,或进行snapshot,不同时刻的网络给不同任务使用。
4.2. Architecture
Regularization methods such as weight decay can be used with MTL. By reducing the effective number of free parameters in the model, regularization promotes sharing.
Too strong a bias for sharing, however, can hurt performance.
MTL performance often drops if the size of the shared hidden layer is much smaller than the sum of the sizes of the STL hidden layers that would provide good performance on the tasks when trained separately.
4.3. What Are Related Tasks?
No clear definition. 事实上,CVPR 2018的best paper还在讨论这一问题。
要注意的是,如果两个任务同时训练时互相促进,不能证明它们相关。
例如,如果在BP网络中增加噪声,那么网络的泛化能力会增强,因为这相当于给hidden layer增加了正则项。但显然噪声任务与主线任务无关。
4.4. Transfer Learning
Sequential transfer learning differ from MTL, where the goal is to learn a better model for one task by learning all available extra tasks in parallel. MTL is a kind of parallel transfer.
5. 番外:生物学启发
人工神经网络在20世纪80年代末和90年代初达到巅峰,随后迅速衰落,其中一个重要原因是深度神经网络的发展严重受挫。
人们发现,如果网络的层数加深,那么最终网络的输出结果对于初始几层的参数影响微乎其微,整个网络的训练过程无法保证收敛。
同时,人们发现大脑具有不同的功能区域,每个区域专门负责同一类的任务,例如视觉图像识别, 语音信号处理和文字处理等等。在这一阶段,计算机科学家为不同的任务发展出不同的算法。例如,为了语音识别,人们发展了隐马尔科夫链模型;为了人脸识别,发展了Gaber滤波器,SIFT滤波,马尔科夫随机场的图模型。因此,在这个阶段,人们倾向于发展专用(task-specific)算法。
但是在2000年后,一系列生物学突破打破了人类对大脑的认识。
- 2000年,Jitendra Sharma等人发现小鼠的视觉、听觉神经系统是通用的。Sharma把幼年鼬鼠的视觉神经和听觉神经剪断,交换后接合,眼睛接到了听觉中枢,耳朵接到了视觉中枢。鼬鼠长大后,依然发展出了视觉和听觉。
- 2009年,Vuillerme等人让盲人用舌头掌握了“视觉”。他们将摄像机的输出表示成二维微电极矩阵,放在舌头表面。盲人经过一段时间的学习训练,可以用舌头“看到”障碍物。
- 2011年,Thaler等研究发现,盲人的视觉中枢经过训练,可以通过回声来探测并规避大的障碍物。
种种研究表明,大脑实际上是一台“万用学习机器”(Universal learning Machine)。那么,神经网络是否也应如此呢?
Paper | 多任务学习的鼻祖的更多相关文章
- [DeeplearningAI笔记]ML strategy_2_3迁移学习/多任务学习
机器学习策略-多任务学习 Learninig from multiple tasks 觉得有用的话,欢迎一起讨论相互学习~Follow Me 2.7 迁移学习 Transfer Learninig 神 ...
- caffe实现多任务学习
Github: https://github.com/Haiyang21/Caffe_MultiLabel_Classification Blogs 1. 采用多label的lmdb+Slice L ...
- DLNg[结构化ML项目]第二周迁移学习+多任务学习
1.迁移学习 比如要训练一个放射科图片识别系统,但是图片非常少,那么可以先在有大量其他图片的训练集上进行训练,比如猫狗植物等的图片,这样训练好模型之后就可以转移到放射科图片上,模型已经从其他图片中学习 ...
- 【论文笔记】多任务学习(Multi-Task Learning)
1. 前言 多任务学习(Multi-task learning)是和单任务学习(single-task learning)相对的一种机器学习方法.在机器学习领域,标准的算法理论是一次学习一个任务,也就 ...
- 深度神经网络多任务学习(Multi-Task Learning in Deep Neural Networks)
https://cloud.tencent.com/developer/article/1118159 http://ruder.io/multi-task/ https://arxiv.org/ab ...
- keras函数式编程(多任务学习,共享网络层)
https://keras.io/zh/ https://keras.io/zh/getting-started/functional-api-guide/ https://github.com/ke ...
- 多任务学习Multi-task-learning MTL
https://blog.csdn.net/chanbo8205/article/details/84170813 多任务学习(Multitask learning)是迁移学习算法的一种,迁移学习可理 ...
- [译]深度神经网络的多任务学习概览(An Overview of Multi-task Learning in Deep Neural Networks)
译自:http://sebastianruder.com/multi-task/ 1. 前言 在机器学习中,我们通常关心优化某一特定指标,不管这个指标是一个标准值,还是企业KPI.为了达到这个目标,我 ...
- ubuntu之路——day11.6 多任务学习
在迁移学习transfer learning中,你的步骤是串行的sequential process 在多任务学习multi-task learning中,你试图让单个神经网络同时做几件事情,然后这里 ...
随机推荐
- webpack 4.0配置2
上个博客记录了webpack 的基本配置今天主要是css-loader的介绍,包括单独提出css,压缩css.js文件 这里使用的插件npm 地址:https://www.npmjs.com/pack ...
- mui-H5获取当前手机通讯录
mui.plusReady(function() { // 扩展API加载完毕,现在可以正常调用扩展API plus.contacts.getAddressBook(plus.contacts.ADD ...
- python 练习题(1-15)
1.给定一个整数数组和一个目标值,找出数组中和为目标值的两个数. 2.生成双色球 3.逻辑运算(运算符优先级) 4.输入一个整数,判断这个数是几位数 5.用while循环计算 1-2+3-4...-9 ...
- 使用MagickNet编辑图片
ImageMagick是一个免费的创建.编辑.合成图片的软件.它可以读取.转换.写入多种格式的图片.图片切割.颜色替换.各种效果的应用,图片的旋转.组合,文本,直线,多边形,椭圆,曲线 ...
- Java框架spring 学习笔记(五):Bean定义继承
子 bean 的定义继承父定义的配置数据.子定义可以根据需要重写一些值,或者添加其他值. 编写HelloWorld.java package com.example.spring; public cl ...
- 聚宽获取财务数据+DataFrame写入txt
from jqdata import jy from jqdata import * #获取股票列表,这里是板块内股票 pool=get_industry_stocks(',date='2016-09 ...
- 域名系统DNS以及跨域问题
域名到Ip地址解析是由分布在因特网上的许多域名服务器程序共同完成的.运行域名服务器程序的机器是域名服务器 域名到ip地址的解析过程: 当一个应用进程需要把主机名解析为ip地址时,该应用就调用解析程 ...
- mysql学习笔记--数据库多表查询
一.内连接[inner join] 1. 语法一:select 列名 from 表1 inner join 表2 on 表1.公共字段=表2.公共字段 2. 语法二:select 列名 from 表1 ...
- 数据结构python编程总结
大数据.空间限制 布隆过滤器 使用很少的空间就可以将准确率做到很高的程度(网页黑名单系统.垃圾邮件过滤系统.爬虫的网址判重系统等) 有一定的失误率 单个样本的大小不影响布隆过滤器的大小 n个输入.k个 ...
- 获取网页title(还有一坑未填)
def getTitle(self,url): #get title title = 'time out' try: self.res = requests.get(url,timeout=5) so ...