简介

在上一篇博客:数据挖掘入门系列教程(十一点五)之CNN网络介绍中,介绍了CNN的工作原理和工作流程,在这一篇博客,将具体的使用代码来说明如何使用keras构建一个CNN网络来对CIFAR-10数据集进行训练。

如果对keras不是很熟悉的话,可以去看一看官方文档。或者看一看我前面的博客:数据挖掘入门系列教程(十一)之keras入门使用以及构建DNN网络识别MNIST,在数据挖掘入门系列教程(十一)这篇博客中使用了keras构建一个DNN网络,并对keras的做了一个入门使用介绍。

CIFAR-10数据集

CIFAR-10数据集是图像的集合,通常用于训练机器学习和计算机视觉算法。它是机器学习研究中使用比较广的数据集之一。CIFAR-10数据集包含10 种不同类别的共6w张32x32彩色图像。10个不同的类别分别代表飞机,汽车,鸟类,猫,鹿,狗,青蛙,马,轮船 和卡车。每个类别有6,000张图像

在keras恰好提供了这些数据集。加载数据集的代码如下所示:

from keras.datasets import cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

print(x_train.shape, 'x_train samples')
print(x_test.shape, 'x_test samples')
print(y_train.shape, 'y_trian samples')
print(y_test.shape, 'Y_test samples')

输出结果如下:

训练集有5w张图片,测试集有1w张图片。在\(x\)数据集中,图片是\((32,32,3)\),代表图片的大小是\(32 \times 32\),为3通道(R,G,B)的图片。

展示图片内容

我们可以稍微的展示一下图片的内容,python代码如下所示:

import matplotlib.pyplot as plt
%matplotlib inline plt.figure(figsize=(12,10))
x, y = 8, 6 for i in range(x*y):
plt.subplot(y, x, i+1)
plt.imshow(x_train[i],interpolation='nearest')
plt.show()

下面就是数据集中的部分图片:

数据集变换

同样,我们需要将类标签进行one-hot编码:

import keras
# 将类向量转换为二进制类矩阵。
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

实际上这一步还有很多牛逼(骚)操作,比如说对数据集进行增强,变换等等,这样都可以在一定程度上提高模型的鲁棒性,防止过拟合。这里我们就怎么简单怎么来,就只对数据集标签进行one-hot编码就行了。

构建CNN网络

构建的网络模型代码如下所示:

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten,Conv2D, MaxPooling2D # 构建CNN网络
model = Sequential() # 添加卷积层
model.add(Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:]))
# 添加激活层
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu')) # 添加最大池化层
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25)) model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25)) # 将上一层输出的数据变成一维
model.add(Flatten())
# 添加全连接层
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10))
model.add(Activation('softmax')) # 网络模型的介绍
print(model.summary())

这里解释一下代码:

Conv2D

Conv2D代表2D的卷积层,可能这里会有人问,我的图片不是3通道(RGB)的吗?为什么使用的是Conv2D而不是Conv3D。首先先说明,在Conv2D中的这个“2”代表的是卷积层可以在两个维度(也就是width,length)进行移动。那么同理Conv3D中的“3”代表这个卷积层可以在3个维度进行移动(比如说视频中的width ,length,time)。那么针对RGB这种3通道(channels),卷积过程中输入有多少个通道,则滤波器(卷积核)就有多少个通道。

简单点来说就是:

输入

单色图片的input,是2D, \(w \times h\)

彩色图片的input,是3D,\(w \times h \times channels\)

卷积核filter

单色图片的filter,是2D, \(w \times h\)

彩色图片的filter,是3D, \(w \times h \times channels\)

值得注意的是,卷积之后的结果是二维的。(因为会将3维卷积得到的结果进行相加)

接着继续解释Conv2D的参数:

Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:])

  • 32表示的是输出空间的维度(也就是filter滤波器的输出数量)
  • (3,3)代表的是卷积核的大小
  • strides(这里没有用到):这个代表是滑动的步长。
  • input_shape:输入的维度,这里是(28,28,3)

padding在上一篇博客介绍过,在keras中有两个取值:"valid""same" (大小写敏感)。

  • valid padding:不进行任何处理,只使用原始图像,不允许卷积核超出原始图像边界
  • same padding:进行填充,允许卷积核超出原始图像边界,并使得卷积后结果的大小与原来的一致

Flatten

Flatten这一层就是为了将多维数据变成一维数据:

构建网络

from keras.optimizers import RMSprop
# 利用 RMSprop 来训练模型。
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(),
metrics=['accuracy']
)

其他的参数在上两篇博客中已经讲了,就不再赘述。

进行训练评估

这里大家可以根据自己的电脑配置适当调整一下batch_size的大小。

history = model.fit(x_train, y_train,
batch_size=32,
epochs=64,
verbose=1,
validation_data=(x_test, y_test)
)

在i5-10代u,mx250的情况下,训练一轮大概需要27s左右。

训练完成之后,进行评估:

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

结果如下所示:

这个结果可以说的上是一言难尽,

数据挖掘入门系列教程(十二)之使用keras构建CNN网络识别CIFAR10的更多相关文章

  1. 数据挖掘入门系列教程(二)之分类问题OneR算法

    数据挖掘入门系列教程(二)之分类问题OneR算法 数据挖掘入门系列博客:https://www.cnblogs.com/xiaohuiduan/category/1661541.html 项目地址:G ...

  2. 数据挖掘入门系列教程(十一)之keras入门使用以及构建DNN网络识别MNIST

    简介 在上一篇博客:数据挖掘入门系列教程(十点五)之DNN介绍及公式推导中,详细的介绍了DNN,并对其进行了公式推导.本来这篇博客是准备直接介绍CNN的,但是想了一下,觉得还是使用keras构建一个D ...

  3. 数据挖掘入门系列教程(三)之scikit-learn框架基本使用(以K近邻算法为例)

    数据挖掘入门系列教程(三)之scikit-learn框架基本使用(以K近邻算法为例) 简介 scikit-learn 估计器 加载数据集 进行fit训练 设置参数 预处理 流水线 结尾 数据挖掘入门系 ...

  4. 数据挖掘入门系列教程(八)之使用神经网络(基于pybrain)识别数字手写集MNIST

    目录 数据挖掘入门系列教程(八)之使用神经网络(基于pybrain)识别数字手写集MNIST 下载数据集 加载数据集 构建神经网络 反向传播(BP)算法 进行预测 F1验证 总结 参考 数据挖掘入门系 ...

  5. 数据挖掘入门系列教程(九)之基于sklearn的SVM使用

    目录 介绍 基于SVM对MINIST数据集进行分类 使用SVM SVM分析垃圾邮件 加载数据集 分词 构建词云 构建数据集 进行训练 交叉验证 炼丹术 总结 参考 介绍 在上一篇博客:数据挖掘入门系列 ...

  6. CRL快速开发框架系列教程十二(MongoDB支持)

    本系列目录 CRL快速开发框架系列教程一(Code First数据表不需再关心) CRL快速开发框架系列教程二(基于Lambda表达式查询) CRL快速开发框架系列教程三(更新数据) CRL快速开发框 ...

  7. webpack4 系列教程(十二):处理第三方JavaScript库

    教程所示图片使用的是 github 仓库图片,网速过慢的朋友请移步<webpack4 系列教程(十二):处理第三方 JavaScript 库>原文地址.或者来我的小站看更多内容:godbm ...

  8. 数据挖掘入门系列教程(四)之基于scikit-lean实现决策树

    目录 数据挖掘入门系列教程(四)之基于scikit-lean决策树处理Iris 加载数据集 数据特征 训练 随机森林 调参工程师 结尾 数据挖掘入门系列教程(四)之基于scikit-lean决策树处理 ...

  9. 数据挖掘入门系列教程(四点五)之Apriori算法

    目录 数据挖掘入门系列教程(四点五)之Apriori算法 频繁(项集)数据的评判标准 Apriori 算法流程 结尾 数据挖掘入门系列教程(四点五)之Apriori算法 Apriori(先验)算法关联 ...

随机推荐

  1. 1642: 【USACO】Payback(还债)

    1642: [USACO]Payback(还债) 时间限制: 1 Sec 内存限制: 64 MB 提交: 190 解决: 95 [提交] [状态] [讨论版] [命题人:外部导入] 题目描述 &quo ...

  2. Oracle创建函数例子

    编写一个函数计算学生某一门课程在班级内的排名. 表结构如下: create or replace function fun_score_rank( p_in_stuid in number,--学号 ...

  3. js中的位置属性

    原生js中位置信息 clientLeft,clientTop:表示内容区域的左上角相对于整个元素左上角的位置(包括边框),实测,clientLeft=左侧边框的宽度,clientTop=顶部边框的宽度 ...

  4. JavaScript基本数据类型及其转换规则

    ECMAScript 数据类型 ECMAScript中有五种基本数据类型:Undefined, Null, Boolean, Number, String 一种复杂数据类型:Object 数据类型检测 ...

  5. 提示要安装Python-OpenSSL

    PyOpenSSL是OpenSSL的python接口,用于提供加密传输支持(SSL),如果没用该模组,会导致goagent无法生成证书而影响使用. 若系统没有openssl,先安装openssl,** ...

  6. Scratch 第2课淘气男孩儿

    素材及视频下载 链接:https://pan.baidu.com/s/1qX0T2B_zczcLaCCpiRrsnA提取码:xfp8

  7. Apache Hudi 设计与架构最强解读

    感谢 Apache Hudi contributor:王祥虎 翻译&供稿. 欢迎关注微信公众号:ApacheHudi 本文将介绍Apache Hudi的基本概念.设计以及总体基础架构. 1.简 ...

  8. Python操作rabbitmq系列(四):根据类型订阅消息

    在上一章中,所有的接收端获取的所有的消息.这一章,我们将讨论,一些消息,仍然发送给所有接收端.其中,某个接收端,只对其中某些消息感兴趣,它只想接收这一部分消息.如下图:C1,只对error感兴趣,C2 ...

  9. SpringMVC框架详细教程(六)_HelloWorld

    HelloWorld 在src下创建包com.pudding.controller,然后创建一个类HelloWorldController: package com.pudding.controlle ...

  10. SaaS、PaaS、IaaS的含义与区别

    先上个图,直观的了解一下 云计算有SPI,即SaaS.PaaS和IaaS三大服务模式. PaaS和IaaS源于SaaS SaaS Software as a Service 软件即服务,提供给客户的服 ...