这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环。我也是!

这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟。在TensorFlow的中文介绍文档中的内容,有些可能与你使用的tensorflow的版本不一致了,我这里用到的tensorflow的版本就有这个问题。 另外,还给大家说下,例子中的MNIST所用到的资源图片,在原始的官网上,估计很多人都下载不到了。我也提供一下下载地址。

我的tensorflow的版本信息:

>>> import tensorflow as tf
>>> print tf.VERSION
1.0.
>>> print tf.GIT_VERSION
v1.0.0--g4763edf-dirty
>>> print tf.COMPILER_VERSION
4.8.

下面,就看看,我参考的中文tensorflow网站的代码,在自己的环境里,运行的结果。

 [root@bogon tensorflow]# python
Python 2.7. (default, Nov , ::)
[GCC 4.8. (Red Hat 4.8.-)] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow.examples.tutorials.mnist.input_data as input_data
>>> mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
Traceback (most recent call last):
File "<stdin>", line , in <module>
File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py", line , in read_data_sets
SOURCE_URL + TRAIN_IMAGES)
File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line , in maybe_download
temp_file_name, _ = urlretrieve_with_retry(source_url)
File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line , in wrapped_fn
return fn(*args, **kwargs)
File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line , in urlretrieve_with_retry
return urllib.request.urlretrieve(url, filename)
File "/usr/lib64/python2.7/urllib.py", line , in urlretrieve
return _urlopener.retrieve(url, filename, reporthook, data)
File "/usr/lib64/python2.7/urllib.py", line , in retrieve
fp = self.open(url, data)
File "/usr/lib64/python2.7/urllib.py", line , in open
return self.open_unknown_proxy(proxy, fullurl, data)
File "/usr/lib64/python2.7/urllib.py", line , in open_unknown_proxy
raise IOError, ('url error', 'invalid proxy for %s' % type, proxy)
IOError: [Errno url error] invalid proxy for http: '10.90.1.101:8080'
>>>
>>> mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
28 Extracting MNIST_data/train-images-idx3-ubyte.gz
29 Extracting MNIST_data/train-labels-idx1-ubyte.gz
30 Extracting MNIST_data/t10k-images-idx3-ubyte.gz
31 Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
>>> import tensorflow as tf
>>> x = tf.placeholder(tf.float32, [None, ])
>>> W = tf.Variable(tf.zeros([,]))
>>> b = tf.Variable(tf.zeros([]))
>>> y = tf.nn.softmax(tf.matmul(x,W) + b)
>>> y_ = tf.placeholder("float", [None,])
>>> cross_entropy = -tf.reduce_sum(y_*tf.log(y))
>>> train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
40 >>> init = tf.initialize_all_variables()
WARNING:tensorflow:From <stdin>:: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after --.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
44 >>> init = tf.global_variables_initializer()
>>> sess = tf.Session()
W tensorflow/core/platform/cpu_feature_guard.cc:] The TensorFlow library wasn't compiled to use SSE3 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
>>> sess.run(init)
>>> for i in range():
... batch_xs, batch_ys = mnist.train.next_batch()
... sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
...
>>> correct_prediction = tf.equal(tf.argmax(y,), tf.argmax(y_,))
>>> accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
>>> print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
0.9088
>>>

上述日志,是我的测试全过程记录,上面反映的信息有如下几点:

1. 红色部分的错误,因为我本地机器是通过代理上网的,这个过程中,tensorflow会用urllib进行MNIST的图片资源的下载,由于网络问题,资源文件下载失败。

2. 都有哪些资源文件要下载呢?追踪日志中的文件/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py第211行前后:

def read_data_sets(train_dir,
fake_data=False,
one_hot=False,
dtype=dtypes.float32,
reshape=True,
validation_size=):
if fake_data: def fake():
return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) train = fake()
validation = fake()
test = fake()
return base.Datasets(train=train, validation=validation, test=test) TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
SOURCE_URL + TRAIN_IMAGES)
with open(local_file, 'rb') as f:
train_images = extract_images(f) local_file = base.maybe_download(TRAIN_LABELS, train_dir,
SOURCE_URL + TRAIN_LABELS)
with open(local_file, 'rb') as f:
train_labels = extract_labels(f, one_hot=one_hot) local_file = base.maybe_download(TEST_IMAGES, train_dir,
SOURCE_URL + TEST_IMAGES)
with open(local_file, 'rb') as f:
test_images = extract_images(f) local_file = base.maybe_download(TEST_LABELS, train_dir,
SOURCE_URL + TEST_LABELS)
with open(local_file, 'rb') as f:
test_labels = extract_labels(f, one_hot=one_hot) if not <= validation_size <= len(train_images):
raise ValueError(
'Validation size should be between 0 and {}. Received: {}.'
.format(len(train_images), validation_size)) validation_images = train_images[:validation_size]
validation_labels = train_labels[:validation_size]
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:] train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
validation = DataSet(validation_images,
validation_labels,
dtype=dtype,
reshape=reshape)
test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape) return base.Datasets(train=train, validation=validation, test=test)

看到上面红色的部分,就是这里需要下载的图片资源文件。这个,我的网络环境是下载不了的。我通过其他途径下载到了这里需要的资源。我将下载的图片资源,放在了我进入python时所在的路径下。虽然直接下载没有成功,但是在当前路径下还是创建了MNIST_data的目录的。如下图,红色圈目录就是程序创建的目录。我将下载的train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz放在MNIST_data目录了

然后,再次执行mnist = input_data.read_data_sets("MNIST_data/", one_hot=True),就ok了,不会报错。得到28-31行的输出信息。

3. 执行到第40行的代码时,爆出WARNING,提示用新的函数,按照提示信息,执行了第41行的代码,OK。说明版本兼容性,在tensorflow中需要注意

4. 执行后,得到结果,如60行显示,识别率为0.9088。

关于MNIST的这个例子的手写识别性能的理论,不是本博文的重点,读者可以参照MNIST相关的文章自行学习。

最后,附上MNIST这个例子中,用到的资源图片下载地址,点击进行下载。(说明:需要积分才能下载的,谅解)

基于tensorflow的MNIST手写识别的更多相关文章

  1. 基于tensorflow实现mnist手写识别 (多层神经网络)

    标题党其实也不多,一个输入层,三个隐藏层,一个输出层 老样子先上代码 导入mnist的路径很长,现在还记不住 import tensorflow as tf import tensorflow.exa ...

  2. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  3. 基于TensorFlow的MNIST手写数字识别-初级

    一:MNIST数据集    下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...

  4. 使用tensorflow实现mnist手写识别(单层神经网络实现)

    import tensorflow as tf import tensorflow.examples.tutorials.mnist.input_data as input_data import n ...

  5. 基于TensorFlow的MNIST手写数字识别-深入

    构建多层卷积神经网络时需要多组W和偏移项b,我们封装2个方法来产生W和b 初级MNIST中用0初始化W和b,这里用噪声初始化进行对称打破,防止产生梯度0,同时用一个小的正值来初始化b避免dead ne ...

  6. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  7. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  8. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  9. tensorflow笔记(四)之MNIST手写识别系列一

    tensorflow笔记(四)之MNIST手写识别系列一 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7436310.html ...

随机推荐

  1. 使用python绘出常见函数

    '''''' ''' mpl.rcParams['font.sans-serif'] = ['SimHei'] mpl.rcParams['axes.unicode_minus'] = False用来 ...

  2. react native 之 获取键盘高度

    多说不如多撸: /** * Created by shaotingzhou on 2017/2/23. *//** * Sample React Native App * https://github ...

  3. 后台返回json字符串 页面js报错 Uncaught SyntaxError: Unexpected identifier

    后台json字符串是 [{"name": "报销申请", "id": "start"}, {"name&quo ...

  4. phpcms pc_base::load

    //数据库pc_base::load_model(‘*_model’) 加载数据库模型 (一张表对应一个数据库模型类),即 modle/*_model.class.php每一个数据库模型类都会继承底层 ...

  5. MySQL篇,第四章:数据库知识4

    MySQL 数据库 4 数据备份(在Linux终端操作) 1.命令格式 mysqldump -u用户名 -p 源库名 > 路径/XXX.sql 2.源库名的表示方式 --all-database ...

  6. 《网页文档/文字复制方法大全》 - imsoft.cnblogs

    <网页文档/文字复制方法大全> 一: 1.首先,找到自己要的文档. 2.文章题目复制,在搜索引擎的框框里输入:site:wenku.baidu.com "题目"/sit ...

  7. QT学习相关

    1. vs2012的编译器对execution_character_set("utf-8")无反应的bug在vs2013中解决 2. 安装上vs2013后,重装的qt插件,发现不能 ...

  8. 20155208实验二 Java面向对象程序设计

    20155208实验二 Java面向对象程序设计 一.实验内容 1.初步掌握单元测试和TDD 2.理解并掌握面向对象三要素:封装.继承.多态 3.初步掌握UML建模 4.熟悉S.O.L.I.D原则 5 ...

  9. Beta周第7次Scrum会议(11/16)【王者荣耀交流协会】

    一.小组信息 队名:王者荣耀交流协会 小组成员 队长:高远博 成员:王超,袁玥,任思佳,王磊,王玉玲,冉华 小组照片 二.开会信息 时间:2017/11/16 17:03~17:17,总计14min. ...

  10. oracle 数据库备份、还原、和使用心得(表丢失、视图丢失的解决办法)

    一.oracle数据备份:exp 关键字     说明(默认值)                  关键字      说明(默认值) --------------------------------- ...