上次使用Google ML Engine跑了一下TensorFlow Object Detection API中的Quick Start(http://www.cnblogs.com/take-fetter/p/8384564.html),但是遇到了很多错误,索性放弃了

这两天直接开始从自己的数据集开始制作手掌识别器。先放运行结果吧

所有代码文件可在https://github.com/takefetter/hand-detection查看,欢迎star和issue

使用前所需要的准备:1.clone tensorflow models(site:https://github.com/tensorflow/models)

          2.在model/research目录下运行setup.py安装object detection API

          3.其余必要条件:安装tensorflow(版本需大于等于1.4),opencv-python等必须的package

          4.安装Google Cloud SDK,激活免费试用300美金(需要一张信用卡来验证)和在命令行中使用gcloud init设置等

  •  准备数据集

  (关于手的图片的dataset仍旧使用的dlib训练(site:http://www.cnblogs.com/take-fetter/p/8321158.html)中的Hand Images Databases - https://www.mutah.edu.jo/biometrix/hand-images-databases.html提供的数据集,只不过这次使用了WEHI系列的图片(MOHI的图片我也试过,导入后会导致standard-gpu版的训练无法进行(内存不足)),作为示例目前我只使用了1-50人的共计250张图片)

   tensorflow训练的数据集需为TFRecord格式,我们需要对训练数据进行标注,但是我并没有找到直接可以标注生成的工具,还好有工具可以生成Pascal VOC格式的xml文件      https://github.com/tzutalin/labelImg,推荐将图片文件放于research/images中,保存xml文件夹位于research/images/xmls中

根据你要训练的数据集,创建.pbtxt文件

  • 转换为tfrecord格式

   完成图片标注后在xmls文件夹中运行xml_to_csv.py即可生成csv文件,再通过create_hand_tfrecord.py即可将图片转换为hand.record文件

   需要注意的是,如果你需要训练的数据集和我这里的不一样的话,create_hand_tfrecord.py的todo部分需要与你的.pbtxt文件内的内容一致

   (方法参考至https://github.com/datitran/raccoon_dataset 使用本作者的文件还可以完成划分测试集和分析数据等功能,当然我这里并没有使用)

  • 下载预训练模型

   重新开始一个模型的训练时间是很长的时间,而tensorflow model zoo为我们提供好了预训练的模型(site:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#coco-trained-models-coco-models),选择并下载一个 我选择的是

速度最快的ssd_mobilenet_v1,下载后解压可找到3个含有ckpt的文件,如图

  之后还需下载并配置model对应的config文件(https://github.com/tensorflow/models/tree/master/research/object_detection/samples/configs)并修改文件中的内容

需要修改的地方有:

  1. num_classes: 改为pbtxt中类的数目
  2. PATH_TO_BE_CONFIGURED的部分改为相应的目录
  3. num_steps定义了学习的上限 默认是200000 可自己更改,训练过程中也可以随时停止
  • 上传文件并在Google Cloud Platform中训练

  1.上传3个ckpt文件以及config文件和.record文件

      到google cloud控制台-存储目录下,创建存储分区(这里使用takefetter_hand_detector),并新建data文件夹,拖拽上传到该目录中完成后的目录和文件如下

+ takefetter_hand_detector/
+ data/
- ssd_mobilenet_v1_hand.config
- model.ckpt.index
- model.ckpt.meta
- model.ckpt.data-00000-of-00001
- hand_label_map.pbtxt
- hand.record

  2. 打包tf slim和object detection

     在research目录下运行

python setup.py sdist
(cd slim && python setup.py sdist)

  3.创建机器学习任务

    在research目录下运行此命令 开始训练

gcloud ml-engine jobs submit training `whoami`_object_detection_`date +%s` \
--runtime-version 1.4 \
--job-dir=gs://takefetter_hand_detector/train \
--packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz \
--module-name object_detection.train \
--region us-central1 \
--config object_detection/samples/cloud/cloud.yml \
-- \
--train_dir=gs://takefetter_hand_detector/train \
--pipeline_config_path=gs://takefetter_hand_detector/data/ssd_mobilenet_v1_hand.config

需要注意的地方有

  1. windows下需要放在同一行运行 并删除\
  2. cloud.yml文件中的内容可以自行更改,我这里的设置为
    trainingInput:
    runtimeVersion: "1.4"
    scaleTier: CUSTOM
    masterType: standard_gpu
    workerCount: 2
    workerType: standard_gpu
    parameterServerCount: 2
    parameterServerType: standard

在提交任务后在 机器学习引擎-作业中即可看到具体情况,每运行几千次后在 takefetter_hand_detector/train中存储对应cheakpoint的文件 如图

之后下载需要的cheak的3个文件 复制到research目录下(这里用30045的3个文件作为示例),并将research/object_detectIon目录下的export_inference_graph.py复制到research目录下 运行例如

python object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path object_detection/samples/configs/ssd_mobilenet_v1_hand.config \
--trained_checkpoint_prefix model.ckpt-30045 \
--output_directory exported_graphs

在运行完成后research目录中会生成文件夹exported_graphs_30045 包含的文件如图所示

拷贝frozen_inference_graph.pb和pbtxt文件到test/hand_inference_graph文件夹,并运行hand_detector.py 即可得到如文章开头的结果

后记:

1.如果需要视频实时的hand tracking,可使用https://github.com/victordibia/handtracking 在我的渣本上FPS太低了......

2.我目前使用的数据集还是较小训练次数也比较少,很容易出现一些误识别的情况,之后还会加大数据集和训练次数

3.换用其他model应该也会显著改善识别精确度

4.遇到任何问题,欢迎提问(虽然感觉大多数stack overflow都有)

5.本地训练要好很多,如果使用在Google Cloud训练中可能会遇到问题,但是解决方法是将tensorflow版本改为1.2,但是1.2版本的object detection在准备阶段就会遇到问题,目前来看确实无解。(毕竟API Caller)

6.本地训练建议使用tensorflow版本为1.2

感谢:

  1. https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_pets.md
  2. https://github.com/victordibia/handtracking
  3. https://pythonprogramming.net/testing-custom-object-detector-tensorflow-object-detection-api-tutorial/?completed=/training-custom-objects-tensorflow-object-detection-api-tutorial/
  4. https://github.com/datitran/raccoon_dataset
  5. https://www.mutah.edu.jo/biometrix/hand-images-databases.html

使用TensorFlow Object Detection API+Google ML Engine训练自己的手掌识别器的更多相关文章

  1. 基于TensorFlow Object Detection API进行迁移学习训练自己的人脸检测模型(二)

    前言 已完成数据预处理工作,具体参照: 基于TensorFlow Object Detection API进行迁移学习训练自己的人脸检测模型(一) 设置配置文件 新建目录face_faster_rcn ...

  2. TensorFlow Object Detection API(Windows下训练)

    本文为作者原创,转载请注明出处(http://www.cnblogs.com/mar-q/)by 负赑屃 最近事情比较多,前面坑挖的有点久,今天终于有时间总结一下,顺便把Windows下训练跑通.Li ...

  3. [Tensorflow] Object Detection API - predict through your exclusive model

    开始预测 一.训练结果 From: Testing Custom Object Detector - TensorFlow Object Detection API Tutorial p.6 训练结果 ...

  4. TensorFlow object detection API

    cloud执行:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_pet ...

  5. Tensorflow object detection API ——环境搭建与测试

    1.开发环境搭建 ①.安装Anaconda 建议选择 Anaconda3-5.0.1 版本,已经集成大多数库,并将其作为默认python版本(3.6.3),配置好环境变量(Anaconda安装则已经配 ...

  6. Tensorflow object detection API 搭建物体识别模型(一)

    一.开发环境 1)python3.5 2)tensorflow1.12.0 3)Tensorflow object detection API :https://github.com/tensorfl ...

  7. Tensorflow object detection API 搭建物体识别模型(二)

    二.数据准备 1)下载图片 图片来源于ImageNet中的鲤鱼分类,下载地址:https://pan.baidu.com/s/1Ry0ywIXVInGxeHi3uu608g 提取码: wib3 在桌面 ...

  8. 基于TensorFlow Object Detection API进行相关开发的步骤

    *以下二/三.四步骤确保你当前的文件目录是以research文件夹为相对目录. 一/安装或升级protoc 查看protoc版本命令: protoc --version 如果发现版本低于2.6.0或运 ...

  9. TensorFlow object detection API应用

    前一篇讲述了TensorFlow object detection API的安装与配置,现在我们尝试用这个API搭建自己的目标检测模型. 一.准备数据集 本篇旨在人脸识别,在百度图片上下载了120张张 ...

随机推荐

  1. 【深度学习系列】CNN模型的可视化

    前面几篇文章讲到了卷积神经网络CNN,但是对于它在每一层提取到的特征以及训练的过程可能还是不太明白,所以这节主要通过模型的可视化来神经网络在每一层中是如何训练的.我们知道,神经网络本身包含了一系列特征 ...

  2. Linux 将本地文件上传Linux服务器, 即ssh 命令上传本地文件

    利用ssh传输文件   在linux下一般用scp这个命令来通过ssh传输文件. 1.从服务器上下载文件 scp username@servername:/path/filename /var/www ...

  3. mysql中的union和order by、limit

      我有一个表 CREATE TABLE `test1` (  `id` int(10) unsigned NOT NULL AUTO_INCREMENT,  `name` varchar(20) N ...

  4. PHP和Python如何选择?或许可以考虑这三个问题

    撤稿纠错 文/黄小天.李亚洲 (选自Hackernoon 机器之心编译) 2017 年可谓是网页应用与 API 之年,开发者不用每次重新发明轮子,而是利用脚手架和第三方库就能确保项目在几天内实时部署. ...

  5. twitter的ID生成器的snowFlake算法的自造版

    snowFlake算法在生成ID时特别高效,可参考:https://segmentfault.com/a/1190000011282426 SnowFlake算法生成id的结果是一个64bit大小的整 ...

  6. Struts2与Ajax数据交互

    写在前面: ajax请求在项目中常常使用,今天就平时掌握的总结一下,关于使用ajax请求到Struts2中的action时,前台页面与后台action之间的数据传递交互问题. 这里我主要记录下自己所掌 ...

  7. windows安装xampp时出现,unable to realloc xxxxxxxx bytes

    摘录自:http://blog.csdn.net/lz610756247/article/details/70842166 Windows虚拟内存的设置 问题描述:由于开启虚拟内存会导致硬盘IO性能下 ...

  8. SVN的安装和配置

    SVN为程序开发团队常用的代码管理,版本控制软件:下面我们来介绍TortoiseSVN的安装,和其服务器的搭建:(下面为windows 64位系统下的搭建) 闲来无事,就在本地搭建了一个SVN环境,网 ...

  9. NtDuplicateObject小解读

    源进程和目标进程可以是一个吗 当然执行进程可以是同一个吗 ,当然标志位重要!有一个关闭源进程的标志位 第一步通过ObReferenceHandleTable获得源进程对象(数据结构) //为新的句柄构 ...

  10. Python实现一个简单的微信跳一跳辅助

    1.  前言 微信的跳一跳相信大家都很熟悉了,而且现在各种外挂.辅助也是满天飞,反正本人的好友排行榜中已经是八九百都不足为奇了.某宝上一搜一堆结果,最低的居然只要3块多,想刷多少分就刷多少分,真是离谱 ...