tensorflow MNIST新手教程
官方教程代码如下:
import gzip
import os
import tempfile import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets mnist = read_data_sets("MNIST_data/", one_hot=True) x = tf.placeholder(tf.float32,[None, 784]) #图像输入向量
W = tf.Variable(tf.zeros([784,10])) #权重,初始化值为全零
b = tf.Variable(tf.zeros([10])) #偏置,初始化值为全零 #进行模型计算,y是预测,y_ 是实际
y = tf.nn.softmax(tf.matmul(x,W) + b) y_ = tf.placeholder("float", [None,10]) #计算交叉熵
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#接下来使用BP算法来进行微调,以0.01的学习速率
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #上面设置好了模型,添加初始化创建变量的操作
init = tf.global_variables_initializer()
#启动创建的模型,并初始化变量
sess = tf.Session()
sess.run(init)
#开始训练模型,循环训练1000次
for i in range(1000):
#随机抓取训练数据中的100个批处理数据点
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x:batch_xs,y_:batch_ys}) ''''' 进行模型评估 ''' #判断预测标签和实际标签是否匹配
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
#计算所学习到的模型在测试数据集上面的正确率
print( sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels}) )
运行出现错误,不能导入数据,解决方案如下:
1..坑爹的GWF,严重阻碍人工智能的发展,与习大大十九大报告背道而驰,只能靠爱国青年曲线救国。方法为修改mnist.py文件(tensorflow.contrib.learn.python.learn.datasets.mnist),SOURCE_URL从 'https://storage.googleapis.com/cvdf-datasets/mnist/'改为 'http://yann.lecun.com/exdb/mnist/',如此就能运行了。
2.手动下载 http://yann.lecun.com/exdb/mnist/,处理代码如下,注意一定要把图片像素除以255,否则正确率只有0.098
import tensorflow as tf
import struct
import numpy as np with open('train-labels.idx1-ubyte','rb') as lb:
magic,n=struct.unpack('>II',lb.read(8))
labels = np.fromfile(lb,dtype=np.uint8) with open('train-images.idx3-ubyte','rb') as img:
magic,num,rows,cols=struct.unpack('>IIII',img.read(16))
images = np.fromfile(img,dtype=np.uint8).reshape(-1,784)
images = images.astype(np.float32)
images = np.multiply(images, 1.0 / 255.0) with open('t10k-labels.idx1-ubyte','rb') as lb:
magic,n=struct.unpack('>II',lb.read(8))
testlabels = np.fromfile(lb,dtype=np.uint8) with open('t10k-images.idx3-ubyte','rb') as img:
magic,num,rows,cols=struct.unpack('>IIII',img.read(16))
testimages = np.fromfile(img,dtype=np.uint8).reshape(-1,784)
testimages = testimages.astype(np.float32)
testimages = np.multiply(testimages, 1.0 / 255.0) x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) sess = tf.InteractiveSession()
tf.global_variables_initializer().run() labels = sess.run(tf.one_hot(labels,10))
testlabels = sess.run(tf.one_hot(testlabels,10)) for _ in range(1000):
a=np.random.permutation(np.arange(60000))
bx = images[a[:100]]
by = labels[a[:100]]
sess.run(train_step,feed_dict={x:bx,y_:by}) correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x:testimages, y_:testlabels}))
执行后结果约为0.92.
另外下列onehot方法非常好用:
def dense_to_one_hot(labels_dense, num_classes):
"""Convert class labels from scalars to one-hot vectors."""
num_labels = labels_dense.shape[0]
index_offset = numpy.arange(num_labels) * num_classes
labels_one_hot = numpy.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
tensorflow MNIST新手教程的更多相关文章
- Tensorflow的CNN教程解析
之前的博客我们已经对RNN模型有了个粗略的了解.作为一个时序性模型,RNN的强大不需要我在这里重复了.今天,让我们来看看除了RNN外另一个特殊的,同时也是广为人知的强大的神经网络模型,即CNN模型.今 ...
- TensorFlow MNIST(手写识别 softmax)实例运行
TensorFlow MNIST(手写识别 softmax)实例运行 首先要有编译环境,并且已经正确的编译安装,关于环境配置参考:http://www.cnblogs.com/dyufei/p/802 ...
- Mac tensorflow mnist实例
Mac tensorflow mnist实例 前期主要需要安装好tensorflow的环境,Mac 如果只涉及到CPU的版本,推荐使用pip3,傻瓜式安装,一行命令!代码使用python3. 在此附上 ...
- Tensorflow 官方版教程中文版
2015年11月9日,Google发布人工智能系统TensorFlow并宣布开源,同日,极客学院组织在线TensorFlow中文文档翻译.一个月后,30章文档全部翻译校对完成,上线并提供电子书下载,该 ...
- Web项目的发布新手教程
ASP.NET服务器发布新手教程 ——本文仅赠予第一次做Web项目,需要发布的新手们,转载的请注明出处. 首先我们说一下我们的需要的一个环境.我使用的是Visual Studio 2010,版本.NE ...
- APP设计尺寸规范大全,APP界面设计新手教程【官方版】(转)
正值25学堂一周年之际,同时站长和APP设计同仁们在群里(APP界面设计 UI设计交流群,APP界面设计⑥群 APPUI设计③群58946771 APP设计资源⑤群 386032923欢迎大家加入交流 ...
- ROS探索总结(三)——ROS新手教程【转】
转自:http://blog.csdn.net/hcx25909/article/details/8811313 版权声明:本文为博主原创文章,未经博主允许不得转载. 目录(?)[-] 一ROS的 ...
- 新手教程之使用Xib自定义UITableViewCell
新手教程之使用Xib自定义UITableViewCell 前言 首先:什么是UITableView?看图 其次:什么是cell? 然后:为什么要自定cell,UITableView不是自带的有cell ...
- MATLAB新手教程
MATLAB新手教程 .MATLAB的基本知识 1-1.基本运算与函数 在MATLAB下进行基本数学运算,仅仅需将运算式直接打入提示号(>>)之後,并按入Enter键就可以.比如 ...
随机推荐
- maximum shortest distance
maximum shortest distance Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 32768/32768 K (Java/O ...
- java 操作格子问题(线段树)
很久之前做过线段树的问题(操作格子),时间长了之后再次接触到,发现当初理解的不是很透彻,然后代码冗长,再遇到的时候发现自己甚至不能独立地完成这个问题. 所以算法这个东西啊, 第一,是要经常练习(我个人 ...
- Java面试常会被问到的经典面试题,学习或者求职,你都要好好掌握
Java现在的热度虽然有所下降,但是,学Java的人依旧很多..Java的岗位也是渗透很多.那么,那些经典的Java知识点,你能看到问题就能说出一二三吗?来一起看看.. 1.JDK和JRE的区别 2. ...
- PHP静态化技术
很多框架的模板引擎都有页面静态化的功能 目的是为了优化网站运行时间 静态化分两种 纯静态和伪静态 一. 纯静态 纯静态展示的是实实在在的静态页面 运行PHP程序 判断是否存在静态页 如果存在 展示 ...
- 通讯框架 T-io 学习——给初学者的Demo:ShowCase设计分析
前言 最近闲暇时间研究Springboot,正好需要用到即时通讯部分了,虽然springboot 有websocket,但是我还是看中了 t-io框架.看了部分源代码和示例,先把helloworld敲 ...
- Spring MVC的配置与DispatcherServlet的分析
Spring MVC是一款Web MVC框架,是目前主流的Web MVC框架之一. Spring MVC工作原理简单来看如下图所示: 接下来进行Spring MVC的配置 首先我们配置Spring M ...
- 深入理解java虚拟机_第二章_读书笔记
1.本章内容目录: 概述 运行时数据区域 程序计数器 java虚拟机栈 本地方法栈 java堆 方法区 运行时常量池 直接内存 HotSpot虚拟机对象探秘 对象的创建 对象的内存布局 对象的访问定位 ...
- R语言高性能编程(二)
接着上一篇 一.减少内存使用的简单方法1.重用对象而不多占用内存 y <- x 是指新变量y指向包含X的那个内存块,只有当y被修改时才会复制到新的内存块,一般来说只要向量没有被其他对象引用,就可 ...
- python的Windows下的安装
1.先打开网址http://www.python.org/download/: 2.在下载列表中选择Window平台安装包, 找到最后 web-based installer 是需要通过联网完成安装的 ...
- php中向前台js中传送一个二维数组
在php中向前台js中传送一个二维数组,并在前台js接收获取其中值的全过程方法: (1),方法说明:现在后台将数组发送到前台 echo json_encode($result); 然后再在js页面中的 ...