一个简单的TensorFlow可视化MNIST数据集识别程序
下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev))
# -*- coding: utf-8 -*- import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)
# 载入数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 运行次数
max_steps = 3001
# 图片数量
image_num = 5000
# 文件路径
DIR = "D:/AIdata/tf_data/tf_test1/" sess = tf.Session() # 载入图片,
# tf.stack矩阵拼接函数,
embedding = tf.Variable(tf.stack(mnist.test.images[:image_num]),
trainable=False, name="embedding") def variable_summaries(var):
with tf.name_scope("summaries"):
mean = tf.reduce_mean(var)
with tf.name_scope("stddev"):
# 计算标准差
stddev = tf.sqrt(tf.reduce_mean(tf.square(var-mean)))
# 绘制标准差信息
tf.summary.scalar("stddev", stddev)
# 绘制最大值
tf.summary.scalar("max", tf.reduce_max(var))
tf.summary.scalar("min", tf.reduce_min(var))
# 绘制直方图信息
tf.summary.histogram("histogram", var) with tf.name_scope('Input'):
x = tf.placeholder(tf.float32, [None, 784], name="x_input")
y = tf.placeholder(tf.float32, [None, 10], name="y_input")
LR = tf.Variable(0.001, dtype=tf.float32) # 显示图片
with tf.name_scope("input_reshape"):
# 改变x的形状(28x28x1)
image_shape_input = tf.reshape(x, [-1, 28, 28, 1])
# 将图像写入summary,输出带图像的probuf
tf.summary.image("Input", image_shape_input, 10) with tf.name_scope('layer'):
with tf.name_scope('weights'):
W = tf.Variable(tf.zeros([784, 10]), name='W')
variable_summaries(W)
with tf.name_scope('biases'):
b = tf.Variable(tf.zeros([10]), name='b')
variable_summaries(b)
with tf.name_scope('wxb'):
# tf.matmul实现矩阵乘法功能
wxb = tf.matmul(x, W) + b
with tf.name_scope('softmax'):
prediction = tf.nn.softmax(wxb) with tf.name_scope("loss"):
# 交叉熵函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,
logits=prediction))
# 绘制loss值
tf.summary.scalar("loss", loss) with tf.name_scope("Train"):
# AdamOptimizer优化器
train_step = tf.train.AdamOptimizer(LR).minimize(loss) init_op = tf.global_variables_initializer()
sess.run(init_op) # 变量初始化 with tf.name_scope("Result"):
with tf.name_scope("correct_prediction"):
# 记录预测值和标签值对比结果
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
with tf.name_scope("Accuracy"):
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 绘制准确率
tf.summary.scalar("accuracy", accuracy) # 判断是否已存在metadata.tsv文件,若存在则删除
if tf.gfile.Exists(DIR+"projector/projector/metadata.tsv"):
tf.gfile.Remove(DIR+"projector/projector/metadata.tsv") # 创建并写入metadata.tsv文件
with open(DIR+"projector/projector/metadata.tsv", 'w') as f:
labels = sess.run(tf.argmax(mnist.test.labels[:], 1))
for i in range(image_num):
f.write(str(labels[i]) + '\n') # 合并默认图表管理summary
merged = tf.summary.merge_all() projector_writer = tf.summary.FileWriter(DIR+"/projector/projector", sess.graph)
# 定义saver对象,以保存和恢复模型变量
saver = tf.train.Saver()
# 定义配置
config = projector.ProjectorConfig()
embed = config.embeddings.add()
embed.tensor_name = embedding.name
# metadata_path文件路径
embed.metadata_path = DIR+"projector/projector/metadata.tsv"
# sprite image文件路径
embed.sprite.image_path = DIR+'projector/data/mnist_10k_sprite.png'
# sprite image中每一单个图像的大小
embed.sprite.single_image_dim.extend([28, 28])
# 写入可视化配置
projector.visualize_embeddings(projector_writer, config) for i in range(max_steps):
# 每个批次100个样本
batch_xs, batch_ys = mnist.train.next_batch(100)
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
summary, _ = sess.run([merged, train_step], feed_dict={x: batch_xs, y: batch_ys},
options=run_options, run_metadata=run_metadata)
projector_writer.add_run_metadata(run_metadata, 'step%03d' % i)
projector_writer.add_summary(summary, i) if i % 100 == 0:
sess.run(tf.assign(LR, 0.001))
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Iter " + str(i) + ", Testing Accuracy= " + str(acc))
# 保存模型
saver.save(sess, DIR+'projector/projector/mnist_model.ckpt', global_step=max_steps)
projector_writer.close()
sess.close()
在cmd中输入tensorboard --logdir=tensorboard --logdir=D:\AIdata\tf_data\tf_test1\projector\projector --host=127.0.0.1
在浏览器中输入http://127.0.0.1:6006打开,会显示如下内容
显示表(loss表, 权重W...)
显示图片信息
计算图
动态放映训练过程,可在此进行模型训练,动态的观看训练状态
一个简单的TensorFlow可视化MNIST数据集识别程序的更多相关文章
- TensorFlow 下 mnist 数据集的操作及可视化
from tensorflow.examples.tutorials.mnist import input_data 首先需要连网下载数据集: mnsit = input_data.read_data ...
- Tensorflow可视化MNIST手写数字训练
简述] 我们在学习编程语言时,往往第一个程序就是打印“Hello World”,那么对于人工智能学习系统平台来说,他的“Hello World”小程序就是MNIST手写数字训练了.MNIST是一个手写 ...
- 基于TensorFlow的MNIST数据集的实验
一.MNIST实验内容 MNIST的实验比较简单,可以直接通过下面的程序加上程序上的部分注释就能很好的理解了,后面在完善具体的相关的数学理论知识,先记录在这里: 代码如下所示: import tens ...
- TensorFlow 训练MNIST数据集(2)—— 多层神经网络
在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率.这次换一种神经网络(多层神经网络)来进行训练和测试. 1.获取MNIST数据 MNIST数据集只要一行代码 ...
- 《Hands-On Machine Learning with Scikit-Learn&TensorFlow》mnist数据集错误及解决方案
最近在看这本书看到Chapter 3.Classification,是关于mnist数据集的分类,里面有个代码是 from sklearn.datasets import fetch_mldata m ...
- 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)
1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...
- TensorFlow训练MNIST数据集(1) —— softmax 单层神经网络
1.MNIST数据集简介 首先通过下面两行代码获取到TensorFlow内置的MNIST数据集: from tensorflow.examples.tutorials.mnist import inp ...
- 基于Keras 的VGG16神经网络模型的Mnist数据集识别并使用GPU加速
这段话放在前面:之前一种用的Pytorch,用着还挺爽,感觉挺方便的,但是在最近文献的时候,很多实验都是基于Google 的Keras的,所以抽空学了下Keras,学了之后才发现Keras相比Pyto ...
- 基于 tensorflow 的 mnist 数据集预测
1. tensorflow 基本使用方法 2. mnist 数据集简介与预处理 3. 聚类算法模型 4. 使用卷积神经网络进行特征生成 5. 训练网络模型生成结果 how to install ten ...
随机推荐
- 2018-2019-2 20175213实验一 《Java开发环境的熟悉》实验报告
第一部分实验要求:1 建立“自己学号exp1”的目录2 在“自己学号exp1”目录下建立src,bin等目录3 javac,java的执行在“自己学号exp1”目录4 提交 Linux或Window或 ...
- vue 自定义组件销毁
今天在开发电商vue前端项目时,用户每次登出再换其它用户登录时,页面显示的用户名和左则导航都还是上个用户的,刚开始以为是localStorage中没有清除全局数据,然后在用户点击退出系统时手动清除lo ...
- Uncaught TypeError: form.attr is not a function 解决办法
前端form表单提交时遇到个问题,一直报错如下 首先说结论:form是个js对象,不是jQuery对象,不能用jquery对象的方法. 代码是: $(document).ready(function( ...
- Ambari2.6.2 HDP2.6.5 大数据集群搭建
Ambari 2.6.2 中 HDFS-2.7.3 YARN-2.7.3 HIVE-1.2.1 HBASE-1.1.2 ZOOKEEPER-3.4.6 SPARK-2.3.0 注:本文基于root用户 ...
- 独立安装CentOS7.4全记录
大学用了四年的笔记本快用废了,闲来想着用来装个centos,当个服务器也行,于是装上了CentOS6.9系统,由于最小化安装,而且在安装时没有安装wpa_supplicant包,笔记本本身网卡接口又坏 ...
- redis缓存雪崩、穿透、击穿概念及解决办法
缓存雪崩 对于系统 A,假设每天高峰期每秒 5000 个请求,本来缓存在高峰期可以扛住每秒 4000 个请求,但是缓存机器意外发生了全盘宕机.缓存挂了,此时 1 秒 5000 个请求全部落数据库,数据 ...
- idea连接mysql
https://blog.csdn.net/Golden_soft/article/details/80952243
- 51单片机学习笔记(清翔版)(13)——LED点阵、74HC595
如图3,点阵屏分单色和彩色,点阵屏是由许多点组成的,在一个点上,只有一颗一种颜色的灯珠,这就是单色点阵屏,彩色的在一个点上有三颗灯珠,分别是RGB三原色. 图4你可能没看出来,那么大块黄色的就是点阵屏 ...
- Dockerfile的alpine时区设置
FROM *** RUN apk add -U tzdataRUN cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
- VB 性能优化点
1.将Single,Double和Currency类型的变量替换为Integer或Long类型的变量:10倍 2.避免使用变体: 慢:Dim FSO as object Set FSO = N ...