【TensorFlow-windows】(一)实现Softmax Regression进行手写数字识别(mnist)
博文主要内容有:
1.softmax regression的TensorFlow实现代码(教科书级的代码注释)
2.该实现中的函数总结
平台:
1.windows 10 64位
2.Anaconda3-4.2.0-Windows-x86_64.exe (当时TF还不支持python3.6,又懒得在高版本的anaconda下配置多个python环境,于是装了一个3-4.2.0(默认装python3.5),建议装anaconda3的最新版本,TF1.2.0版本已经支持python3.6!)
3.TensorFlow1.1.0
先贴代码,函数
# -*- coding: utf-8 -*-
"""
Created on Mon Jun 12 16:36:43 2017
@author: ASUS
"""
#import itchat
#from PIL import Image
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot = True) # 是一个tensorflow内部的变量
print(mnist.train.images.shape, mnist.train.labels.shape) # 训练集形状, 标签形状
sess = tf.InteractiveSession() # sess被注册为默认的session
x = tf.placeholder(tf.float32, [None, 784]) # Placeholder是输入数据的地方
#-------------给weights和bias创建Variable对象-------------
# Variable是用来存储模型参数,与存储数据的tensor不同,tensor一旦使用掉就消失
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
# 计算 softmax 输出 y 其中x形状是[None, 784],None是为batch数而准备的
y = tf.nn.softmax(tf.matmul(x, W) + b)
#-------------交叉熵损失函数-------------
# y_存放真实标签
y_ = tf.placeholder(tf.float32, [None, 10])
# recude_mean和reduce_sum意思缩减维度的均值,以及缩减维度的求和
# reduce_mean在这里是对一个batch进行求均值
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices = [1]))
#-------------优化算法设置-------------
# 采用梯度下降的优化方法,
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#---------------全局参数初始化-------------------
tf.global_variables_initializer().run()
#---------------迭代地执行训练操作-------------------
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run({x: batch_xs, y_: batch_ys})
#---------------准确率验证-------------------
# tf.argmax是寻找tensor中值最大的元素的序号 ,在此用来判断类别
# tf.equal是判断两个变量是否相等, 返回的是bool值
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_, 1))
# tf.cast用于数据类型转换
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(accuracy)
print(accuracy.eval({x:mnist.test.images, y_: mnist.test.labels}))
其中用到的函数总结:
1. sess = tf.InteractiveSession() 将sess注册为默认的session
2. tf.placeholder() , Placeholder是输入数据的地方,也称为占位符,通俗的理解就是给输入数据(此例中的图片x)和真实标签(y_)提供一个入口,或者是存放地。(个人理解,可能不太正确,后期对TF有深入认识的话再回来改~~)
3. tf.Variable() Variable是用来存储模型参数,与存储数据的tensor不同,tensor一旦使用掉就消失
4. tf.matmul() 矩阵相乘函数
5. tf.reduce_mean 和tf.reduce_sum 是缩减维度的计算均值,以及缩减维度的求和
6. tf.argmax() 是寻找tensor中值最大的元素的序号 ,此例中用来判断类别
7. tf.cast() 用于数据类型转换
(ps: 可将每篇博文最后的函数总结复制到一个word文档,便于日后查找)
【TensorFlow-windows】(一)实现Softmax Regression进行手写数字识别(mnist)的更多相关文章
- TensorFlow 之 手写数字识别MNIST
官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...
- keras框架的MLP手写数字识别MNIST,梳理?
keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...
- keras和tensorflow搭建DNN、CNN、RNN手写数字识别
MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...
- TensorFlow(十二):使用RNN实现手写数字识别
上代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist ...
- 吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集
import tensorflow as tf tf.reset_default_graph() # 配置神经网络的参数 INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE ...
- Tensorflow - Tutorial (7) : 利用 RNN/LSTM 进行手写数字识别
1. 经常使用类 class tf.contrib.rnn.BasicLSTMCell BasicLSTMCell 是最简单的一个LSTM类.没有实现clipping,projection layer ...
- 吴裕雄 python 神经网络——TensorFlow实现AlexNet模型处理手写数字识别MNIST数据集
import tensorflow as tf # 输入数据 from tensorflow.examples.tutorials.mnist import input_data mnist = in ...
- Tensorflow手写数字识别---MNIST
MNIST数据集:包含数字0-9的灰度图, 图片size为28x28.训练样本:55000,测试样本:10000,验证集:5000
- Tensorflow笔记——神经网络图像识别(五)手写数字识别
随机推荐
- Python [拷贝copy / 深度拷贝deepcopy] | 可视化理解
Python 是一门面向对象的语言, 在Python一切皆对象. 每一个对象都有由以下三个属性组成: ------------------------------------------------- ...
- SpringBoot之ApplicationContextInitializer的理解和使用
一. ApplicationContextInitializer 介绍 首先看spring官网的介绍: 翻译一下: 用于在spring容器刷新之前初始化Spring ConfigurableAppli ...
- 【CF1023B】Pair of Toys(解方程)
题意:给定n个玩具要你选出两个玩具求出k的价值,第i个玩具的价值为i.若是没有选择方案,输出0 补充:玩具A与玩具B 和 玩具B和玩具A 是同一种选择 n,k<=1e14 思路:列出式子,解不等 ...
- ubuntu登入死循环问题 解决!!
把/etc/environment文件中的 PATH="/usr/local//sbin:/usr/local/bin:/usr/bin:/sbin:/bin:/usr/games" ...
- 三读bootmem【转】
转自:http://blog.csdn.net/lights_joy/article/details/2704788 版权声明:本文为博主原创文章,未经博主允许不得转载. 目录(?)[-] 11 ...
- OpenOPC
客户端连接OpenOPC Gateway import OpenOPC gateway='192.168.1.90' opchost='testbox' opcserv='KEPware.KEPSer ...
- hdu 1077(单位圆覆盖问题)
Catching Fish Time Limit: 10000/5000 MS (Java/Others) Memory Limit: 65536/32768 K (Java/Others)To ...
- iOS-tableView上拉加载更多后,界面出现偏移
问题描述: 在做tableview的界面展示的时候,cell用自动计算高度的.但是在上拉加载更多的时候,数据请求完后,刷新界面,界面的顶部就出现了偏移 分析: 查阅资料后发现,当tableView的c ...
- 移动端自动化测试(一)appium环境搭建
自动化测试有主要有两个分类,接口自动化和ui自动化,ui自动化呢又分移动端的和web端的,当然还有c/s架构的,这种桌面程序应用的自动化,使用QTP,只不过现在没人做了. web自动化呢,现在基本上都 ...
- 洛谷——P2386 放苹果
P2386 放苹果 题目背景 (poj1664) 题目描述 把M个同样的苹果放在N个同样的盘子里,允许有的盘子空着不放,问共有多少种不同的分发(5,1,1和1,1,5是同一种方法) 输入输出格式 输入 ...