从文件中读取数据

在TensorFlow中进行模型训练时,在官网给出的三种读取方式,中最好的文件读取方式就是将利用队列进行文件读取,而且步骤有两步:

  1. 把样本数据写入TFRecords二进制文件
  2. 从队列中读取

TFRecords二进制文件,能够更好的利用内存,更方便的移动和复制,并且不需要单独的标记文件

下面官网给出的,对mnist文件进行操作的code,具体代码请参考:tensorflow-master\tensorflow\examples\how_tos\reading_data\convert_to_records.py

https://www.sogou.com/link?url=DSOYnZeCC_pKZzihDKzFgzQoUkRGi7SFyAyslJcA_SlXxobSKiNyJA..)

生成TFRecords文件

定义主函数,给训练、验证、测试数据集做转换:

def main(unused_argv):
# Get the data.
data_sets = mnist.read_data_sets(FLAGS.directory,
dtype=tf.uint8,
reshape=False,
validation_size=FLAGS.validation_size) # Convert to Examples and write the result to TFRecords.
convert_to(data_sets.train, 'train')
convert_to(data_sets.validation, 'validation')
convert_to(data_sets.test, 'test')

转换函数的作用convert_to的主要功能是,将数据填入到协议缓冲区,并化为一个字符串,然后写入到TFRecords文件。


def convert_to(data_set, name):
"""Converts a dataset to tfrecords."""
images = data_set.images
labels = data_set.labels
num_examples = data_set.num_examples if images.shape[0] != num_examples:
raise ValueError('Images size %d does not match label size %d.' %
(images.shape[0], num_examples))
rows = images.shape[1] # 28
cols = images.shape[2] # 28
depth = images.shape[3] # 1. 是黑白图像,所以是单通道 filename = os.path.join(FLAGS.directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw = images[index].tostring() # 写入协议缓存区,height,width,depth,label编码成int64类型,image_raw 编码成二进制
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString()) # 序列化为字符串
writer.close()

编码函数如下:

def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

完整代码:

import tensorflow as tf
import os
import argparse
import sys os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #1.0 生成TFRecords 文件
from tensorflow.contrib.learn.python.learn.datasets import mnist FLAGS = None # 编码函数如下:
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def convert_to(data_set, name):
"""Converts a dataset to tfrecords."""
images = data_set.images
labels = data_set.labels
num_examples = data_set.num_examples if images.shape[0] != num_examples:
raise ValueError('Images size %d does not match label size %d.' %
(images.shape[0], num_examples))
rows = images.shape[1] # 28
cols = images.shape[2] # 28
depth = images.shape[3] # 1. 是黑白图像,所以是单通道 filename = os.path.join(FLAGS.directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw = images[index].tostring() # 写入协议缓存区,height,width,depth,label编码成int64类型,image_raw 编码成二进制
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString()) # 序列化为字符串
writer.close() def main(unused_argv):
# Get the data.
data_sets = mnist.read_data_sets(FLAGS.directory,
dtype=tf.uint8,
reshape=False,
validation_size=FLAGS.validation_size) # Convert to Examples and write the result to TFRecords.
convert_to(data_sets.train, 'train')
convert_to(data_sets.validation, 'validation')
convert_to(data_sets.test, 'test') if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--directory',
type=str,
default='MNIST_data/',
help='Directory to download data files and write the converted result'
)
parser.add_argument(
'--validation_size',
type=int,
default=5000,
help="""\
Number of examples to separate from the training data for the validation
set.\
"""
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

运行结束后,在/tmp/data下生成3个文件,即train.tfrecords,validation.tfrecords和test.tfrecords.

从队列中读取

读取TFRecords文件步骤

使用队列读取数TFRecords 文件 数据的步骤

  1. 创建张量,从二进制文件读取一个样本数据
  2. 创建张量,从二进制文件随机读取一个mini-batch
  3. 把每一批张量传入网络作为输入点

TensorFlow使用TFRecords文件训练样本的步骤

在生成文件名的序列中,设定epoch数量

训练时,设定为无穷循环

在读取数据时,如果捕捉到错误,终止

source code:tensorflow-master\tensorflow\examples\how_tos\reading_data\fully_connected_reader.py(1.2.1)

https://blog.csdn.net/fontthrone/article/details/76728083


import tensorflow as tf
import os # from tensorflow.contrib.learn.python.learn.datasets import mnist
# 注意上面的这个mnist 与 example 中的 mnist 是不同的,本文件中请使用下面的那个 mnist os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import argparse
import os.path
import sys
import time from tensorflow.examples.tutorials.mnist import mnist # Basic model parameters as external flags.
FLAGS = None # This part of the code is added by FontTian,which comes from the source code of tensorflow.examples.tutorials.mnist
# The MNIST images are always 28x28 pixels.
# IMAGE_SIZE = 28
# IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE # Constants used for dealing with the files, matches convert_to_records.
TRAIN_FILE = 'train.tfrecords'
VALIDATION_FILE = 'validation.tfrecords' def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
# 必须写明faetures 中的 key 的名称
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
}) # Convert from a scalar string tensor (whose single string has
# length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
# [mnist.IMAGE_PIXELS].
# 将一个标量字符串张量(其单个字符串的长度是mnist.image像素) # 0 维的Tensor
# 转换为一个带有形状mnist.图像像素的uint8张量。 # 一维的Tensor
image = tf.decode_raw(features['image_raw'], tf.uint8)
# print(tf.shape(image)) # Tensor("input/Shape:0", shape=(1,), dtype=int32) image.set_shape([mnist.IMAGE_PIXELS])
# print(tf.shape(image)) # Tensor("input/Shape_1:0", shape=(1,), dtype=int32) # OPTIONAL: Could reshape into a 28x28 image and apply distortions
# here. Since we are not applying any distortions in this
# example, and the next step expects the image to be flattened
# into a vector, we don't bother. # Convert from [0, 255] -> [-0.5, 0.5] floats.
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
# print(tf.shape(image)) # Tensor("input/Shape_2:0", shape=(1,), dtype=int32) # Convert label from a scalar uint8 tensor to an int32 scalar.
label = tf.cast(features['label'], tf.int32)
# print(tf.shape(label)) # Tensor("input/Shape_3:0", shape=(0,), dtype=int32) return image, label # 使用 tf.train.shuffle_batch 将前面生成的样本随机化,获得一个最小批次的张量
def inputs(train, batch_size, num_epochs):
"""Reads input data num_epochs times. Args:
train: Selects between the training (True) and validation (False) data.
batch_size: Number of examples per returned batch.
num_epochs: Number of times to read the input data, or 0/None to
train forever. Returns:
A tuple (images, labels), where:
* images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
in the range [-0.5, 0.5].
* labels is an int32 tensor with shape [batch_size] with the true label,
a number in the range [0, mnist.NUM_CLASSES).
Note that an tf.train.QueueRunner is added to the graph, which
must be run using e.g. tf.train.start_queue_runners(). 输入参数:
train: Selects between the training (True) and validation (False) data.
batch_size: 训练的每一批有多少个样本
num_epochs: 读取输入数据的次数, or 0/None 表示永远训练下去 返回结果:
A tuple (images, labels), where:
* images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
范围: [-0.5, 0.5].
* labels is an int32 tensor with shape [batch_size] with the true label,
范围: [0, mnist.NUM_CLASSES).
注意 : tf.train.QueueRunner 被添加进 graph, 它必须用 tf.train.start_queue_runners() 来启动线程. """ if not num_epochs: num_epochs = None
filename = os.path.join(FLAGS.train_dir,
TRAIN_FILE if train else VALIDATION_FILE) with tf.name_scope('input'):
# tf.train.string_input_producer 返回一个 QueueRunner,里面有一个 FIFQueue
filename_queue = tf.train.string_input_producer(
[filename], num_epochs=num_epochs)
# 如果样本数据很大,可以分成若干文件,把文件名列表传入 # Even when reading in multiple threads, share the filename queue.
image, label = read_and_decode(filename_queue) # Shuffle the examples and collect them into batch_size batches.
# (Internally uses a RandomShuffleQueue.)
# We run this in two threads to avoid being a bottleneck.
images, sparse_labels = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, num_threads=2,
capacity=1000 + 3 * batch_size,
# Ensures a minimum amount of shuffling of examples.
# 留下一部分队列,来保证每次有足够的数据做随机打乱
min_after_dequeue=1000) return images, sparse_labels def run_training():
"""Train MNIST for a number of steps.""" # Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
# Input images and labels.
images, labels = inputs(train=True, batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs) # 构建一个从推理模型来预测数据的图
logits = mnist.inference(images,
FLAGS.hidden1,
FLAGS.hidden2) # Add to the Graph the loss calculation.
# 定义损失函数
loss = mnist.loss(logits, labels) # 将模型添加到图操作中
train_op = mnist.training(loss, FLAGS.learning_rate) # 初始化变量的操作
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer()) # Create a session for running operations in the Graph.
# 在图中创建一个用于运行操作的会话
sess = tf.Session() # 初始化变量,注意:string_input_product 内部创建了一个epoch计数器
sess.run(init_op) # Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord) try:
step = 0
while not coord.should_stop():
start_time = time.time() # Run one step of the model. The return values are
# the activations from the `train_op` (which is
# discarded) and the `loss` op. To inspect the values
# of your ops or variables, you may include them in
# the list passed to sess.run() and the value tensors
# will be returned in the tuple from the call.
_, loss_value = sess.run([train_op, loss]) duration = time.time() - start_time # Print an overview fairly often.
if step % 100 == 0:
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
duration))
step += 1
except tf.errors.OutOfRangeError:
print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
finally:
# 通知其他线程关闭
coord.request_stop() # Wait for threads to finish.
coord.join(threads)
sess.close() def main(_):
run_training() if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--learning_rate',
type=float,
default=0.01,
help='Initial learning rate.'
)
parser.add_argument(
'--num_epochs',
type=int,
default=2,
help='Number of epochs to run trainer.'
)
parser.add_argument(
'--hidden1',
type=int,
default=128,
help='Number of units in hidden layer 1.'
)
parser.add_argument(
'--hidden2',
type=int,
default=32,
help='Number of units in hidden layer 2.'
)
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='Batch size.'
)
parser.add_argument(
'--train_dir',
type=str,
default='/tmp/data',
help='Directory with the training data.'
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

TF从文件中读取数据的更多相关文章

  1. java从文件中读取数据然后插入到数据库表中

    实习工作中,完成了领导交给的任务,将搜集到的数据插入到数据库中,代码片段如下: static Connection getConnection() throws SQLException, IOExc ...

  2. 【Python】从文件中读取数据

    从文件中读取数据 1.1 读取整个文件 要读取文件,需要一个包含几行文本的文件(文件PI_DESC.txt与file_reader.py在同一目录下) PI_DESC.txt 3.1415926535 ...

  3. 归纳从文件中读取数据的六种方法-JAVA IO基础总结第2篇

    在上一篇文章中,我为大家介绍了<5种创建文件并写入文件数据的方法>,本节我们为大家来介绍6种从文件中读取数据的方法. 另外为了方便大家理解,我为这一篇文章录制了对应的视频:总结java从文 ...

  4. 从txt文件中读取数据放在二维数组中

    1.我D盘中的test.txt文件内的内容是这样的,也是随机产生的二维数组 /test.txt/ 5.440000 3.4500006.610000 6.0400008.900000 3.030000 ...

  5. Java—从文件中读取数据

    1.FileInputStream() // 构建字节输入流对象,参数为文件名 FileInputStream fin = new FileInputStream("message" ...

  6. springMVC从上传的Excel文件中读取数据

    示例:导入客户文件(Excle文件) 一.编辑customer.xlsx 二.在spring的xml文件设置上传文件大小 <!-- 上传文件拦截,设置最大上传文件大小 10M=10*1024*1 ...

  7. Python从文件中读取数据(2)

    一.读取文件中各行的内容并存储到一个列表中 继续用resource.txt 举例 resource.txt my name is joker, I am 18 years old, How about ...

  8. Java从.CSV文件中读取数据和写入

    .CSV文件是以逗号分割的数据仓储,读取数据时从每一行中读取一条数据元祖,也就是一条数据,再用字符分割的方式获取表中的每一个数据项. import java.io.BufferedReader;    ...

  9. Python从文件中读取数据

    一.读取整个文件内容 在读取文件之前,我们先创建一个文本文件resource.txt作为源文件. resource.txt my name is joker, I am 18 years old, H ...

随机推荐

  1. 无法启用internet连接共享,为LAN连接配置的IP地址需要使用自动IP寻址

    热点不能用了,一直都不知道为什么,今天查了一些资料,终于知道了原因,是因为我安装了VMware Workstation Pro ,它生成了VMnet1和VMnet8所在的两个网段,这个网段就和热点共享 ...

  2. open-ldap schema (2)

    schema介绍及用途 schema 是OpenLDAP 软件的重要组成部分,主要用于控制目录树中各种条目所拥有的对象类以及各种属性的定义,并通过自身内部规范机制限定目录树条目所遵循的逻辑结构以及定义 ...

  3. 3.Linux 系统目录结构

    Linux 系统目录结构 登录系统后,在当前命令窗口下输入命令可以查看我们系统的默认文件列表:  ls /  你会看到如下图所示: 树状目录结构: 以下是对这些目录的解释: /bin:bin是Bina ...

  4. 0619数据库_MySQL_由浅入深理解索引的实现

    转自http://blog.csdn.net/u010003835/article/details/51563348 这篇文章是介绍MySQL数据库中的索引是如何根据需求一步步演变最终成为B+树结构的 ...

  5. 使用 Redis及其产品定位

    实际MySQL是适合进行海量数据存储的,通过Memcached将热点数据加载到cache,加速访问,很多公司都曾经使用过这样的架构,但随着业务数据量的不断增加,和访问量的持续增长,我们遇到了很多问题: ...

  6. CODEVS——T 1297 硬币

    http://www.codevs.cn/problem/1297/  时间限制: 1 s  空间限制: 128000 KB  题目等级 : 黄金 Gold 题解       题目描述 Descrip ...

  7. Spring深入理解(一)

    Spring 框架的设计理念与设计模式分析 Spring核心组件 Spring 框架中的核心组件只有三个:Core.Context 和 Beans Spring 的设计理念 前面介绍了 Spring ...

  8. [jQuery]jQuery获取URL参数

    // jQuery url get parameters function [获取URL的GET参数值]// <code>// var GET = $.urlGet(); //获取URL的 ...

  9. Android自己定义控件之轮播图控件

    背景 近期要做一个轮播图的效果.网上看了几篇文章.基本上都能找到实现,效果还挺不错,可是在写的时候感觉每次都要单独去又一次在Activity里写一堆代码.于是自己封装了一下.这里仅仅是做了下封装成一个 ...

  10. 剖析Mysql的InnoDB索引

    摘要: 本篇介绍下Mysql的InnoDB索引相关知识,从各种树到索引原理到存储的细节. InnoDB是Mysql的默认存储引擎(Mysql5.5.5之前是MyISAM,文档).本着高效学习的目的,本 ...