TF.Learn 手写文字识别

 

转载请注明作者:梦里风林
Google Machine Learning Recipes 7
官方中文博客 - 视频地址
Github工程地址 https://github.com/ahangchen/GoogleML
欢迎Star,也欢迎到Issue区讨论

mnist问题

  • 计算机视觉领域的Hello world
  • 给定55000个图片,处理成28*28的二维矩阵,矩阵中每个值表示一个像素点的灰度,作为feature
  • 给定每张图片对应的字符,作为label,总共有10个label,是一个多分类问题

TensorFlow

  • 可以按教程用Docker安装,也可以直接在Linux上安装
  • 你可能会担心,不用Docker的话怎么开那个notebook呢?其实notebook就在主讲人的Github页
  • 可以用这个Chrome插件:npviewer直接在浏览器中阅读ipynb格式的文件,而不用在本地启动iPython notebook
  • 我们的教程在这里:ep7.ipynb
  • 把代码从ipython notebook中整理出来:tflearn_mnist.py

代码分析

  • 下载数据集
  1. mnist = learn.datasets.load_dataset('mnist')

恩,就是这么简单,一行代码下载解压mnist数据,每个img已经灰度化成长784的数组,每个label已经one-hot成长度10的数组

在我的深度学习笔记看One-hot是什么东西

  • numpy读取图像到内存,用于后续操作,包括训练集(只取前10000个)和验证集
  1. data = mnist.train.images
  2. labels = np.asarray(mnist.train.labels, dtype=np.int32)
  3. test_data = mnist.test.images
  4. test_labels = np.asarray(mnist.test.labels, dtype=np.int32)
  5. max_examples = 10000
  6. data = data[:max_examples]
  7. labels = labels[:max_examples]
  • 可视化图像
  1. def display(i):
  2. img = test_data[i]
  3. plt.title('Example %d. Label: %d' % (i, test_labels[i]))
  4. plt.imshow(img.reshape((28, 28)), cmap=plt.cm.gray_r)
  5. plt.show()

用matplotlib展示灰度图

  • 训练分类器
  • 提取特征(这里每个图的特征就是784个像素值)
  1. feature_columns = learn.infer_real_valued_columns_from_input(data)
  • 创建线性分类器并训练
  1. classifier = learn.LinearClassifier(feature_columns=feature_columns, n_classes=10)
  2. classifier.fit(data, labels, batch_size=100, steps=1000)

注意要制定n_classes为labels的数量

  • 分类器实际上是在根据每个feature判断每个label的可能性,
  • 不同的feature有的重要,有的不重要,所以需要设置不同的权重
  • 一开始权重都是随机的,在fit的过程中,实际上就是在调整权重

  • 最后可能性最高的label就会作为预测输出

  • 传入测试集,预测,评估分类效果

  1. result = classifier.evaluate(test_data, test_labels)
  2. print result["accuracy"]

速度非常快,而且准确率达到91.4%

可以只预测某张图,并查看预测是否跟实际图形一致

  1. # here's one it gets right
  2. print ("Predicted %d, Label: %d" % (classifier.predict(test_data[0]), test_labels[0]))
  3. display(0)
  4. # and one it gets wrong
  5. print ("Predicted %d, Label: %d" % (classifier.predict(test_data[8]), test_labels[8]))
  6. display(8)
  • 可视化权重以了解分类器的工作原理
  1. weights = classifier.weights_
  2. a.imshow(weights.T[i].reshape(28, 28), cmap=plt.cm.seismic)

  • 这里展示了8个张图中,每个像素点(也就是feature)的weights,
  • 红色表示正的权重,蓝色表示负的权重
  • 作用越大的像素,它的颜色越深,也就是权重越大
  • 所以权重中红色部分几乎展示了正确的数字

Next steps

TF.Learn的更多相关文章

  1. Google机器学习笔记(七)TF.Learn 手写文字识别

    转载请注明作者:梦里风林 Google Machine Learning Recipes 7 官方中文博客 - 视频地址 Github工程地址 https://github.com/ahangchen ...

  2. 学习笔记TF043:TF.Learn 机器学习Estimator、DataFrame、监督器Monitors

    线性.逻辑回归.input_fn()建立简单两个特征列数据,用特证列API建立特征列.特征列传入LinearClassifier建立逻辑回归分类器,fit().evaluate()函数,get_var ...

  3. 学习笔记TF042:TF.Learn、分布式Estimator、深度学习Estimator

    TF.Learn,TensorFlow重要模块,各种类型深度学习及流行机器学习算法.TensorFlow官方Scikit Flow项目迁移,谷歌员工Illia Polosukhin.唐源发起.Scik ...

  4. TF.learn学习

    官网地址:https://www.tensorflow.org/versions/r1.1/get_started/tflearn 1.代码例子 实现自定义的Estimator 使用DNNClassi ...

  5. 学习笔记TF044:TF.Contrib组件、统计分布、Layer、性能分析器tfprof

    TF.Contrib,开源社区贡献,新功能,内外部测试,根据反馈意见改进性能,改善API友好度,API稳定后,移到TensorFlow核心模块.生产代码,以最新官方教程和API指南参考. 统计分布.T ...

  6. CNN网络介绍与实践:王者荣耀英雄图片识别

    欢迎大家前往腾讯云社区,获取更多腾讯海量技术实践干货哦~ 作者介绍:高成才,腾讯Android开发工程师,2016.4月校招加入腾讯,主要负责企鹅电竞推流SDK.企鹅电竞APP的功能开发和技术优化工作 ...

  7. TensorFlow与主流深度学习框架对比

    引言:AlphaGo在2017年年初化身Master,在弈城和野狐等平台上横扫中日韩围棋高手,取得60连胜,未尝败绩.AlphaGo背后神秘的推动力就是TensorFlow--Google于2015年 ...

  8. TensorFlow 中文资源全集,官方网站,安装教程,入门教程,实战项目,学习路径。

    Awesome-TensorFlow-Chinese TensorFlow 中文资源全集,学习路径推荐: 官方网站,初步了解. 安装教程,安装之后跑起来. 入门教程,简单的模型学习和运行. 实战项目, ...

  9. 第九章——运行tensorflow(Up and Running with TensorFlow)

    本章简单介绍了TensorFlow的安装以及使用.一些细节需要在后续的应用中慢慢把握. TensorFlow并不仅仅局限于神经网络和机器学习,它甚至可以用于量子物理仿真. TensorFlow的优势: ...

随机推荐

  1. 2014-07-24 .NET实现微信公众号的消息回复与自定义菜单

    今天是在吾索实习的第12天.我们在这一天中,基本实现了微信公众号的消息回复与自定义菜单的创建. 首先,是实现消息回复,其关键点如下: 读取POST来的数据流:Stream 数据流变量 = HttpCo ...

  2. hdu 1860 统计字符

    Problem Description 统计一个给定字符串中指定的字符出现的次数 Input 测试输入包含若干测试用例,每个测试用例包含2行,第1行为一个长度不超过5的字符串,第2行为一个长度不超过8 ...

  3. zabbix linux被监控端部署

    测试使用agentd监听获取数据. 服务端的安装可以查看http://blog.chinaunix.net/space.php?uid=25266990&do=blog&id=3380 ...

  4. [置顶] 白话最小边覆盖总结--附加 hdu1151结题报告

    刚开始看到这个题目的时候就觉得想法很明了,就是不知道如何去匹配... 去网上看了不少人的解题报告,但是对于刚接触“最小边覆盖”的我来说....还是很困难滴....于是自己又开始一如以往学习“最大独立集 ...

  5. s16_day01

    一.基础 1.编码 ascii-->GB2312-->GB18030-->GBK-->unicode-->UTF8可变长 2.数据类型 int,long,float,co ...

  6. zabbix流量汇聚

    "服务器流量汇总"领导提需求,要把几个数据中心的数据汇总起来,于是就google了一下"zabbix流量汇总" 按照其中一篇博客做了出来,博客地址如下. htt ...

  7. 海量数据挖掘--DB优化篇

    上一篇博客我们介绍了针对大数据量的处理,我们应该对程序做什么样的处理,但是一个程序的优化是有底线的,我们要考虑人力,物力,程序的优化是海量数据处理的一部分,这里介绍我们的重头戏,对数据库的优化! 这里 ...

  8. VB.NET入门基础

    众所周知,Visual Basic.NET是由Visual Basic发展而来,这两者之间的升级使得Visual Basic语言发生了革命性的变革,使得由基于对象编程的Visual Basic过渡到了 ...

  9. mvc之验证IEnumerable<T> 类型

    假设我们有这么一种需求,我们要同时添加年级和年级下面的多个班级,我们一般会像下面这种做法. Action中我们这样接收: [HttpPost] public ActionResult CreateGr ...

  10. 关于CCRect

    一直有一个误区,因为之前处理的公司引擎是屏幕坐标系 导致觉得CCRect的坐标起始值(x,y),习惯性的认为就是左上角的点. 但是,真正的x,y值,是跟x轴与y轴相对应的.