【写在前面】

用Tensorflow(TF)已实现好的卷积神经网络(CNN)模型来训练自己的数据集,验证目前较成熟模型在不同数据集上的准确度,如Inception_V3, VGG16,Inception_resnet_v2等模型。本文验证Inception_resnet_v2基于菜场实拍数据的准确性,测试数据为芹菜、鸡毛菜、青菜,各类别样本约600张,多个菜场拍摄,不同数据源。

补充:自己当初的计划是用别人预训练好的模型来再训练自己的数据集已使可以完成新的分类任务,但必须要修改代码改网络结构,并使用迁移学习(Fine-tune)

本文记录了其间的工作过程 ,  相信也会有一些帮助的  : )

测试环境:Centos7.3-64位  python3.5.4(Anaconda)

目录

一.准备

1.安装python
2.安装tensorflow
3.下载TF-slim图像库
4.准备数据
5.下载模型

二.训练

1.读入数据
2.构建模型
3.开始训练
4.执行脚本,训练自己的数据
5.可视化log
【问题】 tensorboard版本已更新,找不到对应包

三.验证

四.测试

一.准备

1.安装python

推荐Anaconda,可创建虚拟环境,用conda命令易实现虚拟环境管理、包管理,安装包时会查出所有依赖包并一共一键安装, 链接:https://www.anaconda.com/download/

2.安装tensorflow

进入当下Anaconda的运行环境,我安装的是python2.7版,并创建3.5虚拟环境

conda create -n py35 python=3.5    【py35是虚拟环境的名称; 输入y 安装】
source activate py35  【激活py35环境】
conda install tensorflow  【安装tensorflow-cpu版,有GPU可安装cpu版】

3.下载TF-slim代码库

cd  $WORKSPACE   【目录跳转到自己的工作目录下】
git clone https://github.com/tensorflow/models/

4.准备数据

对所有训练样本按不同样本类别存在不同文件夹下

zsy_train
|---jimaocai
  |--- .jpg
  |--- ...
|---qc
|---qingcai

下面的代码是为了生成list.txt , 把不同文件夹下的图片和 数字label对应起来

 import os
class_names_to_ids = {'jimaocai': 0, 'qc': 1, 'qingcai': 2}
data_dir = 'flower_photos/'
output_path = 'list.txt'
fd = open(output_path, 'w')
for class_name in class_names_to_ids.keys():
images_list = os.listdir(data_dir + class_name)
for image_name in images_list:
fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name]))
fd.close()

为了方便后期查看label标签,也可定义labels.txt

jimaocai
qc
qingcai

随机生成训练集和验证集(在总量中随机选取350个样本作为验证集)

 import random
_NUM_VALIDATION = 350
_RANDOM_SEED = 0
list_path = 'list.txt'
train_list_path = 'list_train.txt'
val_list_path = 'list_val.txt'
fd = open(list_path)
lines = fd.readlines()
fd.close()
random.seed(_RANDOM_SEED)
random.shuffle(lines)
fd = open(train_list_path, 'w')
for line in lines[_NUM_VALIDATION:]:
fd.write(line)
fd.close()
fd = open(val_list_path, 'w')
for line in lines[:_NUM_VALIDATION]:
fd.write(line)
fd.close()

生成TFRecord数据

import sys
# sys.path.insert(0, '../models/slim/') models-master research
sys.path.insert(0, './models/research/slim/') #把后面的路径插入到系统路径中 idx=0
from datasets import dataset_utils
import math
import os
import tensorflow as tf # 根据list路径 把数据转化为TFRecord
# def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5):
def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=3):
fd = open(list_path)
lines = [line.split() for line in fd]
fd.close()
num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))
with tf.Graph().as_default():
decode_jpeg_data = tf.placeholder(dtype=tf.string)
decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)
with tf.Session('') as sess:
for shard_id in range(_NUM_SHARDS):
output_path = os.path.join(output_dir,
# 'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS))
'data_{:03}-of-{:03}.tfrecord'.format(shard_id, _NUM_SHARDS))
tfrecord_writer = tf.python_io.TFRecordWriter(output_path)
start_ndx = shard_id * num_per_shard
end_ndx = min((shard_id + 1) * num_per_shard, len(lines))
for i in range(start_ndx, end_ndx):
sys.stdout.write('\r>> Converting image {}/{} shard {}'.format(
i + 1, len(lines), shard_id))
sys.stdout.flush()
image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]), 'rb').read()
image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data})
height, width = image.shape[0], image.shape[1]
example = dataset_utils.image_to_tfexample(
image_data, b'jpg', height, width, int(lines[i][1]))
tfrecord_writer.write(example.SerializeToString())
tfrecord_writer.close()
sys.stdout.write('\n')
sys.stdout.flush() os.system('mkdir -p train')
convert_dataset('list_train.txt', 'zsy_train', 'train/')
os.system('mkdir -p val')
convert_dataset('list_val.txt', 'zsy_train', 'val/')

得到的文件夹结构如下

WORKSPACE
├── zsy_train
├── labels.txt
├── list_train.txt
├── list.txt
├── list_val.txt
├── train
│ ├── data_000-of-.tfrecord
│ ├── ...
│ └── data_002-of-.tfrecord
└── val
├── data_000-of-.tfrecord
├── ...
└── data_002-of-.tfrecord

5.下载模型

官方提供了预训练,这里以Inception-ResNet-v2以例

cd $WORKSPACE/checkpoints
wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
tar zxf inception_resnet_v2_2016_08_30.tar.gz

二.训练

1.读入数据

读入自己的数据,需要把下面代码写入models/slim/datasets/dataset_classification.py

import os
import tensorflow as tf
slim = tf.contrib.slim def get_dataset(dataset_dir, num_samples, num_classes, labels_to_names_path=None, file_pattern='*.tfrecord'):
file_pattern = os.path.join(dataset_dir, file_pattern)
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
items_to_descriptions = {
'image': 'A color image of varying size.',
'label': 'A single integer between 0 and ' + str(num_classes - 1),
}
labels_to_names = None
if labels_to_names_path is not None:
fd = open(labels_to_names_path)
labels_to_names = {i : line.strip() for i, line in enumerate(fd)}
fd.close()
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=num_samples,
items_to_descriptions=items_to_descriptions,
num_classes=num_classes,
labels_to_names=labels_to_names)

2.构建模型

构建模型取决于个人欲构建什么样的模型,官方都有对应模型的下载链接,只需把对应下载(下载链接:https://github.com/tensorflow/models/tree/master/research/slim)好的模型解压放入到checkpoints中即可

3.开始训练

由于是用已有模型训练自己的数据集,故需对原工程代码做适当调整。

from datasets import dataset_factory

改为:

from datasets import dataset_classification

dataset = dataset_factory.get_dataset(
FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

改为:

dataset = dataset_classification.get_dataset(
FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

tf.app.flags.DEFINE_string(
'dataset_dir', None, 'The directory where the dataset files are stored.')

后加入:

tf.app.flags.DEFINE_integer(
'num_samples', 1781, 'Number of samples.')
tf.app.flags.DEFINE_integer(
'num_classes', 3, 'Number of classes.')
tf.app.flags.DEFINE_string(
'labels_to_names_path', None, 'Label names file path.')

4.执行脚本,训练自己的数据

cd $WORKSPACE/models/slim    #跳转到工作环境目录
python train_image_classifier.py \ #运行脚本,后面跟的系统参数
--train_dir=/root/workspace_mrt/model_lab/train_logs \ #train_log目录,当模型训练时,可用tensorboard命令指定该目录,动态监测
--dataset_dir=../../../train \ #训练数据集 里面是转换好的TFRecord格式
--num_samples=1781 \       #训练样本数,即值train_set中的总样本数,不包括valid中随机抽取350个样本
--num_classes=3 \            #样本类别数
--labels_to_names_path=../../../labels.txt \
--model_name=inception_resnet_v2 \
--checkpoint_path=../../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \   #指定模型位置
--checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
--trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
--clone_on_cpu=True        #cpu训练必须加上该参数 #fine-tune要把 --checkpoint_path,--checkpoint_exclude_scopes,--trainable_scopes 加上

5.可视化log

为了可视化训练时的loss或其他指标,可用tensorboard,如下命令

tensorboard --logdir=${TRAIN_DIR}
在本教程中,对应执行下面命令
tensorboard --logdir=/root/workspace_mrt/model_lab/train_logs

【问题】 tensorboard版本已更新,找不到对应包

当执行

tensorboard --logdir=/root/workspace_mrt/model_lab/train_logs

时,得到如下错误

ImportError: No module named 'tensorflow.tensorboard.tensorboard'

究其原因,是因为在tensorflow更新时,包的位置和所属关系改变了。执行以下代码,可解决该问题。

cd /root/anaconda2/envs/py35/bin    #跳转到对应python环境的bin目录下,修改tensorboard执行脚本代码,使之适应当前版本
vim tensorboard

import tensorflow.tensorboard.tensorboard

修改为:

import tensorboard.main

sys.exit(tensorflow.tensorboard.tensorboard.main())

修改为: sys.exit(tensorboard.main.main())

wq保存,退出,重新执行

tensorboard --logdir=/root/workspace_mrt/model_lab/train_logs

命令,无报错。根据日志提示,进入ip:6006进入tensorboard界面。

三.验证

使用自己的数据集,需修改models/slim/eval_image_classifier.py

from datasets import dataset_factory

改为:

from datasets import dataset_classification

dataset = dataset_factory.get_dataset( FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

改为:

dataset = dataset_classification.get_dataset(
FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

tf.app.flags.DEFINE_string(
'dataset_dir', None, 'The directory where the dataset files are stored.')

后加入

tf.app.flags.DEFINE_integer(
'num_samples', 350, 'Number of samples.')
tf.app.flags.DEFINE_integer(
'num_classes', 3, 'Number of classes.')
tf.app.flags.DEFINE_string(
'labels_to_names_path', None, 'Label names file path.')

验证时执行以下命令即可:

python eval_image_classifier.py \
--checkpoint_path=../../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \
--eval_dir=/root/workspace_mrt/model_lab/eval_logs \
--dataset_dir=../../../val \
--num_samples=350 \
--num_classes=3 \
--model_name=inception_resnet_v2

可以一边训练一边验证,注意使用其它的GPU或合理分配显存。

同样也可以可视化log,如果已经在可视化训练的log则建议使用其它端口,如:

tensorboard --logdir ../../../eval_logs/ --port 6007

四.测试

参考models/slim/eval_image_classifier.py,可编写批量读取图片用模型进行推导的脚本models/slim/test_image_classifier.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import json
import math
import time
import numpy as np
import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
slim = tf.contrib.slim tf.app.flags.DEFINE_string(
'master', '', 'The address of the TensorFlow master to use.')
tf.app.flags.DEFINE_string(
'checkpoint_path', None,
'The directory where the model was written to or an absolute path to a '
'checkpoint file.')
tf.app.flags.DEFINE_string(
'test_list', '', 'Test image list.')
tf.app.flags.DEFINE_string(
'test_dir', '.', 'Test image directory.')
tf.app.flags.DEFINE_integer(
'batch_size', 16, 'Batch size.')
tf.app.flags.DEFINE_integer(
'num_classes', 3, 'Number of classes.')
tf.app.flags.DEFINE_integer(
'labels_offset', 0,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.')
tf.app.flags.DEFINE_string(
'model_name', 'inception_resnet_v2', 'The name of the architecture to evaluate.')
tf.app.flags.DEFINE_string(
'preprocessing_name', None, 'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.')
tf.app.flags.DEFINE_integer(
'test_image_size', None, 'Eval image size')
FLAGS = tf.app.flags.FLAGS
def main(_):
if not FLAGS.test_list:
raise ValueError('You must supply the test list with --test_list')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default():
tf_global_step = slim.get_or_create_global_step()
####################
# Select the model #
####################
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(FLAGS.num_classes - FLAGS.labels_offset),
is_training=False)
#####################################
# Select the preprocessing function #
#####################################
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=False)
test_image_size = FLAGS.test_image_size or network_fn.default_image_size
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path
batch_size = FLAGS.batch_size
tensor_input = tf.placeholder(tf.float32, [None, test_image_size, test_image_size, 3])
logits, _ = network_fn(tensor_input)
logits = tf.nn.top_k(logits, 5)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
test_ids = [line.strip() for line in open(FLAGS.test_list)]
tot = len(test_ids)
results = list()
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
time_start = time.time()
for idx in range(0, tot, batch_size):
images = list()
idx_end = min(tot, idx + batch_size)
print(idx)
for i in range(idx, idx_end):
image_id = test_ids[i]
test_path = os.path.join(FLAGS.test_dir, image_id)
image = open(test_path, 'rb').read()
image = tf.image.decode_jpeg(image, channels=3)
processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
processed_image = sess.run(processed_image)
images.append(processed_image)
images = np.array(images)
predictions = sess.run(logits, feed_dict = {tensor_input : images}).indices
for i in range(idx, idx_end):
print('{} {}'.format(image_id, predictions[i - idx].tolist())
time_total = time.time() - time_start
print('total time: {}, total images: {}, average time: {}'.format(
time_total, len(test_ids), time_total / len(test_ids)))
if __name__ == '__main__':
tf.app.run()

测试时执行以下命令即可:

CUDA_VISIBLE_DEVICES="" python test_image_classifier.py \
--checkpoint_path=../../../train_logs/ \
--test_list=../../../list_val.txt \
--test_dir=../../../val \
--batch_size= \
--num_classes= \
--model_name=inception_resnet_v2

【参考】

https://lijiancheng0614.github.io/2017/06/29/2017_06_29_TensorFlow-Slim-image-classification/#%E9%AA%8C%E8%AF%81

https://github.com/tensorflow/models/tree/master/research/slim

【总结】

1.使用Tensorflow预训练模型(TF-slim)可以快速地测试已有成熟模型在不同数据集上的效果,且有利于数据代码架构和核心代码

2.若要自己实现已有成熟的网络模型,包括数据读取、训练、验证、测试,调优,则可用Tensorflow底层API或基于TF的高级API(TensorLayer  TFLearn  TF-slim)实现。从目前了解情况看,TensorLayer更好用,在DeepLearning中迁移学习更是如此

3.TensorFlow是一个非常庞大的架构,新手学习成本太高且难以摸清方向,故此借助下载已有models,调整参数和少量代码有助于进一步认识TensorFlow的实现原理,若想要更深入研究,手写CNN,看源码等都是不可避免的

4.当多看TensorFlow源码,迁移学习的一大利器!

【Tensorflow系列】使用Inception_resnet_v2训练自己的数据集并用Tensorboard监控的更多相关文章

  1. 【Tensorflow】 Object_detection之训练PASCAL VOC数据集

    参考:Running Locally 1.检查数据.config文件是否配置好 可参考之前博客: Tensorflow Object_detection之配置Training Pipeline Ten ...

  2. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  3. 在C#下使用TensorFlow.NET训练自己的数据集

    在C#下使用TensorFlow.NET训练自己的数据集 今天,我结合代码来详细介绍如何使用 SciSharp STACK 的 TensorFlow.NET 来训练CNN模型,该模型主要实现 图像的分 ...

  4. 【实践】如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统)

    如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统) 一.环境配置 1. Python3.7.x(注:我用的是3.7.3.安 ...

  5. Caffe系列4——基于Caffe的MNIST数据集训练与测试(手把手教你使用Lenet识别手写字体)

    基于Caffe的MNIST数据集训练与测试 原创:转载请注明https://www.cnblogs.com/xiaoboge/p/10688926.html  摘要 在前面的博文中,我详细介绍了Caf ...

  6. 【2】TensorFlow光速入门-数据预处理(得到数据集)

    本文地址:https://www.cnblogs.com/tujia/p/13862351.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  7. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  8. Tensorflow学习教程------普通神经网络对mnist数据集分类

    首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...

  9. Dive into TensorFlow系列(1)-静态图运行原理

    接触过TensorFlow v1的朋友都知道,训练一个TF模型有三个步骤:定义输入和模型结构,创建tf.Session实例sess,执行sess.run()启动训练.不管是因为历史遗留代码或是团队保守 ...

随机推荐

  1. 从PRISM开始学WPF(六)MVVM(二)Command?

    从PRISM开始学WPF(一)WPF? 从PRISM开始学WPF(二)Prism? 从PRISM开始学WPF(三)Prism-Region? 从PRISM开始学WPF(四)Prism-Module? ...

  2. java 1.7新特性

    try( ... ){ ... } catch(xxx e){ ... } java1.7特性,叫做try-with-resource,实现了AutoCloseable接口的实例可以放在try(... ...

  3. Mongodb中 Documents文档说明

    mongodb使用BSON格式存储数据记录. 如下图: 文档结构 文档有键值对组成, 有以下结构: {    field1: value1,    field2: value2,    ...     ...

  4. python的命名空间

    Python的命名空间是Python程序猿必须了解的内容,对Python命名空间的学习,将使我们在本质上掌握一些Python中的琐碎的规则. 接下来我将分四部分揭示Python命名空间的本质:一.命名 ...

  5. Oracle10g物理DG详细配置方法及步骤

    --测试环境:    OS:Redhat linux(64)    Primary:    IP:192.168.94.198    SID:dgdb1    Hostname:dg1    DB_U ...

  6. 新概念英语(1-131)Don't be so sure

    Lesson 131 Don't be so sure! 别那么肯定! Listen to the tape then answer this question. What's the problem ...

  7. HTTP协议扫盲(四)HTTP协议进阶 - MIME类型

    一.概念和原理 1.什么是MIME类型? MIME类型,即多用途互联网邮件扩展,它是一个互联网标准,在1992年最早应用于电子邮件系统,但后来也应用到浏览器. 服务器会将它们发送的多媒体数据的类型告诉 ...

  8. R数据分析 第一篇:温习概率论

    概率论是人们在长期实践中发现的理论,是客观存在的.自然界和社会上发生的现象是多种多样的,有一类现象,在一定条件下必然发生,称作确定性现象,而概率论研究的现象是不确定性现象,嗯嗯,醒醒,概率论研究的对象 ...

  9. centos单机安装zookeeper+kafaka

    环境如下: CentOS-7-x86_64zookeeper-3.4.11kafka_2.12-1.1.0 一.zookeeper下载与安装1)下载zookeeper [root@localhost ...

  10. hadoop fs:du统计hdfs文件(目录下文件)大小的用法

    hadoop fs 更多用法,请参考官网:http://hadoop.apache.org/docs/r1.0.4/cn/hdfs_shell.html 以下是我的使用统计文件时使用的记录: [t@d ...