DBN运用CD算法逐层进行训练,得到每一层的参数Wi和ci用于初始化DBN,之后再用监督学习算法对参数进行微调。本例中采用softmax分类器(下一篇随笔中)作为监督学习算法。

RBM与上一篇随笔中一致,通过多层RBM将softmax parameter从 (10L, 784L)降低到(10L, 50L)。单独用softmax分类器也可以得到相近(或者略好)的正确率,所需的时间略长一点。

  1. from rbm2 import RBM
  2. from softmax import SoftMax
  3. import os
  4. import numpy as np
  5. import cPickle
  6.  
  7. class DBN:
  8. def __init__(self,nlayers,ntype,vlen,hlen):
  9. self.rbm_layers = []
  10. self.nlayers = nlayers
  11. self.ntype = ntype
  12. self.vlen=vlen
  13. self.hlen=hlen
  14.  
  15. def calcRBMForward(self,x):
  16. for rbm in self.rbm_layers:
  17. x = rbm.forward(x.T)
  18. return x
  19.  
  20. def load_param(self,dbnpath,softmaxpath):
  21. weights = cPickle.load(open(dbnpath,'rb'))
  22. self.nlayers = len(weights)
  23. for i in range(self.nlayers):
  24. weight = weights[i]
  25. v,h= np.shape(weight)
  26. rbm = RBM(v,h)
  27. rbm.w = weight
  28. self.rbm_layers.append(rbm)
  29. print "RBM layer%d shape:%s" %(i,str(rbm.w.shape))
  30. self.softmax = SoftMax()
  31. self.softmax.load_theta(softmaxpath)
  32. print "softmax parameter: "+str(self.softmax.theta.shape)
  33.  
  34. def pretrainRBM(self,trainset):
  35. weights = []
  36. for i in range(self.nlayers):
  37. rbm = RBM(self.vlen,self.hlen)
  38. if i == 0:
  39. traindata = trainset
  40. else:
  41. traindata = np.array(outdata.T)
  42. rbm.rbmBB(traindata)
  43. outdata = np.mat(rbm.forward(traindata))
  44. self.rbm_layers.append(rbm)
  45. weights.append(rbm.w)
  46. self.vlen = self.hlen
  47. self.hlen = self.hlen/2
  48. f= open("data/dbn.pkl",'wb')
  49. cPickle.dump(weights,f)
  50. f.close()
  51.  
  52. def fineTune(self,trainset,labelset):
  53. rbm_output = self.calcRBMForward(trainset)
  54. MAXT,step,landa = 100,1,0.01
  55. self.softmax = SoftMax(MAXT,step,landa)
  56. self.softmax.process_train(rbm_output,labelset,self.ntype)
  57.  
  58. def predict(self,x):
  59. rbm_output = self.calcRBMForward(x)
  60. return self.softmax.predict(rbm_output)
  61.  
  62. def validate(self,testset,labelset):
  63. testnum = len(testset)
  64. correctnum = 0
  65. for i in range(testnum):
  66. x = testset[i]
  67. testtype = self.predict(x)
  68. orgtype = labelset[i]
  69. if testtype == orgtype:
  70. correctnum += 1
  71. rate = float(correctnum)/testnum
  72. print "correctnum = %d, sumnum = %d" %(correctnum,testnum)
  73. print "Accuracy:%.2f" %(rate)
  74. return rate
  75.  
  76. dbn = DBN(3,10,784,200)
  77. f = open('mnist.pkl', 'rb')
  78. training_data, validation_data, test_data = cPickle.load(f)
  79. training_inputs = [np.reshape(x, 784) for x in training_data[0]]
  80. data = np.array(training_inputs[:5000]).T
  81. training_inputs = [np.reshape(x, 784) for x in validation_data[0]]
  82. vdata = np.array(training_inputs[:5000])
  83. if not os.path.exists('data/softmax.pkl'): # Run twice
  84. dbn.pretrainRBM(data)
  85. dbn.fineTune(data.T,training_data[1][:5000])
  86. else:
  87. dbn.load_param("data/dbn.pkl","data/softmax.pkl")
  88. dbn.validate(vdata,validation_data[1][:5000])
  89.  
  90. #RBM layer0 shape:(784L, 200L)
  91. #RBM layer1 shape:(200L, 100L)
  92. #RBM layer2 shape:(100L, 50L)
  93. #softmax parameter: (10L, 50L)
  94. #correctnum = 4357, sumnum = 5000
  95. #Accuracy:0.87

DBN(深度信念网络)的更多相关文章

  1. 机器学习——DBN深度信念网络详解(转)

    深度神经网路已经在语音识别,图像识别等领域取得前所未有的成功.本人在多年之前也曾接触过神经网络.本系列文章主要记录自己对深度神经网络的一些学习心得. 简要描述深度神经网络模型. 1.  自联想神经网络 ...

  2. 深度学习(二)--深度信念网络(DBN)

    深度学习(二)--深度信念网络(Deep Belief Network,DBN) 一.受限玻尔兹曼机(Restricted Boltzmann Machine,RBM) 在介绍深度信念网络之前需要先了 ...

  3. 受限玻尔兹曼机(RBM, Restricted Boltzmann machines)和深度信念网络(DBN, Deep Belief Networks)

    受限玻尔兹曼机对于当今的非监督学习有一定的启发意义. 深度信念网络(DBN, Deep Belief Networks)于2006年由Geoffery Hinton提出.

  4. Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.3

    Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.3 http://blog.csdn.net/sunbow0 第二章Deep ...

  5. Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.1

    Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.1 http://blog.csdn.net/sunbow0 Spark ML ...

  6. Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.2

    Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.2 http://blog.csdn.net/sunbow0 第二章Deep ...

  7. 理论优美的深度信念网络--Hinton北大最新演讲

    什么是深度信念网络 深度信念网络是第一批成功应用深度架构训练的非卷积模型之一. 在引入深度信念网络之前,研究社区通常认为深度模型太难优化,还不如使用易于优化的浅层ML模型.2006年,Hinton等研 ...

  8. 八.DBN深度置信网络

    BP神经网络是1968年由Rumelhart和Mcclelland为首的科学家提出的概念,是一种按照误差反向传播算法进行训练的多层前馈神经网络,是目前应用比较广泛的一种神经网络结构.BP网络神经网络由 ...

  9. RBM(受限玻尔兹曼机)和深层信念网络(Deep Brief Network)

    目录: 一.RBM 二.Deep Brief Network 三.Deep Autoencoder 一.RBM 1.定义[无监督学习] RBM记住三个要诀:1)两层结构图,可视层和隐藏层:[没输出层] ...

随机推荐

  1. 一篇介绍jquery中的ajax的结合

    <script type="text/javascript">        function Text_ajax()        {           $.aja ...

  2. 03-树2 Tree Traversals Again

    这题是第二次做了,两次都不是独立完成,不过我发现我第一次参考的程序,也是参考老师(陈越)的范例做出来的.我对老师给的做了小幅修改,因为我不想有全局变量的的存在,所以我多传了三个参数进去.正序遍历每次都 ...

  3. Airbase-ng帮助

    Airbase-ng 1.2 rc2 - (C) 2008-2014 Thomas d'Otreppe  Original work: Martin Beck  http://www.aircrack ...

  4. java中判断字符串是否为数字的三种方法

    以下内容引自  http://www.blogjava.net/Javaphua/archive/2007/06/05/122131.html 1用JAVA自带的函数   public static ...

  5. HDU 5253 最小生成树(kruskal)+ 并查集

    题目链接 #include<cstdio> #include<cmath> #include<cstring> #include<algorithm> ...

  6. SQL语句建表、设置主键、外键、check、default、unique约束

    · 什么是数据库? 存放数据的仓库. · 数据库和数据结构有什么区别? 数据结构要解决在内存中操作数据的问题,数据库要解决在硬盘中操作数据的问题.数据结构研究一些抽象数据模型(ADT)和以及定义在该模 ...

  7. PHP 上传图片和安全处理

    上传图片 public function images() { $data = $_FILES['file']; switch($data['type']) { case 'image/jpeg': ...

  8. JS回车键处理

    HTML <input type="text" id="keyword" name="keyword" autocomplete=&q ...

  9. [转]设计模式(22)-Strategy Pattern

    一. 策略(Strategy)模式 策略模式的用意是针对一组算法,将每一个算法封装到具有共同接口的独立的类中,从而使得它们可以相互替换.策略模式使得算法可以在不影响到客户端的情况下发生变化. 假 设现 ...

  10. 设备、像素和点 、 9切片技术 、 颜色和外观 、 NavigationBar的美化

    1 TMessage项目的输入面板界面 1.1 问题 IOS中经常会使用到九切片技术对图片进行处理.本案例使用九切片技术完成Tmessage项目的输入板界面,如图-1所示: 图-1 1.2 方案 首先 ...