文章主要来自Tensorflow官方文档,同时加入了自己的理解以及部分代码

数据读取

TensorFlow程序读取数据一共有3种方法:

  • 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
  • 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
  • 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。

目录

数据读取

供给数据

TensorFlow的数据供给机制允许你在TensorFlow运算图中将数据注入到任一张量中。因此,python运算可以把数据直接设置到TensorFlow图中。通过给run()或者eval()函数输入feed_dict参数, 可以启动运算过程。

with tf.Session():
input = tf.placeholder(tf.float32)
classifier = ...
print classifier.eval(feed_dict={input: my_python_preprocessing_fn()})

虽然你可以使用常量和变量来替换任何一个张量, 但是最好的做法应该是使用placeholder op节点。设计placeholder节点的唯一的意图就是为了提供数据供给(feeding)的方法。placeholder节点被声明的时候是未初始化的, 也不包含数据, 如果没有为它供给数据, 则TensorFlow运算的时候会产生错误, 所以千万不要忘了为placeholder提供数据。可以在tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py找到使用placeholder和MNIST训练的例子,MNIST tutorial也讲述了这一例子。

从文件读取数据

一共典型的文件读取管线会包含下面这些步骤:

  • 文件名列表
  • 可配置的 文件名乱序(shuffling)
  • 可配置的 最大训练迭代数(epoch limit)
  • 文件名队列
  • 针对输入文件格式的阅读器
  • 纪录解析器
  • 可配置的 预处理器
  • 样本队列

文件名, 乱序(shuffling), 和最大训练迭代数(epoch limits)

可以使用字符串张量(比如["file0", "file1"], [("file%d" % i) for i in range(2)], [("file%d" % i) for i in range(2)]) 或者tf.train.match_filenames_once 函数来产生文件名列表。

将文件名列表交给tf.train.string_input_producer 函数.string_input_producer来生成一个先入先出的队列, 文件阅读器会需要它来读取数据。

tf.train.slice_input_producer定义了样本放入文件名队列的方式,包括迭代次数,是否乱序等,要真正将文件放入文件名队列,还需要调用tf.train.start_queue_runners 函数来启动执行文件名队列填充的线程,之后计算单元才可以把数据读出来,否则文件名队列为空的,计算单元就会处于一直等待状态,导致系统阻塞。

这个QueueRunner的工作线程是独立于文件阅读器的线程, 因此乱序和将文件名推入到文件名队列这些过程不会阻塞文件阅读器运行。

文件格式

根据你的文件格式,选择对应的文件阅读器,然后将文件名队列提供给阅读器的read方法。阅读器的read方法会输出一个key来表征输入的文件和其中的纪录(对于调试非常有用),同时得到一个字符串标量, 这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。

from __future__ import absolute_import, division, print_function
import tensorflow as tf
import os
os.environ['CUDA_VISIBLE_DEVICES']='3' # Download Titanic dataset (in csv format).
filename_queue = tf.train.string_input_producer(["iris.csv"]) # Skip 1 line from the beginning of every file if needed
reader = tf.TextLineReader(skip_header_lines=1)
key, value = reader.read(filename_queue) # Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1.0], [1.0], [1.0], [1.0], ["1"]]
col1, col2, col3, col4, col5, col6 = tf.decode_csv(
value, record_defaults=record_defaults)
features = [col1, col2, col3, col4, col5] with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord) for i in range(1200):
# Retrieve a single instance:
example, label = sess.run([features, col6])
print(label)
coord.request_stop()
coord.join(threads)

每次read的执行都会从文件中读取一行内容, decode_csv 操作会解析这一行内容并将其转为张量列表。如果输入的参数有缺失,record_default参数可以根据张量的类型来设置默认值。

在调用run或者eval去执行read之前, 你必须调用tf.train.start_queue_runners来将文件名填充到队列。否则read操作会被阻塞到文件名队列中有值为止。

固定长度的记录

从二进制文件中读取固定长度纪录, 可以使用tf.FixedLengthRecordReader的tf.decode_raw操作。decode_raw操作可以讲一个字符串转换为一个uint8的张量。

举例来说,the CIFAR-10 dataset的文件格式定义是:每条记录的长度都是固定的,一个字节的标签,后面是3072字节的图像数据。uint8的张量的标准操作就可以从中获取图像片并且根据需要进行重组。 例子代码可以在tensorflow/models/image/cifar10/cifar10_input.py找到,具体讲述可参见教程.

标准Tensorflow格式

另一种保存记录的方法可以允许你将任意的数据转换为TensorFlow所支持的格式, 这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter class写入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py就是这样的一个例子。

从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个parse_single_example操作可以将Example协议内存块(protocol buffer)解析为张量。 MNIST的例子就使用了convert_to_records 所构建的数据。 请参看tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py, 你也可以将这个例子跟fully_connected_feed的版本加以比较。接下来用下面的例子解释如何构建tfrecord文件并从tfrecord文件中读取数据

filename_queue = tf.train.string_input_producer(["/home/learning/tensorflow/iris.csv"])

# Create TFRecords
# Generate Integer Features.
def build_int64_feature(data):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[data])) # Generate Float Features.
def build_float_feature(data):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[data])) # Generate String Features.
def build_string_feature(data):
"""Returns a bytes_list from a string / byte."""
if isinstance(data, type(tf.constant(0))):
data = data.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[data])) # Generate a TF `Example` parsing all features of the dataset
def convert_to_tfexample(no, sepal_length, sepal_width, petal_length, petal_width, species):
return tf.train.Example(
features=tf.train.Features(
feature={
'no': build_int64_feature(no),
'sepal_length': build_float_feature(sepal_length),
'sepal_width': build_float_feature(sepal_width),
'petal_length': build_float_feature(petal_length),
'petal_width': build_float_feature(petal_width),
'species': build_string_feature(species),
}
)
) with open('/home/learning/tensorflow/iris.csv') as f:
with tf.python_io.TFRecordWriter('/home/learning/tensorflow/iris.tfrecord') as w:
# Generate a TF Example for all row in our dataset.
# CSV reader will read and parse all rows.
reader = csv.reader(f, skipinitialspace=True)
for i, record in enumerate(reader): if i == 0:
continue
no, sepal_length, sepal_width, petal_length, petal_width, species = record species = species.encode('utf-8') # Parse each csv row to TF Example using the above functions.
example = convert_to_tfexample(int(no), float(sepal_length), float(sepal_width), float(petal_length),
float(petal_width), species)
# Serialize each TF Example to string, and write to TFRecord file
w.write(example.SerializeToString())
# Build features template, with types.
features = {
'no': tf.FixedLenFeature([], tf.int64),
'sepal_length': tf.FixedLenFeature([], tf.float32),
'sepal_width': tf.FixedLenFeature([], tf.float32),
'petal_length': tf.FixedLenFeature([], tf.float32),
'petal_width': tf.FixedLenFeature([], tf.float32),
'species': tf.FixedLenFeature([], tf.string),
} # Create TensorFlow session.
sess = tf.Session() # Load TFRecord data.
filenames = ["/home/zhangyiran/learning/tensorflow/iris.tfrecord"]
data = tf.data.TFRecordDataset(filenames) # Parse features, using the above template.
def parse_record(record):
return tf.parse_single_example(record, features=features)
# Apply the parsing to each record from the dataset.
data = data.map(parse_record) # Refill data indefinitely.
data = data.repeat()
# Shuffle data.
data = data.shuffle(buffer_size=1000)
# Batch data (aggregate records together).
data = data.batch(batch_size=4)
# Prefetch batch (pre-load batch for faster consumption).
data = data.prefetch(buffer_size=1) # Create an iterator over the dataset.
iterator = data.make_initializable_iterator()
# Initialize the iterator.
sess.run(iterator.initializer) # Get next data batch.
x = iterator.get_next() # Dequeue data and display.
for i in range(3):
print(sess.run(x))
print("")

输出结果

{'no': array([141,  41,  88,  24]), 'petal_width': array([2.4, 0.3, 1.3, 0.5], dtype=float32), 'sepal_width': array([3.1, 3.5, 2.3, 3.3], dtype=float32), 'sepal_length': array([6.7, 5. , 6.3, 5.1], dtype=float32), 'petal_length': array([5.6, 1.3, 4.4, 1.7], dtype=float32), 'species': array([b'virginica', b'setosa', b'versicolor', b'setosa'], dtype=object)}

{'no': array([84, 56, 64, 35]), 'petal_width': array([1.6, 1.3, 1.4, 0.2], dtype=float32), 'sepal_width': array([2.7, 2.8, 2.9, 3.1], dtype=float32), 'sepal_length': array([6. , 5.7, 6.1, 4.9], dtype=float32), 'petal_length': array([5.1, 4.5, 4.7, 1.5], dtype=float32), 'species': array([b'versicolor', b'versicolor', b'versicolor', b'setosa'],
dtype=object)} {'no': array([ 21, 144, 147, 119]), 'petal_width': array([0.2, 2.3, 1.9, 2.3], dtype=float32), 'sepal_width': array([3.4, 3.2, 2.5, 2.6], dtype=float32), 'sepal_length': array([5.4, 6.8, 6.3, 7.7], dtype=float32), 'petal_length': array([1.7, 5.9, 5. , 6.9], dtype=float32), 'species': array([b'setosa', b'virginica', b'virginica', b'virginica'], dtype=object)}

预处理

你可以对输入的样本进行任意的预处理, 这些预处理不依赖于训练参数, 你可以在tensorflow/models/image/cifar10/cifar10.py找到数据归一化, 提取随机数据片,增加噪声或失真等等预处理的例子。

批处理

在数据输入管线的末端, 我们需要有另一个队列来执行输入样本的训练,评价和推理。因此我们使用tf.train.shuffle_batch函数来对队列中的样本进行乱序处理. 该函数是先将队列中数据打乱,然后再从队列里读取出来,因此队列中剩下的数据也是乱序的.

import tensorflow as tf
import numpy as np
import os def read_my_file_format(filename_queue):
reader = tf.TextLineReader()
key, value =reader.read(filename_queue) record_defaults = [[1.0],[1.0],[1.0],[1.0],["1"]]
col2, col3, col4, col5, col6 = tf.decode_csv(
value, record_defaults = record_defaults)
return [col2, col3, col4, col5], col6 def input_pipeline(filenames, batch_size, num_epochs = None):
filename_queue = tf.train.string_input_producer(
file_path, num_epochs=num_epochs, shuffle = True)
features, label = read_my_file_format(filename_queue)
min_after_dequeue = 5
capacity = min_after_dequeue+3*batch_size
features_batch, label_batch = tf.train.shuffle_batch(
[features, label], batch_size = batch_size, capacity = capacity,
min_after_dequeue = min_after_dequeue) return features_batch, label_batch file_path = ["/home/Documents/data/iris.data"]
features_batch, label_batch = input_pipeline(file_path, 10) with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord = coord)
for _ in range(5):
feat, lb = sess.run([features_batch, label_batch])
print(feat,lb)
coord.request_stop()
coord.join(threads)

输出结果(部分)

[[4.8 3.  1.4 0.1]
[4.8 3.4 1.9 0.2]
[4.8 3.4 1.6 0.2]
[5.1 3.3 1.7 0.5]
[5. 3. 1.6 0.2]
[5.2 3.4 1.4 0.2]
[5.2 3.5 1.5 0.2]
[5. 3.4 1.6 0.4]
[5.4 3.4 1.5 0.4]
[4.9 3.1 1.5 0.1]] [b'Iris-setosa' b'Iris-setosa' b'Iris-setosa' b'Iris-setosa'
b'Iris-setosa' b'Iris-setosa' b'Iris-setosa' b'Iris-setosa'
b'Iris-setosa' b'Iris-setosa']
[[5.1 3.5 1.4 0.3]
[4.8 3.1 1.6 0.2]
[5.5 4.2 1.4 0.2]
[4.7 3.2 1.6 0.2]
[4.9 3.1 1.5 0.1]
[5. 3.2 1.2 0.2]
[4.4 3. 1.3 0.2]
[5.5 3.5 1.3 0.2]
[4.4 3.2 1.3 0.2]
[4.5 2.3 1.3 0.3]] [b'Iris-setosa' b'Iris-setosa' b'Iris-setosa' b'Iris-setosa'
b'Iris-setosa' b'Iris-setosa' b'Iris-setosa' b'Iris-setosa'
b'Iris-setosa' b'Iris-setosa']

如果你需要对不同文件中的样子有更强的乱序和并行处理,可以使用tf.train.shuffle_batch_join 函数. 示例:

def read_my_file_format(filename_queue):
# Same as above def input_pipeline(filenames, batch_size, read_threads, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example_list = [read_my_file_format(filename_queue)
for _ in range(read_threads)]
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch_join(
example_list, batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch

在这个例子中, 你虽然只使用了一个文件名队列, 但是TensorFlow依然能保证多个文件阅读器从同一次迭代(epoch)的不同文件中读取数据,直到这次迭代的所有文件都被开始读取为止。(通常来说一个线程来对文件名队列进行填充的效率是足够的)

另一种替代方案是: 使用tf.train.shuffle_batch 函数,设置num_threads的值大于1, 使用多个线程在tensor_list中读取文件.这种方案可以保证同一时刻只在一个文件中进行读取操作(但是读取速度依然优于单线程),而不是之前的同时读取多个文件。这种方案的优点是:

  • 避免了两个不同的线程从同一个文件中读取同一个样本。
  • 避免了过多的磁盘搜索操作。

    你一共需要多少个读取线程呢? 函数tf.train.shuffle_batch*为TensorFlow图提供了获取文件名队列中的元素个数之和的方法. 如果你有足够多的读取线程, 文件名队列中的元素个数之和应该一直是一个略高于0的数。具体可以参考TensorBoard:可视化学习.

创建线程并使用QueueRunner对象来预取

在我们的代码中tf.train.string_input_producer()生成了文件名队列, 在TensorFlow中,队列不仅仅是一种数据结构,还是异步计算张量取值的一个重要机制。比如多个线程可以同时向一个队列中写元素,或者同时读取一个队列中的元素。TF提供了tf.Coordinatortf.QueueRunner两个类来完成多线程协同的功能.从设计上这两个类必须被一起使用. Coordinator类是线程协调器, 用来帮助多个线程协同工作,多个线程同步终止。 其主要方法有:

  • should_stop():如果线程应该停止则返回True。
  • request_stop():请求该线程停止。
  • join():等待被指定的线程终止。

QueueRunner是队列管理器,主要用于启动多个线程来操作同一个队列,启动的这些线程可以通过上面介绍的tf.Coordinator类来统一管理. QueueRunner会协调多个工作线程同时将多个张量推入同一个队列中.

在Python的训练程序中,创建一个QueueRunner来运行几个线程, 这几个线程处理样本,并且将样本推入队列. 创建一个Coordinator,让queue runner使用Coordinator来启动这些线程,创建一个训练的循环, 并且使用Coordinator来控制QueueRunner的线程们的终止, 如果你对训练迭代数做了限制,那么需要使用一个训练迭代数计数器,并且需要被初始化。推荐的代码模板如下:

# Create the graph, etc.
init_op = tf.initialize_all_variables() # Create a session for running operations in the Graph.
sess = tf.Session() # Initialize the variables (like the epoch counter).
sess.run(init_op) # Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord) try:
while not coord.should_stop():
# Run training steps or whatever
sess.run(train_op) except tf.errors.OutOfRangeError:
print 'Done training -- epoch limit reached'
finally:
# When done, ask the threads to stop.
coord.request_stop() # Wait for threads to finish.
coord.join(threads)
sess.close()

疑问:这是怎么回事

首先,我们先创建数据流图,这个数据流图由一些流水线的阶段组成,阶段间用队列连接在一起。第一阶段将生成文件名,我们读取这些文件名并且把他们排到文件名队列中。第二阶段从文件中读取数据(使用Reader),产生样本,而且把样本放在一个样本队列中。根据你的设置,实际上也可以拷贝第二阶段的样本,使得他们相互独立,这样就可以从多个文件中并行读取。在第二阶段的最后是一个排队操作,就是入队到队列中去,在下一阶段出队。因为我们是要开始运行这些入队操作的线程,所以我们的训练循环会使得样本队列中的样本不断地出队

在tf.train中要创建这些队列和执行入队操作,就要添加tf.train.QueueRunner到一个使用tf.train.add_queue_runner函数的数据流图中。每个QueueRunner负责一个阶段,处理那些需要在线程中运行的入队操作的列表。一旦数据流图构造成功,tf.train.start_queue_runners函数就会要求数据流图中每个QueueRunner去开始它的线程运行入队操作.

如果一切顺利的话,你现在可以执行你的训练步骤,同时队列也会被后台线程来填充。如果你设置了最大训练迭代数,在某些时候,样本出队的操作可能会得到一个tf.OutOfRangeError的错误。这其实是TensorFlow的“文件结束”(EOF) ———— 这就意味着已经达到了最大训练迭代数,已经没有更多可用的样本了.

最后一个因素是Coordinator。tf.train.Coordinator()创建进线程协调器. 这是负责在收到任何关闭信号的时候,让所有的线程都知道。最常用的是在发生异常时这种情况就会呈现出来,比如说其中一个线程在运行某些操作时出现错误(或一个普通的Python异常).

想要了解更多的关于threading, queues, QueueRunners, and Coordinators的内容可以看这里.

疑问: 在达到最大训练迭代数的时候如何清理关闭线程?

想象一下,你有一个模型并且设置了最大训练迭代数。这意味着,生成文件的那个线程将只会在产生OutOfRange错误之前运行许多次。该QueueRunner会捕获该错误,并且关闭文件名的队列,最后退出线程。关闭队列做了两件事情:

  • 如果还试着对文件名队列执行入队操作时将发生错误。任何线程不应该尝试去这样做,但是当队列因为其他错误而关闭时,这就会有用了。
  • 任何当前或将来出队操作要么成功(如果队列中还有足够的元素)或立即失败(发生OutOfRange错误)。它们不会防止等待更多的元素被添加到队列中,因为上面的一点已经保证了这种情况不会发生。

关键是,当在文件名队列被关闭时候,有可能还有许多文件名在该队列中,这样下一阶段的流水线(包括reader和其它预处理)还可以继续运行一段时间。 一旦文件名队列空了之后,如果后面的流水线还要尝试从文件名队列中取出一个文件名(例如,从一个已经处理完文件的reader中),这将会触发OutOfRange错误。在这种情况下,即使你可能有一个QueueRunner关联着多个线程。如果这不是在QueueRunner中的最后那个线程,OutOfRange错误仅仅只会使得一个线程退出。这使得其他那些正处理自己的最后一个文件的线程继续运行,直至他们完成为止。 (但如果假设你使用的是tf.train.Coordinator,其他类型的错误将导致所有线程停止)。一旦所有的reader线程触发OutOfRange错误,然后才是下一个队列,再是样本队列被关闭。

同样,样本队列中会有一些已经入队的元素,所以样本训练将一直持续直到样本队列中再没有样本为止。如果样本队列是一个RandomShuffleQueue,因为你使用了shuffle_batch 或者 shuffle_batch_join,所以通常不会出现以往那种队列中的元素会比min_after_dequeue 定义的更少的情况。 然而,一旦该队列被关闭,min_after_dequeue设置的限定值将失效,最终队列将为空。在这一点来说,当实际训练线程尝试从样本队列中取出数据时,将会触发OutOfRange错误,然后训练线程会退出。一旦所有的培训线程完成,tf.train.Coordinator.join会返回,你就可以正常退出了。

筛选记录或产生每个记录的多个样本

举个例子,有形式为[x, y, z]的样本,我们可以生成一批形式为[batch, x, y, z]的样本。 如果你想滤除这个记录(或许不需要这样的设置),那么可以设置batch的大小为0;但如果你需要每个记录产生多个样本,那么batch的值可以大于1。 然后很简单,只需调用批处理函数(比如: shuffle_batch or shuffle_batch_join)去设置enqueue_many=True就可以实现。enqueue_many主要是设置tensor中的数据是否能重复,如果想要实现同一个样本多次出现可以将其设置为:“True”,如果只想要其出现一次,也就是保持数据的唯一性,这时候我们将其设置为默认值:“False”

稀疏输入数据

SparseTensors这种数据类型使用队列来处理不是太好。如果要使用SparseTensors你就必须在批处理之后使用tf.parse_example 去解析字符串记录 (而不是在批处理之前使用 tf.parse_single_example) 。

预取数据

这仅用于可以完全加载到存储器中的小的数据集。有两种方法:

  • 存储在常数中。
  • 存储在变量中,初始化后,永远不要改变它的值。

    使用常数更简单一些,但是会使用更多的内存(因为常数会内联的存储在数据流图数据结构中,这个结构体可能会被复制几次)。
training_data = ...
training_labels = ...
with tf.Session():
input_data = tf.constant(training_data)
input_labels = tf.constant(training_labels)
...

要改为使用变量的方式,你就需要在数据流图建立后初始化这个变量。

training_data = ...
training_labels = ...
with tf.Session() as sess:
data_initializer = tf.placeholder(dtype=training_data.dtype,
shape=training_data.shape)
label_initializer = tf.placeholder(dtype=training_labels.dtype,
shape=training_labels.shape)
input_data = tf.Variable(data_initalizer, trainable=False, collections=[])
input_labels = tf.Variable(label_initalizer, trainable=False, collections=[])
...
sess.run(input_data.initializer,
feed_dict={data_initializer: training_data})
sess.run(input_labels.initializer,
feed_dict={label_initializer: training_lables})

设定trainable=False 可以防止该变量被数据流图的 GraphKeys.TRAINABLE_VARIABLES 收集, 这样我们就不会在训练的时候尝试更新它的值; 设定 collections=[] 可以防止GraphKeys.VARIABLES 收集后做为保存和恢复的中断点。

无论哪种方式,[tf.train.slice_input_producer function](http://www.tensorfly.cn/tfdoc/api_docs/python/io_ops.html#slice_input_producer)函数可以被用来每次产生一个切片。这样就会让样本在整个迭代中被打乱,所以在使用批处理的时候不需要再次打乱样本。所以我们不使用shuffle_batch函数,取而代之的是纯tf.train.batch(http://www.tensorfly.cn/tfdoc/api_docs/python/io_ops.html#batch) 函数。 如果要使用多个线程进行预处理,需要将num_threads参数设置为大于1的数字。

tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py 中可以找到一个MNIST例子,使用常数来预加载。 另外使用变量来预加载的例子在tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py,你可以用上面 fully_connected_feedfully_connected_reader 的描述来进行比较。

多输入管道

通常你会在一个数据集上面训练,然后在另外一个数据集上做评估计算(或称为 "eval")。 这样做的一种方法是,实际上包含两个独立的进程:

训练过程中读取输入数据,并定期将所有的训练的变量写入还原点文件)。

在计算过程中恢复还原点文件到一个推理模型中,读取有效的输入数据。

这两个进程在下面的例子中已经完成了:the example CIFAR-10 model,有以下几个好处:

eval被当做训练后变量的一个简单映射。

你甚至可以在训练完成和退出后执行eval。

你可以在同一个进程的相同的数据流图中有训练和eval,并分享他们的训练后的变量。参考the shared variables tutorial.

tensorflow学习--数据加载的更多相关文章

  1. Redis深入学习笔记(一)Redis启动数据加载流程

    这两年使用Redis从单节点到主备,从主备到一主多从,再到现在使用集群,碰到很多坑,所以决定深入学习下Redis工作原理并予以记录. 本系列主要记录了Redis工作原理的一些要点,当然配置搭建和使用这 ...

  2. arcgis python 使用光标和内存中的要素类将数据加载到要素集 学习:http://zhihu.esrichina.com.cn/article/634

    学习:http://zhihu.esrichina.com.cn/article/634使用光标和内存中的要素类将数据加载到要素集 import arcpy arcpy.env.overwriteOu ...

  3. Tensorflow 2.0 datasets数据加载

    导入包 import tensorflow as tf from tensorflow import keras 加载数据 tensorflow可以调用keras自带的datasets,很方便,就是有 ...

  4. hibernate框架学习第六天:QBC、分页查询、投影、数据加载策略、二级缓存

    QBC查询 1.简单查询 Criteria c = s.createCriteria(TeacherModel.class); 2.获取查询结果 多条:list 单挑:uniqueResult 3.分 ...

  5. python多种格式数据加载、处理与存储

    多种格式数据加载.处理与存储 实际的场景中,我们会在不同的地方遇到各种不同的数据格式(比如大家熟悉的csv与txt,比如网页HTML格式,比如XML格式),我们来一起看看python如何和这些格式的数 ...

  6. 【微信小程序】模仿58同城页面制作以及动态数据加载

    完成动态数据的加载,如下 使用上班的空余时间慢慢的学习,相信总有一天我会很熟悉的掌握这门技术. 本次学习小总结: 微信小程序使用的代码基本与HTML.CSS.JS等前段有关知识一样. 微信小程序js使 ...

  7. PyTorch 数据集类 和 数据加载类 的一些尝试

    最近在学习PyTorch,  但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实 ...

  8. cesium 学习(五) 加载场景模型

    cesium 学习(五) 加载场景模型 一.前言 现在开始实际的看看效果,目前我所接触到基本上都是使用Cesium加载模型这个内容,以及在模型上进行操作.So,现在进行一些加载模型的学习,数据的话可以 ...

  9. 旷视MegEngine数据加载与处理

    旷视MegEngine数据加载与处理 在网络训练与测试中,数据的加载和预处理往往会耗费大量的精力. MegEngine 提供了一系列接口来规范化这些处理工作. 利用 Dataset 封装一个数据集 数 ...

随机推荐

  1. Mila Fletcher:日常理财应注意的五点

    米拉·弗莱彻于2007年毕业于耶鲁大学,她是一名真正意义上的法学博士,在校期间获得了马歇尔奖学金,毕业后曾在美国多家知名律师事务所任职,目前就职于星盟全球投资公司,专注于帮助公司和客户提供法务咨询,他 ...

  2. Masterboxan INC发布《2019年可持续发展报告》

    近日,Masterboxan INC万事达资产管理有限公司(公司编号:20151264097)发布<2019年可持续发展报告>,全面回顾了在过去一年Masterboxan INC开展的可持 ...

  3. ASP.NET Core获取请求完整的Url

    在ASP.NET项目中获取请求完整的Url: 获取System.Web命名空间下的类名为HttpRequestBase的Url方法: /// <summary>在派生类中替代时,获取有关当 ...

  4. apply方法的实现原理

    apply 的核心原理: 将函数设为对象的属性 执行和删除这个函数 指定 this 到函数并传入给定参数执行函数 如果不传参数,默认指向 window Function.prototype.myApp ...

  5. 手把手教你gitlab汉化

    详细教程如下: 一.在Github上 https://gitlab.com/xhang/gitlab/-/tags 下载对应的版本到服务器中 这种-zh结尾的才是汉化包,下载速度可能比较慢,有条件的可 ...

  6. 微信小程序开发小技巧:

    小技巧:输入view.tabs_content就可以生成下面的代码. 输入p10,就可以得到: 输入jc:c得到:文字水平对齐 输入d:f得到: 输入ai:c得到: 输入bb得到: currentCo ...

  7. Redis基本数据结构之ZSet

    1.1Zset(有序集合) Zset保留了集合不能有重复成员的特性,但不同的是,有序集合中的元素可以排序.但是它和列表使用索引下标作为排序依据不同的是,它给每个元素设置一个分数(score)作为排序的 ...

  8. Python中的sklearn--KFold与StratifiedKFold

    KFold划分数据集的原理:根据n_split直接进行划分 StratifiedKFold划分数据集的原理:划分后的训练集和验证集中类别分布尽量和原数据集一样 #导入相关packages from s ...

  9. MySQL:安装与配置

    记录一次 MySQL 在Windows系统的安装配置过程 安装MySQL 0.下载社区版安装包 官网下载地址:https://dev.mysql.com/downloads/installer/ 1. ...

  10. KeyboardDemo - Android身份证号、车牌号快捷输入键盘

    Android身份证号.车牌号快捷输入键盘 项目地址 Github 键盘部分在 keyboard module 中 键盘与EditText绑定参照 MainActivity