完全版见github:TransforLearning

零、迁移学习

将一个领域的已经成熟的知识应用到其他的场景中称为迁移学习。用神经网络的角度来表述,就是一层层网络中每个节点的权重从一个训练好的网络迁移到一个全新的网络里,而不是从头开始,为每特定的个任务训练一个神经网络。

假设你已经有了一个可以高精确度分辨猫和狗的深度神经网络,你之后想训练一个能够分别不同品种的狗的图片模型,你需要做的不是从头训练那些用来分辨直线,锐角的神经网络的前几层,而是利用训练好的网络,提取初级特征,之后只训练最后几层神经元,让其可以分辨狗的品种。节省资源是迁移学习最大意义之一,举图像识别中最常见的例子,训练一个神经网络。来识别不同的品种的猫,你若是从头开始训练,你需要百万级的带标注数据,海量的显卡资源。而若是使用迁移学习,你可以使用Google发布的Inception或VGG16这样成熟的物品分类的网络,只训练最后的softmax层,你只需要几千张图片,使用普通的CPU就能完成,而且模型的准确性不差。

使用迁徙学习时要注意,本来预训练的神经网络要和当前的任务差距不大,不然迁徙学习的效果会很差。例如如果你要训练一个神经网络来识别肺部X光片中是否包含肿瘤,那么使用VGG16的网络就不如使用一个已训练好的判断脑部是否包含肿瘤的神经网络。后者与当前的任务有相似的场景,很多底层的神经员可以做相同的事,而用来识别日常生活中照片的网络,则难以从X光片中提取有效的特征。

另一种迁移学习的方法是对整个网络进行微调(fine turing),假设你已训练好了识别猫品种的神经网络,你的网络能对50种猫按品种进行分类。接下来你想对网络进行升级,让其能够识别100种猫,这时你不应该只训练网络的最后一层,而应该逐层对网络中每个节点的权重进行微调。显然,只训练最后几层,是迁移学习最简单的1.0版,而对节点权重进行微调,就是更难的2.0版,通过将其他层的权重固定,只训练一层这样的逐层训练,可以更好的完成上述任务。

迁移方式和数据集规模关系

1)右下角场景,待训练的数据集较小,已训练的模型和当前任务相似。此时可以只是重新训练已有模型的靠近输出的几层,例如将ImageNet中输出层原来可以判别一万种输出的网络改的只能判别猫的品种,从而利用已有网络来做低层次的特征提取。

2)左下角场景,待训练的数据集较小,已训练的模型和当前任务场景差距较大。例如你有的已训练网络能识别出白天高速路上的违章车辆,你需要训练一个能识别出夜间违章车辆的模型,由于不管白天夜晚,交通规则是没有变化的,所以你需要将网络靠近输入的那几层重新训练,等到新的网络能够提取出夜间车辆的基本信息后,就可以借用已有的,在大数据集下训练好的神经网络来识别违章车辆,而不用等夜间违章的车辆的照片积累的足够多之后再重新训练。

3)左上角场景,待训练的数据集较大,已有的模型和新模型的数据差异度很高。此时应该做的是从头开始,重新训练。

4)右上角场景,待训练的数据集较大,已有模型的训练数据和现有的训练数据类似。此时应该使用原网络的结构微调。

一、实验目的

使用google已经训练好的模型,将最后的全连接层修改为我们自己的全连接层,将原有的1000分类分类器修改为我们自己的5分类分类器,利用原有模型的特征提取能力实现我们自己数据对应模型的快速训练。实际中对于一个陌生的数据集,原有模型经过不高的迭代次数即可获得很好的准确率。

二、代码实战

实机文件夹如下:

花朵图片数据下载:

curl -O http://download.tensorflow.org/example_images/flower_photos.tgz

已经训练好的Inception-v3的1000分类模型下载:

wget https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip

迁移学习代码以及使用指南见github,本次的代码可以将新保存的模型迁移到自己的数据集上,并对自己的数据进行预测。

新添加的测试部分,直接运行TransferLearning_reload.py即可,输出如下,

第二行白字的五个值对应第一行的5个分类的概率。

三、问题&建议

1.建议从main函数开始阅读,跳到哪里读到那里;

2.我给的注释很详尽,原书《TensorFlow实战Google深度学习框架》也有更为详尽的注释,所以这里不多说了

之前本部分对输入图片的过程进行了分析,当时水平有限,现在看来很幼稚,实际上分析一下图上节点即可了解:

  • InceptionV3接受二进制数据即可自行解码,即接收open().read()的二进制流即可
  • 保存的模型文件Graph包含了InceptionV3和新的classer,但是两者是隔离的,这是由于程序中并没将两者联通,是先把InceptionV3的瓶颈张量feed出来,然后是用这个数组去feed新的classer,但是由于saver、InceptionV3、classer使用的sess是同一个,所以最终两者都保存在了model模型中

『TensorFlow』迁移学习的更多相关文章

  1. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  2. 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍

    一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...

  3. 『TensorFlow』SSD源码学习_其五:TFR数据读取&数据预处理

    Fork版本项目地址:SSD 一.TFR数据读取 创建slim.dataset.Dataset对象 在train_ssd_network.py获取数据操作如下,首先需要slim.dataset.Dat ...

  4. 『TensorFlow』梯度优化相关

    tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...

  5. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  6. 『TensorFlow』读书笔记_降噪自编码器

    『TensorFlow』降噪自编码器设计  之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...

  7. 『TensorFlow』分布式训练_其三_多机分布式

    本节中的代码大量使用『TensorFlow』分布式训练_其一_逻辑梳理中介绍的概念,是成熟的多机分布式训练样例 一.基本概念 Cluster.Job.task概念:三者可以简单的看成是层次关系,tas ...

  8. 『TensorFlow』滑动平均

    滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量. 1.滑动平均求解对象初始化 ema = tf.train.Ex ...

  9. 『TensorFlow』流程控制

    『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...

随机推荐

  1. BZOJ5018: [Snoi2017]英雄联盟

    Description 正在上大学的小皮球热爱英雄联盟这款游戏,而且打的很菜,被网友们戏称为「小学生」.现在,小皮球终于受不 了网友们的嘲讽,决定变强了,他变强的方法就是:买皮肤!小皮球只会玩N个英雄 ...

  2. 02:httpd-2.2基础配置

    ---恢复内容开始--- 9.日志设定 错误日志: ErrorLog logs/error_log //这里使用了相对路径,相对于/etc/httpd/路径 LogLevel warn  //定义日志 ...

  3. Using keytool to import keystore

    open command line and locate to the location of  keytool.exe. import cert to keystore command: keyto ...

  4. Jquery中的DOM操作:

    DOM是 Document Object Model的缩写,是一种与浏览器,平台,语言无关的接口,使用该接口可以访问页面中所有的标准组件,下面介绍一下常用的一些DOM操作: 选择节点: 将在下篇博客中 ...

  5. Java Virtual Machine(Java虚拟机)

    JVM是Java Virtual Machine(Java虚拟机)的缩写,JVM是一种用于计算设备的规范,它是一个虚构出来的计算机,是通过在实际的计算机上仿真模拟各种计算机功能来实现的. Java语言 ...

  6. 虹软2.0免费离线人脸识别 Demo [C++]

    环境: win10(10.0.16299.0)+ VS2017 sdk版本:ArcFace v2.0 OPENCV3.43版本 x64平台Debug.Release配置都已通过编译 下载地址:http ...

  7. 学习笔记22—PS小技巧

    1.将图片四角变弧形:菜单-->选择-->平滑-->设置参数: 2.画曲线的方法是: 1)选择钢笔工具, 2)工具属性选择路径:用钢笔点下路径的起点,点下即松开鼠标:在下一个锚点,点 ...

  8. sessionId的生成机制

    目录 面试问道这个我居然不知道怎么回答,当然也是因为我确实没有研究过.下面就是百度了一篇文章后简单回答这个问题. 参考:http://www.cnblogs.com/sharpxiajun/p/339 ...

  9. cocos2dx 编译遇到资源里有.svn文件不能删除报错的问题

    使用cocos compile -p android 对项目进行编译的时候,遇到res文件中包含了只读属性的svn目录,不能进行删除而报错. 错误如下图(build_android.py里面对.svn ...

  10. 在linux中,我为什么不能安装VMware Tools?

    在linux中,我为什么不能安装VMware Tools? 应该是操作不正确导致,以下为linux安装VMware Tools的方法. 1.在安装Linux的虚拟机中,单击“虚拟机”菜单下的“安装Vm ...