人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练

MXNet 是一个轻量级、可移植、灵活的分布式深度学习框架,2017 年 1 月 23 日,该项目进入 Apache 基金会,成为 Apache 的孵化器项目。尽管现在已经有很多深度学习框架,包括 TensorFlow, Keras, Torch,以及 Caffe,但 Apache MXNet 因其对多 GPU 的分布式支持而越来越受欢迎。

环境准备
1.安装 Anaconda。Anaconda 是一个用于科学计算的 Python 发行版,提供了包管理与环境管理的功能。Anaconda 利用 conda 来进行 package 和 environment 的管理,并且已经包含了 Python 和相关的配套工具。

Anaconda3-4.4下载地址: https://repo.continuum.io/archive/Anaconda3-4.4.0-Windows-x86_64.exe

2.在 conda 下安装 pip,安装命令为‘conda install pip’

3.安装 OpenCV-python 库。OpenCV-python 是一个很强大的计算机视觉库,在这个项目中可以用于处理图像。使用‘pip install openvc-python’在 Anaconda 环境下安装 OpenCV。也可以从源文件进行编译(注意:conda 安装 opencv3.0 不能运行)。

4.安装 scikit learn,一个开源的 python 机器学习科学计算库,它将用于对数据进行预处理。安装命令为‘conda install scikit-learn’。

5.安装 Jupyter Notebook,安装命令为‘conda install jupyter notebook’。

6.安装 MXNet。安装命令为‘pip install mxnet’。

------------------
数据库

使用的数据库是德国交通标志识别基准,来自论文《德国交通标志识别基准:多类别分类竞赛》( J. Stallkamp, M. Schlipsing, J. Salmen, and C. Igel. "The German Traffic Sign Recognition Benchmark: A multi-class classification competition." ),发表在 IEEE International Joint Conference on Neural Networks,2011。该数据集包含 39209 张训练样例和 12630 张测试样例,有 43 种不同的交通标志——停车标志,限速标志,各种警示标志以及其他标志。
数据库中的每张图像大小为 32×32,均为三通道彩色图。每幅图属于一种交通标志。图像种类标签由 0 到 42 的整数表示。

从一个 NumPy 阵列中下载数据,数据分为训练,验证和测试集。训练集包含 39209 张大小为 32×32,通道数为 3 的图像,所以 NumPy 阵列的维度为 39209×32×32×3。该项目中作者仅使用了训练集和验证集。作者将使用网上的真实图像来测试所构建的模型。X_train 存储图像,维度为 39209×32×32×3。Y_train 存储图像对应的类标,维度为 39209,包含 0-42 的整数,对应每张图的类标。

训练过程

1. 准备数据集
X_train 和 Y_train 组成了训练数据集。可以使用 scikit-learn 对训练数据集进行分割得到验证集,这样可以避免使用出现过的图片测试模型。代码如下:

2. 训练数据预处理
批训练
神经网络训练需要花费大量时间和内存。所以作者将数据分批训练,一批大小为 64. 不仅是为了让数据适应内存,而且它可以让 MXNet 尽量利用 GPU 的计算效率。
归一化
除此之外,图像的像素值也进行了归一化,可以使学习算法更快收敛。下面是对训练数据进行预处理的代码:

3. 构建深度网络
目前,对于图像识别这类处在探索研究热点的问题,学界已经设计了很多效果良好的网络结构。所以最好的方法是实现一个已经发表出来的网络结构,然后对其进行改进。基于 AlexNet 结构,构建了一个简化版的卷积神经网络。AlexNet 是 2012 年发表的一个经典网络,在当年取得了 ImageNet 的最好成绩。

网络共有 8 层,其中前 5 层是卷积层,后边 3 层是全连接层,在每一个卷积层中包含了激励函数 RELU 以及局部响应归一化(LRN)处理,然后再经过池化(max pooling),最后的一个全连接层的输出是具有 1000 个输出的 softmax 层,最后的优化目标是最大化平均的多元逻辑回归。
在此之后也有很多更优秀的网络结构被提出,例如 VGGNet 和 ResNet,大家可以选择更好的网络结构去实现。
由于 MXNet 的符号计算构架,该神经网络的代码十分简洁明了

4. 训练网络
训练 epoch 为 10,训练好的模型存在 JSON 文件中,并且可以通过测量训练和验证准确率来观测网络“学习”的情况。

5. 载入预训练模型
下面给出了加载第 10 个 epoch 模型(最终模型)的代码。由于将在单张图片上进行测试,所以批尺寸由 64 减到 1,数据维度也变成了 1×3×32×32。

测试过程
测试图像(32×32×3)样例:

从结果可以看出可能性最高的种类为停车标志,说明预测准确。如果需要对模型有一个更完整的衡量,还需要用测试数据库进行测试,得到最终的分类准确率。

总结
本文我们介绍了使用 MXNet 进行多目标分类任务的方法。使用 MXNet,在 AlexNet 的结构基础上构建了一个更为简单的卷积神经网络结构。网络由卷积层,激活函数层,池化层和全连接层组成,采用德国交通标志图像训练数据库对该网络进行训练,实验结果证明网络可以将交通标志进行正确的分类。介绍了如何使用 MXNet 对数据进行预处理,构建网络,以及如何加载预训练好的网络模型。可以看出,MXNet 因其在多 GPU 上进行并行训练的能力,以及网络模型构建简单灵活的特性,是一个十分优秀的深度学习框架。

人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练的更多相关文章

  1. Tensorflow 实战Google深度学习框架 第五章 5.2.1Minister数字识别 源代码

    import os import tab import tensorflow as tf print "tensorflow 5.2 " from tensorflow.examp ...

  2. 吴裕雄--天生自然python Google深度学习框架:TensorFlow实现神经网络

    http://playground.tensorflow.org/

  3. 深度学习-使用cuda加速卷积神经网络-手写数字识别准确率99.7%

    源码和运行结果 cuda:https://github.com/zhxfl/CUDA-CNN C语言版本参考自:http://eric-yuan.me/ 针对著名手写数字识别的库mnist,准确率是9 ...

  4. TensorFlow+实战Google深度学习框架学习笔记(5)----神经网络训练步骤

    一.TensorFlow实战Google深度学习框架学习 1.步骤: 1.定义神经网络的结构和前向传播的输出结果. 2.定义损失函数以及选择反向传播优化的算法. 3.生成会话(session)并且在训 ...

  5. TensorFlow实战Google深度学习框架-人工智能教程-自学人工智能的第二天-深度学习

    自学人工智能的第一天 "TensorFlow 是谷歌 2015 年开源的主流深度学习框架,目前已得到广泛应用.本书为 TensorFlow 入门参考书,旨在帮助读者以快速.有效的方式上手 T ...

  6. 转:TensorFlow和Caffe、MXNet、Keras等其他深度学习框架的对比

    http://geek.csdn.net/news/detail/138968 Google近日发布了TensorFlow 1.0候选版,这第一个稳定版将是深度学习框架发展中的里程碑的一步.自Tens ...

  7. [Tensorflow实战Google深度学习框架]笔记4

    本系列为Tensorflow实战Google深度学习框架知识笔记,仅为博主看书过程中觉得较为重要的知识点,简单摘要下来,内容较为零散,请见谅. 2017-11-06 [第五章] MNIST数字识别问题 ...

  8. Reading | 《TensorFlow:实战Google深度学习框架》

    目录 三.TensorFlow入门 1. TensorFlow计算模型--计算图 I. 计算图的概念 II. 计算图的使用 2.TensorFlow数据类型--张量 I. 张量的概念 II. 张量的使 ...

  9. 1 如何使用pb文件保存和恢复模型进行迁移学习(学习Tensorflow 实战google深度学习框架)

    学习过程是Tensorflow 实战google深度学习框架一书的第六章的迁移学习环节. 具体见我提出的问题:https://www.tensorflowers.cn/t/5314 参考https:/ ...

随机推荐

  1. python-面向对象-07_继承

    继承 目标 单继承 多继承 面向对象三大特性 封装 根据 职责 将 属性 和 方法 封装 到一个抽象的 类 中 继承 实现代码的重用,相同的代码不需要重复的编写 多态 不同的对象调用相同的方法,产生不 ...

  2. 如何让帝国CMS7.2搜索模板支持动态标签调用

    帝国cms站内搜索一般不支持动态标签调用,如果要调用如何实现呢?修改两个地方就可以实现了.打开 /e/search/result/index.php 文件,找到(文件改了,不会调用也是徒劳!看看这个帝 ...

  3. 20170718 关于Mysql 安装于虚拟机Ubuntu中,内网中Windows系统无法访问

    -- 1. 前提Mysql 已经安装在Ubuntu中 -- 2. 防火墙已经关闭 命令确认防护墙状态 -- 3.问题如果Ubuntu是基于Docker容器的环境,是否需要把Docker做端口映射? 解 ...

  4. Python3学习之路~2.8 文件操作实现简单的shell sed替换功能

    程序:实现简单的shell sed替换功能 #实现简单的shell sed替换功能,保存为file_sed.py #打开命令行输入python file_sed.py 我 Alex,回车后会把文件中的 ...

  5. 解决npm ERR! Unexpected end of JSON input while parsing near的方法汇总

    参考链接:https://segmentfault.com/a/1190000015646531

  6. [django]python异步神器-celery

    python异步神器celery https://segmentfault.com/a/1190000007780963

  7. lua加载函数require和dofile

    lua加载函数require和dofile Lua提供高级的require函数来加载运行库.粗略的说require和dofile完成同样的功能但有两点不同: 1. require会搜索目录加载文件; ...

  8. Swagger Editor本地安装

    一:安装Node JS 二:下载源码swagger-editor源码,解压 下载地址:https://github.com/swagger-api/swagger-editor 三:在解压目录下运行进 ...

  9. jenkins集成sonar

    用于我的sonar已经在一台机器上搭建好了,但是每次都要人工去执行sonar-run,很麻烦,所以就想着集成到jenkins上,在jenkins上点点按钮就可以看sonar结果,所以很抱歉,本博客不设 ...

  10. 26-Python3 面向对象

    26-Python3 面向对象 ''' 面向对象技术简介 ''' ''' 类定义 ''' ''' 类对象 ''' class MyClass: i = 12345 def f(self): retur ...