机器学习 —— 深度学习 —— 基于DAGNN的MNIST NET
DAGNN 是Directed acyclic graph neural network 缩写,也就有向图非循环神经网络。我使用的是由MatConvNet 提供的DAGNN API。选择这套API作为工具的原因有三点,第一:这是matlab的API,相对其他语言我对Matlab比较熟悉;第二:有向图非循环的网络可以实现RPN,Network in Network 等较为复杂的功能,可以随意的引出各层的输入和输出,有利于针对三维视觉任务改造网络结构。MNIST 是手写数字的图片集,也是机器学习网络最简单的试金石。
1、定义层
conv_layer1 = dagnn.Conv('size',single([5,5,1,30]),'hasBias',true);
relu_layer2 = dagnn.ReLU(); conv_layer3 = dagnn.Conv('size',single([5,5,30,16]),'hasBias',true);
relu_layer4 = dagnn.ReLU();
pooling_layer5 = dagnn.Pooling('poolSize',[2,2],'stride',[2 2]); fullConnet_layer6 = dagnn.Conv('size',single([4,4,16,256]),'hasBias',true);
relu_layer7 = dagnn.ReLU();
fullConnet_layer8 = dagnn.Conv('size',single([1,1,256,10]),'hasBias',true);
SoftMat_layer9 = dagnn.SoftMax();
Loss_layer = dagnn.Loss();
首先是利用API构造各层网络,定义网络结构类型。所有的Layer 都继承自dagnn.Layer类,子类中定义了输入输出,前向传播,反向传播的行为。
其中包括卷积层,激活层,池化层,Softmax 分类层,以及计算Loss层。值得注意的是全连接层是通过大卷积层来实现的。本质上全连接就是“输入的等尺寸卷积”。全连接层的作用是将卷积层提取的特征进行高度非线性的映射,将其映射到输出空间中。
2、定义网络
mynet = dagnn.DagNN();
mynet.addLayer('conv1',conv_layer1,{'input'},{'x2'},{'filters_conv1','bias_conv1'});
mynet.addLayer('relu1',relu_layer2,{'x2'},{'x3'});
mynet.addLayer('pool1',pooling_layer5,{'x3'},{'x4'}); mynet.addLayer('conv2',conv_layer3,{'x4'},{'x5'},{'filters_conv2','bias_conv2'});
mynet.addLayer('relu2',relu_layer4,{'x5'},{'x6'});
mynet.addLayer('pool2',pooling_layer5,{'x6'},{'x7'}); mynet.addLayer('full1',fullConnet_layer6,{'x7'},{'x8'},{'filters_fc1','bias_fc1'});
mynet.addLayer('relu3',relu_layer7,{'x8'},{'x9'});
mynet.addLayer('full2',fullConnet_layer8,{'x9'},{'x10'},{'filters_fc2','bias_fc2'});
mynet.addLayer('Cls1',SoftMat_layer9,{'x10'},{'pred'});
mynet.addLayer('Loss',Loss_layer,{'pred','label'},{'loss'});
mynet.initParams();
mynet.meta.inputs = {'data',[28,28,1,1]};
mynet.meta.classes.name = {1,2,3,4,5,6,7,8,9,10};
mynet.meta.normalization.imageSize = [28,28,1,1];
mynet.meta.interpolation = 'bicubic';
定义网络调用了addLayer方法,与其他API的网络构建方法不同的是,DAGNN的API需要针对每层定义输入和输出,以及网络中的待求得参数。当然,作为初学者我先实现了链式网络,在下周的工作中会尝试实现Faster R-CNN。
net.addLayer('full1',fullConnet_layer6,{'x7'},{'x8'},{'filters_fc1','bias_fc1'});
以此为例,代表该层的名字是full1 , 该层的结构是fullConnect_layer6,输入为x7、输出x8,参数名为filters_fc1 和 bias_fc1。其中loss 层最为特殊,其具有来自softmax层的pred 和 label (ground truth) 两种输入。
最重要的是一定要initParams()!!!!这会生成初始参数。
3、定义数据输入函数
为了训练网络,我们需要定义一个输入函数。数据量小,可存在内存中,但当数据量大的时候全部存在内存里是不现实的,这就需要一个数据输入函数来对你定义的数据库进行操作。本例中我仅使用5000幅图片进行训练,所以可以把图片放在内存中。getBatch函数如下所示:
function inputs = getBatch(imdb, batch)
% --------------------------------------------------------------------
images = imdb.images.data(:,:,:,batch) ;
labels = imdb.images.labels(1,batch) ; % images = gpuArray(images) ; inputs = {'input', images, 'label', labels} ;
其中 imdb 是image data base. 其中包括:
imdb.images.data 图片 W*H*C*N 的4-D single Array
imdb.images.label 标签 N*1 的 single Array
imdb.images.data_mean 图片平均值 用于预处理时去中心
imdb.images.set 集合号 N*1 的 single Array, 其中1 代表训练集 2 代表测试集 3 代表验证集
imdb.meta 存放类型名称等和训练关系不太密切的东西
4、开始训练
直接调用 cnn_train_dag 的API 开始对整个集合进行训练,注意getBatch 输入的是函数句柄。
cnn_train_dag(mynet,imdb_sub,@getBatch);
训练了30个epoch,但是learningRate好像给太高了,掉局部最小里了。。。。。。。不过结果不错,在验证集中拿到了4998/5000.
机器学习 —— 深度学习 —— 基于DAGNN的MNIST NET的更多相关文章
- 机器学习&深度学习基础(目录)
从业这么久了,做了很多项目,一直对机器学习的基础课程鄙视已久,现在回头看来,系统的基础知识整理对我现在思路的整理很有利,写完这个基础篇,开始把AI+cv的也总结完,然后把这么多年做的项目再写好总结. ...
- [转载]机器学习&深度学习经典资料汇总,全到让人震惊
自学成才秘籍!机器学习&深度学习经典资料汇总 转自:中国大数据: http://www.thebigdata.cn/JiShuBoKe/13299.html [日期:2015-01-27] 来 ...
- 机器学习&深度学习资料
机器学习(Machine Learning)&深度学习(Deep Learning)资料(Chapter 1) 机器学习(Machine Learning)&深度学习(Deep Lea ...
- 深度学习之 GAN 进行 mnist 图片的生成
深度学习之 GAN 进行 mnist 图片的生成 mport numpy as np import os import codecs import torch from PIL import Imag ...
- 机器学习&深度学习基础(机器学习基础的算法概述及代码)
参考:机器学习&深度学习算法及代码实现 Python3机器学习 传统机器学习算法 决策树.K邻近算法.支持向量机.朴素贝叶斯.神经网络.Logistic回归算法,聚类等. 一.机器学习算法及代 ...
- 最全的机器学习&深度学习入门视频课程集
资源介绍 链接:http://pan.baidu.com/s/1kV6nWJP 密码:ryfd 链接:http://pan.baidu.com/s/1dEZWlP3 密码:y82m 更多资源 ...
- 深度学习|基于LSTM网络的黄金期货价格预测--转载
深度学习|基于LSTM网络的黄金期货价格预测 前些天看到一位大佬的深度学习的推文,内容很适用于实战,争得原作者转载同意后,转发给大家.之后会介绍LSTM的理论知识. 我把code先放在我github上 ...
- 机器学习&深度学习经典资料汇总,data.gov.uk大量公开数据
<Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost到随机森林.D ...
- 机器学习&深度学习基础(tensorflow版本实现的算法概述0)
tensorflow集成和实现了各种机器学习基础的算法,可以直接调用. 代码集:https://github.com/ageron/handson-ml 监督学习 1)决策树(Decision Tre ...
随机推荐
- IE内核浏览器的404页面问题和IE自动缓存引发的问题
本站404页面被IE替换成IE自己的404页面 在权限设置正确的情况下,自定义的404页面文件大小如果小于512字节,那么IE内核的浏览器会认为你自定义的404页面不够权威,从而使用其自带的404页面 ...
- maven 学习
最近有项目需要储备maven的技能,就学习了一下,找到了一个很适合入门的博客,这里记录下网址. https://www.cnblogs.com/whgk/p/7112560.html
- 出现明明SQL语句没问题,但是却无法通过代码查询到结果的问题。
问题:SQL语句查询不到记录,导致空指针异常 SQL语句: select * from mixinfo where infotype='网站简介' 代码: publicList<HashMap& ...
- jQuery 学习(1)——认识jQuery
1.下载 下载地址:http://jquery.com/download/ jquery-3.2.1.js——用于开发和学习(229K) jquery-3.2.1.min.js——用于项目和产品(31 ...
- 集群安装Java环境
需要安装一个集群环境,发现全部要手动安装java.记录下安装Java环境的过程.虽然,依旧是挨个安装,但总算是有体系了. java 找到下载地址: https://www.oracle.com/tec ...
- WSDL测试webservice接口记录
收到一个事情,需要对接第三方API,对方给了个service,看了一下,原来是webservices的. 上一次测试webervice的接口,还是至少八九年前的时候了,这种相对比较老旧的也好久不在使用 ...
- 【php】php5.0以上,instanceof 用法
1.instanceof php官网:http://php.net/manual/zh/language.operators.type.php 2.instanceof 用于确定一个 PHP 变量是否 ...
- 使用h5py操作hdf5文件
HDF(Hierarchical Data Format)指一种为存储和处理大容量科学数据设计的文件格式及相应库文件.HDF 最早由美国国家超级计算应用中心 NCSA 开发,目前在非盈利组织 HDF ...
- 【理论面试篇】收集整理来自网络上的一些常见的 经典前端、H5面试题 Web前端开发面试题
##2017.10.30收集 面试技巧 5.1 面试形式 1) 一般而言,小公司做笔试题:大公司面谈项目经验:做地图的一定考算法 2) 面试官喜欢什么样的人 ü 技术好. ...
- 【C++】C++中assert和ENDEGU预处理语句
assert 断言语句是C++中的一种预处理宏语句,它能在程序运行时根据否定条件中断程序. C++中的assert()函数可以实现断言功能,在使用assert函数之前应该先引入<cassert& ...