我们在第三十二节,使用谷歌Object Detection API进行目标检测、训练新的模型(使用VOC 2012数据集)那一节我们介绍了如何使用谷歌Object Detection API进行目标检测,以及如何使用谷歌提供的目标检测模型训练自己的数据。在训练自己的数据集时,主要包括以下几步:

  • 制作自己的数据集,注意这里数据集在进行标注时,需要按照一定的格式。然后调object_detection\dataset_tools下对应的脚本生成tfrecord文件。如下图,如果我们想调用create_pascal_tf_record.py文件生成tfrecord文件,那么我们的数据集要和voc 2012数据集的标注方式一样。你也可以通过解读create_pascal_tf_record.py文件了解我们的数据集的标注方式。

  • 下载我们所要使用的目标检测模型,进行预训练,不然从头开始训练时间成本会很高。
  • 在object_detection/samples/configs文件夹下有一些配置文件,选择与我们所要使用的目标检测模型相对应的配置文件,并进行一些修改。
  • 使用object_detection/train.py文件进行训练。
  • 使用export_inference_graph.py脚本导出训练好的模型,并进行目标检测。

在这里我主要解析一下train.py文件的工作流程。

一 train.py文件解析

先附上源码:

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================== r"""Training executable for detection models. This executable is used to train DetectionModels. There are two ways of
configuring the training job: 1) A single pipeline_pb2.TrainEvalPipelineConfig configuration file
can be specified by --pipeline_config_path. Example usage:
./train \
--logtostderr \
--train_dir=path/to/train_dir \
--pipeline_config_path=pipeline_config.pbtxt 2) Three configuration files can be provided: a model_pb2.DetectionModel
configuration file to define what type of DetectionModel is being trained, an
input_reader_pb2.InputReader file to specify what training data will be used and
a train_pb2.TrainConfig file to configure training parameters. Example usage:
./train \
--logtostderr \
--train_dir=path/to/train_dir \
--model_config_path=model_config.pbtxt \
--train_config_path=train_config.pbtxt \
--input_config_path=train_input_config.pbtxt
""" import functools
import json
import os
import tensorflow as tf from object_detection import trainer
from object_detection.builders import dataset_builder
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
from object_detection.utils import config_util
from object_detection.utils import dataset_util tf.logging.set_verbosity(tf.logging.INFO) flags = tf.app.flags
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_integer('task', 0, 'task id')
flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy per worker.')
flags.DEFINE_boolean('clone_on_cpu', False,
'Force clones to be deployed on CPU. Note that even if '
'set to False (allowing ops to run on gpu), some ops may '
'still be run on the CPU if they have no GPU kernel.')
flags.DEFINE_integer('worker_replicas', 1, 'Number of worker+trainer '
'replicas.')
flags.DEFINE_integer('ps_tasks', 0,
'Number of parameter server tasks. If None, does not use '
'a parameter server.')
flags.DEFINE_string('train_dir', '',
'Directory to save the checkpoints and training summaries.') flags.DEFINE_string('pipeline_config_path', '',
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file. If provided, other configs are ignored') flags.DEFINE_string('train_config_path', '',
'Path to a train_pb2.TrainConfig config file.')
flags.DEFINE_string('input_config_path', '',
'Path to an input_reader_pb2.InputReader config file.')
flags.DEFINE_string('model_config_path', '',
'Path to a model_pb2.DetectionModel config file.') FLAGS = flags.FLAGS def main(_):
assert FLAGS.train_dir, '`train_dir` is missing.'
if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir)
if FLAGS.pipeline_config_path:
configs = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path)
if FLAGS.task == 0:
tf.gfile.Copy(FLAGS.pipeline_config_path,
os.path.join(FLAGS.train_dir, 'pipeline.config'),
overwrite=True)
else:
configs = config_util.get_configs_from_multiple_files(
model_config_path=FLAGS.model_config_path,
train_config_path=FLAGS.train_config_path,
train_input_config_path=FLAGS.input_config_path)
if FLAGS.task == 0:
for name, config in [('model.config', FLAGS.model_config_path),
('train.config', FLAGS.train_config_path),
('input.config', FLAGS.input_config_path)]:
tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name),
overwrite=True) model_config = configs['model']
train_config = configs['train_config']
input_config = configs['train_input_config'] model_fn = functools.partial(
model_builder.build,
model_config=model_config,
is_training=True) def get_next(config):
return dataset_util.make_initializable_iterator(
dataset_builder.build(config)).get_next() create_input_dict_fn = functools.partial(get_next, input_config) env = json.loads(os.environ.get('TF_CONFIG', '{}'))
cluster_data = env.get('cluster', None)
cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
task_data = env.get('task', None) or {'type': 'master', 'index': 0}
task_info = type('TaskSpec', (object,), task_data) # Parameters for a single worker.
ps_tasks = 0
worker_replicas = 1
worker_job_name = 'lonely_worker'
task = 0
is_chief = True
master = '' if cluster_data and 'worker' in cluster_data:
# Number of total worker replicas include "worker"s and the "master".
worker_replicas = len(cluster_data['worker']) + 1
if cluster_data and 'ps' in cluster_data:
ps_tasks = len(cluster_data['ps']) if worker_replicas > 1 and ps_tasks < 1:
raise ValueError('At least 1 ps task is needed for distributed training.') if worker_replicas >= 1 and ps_tasks > 0:
# Set up distributed training.
server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc',
job_name=task_info.type,
task_index=task_info.index)
if task_info.type == 'ps':
server.join()
return worker_job_name = '%s/task:%d' % (task_info.type, task_info.index)
task = task_info.index
is_chief = (task_info.type == 'master')
master = server.target graph_rewriter_fn = None
if 'graph_rewriter_config' in configs:
graph_rewriter_fn = graph_rewriter_builder.build(
configs['graph_rewriter_config'], is_training=True) trainer.train(
create_input_dict_fn,
model_fn,
train_config,
master,
task,
FLAGS.num_clones,
worker_replicas,
FLAGS.clone_on_cpu,
ps_tasks,
worker_job_name,
is_chief,
FLAGS.train_dir,
graph_hook_fn=graph_rewriter_fn) if __name__ == '__main__':
tf.app.run()

1、先定义了tf.app.flags,用于支持接受命令行传递参数,相当于接受argv。

flags = tf.app.flags
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_integer('task', 0, 'task id')
flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy per worker.')
flags.DEFINE_boolean('clone_on_cpu', False,
'Force clones to be deployed on CPU. Note that even if '
'set to False (allowing ops to run on gpu), some ops may '
'still be run on the CPU if they have no GPU kernel.')
flags.DEFINE_integer('worker_replicas', 1, 'Number of worker+trainer '
'replicas.')
flags.DEFINE_integer('ps_tasks', 0,
'Number of parameter server tasks. If None, does not use '
'a parameter server.')
flags.DEFINE_string('train_dir', '',
'Directory to save the checkpoints and training summaries.') flags.DEFINE_string('pipeline_config_path', '',
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file. If provided, other configs are ignored') flags.DEFINE_string('train_config_path', '',
'Path to a train_pb2.TrainConfig config file.')
flags.DEFINE_string('input_config_path', '',
'Path to an input_reader_pb2.InputReader config file.')
flags.DEFINE_string('model_config_path', '',
'Path to a model_pb2.DetectionModel config file.') FLAGS = flags.FLAGS

这里面有几个比较重要的参数,train_dir目录用于保存训练的模型和日志文件,pipeline_config_path用于指定pipeline_pb2.TrainEvalPipelineConfig配置文件的全路径(如果不指定指定这个参数,需要指定train_config_path,input_config_path,model_config_path配置文件,其实这三个文件就是把pipeline_pb2.TrainEvalPipelineConfig配置文件分成了三部分)。

2、再来看一下main函数,我们把它分成几部分来解读。

假设我们在控制台下的命令如下:

python train.py --train_dir voc/train_dir/ --pipeline_config_path voc/faster_rcnn_inception_resnet_v2_atrous_voc.config

3、第一部分

  assert FLAGS.train_dir, '`train_dir` is missing.'
if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir)
if FLAGS.pipeline_config_path:
configs = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path)
if FLAGS.task == 0:
tf.gfile.Copy(FLAGS.pipeline_config_path,
os.path.join(FLAGS.train_dir, 'pipeline.config'),
overwrite=True)
else:
configs = config_util.get_configs_from_multiple_files(
model_config_path=FLAGS.model_config_path,
train_config_path=FLAGS.train_config_path,
train_input_config_path=FLAGS.input_config_path)
if FLAGS.task == 0:
for name, config in [('model.config', FLAGS.model_config_path),
('train.config', FLAGS.train_config_path),
('input.config', FLAGS.input_config_path)]:
tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name),
overwrite=True)

因为我们传入了train_dir,pipeline_config_path参数,程序执行时会:

  • 读取pipeline_config_path配置文件,返回一个dict,保存配置文件中`model`, `train_config`,  `train_input_config`, `eval_config`, `eval_input_config`信息。
  • 把pipeline_config_path配置文件复制到train_dir目录下,命名为pipeline.config

4、第二部分

  model_config = configs['model']
train_config = configs['train_config']
input_config = configs['train_input_config'] model_fn = functools.partial(
model_builder.build,
model_config=model_config,
is_training=True) def get_next(config):
return dataset_util.make_initializable_iterator(
dataset_builder.build(config)).get_next() create_input_dict_fn = functools.partial(get_next, input_config)
  • 变量model_config,train_config,input_config初始化
  • model_builder.build函数,指定两个固定参数model_config,is_training并返回一个新的函数model_fn 。这个函数很重要,包括对目标检测模型的实现,后面会详细介绍。
  • get_next函数,指定固定参数input_config。这个函数主要实现了tfrecord数据的读取,我们也放在后面介绍。

5、第三部分

  env = json.loads(os.environ.get('TF_CONFIG', '{}'))
cluster_data = env.get('cluster', None)
cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
task_data = env.get('task', None) or {'type': 'master', 'index': 0}
task_info = type('TaskSpec', (object,), task_data) # Parameters for a single worker.
ps_tasks = 0
worker_replicas = 1
worker_job_name = 'lonely_worker'
task = 0
is_chief = True
master = '' if cluster_data and 'worker' in cluster_data:
# Number of total worker replicas include "worker"s and the "master".
worker_replicas = len(cluster_data['worker']) + 1
if cluster_data and 'ps' in cluster_data:
ps_tasks = len(cluster_data['ps']) if worker_replicas > 1 and ps_tasks < 1:
raise ValueError('At least 1 ps task is needed for distributed training.') if worker_replicas >= 1 and ps_tasks > 0:
# Set up distributed training.
server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc',
job_name=task_info.type,
task_index=task_info.index)
if task_info.type == 'ps':
server.join()
return worker_job_name = '%s/task:%d' % (task_info.type, task_info.index)
task = task_info.index
is_chief = (task_info.type == 'master')
master = server.target

6、第四部分

  graph_rewriter_fn = None
if 'graph_rewriter_config' in configs:
graph_rewriter_fn = graph_rewriter_builder.build(
configs['graph_rewriter_config'], is_training=True) trainer.train(
create_input_dict_fn,
model_fn,
train_config,
master,
task,
FLAGS.num_clones,
worker_replicas,
FLAGS.clone_on_cpu,
ps_tasks,
worker_job_name,
is_chief,
FLAGS.train_dir,
graph_hook_fn=graph_rewriter_fn)
  • 由于没有定义graph_rewriter_config,所以会直接执行trainer.train,开始读取数据,进行预处理后训练。

二  dataset_builder.build函数

先附上代码:

def build(input_reader_config, transform_input_data_fn=None,
batch_size=None, max_num_boxes=None, num_classes=None,
spatial_image_shape=None):
"""Builds a tf.data.Dataset. Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
records. Applies a padded batch to the resulting dataset. Args:
input_reader_config: A input_reader_pb2.InputReader object.
transform_input_data_fn: Function to apply to all records, or None if
no extra decoding is required.
batch_size: Batch size. If None, batching is not performed.
max_num_boxes: Max number of groundtruth boxes needed to compute shapes for
padding. If None, will use a dynamic shape.
num_classes: Number of classes in the dataset needed to compute shapes for
padding. If None, will use a dynamic shape.
spatial_image_shape: A list of two integers of the form [height, width]
containing expected spatial shape of the image after applying
transform_input_data_fn. If None, will use dynamic shapes. Returns:
A tf.data.Dataset based on the input_reader_config. Raises:
ValueError: On invalid input reader proto.
ValueError: If no input paths are specified.
"""
if not isinstance(input_reader_config, input_reader_pb2.InputReader):
raise ValueError('input_reader_config not of type '
'input_reader_pb2.InputReader.') if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
config = input_reader_config.tf_record_input_reader
if not config.input_path:
raise ValueError('At least one input path must be specified in '
'`input_reader_config`.') label_map_proto_file = None
if input_reader_config.HasField('label_map_path'):
label_map_proto_file = input_reader_config.label_map_path    #初始化需要解码的字段,以及解码对应字段的 handler
decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=input_reader_config.load_instance_masks,
instance_mask_type=input_reader_config.mask_type,
label_map_proto_file=label_map_proto_file) def process_fn(value):
processed = decoder.decode(value)
if transform_input_data_fn is not None:
return transform_input_data_fn(processed)
return processed

   # 调用 tf.data.TFRecordDataset 从 config.input_path 读数据,调用 process_fn 对读取的数据解码数,预提取 input_reader_config.prefetch_size 条数据
dataset = dataset_util.read_dataset(
functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
process_fn, config.input_path[:], input_reader_config) if batch_size:
padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes,
spatial_image_shape)
dataset = dataset.apply(
tf.contrib.data.padded_batch_and_drop_remainder(batch_size,
padding_shapes))
return dataset raise ValueError('Unsupported input_reader_config.')

整个流程

  • 获取训练集tfrecord文件路径,label_map_path文件路径,input_reader_config设置参数如下:
train_input_reader: {
tf_record_input_reader {
input_path: "voc/pascal_train.record"
}
label_map_path: "voc/pascal_label_map.pbtxt"
}
  • 初始化需要解码的字段,以及解码对应字段的 handler
  • 调用 tf.data.TFRecordDataset 从 config.input_path 读数据,调用 process_fn (定义了数据的解码格式)对读取的数据解码,预提取 input_reader_config.prefetch_size 条数据
  • 对数据集应用 tf.contrib.data.padded_batch_and_drop_remainder,如果不够一个 batch_size 就丢弃该部分数据
  • 返回一个迭代器

三 model_builder.build函数

代码如下:

def build(model_config, is_training, add_summaries=True,
add_background_class=True):
"""Builds a DetectionModel based on the model config. Args:
model_config: A model.proto object containing the config for the desired
DetectionModel.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tensorflow summaries in the model graph.
add_background_class: Whether to add an implicit background class to one-hot
encodings of groundtruth labels. Set to false if using groundtruth labels
with an explicit background class or using multiclass scores instead of
truth in the case of distillation. Ignored in the case of faster_rcnn.
Returns:
DetectionModel based on the config. Raises:
ValueError: On invalid meta architecture or model.
"""
if not isinstance(model_config, model_pb2.DetectionModel):
raise ValueError('model_config not of type model_pb2.DetectionModel.')
meta_architecture = model_config.WhichOneof('model')
if meta_architecture == 'ssd':
return _build_ssd_model(model_config.ssd, is_training, add_summaries,
add_background_class)
if meta_architecture == 'faster_rcnn':
return _build_faster_rcnn_model(model_config.faster_rcnn, is_training,
add_summaries)
raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))

先获取我们使用的目标检测模型,由于我们使用的是faster_rcnn_inception_resnet_v2,因此会调用_build_faster_rcnn_model函数,并且传入参数faster_rcnn,is_training,add_summaries。其中faster_rcnn的内容如下:

model {
faster_rcnn {
num_classes:
image_resizer {
keep_aspect_ratio_resizer {
min_dimension:
max_dimension:
}
}
feature_extractor {
type: 'faster_rcnn_inception_resnet_v2'
first_stage_features_stride:
}
first_stage_anchor_generator {
grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0]
aspect_ratios: [0.5, 1.0, 2.0]
height_stride:
width_stride:
}
}
first_stage_atrous_rate:
first_stage_box_predictor_conv_hyperparams {
op: CONV
regularizer {
l2_regularizer {
weight: 0.0
}
}
initializer {
truncated_normal_initializer {
stddev: 0.01
}
}
}
first_stage_nms_score_threshold: 0.0
first_stage_nms_iou_threshold: 0.7
first_stage_max_proposals:
first_stage_localization_loss_weight: 2.0
first_stage_objectness_loss_weight: 1.0
initial_crop_size:
maxpool_kernel_size:
maxpool_stride:
second_stage_box_predictor {
mask_rcnn_box_predictor {
use_dropout: false
dropout_keep_probability: 1.0
fc_hyperparams {
op: FC
regularizer {
l2_regularizer {
weight: 0.0
}
}
initializer {
variance_scaling_initializer {
factor: 1.0
uniform: true
mode: FAN_AVG
}
}
}
}
}
second_stage_post_processing {
batch_non_max_suppression {
score_threshold: 0.0
iou_threshold: 0.6
max_detections_per_class:
max_total_detections:
}
score_converter: SOFTMAX
}
second_stage_localization_loss_weight: 2.0
second_stage_classification_loss_weight: 1.0
}
}

我们再来看一下_build_faster_rcnn_model的源码:

def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
"""Builds a Faster R-CNN or R-FCN detection model based on the model config. Builds R-FCN model if the second_stage_box_predictor in the config is of type
`rfcn_box_predictor` else builds a Faster R-CNN model. Args:
frcnn_config: A faster_rcnn.proto object containing the config for the
desired FasterRCNNMetaArch or RFCNMetaArch.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model. Returns:
FasterRCNNMetaArch based on the config. Raises:
ValueError: If frcnn_config.type is not recognized (i.e. not registered in
model_class_map).
"""
num_classes = frcnn_config.num_classes
image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer) feature_extractor = _build_faster_rcnn_feature_extractor(
frcnn_config.feature_extractor, is_training,
frcnn_config.inplace_batchnorm_update) number_of_stages = frcnn_config.number_of_stages
first_stage_anchor_generator = anchor_generator_builder.build(
frcnn_config.first_stage_anchor_generator) first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training)
first_stage_box_predictor_kernel_size = (
frcnn_config.first_stage_box_predictor_kernel_size)
first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth
first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size
first_stage_positive_balance_fraction = (
frcnn_config.first_stage_positive_balance_fraction)
first_stage_nms_score_threshold = frcnn_config.first_stage_nms_score_threshold
first_stage_nms_iou_threshold = frcnn_config.first_stage_nms_iou_threshold
first_stage_max_proposals = frcnn_config.first_stage_max_proposals
first_stage_loc_loss_weight = (
frcnn_config.first_stage_localization_loss_weight)
first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight initial_crop_size = frcnn_config.initial_crop_size
maxpool_kernel_size = frcnn_config.maxpool_kernel_size
maxpool_stride = frcnn_config.maxpool_stride second_stage_box_predictor = box_predictor_builder.build(
hyperparams_builder.build,
frcnn_config.second_stage_box_predictor,
is_training=is_training,
num_classes=num_classes)
second_stage_batch_size = frcnn_config.second_stage_batch_size
second_stage_balance_fraction = frcnn_config.second_stage_balance_fraction
(second_stage_non_max_suppression_fn, second_stage_score_conversion_fn
) = post_processing_builder.build(frcnn_config.second_stage_post_processing)
second_stage_localization_loss_weight = (
frcnn_config.second_stage_localization_loss_weight)
second_stage_classification_loss = (
losses_builder.build_faster_rcnn_classification_loss(
frcnn_config.second_stage_classification_loss))
second_stage_classification_loss_weight = (
frcnn_config.second_stage_classification_loss_weight)
second_stage_mask_prediction_loss_weight = (
frcnn_config.second_stage_mask_prediction_loss_weight) hard_example_miner = None
if frcnn_config.HasField('hard_example_miner'):
hard_example_miner = losses_builder.build_hard_example_miner(
frcnn_config.hard_example_miner,
second_stage_classification_loss_weight,
second_stage_localization_loss_weight) common_kwargs = {
'is_training': is_training,
'num_classes': num_classes,
'image_resizer_fn': image_resizer_fn,
'feature_extractor': feature_extractor,
'number_of_stages': number_of_stages,
'first_stage_anchor_generator': first_stage_anchor_generator,
'first_stage_atrous_rate': first_stage_atrous_rate,
'first_stage_box_predictor_arg_scope_fn':
first_stage_box_predictor_arg_scope_fn,
'first_stage_box_predictor_kernel_size':
first_stage_box_predictor_kernel_size,
'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
'first_stage_minibatch_size': first_stage_minibatch_size,
'first_stage_positive_balance_fraction':
first_stage_positive_balance_fraction,
'first_stage_nms_score_threshold': first_stage_nms_score_threshold,
'first_stage_nms_iou_threshold': first_stage_nms_iou_threshold,
'first_stage_max_proposals': first_stage_max_proposals,
'first_stage_localization_loss_weight': first_stage_loc_loss_weight,
'first_stage_objectness_loss_weight': first_stage_obj_loss_weight,
'second_stage_batch_size': second_stage_batch_size,
'second_stage_balance_fraction': second_stage_balance_fraction,
'second_stage_non_max_suppression_fn':
second_stage_non_max_suppression_fn,
'second_stage_score_conversion_fn': second_stage_score_conversion_fn,
'second_stage_localization_loss_weight':
second_stage_localization_loss_weight,
'second_stage_classification_loss':
second_stage_classification_loss,
'second_stage_classification_loss_weight':
second_stage_classification_loss_weight,
'hard_example_miner': hard_example_miner,
'add_summaries': add_summaries} if isinstance(second_stage_box_predictor, box_predictor.RfcnBoxPredictor):
return rfcn_meta_arch.RFCNMetaArch(
second_stage_rfcn_box_predictor=second_stage_box_predictor,
**common_kwargs)
else:
return faster_rcnn_meta_arch.FasterRCNNMetaArch(
initial_crop_size=initial_crop_size,
maxpool_kernel_size=maxpool_kernel_size,
maxpool_stride=maxpool_stride,
second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
second_stage_mask_prediction_loss_weight=(
second_stage_mask_prediction_loss_weight),
**common_kwargs)
  • 代码的前半部分是加载Faster  R-CNN模型的参数,如num_classes,第一阶段RPN网络的参数,第二阶段Fast R-CNN网络的参数。
  • 由于配置文件中指定了网络结构是faster_rcnn_inception_resnet_v2,因此会调用inception_resnet_v2网络进行特征提取(位于object_detection\object_detection\models文件夹下),特征提取分为两部分,利用inception_resnet_v2的前半部分得到供RPN网络输入的特征图,网络继续前向传播至特有卷积层,产生更高维特征图,用于Fast R-CNN网络网络的特征提取,详情参考第三十一节,目标检测算法之 Faster R-CNN算法详解。
  • 代码后半部分用来构建目标检测模型以及损失函数。

四 trainer.train函数

代码如下:

def train(create_tensor_dict_fn,
create_model_fn,
train_config,
master,
task,
num_clones,
worker_replicas,
clone_on_cpu,
ps_tasks,
worker_job_name,
is_chief,
train_dir,
graph_hook_fn=None):
"""Training function for detection models. Args:
create_tensor_dict_fn: a function to create a tensor input dictionary.
create_model_fn: a function that creates a DetectionModel and generates
losses.
train_config: a train_pb2.TrainConfig protobuf.
master: BNS name of the TensorFlow master to use.
task: The task id of this training instance.
num_clones: The number of clones to run per machine.
worker_replicas: The number of work replicas to train with.
clone_on_cpu: True if clones should be forced to run on CPU.
ps_tasks: Number of parameter server tasks.
worker_job_name: Name of the worker job.
is_chief: Whether this replica is the chief replica.
train_dir: Directory to write checkpoints and training summaries to.
graph_hook_fn: Optional function that is called after the inference graph is
built (before optimization). This is helpful to perform additional changes
to the training graph such as adding FakeQuant ops. The function should
modify the default graph.
""" detection_model = create_model_fn()
data_augmentation_options = [
preprocessor_builder.build(step)
for step in train_config.data_augmentation_options] with tf.Graph().as_default():
# Build a configuration specifying multi-GPU and multi-replicas.
deploy_config = model_deploy.DeploymentConfig(
num_clones=num_clones,
clone_on_cpu=clone_on_cpu,
replica_id=task,
num_replicas=worker_replicas,
num_ps_tasks=ps_tasks,
worker_job_name=worker_job_name) # Place the global step on the device storing the variables.
with tf.device(deploy_config.variables_device()):
global_step = slim.create_global_step() with tf.device(deploy_config.inputs_device()):
input_queue = create_input_queue(
train_config.batch_size // num_clones, create_tensor_dict_fn,
train_config.batch_queue_capacity,
train_config.num_batch_queue_threads,
train_config.prefetch_queue_capacity, data_augmentation_options) # Gather initial summaries.
# TODO(rathodv): See if summaries can be added/extracted from global tf
# collections so that they don't have to be passed around.
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
global_summaries = set([]) model_fn = functools.partial(_create_losses,
create_model_fn=create_model_fn,
train_config=train_config)
clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
first_clone_scope = clones[0].scope if graph_hook_fn:
with tf.device(deploy_config.variables_device()):
graph_hook_fn() # Gather update_ops from the first clone. These contain, for example,
# the updates for the batch_norm variables created by model_fn.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) with tf.device(deploy_config.optimizer_device()):
training_optimizer, optimizer_summary_vars = optimizer_builder.build(
train_config.optimizer)
for var in optimizer_summary_vars:
tf.summary.scalar(var.op.name, var, family='LearningRate') sync_optimizer = None
if train_config.sync_replicas:
training_optimizer = tf.train.SyncReplicasOptimizer(
training_optimizer,
replicas_to_aggregate=train_config.replicas_to_aggregate,
total_num_replicas=worker_replicas)
sync_optimizer = training_optimizer with tf.device(deploy_config.optimizer_device()):
regularization_losses = (None if train_config.add_regularization_loss
else [])
total_loss, grads_and_vars = model_deploy.optimize_clones(
clones, training_optimizer,
regularization_losses=regularization_losses)
total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
if train_config.bias_grad_multiplier:
biases_regex_list = ['.*/biases']
grads_and_vars = variables_helper.multiply_gradients_matching_regex(
grads_and_vars,
biases_regex_list,
multiplier=train_config.bias_grad_multiplier) # Optionally freeze some layers by setting their gradients to be zero.
if train_config.freeze_variables:
grads_and_vars = variables_helper.freeze_gradients_matching_regex(
grads_and_vars, train_config.freeze_variables) # Optionally clip gradients
if train_config.gradient_clipping_by_norm > 0:
with tf.name_scope('clip_grads'):
grads_and_vars = slim.learning.clip_gradient_norms(
grads_and_vars, train_config.gradient_clipping_by_norm) # Create gradient updates.
grad_updates = training_optimizer.apply_gradients(grads_and_vars,
global_step=global_step)
update_ops.append(grad_updates)
update_op = tf.group(*update_ops, name='update_barrier')
with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss, name='train_op') # Add summaries.
for model_var in slim.get_model_variables():
global_summaries.add(tf.summary.histogram('ModelVars/' +
model_var.op.name, model_var))
for loss_tensor in tf.losses.get_losses():
global_summaries.add(tf.summary.scalar('Losses/' + loss_tensor.op.name,
loss_tensor))
global_summaries.add(
tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss())) # Add the summaries from the first clone. These contain the summaries
# created by model_fn and either optimize_clones() or _gather_clone_loss().
summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
first_clone_scope))
summaries |= global_summaries # Merge all summaries together.
summary_op = tf.summary.merge(list(summaries), name='summary_op') # Soft placement allows placing on CPU ops without GPU implementation.
session_config = tf.ConfigProto(allow_soft_placement=True,
log_device_placement=False) # Save checkpoints regularly.
keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
saver = tf.train.Saver(
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) # Create ops required to initialize the model from a given checkpoint.
init_fn = None
if train_config.fine_tune_checkpoint:
if not train_config.fine_tune_checkpoint_type:
# train_config.from_detection_checkpoint field is deprecated. For
# backward compatibility, fine_tune_checkpoint_type is set based on
# from_detection_checkpoint.
if train_config.from_detection_checkpoint:
train_config.fine_tune_checkpoint_type = 'detection'
else:
train_config.fine_tune_checkpoint_type = 'classification'
var_map = detection_model.restore_map(
fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
load_all_detection_checkpoint_vars=(
train_config.load_all_detection_checkpoint_vars))
available_var_map = (variables_helper.
get_variables_available_in_checkpoint(
var_map, train_config.fine_tune_checkpoint))
init_saver = tf.train.Saver(available_var_map)
def initializer_fn(sess):
init_saver.restore(sess, train_config.fine_tune_checkpoint)
init_fn = initializer_fn slim.learning.train(
train_tensor,
logdir=train_dir,
master=master,
is_chief=is_chief,
session_config=session_config,
startup_delay_steps=train_config.startup_delay_steps,
init_fn=init_fn,
summary_op=summary_op,
number_of_steps=(
train_config.num_steps if train_config.num_steps else None),
save_summaries_secs=120,
sync_optimizer=sync_optimizer,
saver=saver)
  • 批量读数据前,通过 data_augmentation_options (数据增强)类指定预处理操作

  • 批量数据读取:创建两个队列
    队列1 : 开启 N 个线程,每个线程从数据集依次读一条数据,写入队列 1。一个线程从队列 1 每次读 batch_size 条数据
    队列2:将队列 1 出队列的数据写入队列 2, 当调用 dequeue 的事实,从队列 2 读取 batch_size 的数据。
  • 定义损失函数,优化器,Saver对象。
  • 批量读数据后,通过模型的预处理函数进行预处理 detection_model.preprocess 之后,喂给模型,开始训练。

五 总结

一个训练流程下来,我们会发现谷歌源代码封装的很严重,不易于我们之后的改进,因此自己实现一个目标检测模型显得尤为重要,后面我会简单的实现一个目标检测算法。

 参考文章

[1]Tensorflow 物体识别(object detection) 之如何优雅地读数据

[2]Tensorflow学习笔记-通过slim读取TFRecord文件

[3]Tensorflow 物体识别(object detection) 之如何优雅地预处理数据

第三十四节,目标检测之谷歌Object Detection API源码解析的更多相关文章

  1. 第三十二节,使用谷歌Object Detection API进行目标检测、训练新的模型(使用VOC 2012数据集)

    前面已经介绍了几种经典的目标检测算法,光学习理论不实践的效果并不大,这里我们使用谷歌的开源框架来实现目标检测.至于为什么不去自己实现呢?主要是因为自己实现比较麻烦,而且调参比较麻烦,我们直接利用别人的 ...

  2. 第三百三十四节,web爬虫讲解2—Scrapy框架爬虫—Scrapy爬取百度新闻,爬取Ajax动态生成的信息

    第三百三十四节,web爬虫讲解2—Scrapy框架爬虫—Scrapy爬取百度新闻,爬取Ajax动态生成的信息 crapy爬取百度新闻,爬取Ajax动态生成的信息,抓取百度新闻首页的新闻rul地址 有多 ...

  3. 风炫安全web安全学习第三十四节课 文件包含漏洞防御

    风炫安全web安全学习第三十四节课 文件包含漏洞防御 文件包含防御 在功能设计上不要把文件包含的对应文件放到前台去操作 过滤各种../,https://, http:// 配置php.ini文件 al ...

  4. 目标检测数据集The Object Detection Dataset

    目标检测数据集The Object Detection Dataset 在目标检测领域,没有像MNIST或Fashion MNIST这样的小数据集.为了快速测试模型,我们将组装一个小数据集.首先,我们 ...

  5. Tomcat源码分析三:Tomcat启动加载过程(一)的源码解析

    Tomcat启动加载过程(一)的源码解析 今天,我将分享用源码的方式讲解Tomcat启动的加载过程,关于Tomcat的架构请参阅<Tomcat源码分析二:先看看Tomcat的整体架构>一文 ...

  6. 第三十三节,目标检测之选择性搜索-Selective Search

    在基于深度学习的目标检测算法的综述 那一节中我们提到基于区域提名的目标检测中广泛使用的选择性搜索算法.并且该算法后来被应用到了R-CNN,SPP-Net,Fast R-CNN中.因此我认为还是有研究的 ...

  7. 第二十四课:jQuery.event.remove,dispatch的源码解读

    本课还是来讲解一下jQuery是如何实现它的事件系统的.这一课我们先来讲一下jQuery.event.remove的源码解读. remove方法的目的是,根据用户传参,找到事件队列,从里面把匹配的ha ...

  8. 第二百三十四节,Bootstrap表单和图片

    Bootstrap表单和图片 学习要点: 1.表单 2.图片 本节课我们主要学习一下 Bootstrap 表单和图片功能,通过内置的 CSS 定义,显示各 种丰富的效果. 一.表单 Bootstrap ...

  9. 第一百三十四节,JavaScript,封装库--遮罩锁屏

    JavaScript,封装库--遮罩锁屏 封装库新增1个方法 /** zhe_zhao_suo_ping()方法,将一个区块元素设置成遮罩锁屏区块 * 注意:一般需要在css文件将元素设置成隐藏 ** ...

随机推荐

  1. Linux基础学习(12)--Linux服务管理

    第十二章——Linux服务管理 一.服务简介与分类 1.服务的分类: 注:独立的服务放在内存中(好处:响应的速率快,坏处:独立的服务越多,耗费的内存资源越多):xinetd服务本身是独立的,在内存中, ...

  2. SQL Server2012中时间字段为DateTime和VarChar的区别

    在设计数据库的时候varchar类型是一个非常常见的类型,很多字段都可以使用这个类型,所以有时候在设计数据库的时候就很容易习惯性设计该类型,比如说时间类型,我们既可以DateTime类型,又可以使用v ...

  3. 莫烦theano学习自修第九天【过拟合问题与正规化】

    如下图所示(回归的过拟合问题):如果机器学习得到的回归为下图中的直线则是比较好的结果,但是如果进一步控制减少误差,导致机器学习到了下图中的曲线,则100%正确的学习了训练数据,看似较好,但是如果换成另 ...

  4. HDU 5025 Saving Tang Monk

    Problem Description <Journey to the West>(also <Monkey>) is one of the Four Great Classi ...

  5. QXcbConnection: Could not connect to display

    import matplotlib; matplotlib.use('agg') 注意:要添加到所有matplotlib前面,否则不起作用

  6. Logging - MVC Using Log4net Save to File and Database

    第一步:创建Config文件夹和log4net.config 第二步:在log4net.confg黏贴以下配置 <?xml version="1.0" encoding=&q ...

  7. hibernate多对多映射文件的配置

    user.hbm.xml <?xml version="1.0" encoding="UTF-8"?> <!DOCTYPE hibernate ...

  8. Django实现Rbac权限管理

    权限管理 权限管理是根据不同的用户有相应的权限功能,通常用到的权限管理理念Rbac. Rbac 基于角色的权限访问控制(Role-Based Access Control)作为传统访问控制(自主访问, ...

  9. openwrt-scripts/config/mconf: Syntax error: “(” unexpected错误解决

    scripts/config/mconf: Syntax error: “(” unexpected错误解决 从其他地方复制而来的openwrt SDK,放在本地执行make menuconfig时出 ...

  10. windows开关机事件

    开关机事件.xml <ViewerConfig> <QueryConfig> <QueryParams> <Simple> <BySource&g ...