DAGNN 是Directed acyclic graph neural network 缩写,也就有向图非循环神经网络。我使用的是由MatConvNet 提供的DAGNN API。选择这套API作为工具的原因有三点,第一:这是matlab的API,相对其他语言我对Matlab比较熟悉;第二:有向图非循环的网络可以实现RPN,Network in Network 等较为复杂的功能,可以随意的引出各层的输入和输出,有利于针对三维视觉任务改造网络结构。MNIST 是手写数字的图片集,也是机器学习网络最简单的试金石。

1、定义层

 conv_layer1 = dagnn.Conv('size',single([5,5,1,30]),'hasBias',true);
relu_layer2 = dagnn.ReLU(); conv_layer3 = dagnn.Conv('size',single([5,5,30,16]),'hasBias',true);
relu_layer4 = dagnn.ReLU();
pooling_layer5 = dagnn.Pooling('poolSize',[2,2],'stride',[2 2]); fullConnet_layer6 = dagnn.Conv('size',single([4,4,16,256]),'hasBias',true);
relu_layer7 = dagnn.ReLU();
fullConnet_layer8 = dagnn.Conv('size',single([1,1,256,10]),'hasBias',true);
SoftMat_layer9 = dagnn.SoftMax();
Loss_layer = dagnn.Loss();

首先是利用API构造各层网络,定义网络结构类型。所有的Layer 都继承自dagnn.Layer类,子类中定义了输入输出,前向传播,反向传播的行为。

其中包括卷积层,激活层,池化层,Softmax 分类层,以及计算Loss层。值得注意的是全连接层是通过大卷积层来实现的。本质上全连接就是“输入的等尺寸卷积”。全连接层的作用是将卷积层提取的特征进行高度非线性的映射,将其映射到输出空间中。

2、定义网络

 mynet = dagnn.DagNN();
mynet.addLayer('conv1',conv_layer1,{'input'},{'x2'},{'filters_conv1','bias_conv1'});
mynet.addLayer('relu1',relu_layer2,{'x2'},{'x3'});
mynet.addLayer('pool1',pooling_layer5,{'x3'},{'x4'}); mynet.addLayer('conv2',conv_layer3,{'x4'},{'x5'},{'filters_conv2','bias_conv2'});
mynet.addLayer('relu2',relu_layer4,{'x5'},{'x6'});
mynet.addLayer('pool2',pooling_layer5,{'x6'},{'x7'}); mynet.addLayer('full1',fullConnet_layer6,{'x7'},{'x8'},{'filters_fc1','bias_fc1'});
mynet.addLayer('relu3',relu_layer7,{'x8'},{'x9'});
mynet.addLayer('full2',fullConnet_layer8,{'x9'},{'x10'},{'filters_fc2','bias_fc2'});
mynet.addLayer('Cls1',SoftMat_layer9,{'x10'},{'pred'});
mynet.addLayer('Loss',Loss_layer,{'pred','label'},{'loss'});
mynet.initParams();
mynet.meta.inputs = {'data',[28,28,1,1]};
mynet.meta.classes.name = {1,2,3,4,5,6,7,8,9,10};
mynet.meta.normalization.imageSize = [28,28,1,1];
mynet.meta.interpolation = 'bicubic';

定义网络调用了addLayer方法,与其他API的网络构建方法不同的是,DAGNN的API需要针对每层定义输入和输出,以及网络中的待求得参数。当然,作为初学者我先实现了链式网络,在下周的工作中会尝试实现Faster R-CNN。

net.addLayer('full1',fullConnet_layer6,{'x7'},{'x8'},{'filters_fc1','bias_fc1'});

以此为例,代表该层的名字是full1 , 该层的结构是fullConnect_layer6,输入为x7、输出x8,参数名为filters_fc1 和 bias_fc1。其中loss 层最为特殊,其具有来自softmax层的pred 和 label (ground truth) 两种输入。

最重要的是一定要initParams()!!!!这会生成初始参数。

3、定义数据输入函数

为了训练网络,我们需要定义一个输入函数。数据量小,可存在内存中,但当数据量大的时候全部存在内存里是不现实的,这就需要一个数据输入函数来对你定义的数据库进行操作。本例中我仅使用5000幅图片进行训练,所以可以把图片放在内存中。getBatch函数如下所示:

 function inputs = getBatch(imdb, batch)
% --------------------------------------------------------------------
images = imdb.images.data(:,:,:,batch) ;
labels = imdb.images.labels(1,batch) ; % images = gpuArray(images) ; inputs = {'input', images, 'label', labels} ;

其中 imdb 是image data base. 其中包括:

imdb.images.data  图片 W*H*C*N 的4-D  single Array

imdb.images.label  标签 N*1  的 single Array

imdb.images.data_mean 图片平均值 用于预处理时去中心

imdb.images.set    集合号 N*1 的 single Array, 其中1 代表训练集 2 代表测试集 3 代表验证集

imdb.meta 存放类型名称等和训练关系不太密切的东西

4、开始训练

直接调用 cnn_train_dag 的API 开始对整个集合进行训练,注意getBatch 输入的是函数句柄。

cnn_train_dag(mynet,imdb_sub,@getBatch);

  训练了30个epoch,但是learningRate好像给太高了,掉局部最小里了。。。。。。。不过结果不错,在验证集中拿到了4998/5000.

机器学习 —— 深度学习 —— 基于DAGNN的MNIST NET的更多相关文章

  1. 机器学习&深度学习基础(目录)

    从业这么久了,做了很多项目,一直对机器学习的基础课程鄙视已久,现在回头看来,系统的基础知识整理对我现在思路的整理很有利,写完这个基础篇,开始把AI+cv的也总结完,然后把这么多年做的项目再写好总结. ...

  2. [转载]机器学习&深度学习经典资料汇总,全到让人震惊

    自学成才秘籍!机器学习&深度学习经典资料汇总 转自:中国大数据: http://www.thebigdata.cn/JiShuBoKe/13299.html [日期:2015-01-27] 来 ...

  3. 机器学习&深度学习资料

    机器学习(Machine Learning)&深度学习(Deep Learning)资料(Chapter 1) 机器学习(Machine Learning)&深度学习(Deep Lea ...

  4. 深度学习之 GAN 进行 mnist 图片的生成

    深度学习之 GAN 进行 mnist 图片的生成 mport numpy as np import os import codecs import torch from PIL import Imag ...

  5. 机器学习&深度学习基础(机器学习基础的算法概述及代码)

    参考:机器学习&深度学习算法及代码实现 Python3机器学习 传统机器学习算法 决策树.K邻近算法.支持向量机.朴素贝叶斯.神经网络.Logistic回归算法,聚类等. 一.机器学习算法及代 ...

  6. 最全的机器学习&深度学习入门视频课程集

    资源介绍 链接:http://pan.baidu.com/s/1kV6nWJP 密码:ryfd     链接:http://pan.baidu.com/s/1dEZWlP3 密码:y82m 更多资源 ...

  7. 深度学习|基于LSTM网络的黄金期货价格预测--转载

    深度学习|基于LSTM网络的黄金期货价格预测 前些天看到一位大佬的深度学习的推文,内容很适用于实战,争得原作者转载同意后,转发给大家.之后会介绍LSTM的理论知识. 我把code先放在我github上 ...

  8. 机器学习&深度学习经典资料汇总,data.gov.uk大量公开数据

    <Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...

  9. 机器学习&深度学习基础(tensorflow版本实现的算法概述0)

    tensorflow集成和实现了各种机器学习基础的算法,可以直接调用. 代码集:https://github.com/ageron/handson-ml 监督学习 1)决策树(Decision Tre ...

随机推荐

  1. 用python实现ARP欺骗

    首先介绍一个python第三方库--Scapy,这个库不是标准库,默认是没有的,需要安装,不过在kali-linux里边是默认安装的, 这里我用kali做攻击者,xp做受害者 关于Scapy Scap ...

  2. JS_高程3.基本概念(5)语句

    1.if语句 2.do-while语句:后测循环语句,循环体内的代码至少执行一次. 3.while语句:前测循环语句. 4.for语句:前测循环语句. 注意:在ECMAScript中不存在块级作用域, ...

  3. 一款开源免费的WPF图表控件ModernuiCharts

    一款简洁好看的Chart控件  支持WPF.silverlight.Windows8  ,基本够用,主要是开源免费的.(商业控件ComponentOne for WPF要4w多呢) This proj ...

  4. SharePoint JavaScript 更新用户和组字段

    前言 最近,需要更新列表字段,字段的类型是用户和组,so写了这么一段代码 function updateUserField(){ var ctx = new SP.ClientContext.get_ ...

  5. 搜索历史命令 Ctrl + R ( ctrl + r to search the history command )

    Linux下的神器 ctrl + r (reverse-i-search ) 的使用方法:   (reverse-i-search usage: ) (press ctl + r ) 输入任意字符,例 ...

  6. springboot 注解整理

    项目用到的注解作用: bean的分类标识@Service: 注解在类上,表示这是一个业务层bean@Controller:注解在类上,表示这是一个控制层bean@Repository: 注解在类上,表 ...

  7. 【JVM】垃圾收集器

    程序计数器.Java虚拟机栈.本地方法栈分配的内存是确定的,生命周期与线程同样.所以不须要过多考虑回收问题. 而Java堆和方法区仅仅有运行时才知道有哪些对象被创建,须要多少内存,这部分的内存分配和回 ...

  8. jvm理论-运行时数据区

    三大流行jvm sun HotSpot ibm j9 BEA JRockit Oracle 会基于HotSpot整合 JRockit. jvm运行时数据区 java虚拟机所管理的内存将会包括以下几个运 ...

  9. Substr与mb_substr区别

      <?php $str = substr('helloword',3,4);//从下标3开始截取截取4个字符 $str = substr('helloword',3);//从截取掉前三个字符 ...

  10. Swift 拷贝文件夹,实现文件夹内容整体复制

    我们知道,在沙盒内,iOS要拷贝一个文件,可以使用 fileManager.copyItem(atPath: fullPath, toPath: fulltoPath) 方法简单实现,不过当我们要拷贝 ...