21个项目玩转深度学习:基于TensorFlow的实践详解03—打造自己的图像识别模型
书籍源码:https://github.com/hzy46/Deep-Learning-21-Examples
CNN的发展已经很多了,ImageNet引发的一系列方法,LeNet,GoogLeNet,VGGNet,ResNet每个方法都有很多版本的衍生,tensorflow中带有封装好各方法和网络的函数,只要喂食自己的训练集就可以完成自己的模型,感觉超方便!!!激动!!!因为虽然原理流程了解了,但是要写出来真的。。。。好难,臣妾做不到啊~~~~~~~~
START~~~~
1.数据准备
首先了解下微调的概念: 以VGG为例
他的结构是卷积+全连接,卷积层分为5个部分共13层,conv1~conv5。还有三层全连接,即fc6,fc7,fc8。总共16层,因此被称为VGG16。
a.如果要将VGG16的结构用于一个新的数据集,首先要去掉fc8,因为fc8原本的输出是1000类的概率。需要改为符合自身训练集的输出类别数。
b.训练的时候,网络的参数的初始值并不是随机化生成的,而是采用VGG16在ImageNet上已经训练好的参数作为训练的初始值。因为已经训练过的VGG16中的参数已经包含了大量有用的卷积过滤器,这样做不仅节约大量训练时间,而且有助于分类器性能的提高。
载入VGG16的参数后,即可开始训练。此时需要指定训练层数的范围。一般而言,可以选择以下几种:
- 只训练fc8:训练范围一定要包含fc8这一层。这样的选择一般性能都不会太好,但速度很快;因为他只训练fc8,保持其他层的参数不动,相当于把VGG16当成一个“特征提取器”,用fc7层提取的特征做一个softmax的模型分类。
- 训练所有参数:耗时较慢,但能取得较高性能。
- 训练部分参数:通常是固定浅层参数不变,训练深层参数。如固定conv1、conv2部分的参数不训练,只训练conv3、conv4、conv5、fc6、fc7、fc8的参数。
这种训练方法就是对神经网络做微调。
1.1 切分train&test
书中提供了卫星图像数据集,有6个类别,分别是森林(wood),水域(water),岩石(rock),农田(wetland),冰川(glacier),城市区域(urban)
保存结构为data_prepare/pic,再下层有两个文件夹train和validation,各文件夹下有6个文件夹,放的是该类别下的图片。
1.2 转换成tfrecord格式
python data_convert.py -t pic/ \
--train-shards 2 \
--validation-shards 2 \
--num-threads 2 \
--dataset-name satellite
参数解释:
-t pic/ :表示转换pic文件夹下的数据,该文件夹必须与上面的文件结构保持一致
--train-shards 2 :把训练集分成两块,即最后的训练数据就是两个tfrecord格式的文件。若数据集更大,可以分更多数据块
--validation-shards 2 :把验证集分成两块
--num-thread 2 :用两个线程来产生数据。注意线程数必须要能整除train-shards和validation-shards,来保证每个线程处理的数据块是相同的。
--dataset-name :给生成的数据集起个名字,即表示最后生成文件的开头是satellite_train和satellite_validation
运行上述命令后,就可以在 pic 文件夹中找到 5 个新生成的文件 ,分别是:
- 训练数据 satellite_train_00000-of-00002.tfrecord、satellite_train_00001-of-00002.tfrecord,
- 验证数据 satellite_validation_00000-of-00002.tfrecord、satellite_validation_00001-of-00002.tfrecord。
- label.txt 它表示图片的内部标签(数字)到真实类别(字符串)之间的映射顺序 。 如图片在 tfrecord 中的标签为 0 ,那么就对应 label.txt 第一行的类别,在 tfrecord的标签为1,就对应 label.txt 中第二行的类别,依此类推。
2.训练模型
2.1 TensorFlow Slim
Google 公司公布的一个图像分类工具包,它不仅定义了一些方便的接口,还提供了很多 ImageNet 数据集上常用的网络结构和预训练模型 。
截至2017年7月,Slim 提供包括 VGG16、VGG19、Inception V1 ~ V4、ResNet 50、ResNet 101、MobileNet 在内大多数常用模型的结构以及预训练模型,更多的模型还会被持续添加进来。
源码地址: https://github.com/tensorflow/models/tree/master/research/slim
可以通过执行 git clone https://github.corn/tensorflow/models.git 来获取
2.2 定义新的datasets文件<修改slim源码>
在slim/datasets中,定义了所有可以使用的数据库,为了使用之前生成的 tfrecord 数据进行训练,必须要在datasets 中定义新的数据库。
- 在 datasets/目录下新建一个文件 satellite,并将 flowers.py 文件中的内容复制到 satellite.py 中。
修改两处代码:
- 在同目录的 dataset_factory. py 文件中注册 satellite 数据库
2.3 准备训练文件夹
在slim文件下新建satellite文件夹,按照上图结构准备各个文件和数据,模型下载地址:http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz,也可以在slim的github地址上找到其他模型的下载链接。
2.4 开始训练
python train_image_classifier.py \
--train_dir=satellite/train_dir \
--dataset_name=satellite \
--dataset_split_name=train \
--dataset_dir=satellite/data \
--model_name=inception_v3 \
--checkpoint_path=satellite/pretrained/inception_v3.ckpt \
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--max_number_of_steps=100000 \
--batch_size=32 \
--learning_rate=0.001 \
--learning_rate_decay_type=fixed \
--save_interval_secs=300 \
--save_summaries_secs=2 \
--log_every_n_steps=10 \
--optimizer=rmsprop \
--weight_decay=0.00004
参数解释:
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits:首先来解释参数trainable_scopes 的作用,因为它非常重要。 trainable_scopes 规定了在模型中微调变量的范围 。这里的设定表示只对InceptionV3/Logits,InceptionV3/AuxLogits 两个变量进行微调,其他变量都保持不动 。InceptionV3/Logits,InceptionV3/AuxLogits 就相当于之前所讲的 fc8 ,它们是 Inception V3 的“末端层” 。如果不设定 trainable_scopes ,就会对模型中所有的参数进行训练。
--train_dir=satellite/train_dir : 表明会在 satellite/train_dir 目录下保存日志和 checkpoint 。可通过tensorboard --logdir satellite/train_dir来可视化参数变化。
--dataset_name=satellite 、--dataset_split_name=train : 指定训练的数据集 。之前定义的新的 dataset 就是在这里发挥用处的 。
--dataset_dir=satellite/data :指定训练数据集保存的位置 。
--model_name=inception_v3 :使用的模型名称 。
--checkpoint_path=satellite/pretrained/inception_v3.ckpt :预训练模型的保存位置 。
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits : 在恢复预训练模型时,不恢复这两层。正如之前所说,这两层是 InceptionV3 模型的末端层,对应着 ImageNet 数据集的 1000 类,和当前的数据集不符 ,因此不要去恢复它 。
--max_number_of_steps 100000 :最大的执行步数 。
--batch_size=32 :每步使用的 batch 数量 。
--learning_rate=0.001 :学习率 。
--learning_rate_decay_type=fixed :学习率是否自动下降,此处使用固定的学习率 。
--save_interval_secs=300 :每隔 300s ,程序会把当前模型保存到 train_dir 中 。 此处就是目录 satellite/train_dir 。
--save_summaries_secs=2 :每隔 2s ,就会将日志写入到 train_dir 中 。 可以用 TensorBoard 查看该日志 。 此处为了方便观察,设定的时间间隔较多,实际训练时,为了性能考虑,可以设定较长的时间间隔 。
--log_every_n_steps=10 :每隔 10 步,就会在屏幕上打出训练信息。
--optimize=rmsprop:表示选定的优化器 。
--weight_decay=0.00004 :选定的 weight_decay 值 。 即模型中所有参数的二次正则化超参数 。
如果想要对所有层训练,去掉--trainable_scopes参数即可。
训练结果:
程序开始运行时,会先读取train_dir下的模型,若没有,则去读checkpoint_path目录下的预训练模型,以5min的频率保留模型(仅保留最新的5个模型)。
若程序中断再次训练,读取train_dir下的模型,不存在则依旧去找预训练模型,若存在,则以该模型的参数为起点进行参数微调。
3.验证准确率
python eval_image_classifier.py \
--checkpoint_path=satellite/train_dir \
--eval_dir=satellite/eval_dir \
--dataset_name=satellite \
--dataset_split_name=validation \
--dataset_dir=satellite/data \
--model_name=inception_v3
参数说明:
--checkpoint_path=satellite/train_dir :这个参数既可以接收一个目录的路径,也可以接收一个文件的路径。 如果接收的是一个目录的路径,如这里的 satellite/train_dir ,就会在这个目录中寻找最新保存的模型文件,执行验证。也可以指定一个模型进行验证,以第 300 步的模型为例,在satellite/train_dir 文件夹下它被保存为 model.ckpt-300.meta 、model.ckpt-300.index 、 model.ckpt-300.data-00000-of-OOOO1 三个文件 。 此时,如果要对它执行验证,给 checkpoint_path 传递的参数应该为satellite/train_dir/model.ckpt-300 。
--eval_dir=satellite/eval_dir:执行结果的曰志就保存在 eval_dir 中,同样可以通过 TensorBoard 查看。
--dataset_name=satellite、--dataset_split_name=validation : 指定需要执行的数据集 。 注意此处是使用验证集( validation )执行验证。
--dataset dir=satellite/data :数据集保存的位置 。
--model_name=inception_v3 :使用的模型。
运行结果:
Accuracy表示模型的分类准确率,Recall_5表示Top5的准确率,即输出的各类别概率中,正确的类别只要落在前5个中则算对。由于此处类别数不多,可尝试改成Top2或者Top3.
修改eval_image_classifier.py
可再次执行上述验证语句查看预测结果。
前面有讲到可以训练所有层的参数,再进行测试,发现Accuracy的值,在训练所有层的时候可达82%,效果更佳。
4.导出模型并对单张图片分类
4.1 生成.pb文件
在slim文件夹下有 export_inference_graph.py 文件,运行它会生成一个 inception_v3_inf_graph.pb 文件。该文件中仅保存了Inception V3的网络结构,后续需要把checkpoint中的模型参数保存进来。
python export_inference_graph.py \
--alsologtostderr \
--model_name=inception_v3 \
--output_file=satellite/inception_v3_inf_graph.pb \
--dataset_name satellite
4.2 保存模型参数
在chapter_3文件下有 freeze_graph.py 文件,运行它会生成一个 frozen_graph.pb 文件(一个用于识别的模型)。之后就可以用该文件对单张图片进行预测。
python freeze_graph.py \
--input_graph slim/satellite/inception_v3_inf_graph.pb \
--input_checkpoint slim/satellite/train_dir/model.ckpt-5271 \
--input_binary true \
--output_node_names InceptionV3/Predictions/Reshape_1 \
--output_graph slim/satellite/frozen_graph.pb
需将5271改成train_dir中保存的实际的模型训练步数
--input_graph slim/satellite/inception_v3_inf_graph.pb:这个参数很好理解 ,它表示使用的网络结构文件,即之前已经导出的inception_v3_inf_graph.pb 。
--input_checkpoint slim/satellite/train_dir/model.ckpt-5271:具体将哪一个checkpoint 的参数载入到网络结构中 。这里使用的是训练文件夹 train_dir中的第 5271 步模型文件 。 读者需要根据训练文件夹下 checkpoint 的实际步数,将 5271 修改成对应的数值。
--input_binary true:导入的 inception_v3_inf_graph.pb 实际是一个 protobuf 文件 。 而 protobuf 文件有两种保存格式,一种是文本形式,一种是二进制形式。 inception_v3_inf_graph.pb 是二进制形式,所以对应的参数是--input_binary true。初学的话对此可以不用深究,若有兴趣的话可以参考资料。
--output_node_names InceptionV3/Predictions/Reshape_1:在导出的模型中,指定一个输出结点, InceptionV3/Predictions/Reshape_1是 Inception V3最后的输出层 。
--output_graph slim/satellite/frozen_graph.pb:最后导出的模型保存为 slim/satellite/frozen_graph.pb 文件。
4.3 单张图片预测
python classify_image_inception_v3.py \
--model_path slim/satellite/frozen_graph.pb \
--label_path slim/satellite/data/label.txt \
--image_file test_image.jpg
运行结果如下:
[root@node5 chapter_03]# python3 classify_image_inception_v3.py --model_path slim/satellite/frozen_graph.pb --label_path slim/satellite/data/label.txt --image_file test_image.jpg
-- ::39.435166: I tensorflow/core/platform/cpu_feature_guard.cc:] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
water (score = 3.25221)
wetland (score = 1.97180)
urban (score = 1.33430)
wood (score = 0.53297)
rock (score = -0.41706)
score是各个类别对应的Logit。
代码逻辑实现:
拓展阅读
- TensorFlow Slim 是TensorFlow 中用于定义、训练和验证复杂网络的 高层API。官方已经使用TF-Slim 定义了一些常用的图像识别模型, 如AlexNet、VGGNet、Inception模型、ResNet等。本章介绍的Inception V3 模型也是其中之一, 详细文档请参考: https://github.com/tensorflow/models/tree/master/research/slim。
- 在第3.2节中,将图片数据转换成了TFRecord文件。TFRecord 是 TensorFlow 提供的用于高速读取数据的文件格式。读者可以参考博文( http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/ )详细了解如何将数据转换为TFRecord 文件,以及 如何从TFRecord 文件中读取数据。
- Inception V3 是Inception 模型(即GoogLeNet)的改进版,可以参考论文Rethinking the Inception Architecture for Computer Vision 了解 其结构细节。
THE END~~~~
21个项目玩转深度学习:基于TensorFlow的实践详解03—打造自己的图像识别模型的更多相关文章
- 21个项目玩转深度学习:基于TensorFlow的实践详解02—CIFAR10图像识别
cifar10数据集 CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集.一共包含 10 个类别的 ...
- 21个项目玩转深度学习:基于TensorFlow的实践详解01—MNIST机器学习入门
数据集 由Yann Le Cun建立,训练集55000,验证集5000,测试集10000,图片大小均为28*28 下载 # coding:utf-8 # 从tensorflow.examples.tu ...
- 21个项目玩转深度学习:基于TensorFlow的实践详解06—人脸检测和识别——项目集锦
摘自:https://github.com/azuredsky/mtcnn-2 mtcnn - Multi-task CNN library language dependencies comment ...
- 【原创 深度学习与TensorFlow 动手实践系列 - 4】第四课:卷积神经网络 - 高级篇
[原创 深度学习与TensorFlow 动手实践系列 - 4]第四课:卷积神经网络 - 高级篇 提纲: 1. AlexNet:现代神经网络起源 2. VGG:AlexNet增强版 3. GoogleN ...
- 【原创 深度学习与TensorFlow 动手实践系列 - 3】第三课:卷积神经网络 - 基础篇
[原创 深度学习与TensorFlow 动手实践系列 - 3]第三课:卷积神经网络 - 基础篇 提纲: 1. 链式反向梯度传到 2. 卷积神经网络 - 卷积层 3. 卷积神经网络 - 功能层 4. 实 ...
- 【原创 深度学习与TensorFlow 动手实践系列 - 1】第一课:深度学习总体介绍
最近一直在研究机器学习,看过两本机器学习的书,然后又看到深度学习,对深度学习产生了浓厚的兴趣,希望短时间内可以做到深度学习的入门和实践,因此写一个深度学习系列吧,通过实践来掌握<深度学习> ...
- faceswap深度学习AI实现视频换脸详解
给大家介绍最近超级火的黑科技应用deepfake,这是一个实现图片和视频换脸的app.前段时间神奇女侠加尔盖朵的脸被换到了爱情动作片上,233333.我们这里将会从github项目faceswap开始 ...
- 深度学习——优化器算法Optimizer详解(BGD、SGD、MBGD、Momentum、NAG、Adagrad、Adadelta、RMSprop、Adam)
在机器学习.深度学习中使用的优化算法除了常见的梯度下降,还有 Adadelta,Adagrad,RMSProp 等几种优化器,都是什么呢,又该怎么选择呢? 在 Sebastian Ruder 的这篇论 ...
- 深度学习之卷积神经网络(CNN)详解与代码实现(一)
卷积神经网络(CNN)详解与代码实现 本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-further/p/10430073.html 目 ...
随机推荐
- OSGi Capabilities
OSGi bundle的Capability就是这个bundle所具有的能力. 就像淘宝上的每个店铺一样,它会说明自己都卖哪些东西,也就是Provide-Capability 我们这些剁手党就会根据自 ...
- JetBrains 系列软件--操作数据库+centos系统
这系列软件贼强大! 能操作数据库 也能操作centos(linux)系统 由于这系列都有这两个功能,下面以最近常用的JetBrains PhpStorm 2017.2.1 x64来举例子: 一.操作数 ...
- node.js(二)各种模块
我们知道Node.js适合于IO密集型应用,不适合于CPU密集型应用. JS和Node.js区别: JS运行于客户端浏览器中,存在兼容性问题:数据类型:值类型+引用类型(ES+D ...
- 如何用KNIME进行情感分析
Customer Intelligence Social Media Finance Credit Scoring Manufacturing Pharma / Health Care Retail ...
- JVM学习篇章(二)
上节我们已经介绍了jvm和监控的一下方法,下面举例说明一下: 瓶颈问题定位: 内存泄漏原因定位: 1.常见的内存泄漏 2.定位的方法
- RabbitMQ默认端口
4369 (epmd), 25672 (Erlang distribution)5672, 5671 (AMQP 0-9-1 without and with TLS)15672 (if manage ...
- Python编码---转自金角大王
本节内容 编码回顾 编码转换 Python的bytes类型 编码回顾 在备编码相关的课件时,在知乎上看到一段关于Python编码的回答 这哥们的这段话说的太对了,搞Python不把编码彻底搞明白,总有 ...
- 2019-8-31-asp-dotnet-core-支持客户端上传文件
title author date CreateTime categories asp dotnet core 支持客户端上传文件 lindexi 2019-08-31 16:55:58 +0800 ...
- SDUT-3375_数据结构实验之查找三:树的种类统计
数据结构实验之查找三:树的种类统计 Time Limit: 400 ms Memory Limit: 65536 KiB Problem Description 随着卫星成像技术的应用,自然资源研究机 ...
- Oracle日期
oracle 日期格式 to_date("要转换的字符串","转换的格式") 两个参数的格式必须匹配,否则会报错. 即按照第二个参数的格式解释第一个参数. ...