Tensorflow中使用TFRecords高效读取数据--结合Attention-over-Attention Neural Network for Reading Comprehension
原文链接:https://arxiv.org/pdf/1607.04423.pdf 本片论文主要讲了Attention Model在完形填空类的阅读理解上的应用。
转载:https://blog.csdn.net/liuchonge/article/details/73649251
在进行论文仿真的时候用到了TFRecords进行数据的读取操作,所以进行深入学习。这两天看了一下相关博客,结合该代码记录一下TFRecords的相关操作。
首先说一下为什么要使用TFRecords来进行文件的读写,在TF中数据的传入方式主要包含以下几种:
- 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
- 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
- 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。
之前都是使用1和3进行数据的操作,但是当我们遇到数据集比较大的情况时,这两种方法会及其占用内存,效率很差。那么为甚么使用TFRecords会比较快呢?因为其使用二进制存储文件,也就是将数据存储在一个内存块中,相比其它文件格式要快很多,特别是如果你使用hdd而不是ssd,因为它涉及移动磁盘阅读器头并且需要相当长的时间。总体而言,通过使用二进制文件,您可以更轻松地分发数据,使数据更好地对齐,以实现高效的读取。接下来我们看一下具体的操作。
这里可以参见官网给的建议:
Another approach is to convert whatever data you have into a supported format. This approach makes it easier to mix and match data sets and network architectures. The recommended format for TensorFlow is a TFRecords file containing tf.train.Example protocol buffers (which contain Features as a field). You write a little program that gets your data, stuffs it in an Example protocol buffer, serializes the protocol buffer to a string, and then writes the string to a TFRecords file using the tf.python_io.TFRecordWriter. For example, tensorflow/examples/how_tos/reading_data/convert_to_records.py converts MNIST data to this format. To read a file of TFRecords, use tf.TFRecordReader with the tf.parse_single_example decoder. The parse_single_example op decodes the example protocol buffers into tensors. An MNIST example using the data produced by convert_to_records can be found in tensorflow/examples/how_tos/reading_data/fully_connected_reader.py, which you can compare with the fully_connected_feed version.
个人感觉可以分成两部分,一是使用tf.train.Example协议流将文件保存成TFRecords格式的.tfrecords文件,这里主要涉及到使用tf.python_io.TFRecordWriter("train.tfrecords")
和tf.train.Example
以及tf.train.Features
三个函数,第一个是生成需要对应格式的文件,后面两个函数主要是将我们要传入的数据按照一定的格式进行规范化。这里还要提到一点就是使用TFRecords可以避免多个文件的使用,比如说我们一般会将一次要传入的数据的不同部分分别存放在不同文件夹中,question一个,answer一个,query一个等等,但是使用TFRecords之后,我们可以将一批数据同时保存在一个文件之中,这样方便我们在后续程序中的使用。
另一部分就是在训练模型时将我们生成的.tfrecords文件读入并传到模型中进行使用。这部分主要涉及到使用tf.TFRecordReader("train.tfrecords")
和tf.parse_single_example
两个函数。第一个函数是将我们的二进制文件读入,第二个则是进行解析然后得到我们想要的数据。
接下来我们结合代码进行理解:
生成TFRecords文件
这里关于要使用的数据集的介绍可以参考我的下一篇,主要是QA任务的数据集。代码如下所示:
def tokenize(index, word):
#index是每个单词对应词袋子之中的索引值,word是所有出现的单词
directories = ['cnn/questions/training/', 'cnn/questions/validation/', 'cnn/questions/test/']
for directory in directories:
#分别读取训练测试验证集的数据
out_name = directory.split('/')[-2] + '.tfrecords'
#生成对应.tfrecords文件
writer = tf.python_io.TFRecordWriter(out_name)
#每个文件夹下面都有若干文件,每个文件代表一个QA队,也就是一条训练数据
files = map(lambda file_name: directory + file_name, os.listdir(directory))
for file_name in files:
with open(file_name, 'r') as f:
lines = f.readlines()
#对每条数据分别获得文档,问题,答案三个值,并将相应单词转化为索引
document = [index[token] for token in lines[2].split()]
query = [index[token] for token in lines[4].split()]
answer = [index[token] for token in lines[6].split()]
#调用Example和Features函数将数据格式化保存起来。注意Features传入的参数应该是一个字典,方便后续读数据时的操作
example = tf.train.Example(
features = tf.train.Features(
feature = {
'document': tf.train.Feature(
int64_list=tf.train.Int64List(value=document)),
'query': tf.train.Feature(
int64_list=tf.train.Int64List(value=query)),
'answer': tf.train.Feature(
int64_list=tf.train.Int64List(value=answer))
}))
#写数据
serialized = example.SerializeToString()
writer.write(serialized)
读取.tfrecords文件
因为在读取数据之后我们可能还会进行一些额外的操作,使我们的数据格式满足模型输入,所以这里会引入一些额外的函数来实现我们的目的。这里介绍几个个人感觉较重要常用的函数。不过还是推荐到官网API去查,或者有某种需求的时候到Stack Overflow上面搜一搜,一般都能找到满足自己需求的函数。
1,string_input_producer(
其输出是一个输入管道的队列,这里需要注意的参数是num_epochs和shuffle。对于每个epoch其会将所有的文件添加到文件队列当中,如果设置shuffle,则会对文件顺序进行打乱。其对文件进行均匀采样,而不会导致上下采样。
string_tensor,
num_epochs=None,
shuffle=True,
seed=None,
capacity=32,
shared_name=None,
name=None,
cancel_op=None
)
2,shuffle_batch(
产生随机打乱之后的batch数据
tensors,
batch_size,
capacity,
min_after_dequeue,
num_threads=1,
seed=None,
enqueue_many=False,
shapes=None,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
3,sparse_ops.serialize_sparse(sp_input, name=None)
: 返回一个字符串的3-vector(1-D的tensor),分别表示索引、值、shape
4,deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None)
: 将多个稀疏的serialized_sparse合并成一个
def read_records(index=0):
#生成读取数据的队列,要指定epoches
train_queue = tf.train.string_input_producer(['training.tfrecords'], num_epochs=FLAGS.epochs)
validation_queue = tf.train.string_input_producer(['validation.tfrecords'], num_epochs=FLAGS.epochs)
test_queue = tf.train.string_input_producer(['test.tfrecords'], num_epochs=FLAGS.epochs) queue = tf.QueueBase.from_list(index, [train_queue, validation_queue, test_queue])
#定义一个recordreader对象,用于数据的读取
reader = tf.TFRecordReader()
#从之前的队列中读取数据到serialized_example
_, serialized_example = reader.read(queue)
#调用parse_single_example函数解析数据
features = tf.parse_single_example(
serialized_example,
features={
'document': tf.VarLenFeature(tf.int64),
'query': tf.VarLenFeature(tf.int64),
'answer': tf.FixedLenFeature([], tf.int64)
}) #返回索引、值、shape的三元组信息
document = sparse_ops.serialize_sparse(features['document'])
query = sparse_ops.serialize_sparse(features['query'])
answer = features['answer'] #生成batch切分数据
document_batch_serialized, query_batch_serialized, answer_batch = tf.train.shuffle_batch(
[document, query, answer], batch_size=FLAGS.batch_size,
capacity=2000,
min_after_dequeue=1000) sparse_document_batch = sparse_ops.deserialize_many_sparse(document_batch_serialized, dtype=tf.int64)
sparse_query_batch = sparse_ops.deserialize_many_sparse(query_batch_serialized, dtype=tf.int64) document_batch = tf.sparse_tensor_to_dense(sparse_document_batch)
document_weights = tf.sparse_to_dense(sparse_document_batch.indices, sparse_document_batch.shape, 1) query_batch = tf.sparse_tensor_to_dense(sparse_query_batch)
query_weights = tf.sparse_to_dense(sparse_query_batch.indices, sparse_query_batch.shape, 1) return document_batch, document_weights, query_batch, query_weights, answer_batch
最后,我们要在模型开始训练之前,执行下面两行代码:
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
这是填充队列的指令,如果不执行程序会等在队列文件的读取处无法运行。至此,我们就可以使用TFRecords来读写文件了。最后总结一下,大概格式如下,这里并未指定某种读写函数,而是可以自定义的方式用的伪代码来说一下整个流程:
def read_my_file_format(filename_queue):
reader = tf.SomeReader()
key, record_string = reader.read(filename_queue)
example, label = tf.some_decoder(record_string)
processed_example = some_processing(example)
return processed_example, label def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example, label = read_my_file_format(filename_queue)
# min_after_dequeue defines how big a buffer we will randomly sample
# from -- bigger means better shuffling but slower start up and more
# memory used.
# capacity must be larger than min_after_dequeue and the amount larger
# determines the maximum we will prefetch. Recommendation:
# min_after_dequeue + (num_threads + a small safety margin) * batch_size
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch(
[example, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch
Tensorflow中使用TFRecords高效读取数据--结合Attention-over-Attention Neural Network for Reading Comprehension的更多相关文章
- Tensorflow中使用tfrecord方式读取数据-深度学习-周振洋
本博客默认读者对神经网络与Tensorflow有一定了解,对其中的一些术语不再做具体解释.并且本博客主要以图片数据为例进行介绍,如有错误,敬请斧正. 使用Tensorflow训练神经网络时,我们可以用 ...
- Tensorflow 中(批量)读取数据的案列分析及TFRecord文件的打包与读取
内容概要: 单一数据读取方式: 第一种:slice_input_producer() # 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中 ...
- Tensorflow高效读取数据
关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow ...
- Tensorflow高效读取数据的方法
最新上传的mcnn中有完整的数据读写示例,可以参考. 关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码 ...
- TensorFlow高效读取数据的方法——TFRecord的学习
关于TensorFlow读取数据,官网给出了三种方法: 供给数据(Feeding):在TensorFlow程序运行的每一步,让python代码来供给数据. 从文件读取数据:在TensorFlow图的起 ...
- "笨方法"学习CNN图像识别(二)—— tfrecord格式高效读取数据
原文地址:https://finthon.com/learn-cnn-two-tfrecord-read-data/-- 全文阅读5分钟 -- 在本文中,你将学习到以下内容: 将图片数据制作成tfre ...
- MySQL数据库中tinyint类型字段读取数据为true和false
今天遇到这么一个问题,公司最近在做一个活动,然后数据库需要建表,其中有个字段是关于奖励发放的状态的字段,结果读取出来的值为true 一.解决读取数据为true/false的问题 场景: 字段:stat ...
- R中利用SQL语言读取数据框(sqldf库的使用)
熟悉MySQL的朋友可以使用sqldf来操作数据框 # 引入sqldf库(sqldf) library(sqldf) # 释放RMySQL库的加载(针对sqldf报错) #detach("p ...
- MySQL数据库中tinyint类型字段读取数据为true和false (MySQL的boolean和tinyint(1))
数据库一个表中有一个tinyint类型的字段,值为0或者1,如果取出来的话,0会变成false,1会变成true. MySQL保存boolean值时用1代表TRUE,0代表FALSE.boolean在 ...
随机推荐
- mybatis 批量插入 返回主键id
我们都知道Mybatis在插入单条数据的时候有两种方式返回自增主键: 1.对于支持生成自增主键的数据库:增加 useGenerateKeys和keyProperty ,<insert>标签 ...
- SpringMVC Ajax两种传参方式
1.采用@RequestParam或Request对象获取参数的方法 注:contentType必须指定为:application/x-www-form-urlencoded @ResponseBod ...
- 包装类 integer 当做 list的参数时候 会出现无法删除成功的现象
- AtCoder Grand Contest 019 A: Ice Tea Store
tourist出的题诶!想想就很高明,老年选手可能做不太动.不过A题还是按照惯例放水的. AtCoder Grand Contest 019 A: Ice Tea Store 题意:买0.25L,0. ...
- 【bzoj5173】[Jsoi2014]矩形并 扫描线+二维树状数组区间修改区间查询
题目描述 JYY有N个平面坐标系中的矩形.每一个矩形的底边都平行于X轴,侧边平行于Y轴.第i个矩形的左下角坐标为(Xi,Yi),底边长为Ai,侧边长为Bi.现在JYY打算从这N个矩形中,随机选出两个不 ...
- 【JavaScript&jQuery】购物车自动结算
<!doctype html> <html lang="en"> <head> <meta charset="UTF-8&quo ...
- Debugging QML Applications
Debugging QML Applications Console API Log console.log, console.debug, console.info, console.warn an ...
- Java (Socket,ServerSocket)与(SocketChannel,ServerSocketChannel)区别和联系
Socket 和ServerSocke 是一对 他们是java.net下面实现socket通信的类SocketChannel 和ServerSocketChannel是一对 他们是java.nio下面 ...
- Java日期格式转换
Java时间格式转换大全 import java.text.*;import java.util.Calendar;public class VeDate {/** * 获取现在时间 * ...
- 【BZOJ4300】绝世好题(动态规划)
[BZOJ4300]绝世好题(动态规划) 题面 BZOJ Description 给定一个长度为n的数列ai,求ai的子序列bi的最长长度,满足bi&bi-1!=0(2<=i<=l ...