迁移学习(Transfer Learning)

如果你要做一个计算机视觉的应用,相比于从头训练权重,或者说从随机初始化权重开始,如果你下载别人已经训练好网络结构的权重,你通常能够进展的相当快,用这个作为预训练,然后转换到你感兴趣的任务上。

计算机视觉的研究社区非常喜欢把许多数据集上传到网上,如果你听说过,比如ImageNet,或者MS_COCO,或者Pascal类型的数据集,这些都是不同数据集的名字,它们都是由大家上传到网络的,并且有大量的计算机视觉研究者已经用这些数据集训练过他们的算法了。

有时候这些训练过程需要花费好几周,并且需要很多的GPU,其它人已经做过了,并且经历了非常痛苦的寻最优过程,这就意味着你可以下载花费了别人好几周甚至几个月而做出来的开源的权重参数,把它当作一个很好的初始化用在你自己的神经网络上。用迁移学习把公共的数据集的知识迁移到你自己的问题上,让我们看一下怎么做。

来个栗子

举个例子,假如说你要建立一个猫咪检测器,用来检测你自己的宠物猫。

比如网络上的Tigger,是一个常见的猫的名字,Misty也是比较常见的猫名字。

假如你的两只猫叫Tigger和Misty,还有一种情况是,两者都不是。所以你现在有一个三分类问题,图片里是Tigger还是Misty,或者都不是,我们忽略两只猫同时出现在一张图片里的情况。现在你可能没有Tigger或者Misty的大量的图片,所以你的训练集会很小,你该怎么办呢?

 

我建议你从网上下载一些神经网络开源的实现,不仅把代码下载下来,也把权重下载下来

有许多训练好的网络,你都可以下载。举个例子,ImageNet数据集,它有1000个不同的类别,因此这个网络会有一个Softmax单元,它可以输出1000个可能类别之一。

 
 

你可以去掉上图中的这个Softmax层,创建你自己的Softmax单元,用来输出Tigger、Misty和neither三个类别

就网络而言,我建议你把所有的层看作是冻结的,你冻结网络中所有层的参数,你只需要训练和你的Softmax层有关的参数。这个Softmax层有三种可能的输出,Tigger、Misty或者都不是。

通过使用其他人预训练的权重,你很可能得到很好的性能,即使只有一个小的数据集。幸运的是,大多数深度学习框架都支持这种操作,事实上,取决于用的框架,它也许会有trainableParameter=0这样的参数,对于这些前面的层,你可能会设置这个参数。

为了不训练这些权重,有时也会有freeze=1这样的参数。不同的深度学习编程框架有不同的方式,允许你指定是否训练特定层的权重。在这个例子中,你只需要训练softmax层的权重,把前面这些层的权重都冻结。

 

另一个技巧,也许对一些情况有用,由于前面的层都冻结了,相当于一个固定的函数,不需要改变。因为你不需要改变它,也不训练它,取输入图像X,然后把它映射到这层(softmax的前一层)的激活函数

所以这个能加速训练的技巧就是,如果我们先计算这一层(紫色箭头标记),计算特征或者激活值,然后把它们存到硬盘里。你所做的就是用这个固定的函数,在这个神经网络的前半部分(softmax层之前的所有层视为一个固定映射),取任意输入图像X,然后计算它的某个特征向量,这样你训练的就是一个很浅的softmax模型,用这个特征向量来做预测。

对你的计算有用的一步就是对你的训练集中所有样本的这一层的激活值进行预计算,然后存储到硬盘里,然后在此之上训练softmax分类器。所以,存储到硬盘或者说预计算方法的优点就是,你不需要每次遍历训练集再重新计算这个激活值了。

因此如果你的任务只有一个很小的数据集,你可以这样做。

要有一个更大的训练集怎么办呢?

根据经验,如果你有一个更大的标定的数据集,也许你有大量的Tigger和Misty的照片,还有两者都不是的,这种情况,你应该冻结更少的层,比如只把上图中前面括起来的这些层冻结,然后训练后面的层。如果你的输出层的类别不同,那么你需要构建自己的输出单元,Tigger、Misty或者两者都不是三个类别。有很多方式可以实现,你可以取后面几层的权重,用作初始化,然后从这里开始梯度下降。

或者你可以直接去掉这几层,换成你自己的隐藏单元和你自己的softmax输出层,这些方法值得一试。但是有一个规律,如果你有越来越多的数据,你需要冻结的层数越少,你能够训练的层数就越多。这个理念就是,如果你有一个更大的数据集,也许有足够多的数据,那么不要单单训练一个softmax单元,而是考虑训练中等大小的网络,包含你最终要用的网络的后面几层。

最后,如果你有大量数据,你应该做的就是用开源的网络和它的权重,把这、所有的权重当作初始化,然后训练整个网络。再次注意,如果这是一个1000节点的softmax,而你只有三个输出,你需要你自己的softmax输出层来输出你要的标签。

如果你有越多的标定的数据,或者越多的Tigger、Misty或者两者都不是的图片,你可以训练越多的层。极端情况下,你可以用下载的权重只作为初始化,用它们来代替随机初始化,接着你可以用梯度下降训练,更新网络所有层的所有权重。

这就是卷积网络训练中的迁移学习,事实上,网上的公开数据集非常庞大,并且你下载的其他人已经训练好几周的权重,已经从数据中学习了很多了,你会发现,对于很多计算机视觉的应用,如果你下载其他人的开源的权重,并用作你问题的初始化,你会做的更好。

在所有不同学科中,在所有深度学习不同的应用中,我认为计算机视觉是一个你经常用到迁移学习的领域,除非你有非常非常大的数据集,你可以从头开始训练所有的东西。总之,迁移学习是非常值得你考虑的,除非你有一个极其大的数据集和非常大的计算量预算来从头训练你的网络。

【47】迁移学习(Transfer Learning)的更多相关文章

  1. 【转载】 迁移学习(Transfer learning),多任务学习(Multitask learning)和端到端学习(End-to-end deep learning)

    --------------------- 作者:bestrivern 来源:CSDN 原文:https://blog.csdn.net/bestrivern/article/details/8700 ...

  2. 迁移学习-Transfer Learning

    迁移学习两种类型: ConvNet as fixed feature extractor:利用在大数据集(如ImageNet)上预训练过的ConvNet(如AlexNet,VGGNet),移除最后几层 ...

  3. 【深度学习系列】迁移学习Transfer Learning

    在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...

  4. pytorch例子学习——TRANSFER LEARNING TUTORIAL

    参考:https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html 以下是两种主要的迁移学习场景 微调convnet : ...

  5. 图像识别 | AI在医学上的应用 | 深度学习 | 迁移学习

    参考:登上<Cell>封面的AI医疗影像诊断系统:机器之心专访UCSD张康教授 Identifying Medical Diagnoses and Treatable Diseases b ...

  6. AI小白必读:深度学习、迁移学习、强化学习别再傻傻分不清

    摘要:诸多关于人工智能的流行词汇萦绕在我们耳边,比如深度学习 (Deep Learning).强化学习 (Reinforcement Learning).迁移学习 (Transfer Learning ...

  7. 【迁移学习】2010-A Survey on Transfer Learning

    资源:http://www.cse.ust.hk/TL/ 简介: 一个例子: 关于照片的情感分析. 源:比如你之前已经搜集了大量N种类型物品的图片进行了大量的人工标记(label),耗费了巨大的人力物 ...

  8. [DeeplearningAI笔记]ML strategy_2_3迁移学习/多任务学习

    机器学习策略-多任务学习 Learninig from multiple tasks 觉得有用的话,欢迎一起讨论相互学习~Follow Me 2.7 迁移学习 Transfer Learninig 神 ...

  9. 迁移学习(Transformer),面试看这些就够了!(附代码)

    1. 什么是迁移学习 迁移学习(Transformer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中.迁移学习是通过从已学习的相 ...

随机推荐

  1. 暑假第六周总结(对HBASE进行编程实践并且安装Redis)

    本周主要是根据教程对HBASE进行了编程实践,对于hadoop的编程来说需要用到很多的.jar 包,在进行编程实践的时候需要参照相关的教程将jar包添加至程序当中去.教程上给的代码还是比较详细的,加上 ...

  2. uml图六种箭头的含义

    转:https://blog.csdn.net/wglla/article/details/52225571 在看一些技术博客的时候,经常会见到博客里画上很多uml图.因为经常会被这几种表达关系的箭头 ...

  3. 新的起航从这里开始 Encantado!

    大家好,我是一名DBA之前也在其它地方写过blog,但是可惜目前在greatwall之内都不能访问了. 如果有小伙伴可以在墙外访问的话 可以尝试着看看这个地址 https://liuleiit.wix ...

  4. 第3章 JDK并发包(三)

    3.2 线程复用:线程池 一种最为简单的线程创建和回收的方法类似如下代码: new Thread(new Runnable() { @Override public void run() { // d ...

  5. Source Code Structure - Python 源码目录结构

    Source Code Structure - Python 源码目录结构 Include 目录包含了 Python 提供的所有头文件, 如果用户需要用 C 或 C++ 编写自定义模块扩展 Pytho ...

  6. 如何用Python实现do...while语句

    我在编程的时候可能会遇到如下代码: a = 0 while a != 0: a = input() print a 我所设想的运行过程是这样的: 很显然我是想先运行后判断的模式,即 do...whil ...

  7. k8s系列---pod介绍

    # yaml格式的pod定义文件完整内容: apiVersion: v1 #必选,版本号,例如v1 kind: Pod #必选,Pod metadata: #必选,元数据 name: string # ...

  8. zabbix-mysql迁移分离

    io过高,迁移mysql 停掉zabbix 导出数据库的zabbix库 导入到新机器,并启动mysql 1:修改zabbix_server.conf文件里DB相关的地址,用户名和密码. vim /et ...

  9. kubeadm安装Kubernetes 1.14最佳实践

    前言 Kubernetes作为容器编排工具,简化容器管理,提升工作效率而颇受青睐.很多新手部署Kubernetes由于“上网”问题举步维艰,本文以实战经验详解kubeadm不用“翻墙”部署Kubern ...

  10. redis系列-14点的灵异事件

    概述 项目组每天14点都会遭遇惊魂时刻.一条条告警短信把工程师从午后小憩中拉回现实.之后问题又神秘消失.是PM喊你上工了?还是服务器给你开玩笑?下面请看工程师如何一步一步揪出真凶,解决问题. 如果不想 ...