【Tensorflow系列】使用Inception_resnet_v2训练自己的数据集并用Tensorboard监控
【写在前面】
用Tensorflow(TF)已实现好的卷积神经网络(CNN)模型来训练自己的数据集,验证目前较成熟模型在不同数据集上的准确度,如Inception_V3, VGG16,Inception_resnet_v2等模型。本文验证Inception_resnet_v2基于菜场实拍数据的准确性,测试数据为芹菜、鸡毛菜、青菜,各类别样本约600张,多个菜场拍摄,不同数据源。
补充:自己当初的计划是用别人预训练好的模型来再训练自己的数据集已使可以完成新的分类任务,但必须要修改代码改网络结构,并使用迁移学习(Fine-tune)
本文记录了其间的工作过程 , 相信也会有一些帮助的 : )
测试环境:Centos7.3-64位 python3.5.4(Anaconda)
目录
一.准备
1.安装python
2.安装tensorflow
3.下载TF-slim图像库
4.准备数据
5.下载模型二.训练
1.读入数据
2.构建模型
3.开始训练
4.执行脚本,训练自己的数据
5.可视化log
【问题】 tensorboard版本已更新,找不到对应包三.验证
四.测试
一.准备
1.安装python
推荐Anaconda,可创建虚拟环境,用conda命令易实现虚拟环境管理、包管理,安装包时会查出所有依赖包并一共一键安装, 链接:https://www.anaconda.com/download/
2.安装tensorflow
进入当下Anaconda的运行环境,我安装的是python2.7版,并创建3.5虚拟环境
- conda create -n py35 python=3.5 【py35是虚拟环境的名称; 输入y 安装】
- source activate py35 【激活py35环境】
- conda install tensorflow 【安装tensorflow-cpu版,有GPU可安装cpu版】
3.下载TF-slim代码库
- cd $WORKSPACE 【目录跳转到自己的工作目录下】
- git clone https://github.com/tensorflow/models/
4.准备数据
对所有训练样本按不同样本类别存在不同文件夹下
- zsy_train
- |---jimaocai
- |--- .jpg
- |--- ...
- |---qc
- |---qingcai
下面的代码是为了生成list.txt , 把不同文件夹下的图片和 数字label对应起来
- import os
- class_names_to_ids = {'jimaocai': 0, 'qc': 1, 'qingcai': 2}
- data_dir = 'flower_photos/'
- output_path = 'list.txt'
- fd = open(output_path, 'w')
- for class_name in class_names_to_ids.keys():
- images_list = os.listdir(data_dir + class_name)
- for image_name in images_list:
- fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name]))
- fd.close()
为了方便后期查看label标签,也可定义labels.txt
- jimaocai
- qc
- qingcai
随机生成训练集和验证集(在总量中随机选取350个样本作为验证集)
- import random
- _NUM_VALIDATION = 350
- _RANDOM_SEED = 0
- list_path = 'list.txt'
- train_list_path = 'list_train.txt'
- val_list_path = 'list_val.txt'
- fd = open(list_path)
- lines = fd.readlines()
- fd.close()
- random.seed(_RANDOM_SEED)
- random.shuffle(lines)
- fd = open(train_list_path, 'w')
- for line in lines[_NUM_VALIDATION:]:
- fd.write(line)
- fd.close()
- fd = open(val_list_path, 'w')
- for line in lines[:_NUM_VALIDATION]:
- fd.write(line)
- fd.close()
生成TFRecord数据
- import sys
- # sys.path.insert(0, '../models/slim/') models-master research
- sys.path.insert(0, './models/research/slim/') #把后面的路径插入到系统路径中 idx=0
- from datasets import dataset_utils
- import math
- import os
- import tensorflow as tf
- # 根据list路径 把数据转化为TFRecord
- # def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5):
- def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=3):
- fd = open(list_path)
- lines = [line.split() for line in fd]
- fd.close()
- num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))
- with tf.Graph().as_default():
- decode_jpeg_data = tf.placeholder(dtype=tf.string)
- decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)
- with tf.Session('') as sess:
- for shard_id in range(_NUM_SHARDS):
- output_path = os.path.join(output_dir,
- # 'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS))
- 'data_{:03}-of-{:03}.tfrecord'.format(shard_id, _NUM_SHARDS))
- tfrecord_writer = tf.python_io.TFRecordWriter(output_path)
- start_ndx = shard_id * num_per_shard
- end_ndx = min((shard_id + 1) * num_per_shard, len(lines))
- for i in range(start_ndx, end_ndx):
- sys.stdout.write('\r>> Converting image {}/{} shard {}'.format(
- i + 1, len(lines), shard_id))
- sys.stdout.flush()
- image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]), 'rb').read()
- image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data})
- height, width = image.shape[0], image.shape[1]
- example = dataset_utils.image_to_tfexample(
- image_data, b'jpg', height, width, int(lines[i][1]))
- tfrecord_writer.write(example.SerializeToString())
- tfrecord_writer.close()
- sys.stdout.write('\n')
- sys.stdout.flush()
- os.system('mkdir -p train')
- convert_dataset('list_train.txt', 'zsy_train', 'train/')
- os.system('mkdir -p val')
- convert_dataset('list_val.txt', 'zsy_train', 'val/')
得到的文件夹结构如下
- WORKSPACE
- ├── zsy_train
- ├── labels.txt
- ├── list_train.txt
- ├── list.txt
- ├── list_val.txt
- ├── train
- │ ├── data_000-of-.tfrecord
- │ ├── ...
- │ └── data_002-of-.tfrecord
- └── val
- ├── data_000-of-.tfrecord
- ├── ...
- └── data_002-of-.tfrecord
5.下载模型
官方提供了预训练,这里以Inception-ResNet-v2
以例
- cd $WORKSPACE/checkpoints
- wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
- tar zxf inception_resnet_v2_2016_08_30.tar.gz
二.训练
1.读入数据
读入自己的数据,需要把下面代码写入models/slim/datasets/dataset_classification.py
- import os
- import tensorflow as tf
- slim = tf.contrib.slim
- def get_dataset(dataset_dir, num_samples, num_classes, labels_to_names_path=None, file_pattern='*.tfrecord'):
- file_pattern = os.path.join(dataset_dir, file_pattern)
- keys_to_features = {
- 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
- 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
- 'image/class/label': tf.FixedLenFeature(
- [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
- }
- items_to_handlers = {
- 'image': slim.tfexample_decoder.Image(),
- 'label': slim.tfexample_decoder.Tensor('image/class/label'),
- }
- decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
- items_to_descriptions = {
- 'image': 'A color image of varying size.',
- 'label': 'A single integer between 0 and ' + str(num_classes - 1),
- }
- labels_to_names = None
- if labels_to_names_path is not None:
- fd = open(labels_to_names_path)
- labels_to_names = {i : line.strip() for i, line in enumerate(fd)}
- fd.close()
- return slim.dataset.Dataset(
- data_sources=file_pattern,
- reader=tf.TFRecordReader,
- decoder=decoder,
- num_samples=num_samples,
- items_to_descriptions=items_to_descriptions,
- num_classes=num_classes,
- labels_to_names=labels_to_names)
2.构建模型
构建模型取决于个人欲构建什么样的模型,官方都有对应模型的下载链接,只需把对应下载(下载链接:https://github.com/tensorflow/models/tree/master/research/slim)好的模型解压放入到checkpoints中即可
3.开始训练
由于是用已有模型训练自己的数据集,故需对原工程代码做适当调整。
把
- from datasets import dataset_factory
改为:
- from datasets import dataset_classification
把
- dataset = dataset_factory.get_dataset(
- FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
改为:
- dataset = dataset_classification.get_dataset(
- FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)
在
- tf.app.flags.DEFINE_string(
- 'dataset_dir', None, 'The directory where the dataset files are stored.')
后加入:
- tf.app.flags.DEFINE_integer(
- 'num_samples', 1781, 'Number of samples.')
- tf.app.flags.DEFINE_integer(
- 'num_classes', 3, 'Number of classes.')
- tf.app.flags.DEFINE_string(
- 'labels_to_names_path', None, 'Label names file path.')
4.执行脚本,训练自己的数据
- cd $WORKSPACE/models/slim #跳转到工作环境目录
- python train_image_classifier.py \ #运行脚本,后面跟的系统参数
- --train_dir=/root/workspace_mrt/model_lab/train_logs \ #train_log目录,当模型训练时,可用tensorboard命令指定该目录,动态监测
- --dataset_dir=../../../train \ #训练数据集 里面是转换好的TFRecord格式
- --num_samples=1781 \ #训练样本数,即值train_set中的总样本数,不包括valid中随机抽取350个样本
- --num_classes=3 \ #样本类别数
- --labels_to_names_path=../../../labels.txt \
- --model_name=inception_resnet_v2 \
- --checkpoint_path=../../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \ #指定模型位置
- --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
- --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
- --clone_on_cpu=True #cpu训练必须加上该参数
- #fine-tune要把 --checkpoint_path,--checkpoint_exclude_scopes,--trainable_scopes 加上
5.可视化log
为了可视化训练时的loss或其他指标,可用tensorboard,如下命令
- tensorboard --logdir=${TRAIN_DIR}
在本教程中,对应执行下面命令
tensorboard --logdir=/root/workspace_mrt/model_lab/train_logs
【问题】 tensorboard版本已更新,找不到对应包
当执行
- tensorboard --logdir=/root/workspace_mrt/model_lab/train_logs
时,得到如下错误
- ImportError: No module named 'tensorflow.tensorboard.tensorboard'
究其原因,是因为在tensorflow更新时,包的位置和所属关系改变了。执行以下代码,可解决该问题。
- cd /root/anaconda2/envs/py35/bin #跳转到对应python环境的bin目录下,修改tensorboard执行脚本代码,使之适应当前版本
vim tensorboard
把
- import tensorflow.tensorboard.tensorboard
修改为:
- import tensorboard.main
把
- sys.exit(tensorflow.tensorboard.tensorboard.main())
修改为: sys.exit(tensorboard.main.main())
wq保存,退出,重新执行
tensorboard --logdir=/root/workspace_mrt/model_lab/train_logs
命令,无报错。根据日志提示,进入ip:6006进入tensorboard界面。
三.验证
使用自己的数据集,需修改models/slim/eval_image_classifier.py
把
from datasets import dataset_factory
改为:
from datasets import dataset_classification
把
- dataset = dataset_factory.get_dataset( FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
改为:
- dataset = dataset_classification.get_dataset(
- FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)
在
- tf.app.flags.DEFINE_string(
- 'dataset_dir', None, 'The directory where the dataset files are stored.')
后加入
- tf.app.flags.DEFINE_integer(
- 'num_samples', 350, 'Number of samples.')
- tf.app.flags.DEFINE_integer(
- 'num_classes', 3, 'Number of classes.')
- tf.app.flags.DEFINE_string(
- 'labels_to_names_path', None, 'Label names file path.')
验证时执行以下命令即可:
- python eval_image_classifier.py \
- --checkpoint_path=../../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \
- --eval_dir=/root/workspace_mrt/model_lab/eval_logs \
- --dataset_dir=../../../val \
- --num_samples=350 \
- --num_classes=3 \
- --model_name=inception_resnet_v2
可以一边训练一边验证,注意使用其它的GPU或合理分配显存。
同样也可以可视化log,如果已经在可视化训练的log则建议使用其它端口,如:
tensorboard --logdir ../../../eval_logs/ --port 6007
四.测试
参考models/slim/eval_image_classifier.py
,可编写批量读取图片用模型进行推导的脚本models/slim/test_image_classifier.py
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import json
- import math
- import time
- import numpy as np
- import tensorflow as tf
- from nets import nets_factory
- from preprocessing import preprocessing_factory
- slim = tf.contrib.slim
- tf.app.flags.DEFINE_string(
- 'master', '', 'The address of the TensorFlow master to use.')
- tf.app.flags.DEFINE_string(
- 'checkpoint_path', None,
- 'The directory where the model was written to or an absolute path to a '
- 'checkpoint file.')
- tf.app.flags.DEFINE_string(
- 'test_list', '', 'Test image list.')
- tf.app.flags.DEFINE_string(
- 'test_dir', '.', 'Test image directory.')
- tf.app.flags.DEFINE_integer(
- 'batch_size', 16, 'Batch size.')
- tf.app.flags.DEFINE_integer(
- 'num_classes', 3, 'Number of classes.')
- tf.app.flags.DEFINE_integer(
- 'labels_offset', 0,
- 'An offset for the labels in the dataset. This flag is primarily used to '
- 'evaluate the VGG and ResNet architectures which do not use a background '
- 'class for the ImageNet dataset.')
- tf.app.flags.DEFINE_string(
- 'model_name', 'inception_resnet_v2', 'The name of the architecture to evaluate.')
- tf.app.flags.DEFINE_string(
- 'preprocessing_name', None, 'The name of the preprocessing to use. If left '
- 'as `None`, then the model_name flag is used.')
- tf.app.flags.DEFINE_integer(
- 'test_image_size', None, 'Eval image size')
- FLAGS = tf.app.flags.FLAGS
- def main(_):
- if not FLAGS.test_list:
- raise ValueError('You must supply the test list with --test_list')
- tf.logging.set_verbosity(tf.logging.INFO)
- with tf.Graph().as_default():
- tf_global_step = slim.get_or_create_global_step()
- ####################
- # Select the model #
- ####################
- network_fn = nets_factory.get_network_fn(
- FLAGS.model_name,
- num_classes=(FLAGS.num_classes - FLAGS.labels_offset),
- is_training=False)
- #####################################
- # Select the preprocessing function #
- #####################################
- preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
- image_preprocessing_fn = preprocessing_factory.get_preprocessing(
- preprocessing_name,
- is_training=False)
- test_image_size = FLAGS.test_image_size or network_fn.default_image_size
- if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
- checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
- else:
- checkpoint_path = FLAGS.checkpoint_path
- batch_size = FLAGS.batch_size
- tensor_input = tf.placeholder(tf.float32, [None, test_image_size, test_image_size, 3])
- logits, _ = network_fn(tensor_input)
- logits = tf.nn.top_k(logits, 5)
- config = tf.ConfigProto()
- config.gpu_options.allow_growth = True
- test_ids = [line.strip() for line in open(FLAGS.test_list)]
- tot = len(test_ids)
- results = list()
- with tf.Session(config=config) as sess:
- sess.run(tf.global_variables_initializer())
- saver = tf.train.Saver()
- saver.restore(sess, checkpoint_path)
- time_start = time.time()
- for idx in range(0, tot, batch_size):
- images = list()
- idx_end = min(tot, idx + batch_size)
- print(idx)
- for i in range(idx, idx_end):
- image_id = test_ids[i]
- test_path = os.path.join(FLAGS.test_dir, image_id)
- image = open(test_path, 'rb').read()
- image = tf.image.decode_jpeg(image, channels=3)
- processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
- processed_image = sess.run(processed_image)
- images.append(processed_image)
- images = np.array(images)
- predictions = sess.run(logits, feed_dict = {tensor_input : images}).indices
- for i in range(idx, idx_end):
- print('{} {}'.format(image_id, predictions[i - idx].tolist())
- time_total = time.time() - time_start
- print('total time: {}, total images: {}, average time: {}'.format(
- time_total, len(test_ids), time_total / len(test_ids)))
- if __name__ == '__main__':
- tf.app.run()
测试时执行以下命令即可:
- CUDA_VISIBLE_DEVICES="" python test_image_classifier.py \
- --checkpoint_path=../../../train_logs/ \
- --test_list=../../../list_val.txt \
- --test_dir=../../../val \
- --batch_size= \
- --num_classes= \
- --model_name=inception_resnet_v2
【参考】
https://github.com/tensorflow/models/tree/master/research/slim
【总结】
1.使用Tensorflow预训练模型(TF-slim)可以快速地测试已有成熟模型在不同数据集上的效果,且有利于数据代码架构和核心代码
2.若要自己实现已有成熟的网络模型,包括数据读取、训练、验证、测试,调优,则可用Tensorflow底层API或基于TF的高级API(TensorLayer TFLearn TF-slim)实现。从目前了解情况看,TensorLayer更好用,在DeepLearning中迁移学习更是如此
3.TensorFlow是一个非常庞大的架构,新手学习成本太高且难以摸清方向,故此借助下载已有models,调整参数和少量代码有助于进一步认识TensorFlow的实现原理,若想要更深入研究,手写CNN,看源码等都是不可避免的
4.当多看TensorFlow源码,迁移学习的一大利器!
【Tensorflow系列】使用Inception_resnet_v2训练自己的数据集并用Tensorboard监控的更多相关文章
- 【Tensorflow】 Object_detection之训练PASCAL VOC数据集
参考:Running Locally 1.检查数据.config文件是否配置好 可参考之前博客: Tensorflow Object_detection之配置Training Pipeline Ten ...
- TensorFlow学习笔记——LeNet-5(训练自己的数据集)
在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...
- 在C#下使用TensorFlow.NET训练自己的数据集
在C#下使用TensorFlow.NET训练自己的数据集 今天,我结合代码来详细介绍如何使用 SciSharp STACK 的 TensorFlow.NET 来训练CNN模型,该模型主要实现 图像的分 ...
- 【实践】如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统)
如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统) 一.环境配置 1. Python3.7.x(注:我用的是3.7.3.安 ...
- Caffe系列4——基于Caffe的MNIST数据集训练与测试(手把手教你使用Lenet识别手写字体)
基于Caffe的MNIST数据集训练与测试 原创:转载请注明https://www.cnblogs.com/xiaoboge/p/10688926.html 摘要 在前面的博文中,我详细介绍了Caf ...
- 【2】TensorFlow光速入门-数据预处理(得到数据集)
本文地址:https://www.cnblogs.com/tujia/p/13862351.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- Tensorflow学习教程------普通神经网络对mnist数据集分类
首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...
- Dive into TensorFlow系列(1)-静态图运行原理
接触过TensorFlow v1的朋友都知道,训练一个TF模型有三个步骤:定义输入和模型结构,创建tf.Session实例sess,执行sess.run()启动训练.不管是因为历史遗留代码或是团队保守 ...
随机推荐
- TFTP通信原理
TFTP的通信流程 TFTP共定义了五种类型的包格式,格式的区分由包数据前两个字节的Opcode字段区分,分别是: · l 读文件请求包:Read request,简写为RRQ,对应Opcode字段值 ...
- Python设计TFTP客户端
#coding=utf-8 from socket import * from threading import Thread import struct def recvData(fileName, ...
- 在thinkphp框架中使用后台传值过来的数组,在hightcart中使用数组
js的数组是和php里面数组是不一样的,所以模板文件需要先接受,然后利用Js代码转化之后再使用,接受后台的数组有几种办法 1.后台传过来的json数组,利用Js是可以接受的,然后将json数据利用js ...
- python pickle 模块的使用
用于序列化的两个模块 json:用于字符串和Python数据类型间进行转换 pickle: 用于python特有的类型和python的数据类型间进行转换 json提供四个功能:dumps,dump,l ...
- WPF 自定义ComboBox样式
一.ComboBox基本样式 ComboBox有两种状态,可编辑和不可编辑状态.通过设置IsEditable属性可以切换控件状态. 先看基本样式效果: 基本样式代码如下: <!--ComboBo ...
- java实现两个int数交换
普通方法,进阶方法,大神方法 @Test public void test3(){ int m = 5; int n = 12; //要求m和n交换位置 System.out.println(&quo ...
- O(logN)中logN的底数
转载:http://blog.csdn.net/jdbc/article/details/42173751 问题: 无论是计算机算法概论.还是数据结构书中, 关于算法的时间复杂度很多都用包含O(log ...
- python爬虫requests json与字典对象互相转换
import requests import json ''' json.loads(json_str) json字符串转换成字典 json.dumps(dict) 字典转换成json字符串 ''' ...
- 网络配置及shell基础
一:集群已做完 二:临时配置网络(ip,网关,dns)+永久配置 临时配置网络: ip: [root@localhost ~]# ifconfig [root@localhost ~]# ifc ...
- WPF绘制光滑连续贝塞尔曲线
1.需求 WPF本身没有直接把点集合绘制成曲线的函数.可以通过贝塞尔曲线函数来绘制. 贝塞尔曲线类是:BezierSegment,三次贝塞尔曲线,通过两个控制点来控制开始和结束方向. Quadrati ...