import tensorflow as tf
import numpy as np
# const = tf.constant(2.0, name='const')
# b = tf.placeholder(tf.float32, [None, 1], name='b')
# # b = tf.Variable(2.0, dtype=tf.float32, name='b')
# c = tf.Variable(1.0, dtype=tf.float32, name='c')
#
# d = tf.add(b, c, name='d')
# e = tf.add(c, const, name='e')
# a = tf.multiply(d, e, name='a')
# init = tf.global_variables_initializer()
#
# print(a)
# with tf.Session() as sess:
# sess.run(init)
# ans = sess.run(a, feed_dict={b: np.arange(0, 10)[:, np.newaxis]})
# print(a)
# print(ans) from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # 载入数据集 learning_rate = 0.5 # 学习率
epochs = 10 # 训练10次所有的样本
batch_size = 100 # 每批训练的样本数 x = tf.placeholder(tf.float32, [None, 784]) # 为训练集的特征提供占位符
y = tf.placeholder(tf.float32, [None, 10]) # 为训练集的标签提供占位符 W1 = tf.Variable(tf.random_normal([784, 300], stddev=0.03), name='W1') # 初始化隐藏层的W1参数
b1 = tf.Variable(tf.random_normal([300]), name='b1') # 初始化隐藏层的b1参数
W2 = tf.Variable(tf.random_normal([300, 10], stddev=0.03), name='W2') # 初始化全连接层的W1参数
b2 = tf.Variable(tf.random_normal([10]), name='b2') # 初始化全连接层的b1参数 hidden_out = tf.add(tf.matmul(x, W1), b1) # 定义隐藏层的第一步运算
hidden_out = tf.nn.relu(hidden_out) # 定义隐藏层经过激活函数后的运算 y_ = tf.nn.softmax(tf.add(tf.matmul(hidden_out, W2), b2)) # 定义全连接层的输出运算 y_clipped = tf.clip_by_value(y_, 1e-10, 0.9999999)
cross_entropy = -tf.reduce_mean(tf.reduce_sum(y * tf.log(y_clipped) + (1 - y) * tf.log(1 - y_clipped), axis=1))
# 交叉熵 optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cross_entropy)
# 梯度下降优化器,传入的参数是交叉熵 init = tf.global_variables_initializer() # 所有参数初始化 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) # 返回true|false
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 将true转化为1,false转化为0 # 开始训练
with tf.Session() as sess:
sess.run(init)
total_batch = int(len(mnist.train.labels) / batch_size) # 计算每个epoch要迭代几次
for epoch in range(epochs):
avg_cost = 0
for i in range(total_batch):
batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)
_, c = sess.run([optimizer, cross_entropy], feed_dict={x: batch_x, y: batch_y})
# 其实上面这一步只需要跑optimizer这个优化器就好了,因为交叉熵也会同时跑。
# 但是我们想要得到交叉熵的值来作为损失函数,所以还需要跑一个交叉熵。
avg_cost += c / total_batch
print("Epoch:", (epoch + 1), "cost = ", "{:.3f}".format(avg_cost)) # 这是每训练完所有样本得到的损失值
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
# 因为之前的计算已经把中间参数计算出来了,所以这里只用最后的计算测试集就行了

tensorflow手写数字识别(有注释)的更多相关文章

  1. Tensorflow手写数字识别(交叉熵)练习

    # coding: utf-8import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #pr ...

  2. Tensorflow手写数字识别训练(梯度下降法)

    # coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #p ...

  3. tensorflow 手写数字识别

    https://www.kaggle.com/kakauandme/tensorflow-deep-nn 本人只是负责将这个kernels的代码整理了一遍,具体还是请看原链接 import numpy ...

  4. Tensorflow手写数字识别---MNIST

    MNIST数据集:包含数字0-9的灰度图, 图片size为28x28.训练样本:55000,测试样本:10000,验证集:5000

  5. 卷积神经网络应用于tensorflow手写数字识别(第三版)

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

  6. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  7. 手写数字识别 ----卷积神经网络模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----卷积神经网络模型 import os import tensorflow as tf #部分注释来源于 # http://www.cnblogs.com/rgvb178/p/ ...

  8. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  9. TensorFlow使用RNN实现手写数字识别

    学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...

随机推荐

  1. SpringCloud整合sleuth,使用zipkin时不显示服务

    转载于:https://www.cnblogs.com/Dandwj/p/11179141.html 原文地址:https://blog.csdn.net/weixin_30416497/articl ...

  2. linux搭建GitLab

    GitLab CentOS6 1. 安装VMware和CentOS 2. 安装必备Linux插件 3. 准备安装GitLab 4. 开始安装GitLab 5. 配置GitLab 6. 启动GitLab ...

  3. python3 语法 数据类型

     python3中 有6种标准数据类型 数字,字符串,列表,元祖,集合,字典

  4. Python进阶(十三)----面向对象

    Python进阶(十三)----面向对象 一丶面向过程编程vs函数式编程vs面向对象编程 面向过程: ​ 简而言之,step by step 一步一步完成功能,就是分析出解决问题所需要的步骤,然后用函 ...

  5. Jquery 跨Dom窗口操作

    . 子窗口给父窗口元素赋值 function modifyTheme(id){ $("#parent_dom",window.parent.document).attr(" ...

  6. 【转载】C#中List集合使用GetRange方法获取指定索引范围内的所有值

    在C#的List集合中有时候需要获取指定索引位置范围的元素对象来组成一个新的List集合,此时就可使用到List集合的扩展方法GetRange方法,GetRange方法专门用于获取List集合指定范围 ...

  7. cs/bs架构的区别

    Client/Server是建立在局域网的基础上的,基于客户端/服务器,安全,响应快,维护难度大,不易拓展,用户面固定,需要相同的操作系统. Browser/Server是建立在广域网的基础上的,基于 ...

  8. Android数据库GreenDao配置版本问题

    感谢该贴解决我多天的困惑:https://blog.csdn.net/u013472738/article/details/72895747 主要是降低了GreenDao版本 网上很多教程说的版本都是 ...

  9. MyBatis日记(四):MyBatis——insert、update、delete、select

    MyBatis简单增删改查操作,此处所做操作,皆是在之前创建的MyBatis的Hello world的工程基础上所做操作. 首先在接口文件(personMapper.java)中,添加操作方法: pa ...

  10. 肖哥HCNP-正式篇笔记

    21.网工学习环境准备. 一. 关掉所有杀毒软件及管家如阿健. 二. 安装环回网卡 (一定要先安装.) 1. 计算机设备管理 2. 在右侧最上端计算机名上方右键,点击过时硬件. 3. 下一步.手动选择 ...