http://fantasticinblur.iteye.com/blog/1465497

课程作业要求实现一个BPNN。这次尝试使用Java实现了一个。现共享之。版权属于大家。关于BPNN的原理,就不赘述了。

下面是BPNN的实现代码。类名为BP。

  1. package ml;
  2. import java.util.Random;
  3. /**
  4. * BPNN.
  5. *
  6. * @author RenaQiu
  7. *
  8. */
  9. public class BP {
  10. /**
  11. * input vector.
  12. */
  13. private final double[] input;
  14. /**
  15. * hidden layer.
  16. */
  17. private final double[] hidden;
  18. /**
  19. * output layer.
  20. */
  21. private final double[] output;
  22. /**
  23. * target.
  24. */
  25. private final double[] target;
  26. /**
  27. * delta vector of the hidden layer .
  28. */
  29. private final double[] hidDelta;
  30. /**
  31. * output layer of the output layer.
  32. */
  33. private final double[] optDelta;
  34. /**
  35. * learning rate.
  36. */
  37. private final double eta;
  38. /**
  39. * momentum.
  40. */
  41. private final double momentum;
  42. /**
  43. * weight matrix from input layer to hidden layer.
  44. */
  45. private final double[][] iptHidWeights;
  46. /**
  47. * weight matrix from hidden layer to output layer.
  48. */
  49. private final double[][] hidOptWeights;
  50. /**
  51. * previous weight update.
  52. */
  53. private final double[][] iptHidPrevUptWeights;
  54. /**
  55. * previous weight update.
  56. */
  57. private final double[][] hidOptPrevUptWeights;
  58. public double optErrSum = 0d;
  59. public double hidErrSum = 0d;
  60. private final Random random;
  61. /**
  62. * Constructor.
  63. * <p>
  64. * <strong>Note:</strong> The capacity of each layer will be the parameter
  65. * plus 1. The additional unit is used for smoothness.
  66. * </p>
  67. *
  68. * @param inputSize
  69. * @param hiddenSize
  70. * @param outputSize
  71. * @param eta
  72. * @param momentum
  73. * @param epoch
  74. */
  75. public BP(int inputSize, int hiddenSize, int outputSize, double eta,
  76. double momentum) {
  77. input = new double[inputSize + 1];
  78. hidden = new double[hiddenSize + 1];
  79. output = new double[outputSize + 1];
  80. target = new double[outputSize + 1];
  81. hidDelta = new double[hiddenSize + 1];
  82. optDelta = new double[outputSize + 1];
  83. iptHidWeights = new double[inputSize + 1][hiddenSize + 1];
  84. hidOptWeights = new double[hiddenSize + 1][outputSize + 1];
  85. random = new Random(19881211);
  86. randomizeWeights(iptHidWeights);
  87. randomizeWeights(hidOptWeights);
  88. iptHidPrevUptWeights = new double[inputSize + 1][hiddenSize + 1];
  89. hidOptPrevUptWeights = new double[hiddenSize + 1][outputSize + 1];
  90. this.eta = eta;
  91. this.momentum = momentum;
  92. }
  93. private void randomizeWeights(double[][] matrix) {
  94. for (int i = 0, len = matrix.length; i != len; i++)
  95. for (int j = 0, len2 = matrix[i].length; j != len2; j++) {
  96. double real = random.nextDouble();
  97. matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;
  98. }
  99. }
  100. /**
  101. * Constructor with default eta = 0.25 and momentum = 0.3.
  102. *
  103. * @param inputSize
  104. * @param hiddenSize
  105. * @param outputSize
  106. * @param epoch
  107. */
  108. public BP(int inputSize, int hiddenSize, int outputSize) {
  109. this(inputSize, hiddenSize, outputSize, 0.25, 0.9);
  110. }
  111. /**
  112. * Entry method. The train data should be a one-dim vector.
  113. *
  114. * @param trainData
  115. * @param target
  116. */
  117. public void train(double[] trainData, double[] target) {
  118. loadInput(trainData);
  119. loadTarget(target);
  120. forward();
  121. calculateDelta();
  122. adjustWeight();
  123. }
  124. /**
  125. * Test the BPNN.
  126. *
  127. * @param inData
  128. * @return
  129. */
  130. public double[] test(double[] inData) {
  131. if (inData.length != input.length - 1) {
  132. throw new IllegalArgumentException("Size Do Not Match.");
  133. }
  134. System.arraycopy(inData, 0, input, 1, inData.length);
  135. forward();
  136. return getNetworkOutput();
  137. }
  138. /**
  139. * Return the output layer.
  140. *
  141. * @return
  142. */
  143. private double[] getNetworkOutput() {
  144. int len = output.length;
  145. double[] temp = new double[len - 1];
  146. for (int i = 1; i != len; i++)
  147. temp[i - 1] = output[i];
  148. return temp;
  149. }
  150. /**
  151. * Load the target data.
  152. *
  153. * @param arg
  154. */
  155. private void loadTarget(double[] arg) {
  156. if (arg.length != target.length - 1) {
  157. throw new IllegalArgumentException("Size Do Not Match.");
  158. }
  159. System.arraycopy(arg, 0, target, 1, arg.length);
  160. }
  161. /**
  162. * Load the training data.
  163. *
  164. * @param inData
  165. */
  166. private void loadInput(double[] inData) {
  167. if (inData.length != input.length - 1) {
  168. throw new IllegalArgumentException("Size Do Not Match.");
  169. }
  170. System.arraycopy(inData, 0, input, 1, inData.length);
  171. }
  172. /**
  173. * Forward.
  174. *
  175. * @param layer0
  176. * @param layer1
  177. * @param weight
  178. */
  179. private void forward(double[] layer0, double[] layer1, double[][] weight) {
  180. // threshold unit.
  181. layer0[0] = 1.0;
  182. for (int j = 1, len = layer1.length; j != len; ++j) {
  183. double sum = 0;
  184. for (int i = 0, len2 = layer0.length; i != len2; ++i)
  185. sum += weight[i][j] * layer0[i];
  186. layer1[j] = sigmoid(sum);
  187. }
  188. }
  189. /**
  190. * Forward.
  191. */
  192. private void forward() {
  193. forward(input, hidden, iptHidWeights);
  194. forward(hidden, output, hidOptWeights);
  195. }
  196. /**
  197. * Calculate output error.
  198. */
  199. private void outputErr() {
  200. double errSum = 0;
  201. for (int idx = 1, len = optDelta.length; idx != len; ++idx) {
  202. double o = output[idx];
  203. optDelta[idx] = o * (1d - o) * (target[idx] - o);
  204. errSum += Math.abs(optDelta[idx]);
  205. }
  206. optErrSum = errSum;
  207. }
  208. /**
  209. * Calculate hidden errors.
  210. */
  211. private void hiddenErr() {
  212. double errSum = 0;
  213. for (int j = 1, len = hidDelta.length; j != len; ++j) {
  214. double o = hidden[j];
  215. double sum = 0;
  216. for (int k = 1, len2 = optDelta.length; k != len2; ++k)
  217. sum += hidOptWeights[j][k] * optDelta[k];
  218. hidDelta[j] = o * (1d - o) * sum;
  219. errSum += Math.abs(hidDelta[j]);
  220. }
  221. hidErrSum = errSum;
  222. }
  223. /**
  224. * Calculate errors of all layers.
  225. */
  226. private void calculateDelta() {
  227. outputErr();
  228. hiddenErr();
  229. }
  230. /**
  231. * Adjust the weight matrix.
  232. *
  233. * @param delta
  234. * @param layer
  235. * @param weight
  236. * @param prevWeight
  237. */
  238. private void adjustWeight(double[] delta, double[] layer,
  239. double[][] weight, double[][] prevWeight) {
  240. layer[0] = 1;
  241. for (int i = 1, len = delta.length; i != len; ++i) {
  242. for (int j = 0, len2 = layer.length; j != len2; ++j) {
  243. double newVal = momentum * prevWeight[j][i] + eta * delta[i]
  244. * layer[j];
  245. weight[j][i] += newVal;
  246. prevWeight[j][i] = newVal;
  247. }
  248. }
  249. }
  250. /**
  251. * Adjust all weight matrices.
  252. */
  253. private void adjustWeight() {
  254. adjustWeight(optDelta, hidden, hidOptWeights, hidOptPrevUptWeights);
  255. adjustWeight(hidDelta, input, iptHidWeights, iptHidPrevUptWeights);
  256. }
  257. /**
  258. * Sigmoid.
  259. *
  260. * @param val
  261. * @return
  262. */
  263. private double sigmoid(double val) {
  264. return 1d / (1d + Math.exp(-val));
  265. }
  266. }

为了验证正确性,我写了一个测试用例,目的是对于任意的整数(int型),BPNN在经过训练之后,能够准确地判断出它是奇数还是偶数,正数还是负数。首先对于训练的样本(是随机生成的数字),将它转化为一个32位的向量,向量的每个分量就是其二进制形式对应的位上的0或1。将目标输出视作一个4维的向量,[1,0,0,0]代表正奇数,[0,1,0,0]代表正偶数,[0,0,1,0]代表负奇数,[0,0,0,1]代表负偶数。

训练样本为1000个,学习200次。

  1. package ml;
  2. import java.io.IOException;
  3. import java.util.ArrayList;
  4. import java.util.List;
  5. import java.util.Random;
  6. public class Test {
  7. /**
  8. * @param args
  9. * @throws IOException
  10. */
  11. public static void main(String[] args) throws IOException {
  12. BP bp = new BP(32, 15, 4);
  13. Random random = new Random();
  14. List<Integer> list = new ArrayList<Integer>();
  15. for (int i = 0; i != 1000; i++) {
  16. int value = random.nextInt();
  17. list.add(value);
  18. }
  19. for (int i = 0; i != 200; i++) {
  20. for (int value : list) {
  21. double[] real = new double[4];
  22. if (value >= 0)
  23. if ((value & 1) == 1)
  24. real[0] = 1;
  25. else
  26. real[1] = 1;
  27. else if ((value & 1) == 1)
  28. real[2] = 1;
  29. else
  30. real[3] = 1;
  31. double[] binary = new double[32];
  32. int index = 31;
  33. do {
  34. binary[index--] = (value & 1);
  35. value >>>= 1;
  36. } while (value != 0);
  37. bp.train(binary, real);
  38. }
  39. }
  40. System.out.println("训练完毕,下面请输入一个任意数字,神经网络将自动判断它是正数还是复数,奇数还是偶数。");
  41. while (true) {
  42. byte[] input = new byte[10];
  43. System.in.read(input);
  44. Integer value = Integer.parseInt(new String(input).trim());
  45. int rawVal = value;
  46. double[] binary = new double[32];
  47. int index = 31;
  48. do {
  49. binary[index--] = (value & 1);
  50. value >>>= 1;
  51. } while (value != 0);
  52. double[] result = bp.test(binary);
  53. double max = -Integer.MIN_VALUE;
  54. int idx = -1;
  55. for (int i = 0; i != result.length; i++) {
  56. if (result[i] > max) {
  57. max = result[i];
  58. idx = i;
  59. }
  60. }
  61. switch (idx) {
  62. case 0:
  63. System.out.format("%d是一个正奇数\n", rawVal);
  64. break;
  65. case 1:
  66. System.out.format("%d是一个正偶数\n", rawVal);
  67. break;
  68. case 2:
  69. System.out.format("%d是一个负奇数\n", rawVal);
  70. break;
  71. case 3:
  72. System.out.format("%d是一个负偶数\n", rawVal);
  73. break;
  74. }
  75. }
  76. }
  77. }

运行结果截图如下:


 这个测试的例子非常简单。大家可以根据自己的需要去使用BP这个类。

BP神经网络的Java实现(转)的更多相关文章

  1. BP神经网络的Java实现(转载)

    神经网络的计算过程 神经网络结构如下图所示,最左边的是输入层,最右边的是输出层,中间是多个隐含层,隐含层和输出层的每个神经节点,都是由上一层节点乘以其权重累加得到,标上“+1”的圆圈为截距项b,对输入 ...

  2. BP神经网络的Java实现

    http://fantasticinblur.iteye.com/blog/1465497

  3. BP神经网络的直观推导与Java实现

    人工神经网络模拟人体对于外界刺激的反应.某种刺激经过人体多层神经细胞传递后,可以触发人脑中特定的区域做出反应.人体神经网络的作用就是把某种刺激与大脑中的特定区域关联起来了,这样我们对于不同的刺激就可以 ...

  4. 用java写bp神经网络(一)

    根据前篇博文<神经网络之后向传播算法>,现在用java实现一个bp神经网络.矩阵运算采用jblas库,然后逐渐增加功能,支持并行计算,然后支持输入向量调整,最后支持L-BFGS学习算法. ...

  5. BP神经网络—java实现(转载)

    神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...

  6. BP神经网络—java实现

    神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...

  7. JAVA实现BP神经网络算法

    工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测. 简介 BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法 ...

  8. 数据挖掘系列(9)——BP神经网络算法与实践

    神经网络曾经很火,有过一段低迷期,现在因为深度学习的原因继续火起来了.神经网络有很多种:前向传输网络.反向传输网络.递归神经网络.卷积神经网络等.本文介绍基本的反向传输神经网络(Backpropaga ...

  9. BP神经网络的数学原理及其算法实现

    什么是BP网络 BP网络的数学原理 BP网络算法实现 转载请声明出处http://blog.csdn.net/zhongkejingwang/article/details/44514073  上一篇 ...

随机推荐

  1. Coding和Git的环境搭建

    Github太慢了.打开网页慢,下载也只有几kb. 于是找了国内的Git,据说coding不错.就申请了个. 其实csdn也有...但是没人家的专业... 1 注册coding  https://co ...

  2. ELK到底是什么?那么多公司用!

    Sina.饿了么.携程.华为.美团.freewheel.畅捷通 .新浪微博.大讲台.魅族.IBM...... 这些公司都在使用ELK!ELK!ELK! ELK竟然重复了三遍,是个什么?   一.ELK ...

  3. Android 应用内切换语言

    extends :http://bbs.51cto.com/thread-1075165-1.html,http://www.cnblogs.com/loulijun/p/3164746.html 1 ...

  4. Ubuntu 16.04系统下解决Vim乱码问题

    方法: 打开终端输入:vim /etc/vim/vimrc,进入编辑模式,加入如下配置: set fileencodings=utf-8,gb2312,gbk,gb18030 set termenco ...

  5. thinkphp实现采集功能的三种方法!

    最近在做一些数据分析,由于上网找数据比较麻烦,所以写了一个采集网站数据的方法.具体方法如下: 方法一:QueryList 个人感觉比较好用,采集详情比较不错的选择,但是采集复杂一点的列表,不好用.具体 ...

  6. Windows运行python脚本文件

    开始学习python就是听说这个语言写脚本文件特别方便,简单使用.学了一段时间,但是直到现在我才直到直到怎么在Windows的cmd上运行脚本文件. 之前一直都是在pycharm上运行,并不实用. 百 ...

  7. Ubuntu16.04 安装lamp环境

    拿到新装的ubuntu16.04新系统 首先 apt-get update 更新一下 我这里是root用户,如果您不是超级管理员,命令前加sudo即可 如果您加了sudo也不好使,那就联系管理员,给你 ...

  8. HTTP与HTTPS对访问速度(性能)的影响【转】

    1 前言 HTTPS 在保护用户隐私,防止流量劫持方面发挥着非常关键的作用,但与此同时,HTTPS 也会降低用户访问速度,增加网站服务器的计算资源消耗. 本文主要介绍 https 对用户体验的影响. ...

  9. Oracle备份恢复之无备份情况下恢复undo表空间

    UNDO表空间存储着DML操作数据块的前镜像数据,在数据回滚,一致性读,闪回操作,实例恢复的时候都可能用到UNDO表空间中的数据.如果在生产过程中丢失或破坏了UNDO表空间,可能导致某些事务无法回滚, ...

  10. cordova 跨平台APP版本升级

    利用cordova+ionic开发好项目,之后就是打包发布,在这之前,还要做一个版本升级的小功能. 首先我们项目根目录里自然少不了配置:config.xml中 如图.version,我们以后每次升级A ...