人工智能

  人工智能(Artificial Intelligence,简称AI)一词最初是在1956年Dartmouth学会上提出的,从那以后,研究者们发展了众多理论和原理,人工智能的概念也随之扩展。由于人工智能的研究是高度技术性和专业的,各分支领域都是深入且各不相通的,因而涉及范围极广 。 人工智能的核心问题包括建构能够跟人类似甚至超越人类的推理、知识、学习、交流、感知、使用工具和操控机械的能力等,当前人工智能已经有了初步成果,甚至在一些影像识别、语言分析、棋类游戏等等单方面的能力达到了超越人类的水平 。

  人工智能的分支领域非常多,主要有演绎推理、知识表示、规划、学习、自然语言处理……等十多个分支领域,而以机器学习为代表的“学习”领域,是目前研究最广泛的分支之一。

机器学习

   机器学习(Machine Learning)是人工智能的一个分支,它是实现人工智能的一个途径,即以机器学习为手段解决人工智能中的问题。机器学习在近30多年已发展为一门多领域交叉性的学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。

   机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法,该算法是一类从数据中自动分析获得规律,并利用规律对未知数据进行预测的算法。

深度学习

  深度学习(Deep Learning)是机器学习的分支,是一种以人工神经网络为架构,对数据进行表征学习的算法。表征学习的目标是寻求更好的表示方法并创建更好的模型来从大规模未标记数据中学习这些表示方法。表示方法来自神经科学,并松散地创建在类似神经系统中的信息处理和对通信模式的理解上,如神经编码,试图定义拉动神经元的反应之间的关系以及大脑中的神经元的电活动之间的关系。

  因此,人工智能、机器学习、深度学习的关系如下图所示。

  至今已有数种深度学习模型,如深度神经网络、卷积神经网络和深度置信网络和递归神经网络已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

  目前,业内也已经产生了多种优秀的深度学习框架,例如TensorFlow、PyTorch、Caffe、Mxnet等等。但这些都不是本文讨论的重点,本文主要以机器学习初学者的身份,使用最基本的机器学习算法,加以微积分、线性代数、概率统计等基础数学知识,来解决手写数字的识别的问题。

问题背景

  为什么要去研究数字的识别问题呢?因为最近刚过双11,又看了到许多曝光快递行业野蛮分拣的新闻。据某快递公司负责人回应称,之所以会出现野蛮分拣的问题,主要是双11期间快递数量巨增,为了尽快派发收到的快递,他们不得不请了许多“临时工”,而这些“临时工”缺乏培训,缺乏规范操作,所以出现了暴力分拣快递的问题。

问题分析

  针对分拣快递这种简单、重复的工作,可以交给机器去做吗?我们来分析一下“临时工”所做的工作。“临时工”拿到一个快递,找到快递上的快递单,然后再找到目的省(城市),如果是华北城市,则将快递扔进“北京”的筐里;如果是华东城市,则扔进“上海”的筐里;如果是西南城市,则扔进“成都”的筐里。那么机器可以完成吗?

  如上图所示,快递通过传送带进入“目的地识别系统”,识别后将该快递分配到对应地点的传送带即可。那么,该问题的关键就是“目的地识别系统”如何将快递单上的目的地识别出来,并作出正确的判断。

  以顺丰速运的快递单为例,快递单上已将目的地翻译为了城市代码,通过识别该代码即可让机器“知道”快递的目的地,然后配置对应传送带接收的具体城市代码即可。

  识别数字技术在深度学习领域已经非常成熟,常见的解决方案是OpenCV+Keras+TensorFlow,例如Github上有比较完善的车牌识别项目,但本文并不打算使用这些库,而是采用最底层、最基础的机器学习方法来实现。

数学建模

  数字照片通过扫描后,以像素点的方式进行存储,因此输入数据即是像素点,通过机器学习算法后,结果则是识别出来的0-9的数字。根据机器学习理论,每个样本都有对应的标签,因此属于“监督式学习”的范畴。而样本的输出值为0-9固定的10种情况,因此可以采用逻辑回归的机器学习模型来解决,分别计算结果为0-9的概率,建模就是找到一个假设函数(Hypothesis Function),函数值是数据通过假设函数后获得对应的输出结果,即概率。

  逻辑回归的假设函数是由S型函数(Sigmoid Function)演变而来,S型函数的表达式及曲线如下图所示:

  从曲线中可以看到,当变量z趋近于正无穷时,函数值趋近于1,当变量z趋近于负无穷时,函数值趋近于0。这样就能够很好的匹配逻辑回归,因为逻辑回归的输出为0或1,当输出值为0.7时,则表示结果为1的概率是70%,为0的概率是30%,正好可以进行概率的预测。

  受线性回归所启发,逻辑回归的假设函数公式为(其中θ为模型的参数矩阵,x为输入变量矩阵,变量z变成了θ的转置乘以x):

  如果要让机器来识别数字,那么首先就要先用样本去教会机器,即用样本“训练”模型。为了获得“最好”的模型,我们需要计算样本在模型下的代价函数(Cost Function,也有资料称为“损失函数”)。所谓代价函数,就是在该模型下产生的输出与实际结果间产生的偏差,偏差越小,则可以在一定程度上表明模型越好(也不是绝对的,可能会出现模型过度拟合(Overfit)的情况,需要一些手段来避免)。

  通过概率统计理论中的“最大似然估计”,可以得到如下的逻辑回归的代价函数:

  该函数看起来很复杂,可以将其拆开来看,log(h(x))是y=1时的代价函数,log(1-h(x))是y=0时的代价函数,最右边的一项为正则化参数,可以减小出现过度拟合的几率。为了找到最好的模型(假设函数),我们需要找到该代价函数的最小值。找到最小值后,自变量θ即为我们要找的逻辑回归的模型参数。

  根据高等数学中的“拉格朗日中值定理”,可以得知该函数为凹函数,存在最小值。证明过程比较复杂,不在此阐述。

模型训练

  为了得到代价函数J(θ)的最小值,我们可以采用机器学习中最常用的“梯度下降”算法(Gradient Decent)来求得函数在区间内的极小值。所谓梯度下降算法,就是对于任一函数,首先取任一点(x1或者x2均可),在这一点减去这一点对应的梯度(即该点的导数),那么这一点就会向该函数的某一极小值运动,反复进行梯度下降,则可以得到区间内的极小值x0。如果函数为凹函数,那么该极小值就是函数的最小值。

  因此,执行梯度下降的公式为:

  这里需要对J(θ)求“偏导数”,求得后的结果为:

  至此,理论工作准备完毕,可以进行编码实战。

在Matlab/Octave中训练

  输入样本为手写数字,以20 * 20像素点的形式存储,将像素点数据摊开作为一行,每行就有400个像素点信息。训练样本中搜集了5000个手写数字的照片,因此样本X为5000 * 400的矩阵,样本结果y为5000 * 1的列向量。

% 计算代价函数
J = 1 / m * (-y' * log(sigmoid(X * theta)) - (1 - y)' * log(1 - sigmoid(X * theta))); % 计算梯度
grad = 1 / m * X' * (sigmoid(X * theta) - y); % 代价函数正则化
J = J + lambda / (2 * m) * (sum(theta(2:end) .^ 2)); % 梯度正则化
theta_temp = theta;
theta_temp(1) = 0;
grad = grad + lambda / m * theta_temp;

  通过以上代码,就可以实现一次代价函数的计算,并返回当前点的梯度。根据之前的分析,只要重复进行梯度下降即可。而Matlab提供了一种更加简便的方式“fmincg”函数,它能采用类似梯度下降的方式,来自动优化参数θ。

% 初始化theta
initial_theta = zeros(n + 1, 1);
% 参数
options = optimset('GradObj', 'on', 'MaxIter', 50); % 循环所有数字
for c = 1:num_labels
% 训练出最优theta
theta = fmincg(@(t)(lrCostFunction(t, X, (y == c), lambda)), initial_theta, options);
all_theta(c, :) = theta;
endfor

  可以看到,上述代码进行了1-10总共10个模型的训练,每个模型就是识别0-9这10个数字的概率。通过以上训练后,就可以得到最后的theta,将其带入假设函数hθ(x),于是就得到了我们训练后的10个模型,可以用该模型来进行手写数字的识别。

数字识别

  利用已经训练好的10个模型,我们就可以将机器从未见过的手写数字通过10个模型,让每个模型计算出他是对应数字的概率,然后我们取最高的概率,就可以得到机器识别出的数字。我们来举个例子:

  如图红框中的数字,可能有的人会认成“4”,而有的人却会认成“6”。到底是4还是6呢?可能众说纷纭,因为有的人习惯这样写4,而有的人却不习惯这样写。在机器学习中,机器会学习之前样本中的数据,学习到作者写数字的习惯,将该测试样本分别输入到10个模型后,得到如下的概率输出(均保留5位有效数字):

数字 概率 数字 概率
0 0.0000021328% 5 0.0000015184%
1 0.0000021719% 6 99.987%
2 2.3224% 7 0.000033508%
3 0.0000012768% 8 0.0011023%
4 0.013391% 9 0.032251%

  通过如上数据可以看到,数字6的匹配度高达99.987%占据了绝对领先,第二则是数字2的2.3224%。而数字4只有0.013391的概率,看来在机器学习看来,这个数字基本可以判定为“6”,只是稍微有一丁点像“2”,跟其他数字都特别不像。我们取概率最大值,得出了正确的结果为数字“6”。

  我们用测试样本的真实值对模型进行校验,最终获得训练的正确率为94.9%。那么取100个测试样本的识别结果如何呢?

  可以看到,100个测试样本的识别结果有5个数字识别错误,测试识别率为95%。那么有什么方法可以提高识别率呢?

提高梯度下降次数

  通过之前的理论分析我们知道,梯度下降次数越多,代价函数就越接近最小值,于是我把次数从50提高到100时,测试样本准确率达到了95.98,提升了约1%。然后又提高到200时,达到了96.4%,提升了约0.5%。最后提高到500时,仍为96.4%,没有提升。

  看来在逻辑回归模型下,手写数字的识别率最高仅可以提升到96.4%,已经达到了最高。还有其他办法可以提升识别率吗?

神经网络

  人工神经网络(Artificial Neural Network),简称神经网络(Neural Network,NN),在机器学习和认知科学领域,是一种模仿生物神经网络(动物的中枢神经系统,特别是大脑)的结构和功能的数学模型或计算模型,用于对函数进行估计或近似。

  神经网络由大量的人工神经元联结进行计算,大多数情况下人工神经网络能在外界信息的基础上改变内部结构,是一种自适应系统,通俗的讲就是具备学习功能,并且是一种非线性统计性数据建模工具。

  不得不说,人类是真的聪明,居然可以想到建立类似于生物大脑神经的模型来模拟大脑,从而实现部分人类的能力。神经网络模型如下:

  可以看到,基本的神经网络模型有输入层、隐藏层、输出层。输入层用于接受输入信号,类似于人类感知视觉信号、声音信号、触觉信号等等。隐藏层可以是多层,可以让数据在不同层之间传递与处理,类似于人类的神经元,可以逐级传递。输出层用于输出处理后的数据。如果有非常多的隐藏层,又可以称为深度神经网络,在这种模型下的机器学习又称作深度学习。

  由于神经网络是一种非线性模型,属于逻辑结构,因此没有简单的“假设函数”。要计算数据通过输入层、隐藏层后到输出层的数据,可以通过“正向传播算法”(Forward Propagation)。

正向传播

  神经网络模型看似复杂,如果只看一层的话,就可以用逻辑回归的模型来推导,因为每一层都是逻辑回归问题。若θ1与θ2已知(图中标识),那么就可以用逻辑回归来计算每一层的输出,然后逐渐从左到右,正向传递,所以称为正向传播,最终得出输出值hθ(x)。

  正向传播的步骤如下:

  假设函数有了,如果我们能找到模型的θ1与θ2,那么模型就有了,就可以用这个模型来进行数字识别了。那么怎么才能找到合适的θ1与θ2呢?与之前讲的逻辑回归类似,我们也可以先找到该模型的代价函数,然后通过梯度下降找到代价函数的最小值,就可以找到神经网络的参数了。

代价函数

  前面已经提到,神经网络模型其实就是有很多层的逻辑回归模型,那么代价函数也可以采用逻辑回归的代价函数,然后将每一层网络叠加起来就可以了,所以神经网络的代价函数如下:

  公式看起来比较吓人,实际上只是多了网络层数K,并且参数θ从向量变成了矩阵而已。如果把这个公式转化为矩阵形式,其实非常的简单(不含最右边的正则化):

% m为样本数,y为样本结果矩阵,h为由正向传播计算出的输出矩阵。
J = - 1 / m * (sum(sum(y .* log(h))) + sum(sum((1 - y) .* log(1-h))));

  代价函数有了,梯度该怎么算呢?与逻辑回归不同,因为增加了层数的概念,所以梯度计算也会变得比较复杂,神经网络里称为“反向传播算法”(Backward Propagation)。

反向传播

  求解梯度,最终还是对代价函数进行求偏导数,但由于模型是非线性的,无法直接求偏导数。所以,反向传播的基本思路就是将最终的计算偏差分摊到每一层,逐渐从右向左,反向传递,所以称为反向传播。

  反向传播的步骤如下:

  通过第4步,我们就得到了J(θ)的偏导数,即梯度。

随机初始化

  在神经网络模型中,θ的初始化非常重要。我在训练神经网络的实践中,就因为忘了随机初始化θ,而导致模型非常糟糕,无论怎么训练,识别率都只有40%。吃一堑,长一智。神经网络不像逻辑回归中的θ,逻辑回归中的θ初始为0或者其他任何数都可以,神经网络中θ的初始化值直接影响了模型的好坏。

  θ的随机初始化的要求如下:

在Matlab/Octave中训练

  正向传播:

% 计算h(x)
a1 = [ones(m, 1) X]';
z2 = Theta1 * a1;
a2 = [ones(1, m); sigmoid(z2)];
z3 = Theta2 * a2;
a3 = sigmoid(z3);
h = a3';

  计算代价函数:

J = - 1 / m * (sum(sum(y2 .* log(h))) + sum(sum((1 - y2) .* log(1-h))));
% 计算正则化后的J
% 去掉theta的第一列,即去掉theta0
Theta1_new = Theta1(:, 2:end);
Theta2_new = Theta2(:, 2:end);
J = J + (sum(sum(Theta1_new .^ 2)) + sum(sum(Theta2_new .^ 2))) * lambda / (2 * m);

  反向传播:

% 将y2转化为每一列为一个样本
delta3 = a3 - y2';
Theta2_grad = delta3 * a2' / m;
% Theta2需要使用去掉了第一列针对偏置的权重
delta2 = Theta2_new' * delta3 .* sigmoidGradient(z2);
Theta1_grad = delta2 * a1' / m; % 正则化,需要将theta的第一列设置为0
Theta1(:, 1) = 0;
Theta2(:, 1) = 0;
Theta2_grad = Theta2_grad + lambda / m * Theta2;
Theta1_grad = Theta1_grad + lambda / m * Theta1;

  梯度下降:

options = optimset('MaxIter', 50);
lambda = 1;
costFunction = @(p) nnCostFunction(p, ...
input_layer_size, ...
hidden_layer_size, ...
num_labels, X, y, lambda);
[nn_params, cost] = fmincg(costFunction, initial_nn_params, options);

  然后用同样的训练样本对神经网络模型训练后,循环50次,测试样本识别率达到了95.82%。提高到100次,达到98.14%。提高到200次,达到了98.94%。最后提高到500次,最终达到了99.36%。可以看到,神经网络模型对于手写数字识别率高于逻辑回归模型。

总结

  1. 对于手写数字,神经网络模型一般比逻辑回归模型的准确率更高。
  2. 逻辑回归模型只能处理二维的输出,如果是高维的输出,需要用多个模型来降维,而神经网络可以直接处理多维输出。
  3. 使用简单的(非深度)神经网络,就可以实现较高的手写数字识别率。
  4. 神经网络模型的初始化参数非常重要。
  5. 在Maltab/Octave中,矩阵的运算效率要远远高于循环的运算效率,因此数据处理尽量采用矩阵形式。

  (由于水平有限,本文如有分析得不对之处,还请指正。)

使用AI算法进行手写数字识别的更多相关文章

  1. 基于OpenCV的KNN算法实现手写数字识别

    基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...

  2. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  3. KNN分类算法实现手写数字识别

    需求: 利用一个手写数字“先验数据”集,使用knn算法来实现对手写数字的自动识别: 先验数据(训练数据)集: ♦数据维度比较大,样本数比较多. ♦ 数据集包括数字0-9的手写体. ♦每个数字大约有20 ...

  4. 实验楼 1. k-近邻算法实现手写数字识别系统--《机器学习实战 》

    首先看看一些关键词:K-NN算法,训练集,测试集,特征(空间),标签 举实验楼中的样例,通俗的讲讲K-NN算法:电影有两个分类(标签)-动作片-爱情片.两个特征--打斗场面--亲吻画面. 将那些数字和 ...

  5. KNN算法案例--手写数字识别

    import numpy as np import matplotlib .pyplot as plt import pandas as pd from sklearn.neighbors impor ...

  6. CNN:人工智能之神经网络算法进阶优化,六种不同优化算法实现手写数字识别逐步提高,应用案例自动驾驶之捕捉并识别周围车牌号—Jason niu

    import mnist_loader from network3 import Network from network3 import ConvPoolLayer, FullyConnectedL ...

  7. 一看就懂的K近邻算法(KNN),K-D树,并实现手写数字识别!

    1. 什么是KNN 1.1 KNN的通俗解释 何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1 ...

  8. 机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别

    一.问题与解决方案 通过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片.已经预先进行过处理,读取了各像素点的灰度值,并进行了标记. 其中第0列是序号(不参与运算).1-64列是像 ...

  9. 利用神经网络算法的C#手写数字识别

    欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...

随机推荐

  1. Python list 遇到的问题

    1.list“+” 运算 <list += > diff. <ndarray +=> list1 += list2是追加,而不是加法运算 list1 = [0,0,0] lis ...

  2. Qt 串口通信 高速发送出错的解决方法总结

    使用网上的qextserialport-1.2类,自行开发多线程串口通信.开发的过程中,出现两个问题:   问题1:我用信号槽跨线程调用串口类MyCom 发送和接收数据,中间运行的时候,会内存错误,Q ...

  3. C# 7 .NET / CLR / Visual Studio version requirements

    C# 7 .NET / CLR / Visual Studio version requirements   You do NOT need to target .NET 4.6 and above, ...

  4. 数据包从物理网卡流经 Open vSwitch 进入 OpenStack 云主机的流程

    目录 文章目录 目录 前言 数据包从物理网卡进入虚拟机的流程 物理网卡处理 如何将网卡收到的数据写入到内核内存? 中断下半部分软中断处理 数据包在内核态 OvS Bridge(Datapath)中的处 ...

  5. idea使用generatorconfig生成

    在maven工程中的resource中创建generatorConfigxml配置generatorConfigxml的配置pomxml生成对象的两种方式方式一使用idea的maven插件直接快速生成 ...

  6. [CDH] Redis: Remote Dictionary Server

    基本概念 一.安装 Redis: Remote Dictionary Server 远程字典服务 使用ANSI C语言编写.支持网络.可基于内存亦可持久化的日志型.Key-Value数据库,并提供多种 ...

  7. R语言与概率统计(二) 假设检验

    > ####################5.2 > X<-c(159, 280, 101, 212, 224, 379, 179, 264, + 222, 362, 168, 2 ...

  8. Git(3):分支管理

    Git 分支管理 几乎每一种版本控制系统都以某种形式支持分支.使用分支意味着你可以从开发主线上分离开来,然后在不影响主线的同时继续工作. 创建分支命令 $git branch <branch n ...

  9. robotframework 接口测试 +RSA 加密

    首先,实现RSA加密,需要用到pycrypto这个库,这个库又依赖openssl,所以需要先下载openssl,具体教程可以参考http://bbs.csdn.net/topics/392193545 ...

  10. 转:OPC协议解析-OPC UA OPC统一架构

    1    什么是OPC UA 为了应对标准化和跨平台的趋势,为了更好的推广OPC,OPC基金会近些年在之前OPC成功应用的基础上推出了一个新的OPC标准-OPC UA.OPC UA接口协议包含了之前的 ...