一个简单的TensorFlow可视化MNIST数据集识别程序
下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev))
# -*- coding: utf-8 -*- import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)
# 载入数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 运行次数
max_steps = 3001
# 图片数量
image_num = 5000
# 文件路径
DIR = "D:/AIdata/tf_data/tf_test1/" sess = tf.Session() # 载入图片,
# tf.stack矩阵拼接函数,
embedding = tf.Variable(tf.stack(mnist.test.images[:image_num]),
trainable=False, name="embedding") def variable_summaries(var):
with tf.name_scope("summaries"):
mean = tf.reduce_mean(var)
with tf.name_scope("stddev"):
# 计算标准差
stddev = tf.sqrt(tf.reduce_mean(tf.square(var-mean)))
# 绘制标准差信息
tf.summary.scalar("stddev", stddev)
# 绘制最大值
tf.summary.scalar("max", tf.reduce_max(var))
tf.summary.scalar("min", tf.reduce_min(var))
# 绘制直方图信息
tf.summary.histogram("histogram", var) with tf.name_scope('Input'):
x = tf.placeholder(tf.float32, [None, 784], name="x_input")
y = tf.placeholder(tf.float32, [None, 10], name="y_input")
LR = tf.Variable(0.001, dtype=tf.float32) # 显示图片
with tf.name_scope("input_reshape"):
# 改变x的形状(28x28x1)
image_shape_input = tf.reshape(x, [-1, 28, 28, 1])
# 将图像写入summary,输出带图像的probuf
tf.summary.image("Input", image_shape_input, 10) with tf.name_scope('layer'):
with tf.name_scope('weights'):
W = tf.Variable(tf.zeros([784, 10]), name='W')
variable_summaries(W)
with tf.name_scope('biases'):
b = tf.Variable(tf.zeros([10]), name='b')
variable_summaries(b)
with tf.name_scope('wxb'):
# tf.matmul实现矩阵乘法功能
wxb = tf.matmul(x, W) + b
with tf.name_scope('softmax'):
prediction = tf.nn.softmax(wxb) with tf.name_scope("loss"):
# 交叉熵函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,
logits=prediction))
# 绘制loss值
tf.summary.scalar("loss", loss) with tf.name_scope("Train"):
# AdamOptimizer优化器
train_step = tf.train.AdamOptimizer(LR).minimize(loss) init_op = tf.global_variables_initializer()
sess.run(init_op) # 变量初始化 with tf.name_scope("Result"):
with tf.name_scope("correct_prediction"):
# 记录预测值和标签值对比结果
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
with tf.name_scope("Accuracy"):
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 绘制准确率
tf.summary.scalar("accuracy", accuracy) # 判断是否已存在metadata.tsv文件,若存在则删除
if tf.gfile.Exists(DIR+"projector/projector/metadata.tsv"):
tf.gfile.Remove(DIR+"projector/projector/metadata.tsv") # 创建并写入metadata.tsv文件
with open(DIR+"projector/projector/metadata.tsv", 'w') as f:
labels = sess.run(tf.argmax(mnist.test.labels[:], 1))
for i in range(image_num):
f.write(str(labels[i]) + '\n') # 合并默认图表管理summary
merged = tf.summary.merge_all() projector_writer = tf.summary.FileWriter(DIR+"/projector/projector", sess.graph)
# 定义saver对象,以保存和恢复模型变量
saver = tf.train.Saver()
# 定义配置
config = projector.ProjectorConfig()
embed = config.embeddings.add()
embed.tensor_name = embedding.name
# metadata_path文件路径
embed.metadata_path = DIR+"projector/projector/metadata.tsv"
# sprite image文件路径
embed.sprite.image_path = DIR+'projector/data/mnist_10k_sprite.png'
# sprite image中每一单个图像的大小
embed.sprite.single_image_dim.extend([28, 28])
# 写入可视化配置
projector.visualize_embeddings(projector_writer, config) for i in range(max_steps):
# 每个批次100个样本
batch_xs, batch_ys = mnist.train.next_batch(100)
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
summary, _ = sess.run([merged, train_step], feed_dict={x: batch_xs, y: batch_ys},
options=run_options, run_metadata=run_metadata)
projector_writer.add_run_metadata(run_metadata, 'step%03d' % i)
projector_writer.add_summary(summary, i) if i % 100 == 0:
sess.run(tf.assign(LR, 0.001))
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Iter " + str(i) + ", Testing Accuracy= " + str(acc))
# 保存模型
saver.save(sess, DIR+'projector/projector/mnist_model.ckpt', global_step=max_steps)
projector_writer.close()
sess.close()
在cmd中输入tensorboard --logdir=tensorboard --logdir=D:\AIdata\tf_data\tf_test1\projector\projector --host=127.0.0.1

在浏览器中输入http://127.0.0.1:6006打开,会显示如下内容
显示表(loss表, 权重W...)

显示图片信息

计算图

动态放映训练过程,可在此进行模型训练,动态的观看训练状态

一个简单的TensorFlow可视化MNIST数据集识别程序的更多相关文章
- TensorFlow 下 mnist 数据集的操作及可视化
from tensorflow.examples.tutorials.mnist import input_data 首先需要连网下载数据集: mnsit = input_data.read_data ...
- Tensorflow可视化MNIST手写数字训练
简述] 我们在学习编程语言时,往往第一个程序就是打印“Hello World”,那么对于人工智能学习系统平台来说,他的“Hello World”小程序就是MNIST手写数字训练了.MNIST是一个手写 ...
- 基于TensorFlow的MNIST数据集的实验
一.MNIST实验内容 MNIST的实验比较简单,可以直接通过下面的程序加上程序上的部分注释就能很好的理解了,后面在完善具体的相关的数学理论知识,先记录在这里: 代码如下所示: import tens ...
- TensorFlow 训练MNIST数据集(2)—— 多层神经网络
在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率.这次换一种神经网络(多层神经网络)来进行训练和测试. 1.获取MNIST数据 MNIST数据集只要一行代码 ...
- 《Hands-On Machine Learning with Scikit-Learn&TensorFlow》mnist数据集错误及解决方案
最近在看这本书看到Chapter 3.Classification,是关于mnist数据集的分类,里面有个代码是 from sklearn.datasets import fetch_mldata m ...
- 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)
1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...
- TensorFlow训练MNIST数据集(1) —— softmax 单层神经网络
1.MNIST数据集简介 首先通过下面两行代码获取到TensorFlow内置的MNIST数据集: from tensorflow.examples.tutorials.mnist import inp ...
- 基于Keras 的VGG16神经网络模型的Mnist数据集识别并使用GPU加速
这段话放在前面:之前一种用的Pytorch,用着还挺爽,感觉挺方便的,但是在最近文献的时候,很多实验都是基于Google 的Keras的,所以抽空学了下Keras,学了之后才发现Keras相比Pyto ...
- 基于 tensorflow 的 mnist 数据集预测
1. tensorflow 基本使用方法 2. mnist 数据集简介与预处理 3. 聚类算法模型 4. 使用卷积神经网络进行特征生成 5. 训练网络模型生成结果 how to install ten ...
随机推荐
- 逻辑读为何消耗CPU?
在数据库系统中,经常会看到这个说法:“逻辑读很消耗CPU”,然后开始把这句话当作一个定理来使用.但是为什么“同样是读,为什么逻辑读会使用那么多CPU?” 查了一些资料,配合自己的理解,有下面几点体会: ...
- Echars 地图属性详解
1.引入echarts库文件 <script charset="utf-8" type="text/javascript" language=" ...
- jquery的$(selector).each(function(index,element))和$.each(dataresource,function(index,element))的区别
$(selector).each(function(index,element)) 定义和用法 each() 方法规定为每个匹配元素规定运行的函数. $(selector).each(function ...
- hibernate多生成一个外键以及映射文件中含有<list-index>标签
(原文地址: http://blog.csdn.net/xiaoxian8023/article/details/15380529) 一.Inverse是hibernate双向关系中的基本概念.inv ...
- vue踩坑(二):跨域以及携带cookie
最近后台需求要在请求的时候传cooki给后台,正常情况下拿到cookie后存在cookie里,同域名下是会自己带到请求头里的,但是因为要在本地调试,那么问题就来了,localhost:8080下面的c ...
- Ubuntu16.04 藍牙連上,但是聲音裏面找不到設備
解決辦法: 1. sudo apt-get install blueman bluez* 2. sudo vim /etc/pulse/default.pa 注釋掉下面的代碼: #.ifexists ...
- 统一集中管理系统cronsun简介,替代crontab
一.背景 crontab 是 Linux 系统里面最简单易用的定时任务管理工具,相信绝大多数开发和运维都用到过.在咱们公司,很多业务系统的定时任务都是通过 crontab 来定义的,时间长了后会发现存 ...
- java中mysql查询报错java.sql.SQLException: Before start of result set
异常:java.sql.SQLException: Before start of result set 解决方法:使用rs.getString();前一定要加上rs.next(); sm = con ...
- iOS socket常用数据类型转换
int -> data /** int -> data */ + (NSData *)intToData:(int)value { Byte byte[4] = {}; byte[0] = ...
- 【Spring学习】SpringMVC demo搭建
前言:今天会通过IDEA建立一个SpringMVC的demo项目,在其中会涉及到一些基础模块和相关知识,然后结合这个具体的知识点,理解清楚SpringMVC的框架原理(以图的形式展示),顺藤摸瓜分析源 ...