MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt mnist = input_data.read_data_sets('MNIST_data', one_hot=True) tf.reset_default_graph() x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10])) pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred) cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1)) learning_rate = 0.01 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) training_epochs = 25
batch_size = 100 display_step = 1 save_path = 'model/' saver = tf.train.Saver() with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
avg_cost += c / total_batch if (epoch + 1) % display_step == 0:
print('epoch= ', epoch+1, ' cost= ', avg_cost)
print('finished') correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels})) save = saver.save(sess, save_path=save_path+'mnist.cpkt') print(" starting 2nd session ...... ") with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, save_path=save_path+'mnist.cpkt') correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) output = tf.argmax(pred, 1)
batch_xs, batch_ys = mnist.test.next_batch(2)
outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
print(outputval) im = batch_xs[0]
im = im.reshape(-1, 28) plt.imshow(im, cmap='gray')
plt.show() im = batch_xs[1]
im = im.reshape(-1, 28)
plt.imshow(im, cmap='gray')
plt.show()

TensorFlow——MNIST手写数据集的更多相关文章

  1. TensorFlow实战第五课(MNIST手写数据集识别)

    Tensorflow实现softmax regression识别手写数字 MNIST手写数字识别可以形象的描述为机器学习领域中的hello world. MNIST是一个非常简单的机器视觉数据集.它由 ...

  2. 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  3. matlab练习程序(神经网络识别mnist手写数据集)

    记得上次练习了神经网络分类,不过当时应该有些地方写的还是不对. 这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码. mnist数据集训练数据一共有28*28*60000个像素 ...

  4. TensorFlow系列专题(六):实战项目Mnist手写数据集识别

    欢迎大家关注我们的网站和系列教程:http://panchuang.net/ ,学习更多的机器学习.深度学习的知识! 目录: 导读 MNIST数据集 数据处理 单层隐藏层神经网络的实现 多层隐藏层神经 ...

  5. TensorFlow——MNIST手写数字识别

    MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/   一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...

  6. 用Kersa搭建神经网络【MNIST手写数据集】

    MNIST手写数据集的识别算得上是深度学习的”hello world“了,所以想要入门必须得掌握.新手入门可以考虑使用Keras框架达到快速实现的目的. 完整代码如下: # 1. 导入库和模块 fro ...

  7. MNIST手写数据集在运行中出现问题解决方案

    今天在运行手写数据集的过程中,出现一个问题,代码没有问题,但是运行的时候一直报错,错误如下: urllib.error.URLError: <urlopen error [SSL: CERTIF ...

  8. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  9. 吴裕雄 python 神经网络——TensorFlow实现回归模型训练预测MNIST手写数据集

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

随机推荐

  1. 2019-9-2-git-需要知道的1000个问题

    title author date CreateTime categories git 需要知道的1000个问题 lindexi 2019-09-02 10:12:31 +0800 2018-2-13 ...

  2. 深入java面向对象四:Java 内部类种类及使用解析(转)

    内部类Inner Class 将相关的类组织在一起,从而降低了命名空间的混乱. 一个内部类可以定义在另一个类里,可以定义在函数里,甚至可以作为一个表达式的一部分. Java中的内部类共分为四种: 静态 ...

  3. [转]Springboot和SpringMVC区别

    spring boot只是一个配置工具,整合工具,辅助工具. springmvc是框架,项目中实际运行的代码 Spring 框架就像一个家族,有众多衍生产品例如 boot.security.jpa等等 ...

  4. 随机抽样 (numpy.random)

    随机抽样 (numpy.random) 简单的随机数据 rand(d0, d1, ..., dn) 随机值 >>> np.random.rand(3,2) array([[ 0.14 ...

  5. 【31.42%】【CF 714A】Meeting of Old Friends

    time limit per test 1 second memory limit per test 256 megabytes input standard input output standar ...

  6. Linux 内核总线

    一个总线是处理器和一个或多个设备之间的通道. 为设备模型的目的, 所有的设备都通过 一个总线连接, 甚至当它是一个内部的虚拟的,"平台"总线. 总线可以插入另一个 - 一个 USB ...

  7. ZR提高失恋测4

    ZR提高失恋测4 比赛链接 A (方便讨论,设读入的串为\(S,T\)答案串为\(A\)) 首先\(*\)只会有一个 这是这道题目中非常重要的一个结论 简单证明一下? 因为\(*\)可以代表所有的字符 ...

  8. codeforces 1194F (组合数学)

    Codeforces 11194F (组合数学) 传送门:https://codeforces.com/contest/1194/problem/F 题意: 你有n个事件,你需要按照1~n的顺序完成这 ...

  9. boostrap-非常好用但是容易让人忽略的地方【7】:list-unstyled list-inline

    无样式列表 list-unstyled:去掉ul的默认样式 内联列表 list-inline:将ul子元素放置于同一行

  10. 创意app1

      app名称: 与我相似的人 app目的: 旨在通过云匹配,搜索到与自己类似爱好或者性格的人用户相似的内容:衣服品牌鞋子手机笔记本键盘鼠标相机刮胡刀自行车工作  说明: 现有的格局 百度贴吧是面向多 ...