tensorflow手写数字识别(有注释)
import tensorflow as tf
import numpy as np
# const = tf.constant(2.0, name='const')
# b = tf.placeholder(tf.float32, [None, 1], name='b')
# # b = tf.Variable(2.0, dtype=tf.float32, name='b')
# c = tf.Variable(1.0, dtype=tf.float32, name='c')
#
# d = tf.add(b, c, name='d')
# e = tf.add(c, const, name='e')
# a = tf.multiply(d, e, name='a')
# init = tf.global_variables_initializer()
#
# print(a)
# with tf.Session() as sess:
# sess.run(init)
# ans = sess.run(a, feed_dict={b: np.arange(0, 10)[:, np.newaxis]})
# print(a)
# print(ans) from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # 载入数据集 learning_rate = 0.5 # 学习率
epochs = 10 # 训练10次所有的样本
batch_size = 100 # 每批训练的样本数 x = tf.placeholder(tf.float32, [None, 784]) # 为训练集的特征提供占位符
y = tf.placeholder(tf.float32, [None, 10]) # 为训练集的标签提供占位符 W1 = tf.Variable(tf.random_normal([784, 300], stddev=0.03), name='W1') # 初始化隐藏层的W1参数
b1 = tf.Variable(tf.random_normal([300]), name='b1') # 初始化隐藏层的b1参数
W2 = tf.Variable(tf.random_normal([300, 10], stddev=0.03), name='W2') # 初始化全连接层的W1参数
b2 = tf.Variable(tf.random_normal([10]), name='b2') # 初始化全连接层的b1参数 hidden_out = tf.add(tf.matmul(x, W1), b1) # 定义隐藏层的第一步运算
hidden_out = tf.nn.relu(hidden_out) # 定义隐藏层经过激活函数后的运算 y_ = tf.nn.softmax(tf.add(tf.matmul(hidden_out, W2), b2)) # 定义全连接层的输出运算 y_clipped = tf.clip_by_value(y_, 1e-10, 0.9999999)
cross_entropy = -tf.reduce_mean(tf.reduce_sum(y * tf.log(y_clipped) + (1 - y) * tf.log(1 - y_clipped), axis=1))
# 交叉熵 optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cross_entropy)
# 梯度下降优化器,传入的参数是交叉熵 init = tf.global_variables_initializer() # 所有参数初始化 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) # 返回true|false
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 将true转化为1,false转化为0 # 开始训练
with tf.Session() as sess:
sess.run(init)
total_batch = int(len(mnist.train.labels) / batch_size) # 计算每个epoch要迭代几次
for epoch in range(epochs):
avg_cost = 0
for i in range(total_batch):
batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)
_, c = sess.run([optimizer, cross_entropy], feed_dict={x: batch_x, y: batch_y})
# 其实上面这一步只需要跑optimizer这个优化器就好了,因为交叉熵也会同时跑。
# 但是我们想要得到交叉熵的值来作为损失函数,所以还需要跑一个交叉熵。
avg_cost += c / total_batch
print("Epoch:", (epoch + 1), "cost = ", "{:.3f}".format(avg_cost)) # 这是每训练完所有样本得到的损失值
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
# 因为之前的计算已经把中间参数计算出来了,所以这里只用最后的计算测试集就行了
tensorflow手写数字识别(有注释)的更多相关文章
- Tensorflow手写数字识别(交叉熵)练习
# coding: utf-8import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #pr ...
- Tensorflow手写数字识别训练(梯度下降法)
# coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data #p ...
- tensorflow 手写数字识别
https://www.kaggle.com/kakauandme/tensorflow-deep-nn 本人只是负责将这个kernels的代码整理了一遍,具体还是请看原链接 import numpy ...
- Tensorflow手写数字识别---MNIST
MNIST数据集:包含数字0-9的灰度图, 图片size为28x28.训练样本:55000,测试样本:10000,验证集:5000
- 卷积神经网络应用于tensorflow手写数字识别(第三版)
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 手写数字识别 ----卷积神经网络模型官方案例注释(基于Tensorflow,Python)
# 手写数字识别 ----卷积神经网络模型 import os import tensorflow as tf #部分注释来源于 # http://www.cnblogs.com/rgvb178/p/ ...
- 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)
# 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...
- TensorFlow使用RNN实现手写数字识别
学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...
随机推荐
- DevExpress中GridColumnCollection实现父子表数据绑定
绑定数据: 父表: DataTable _parent = _dvFlt.ToTable().Copy(); 子表: DataTable _child = _dvLog.ToTable().Copy( ...
- 精通Spring Boot
原 精通Spring Boot—— 第二十一篇:Spring Social OAuth 登录简介 1.什么是OAuth OAuth官网介绍是这样的: An open protocol to allow ...
- 从实践到原理,带你参透 gRPC
gRPC 在 Go 语言中大放异彩,越来越多的小伙伴在使用,最近也在公司安利了一波,希望这一篇文章能带你一览 gRPC 的巧妙之处,本文篇幅比较长,请做好阅读准备.本文目录如下: 简述 gRPC 是一 ...
- Error creating bean with name 'XXX' defined in file
这个错误是我在之前操作时,错将另一个dubbo服务器也加载到了该dubbo服务器上(pom.xml),所以出现了Error creating bean with name 'XXX' defined ...
- 纯css更改图片颜色的技巧
tips: JPG.PNG.GIF 都可以,但是有一个前提要求,就是黑色纯色,背景白色 .pic1 { background-image: url($img), linear-gradient ...
- CSS 之 圣杯布局&双飞翼布局
圣杯布局 和 双飞翼布局 是重要布局方式.两者的功能相同,都是为了实现一个两侧宽度固定,中间宽度自适应的三栏布局. 遵循了以下要点: 两侧宽度固定,中间宽度自适应 中间部分在DOM结构上优先,以便先行 ...
- CentOS - Eclipse安装Shelled
一,下载Shelled: https://sourceforge.net/projects/shelled/ 二,打开Eclipse,以离线方式安装: Help->Install New Sof ...
- cpython多进程
四 同步\异步and阻塞\非阻塞(重点) 同步: #所谓同步,就是在发出一个功能调用时,在没有得到结果之前,该调用就不会返回.按照这个定义,其实绝大多数函数都是同步调用.但是一般而言,我们在说同步.异 ...
- Hoax or what UVA - 11136(multiset的应用)
刚开始把题意理解错了,结果样例没过,后来发现每天只处理最大和最小的,其余的不管,也就是说昨天的元素会影响今天的最大值和最小值,如果模拟的话明显会超时,故用multiset,另外发现rbegin()的功 ...
- django项目中使用手机号登录
本文使用聚合数据的短信接口,需要先获取到申请接口的appkey和模板id 项目目录下创建ubtils文件夹,定义返回随机验证码和调取短信接口的函数 function.py文件 import rando ...