本章已机器学习领域的Hello World任务----MNIST手写识别做为TensorFlow的开始。MNIST是一个非常简单的机器视觉数据集,是由几万张28像素*28像素的手写数字组成,这些图片只包含灰度值信息。

下面提取了784维的特征,也就是2828个点展开成一维的结果,所以训练数据是一个55000784的Tensor,label是一个55000*10的tensor。当我们处理多分类任务时,通常需要使用Softmax Regression模型。它的工作原理很简单,将可以判定为某类的特征相加,然后将这些特征转化为判定是这一类的概率。其本质就是多类别逻辑回归。

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data",one_hot=True)#从TensorFlow读取数据 print (mnist.train.images.shape,mnist.train.labels.shape)
print (mnist.test.images.shape,mnist.test.labels.shape)
print (mnist.validation.images.shape,mnist.validation.labels.shape) import tensorflow as tf
sess = tf.InteractiveSession()#创建一个session,之后的运算都在这个session里,不同session的数据和运算是相互独立的
x = tf.placeholder(tf.float32,[None,784])#输入数据的地方,第一个参数是数据类型,第二个是tensor的shape W = tf.Variable(tf.zeros([784,10]))#Variable是存储模型参数的,不同于存储数据的tensor一旦使用掉就消失,Variable在模型训练迭代过程中是持久化的。
b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x,W)+b)#实现Softmax Regression算法 y_ = tf.placeholder(tf.float32,[None,10])#定义一个真实的label,与下面的结果做比较
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))#计算模型的loss train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#定义了损失函数之后,再定义一个优化算法,本代码使用SGD算法
tf.global_variables_initializer().run()#使用全局参数初始化器初始化参数 for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(100)#每次选择100条数据
train_step.run({x:batch_xs,y_:batch_ys})#选择好数据之后用SGD算法做迭代 correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))#比较预测结果是否准确
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#转化成准确率
print (accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))#输出结果,正确率为91.57%

TensorFlow实现Softmax Regression识别手写数字的更多相关文章

  1. TensorFlow实现Softmax Regression识别手写数字中"TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败”问题

    出现问题: 在使用TensorFlow实现MNIST手写数字识别时,出现"TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应 ...

  2. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  3. 使用TensorFlow的卷积神经网络识别手写数字(2)-训练篇

    import numpy as np import tensorflow as tf import matplotlib import matplotlib.pyplot as plt import ...

  4. 【TensorFlow-windows】(一)实现Softmax Regression进行手写数字识别(mnist)

    博文主要内容有: 1.softmax regression的TensorFlow实现代码(教科书级的代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3 ...

  5. 使用TensorFlow的卷积神经网络识别手写数字(3)-识别篇

    from PIL import Image import numpy as np import tensorflow as tf import time bShowAccuracy = True # ...

  6. 使用TensorFlow的卷积神经网络识别手写数字(1)-预处理篇

    功能: 将文件夹下的20*20像素黑白图片,根据重心位置绘制到28*28图片上,然后保存.经过预处理的图片有利于数字的准确识别.参见MNIST对图片的要求. 此处可下载已处理好的图片: https:/ ...

  7. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  8. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  9. 3 TensorFlow入门之识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

随机推荐

  1. [原][资料整理][osg]osgDB文件读取插件,工作机制,支持格式,自定义插件

    参考: osgPlugins相关 osg读取文件的原理(插件工作机制) 当使用osgDB读取文件时,会自动根据文件的扩展名来到插件目录中寻找相应的插件,来实现. 比如: osgviewer cow.o ...

  2. 算法习题---5.5集合栈计算机(Uva12096)*****

    一:题目 对于一个以集合为元素的栈,初始时栈为空. 输入的命令有如下几种: PUSH:将空集{}压栈 DUP:将栈顶元素复制一份压入栈中 UNION:先进行两次弹栈,将获得的集合A和B取并集,将结果压 ...

  3. VS2010配置OpenGL开发环境(转)

    OpenGL(Open Graphics Library)是一个跨编程语言.跨平台的专业图形程序接口.OpenGL是SGI公司开发的一套计算机图形处理系统,是图形硬件的软件接口,任何一个OpenGL应 ...

  4. networkx详细教程

    写在前面:城市计算研究中经常涉及到图论的相关知识,而且常常面对某些术语时,根本不知道在说什么.最近接触了NetworkX这个graph处理工具,发现这个工具已经解决绝大部分的图论问题(也许只是我自己认 ...

  5. Juniper总结

    Juniper的路由器分为两个部分——RE和PFE.不过貌似大部分路由器都分为这两个部分.... Routing Engine: 当密码授权通过之后,用户就进入了RoutingEngine中,在其中可 ...

  6. LVS-TUN模式

    TUN模式: 其实数据转发原理和上图是一样的,不过这个我个人认为主要是位于不同位置(不同机房):LB是通过隧道进行了信息传输,虽然增加了负载,可是因为地理位置不同的优势,还是可以参考的一种方案: 优点 ...

  7. [bzoj3829][Poi2014]FarmCraft_树形dp

    FarmCraft 题目链接:https://lydsy.com/JudgeOnline/problem.php?id=3829 数据范围:略. 题解: 因为每条边只能必须走两次,所以我们的路径一定是 ...

  8. LeetCode 637. 二叉树的层平均值(Average of Levels in Binary Tree)

    637. 二叉树的层平均值 637. Average of Levels in Binary Tree LeetCode637. Average of Levels in Binary Tree 题目 ...

  9. Java基础笔试练习(六)

    1.在Java中,一个类可同时定义许多同名的方法,这些方法的形式参数个数.类型或顺序各不相同,传回的值也可以不相同.这种面向对象程序的特性称为? A.隐藏 B.覆盖 C.重载 D.Java不支持此特性 ...

  10. (三)Spring Boot 官网文档学习之默认配置

    文章目录 继承 `spring-boot-starter-parent` 覆盖默认配置 启动器 原文地址:https://docs.spring.io/spring-boot/docs/2.1.3.R ...