DBN(深度信念网络)
DBN运用CD算法逐层进行训练,得到每一层的参数Wi和ci用于初始化DBN,之后再用监督学习算法对参数进行微调。本例中采用softmax分类器(下一篇随笔中)作为监督学习算法。
RBM与上一篇随笔中一致,通过多层RBM将softmax parameter从 (10L, 784L)降低到(10L, 50L)。单独用softmax分类器也可以得到相近(或者略好)的正确率,所需的时间略长一点。
- from rbm2 import RBM
- from softmax import SoftMax
- import os
- import numpy as np
- import cPickle
- class DBN:
- def __init__(self,nlayers,ntype,vlen,hlen):
- self.rbm_layers = []
- self.nlayers = nlayers
- self.ntype = ntype
- self.vlen=vlen
- self.hlen=hlen
- def calcRBMForward(self,x):
- for rbm in self.rbm_layers:
- x = rbm.forward(x.T)
- return x
- def load_param(self,dbnpath,softmaxpath):
- weights = cPickle.load(open(dbnpath,'rb'))
- self.nlayers = len(weights)
- for i in range(self.nlayers):
- weight = weights[i]
- v,h= np.shape(weight)
- rbm = RBM(v,h)
- rbm.w = weight
- self.rbm_layers.append(rbm)
- print "RBM layer%d shape:%s" %(i,str(rbm.w.shape))
- self.softmax = SoftMax()
- self.softmax.load_theta(softmaxpath)
- print "softmax parameter: "+str(self.softmax.theta.shape)
- def pretrainRBM(self,trainset):
- weights = []
- for i in range(self.nlayers):
- rbm = RBM(self.vlen,self.hlen)
- if i == 0:
- traindata = trainset
- else:
- traindata = np.array(outdata.T)
- rbm.rbmBB(traindata)
- outdata = np.mat(rbm.forward(traindata))
- self.rbm_layers.append(rbm)
- weights.append(rbm.w)
- self.vlen = self.hlen
- self.hlen = self.hlen/2
- f= open("data/dbn.pkl",'wb')
- cPickle.dump(weights,f)
- f.close()
- def fineTune(self,trainset,labelset):
- rbm_output = self.calcRBMForward(trainset)
- MAXT,step,landa = 100,1,0.01
- self.softmax = SoftMax(MAXT,step,landa)
- self.softmax.process_train(rbm_output,labelset,self.ntype)
- def predict(self,x):
- rbm_output = self.calcRBMForward(x)
- return self.softmax.predict(rbm_output)
- def validate(self,testset,labelset):
- testnum = len(testset)
- correctnum = 0
- for i in range(testnum):
- x = testset[i]
- testtype = self.predict(x)
- orgtype = labelset[i]
- if testtype == orgtype:
- correctnum += 1
- rate = float(correctnum)/testnum
- print "correctnum = %d, sumnum = %d" %(correctnum,testnum)
- print "Accuracy:%.2f" %(rate)
- return rate
- dbn = DBN(3,10,784,200)
- f = open('mnist.pkl', 'rb')
- training_data, validation_data, test_data = cPickle.load(f)
- training_inputs = [np.reshape(x, 784) for x in training_data[0]]
- data = np.array(training_inputs[:5000]).T
- training_inputs = [np.reshape(x, 784) for x in validation_data[0]]
- vdata = np.array(training_inputs[:5000])
- if not os.path.exists('data/softmax.pkl'): # Run twice
- dbn.pretrainRBM(data)
- dbn.fineTune(data.T,training_data[1][:5000])
- else:
- dbn.load_param("data/dbn.pkl","data/softmax.pkl")
- dbn.validate(vdata,validation_data[1][:5000])
- #RBM layer0 shape:(784L, 200L)
- #RBM layer1 shape:(200L, 100L)
- #RBM layer2 shape:(100L, 50L)
- #softmax parameter: (10L, 50L)
- #correctnum = 4357, sumnum = 5000
- #Accuracy:0.87
DBN(深度信念网络)的更多相关文章
- 机器学习——DBN深度信念网络详解(转)
深度神经网路已经在语音识别,图像识别等领域取得前所未有的成功.本人在多年之前也曾接触过神经网络.本系列文章主要记录自己对深度神经网络的一些学习心得. 简要描述深度神经网络模型. 1. 自联想神经网络 ...
- 深度学习(二)--深度信念网络(DBN)
深度学习(二)--深度信念网络(Deep Belief Network,DBN) 一.受限玻尔兹曼机(Restricted Boltzmann Machine,RBM) 在介绍深度信念网络之前需要先了 ...
- 受限玻尔兹曼机(RBM, Restricted Boltzmann machines)和深度信念网络(DBN, Deep Belief Networks)
受限玻尔兹曼机对于当今的非监督学习有一定的启发意义. 深度信念网络(DBN, Deep Belief Networks)于2006年由Geoffery Hinton提出.
- Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.3
Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.3 http://blog.csdn.net/sunbow0 第二章Deep ...
- Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.1
Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.1 http://blog.csdn.net/sunbow0 Spark ML ...
- Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.2
Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.2 http://blog.csdn.net/sunbow0 第二章Deep ...
- 理论优美的深度信念网络--Hinton北大最新演讲
什么是深度信念网络 深度信念网络是第一批成功应用深度架构训练的非卷积模型之一. 在引入深度信念网络之前,研究社区通常认为深度模型太难优化,还不如使用易于优化的浅层ML模型.2006年,Hinton等研 ...
- 八.DBN深度置信网络
BP神经网络是1968年由Rumelhart和Mcclelland为首的科学家提出的概念,是一种按照误差反向传播算法进行训练的多层前馈神经网络,是目前应用比较广泛的一种神经网络结构.BP网络神经网络由 ...
- RBM(受限玻尔兹曼机)和深层信念网络(Deep Brief Network)
目录: 一.RBM 二.Deep Brief Network 三.Deep Autoencoder 一.RBM 1.定义[无监督学习] RBM记住三个要诀:1)两层结构图,可视层和隐藏层:[没输出层] ...
随机推荐
- 一篇介绍jquery中的ajax的结合
<script type="text/javascript"> function Text_ajax() { $.aja ...
- 03-树2 Tree Traversals Again
这题是第二次做了,两次都不是独立完成,不过我发现我第一次参考的程序,也是参考老师(陈越)的范例做出来的.我对老师给的做了小幅修改,因为我不想有全局变量的的存在,所以我多传了三个参数进去.正序遍历每次都 ...
- Airbase-ng帮助
Airbase-ng 1.2 rc2 - (C) 2008-2014 Thomas d'Otreppe Original work: Martin Beck http://www.aircrack ...
- java中判断字符串是否为数字的三种方法
以下内容引自 http://www.blogjava.net/Javaphua/archive/2007/06/05/122131.html 1用JAVA自带的函数 public static ...
- HDU 5253 最小生成树(kruskal)+ 并查集
题目链接 #include<cstdio> #include<cmath> #include<cstring> #include<algorithm> ...
- SQL语句建表、设置主键、外键、check、default、unique约束
· 什么是数据库? 存放数据的仓库. · 数据库和数据结构有什么区别? 数据结构要解决在内存中操作数据的问题,数据库要解决在硬盘中操作数据的问题.数据结构研究一些抽象数据模型(ADT)和以及定义在该模 ...
- PHP 上传图片和安全处理
上传图片 public function images() { $data = $_FILES['file']; switch($data['type']) { case 'image/jpeg': ...
- JS回车键处理
HTML <input type="text" id="keyword" name="keyword" autocomplete=&q ...
- [转]设计模式(22)-Strategy Pattern
一. 策略(Strategy)模式 策略模式的用意是针对一组算法,将每一个算法封装到具有共同接口的独立的类中,从而使得它们可以相互替换.策略模式使得算法可以在不影响到客户端的情况下发生变化. 假 设现 ...
- 设备、像素和点 、 9切片技术 、 颜色和外观 、 NavigationBar的美化
1 TMessage项目的输入面板界面 1.1 问题 IOS中经常会使用到九切片技术对图片进行处理.本案例使用九切片技术完成Tmessage项目的输入板界面,如图-1所示: 图-1 1.2 方案 首先 ...