训练自己的数据集(以bottle为例):

 

1、准备数据

文件夹结构:
models
├── images
├── annotations
│ ├── xmls
│ └── trainval.txt
└── bottle
├── train_logs 训练文件夹
└── val_logs 日志文件夹

1)、下载官方预训练模型: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 
ssd_mobilenet_v1_coco为例,将压缩包内model.ckpt*的三个文件复制到bottle内

2)、准备jpg图片数据,放入images文件夹(图片文件命名要求“名字+下划线+编号.jpg”,必须使用下划线,编号从1开始) 
使用https://github.com/tzutalin/labelImg工具对图片进行标注,生成xml文件放置xmls文件夹,并保持xml和jpg命名相同 
3)、新建 bottle/trainval.txt 文件,内容为(图片名 1 1 1),每行一个文件,如:

bottle_1 1 1 1
bottle_2 1 1 1

4)、新建object_detection/data/bottle_label_map.pbtxt,内容如下

item {
id: 1
name: 'bottle'
}
 

2、生成数据

# From tensorflow/models
python object_detection/create_pet_tf_record.py \
--label_map_path=object_detection/data/bottle_label_map.pbtxt \
--data_dir=`pwd` \
--output_dir=`pwd`

得到 pet_train.record 和 pet_val.record 移动至bottle文件夹

 

3、准备conf文件

复制object_detection/samples/configs/ssd_mobilenet_v1_pets.config到 /bottle/ssd_mobilenet_v1_bottle.config 
对ssd_mobilenet_v1_bottle.config文件进行一下修改:

修改第9行为 num_classes: 1,此数值代表bottle_label_map.pbtxt文件配置item的数量
修改第158行为 fine_tune_checkpoint: "bottle/model.ckpt"
修改第177行为 input_path: "bottle/pet_train.record"
修改第179行和193行为 label_map_path: "object_detection/data/bottle_label_map.pbtxt"
修改第191行为 input_path: "bottle/pet_val.record"
 

4、训练

# From tensorflow/models
python object_detection/train.py \
--logtostderr \
--pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config \
--train_dir=bottle/train_logs \
2>&1 | tee bottle/train_logs.txt &
 

5、验证

# From tensorflow/models
python object_detection/eval.py \
--logtostderr \
--pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config \
--checkpoint_dir=bottle/train_logs \
--eval_dir=bottle/val_logs &
 

6、可视化log

可一边训练一边可视化训练的log,可看到Loss趋势。

tensorboard --logdir train_logs/

浏览器访问 ip:6006,可看到趋势以及具体image的预测结果

 

7、导出模型

# From tensorflow/models
python object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path bottle/ssd_mobilenet_v1_bottle.config \
--trained_checkpoint_prefix bottle/train_logs/model.ckpt-8 \
--output_directory bottle

生成 bottle/frozen_inference_graph.pb 文件

 

8、测试图片

运行object_detection_tutorial.ipynb并修改其中的各种路径即可 
或自写编译inference脚本,如tensorflow/models/object_detection/infer.py:

import sys
sys.path.append('..')
import os
import time
import tensorflow as tf
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from utils import label_map_util
from utils import visualization_utils as vis_util
PATH_TEST_IMAGE = sys.argv[1]
PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb'
PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt'
NUM_CLASSES = 21
IMAGE_SIZE = (18, 12)
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with detection_graph.as_default():
with tf.Session(graph=detection_graph, config=config) as sess:
start_time = time.time()
print(time.ctime())
image = Image.open(PATH_TEST_IMAGE)
image_np = np.array(image).astype(np.uint8)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time))
vis_util.visualize_boxes_and_labels_on_image_array(
image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),
category_index, use_normalized_coordinates=True, line_thickness=8)
plt.figure(figsize=IMAGE_SIZE)
plt.imshow(image_np)

运行infer.py test_images/image1.jpg即可

使用Tensorflow训练自己的数据的更多相关文章

  1. TensorFlow.训练_资料(有视频)

    ZC:自己训练 的文章 貌似 能度娘出来很多,得 自己弄过才知道哪些个是坑 哪些个好用...(在CSDN文章的右侧 也有列出很多相关的文章链接)(貌似 度娘的关键字是"TensorFlow ...

  2. 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...

  3. smallcorgi/Faster-RCNN_TF训练自己的数据

    熟悉了github项目提供的训练测试后,可以来训练自己的数据了.本文只介绍改动最少的方法,只训练2个类, 即自己添加的类(如person)和 background,使用的数据格式为pascal_voc ...

  4. 通过TensorFlow训练神经网络模型

    神经网络模型的训练过程其实质上就是神经网络参数的设置过程 在神经网络优化算法中最常用的方法是反向传播算法,下图是反向传播算法流程图: 从上图可知,反向传播算法实现了一个迭代的过程,在每次迭代的开始,先 ...

  5. TensorFlow.js之根据数据拟合曲线

    这篇文章中,我们将使用TensorFlow.js来根据数据拟合曲线.即使用多项式产生数据然后再改变其中某些数据(点),然后我们会训练模型来找到用于产生这些数据的多项式的系数.简单的说,就是给一些在二维 ...

  6. tensorflow训练了10万次,运行完毕,对这个word2vec终于有点感觉了

    tensorflow训练了10万次,运行完毕,对这个word2vec终于有点感觉了 感觉它能找到词与词之间的关系,应该可以用来做推荐系统.自动摘要.相关搜索.联想什么的 tensorflow1.1.0 ...

  7. 人脸检测及识别python实现系列(3)——为模型训练准备人脸数据

    人脸检测及识别python实现系列(3)——为模型训练准备人脸数据 机器学习最本质的地方就是基于海量数据统计的学习,说白了,机器学习其实就是在模拟人类儿童的学习行为.举一个简单的例子,成年人并没有主动 ...

  8. 2、TensorFlow训练MNIST

    装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...

  9. tensorflow训练验证码识别模型

    tensorflow训练验证码识别模型的样本可以使用captcha生成,captcha在linux中的安装也很简单: pip install captcha 生成验证码: # -*- coding: ...

随机推荐

  1. PetaPoco源代码学习--1.使用的Attribute介绍

    新版本的PetaPoco使用特性进行注解的形式来代替的老版本的映射类的形式.新版本中使用的特性主要包括以下几种: 名称   用途 TableNameAttribute Class 指定POCO实体类对 ...

  2. JavaScript之String总汇

    一.常用属性 ·length:返回字符串中字符长度 let str = 'asd '; str.length = 1;//无法手动修改,只读 console.log(str.length);//4 二 ...

  3. 继续畅通工程(hdu1879)并查集

    继续畅通工程 Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others) Total Sub ...

  4. RestfulAPI超简单入门

    简单入门 REST -- REpresentational State Transfer,英语的直译就是"表现层状态转移" 是目前最流行的 API 设计规范,用于 Web 数据接口 ...

  5. Android - View的绘制你知道多少?

    https://github.com/android-cn/android-open-project-analysis/tree/master/tech/viewdrawflow Android-La ...

  6. OpenStack的架构详解[精51cto]

    OpenStack既是一个社区,也是一个项目和一个开源软件,它提供了一个部署云的操作平台或工具集.其宗旨在于,帮助组织运行为虚拟计算或存储服务的云,为公有云.私有云,也为大云.小云提供可扩展的.灵活的 ...

  7. django-sql注入攻击

    一.原理 什么是sql注入 所谓SQL注入就是通过把SQL命令插入到Web表单提交或输入域名或页面请求的查询字符串(注入本质上就是把输入的字符串变成可执行的程序语句),最终达到欺骗服务器执行恶意的SQ ...

  8. nodejs 爬虫模板 map&array 数据模型

    app.get('/knowledge', function (req, res, next) { var listUid = req.query.listUid; var url = "h ...

  9. 洛谷P2045 方格取数加强版(费用流)

    题意 题目链接 Sol 这题能想到费用流就不难做了 从S向(1, 1)连费用为0,流量为K的边 从(n, n)向T连费用为0,流量为K的边 对于每个点我们可以拆点限流,同时为了保证每个点只被经过一次, ...

  10. Expo大作战(二十一)--expo如何分离(detach),分离后可以比react native更有优势,但也失去了expo的部分优势,

    简要:本系列文章讲会对expo进行全面的介绍,本人从2017年6月份接触expo以来,对expo的研究断断续续,一路走来将近10个月,废话不多说,接下来你看到内容,讲全部来与官网 我猜去全部机翻+个人 ...