DAGNN 是Directed acyclic graph neural network 缩写,也就有向图非循环神经网络。我使用的是由MatConvNet 提供的DAGNN API。选择这套API作为工具的原因有三点,第一:这是matlab的API,相对其他语言我对Matlab比较熟悉;第二:有向图非循环的网络可以实现RPN,Network in Network 等较为复杂的功能,可以随意的引出各层的输入和输出,有利于针对三维视觉任务改造网络结构。MNIST 是手写数字的图片集,也是机器学习网络最简单的试金石。

1、定义层

 conv_layer1 = dagnn.Conv('size',single([5,5,1,30]),'hasBias',true);
relu_layer2 = dagnn.ReLU(); conv_layer3 = dagnn.Conv('size',single([5,5,30,16]),'hasBias',true);
relu_layer4 = dagnn.ReLU();
pooling_layer5 = dagnn.Pooling('poolSize',[2,2],'stride',[2 2]); fullConnet_layer6 = dagnn.Conv('size',single([4,4,16,256]),'hasBias',true);
relu_layer7 = dagnn.ReLU();
fullConnet_layer8 = dagnn.Conv('size',single([1,1,256,10]),'hasBias',true);
SoftMat_layer9 = dagnn.SoftMax();
Loss_layer = dagnn.Loss();

首先是利用API构造各层网络,定义网络结构类型。所有的Layer 都继承自dagnn.Layer类,子类中定义了输入输出,前向传播,反向传播的行为。

其中包括卷积层,激活层,池化层,Softmax 分类层,以及计算Loss层。值得注意的是全连接层是通过大卷积层来实现的。本质上全连接就是“输入的等尺寸卷积”。全连接层的作用是将卷积层提取的特征进行高度非线性的映射,将其映射到输出空间中。

2、定义网络

 mynet = dagnn.DagNN();
mynet.addLayer('conv1',conv_layer1,{'input'},{'x2'},{'filters_conv1','bias_conv1'});
mynet.addLayer('relu1',relu_layer2,{'x2'},{'x3'});
mynet.addLayer('pool1',pooling_layer5,{'x3'},{'x4'}); mynet.addLayer('conv2',conv_layer3,{'x4'},{'x5'},{'filters_conv2','bias_conv2'});
mynet.addLayer('relu2',relu_layer4,{'x5'},{'x6'});
mynet.addLayer('pool2',pooling_layer5,{'x6'},{'x7'}); mynet.addLayer('full1',fullConnet_layer6,{'x7'},{'x8'},{'filters_fc1','bias_fc1'});
mynet.addLayer('relu3',relu_layer7,{'x8'},{'x9'});
mynet.addLayer('full2',fullConnet_layer8,{'x9'},{'x10'},{'filters_fc2','bias_fc2'});
mynet.addLayer('Cls1',SoftMat_layer9,{'x10'},{'pred'});
mynet.addLayer('Loss',Loss_layer,{'pred','label'},{'loss'});
mynet.initParams();
mynet.meta.inputs = {'data',[28,28,1,1]};
mynet.meta.classes.name = {1,2,3,4,5,6,7,8,9,10};
mynet.meta.normalization.imageSize = [28,28,1,1];
mynet.meta.interpolation = 'bicubic';

定义网络调用了addLayer方法,与其他API的网络构建方法不同的是,DAGNN的API需要针对每层定义输入和输出,以及网络中的待求得参数。当然,作为初学者我先实现了链式网络,在下周的工作中会尝试实现Faster R-CNN。

net.addLayer('full1',fullConnet_layer6,{'x7'},{'x8'},{'filters_fc1','bias_fc1'});

以此为例,代表该层的名字是full1 , 该层的结构是fullConnect_layer6,输入为x7、输出x8,参数名为filters_fc1 和 bias_fc1。其中loss 层最为特殊,其具有来自softmax层的pred 和 label (ground truth) 两种输入。

最重要的是一定要initParams()!!!!这会生成初始参数。

3、定义数据输入函数

为了训练网络,我们需要定义一个输入函数。数据量小,可存在内存中,但当数据量大的时候全部存在内存里是不现实的,这就需要一个数据输入函数来对你定义的数据库进行操作。本例中我仅使用5000幅图片进行训练,所以可以把图片放在内存中。getBatch函数如下所示:

 function inputs = getBatch(imdb, batch)
% --------------------------------------------------------------------
images = imdb.images.data(:,:,:,batch) ;
labels = imdb.images.labels(1,batch) ; % images = gpuArray(images) ; inputs = {'input', images, 'label', labels} ;

其中 imdb 是image data base. 其中包括:

imdb.images.data  图片 W*H*C*N 的4-D  single Array

imdb.images.label  标签 N*1  的 single Array

imdb.images.data_mean 图片平均值 用于预处理时去中心

imdb.images.set    集合号 N*1 的 single Array, 其中1 代表训练集 2 代表测试集 3 代表验证集

imdb.meta 存放类型名称等和训练关系不太密切的东西

4、开始训练

直接调用 cnn_train_dag 的API 开始对整个集合进行训练,注意getBatch 输入的是函数句柄。

cnn_train_dag(mynet,imdb_sub,@getBatch);

  训练了30个epoch,但是learningRate好像给太高了,掉局部最小里了。。。。。。。不过结果不错,在验证集中拿到了4998/5000.

机器学习 —— 深度学习 —— 基于DAGNN的MNIST NET的更多相关文章

  1. 机器学习&深度学习基础(目录)

    从业这么久了,做了很多项目,一直对机器学习的基础课程鄙视已久,现在回头看来,系统的基础知识整理对我现在思路的整理很有利,写完这个基础篇,开始把AI+cv的也总结完,然后把这么多年做的项目再写好总结. ...

  2. [转载]机器学习&深度学习经典资料汇总,全到让人震惊

    自学成才秘籍!机器学习&深度学习经典资料汇总 转自:中国大数据: http://www.thebigdata.cn/JiShuBoKe/13299.html [日期:2015-01-27] 来 ...

  3. 机器学习&深度学习资料

    机器学习(Machine Learning)&深度学习(Deep Learning)资料(Chapter 1) 机器学习(Machine Learning)&深度学习(Deep Lea ...

  4. 深度学习之 GAN 进行 mnist 图片的生成

    深度学习之 GAN 进行 mnist 图片的生成 mport numpy as np import os import codecs import torch from PIL import Imag ...

  5. 机器学习&深度学习基础(机器学习基础的算法概述及代码)

    参考:机器学习&深度学习算法及代码实现 Python3机器学习 传统机器学习算法 决策树.K邻近算法.支持向量机.朴素贝叶斯.神经网络.Logistic回归算法,聚类等. 一.机器学习算法及代 ...

  6. 最全的机器学习&深度学习入门视频课程集

    资源介绍 链接:http://pan.baidu.com/s/1kV6nWJP 密码:ryfd     链接:http://pan.baidu.com/s/1dEZWlP3 密码:y82m 更多资源 ...

  7. 深度学习|基于LSTM网络的黄金期货价格预测--转载

    深度学习|基于LSTM网络的黄金期货价格预测 前些天看到一位大佬的深度学习的推文,内容很适用于实战,争得原作者转载同意后,转发给大家.之后会介绍LSTM的理论知识. 我把code先放在我github上 ...

  8. 机器学习&深度学习经典资料汇总,data.gov.uk大量公开数据

    <Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...

  9. 机器学习&深度学习基础(tensorflow版本实现的算法概述0)

    tensorflow集成和实现了各种机器学习基础的算法,可以直接调用. 代码集:https://github.com/ageron/handson-ml 监督学习 1)决策树(Decision Tre ...

随机推荐

  1. 【容斥+组合数】Massage @2018acm徐州邀请赛 E

    问题 E: Massage 时间限制: 1 Sec  内存限制: 64 MB 题目描述 JSZKC  feels  so  bored  in  the  classroom  that  he  w ...

  2. 【二分】Producing Snow @Codeforces Round #470 Div.2 C

    time limit per test: 1 second memory limit per test: 256 megabytes Alice likes snow a lot! Unfortuna ...

  3. python3 图片文字识别

    最近用到了图片文字识别这个功能,从网上搜查了一下,决定利用百度的文字识别接口.通过测试发现文字识别率还可以.下面就测试过程简要说明一下 1.注册用户 链接:https://login.bce.baid ...

  4. bootstrap-table方法之:合并单元格

    方法一 通过mergeCells方法 演示地址:http://issues.wenzhixin.net.cn/bootstrap-table/methods/mergeCells.html Merge ...

  5. 【一步步学OpenGL 20】 -《点光源》

    教程 20 点光源 原文: http://ogldev.atspace.co.uk/www/tutorial20/tutorial20.html CSDN完整版专栏: http://blog.csdn ...

  6. centos7.0安装cuda驱动

    00.CUDA简介 CUDA和GPU的并行处理能力来加速深度学习和其他计算密集型应用程序 01.CPU+GPU协同架构 02.部署环境 [docker@lab-250 ~]$ cat /etc/*re ...

  7. 转【微信小程序 四】二维码生成/扫描二维码

    原文:https://blog.csdn.net/xbw12138/article/details/75213274 前端 二维码生成 二维码要求:每分钟刷新一次,模拟了个鸡肋,添加了个按分钟显示的时 ...

  8. SQL 数据库结构化查询语言

    1.数据库 常见数据库 MySQL:开源免费的数据库,小型的数据库. Oracle:收费的大型数据库,Oracle 公司的产品 DB2:IBM 公司收费的数据库,常应用在银行系统中 SQLServer ...

  9. Altium Designer重装后图标都变白板或都变一样的解决方法

    https://blog.csdn.net/qq_41995282/article/details/80372113

  10. 使用 Node.js 搭建API 网关

    外部客户端访问微服务架构中的服务时,服务端会对认证和传输有一些常见的要求.API 网关提供共享层来处理服务协议之间的差异,并满足特定客户端(如桌面浏览器.移动设备和老系统)的要求. 微服务和消费者 微 ...