四、经典入门demo:识别手写数字(MNIST)

常规的编程入门有“Hello world”程序,而深度学习的入门程序则是MNIST,一个识别28*28像素的图片中的手写数字的程序。

MNIST的数据和官网:
http://yann.lecun.com/exdb/mnist/

深度学习的内容,其背后会涉及比较多的数学原理,作为一个初学者,受限于我个人的数学和技术水平,也许并不足以准确讲述相关的数学原理,因此,本文会更多的关注“应用层面”,不对背后的数学原理进行展开,感谢谅解。

  1. 加载数据
    程序执行的第一步当然是加载数据,根据我们之前获得的数据集主要包括两部分:60000的训练数据集(mnist.train)和10000的测试数据集(mnist.test)。里面每一行,是一个2828=784的数组,数组的本质就是将2828像素的图片,转化成对应的像素点阵。
    例如手写字1的图片转换出来的对应矩阵表示如下:

    之前我们经常听说,图片方面的深度学习需要大量的计算能力,甚至需要采用昂贵、专业的GPU(Nvidia的GPU),从上述转化的案例我们就已经可以获得一些答案了。一张784像素的图片,对学习模型来说,就有784个特征,而我们实际的相片和图片动辄几十万、百万级别,则对应的基础特征数也是这个数量级,基于这样数量级的数组进行大规模运算,没有强大的计算能力支持,确实寸步难行。当然,这个入门的MNIST的demo还是可以比较快速的跑完。
    Demo中的关键代码(读取并且加载数据到数组对象中,方便后面使用):

  2. 构建模型
    MNIST的每一张图片都表示一个数字,从0到9。而模型最终期望获得的是:给定一张图片,获得代表每个数字的概率。比如说,模型可能推测一张数字9的图片代表数字9的概率是80%但是判断它是8的概率是5%(因为8和9都有上半部分的小圆),然后给予它代表其他数字的概率更小的值。

    MNIST的入门例子,采用的是softmax回归(softmax regression),softmax模型可以用来给不同的对象分配概率。
    为了得到一张给定图片属于某个特定数字类的证据(evidence),我们对图片的784个特征(点阵里的各个像素值)进行加权求和。如果某个特征(像素值)具有很强的证据说明这张图片不属于该类,那么相应的权重值为负数,相反如果某个特征(像素值)拥有有利的证据支持这张图片属于这个类,那么权重值是正数。类似前面提到的房价估算例子,对每一个像素点作出了一个权重分配。
    假设我们获得一张图片,需要计算它是8的概率,转化成数学公式则如下:

    公式中的i代表需要预测的数字(8),代表预测数字为8的情况下,784个特征的不同权重值,代表8的偏置量(bias),X则是该图片784个特征的值。通过上述计算,我们则可以获得证明该图片是8的证据(evidence)的总和,softmax函数可以把这些证据转换成概率 y。(softmax的数学原理,辛苦各位查询相关资料哈)
    将前面的过程概括成一张图(来自官方)则如下:

    不同的特征x和对应不同数字的权重进行相乘和求和,则获得在各个数字的分布概率,取概率最大的值,则认为是我们的图片预测结果。
    将上述过程写成一个等式,则如下:

    该等式在矩阵乘法里可以非常简单地表示,则等价为:

    不展开里面的具体数值,则可以简化为:

    如果我们对线性代数中矩阵相关内容有适当学习,其实,就会明白矩阵表达在一些问题上,更易于理解。如果对矩阵内容不太记得了,也没有关系,后面我会附加上线性代数的视频。
    虽然前面讲述了这么多,其实关键代码就四行:

    上述代码都是类似变量占位符,先设置好模型计算方式,在真实训练流程中,需要批量读取源数据,不断给它们填充数据,模型计算才会真实跑起来。tf.zeros则表示,先给它们统一赋值为0占位。X数据是从数据文件中读取的,而w、b是在训练过程中不断变化和更新的,y则是基于前面的数据进行计算得到。

  3. 损失函数和优化设置
    为了训练我们的模型,我们首先需要定义一个指标来衡量这个模型是好还是坏。这个指标称为成本(cost)或损失(loss),然后尽量最小化这个指标。简单的说,就是我们需要最小化loss的值,loss的值越小,则我们的模型越逼近标签的真实结果。
    Demo中使用的损失函数是“交叉熵”(cross-entropy),它的公式如下:

    y 是我们预测的概率分布, y' 是实际的分布(我们输入的),交叉熵是用来衡量我们的预测结果的不准确性。TensorFlow拥有一张描述各个计算单元的图,也就是整个模型的计算流程,它可以自动地使用反向传播算法(backpropagation algorithm),来确定我们的权重等变量是如何影响我们想要最小化的那个loss值的。然后,TensorFlow会用我们设定好的优化算法来不断修改变量以降低loss值。
    其中,demo采用梯度下降算法(gradient descent algorithm)以0.01的学习速率最小化交叉熵。梯度下降算法是一个简单的学习过程,TensorFlow只需将每个变量一点点地往使loss值不断降低的方向更新。
    对应的关键代码如下:

    备注内容:
    交叉熵:http://colah.github.io/posts/2015-09-Visual-Information/
    反向传播:http://colah.github.io/posts/2015-08-Backprop/

在代码中会看见one-hot vector的概念和变量名,其实这个是个非常简单的东西,就是设置一个10个元素的数组,其中只有一个是1,其他都是0,以此表示数字的标签结果。
例如表示数字3的标签值:
[0,0,0,1,0,0,0,0,0,0]

  1. 训练运算和模型准确度测试
    通过前面的实现,我们已经设置好了整个模型的计算“流程图”,它们都成为TensorFlow框架的一部分。于是,我们就可以启动我们的训练程序,下面的代码的含义是,循环训练我们的模型500次,每次批量取50个训练样本。

    其训练过程,其实就是TensorFlow框架的启动训练过程,在这个过程中,python批量地将数据交给底层库进行处理。
    我在官方的demo里追加了两行代码,每隔50次则额外计算一次当前模型的识别准确率。它并非必要的代码,仅仅用于方便观察整个模型的识别准确率逐步变化的过程。

    当然,里面涉及的accuracy(预测准确率)等变量,需要在前面的地方定义占位:

    当我们训练完毕,则到了验证我们的模型准确率的时候,和前面相同:

    我的demo跑出来的结果如下(softmax回归的例子运行速度还是比较快的),当前的准确率是0.9252:

  2. 实时查看参数的数值的方法
    刚开始跑官方的demo的时候,我们总想将相关变量的值打印出来看看,是怎样一种格式和状态。从demo的代码中,我们可以看见很多的Tensor变量对象,而实际上这些变量对象都是无法直接输出查看,粗略地理解,有些只是占位符,直接输出的话,会获得类似如下的一个对象:
    Tensor("Equal:0", shape=(?,), dtype=bool)
    既然它是占位符,那么我们就必须喂一些数据给它,它才能将真实内容展示出来。因此,正确的方法是,在打印时通常需要加上当前的输入数据给它。
    例如,查看y的概率数据:
    print(sess.run(y, feed_dict={x: batchxs, y: batch_ys}))
    部分非占位符的变量还可以这样输出来:
    print(W.eval())

总的来说,92%的识别准确率是比较令人失望,因此,官方的MNIST其实也有多种模型的不同版本,其中比较适合图片处理的CNN(卷积神经网络)的版本,可以获得99%以上的准确率,当然,它的执行耗时也是比较长的。
(备注:cnn_mnist.py就是卷积神经网络版本的,后面有附带微云网盘的下载url)
前馈神经网络(feed-forward neural network)版本的MNIST,可达到97%:

分享在微云上的数据和源码:
http://url.cn/44aZOpP
(备注:国外网站下载都比较慢,我这份下载相对会快一些,在环境已经搭建完毕的情况下,执行里面的run.py即可)

五、和业务场景结合的demo:预测用户是否是超级会员身份
根据前面的内容,我们对上述基于softmax只是三层(输入、处理、输出)的神经网络模型已经比较熟悉,那么,这个模型是否可以应用到我们具体的业务场景中,其中的难度大吗?为了验证这一点,我拿了一些现网的数据来做了这个试验。

  1. 数据准备

    我将一个现网的电影票活动的用户参与数据,包括点击过哪些按钮、手机平台、IP地址、参与时间等信息抓取了出来。其实这些数据当中是隐含了用户的身份信息的,例如,某些礼包的必须是超级会员身份才能领取,如果这个按钮用户点击领取成功,则可以证明该用户的身份肯定是超级会员身份。当然,我只是将这些不知道相不相关的数据特征直观的整理出来,作为我们的样本数据,然后对应的标签为超级会员身份。
    用于训练的样本数据格式如下:

    第一列是QQ号码,只做认知标识的,第二列表示是否超级会员身份,作为训练的标签值,后面的就是IP地址,平台标志位以及参与活动的参与记录(0是未成功参与,1表示成功参与)。则获得一个拥有11个特征的数组(经过一些转化和映射,将特别大的数变小):
    [0.9166666666666666, 0.4392156862745098, 0.984313725490196, 0.7411764705882353, 0.2196078431372549, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]
    对应的是否是超级数据格式如下,作为监督学习的标签:
    超级会员:[0, 1]
    非超级会员:[1, 0]

这里需要专门解释下,在实际应用中需要做数据转换的原因。一方面,将这些数据做一个映射转化,有助于简化数据模型。另一方面,是为了规避NaN的问题,当数值过大,在一些数学指数和除法的浮点数运算中,有可能得到一个无穷大的数值,或者其他溢出的情形,在Python里会变为NaN类型,这个类型会破坏掉后续全部计算结果,导致计算异常。
例如下图,就是特征数值过大,在训练过程中,导致中间某些参数累计越来越大,最终导致产生NaN值,后续的计算结果全部被破坏掉:

而导致NaN的原因在复杂的数学计算里,会产生无穷大或者无穷小。例如,在我们的这个demo中,产生NaN的原因,主要是因为softmax的计算导致。

RuntimeWarning: divide by zero encountered in log

刚开始做实际的业务应用,就发现经常跑出极奇怪异的结果(遇到NaN问题,我发现程序也能继续走下去),几经排查才发现是NAN值问题,是非常令人沮丧的。当然,经过仔细分析问题,发现也并非没有排查的方式。因为,NaN值是个奇特的类型,可以采用下述编码方式NaN != NaN来检测自己的训练过程中,是否出现的NaN。
关键程序代码如下:

我采用上述方法,非常顺利地找到自己的深度学习程序,在学习到哪一批数据时产生的NaN。因此,很多原始数据我们都会做一个除以某个值,让数值变小的操作。例如官方的MNIST也是这样做的,将256的像素颜色的数值统一除以255,让它们都变成一个小于1的浮点数。
MNIST在处理原始图片像素特征数据时,也对特征数据进行了变小处理:

NaN值问题一度深深地困扰着我(往事不堪回首-__-!!),特别放到这里,避免入门的同学踩坑。

  1. 执行结果
    我准备的训练集(6700)和测试集(1000)数据并不多,不过,超级会员身份的预测准确率最终可以达到87%。虽然,预测准确率是不高,这个可能和我的训练集数据比较少有关系,不过,整个模型也没有花费多少时间,从整理数据、编码、训练到最终跑出结果,只用了2个晚上的时间。

    下图是两个实际的测试例子,例如,该模型预测第一个QQ用户有82%的概率是非超级会员用户,17.9%的概率为超级会员用户(该预测是准确的)。

    通过上面的这个例子,我们会发觉其实对于某些比较简单的场景下应用,我们是可以比较容易就实现的。

六、其他模型

  1. CIFAR-10识别图片分类的demo(官方)
    CIFAR-10数据集的分类是机器学习中一个公开的基准测试问题,它任务是对一组32x32RGB的图像进行分类,这些图像涵盖了10个类别:飞机, 汽车, 鸟, 猫, 鹿, 狗, 青蛙, 马, 船和卡车。
    这也是官方的重要demo之一。

    更详细的介绍内容:
    http://www.cs.toronto.edu/~kriz/cifar.html
    http://tensorfly.cn/tfdoc/tutorials/deep_cnn.html

该例子执行的过程比较长,需要耐心等待。
我在机器上的执行过程和结果:
cifar10_train.py用于训练:

cifar10_eval.py用于检验结果:

识别率不高是因为该官方模型的识别率本来就不高:

另外,官方的例子我首次在1月5日跑的时候,还是有一些小问题的,无法跑起来(最新的官方可能已经修正),建议可以直接使用我放到微云上的版本(代码里面的log和读取文件的路径,需要调整一下)。
源码下载:http://url.cn/44mRzBh

微云盘里,不含训练集和测试集的图片数据,但是,程序如果检测到这些图片不存在,会自行下载:

  1. 是否大于5岁的测试demo
    为了检验softma回归模型是否能够学习到一些我自己设定好的规则,我做了一个小demo来测试。我通过随机数生成的方式构造了一系列的数据,让前面的softmax回归模型去学习,最终看看模型能否通过训练集的学习,最终100%预测这个样本数据是否大于5岁。
    模型和数据本身都比较简单,构造的数据的方式:
    我随机构造一个只有2个特征纬度的样本数据,[year, 1],其中year随机取值0-10,数字1是放进去作为干扰。
    如果year大于5岁,则标签设置为:[0, 0, 1];
    否则,标签设置为:[0, 1, 0]。

生成了6000条假训练集去训练该模型,最终它能做到100%成功预测准确:

微云下载(源码下载):
http://url.cn/44mKFNK

  1. 基于RNN的古诗学习
    最开头的AI写古诗,非常令人感到惊艳,那个demo是美国的一个研究者做出来的,能够根据主题生成不能的古诗,而且古诗的质量还比较高。于是,我也尝试在自己的机器上也跑一个能够写古诗的模型,后来我找到的是一个基于RNN的模型。RNN循环神经网络(Recurrent Neural Networks),是非常常用的深度学习模型之一。我基于一个外部的demo,进行一些调整后跑起一个能够学习古诗和写古诗的比较简单的程序。

执行写诗(让它写了十首):
抑滴留居潋罅斜,二川还羡五侯家。古刘称士身相染,桃李栽林欲称家。回首二毛相喘日,万当仙性尽甘无。如何羽马嘶来泪,不信红峰一寸西。
废寺松阴月似空,垂杨风起晚光催。乌心不把嫌香径,出定沧洲几好清。兰逐白头邻斧蝶,苍苍归路自清埃。渔樵若欲斜阳羡,桂苑西河碧朔来。
遥天花落甚巫山,凤珮飞驰不骋庄。翠初才象饮毫势,上月朱炉一重牛。香催戍渚同虚客,石势填楼取蕊红。佳句旧清箱畔意,剪颜相激菊花繁。
江上萧条第一取,名长经起月还游。数尺温皋云战远,放船乡鬼蘸云多。相逢槛上西风动,莫听风烟认钓鱼。堤费禽雏应昨梦,去朝从此满玄尘。
避命抛醺背暮时,见川谁哭梦知年。却随筵里腥消极,不遇嘉唐两带春。大岁秘魔窥石税,鹤成应听白云中。朝浮到岸鸱巇恨,不向青青听径长。
楚田馀绝宇氤氲,细雨洲头万里凉。百叶长看如不尽,水东春夜足残峰。湖头风浪斜暾鼓,北阙别罹初里村。山在四天三顾客,辘轳争养抵丹墀。
九日重门携手时,吟疑须渴辞金香。钓来犹绕结茶酒,衣上敬亭宁强烧。自明不肯疑恩日,琴馆寒霖急暮霜。划口濡于孤姹末,出谢空卿寄银机。莲龛不足厌丝屦,华骑敷砧出钓矶。
为到席中逢旧木,容华道路不能休。时闲客后多时石,暗水天边暖人说。风弄霜花嗥明镜,犀成磨逐乍牵肠。何劳相听真行侍,石石班场古政蹄。
听巾邑外见朱兰,杂时临厢北满香。门外玉坛花府古,香牌风出即升登。陵桥翠黛销仙妙,晓接红楼叠影闻。敢把苦谣金字表,应从科剑独频行。
昨日荣枯桃李庆,紫骝坚黠自何侵。险知河在皆降月,汉县烟波白发来。仍省封身明月阁,不知吹水洽谁非。更拟惭送风痕去,只怕鲸雏是后仙。

另外,我抽取其中一些个人认为写得比较好的诗句(以前跑出来的,不在上图中):

该模型比较简单,写诗的水平不如最前面我介绍的美国研究者demo,但是,所采用的基本方法应该是类似的,只是他做的更为复杂。
另外,这是一个通用模型,可以学习不同的内容(古诗、现代诗、宋词或者英文诗等),就可以生成对应的结果。

七、深度学习的入门学习体会

  1. 人工智能和深度学习技术并不神秘,更像是一个新型的工具,通过喂数据给它,然后,它能发现这些数据背后的规律,并为我们所用。
  2. 数学基础比较重要,这样有助于理解模型背后的数学原理,不过,从纯应用角度来说,并不一定需要完全掌握数学,也可以提前开始做一些尝试和学习。
  3. 我深深地感到计算资源非常缺乏,每次调整程序的参数或训练数据后,跑完一次训练集经常要很多个小时,部分场景不跑多一些训练集数据,看不出差别,例如写诗的案例。个人感觉,这个是制约AI发展的重要问题,它直接让程序的“调试”效率非常低下。
  4. 中文文档比较少,英文文档也不多,开源社区一直在快速更新,文档的内容过时也比较快。因此,入门学习时遇到的问题会比较多,并且缺乏成型的文档。

八、小结
我不知道人工智能的时代是否真的会来临,也不知道它将要走向何方,但是,毫无疑问,它是一种全新的技术思维模式。更好的探索和学习这种新技术,然后在业务应用场景寻求结合点,最终达到帮助我们的业务获得更好的成果,一直以来,就是我们工程师的核心宗旨。另一方面,对发展有重大推动作用的新技术,通常会快速的发展并且走向普及,就如同我们的编程一样,因此,人人都可以做深度学习应用,并非只是一句噱头。

参考文档:
http://www.tensorfly.cn/
https://www.tensorflow.org/

数学相关的内容:
高中和大学数学部分内容
线性代数视频

相关推荐:

人人都可以做深度学习应用:入门篇(上)

作者:腾讯QQ会员技术团队(代号:小时光茶社)


此文已由作者授权腾讯云技术社区发布,转载请注明文章出处,获取更多云计算技术干货,可请前往腾讯云技术社区

欢迎大家关注腾讯云技术社区-博客园官方主页,我们将持续在博客园为大家推荐技术精品文章哦~

腾讯QQ会员技术团队:人人都可以做深度学习应用:入门篇(下)的更多相关文章

  1. 【腾讯Bugly干货分享】人人都可以做深度学习应用:入门篇

    导语 2016年,继虚拟现实(VR)之后,人工智能(AI)的概念全面进入大众的视野.谷歌,微软,IBM等科技巨头纷纷重点布局,AI 貌似将成为互联网的下一个风口. 很多开发同学,对人工智能非常感兴趣, ...

  2. 腾讯QQ会员技术团队:以手机QQ会员H5加速为例,为你揭开sonic技术内幕

    目前移动端越多越多的网页开始H5化,一方面可以减少安装包体积,另一方面也方便运营.但是相对于原生界面而言,H5的慢速问题一定被大家所诟病,针对这个问题,目前手Q存在几种方案,最常见的便是离线包方案,但 ...

  3. C#7.2——编写安全高效的C#代码 c# 中模拟一个模式匹配及匹配值抽取 走进 LINQ 的世界 移除Excel工作表密码保护小工具含C#源代码 腾讯QQ会员中心g_tk32算法【C#版】

    C#7.2——编写安全高效的C#代码 2018-11-07 18:59 by 沉睡的木木夕, 123 阅读, 0 评论, 收藏, 编辑 原文地址:https://docs.microsoft.com/ ...

  4. 腾讯QQ会员中心g_tk32算法【C#版】

    最近用C#写qq活动辅助类程序,碰到了会员签到的gtk算法不一样,后来网上找了看,发现有php版的(https://www.oschina.net/code/snippet_1378052_48831 ...

  5. 腾讯优图&港科大提出一种基于深度学习的非光流 HDR 成像方法

    目前最好的高动态范围(HDR)成像方法通常是先利用光流将输入图像对齐,随后再合成 HDR 图像.然而由于输入图像存在遮挡和较大运动,这种方法生成的图像仍然有很多缺陷.最近,腾讯优图和香港科技大学的研究 ...

  6. 当微信小程序遇到云开发,再加上一个类似 ColorUI 的模板,人人都能做小程序了

    作为一个 Java 程序员,早就想尝试一把微信小程序,但是一直苦于没有想法,再加上做一个漂亮的页面实在不太擅长. 由于自己比较喜欢历史,经常看历史方面的书.在一次梳理中国现有的朝代时,突然想到,要是可 ...

  7. QQ 腾讯QQ(简称“QQ”)是腾讯公司开发的一款基于Internet的即时通信(IM)软件

    QQ 编辑 腾讯QQ(简称“QQ”)是腾讯公司开发的一款基于Internet的即时通信(IM)软件.腾讯QQ支持在线聊天.视频通话.点对点断点续传文件.共享文件.网络硬盘.自定义面板.QQ邮箱等多种功 ...

  8. Kube-OVN:大型银行技术团队推荐的金融级云原生网络方案

    近日,由TWT社区主办的2021容器云职业技能大赛团队赛的冠军作品:<适用于大中型银行的云原生技术体系建设方案>中,Kube-OVN成为银行技术团队推荐的金融级云原生网络最佳实践.本文部分 ...

  9. 深度学习的异构加速技术(一):AI 需要一个多大的“心脏”?

    欢迎大家前往腾讯云社区,获取更多腾讯海量技术实践干货哦~ 作者:kevinxiaoyu,高级研究员,隶属腾讯TEG-架构平台部,主要研究方向为深度学习异构计算与硬件加速.FPGA云.高速视觉感知等方向 ...

随机推荐

  1. overflow:hidden 你所不知道的事

    overflow:hidden 你所不知道的事 overflow:hidden这个CSS样式是大家常用到的CSS样式,但是大多数人对这个样式的理解仅仅局限于隐藏溢出,而对于清除浮动这个含义不是很了解. ...

  2. linux下安装TensorFlow(centos)

    一.python安装 centos自带python2.7.5,这一步可以省略掉. 二.python-pip pip--python index package,累世linux的yum,安装管理pyth ...

  3. 大大维的贪吃蛇v1

    虽然本人一直是个免费的游戏测试员(/手动滑稽),但一直有着一个游戏架构师的梦想.正如马爸爸所说,梦想还是要有的,万一实现了呢? 这些天放寒假,有些空闲时间,就想着做一个简单的游戏机.能达到小时候十几块 ...

  4. cygwin 运行窗口程序

    首先, 默认安装的cygwin是不能运行窗口程序的 比如,一段python窗口程序: import * from tkinter Tk() mainloop() 如果使用命令行: python3 py ...

  5. WinForm 控件(下)

    10.PictureBox 外观,Image可以选择图片路径行为,SizeMode可以设置图片大小布局方式 11.Imagelist--图片集 imageList1.Images[n]; 12.not ...

  6. PL/SQL基本概念

    首先明确PL/SQL主要作用作用: SQL语言适合管理关系型数据库但是它无法满足更复杂的数据处理,所以产生PLSQL.PLSQL用户创建存储过程.函数.触发器.包及用户自定义的函数. 特点: PLSQ ...

  7. java socket tcp(服务器循环检测)

    刚才看了下以前写了的代码,tcp通信,发现太简单了,直接又摘抄了一个,运行 博客:http://blog.csdn.net/zhy_cheng/article/details/7819659 优点是服 ...

  8. 原生js实现轮播图

    原生js实现轮播图 很多网站上都有轮播图,但找到一个系统讲解的却很难,因此这里做一个简单的介绍,希望大家都能有所收获,如果有哪些不正确的地方,希望大家可以指出. 原理: 将一些图片在一行中平铺,然后计 ...

  9. weex官方demo weex-hackernews代码解读(下)

    weex 是阿里出品的一个类似RN的框架,可以使用前端技术来开发移动应用,实现一份代码支持H5,IOS和Android.而weex-hacknews则是weex官方出品的,首个使用 Weex 和 Vu ...

  10. IOS编程学习笔记

    @interface -实例对象 +类名 #import "MyClass" @implementation MyClass -(id)initWithString:(NSStri ...