tensorflow之数据读取探究(2)
tensorflow之tfrecord数据读取
Tensorflow关于TFRecord格式文件的处理、模型的训练的架构为:
1、获取文件列表、创建文件队列:http://blog.csdn.net/lovelyaiq/article/details/78711944(tfrecord格式,保存,读取)
2、图像预处理:http://blog.csdn.net/lovelyaiq/article/details/78716325
3、合成Batch:http://blog.csdn.net/lovelyaiq/article/details/78727189
4、设计损失函数、梯度下降算法:http://blog.csdn.net/lovelyaiq/article/details/78616736
- 首先了解tfrecord的格式;TensorFlow提供了TFRecord的格式统一管理存储数据
# tf.train.Example
message Example{
Features features = ;
}
message Features{
map<string,Features> feature = ;
}
message Feature {
oneof kind {
BytesList bytes_list = ;
FloateList float_list = ;
Int64List int64_list = ;
}
}
从定义中可以看出tf.train.Example是以字典的形式存储数据格式,string为字典的key值,字典的属性值有三种类型:bytes、float、int64。接下来通过例子说明如果通过TFRecord保存和读取文件。保存和读取用到函数分别为:tf.python_io.TFRecordWriter和tf.TFRecordReader()。
然后将原数据转换为tfrecord;(参考tensorflow/models/deeplab api)
def _int64_list_feature(values):
"""Returns a TF-Feature of int64_list. Args:
values: A scalar or list of values. Returns:
A TF-Feature.
"""
if not isinstance(values, collections.Iterable):
values = [values] return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def _bytes_list_feature(values):
"""Returns a TF-Feature of bytes. Args:
values: A string. Returns:
A TF-Feature.
"""
def norm2bytes(value):
return value.encode() if isinstance(value, str) and six.PY3 else value return tf.train.Feature(
bytes_list=tf.train.BytesList(value=[norm2bytes(values)])) def image_seg_to_tfexample(image_data, filename, height, width, seg_data):
"""Converts one image/segmentation pair to tf example. Args:
image_data: string of image data.
filename: image filename.
height: image height.
width: image width.
seg_data: string of semantic segmentation data. Returns:
tf example of one image/segmentation pair.
"""
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': _bytes_list_feature(image_data),
'image/filename': _bytes_list_feature(filename),
'image/format': _bytes_list_feature(
_IMAGE_FORMAT_MAP[FLAGS.image_format]),
'image/height': _int64_list_feature(height),
'image/width': _int64_list_feature(width),
'image/channels': _int64_list_feature(),
'image/segmentation/class/encoded': (
_bytes_list_feature(seg_data)),
'image/segmentation/class/format': _bytes_list_feature(
FLAGS.label_format),
})) def _convert_dataset(dataset_split):
"""Converts the specified dataset split to TFRecord format. Args:
dataset_split: The dataset split (e.g., train, val). Raises:
RuntimeError: If loaded image and label have different shape, or if the
image file with specified postfix could not be found.
"""
image_files = _get_files('image', dataset_split) //得到文件列表
label_files = _get_files('label', dataset_split) num_images = len(image_files)
num_per_shard = int(math.ceil(num_images / float(_NUM_SHARDS))) image_reader = build_data.ImageReader('png', channels=)
label_reader = build_data.ImageReader('png', channels=) for shard_id in range(_NUM_SHARDS): //保存_NUM_SHARDS个tfrecord文件
shard_filename = '%s-%05d-of-%05d.tfrecord' % (
dataset_split, shard_id, _NUM_SHARDS)
output_filename = os.path.join(FLAGS.output_dir, shard_filename)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_idx = shard_id * num_per_shard
end_idx = min((shard_id + ) * num_per_shard, num_images)
for i in range(start_idx, end_idx): //逐个文件进行读取转换;
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
i + , num_images, shard_id))
sys.stdout.flush()
# Read the image.
image_data = tf.gfile.FastGFile(image_files[i], 'rb').read()
height, width = image_reader.read_image_dims(image_data)
# Read the semantic segmentation annotation.
seg_data = tf.gfile.FastGFile(label_files[i], 'rb').read()
seg_height, seg_width = label_reader.read_image_dims(seg_data)
if height != seg_height or width != seg_width:
raise RuntimeError('Shape mismatched between image and label.')
# Convert to tf example.
re_match = _IMAGE_FILENAME_RE.search(image_files[i])
if re_match is None:
raise RuntimeError('Invalid image filename: ' + image_files[i])
filename = os.path.basename(re_match.group())
example = build_data.image_seg_to_tfexample(
image_data, filename, height, width, seg_data)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush() def main(unused_argv):
# Only support converting 'train' and 'val' sets for now.
for dataset_split in ['train', 'val']:
_convert_dataset(dataset_split)
然后在train时读取;(分三种,一种原始读取,一种tf.data.TFRecordDataset,一种用slim实现)分别参考:
- http://blog.csdn.net/lovelyaiq/article/details/78711944(tfrecord格式,保存,读取)
- tensorflow入门:tfrecord 和tf.data.TFRecordDataset
- https://github.com/NanqingD/DeepLabV3-Tensorflow/blob/master/libs/datasets/dataset_factory.py
- 参考deeplab实现:
import tensorflow as tf
slim = tf.contrib.slim
dataset = slim.dataset
tfexample_decoder = slim.tfexample_decoder def get_dataset(dataset_name, split_name, dataset_dir):
"""Gets an instance of slim Dataset. Args:
dataset_name: Dataset name.
split_name: A train/val Split name.
dataset_dir: The directory of the dataset sources. Returns:
An instance of slim Dataset. Raises:
ValueError: if the dataset_name or split_name is not recognized.
"""
if dataset_name not in _DATASETS_INFORMATION:
raise ValueError('The specified dataset is not supported yet.') splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes if split_name not in splits_to_sizes:
raise ValueError('data split name %s not recognized' % split_name) # Prepare the variables for different datasets.
num_classes = _DATASETS_INFORMATION[dataset_name].num_classes
ignore_label = _DATASETS_INFORMATION[dataset_name].ignore_label file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # Specify how the TF-Examples are decoded.
keys_to_features = {
'image/encoded': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/filename': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/format': tf.FixedLenFeature(
(), tf.string, default_value='jpeg'),
'image/height': tf.FixedLenFeature(
(), tf.int64, default_value=0),
'image/width': tf.FixedLenFeature(
(), tf.int64, default_value=0),
'image/segmentation/class/encoded': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/segmentation/class/format': tf.FixedLenFeature(
(), tf.string, default_value='png'),
}
items_to_handlers = {
'image': tfexample_decoder.Image(
image_key='image/encoded',
format_key='image/format',
channels=3),
'image_name': tfexample_decoder.Tensor('image/filename'),
'height': tfexample_decoder.Tensor('image/height'),
'width': tfexample_decoder.Tensor('image/width'),
'labels_class': tfexample_decoder.Image(
image_key='image/segmentation/class/encoded',
format_key='image/segmentation/class/format',
channels=1),
} decoder = tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers) return dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=splits_to_sizes[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
ignore_label=ignore_label,
num_classes=num_classes,
name=dataset_name,
multi_label=True)
- 再经过:
......
data_provider = dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
num_epochs=None if is_training else 1,
shuffle=is_training)
image, label, image_name, height, width = _get_data(data_provider,
dataset_split)
if label is not None:
if label.shape.ndims == 2:
label = tf.expand_dims(label, 2)
elif label.shape.ndims == 3 and label.shape.dims[2] == 1:
pass
else:
raise ValueError('Input label shape must be [height, width], or '
'[height, width, 1].') label.set_shape([None, None, 1])
original_image, image, label = input_preprocess.preprocess_image_and_label(
image,
label,
crop_height=crop_size[0],
crop_width=crop_size[1],
min_resize_value=min_resize_value,
max_resize_value=max_resize_value,
resize_factor=resize_factor,
min_scale_factor=min_scale_factor,
max_scale_factor=max_scale_factor,
scale_factor_step_size=scale_factor_step_size,
ignore_label=dataset.ignore_label,
is_training=is_training,
model_variant=model_variant)
sample = {
common.IMAGE: image,
common.IMAGE_NAME: image_name,
common.HEIGHT: height,
common.WIDTH: width
}
if label is not None:
sample[common.LABEL] = label if not is_training:
# Original image is only used during visualization.
sample[common.ORIGINAL_IMAGE] = original_image,
num_threads = 1 return tf.train.batch(
sample,
batch_size=batch_size,
num_threads=num_threads,
capacity=32 * batch_size,
allow_smaller_final_batch=not is_training,
dynamic_pad=True)
- 主要讨论tensorflow的tfrecord读取方法;及slim读取数据;
def read_data(is_training, split_name):
file_pattern = '{}_{}.tfrecord'.format(args.data_name, split_name)
tfrecord_path = os.path.join(args.data_dir,'records',file_pattern) if is_training:
dataset = get_dataset(tfrecord_path) //通过slim方式读取tfrecord;
image, gt_mask = extract_batch(dataset, args.batch_size, is_training)
else:
image, gt_mask = read_tfrecord(tfrecord_path) //通过原始方式读取tfrecord;
image, gt_mask = preprocess.preprocess_image(image, gt_mask, is_training)
return image, gt_mask
1. 数据处理流程
对于输入数据的处理,大体上流程都差不多,可以归结如下:
将数据转为 TFRecord 格式的多个文件
用 tf.train.match_filenames_once() 创建文件列表
用 tf.train.string_input_producer() 创建输入文件队列,可以将输入文件顺序随机打乱
用 tf.TFRecordReader() 读取文件中的数据
用 tf.parse_single_example() 解析数据
对数据进行解码及预处理
用 tf.train.shuffle_batch() 将数据组合成 batch
将 batch 用于训练
2. 输入数据处理框架
框架主要是三方面的内容:
TFRecord 输入数据格式
图像数据处理
多线程输入数据处理
3. reference:
tensorflow之数据读取探究(2)的更多相关文章
- tensorflow之数据读取探究(1)
Tensorflow中之前主要用的数据读取方式主要有: 建立placeholder,然后使用feed_dict将数据feed进placeholder进行使用.使用这种方法十分灵活,可以一下子将所有数据 ...
- Tensorflow学习-数据读取
Tensorflow数据读取方式主要包括以下三种 Preloaded data:预加载数据 Feeding: 通过Python代码读取或者产生数据,然后给后端 Reading from file: 通 ...
- 『TensorFlow』数据读取类_data.Dataset
一.资料 参考原文: TensorFlow全新的数据读取方式:Dataset API入门教程 API接口简介: TensorFlow的数据集 二.背景 注意,在TensorFlow 1.3中,Data ...
- 关于Tensorflow 的数据读取环节
Tensorflow读取数据的一般方式有下面3种: preloaded直接创建变量:在tensorflow定义图的过程中,创建常量或变量来存储数据 feed:在运行程序时,通过feed_dict传入数 ...
- 机器学习: TensorFlow 的数据读取与TFRecords 格式
最近学习tensorflow,发现其读取数据的方式看起来有些不同,所以又重新系统地看了一下文档,总得来说,tensorflow 有三种主流的数据读取方式: 1) 传送 (feeding): Pytho ...
- TensorFlow的数据读取机制
一.tensorflow读取机制图解 首先需要思考的一个问题是,什么是数据读取?以图像数据为例,读取的过程可以用下图来表示 假设我们的硬盘中有一个图片数据集0001.jpg,0002.jpg,0003 ...
- TensorFlow中数据读取之tfrecords
关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow ...
- 由浅入深之Tensorflow(3)----数据读取之TFRecords
转载自http://blog.csdn.net/u012759136/article/details/52232266 原文作者github地址 概述 关于Tensorflow读取数据,官网给出了三种 ...
- TensorFlow中数据读取—如何载入样本
考虑到要是自己去做一个项目,那么第一步是如何把数据导入到代码中,何种形式呢?是否需要做预处理?官网中给的实例mnist,数据导入都是写好的模块,那么自己的数据呢? 一.从文件中读取数据(CSV文件.二 ...
随机推荐
- 并发之线程封闭与ThreadLocal解析
并发之线程封闭与ThreadLocal解析 什么是线程封闭 实现一个好的并发并非易事,最好的并发代码就是尽量避免并发.而避免并发的最好办法就是线程封闭,那什么是线程封闭呢? 线程封闭(thread c ...
- 兼容IE8以下,获取className节点的元素(document.getElementsByClassName()兼容写法)。
因为ie8一下不兼容 document.getElementsByClassName() 功能:通过class的名字获取符合条件的元素 ...
- 2017-2018-2 20155225《网络对抗技术》实验九 Web安全基础
2017-2018-2 20155225<网络对抗技术>实验九 Web安全基础 WebGoat 1.String SQL Injection 题目是想办法得到数据库所有人的信用卡号,用Sm ...
- P3331 [ZJOI2011]礼物(GIFT)
题解: 首先转化为平面问题 对于每一个z,f(x,y)的值为它能向上延伸的最大高度 ...莫名其妙想出来的是n^4 以每个点作为右下边界n^3枚举再o(n)枚举左下边界计算z的最大值 然而很显然这种做 ...
- python常用内建模块--collections
1.namedtuple #namedtuple是一个函数,它用来创建一个自定义的tuple对象,并且规定了tuple元素的个数,并可以用属性而不是索引来引用tuple的某个元素.#这样一来,我们用n ...
- k8s 环境搭建
转自:https://blog.csdn.net/running_free/article/details/78388948 一.概述 1.简介 官方中文文档:https://www.kubernet ...
- shell scripts 之 代码量统计
代码统计1 文件only中的内容为多个文件的文件名,code如下: xargs说明:xargs 读入stdin的值, 并默认以空白或者回车作为分隔符,将分割的值作为参数传给后面紧接着的的命令行操作.- ...
- Python3 图像边界识别
# -*- coding: utf-8 -*- """ Created on Wed Mar 7 11:04:15 2018 @author: markli " ...
- 《大型分布式网站架构》学习笔记--01SOA
"学无长幼,达者为先",作者陈康贤通过3年左右时间就能写出如此著作确实令人钦佩,加油,熊二,早日成为一个合格的后端程序员. 基础概念 SOA(Service-Oriented Ar ...
- [译]javascript中的依赖注入
前言 在上文介绍过控制反转之后,本来打算写篇文章介绍下控制反转的常见模式-依赖注入.在翻看资料的时候,发现了一篇好文Dependency injection in JavaScript,就不自己折腾了 ...