吴裕雄 python 神经网络——TensorFlow TFRecord样例程序
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 定义函数转化变量类型。
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 将数据转化为tf.train.Example格式。
def _make_example(pixels, label, image):
image_raw = image.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'pixels': _int64_feature(pixels),
'label': _int64_feature(np.argmax(label)),
'image_raw': _bytes_feature(image_raw)
}))
return example # 读取mnist训练数据。
mnist = input_data.read_data_sets("E:\\MNIST_data\\",dtype=tf.uint8, one_hot=True)
images = mnist.train.images
labels = mnist.train.labels
pixels = images.shape[1]
num_examples = mnist.train.num_examples # 输出包含训练数据的TFRecord文件。
with tf.python_io.TFRecordWriter("E:\\MNIST_data\\output.tfrecords") as writer:
for index in range(num_examples):
example = _make_example(pixels, labels[index], images[index])
writer.write(example.SerializeToString())
print("TFRecord训练文件已保存。") # 读取mnist测试数据。
images_test = mnist.test.images
labels_test = mnist.test.labels
pixels_test = images_test.shape[1]
num_examples_test = mnist.test.num_examples # 输出包含测试数据的TFRecord文件。
with tf.python_io.TFRecordWriter("E:\\MNIST_data\\output_test.tfrecords") as writer:
for index in range(num_examples_test):
example = _make_example(pixels_test, labels_test[index], images_test[index])
writer.write(example.SerializeToString())
print("TFRecord测试文件已保存。")
# 读取文件。
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer(["E:\\MNIST_data\\output.tfrecords"])
_,serialized_example = reader.read(filename_queue) # 解析读取的样例。
features = tf.parse_single_example(
serialized_example,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'pixels':tf.FixedLenFeature([],tf.int64),
'label':tf.FixedLenFeature([],tf.int64)
}) images = tf.decode_raw(features['image_raw'],tf.uint8)
labels = tf.cast(features['label'],tf.int32)
pixels = tf.cast(features['pixels'],tf.int32) sess = tf.Session() # 启动多线程处理输入数据。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord) for i in range(10):
image, label, pixel = sess.run([images, labels, pixels])
吴裕雄 python 神经网络——TensorFlow TFRecord样例程序的更多相关文章
- 吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集
import tensorflow as tf tf.reset_default_graph() # 配置神经网络的参数 INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE ...
- 吴裕雄 python 神经网络——TensorFlow 输入数据处理框架
import tensorflow as tf files = tf.train.match_filenames_once("E:\\MNIST_data\\output.tfrecords ...
- 吴裕雄 python 神经网络——TensorFlow 输入文件队列
import tensorflow as tf def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64 ...
- 吴裕雄 python 神经网络——TensorFlow 图像预处理完整样例
import numpy as np import tensorflow as tf import matplotlib.pyplot as plt def distort_color(image, ...
- 吴裕雄 python 神经网络——TensorFlow 完整神经网络样例程序
import tensorflow as tf from numpy.random import RandomState batch_size = 8 w1= tf.Variable(tf.rando ...
- 吴裕雄 python 神经网络——TensorFlow variables_to_restore函数的使用样例
import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v") ema = tf.train.Expo ...
- 吴裕雄 python 神经网络——TensorFlow训练神经网络:卷积层、池化层样例
import numpy as np import tensorflow as tf M = np.array([ [[1],[-1],[0]], [[-1],[2],[1]], [[0],[2],[ ...
- 吴裕雄 python 神经网络——TensorFlow 数据集高层操作
import tempfile import tensorflow as tf train_files = tf.train.match_filenames_once("E:\\output ...
- 吴裕雄 python 神经网络——TensorFlow 数据集基本使用方法
import tempfile import tensorflow as tf input_data = [1, 2, 3, 5, 8] dataset = tf.data.Dataset.from_ ...
随机推荐
- Iris路由和路由组
package main import ( "github.com/kataras/iris" "github.com/kataras/iris/context" ...
- 每天进步一点点------创建Microblaze软核(二)
第四步 进入Platform Studio操作界面通过向导创建软核后,进入到PlatformStudio——内核开发环境.Platform Studio主界面如下图. 在Ports项中,右键点击RS2 ...
- mongo shell远程连接使用数据库
mongo mydb --username user1 --host --password --username 用户名 --host 连接ip --port 连接端口号 --password 密码 ...
- python numpy中sum()时出现负值
import numpy a=numpy.random.randint(1, 4095, (5000,5000)) a.sum() 结果为负值, 这是错误的,a.sum()的类型为 int32,如何做 ...
- html中的路径详解
路径指文件存放的位置,在网页中利用路径可以引用文件,插入图像.视频等.表示路径的方法有两种:相对路径,绝对路径.以下讨论均是在HTML环境下进行. 相对路径 相对路径是指目标相对于当前文件的路径,网页 ...
- 解决安装VMware Player出错,提示安装程序无法继续,microsoft runtime dll安装程序未能完成安装
方案一: 以兼容模式运行和管理员方式运行安装程序,右键点击安装文件选择属性,在弹出的面板中修改兼容性如下 方案二: 下载最近版的VMWare player安装包哈哈 方案三: 1.双击VMware P ...
- IntelliJ IDEA 2017.3尚硅谷-----滚轮修改字体大小
- linux建立动态库的软链接
复制动态库: /home/wmz/anaconda3/lib/ 删除原链接: 建立新链接: /home/wmz/anaconda3/lib/libstdc++.so. 问题的起源是,安装anacond ...
- 微信小程序UDP通信
前言 UDP通信分为单播 广播 组播,基础库2.7.0之后,小程序开始支持UDP通信,目前小程序只支持单播. 小程序API 小程序UDP通信这一块可以说是很简单了就一个UDPSocket实例.然后bi ...
- Mount命令的参数详解
导读 mount是Linux下的一个命令,它可以将分区挂接到Linux的一个文件夹下,从而将分区和该目录联系起来,因此我们只要访问这个文件夹,就相当于访问该分区了. 挂接命令(mount) 首先,介绍 ...