摘要: 本文是通过Keras实现深度学习入门项目——数字手写体识别,整个流程介绍比较详细,适合初学者上手实践。

对于图像分类任务而言,卷积神经网络(CNN)是目前最优的网络结构,没有之一。在面部识别、自动驾驶、物体检测等领域,CNN被广泛使用,并都取得了最优性能。对于绝大多数深度学习新手而言,数字手写体识别任务可能是第一个上手的项目,网络上也充斥着各种各样的成熟工具箱的相关代码,新手在利用相关工具箱跑一遍程序后就能立刻得到很好的结果,这时候获得的感受只有一个——深度学习真神奇,却没能真正了解整个算法的具体流程。本文将利用Keras和TensorFlow设计一个简单的二维卷积神经网络(CNN)模型,手把手教你用代码完成MNIST数字识别任务,便于理解深度学习的整个流程。

 

 
 
image.png

准备数据

模型使用的MNIST数据集,该数据集是目前最大的数字手写体数据集(0~9),总共包含60,000张训练图像和10,000张测试图像,每张图像的大小为28x28,灰度图。第一步是加载数据集,可以通过Keras API完成:

#源代码不能直接下载,在这里进行稍微修改,下载数据集后指定路径
#下载链接:https://pan.baidu.com/s/1jH6uFFC 密码: dw3d from __future__ import print_function
import keras
import numpy as np
path='./mnist.npz'
f = np.load(path)
X_train, y_train = f['x_train'], f['y_train']
X_test, y_test = f['x_test'], f['y_test']

上述代码中,X_train表示训练数据集,总共60,000张28x28大小的手写体图像,y_train表示训练图像对应的标签。同理,X_test表示测试数据集,总共10,000张28x28大小的手写体图像,y_test表示测试图像对应的标签。下面对数据集部分数据进行可视化,以便更好地了解构建的模型深度学习模型的目的。

import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(9):
plt.subplot(3,3,i+1)
plt.tight_layout()
plt.imshow(X_train[i], cmap='gray', interpolation='none')
plt.title("Digit: {}".format(y_train[i]))
plt.xticks([])
plt.yticks([])
fig
 

image.png

从图中可以看到,左上角是存储在训练集X_train[0]的手写体图像‘5’,y_train[0]表示对应的标签‘5’。整个深度学习模型的功能是训练好之后能够预测出别人手写的数字具体是什么。
对于神经网络而言,一般需要对原始数据进行预处理。常见的预处理方式是调整图像大小、对像素值进行归一化等。

# let's print the actual data shape before we reshape and normalize
print("X_train shape", X_train.shape)
print("y_train shape", y_train.shape)
print("X_test shape", X_test.shape)
print("y_test shape", y_test.shape) #input image size 28*28
img_rows , img_cols = 28, 28 #reshaping
#"channels_first" assumes (channels, conv_dim1, conv_dim2, conv_dim3).
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
#more reshaping
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
print('X_train shape:', X_train.shape) #X_train shape: (60000, 28, 28, 1)

对图像信息进行必要的处理之后,标签数据y_train和y_test被转换为分类格式(向量形式),即标签‘3’被转换为向量[ 0,0,0,1,0,0,0,0,0,0]用于建模,标签向量非零的位置减一(从0开始)后表示该图像的具体标签,即若图像的标签向量在下标5处不为0,则表示该图像代表数字‘4’。

import keras
#set number of categories
num_category = 10
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_category)
y_test = keras.utils.to_categorical(y_test, num_category)

构建和编译模型

在数据准备好提供给模型后,需要定义模型的体系结构并使用必要的优化函数损失函数性能指标进行编译。
构建模型遵循的体系结构是经典卷积神经网络,分别含有2个卷积层,之后是连接全连接层和softmax分类器。如果你对每层的作用不熟悉的话,建议学习CS231课程
在最大池化层和全连接层之后,模型中引入dropout作为正则化来减少过拟合问题。

#导入相关层的结构
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from keras import backend as k
import matplotlib.pyplot as plt
import numpy as np ## model building
model = Sequential()
#convolutional layer with rectified linear unit activation
model.add(Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape))
#32 convolution filters used each of size 3x3
#again
model.add(Conv2D(64, (3, 3), activation='relu'))
#64 convolution filters used each of size 3x3
#choose the best features via pooling
model.add(MaxPooling2D(pool_size=(2, 2)))
#randomly turn neurons on and off to improve convergence
model.add(Dropout(0.25))
#flatten since too many dimensions, we only want a classification output
model.add(Flatten())
#fully connected to get all relevant data
model.add(Dense(128, activation='relu'))
#one more dropout for convergence' sake :)
model.add(Dropout(0.5))
#output a softmax to squash the matrix into output probabilities
model.add(Dense(num_category, activation='softmax'))

模型搭建好之后,需要进行编译。在本文使用categorical_crossentropy多分类损失函数。由于所有的标签都具有相似的权重,因此将其作为性能指标,并使用AdaDelta梯度下降技术来优化模型参数。

#Adaptive learning rate (adaDelta) is a popular form of gradient descent rivaled only by adam and adagrad
#categorical ce since we have multiple classes (10)
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])

训练和评估模型

在定义和编译模型架构之后,需要使用训练数据对模型进行训练,以便能够识别手写数字。即使用X_train和y_train来拟合模型。

batch_size = 128
num_epoch = 10
#model training
model_log = model.fit(X_train, y_train,
batch_size=batch_size,
epochs=num_epoch,
verbose=1,
validation_data=(X_test, y_test))

Epoch表示对所有训练样本进行一个前向传播过程和一个反向传播过程,Batch_Size表示每次前向过程和反向过程时处理的训练样本数,训练输出如下所示:

 

image.png

模型训练好后需要评估其性能:

score = model.evaluate(X_test, y_test, verbose=0)
print('Test loss:', score[0]) #Test loss: 0.0296396646054
print('Test accuracy:', score[1]) #Test accuracy: 0.9904

可以看到,测试准确性高达99%+,这也意味着该模型对于预测训练得很好。对整个过程训练和测试过程进行可视化,即画出训练和测试的准确曲线与损失函数曲线,如下所示。从图中可以看到,随着训练迭代次数的增加,模型在训练和测试数据上的损失和准确性趋于一致,模型最终趋于稳定。

 

image.png

保存模型参数

模型训练好后需要保存训练好的参数,以便下次直接调用。模型的体系结构或结构将存储在json文件中,权重将以hdf5文件格式存储。

#Save the model
# serialize model to JSON
model_digit_json = model.to_json()
with open("model_digit.json", "w") as json_file:
json_file.write(model_digit_json)
# serialize weights to HDF5
model.save_weights("model_digit.h5")
print("Saved model to disk")

因此,保存好的模型可以之后进行重复使用或轻易地迁移到其他应用场景中。

作者信息

Sambit Mahapatra,人工智能和机器学习爱好者
本文由阿里云云栖社区组织翻译。
文章原标题《A simple 2D CNN for MNIST digit recognition》,译者:海棠,审校:[Uncle_LLD]
阅读原文
本文为云栖社区原创内容,未经允许不得转载。

入门项目数字手写体识别:使用Keras完成CNN模型搭建(重要)的更多相关文章

  1. keras训练cnn模型时loss为nan

    keras训练cnn模型时loss为nan 1.首先记下来如何解决这个问题的:由于我代码中 model.compile(loss='categorical_crossentropy', optimiz ...

  2. 手写数字识别——利用keras高层API快速搭建并优化网络模型

    在<手写数字识别——手动搭建全连接层>一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配.梯度计算.准确度的统计等问题,但 ...

  3. keras入门--Mnist手写体识别

    介绍如何使用keras搭建一个多层感知机实现手写体识别及搭建一个神经网络最小的必备知识 import keras # 导入keras dir(keras) # 查看keras常用的模块 ['Input ...

  4. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  5. keras框架下的深度学习(一)手写体识别

    这个系列文章主要记录使用keras框架来搭建深度学习模型的学习过程,其中有一些自己的想法和体会,主要学习的书籍是:Deep Learning with Python,使用的IDE是pycharm. 在 ...

  6. 基于贝叶斯模型和KNN模型分别对手写体数字进行识别

    首先,我们准备了0~9的训练集和测试集,这些手写体全部经过像素转换,用0,1表示,有颜色的区域为0,没有颜色的区域为1.实现代码如下: # 图片处理 # 先将所有图片转为固定宽高,比如32*,然后再进 ...

  7. 数据挖掘入门系列教程(十二)之使用keras构建CNN网络识别CIFAR10

    简介 在上一篇博客:数据挖掘入门系列教程(十一点五)之CNN网络介绍中,介绍了CNN的工作原理和工作流程,在这一篇博客,将具体的使用代码来说明如何使用keras构建一个CNN网络来对CIFAR-10数 ...

  8. Keras入门(四)之利用CNN模型轻松破解网站验证码

    项目简介   在之前的文章keras入门(三)搭建CNN模型破解网站验证码中,笔者介绍介绍了如何用Keras来搭建CNN模型来破解网站的验证码,其中验证码含有字母和数字.   让我们一起回顾一下那篇文 ...

  9. TensorFlow 入门之手写识别(MNIST) 数据处理 一

    TensorFlow 入门之手写识别(MNIST) 数据处理 一 MNIST Fly softmax回归 准备数据 解压 与 重构 手写识别入门 MNIST手写数据集 图片以及标签的数据格式处理 准备 ...

随机推荐

  1. drf框架,restful接口规范,源码分析

    复习 """ 1.vue如果控制html 在html中设置挂载点.导入vue.js环境.创建Vue对象与挂载点绑定 2.vue是渐进式js框架 3.vue指令 {{ }} ...

  2. WLC-Virtual Interface IP

    关于思科WLC,有很多接口类型,如下所示,这里主要针对Virtual IP记录一些最佳实践建议. 思科WLC的Virtual IP地址的作用: • Mobility management • DHCP ...

  3. Go并发介绍

    1. 进程.线程.协程 进程(Process),线程(Thread),协程(Coroutine,也叫轻量级线程) 进程 进程是一个程序在一个数据集中的一次动态执行过程,可以简单理解为“正在执行的程序” ...

  4. Hadoop3.1.1源码Client详解 : 写入准备-RPC调用与流的建立

    该系列总览: Hadoop3.1.1架构体系——设计原理阐述与Client源码图文详解 : 总览 关于RPC(Remote Procedure Call),如果没有概念,可以参考一下RMI(Remot ...

  5. 前后端分离之 跨域和JWT

    书接上回:https://www.cnblogs.com/yangyuanhu/p/12081525.html 前后端分离案例 现在把自己当成是前端,要开发一个前后分离的简单页面,用于展示学生信息列表 ...

  6. 安装和配置Linux系统虚拟机

    1.打开虚拟机软件 2.点击创建新的虚拟机,选择典型(推荐)类型的配置. 3.点击稍后安装操作系统. 4.客户机操作系统选择Linux,版本选择CentOS 7 64位. 5.虚拟机名称可自行更改,位 ...

  7. [转] C++ CImage实现的全屏PNG截图

    #include <atlimage.h> #include <atltime.h> #include <conio.h> //截取全屏保存为png CString ...

  8. IIS-URL重写参数

    参考:https://www.cnblogs.com/gggzly/p/5960335.html URL 重写规则由以下部分组成: 模式 - 可以理解为规则,分通配符和正则匹配     条件 - 可以 ...

  9. MYSQL双查询错误1

    一.基础知识 开始讲解MYSQL双查询错误之前,我们先了解一下双查询语句以及需要使用到的几个数据库函数和GROUP BY语句 1. 双查询语句 先了解一下什么是子查询,子查询就是嵌入第一层select ...

  10. python函数1_参数,返回值和嵌套

    函数 将重复的代码,封装到函数,只要使用直接找函数 函数可以增强代码的模块化和提高代码的重复利用率 函数的定义和调用 格式 def 函数名([参数,参数...]): 函数体 定义函数 import r ...