TensorFlow之单层(全连接层)实现手写数字识别训练及测试实例:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('is_train',1,'指定程序是预测还是训练') def full_connected():
# 获取真实的数据
mnist = input_data.read_data_sets('./data/mnists/', one_hot=True) # 1.建立数据的占位符 X [None,784] y_true [None,10]
# 创建一个作用域
with tf.variable_scope('data'):
# 特征值
x = tf.placeholder(tf.float32, [None, 784]) # 目标值(真实值)
y_true = tf.placeholder(tf.int32, [None, 10]) # 2. 建立一个全连接层的神经网络 W [784,10] b [10]
with tf.variable_scope('fc_model'):
# 随机初始化权重和偏置
weight = tf.Variable(tf.random_normal([784, 10], mean=0.0, stddev=1.0), name='w') bias = tf.Variable(tf.constant(0.0, shape=[10])) # 预测None个样本的输出结果matrix [None,784]*[784,10]+[10] = [None,10]
y_predict = tf.matmul(x, weight) + bias # 3. 求出所有样本的损失,然后求平均值
with tf.variable_scope('soft_cross'):
# 求平均交叉熵损失
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict)) # 4. 梯度下降求出损失(优化)
with tf.variable_scope('optimizer'):
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 5. 计算准确率
with tf.variable_scope('acc'):
equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1)) # equal_list None个样本 [1,0,1,0,1,1....]
accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32)) # 收集变量,单个数字值收集
tf.summary.scalar('losses',loss)
tf.summary.scalar('acc',accuracy) # 高纬度变量收集
tf.summary.histogram('weightes',weight)
tf.summary.histogram('biases',bias) # 定义一个初始化变量的op
init_op = tf.global_variables_initializer() # 定义一个合并变量的op
merged = tf.summary.merge_all() # 创建一个saver
saver = tf.train.Saver() # 开启会话去训练
with tf.Session() as sess:
# 初始化变量
sess.run(init_op) # 建立events文件,然后写入
filewriter = tf.summary.FileWriter('./tmp/summary/test/',graph=sess.graph) if FLAGS.is_train == 1:
# 迭代步数去训练,更新参数预测
for i in range(2000):
# 取出真实存在的特征值和目标值
mnist_x, mnist_y = mnist.train.next_batch(50) # 运行train_op训练
sess.run(train_op, feed_dict={x: mnist_x, y_true: mnist_y}) # 写入每步训练的值
summary = sess.run(merged,feed_dict={x: mnist_x, y_true: mnist_y}) filewriter.add_summary(summary,i) print('训练第%d步,准确率为:%f' % (i, sess.run(accuracy, feed_dict={x: mnist_x, y_true: mnist_y}))) # 保存模型
saver.save(sess,'./tmp/summary/model/fc_model')
else:
# 加载模型
saver.restore(sess,'./tmp/summary/model/fc_model') # 如果是0,做出预测
for i in range(100): # 每次测试一张图片,[0,0,0,0,0,1,0,0,0]
x_test,y_test = mnist.test.next_batch(1) print('第%d章图片,手写数字目标是:%d,预测结果是:%d' % (
i,
tf.argmax(y_test,1).eval(),
tf.argmax(sess.run(y_predict,feed_dict={x: x_test,y_true: y_test}),1).eval()
)) return None if __name__ == '__main__':
full_connected()

TensorFlow------单层(全连接层)实现手写数字识别训练及测试实例的更多相关文章

  1. 5 TensorFlow入门笔记之RNN实现手写数字识别

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  2. Tensorflow项目实战一:MNIST手写数字识别

    此模型中,输入是28*28*1的图片,经过两个卷积层(卷积+池化)层之后,尺寸变为7*7*64,将最后一个卷积层展成一个以为向量,然后接两个全连接层,第一个全连接层加一个dropout,最后一个全连接 ...

  3. TensorFlow(十):卷积神经网络实现手写数字识别以及可视化

    上代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...

  4. Tensorflow手写数字识别训练(梯度下降法)

    # coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #p ...

  5. TensorFlow卷积神经网络实现手写数字识别以及可视化

    边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...

  6. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  7. Tensorflow2.0-mnist手写数字识别示例

    Tensorflow2.0-mnist手写数字识别示例   读书不觉春已深,一寸光阴一寸金. 简介:通过CNN 卷积神经网络训练后识别出手写图片,测试图片mnist数据集中的0.1.2.4.     ...

  8. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

  9. 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)

    上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...

随机推荐

  1. hadoop3.1 ha高可用部署

    1.资源角色规划

  2. 《深入浅出MyBatis技术原理与实战》——6. MyBatis的解析和运行原理

    MyBatis的运行分为两大部分,第一部分是读取配置文件缓存到Configuration对象,用以创建SqlSessionFactory,第二部分是SqlSession的执行过程. 6.1 涉及的技术 ...

  3. ELK日志处理

    ELK的工作原理: 使用多播进行机器发现同一个集群内的节点,并汇总各个节点的返回组成一个集群,主节点要读取各个节点的状态,在关键时候进行数据的恢复,主节点会坚持各个节点的状态,并决定每个分片的位置,通 ...

  4. Git----创建远程分支,并将文件上传到创建的远程分支上

    1.首先创建一个远程仓库 2.将远程仓库克隆到本地 (1)本地新建文件夹,命令行进入文件夹,执行clone操作 (2) git clone git@github.com:Lucky-Syw/lucky ...

  5. AC日记——「SCOI2015」国旗计划 LiBreOJ 2007

    #2007. 「SCOI2015」国旗计划 思路: 跪烂Claris 代码: #include <cstdio> #include <algorithm> #define ma ...

  6. css文本、字母、数字过长 自动换行处理

    ---恢复内容开始--- white-space: normal|pre|nowrap|pre-wrap|pre-line|inherit;white-space 属性设置如何处理元素内的空白norm ...

  7. markdown与html之间转换引发的问题

    https://www.hackersb.cn/hacker/235.html 看了这位师傅的文章有感而发 前言 对于支持markdown语法的网站,一般都是在后端将markdown语法渲染为html ...

  8. Linux操作命令(五)

    find . -name ”*.c" -exec ./command.sh {} \; 本次实验将介绍 Linux 命令中 find 和 xargs 命令的用法. find xargs 1. ...

  9. Spring Cloud Feign 总结

    Spring Cloud中, 服务又该如何调用 ? 各个服务以HTTP接口形式暴露 , 各个服务底层以HTTP Client的方式进行互相访问. SpringCloud开发中,Feign是最方便,最为 ...

  10. python excellent code link

    1. Howdoi Howdoi is a code search tool, written in Python. 2. Flask Flask is a microframework for Py ...