DAGNN 是Directed acyclic graph neural network 缩写,也就有向图非循环神经网络。我使用的是由MatConvNet 提供的DAGNN API。选择这套API作为工具的原因有三点,第一:这是matlab的API,相对其他语言我对Matlab比较熟悉;第二:有向图非循环的网络可以实现RPN,Network in Network 等较为复杂的功能,可以随意的引出各层的输入和输出,有利于针对三维视觉任务改造网络结构。MNIST 是手写数字的图片集,也是机器学习网络最简单的试金石。

1、定义层

 conv_layer1 = dagnn.Conv('size',single([5,5,1,30]),'hasBias',true);
relu_layer2 = dagnn.ReLU(); conv_layer3 = dagnn.Conv('size',single([5,5,30,16]),'hasBias',true);
relu_layer4 = dagnn.ReLU();
pooling_layer5 = dagnn.Pooling('poolSize',[2,2],'stride',[2 2]); fullConnet_layer6 = dagnn.Conv('size',single([4,4,16,256]),'hasBias',true);
relu_layer7 = dagnn.ReLU();
fullConnet_layer8 = dagnn.Conv('size',single([1,1,256,10]),'hasBias',true);
SoftMat_layer9 = dagnn.SoftMax();
Loss_layer = dagnn.Loss();

首先是利用API构造各层网络,定义网络结构类型。所有的Layer 都继承自dagnn.Layer类,子类中定义了输入输出,前向传播,反向传播的行为。

其中包括卷积层,激活层,池化层,Softmax 分类层,以及计算Loss层。值得注意的是全连接层是通过大卷积层来实现的。本质上全连接就是“输入的等尺寸卷积”。全连接层的作用是将卷积层提取的特征进行高度非线性的映射,将其映射到输出空间中。

2、定义网络

 mynet = dagnn.DagNN();
mynet.addLayer('conv1',conv_layer1,{'input'},{'x2'},{'filters_conv1','bias_conv1'});
mynet.addLayer('relu1',relu_layer2,{'x2'},{'x3'});
mynet.addLayer('pool1',pooling_layer5,{'x3'},{'x4'}); mynet.addLayer('conv2',conv_layer3,{'x4'},{'x5'},{'filters_conv2','bias_conv2'});
mynet.addLayer('relu2',relu_layer4,{'x5'},{'x6'});
mynet.addLayer('pool2',pooling_layer5,{'x6'},{'x7'}); mynet.addLayer('full1',fullConnet_layer6,{'x7'},{'x8'},{'filters_fc1','bias_fc1'});
mynet.addLayer('relu3',relu_layer7,{'x8'},{'x9'});
mynet.addLayer('full2',fullConnet_layer8,{'x9'},{'x10'},{'filters_fc2','bias_fc2'});
mynet.addLayer('Cls1',SoftMat_layer9,{'x10'},{'pred'});
mynet.addLayer('Loss',Loss_layer,{'pred','label'},{'loss'});
mynet.initParams();
mynet.meta.inputs = {'data',[28,28,1,1]};
mynet.meta.classes.name = {1,2,3,4,5,6,7,8,9,10};
mynet.meta.normalization.imageSize = [28,28,1,1];
mynet.meta.interpolation = 'bicubic';

定义网络调用了addLayer方法,与其他API的网络构建方法不同的是,DAGNN的API需要针对每层定义输入和输出,以及网络中的待求得参数。当然,作为初学者我先实现了链式网络,在下周的工作中会尝试实现Faster R-CNN。

net.addLayer('full1',fullConnet_layer6,{'x7'},{'x8'},{'filters_fc1','bias_fc1'});

以此为例,代表该层的名字是full1 , 该层的结构是fullConnect_layer6,输入为x7、输出x8,参数名为filters_fc1 和 bias_fc1。其中loss 层最为特殊,其具有来自softmax层的pred 和 label (ground truth) 两种输入。

最重要的是一定要initParams()!!!!这会生成初始参数。

3、定义数据输入函数

为了训练网络,我们需要定义一个输入函数。数据量小,可存在内存中,但当数据量大的时候全部存在内存里是不现实的,这就需要一个数据输入函数来对你定义的数据库进行操作。本例中我仅使用5000幅图片进行训练,所以可以把图片放在内存中。getBatch函数如下所示:

 function inputs = getBatch(imdb, batch)
% --------------------------------------------------------------------
images = imdb.images.data(:,:,:,batch) ;
labels = imdb.images.labels(1,batch) ; % images = gpuArray(images) ; inputs = {'input', images, 'label', labels} ;

其中 imdb 是image data base. 其中包括:

imdb.images.data  图片 W*H*C*N 的4-D  single Array

imdb.images.label  标签 N*1  的 single Array

imdb.images.data_mean 图片平均值 用于预处理时去中心

imdb.images.set    集合号 N*1 的 single Array, 其中1 代表训练集 2 代表测试集 3 代表验证集

imdb.meta 存放类型名称等和训练关系不太密切的东西

4、开始训练

直接调用 cnn_train_dag 的API 开始对整个集合进行训练,注意getBatch 输入的是函数句柄。

cnn_train_dag(mynet,imdb_sub,@getBatch);

  训练了30个epoch,但是learningRate好像给太高了,掉局部最小里了。。。。。。。不过结果不错,在验证集中拿到了4998/5000.

机器学习 —— 深度学习 —— 基于DAGNN的MNIST NET的更多相关文章

  1. 机器学习&深度学习基础(目录)

    从业这么久了,做了很多项目,一直对机器学习的基础课程鄙视已久,现在回头看来,系统的基础知识整理对我现在思路的整理很有利,写完这个基础篇,开始把AI+cv的也总结完,然后把这么多年做的项目再写好总结. ...

  2. [转载]机器学习&深度学习经典资料汇总,全到让人震惊

    自学成才秘籍!机器学习&深度学习经典资料汇总 转自:中国大数据: http://www.thebigdata.cn/JiShuBoKe/13299.html [日期:2015-01-27] 来 ...

  3. 机器学习&深度学习资料

    机器学习(Machine Learning)&深度学习(Deep Learning)资料(Chapter 1) 机器学习(Machine Learning)&深度学习(Deep Lea ...

  4. 深度学习之 GAN 进行 mnist 图片的生成

    深度学习之 GAN 进行 mnist 图片的生成 mport numpy as np import os import codecs import torch from PIL import Imag ...

  5. 机器学习&深度学习基础(机器学习基础的算法概述及代码)

    参考:机器学习&深度学习算法及代码实现 Python3机器学习 传统机器学习算法 决策树.K邻近算法.支持向量机.朴素贝叶斯.神经网络.Logistic回归算法,聚类等. 一.机器学习算法及代 ...

  6. 最全的机器学习&深度学习入门视频课程集

    资源介绍 链接:http://pan.baidu.com/s/1kV6nWJP 密码:ryfd     链接:http://pan.baidu.com/s/1dEZWlP3 密码:y82m 更多资源 ...

  7. 深度学习|基于LSTM网络的黄金期货价格预测--转载

    深度学习|基于LSTM网络的黄金期货价格预测 前些天看到一位大佬的深度学习的推文,内容很适用于实战,争得原作者转载同意后,转发给大家.之后会介绍LSTM的理论知识. 我把code先放在我github上 ...

  8. 机器学习&深度学习经典资料汇总,data.gov.uk大量公开数据

    <Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...

  9. 机器学习&深度学习基础(tensorflow版本实现的算法概述0)

    tensorflow集成和实现了各种机器学习基础的算法,可以直接调用. 代码集:https://github.com/ageron/handson-ml 监督学习 1)决策树(Decision Tre ...

随机推荐

  1. IE内核浏览器的404页面问题和IE自动缓存引发的问题

    本站404页面被IE替换成IE自己的404页面 在权限设置正确的情况下,自定义的404页面文件大小如果小于512字节,那么IE内核的浏览器会认为你自定义的404页面不够权威,从而使用其自带的404页面 ...

  2. maven 学习

    最近有项目需要储备maven的技能,就学习了一下,找到了一个很适合入门的博客,这里记录下网址. https://www.cnblogs.com/whgk/p/7112560.html

  3. 出现明明SQL语句没问题,但是却无法通过代码查询到结果的问题。

    问题:SQL语句查询不到记录,导致空指针异常 SQL语句: select * from mixinfo where infotype='网站简介' 代码: publicList<HashMap& ...

  4. jQuery 学习(1)——认识jQuery

    1.下载 下载地址:http://jquery.com/download/ jquery-3.2.1.js——用于开发和学习(229K) jquery-3.2.1.min.js——用于项目和产品(31 ...

  5. 集群安装Java环境

    需要安装一个集群环境,发现全部要手动安装java.记录下安装Java环境的过程.虽然,依旧是挨个安装,但总算是有体系了. java 找到下载地址: https://www.oracle.com/tec ...

  6. WSDL测试webservice接口记录

    收到一个事情,需要对接第三方API,对方给了个service,看了一下,原来是webservices的. 上一次测试webervice的接口,还是至少八九年前的时候了,这种相对比较老旧的也好久不在使用 ...

  7. 【php】php5.0以上,instanceof 用法

    1.instanceof php官网:http://php.net/manual/zh/language.operators.type.php 2.instanceof 用于确定一个 PHP 变量是否 ...

  8. 使用h5py操作hdf5文件

    HDF(Hierarchical Data Format)指一种为存储和处理大容量科学数据设计的文件格式及相应库文件.HDF 最早由美国国家超级计算应用中心 NCSA 开发,目前在非盈利组织 HDF ...

  9. 【理论面试篇】收集整理来自网络上的一些常见的 经典前端、H5面试题 Web前端开发面试题

    ##2017.10.30收集 面试技巧 5.1 面试形式 1)        一般而言,小公司做笔试题:大公司面谈项目经验:做地图的一定考算法 2)        面试官喜欢什么样的人 ü  技术好. ...

  10. 【C++】C++中assert和ENDEGU预处理语句

    assert 断言语句是C++中的一种预处理宏语句,它能在程序运行时根据否定条件中断程序. C++中的assert()函数可以实现断言功能,在使用assert函数之前应该先引入<cassert& ...