Tensorflow细节-P196-输入数据处理框架
要点
1、filename_queue = tf.train.string_input_producer(files, shuffle=False) 表示创建一个队列来维护列表
2、min_after_dequeue = 10000queue runner线程要保证队列中至少剩下min_after_dequeue个数据。
如果min_after_dequeue设置的过少,则即使shuffle为true,也达不到好的混合效果。
3、·sess.run((tf.global_variables_initializer(),
tf.local_variables_initializer()))· 记得要加一个tf.local_variables_initializer()
import tensorflow as tf
files = tf.train.match_filenames_once("output.tfrecords") # 把文件读进来output.tfrecords
filename_queue = tf.train.string_input_producer(files, shuffle=False) # 创建一个队列来维护列表
# 读取文件。
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# 解析读取的样例。
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'pixels': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64)
})
decoded_images = tf.decode_raw(features['image_raw'],tf.uint8) # 解码图像
retyped_images = tf.cast(decoded_images, tf.float32) # 将图像转换为整数
labels = tf.cast(features['label'], tf.int32)
#pixels = tf.cast(features['pixels'],tf.int32)
images = tf.reshape(retyped_images, [784])
min_after_dequeue = 10000
batch_size = 100
capacity = min_after_dequeue + 3 * batch_size
image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size,
capacity=capacity, min_after_dequeue=min_after_dequeue)
def inference(input_tensor, weights1, biases1, weights2, biases2):
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
return tf.matmul(layer1, weights2) + biases2
# 模型相关的参数
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500
REGULARAZTION_RATE = 0.0001
TRAINING_STEPS = 5000
weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
# 计算交叉熵及其平均值
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=label_batch)
cross_entropy_mean = tf.reduce_mean(cross_entropy)
# 损失函数的计算
regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
regularaztion = regularizer(weights1) + regularizer(weights2)
loss = cross_entropy_mean + regularaztion
# 优化损失函数
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
# 初始化会话,并开始训练过程。
with tf.Session() as sess:
# tf.global_variables_initializer().run()
sess.run((tf.global_variables_initializer(),
tf.local_variables_initializer())) # 记得要加一个tf.local_variables_initializer()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 循环的训练神经网络。
for i in range(TRAINING_STEPS):
if i % 1000 == 0:
print("After %d training step(s), loss is %g " % (i, sess.run(loss)))
sess.run(train_step)
coord.request_stop()
coord.join(threads)
Tensorflow细节-P196-输入数据处理框架的更多相关文章
- TensorFlow多线程输入数据处理框架(四)——输入数据处理框架
参考书 <TensorFlow:实战Google深度学习框架>(第2版) 输入数据处理的整个流程. #!/usr/bin/env python # -*- coding: UTF-8 -* ...
- tensorflow学习笔记——多线程输入数据处理框架
之前我们学习使用TensorFlow对图像数据进行预处理的方法.虽然使用这些图像数据预处理的方法可以减少无关因素对图像识别模型效果的影响,但这些复杂的预处理过程也会减慢整个训练过程.为了避免图像预处理 ...
- TensorFlow多线程输入数据处理框架(三)——组合训练数据
参考书 <TensorFlow:实战Google深度学习框架>(第2版) 通过TensorFlow提供的tf.train.batch和tf.train.shuffle_batch函数来将单 ...
- TensorFlow多线程输入数据处理框架(二)——输入文件队列
参考书 <TensorFlow:实战Google深度学习框架>(第2版) 一个简单的程序来生成样例数据. #!/usr/bin/env python # -*- coding: UTF-8 ...
- Tensorflow多线程输入数据处理框架
Tensorflow提供了一系列的对图像进行预处理的方法,但是复杂的预处理过程会减慢整个训练过程,所以,为了避免图像的预处理成为训练神经网络效率的瓶颈,Tensorflow提供了多线程处理输入数据的框 ...
- Tensorflow多线程输入数据处理框架(一)——队列与多线程
参考书 <TensorFlow:实战Google深度学习框架>(第2版) 对于队列,修改队列状态的操作主要有Enqueue.EnqueueMany和Dequeue.以下程序展示了如何使用这 ...
- 吴裕雄 python 神经网络——TensorFlow 输入数据处理框架
import tensorflow as tf files = tf.train.match_filenames_once("E:\\MNIST_data\\output.tfrecords ...
- 吴裕雄--天生自然 pythonTensorFlow图形数据处理:输入数据处理框架
import tensorflow as tf # 1. 创建文件列表,通过文件列表创建输入文件队列 files = tf.train.match_filenames_once("F:\\o ...
- [Tensorflow实战Google深度学习框架]笔记4
本系列为Tensorflow实战Google深度学习框架知识笔记,仅为博主看书过程中觉得较为重要的知识点,简单摘要下来,内容较为零散,请见谅. 2017-11-06 [第五章] MNIST数字识别问题 ...
随机推荐
- Response知识点小结
HTTP协议: 1. 响应消息:服务器端发送给客户端的数据 * 数据格式: 1. 响应行 1. 组成:协议/版本 响应状态码 状态码描述 2. 响应状态码:服务器告诉客户端浏览器本次请求和响应的一个状 ...
- Java开发笔记(一百三十七)JavaFX的标签
前面介绍了JavaFX的窗口框架,其中舞台.场景.窗格都能与AWT/Swing体系的相关概念一一对应,不仅如此,JavaFX的常见控件也能在Swing中找到相应的控件.比如JavaFX的按钮控件名叫B ...
- Go语言 ( 切片)
本文主要介绍Go语言中切片(slice)及它的基本使用. 引子 因为数组的长度是固定的并且数组长度属于类型的一部分,所以数组有很多的局限性. 例如: func arraySum(x []int) in ...
- JavaScript进行UTF-8编码与解码
JavaScript本身可通过charCodeAt方法得到一个字符的Unicode编码,并通过fromCharCode方法将Unicode编码转换成对应字符. 但charCodeAt方法得到的应该是一 ...
- JDK8-lambda表达式以及接口可以定义默认方法
一.Lambda表达式 1.Lamdba Lambda 允许把函数作为一个方法的参数,使用Lamdba可以让开发的代码更加简洁,但是易读性差,新人不了解Lamdba表达式或者代码功底有点差,不容易读懂 ...
- 实战远程文件同步(Remote File Sync)
1. 远程文件同步的常见方式: 1.cron + rsync 优点: 简单 缺点:定时执行,实时性比较差:另外,rsync同步数据时,需要扫描所有文件后进行比对,进行差量传输.如果文件数量达到了百万甚 ...
- Unity - Profiler参数详解
CPU Usage ● GC Alloc - 记录了游戏运行时代码产生的堆内存分配.这会导致ManagedHeap增大,加速GC的到来.我们要尽可能避免不必要的堆内存分配,同时注意:1 ...
- elasticsearch授权访问
1.search guard插件 https://www.cnblogs.com/shifu204/p/6376683.html 2.Elasticsearch-http-basic 不支持es5,忽 ...
- 关于Eclipse导入maven项目报空指针异常
今天新建了一个maven项目,因为是通过公司的工具新建的,代码拉下来就有src.pom.xml文件. 导入Eclipse却报空指针异常.具体如下: An error has occurred. See ...
- 转换属性transform
transform: rotate(45deg);旋转 rotate(值) 值为正,表示元素顺时针旋转 值为负,表示元素逆时针旋转 transform: translate(200px,100px); ...