Demo侠可能是我等小白进阶的必经之路了,如今在AI领域,我也是个研究Demo的小白。用了两三天装好环境,跑通Demo,自学Python语法,进而研究这个Demo。当然过程中查了很多资料,充分发挥了小白的主观能动性,总算有一些收获需要总结下。

  不多说,算法在代码中,一切也都在代码中。

  1. import os
  2. os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
  3.  
  4. #获得数据集
  5. from tensorflow.examples.tutorials.mnist import input_data
  6. mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
  7.  
  8. import tensorflow as tf
  9.  
  10. #输入图像数据占位符
  11. x = tf.placeholder(tf.float32, [None, 784])
  12.  
  13. #权值和偏差
  14. W = tf.Variable(tf.zeros([784, 10]))
  15. b = tf.Variable(tf.zeros([10]))
  16.  
  17. #使用softmax模型
  18. y = tf.nn.softmax(tf.matmul(x, W) + b)
  19.  
  20. #代价函数占位符
  21. y_ = tf.placeholder(tf.float32, [None, 10])
  22.  
  23. #交叉熵评估代价
  24. cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
  25.  
  26. #使用梯度下降算法优化:学习速率为0.5
  27. train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
  28.  
  29. #Session(交互方式)
  30. sess = tf.InteractiveSession()
  31.  
  32. #初始化变量
  33. tf.global_variables_initializer().run()
  34.  
  35. #训练模型,训练1000次
  36. for _ in range(1000):
  37. batch_xs, batch_ys = mnist.train.next_batch(100)
  38. sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
  39.  
  40. #计算正确率
  41. correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
  42.  
  43. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  44. print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

  看完这个Demo,顿时感觉Python真是一门好语言,Tensorflow是一个好框架,就跟之前掌握Matlab以后,用Matlab做仿真的感觉一样。

  为什么看这几行代码看了两三天,因为看懂很容易,但了解代码背后的意义更重要,如果把一个Demo看透了,那么后边举一反三就会很容易了,我向来就是这样学习的,本小白当年也是个学霸?!

  来一起看下这里边有什么玄机和坑吧,记录一下,人老了记性不好(^-^)。

  看到1,2行代码,不要懵,这个作用是设置日志级别,os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只显示 warning 和 Error,等于1是显示所有信息。不加这两行会有个提示(Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2,具体可以看这里) 

  第5行是一个引用声明,从tensorflow.examples.tutorials.mnist 引用一个名为 input_data 的函数,可以看一下input_data是什么样子的:

  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4.  
  5. import gzip
  6. import os
  7. import tempfile
  8.  
  9. import numpy
  10. from six.moves import urllib
  11. from six.moves import xrange # pylint: disable=redefined-builtin
  12. import tensorflow as tf
  13. from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

  原来input_data里边也是引用声明,真正想用到的实际是tensorflow.contrib.learn.python.learn.datasets.mnist里的read_data_sets,看一下代码:

  1. def read_data_sets(train_dir,
  2. fake_data=False,
  3. one_hot=False,
  4. dtype=dtypes.float32,
  5. reshape=True,
  6. validation_size=5000,
  7. seed=None,
  8. source_url=DEFAULT_SOURCE_URL):
  9. if fake_data:
  10. ...
  11.  
  12. if not source_url: # empty string check
  13. ...
  14.  
  15. local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
  16. source_url + TRAIN_IMAGES)
  17. with gfile.Open(local_file, 'rb') as f:
  18. train_images = extract_images(f)
  19.  
  20. ...
  21.  
  22. if not 0 <= validation_size <= len(train_images):
  23. raise ValueError('Validation size should be between 0 and {}. Received: {}.'
  24. .format(len(train_images), validation_size))
  25.  
  26. validation_images = train_images[:validation_size]
  27. validation_labels = train_labels[:validation_size]
  28. train_images = train_images[validation_size:]
  29. train_labels = train_labels[validation_size:]
  30.  
  31. options = dict(dtype=dtype, reshape=reshape, seed=seed)
  32.  
  33. train = DataSet(train_images, train_labels, **options)
  34. validation = DataSet(validation_images, validation_labels, **options)
  35. test = DataSet(test_images, test_labels, **options)
  36.  
  37. return base.Datasets(train=train, validation=validation, test=test)

  mnist最终得到的是base.Datasets,完成了数据读取。这里边的细节还需要完了再仔细研究下。

  顺便记录下自编的函数的定义方法:

  1. def Mycollect(My , thing):
  2.  
  3. try:
  4. count = My[thing]
  5. except KeyError:
  6. count = 0
  7.  
  8. return count
  9.  
  10. from TestFunction import Mycollect
  11. My = {'a':10, 'b':15, 'c':5}
  12. thing = 'a'
  13. print(Mycollect(My , thing));

  第11行的placeholder,需要注意下,是用了占位符,也就是先安排位置,而不先提供具体数据,也就是说都是模型(管道)的构建过程(这里用管道来类比,我觉得比较恰当)。注意下placeholder的语法就可以,指定了type和shape,这里的None表示有多少幅图片是未知的,也就是说样本数是未知的。这里的坑在于,如果我们用print看的话会发现,构建的是张量(Tensor)而不是矩阵,这里对熟悉matlab的同学来说可能是个坑。可以注意下张量的定义方式。

  第14和15行是定义了变量,如果只看tf.zeros([10])的话也是个张量的,只是外边又加了变量的声明。所以后边可以直接乘的,这个也不难理解了。

  第18行的matmul是张量相乘,然后使用了softmax模型,目的是把结果进行概率化。巧妙,只想说这两个字,这个就是进行归一化,搞算法这个是比较常用的,学校时候这个词很火,我们最终想得到的是一个指定的数组,所以用这个模型来匹配我的规则。

  21行是什么,看完就知道是实际的输出,然后在24行做交叉熵。终于又碰到熵这个老朋友了。交叉熵简单理解为概率分布的距离,在这里作为一个loss_function。第27行使用了梯度下降来优化这个loss_function,最终是想找到最优时候的一个模型,这里的最优指的是通过这个模型,得到的结果和实际值最接近。

  第30行,创建一个session。

  第33行,初始化变量。

  第37行,可以去看下next_batch的源码,作用是选取100个样本来训练。

  第41行,注意equal函数的作用,第43行来做类型转换,然后取平均值。(代码很巧妙,很优雅,很爽)

  最终第44行输出模型的准确率。

  好了,这大概就是我的一点点总结了,算是入了个门,接下来我会更多的举一反三,深入掌握其精髓,我会努力走得更远。

  作为一个小白,我要继续努力向大牛学习,吃饭去咯,下周再战。

 

  

MNIST手写识别的更多相关文章

  1. tensorflow笔记(四)之MNIST手写识别系列一

    tensorflow笔记(四)之MNIST手写识别系列一 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7436310.html ...

  2. tensorflow笔记(五)之MNIST手写识别系列二

    tensorflow笔记(五)之MNIST手写识别系列二 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7455233.html ...

  3. win10下通过Anaconda安装TensorFlow-GPU1.3版本,并配置pycharm运行Mnist手写识别程序

    折腾了一天半终于装好了win10下的TensorFlow-GPU版,在这里做个记录. 准备安装包: visual studio 2015: Anaconda3-4.2.0-Windows-x86_64 ...

  4. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  5. 使用tensorflow实现mnist手写识别(单层神经网络实现)

    import tensorflow as tf import tensorflow.examples.tutorials.mnist.input_data as input_data import n ...

  6. Tensorflow编程基础之Mnist手写识别实验+关于cross_entropy的理解

    好久没有静下心来写点东西了,最近好像又回到了高中时候的状态,休息不好,无法全心学习,恶性循环,现在终于调整的好一点了,听着纯音乐突然非常伤感,那些曾经快乐的大学时光啊,突然又慢慢的一下子出现在了眼前, ...

  7. Haskell手撸Softmax回归实现MNIST手写识别

    Haskell手撸Softmax回归实现MNIST手写识别 前言 初学Haskell,看的书是Learn You a Haskell for Great Good, 才刚看到Making Our Ow ...

  8. 基于tensorflow的MNIST手写识别

    这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环.我也是! 这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟.在 ...

  9. Tensorflow实践:CNN实现MNIST手写识别模型

    前言 本文假设大家对CNN.softmax原理已经比较熟悉,着重点在于使用Tensorflow对CNN的简单实践上.所以不会对算法进行详细介绍,主要针对代码中所使用的一些函数定义与用法进行解释,并给出 ...

  10. 基于tensorflow实现mnist手写识别 (多层神经网络)

    标题党其实也不多,一个输入层,三个隐藏层,一个输出层 老样子先上代码 导入mnist的路径很长,现在还记不住 import tensorflow as tf import tensorflow.exa ...

随机推荐

  1. cannot import name '_imaging' 与No module named PIL解决方法

    今天学习廖雪峰的python 第三方模块pillow一章. 直接使用from PIL import Image 会报"No module named PIL",显然这是没有安装pi ...

  2. Application "org.eclipse.ui.ide.workbench" could not be found in the registry.问题的解决

    今天升级Eclipse,升级完Restart,碰到启动不了让看日志,日志里主要错误信息即是Application "org.eclipse.ui.ide.workbench" co ...

  3. Using SSH and SFTP in Mac OS X

    http://answers.stat.ucla.edu/groups/answers/wiki/7a848/ SH and SFTP are command line applications av ...

  4. java基本数据类型及其包装类

    1.String类 String s1 = "hello world"; String s2 = "hello world"; String s3 = s1 + ...

  5. 【Java入门提高篇】Day16 Java异常处理(下)

    今天继续讲解java中的异常处理机制,主要介绍Exception家族的主要成员,自定义异常,以及异常处理的正确姿势. Exception家族 一图胜千言,先来看一张图. Exception这是一个父类 ...

  6. java8完全解读一

    java8完全解读 java8完全解读前言java8的一些新特性1.为什么要用java8?1.1首先想到的逻辑应该是如下1.2使用策略模式来解这个问题1.3使用策略模式和内部类来解决问题1.4使用策略 ...

  7. QT窗体的小技巧

    1.界面透明 setWindowOpacity(0.8);//构造函数中加此句,1为不透明,0为完全透明,0.8为80%不透明. 2.设置背景图片 QPixmap pixmap = QPixmap(& ...

  8. webpack学习之路01

    webpack是什么 1.模块化 能将css等静态文件模块化 2.借助于插件和加载器 webpack优势是什么 1.代码分离 各做各的 2.装载器(css,sass,jsx,es6等等) 3.智能解析 ...

  9. django-团队简介的网页

    团队简介的网页,是使用Django完成的 关于Django的教程网址:http://www.runoob.com/django/django-tutorial.html 小组作业成果如下:

  10. tomcat 构建问题记录

    mvng构建程序包com.sun.image.codec.jpeg不存在------->缺少serlet的jar包 MasterSlaveRoutingDataSource不是抽象的, 并且未覆 ...