DAGNN 是Directed acyclic graph neural network 缩写,也就有向图非循环神经网络。我使用的是由MatConvNet 提供的DAGNN API。选择这套API作为工具的原因有三点,第一:这是matlab的API,相对其他语言我对Matlab比较熟悉;第二:有向图非循环的网络可以实现RPN,Network in Network 等较为复杂的功能,可以随意的引出各层的输入和输出,有利于针对三维视觉任务改造网络结构。MNIST 是手写数字的图片集,也是机器学习网络最简单的试金石。

1、定义层

 conv_layer1 = dagnn.Conv('size',single([5,5,1,30]),'hasBias',true);
relu_layer2 = dagnn.ReLU(); conv_layer3 = dagnn.Conv('size',single([5,5,30,16]),'hasBias',true);
relu_layer4 = dagnn.ReLU();
pooling_layer5 = dagnn.Pooling('poolSize',[2,2],'stride',[2 2]); fullConnet_layer6 = dagnn.Conv('size',single([4,4,16,256]),'hasBias',true);
relu_layer7 = dagnn.ReLU();
fullConnet_layer8 = dagnn.Conv('size',single([1,1,256,10]),'hasBias',true);
SoftMat_layer9 = dagnn.SoftMax();
Loss_layer = dagnn.Loss();

首先是利用API构造各层网络,定义网络结构类型。所有的Layer 都继承自dagnn.Layer类,子类中定义了输入输出,前向传播,反向传播的行为。

其中包括卷积层,激活层,池化层,Softmax 分类层,以及计算Loss层。值得注意的是全连接层是通过大卷积层来实现的。本质上全连接就是“输入的等尺寸卷积”。全连接层的作用是将卷积层提取的特征进行高度非线性的映射,将其映射到输出空间中。

2、定义网络

 mynet = dagnn.DagNN();
mynet.addLayer('conv1',conv_layer1,{'input'},{'x2'},{'filters_conv1','bias_conv1'});
mynet.addLayer('relu1',relu_layer2,{'x2'},{'x3'});
mynet.addLayer('pool1',pooling_layer5,{'x3'},{'x4'}); mynet.addLayer('conv2',conv_layer3,{'x4'},{'x5'},{'filters_conv2','bias_conv2'});
mynet.addLayer('relu2',relu_layer4,{'x5'},{'x6'});
mynet.addLayer('pool2',pooling_layer5,{'x6'},{'x7'}); mynet.addLayer('full1',fullConnet_layer6,{'x7'},{'x8'},{'filters_fc1','bias_fc1'});
mynet.addLayer('relu3',relu_layer7,{'x8'},{'x9'});
mynet.addLayer('full2',fullConnet_layer8,{'x9'},{'x10'},{'filters_fc2','bias_fc2'});
mynet.addLayer('Cls1',SoftMat_layer9,{'x10'},{'pred'});
mynet.addLayer('Loss',Loss_layer,{'pred','label'},{'loss'});
mynet.initParams();
mynet.meta.inputs = {'data',[28,28,1,1]};
mynet.meta.classes.name = {1,2,3,4,5,6,7,8,9,10};
mynet.meta.normalization.imageSize = [28,28,1,1];
mynet.meta.interpolation = 'bicubic';

定义网络调用了addLayer方法,与其他API的网络构建方法不同的是,DAGNN的API需要针对每层定义输入和输出,以及网络中的待求得参数。当然,作为初学者我先实现了链式网络,在下周的工作中会尝试实现Faster R-CNN。

net.addLayer('full1',fullConnet_layer6,{'x7'},{'x8'},{'filters_fc1','bias_fc1'});

以此为例,代表该层的名字是full1 , 该层的结构是fullConnect_layer6,输入为x7、输出x8,参数名为filters_fc1 和 bias_fc1。其中loss 层最为特殊,其具有来自softmax层的pred 和 label (ground truth) 两种输入。

最重要的是一定要initParams()!!!!这会生成初始参数。

3、定义数据输入函数

为了训练网络,我们需要定义一个输入函数。数据量小,可存在内存中,但当数据量大的时候全部存在内存里是不现实的,这就需要一个数据输入函数来对你定义的数据库进行操作。本例中我仅使用5000幅图片进行训练,所以可以把图片放在内存中。getBatch函数如下所示:

 function inputs = getBatch(imdb, batch)
% --------------------------------------------------------------------
images = imdb.images.data(:,:,:,batch) ;
labels = imdb.images.labels(1,batch) ; % images = gpuArray(images) ; inputs = {'input', images, 'label', labels} ;

其中 imdb 是image data base. 其中包括:

imdb.images.data  图片 W*H*C*N 的4-D  single Array

imdb.images.label  标签 N*1  的 single Array

imdb.images.data_mean 图片平均值 用于预处理时去中心

imdb.images.set    集合号 N*1 的 single Array, 其中1 代表训练集 2 代表测试集 3 代表验证集

imdb.meta 存放类型名称等和训练关系不太密切的东西

4、开始训练

直接调用 cnn_train_dag 的API 开始对整个集合进行训练,注意getBatch 输入的是函数句柄。

cnn_train_dag(mynet,imdb_sub,@getBatch);

  训练了30个epoch,但是learningRate好像给太高了,掉局部最小里了。。。。。。。不过结果不错,在验证集中拿到了4998/5000.

机器学习 —— 深度学习 —— 基于DAGNN的MNIST NET的更多相关文章

  1. 机器学习&深度学习基础(目录)

    从业这么久了,做了很多项目,一直对机器学习的基础课程鄙视已久,现在回头看来,系统的基础知识整理对我现在思路的整理很有利,写完这个基础篇,开始把AI+cv的也总结完,然后把这么多年做的项目再写好总结. ...

  2. [转载]机器学习&深度学习经典资料汇总,全到让人震惊

    自学成才秘籍!机器学习&深度学习经典资料汇总 转自:中国大数据: http://www.thebigdata.cn/JiShuBoKe/13299.html [日期:2015-01-27] 来 ...

  3. 机器学习&深度学习资料

    机器学习(Machine Learning)&深度学习(Deep Learning)资料(Chapter 1) 机器学习(Machine Learning)&深度学习(Deep Lea ...

  4. 深度学习之 GAN 进行 mnist 图片的生成

    深度学习之 GAN 进行 mnist 图片的生成 mport numpy as np import os import codecs import torch from PIL import Imag ...

  5. 机器学习&深度学习基础(机器学习基础的算法概述及代码)

    参考:机器学习&深度学习算法及代码实现 Python3机器学习 传统机器学习算法 决策树.K邻近算法.支持向量机.朴素贝叶斯.神经网络.Logistic回归算法,聚类等. 一.机器学习算法及代 ...

  6. 最全的机器学习&深度学习入门视频课程集

    资源介绍 链接:http://pan.baidu.com/s/1kV6nWJP 密码:ryfd     链接:http://pan.baidu.com/s/1dEZWlP3 密码:y82m 更多资源 ...

  7. 深度学习|基于LSTM网络的黄金期货价格预测--转载

    深度学习|基于LSTM网络的黄金期货价格预测 前些天看到一位大佬的深度学习的推文,内容很适用于实战,争得原作者转载同意后,转发给大家.之后会介绍LSTM的理论知识. 我把code先放在我github上 ...

  8. 机器学习&深度学习经典资料汇总,data.gov.uk大量公开数据

    <Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...

  9. 机器学习&深度学习基础(tensorflow版本实现的算法概述0)

    tensorflow集成和实现了各种机器学习基础的算法,可以直接调用. 代码集:https://github.com/ageron/handson-ml 监督学习 1)决策树(Decision Tre ...

随机推荐

  1. Java中的RASP实现

    RSAP RASP是Gartner公司提出的一个概念,称:程序不应该依赖于外部组件进行运行时保护,而应该自身拥有运行时环境保护机制: RASP就是运行时应用自我保护(Runtime applicati ...

  2. idea其他人把jar更新之后更新不到

    昨天下午开始就发现这个问题,其他同事把jar更新了之后,我一直获取不到更新之后的内容.尝试了很多方法,删除具体的更新不到的jar,一直不停的mvn clean install -U -Dmaven.t ...

  3. JDK提供的几种线程池比较

    JDK提供的几种线程池 newFixedThreadPool创建一个指定工作线程数量的线程池.每当提交一个任务就创建一个工作线程,如果工作线程数量达到线程池初始的最大数,则将提交的任务存入到池队列中. ...

  4. REST风格的增删改查(2)

    一.源码 1.页面 index.jsp <a href="emps">List All Employee</a> <br><br> ...

  5. java解决手机上传竖拍照片旋转90\180\270度问题

    <dependency> <groupId>com.drewnoakes</groupId> <artifactId>metadata-extracto ...

  6. 倾斜摄影数据OSGB进入到ArcGIS平台相关问题小结

    版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/zglybl/article/details/75252288      随着倾斜摄影技术的发展,大家 ...

  7. 配置gitlab通过smtp发送邮件

    1. 编辑/etc/gitlab/gitlab.rb文件(加到文件最后面就好了,或者通过搜索找到,新版已有这些配置,只不过都是被注释掉了).以QQ企业邮箱为例: gitlab_rails['smtp_ ...

  8. Linux目录/usr缩写及目录结构说明

    在 linux 文件结构中,有一个很神奇的目录 —— /usr.   讨论中,大部分观点认为:   usr 是 unix system resources 的缩写: usr 是 user 的缩写: u ...

  9. C# Task WhenAny和WhenAll 以及TaskFactory 的ContinueWhenAny和ContinueWhenAll的实现

    个人感觉Task 的WaitAny和WhenAny以及TaskFactory 的ContinueWhenAny有相似的地方,而WaitAll和WhenAll以及TaskFactory 的Continu ...

  10. Spring中的CharacterEncodingFilter

    spring的配置文件如下: <?xml version="1.0" encoding="UTF-8"?> <web-app xmlns=&q ...