import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 定义函数转化变量类型。
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 将数据转化为tf.train.Example格式。
def _make_example(pixels, label, image):
image_raw = image.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'pixels': _int64_feature(pixels),
'label': _int64_feature(np.argmax(label)),
'image_raw': _bytes_feature(image_raw)
}))
return example # 读取mnist训练数据。
mnist = input_data.read_data_sets("E:\\MNIST_data\\",dtype=tf.uint8, one_hot=True)
images = mnist.train.images
labels = mnist.train.labels
pixels = images.shape[1]
num_examples = mnist.train.num_examples # 输出包含训练数据的TFRecord文件。
with tf.python_io.TFRecordWriter("E:\\MNIST_data\\output.tfrecords") as writer:
for index in range(num_examples):
example = _make_example(pixels, labels[index], images[index])
writer.write(example.SerializeToString())
print("TFRecord训练文件已保存。") # 读取mnist测试数据。
images_test = mnist.test.images
labels_test = mnist.test.labels
pixels_test = images_test.shape[1]
num_examples_test = mnist.test.num_examples # 输出包含测试数据的TFRecord文件。
with tf.python_io.TFRecordWriter("E:\\MNIST_data\\output_test.tfrecords") as writer:
for index in range(num_examples_test):
example = _make_example(pixels_test, labels_test[index], images_test[index])
writer.write(example.SerializeToString())
print("TFRecord测试文件已保存。")

# 读取文件。
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer(["E:\\MNIST_data\\output.tfrecords"])
_,serialized_example = reader.read(filename_queue) # 解析读取的样例。
features = tf.parse_single_example(
serialized_example,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'pixels':tf.FixedLenFeature([],tf.int64),
'label':tf.FixedLenFeature([],tf.int64)
}) images = tf.decode_raw(features['image_raw'],tf.uint8)
labels = tf.cast(features['label'],tf.int32)
pixels = tf.cast(features['pixels'],tf.int32) sess = tf.Session() # 启动多线程处理输入数据。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord) for i in range(10):
image, label, pixel = sess.run([images, labels, pixels])

吴裕雄 python 神经网络——TensorFlow TFRecord样例程序的更多相关文章

  1. 吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf tf.reset_default_graph() # 配置神经网络的参数 INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE ...

  2. 吴裕雄 python 神经网络——TensorFlow 输入数据处理框架

    import tensorflow as tf files = tf.train.match_filenames_once("E:\\MNIST_data\\output.tfrecords ...

  3. 吴裕雄 python 神经网络——TensorFlow 输入文件队列

    import tensorflow as tf def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64 ...

  4. 吴裕雄 python 神经网络——TensorFlow 图像预处理完整样例

    import numpy as np import tensorflow as tf import matplotlib.pyplot as plt def distort_color(image, ...

  5. 吴裕雄 python 神经网络——TensorFlow 完整神经网络样例程序

    import tensorflow as tf from numpy.random import RandomState batch_size = 8 w1= tf.Variable(tf.rando ...

  6. 吴裕雄 python 神经网络——TensorFlow variables_to_restore函数的使用样例

    import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v") ema = tf.train.Expo ...

  7. 吴裕雄 python 神经网络——TensorFlow训练神经网络:卷积层、池化层样例

    import numpy as np import tensorflow as tf M = np.array([ [[1],[-1],[0]], [[-1],[2],[1]], [[0],[2],[ ...

  8. 吴裕雄 python 神经网络——TensorFlow 数据集高层操作

    import tempfile import tensorflow as tf train_files = tf.train.match_filenames_once("E:\\output ...

  9. 吴裕雄 python 神经网络——TensorFlow 数据集基本使用方法

    import tempfile import tensorflow as tf input_data = [1, 2, 3, 5, 8] dataset = tf.data.Dataset.from_ ...

随机推荐

  1. Iris路由和路由组

    package main import ( "github.com/kataras/iris" "github.com/kataras/iris/context" ...

  2. 每天进步一点点------创建Microblaze软核(二)

    第四步 进入Platform Studio操作界面通过向导创建软核后,进入到PlatformStudio——内核开发环境.Platform Studio主界面如下图. 在Ports项中,右键点击RS2 ...

  3. mongo shell远程连接使用数据库

    mongo mydb --username user1 --host --password --username 用户名 --host 连接ip --port 连接端口号 --password 密码 ...

  4. python numpy中sum()时出现负值

    import numpy a=numpy.random.randint(1, 4095, (5000,5000)) a.sum() 结果为负值, 这是错误的,a.sum()的类型为 int32,如何做 ...

  5. html中的路径详解

    路径指文件存放的位置,在网页中利用路径可以引用文件,插入图像.视频等.表示路径的方法有两种:相对路径,绝对路径.以下讨论均是在HTML环境下进行. 相对路径 相对路径是指目标相对于当前文件的路径,网页 ...

  6. 解决安装VMware Player出错,提示安装程序无法继续,microsoft runtime dll安装程序未能完成安装

    方案一: 以兼容模式运行和管理员方式运行安装程序,右键点击安装文件选择属性,在弹出的面板中修改兼容性如下 方案二: 下载最近版的VMWare player安装包哈哈 方案三: 1.双击VMware P ...

  7. IntelliJ IDEA 2017.3尚硅谷-----滚轮修改字体大小

  8. linux建立动态库的软链接

    复制动态库: /home/wmz/anaconda3/lib/ 删除原链接: 建立新链接: /home/wmz/anaconda3/lib/libstdc++.so. 问题的起源是,安装anacond ...

  9. 微信小程序UDP通信

    前言 UDP通信分为单播 广播 组播,基础库2.7.0之后,小程序开始支持UDP通信,目前小程序只支持单播. 小程序API 小程序UDP通信这一块可以说是很简单了就一个UDPSocket实例.然后bi ...

  10. Mount命令的参数详解

    导读 mount是Linux下的一个命令,它可以将分区挂接到Linux的一个文件夹下,从而将分区和该目录联系起来,因此我们只要访问这个文件夹,就相当于访问该分区了. 挂接命令(mount) 首先,介绍 ...