http://fantasticinblur.iteye.com/blog/1465497

课程作业要求实现一个BPNN。这次尝试使用Java实现了一个。现共享之。版权属于大家。关于BPNN的原理,就不赘述了。

下面是BPNN的实现代码。类名为BP。

  1. package ml;
  2. import java.util.Random;
  3. /**
  4. * BPNN.
  5. *
  6. * @author RenaQiu
  7. *
  8. */
  9. public class BP {
  10. /**
  11. * input vector.
  12. */
  13. private final double[] input;
  14. /**
  15. * hidden layer.
  16. */
  17. private final double[] hidden;
  18. /**
  19. * output layer.
  20. */
  21. private final double[] output;
  22. /**
  23. * target.
  24. */
  25. private final double[] target;
  26. /**
  27. * delta vector of the hidden layer .
  28. */
  29. private final double[] hidDelta;
  30. /**
  31. * output layer of the output layer.
  32. */
  33. private final double[] optDelta;
  34. /**
  35. * learning rate.
  36. */
  37. private final double eta;
  38. /**
  39. * momentum.
  40. */
  41. private final double momentum;
  42. /**
  43. * weight matrix from input layer to hidden layer.
  44. */
  45. private final double[][] iptHidWeights;
  46. /**
  47. * weight matrix from hidden layer to output layer.
  48. */
  49. private final double[][] hidOptWeights;
  50. /**
  51. * previous weight update.
  52. */
  53. private final double[][] iptHidPrevUptWeights;
  54. /**
  55. * previous weight update.
  56. */
  57. private final double[][] hidOptPrevUptWeights;
  58. public double optErrSum = 0d;
  59. public double hidErrSum = 0d;
  60. private final Random random;
  61. /**
  62. * Constructor.
  63. * <p>
  64. * <strong>Note:</strong> The capacity of each layer will be the parameter
  65. * plus 1. The additional unit is used for smoothness.
  66. * </p>
  67. *
  68. * @param inputSize
  69. * @param hiddenSize
  70. * @param outputSize
  71. * @param eta
  72. * @param momentum
  73. * @param epoch
  74. */
  75. public BP(int inputSize, int hiddenSize, int outputSize, double eta,
  76. double momentum) {
  77. input = new double[inputSize + 1];
  78. hidden = new double[hiddenSize + 1];
  79. output = new double[outputSize + 1];
  80. target = new double[outputSize + 1];
  81. hidDelta = new double[hiddenSize + 1];
  82. optDelta = new double[outputSize + 1];
  83. iptHidWeights = new double[inputSize + 1][hiddenSize + 1];
  84. hidOptWeights = new double[hiddenSize + 1][outputSize + 1];
  85. random = new Random(19881211);
  86. randomizeWeights(iptHidWeights);
  87. randomizeWeights(hidOptWeights);
  88. iptHidPrevUptWeights = new double[inputSize + 1][hiddenSize + 1];
  89. hidOptPrevUptWeights = new double[hiddenSize + 1][outputSize + 1];
  90. this.eta = eta;
  91. this.momentum = momentum;
  92. }
  93. private void randomizeWeights(double[][] matrix) {
  94. for (int i = 0, len = matrix.length; i != len; i++)
  95. for (int j = 0, len2 = matrix[i].length; j != len2; j++) {
  96. double real = random.nextDouble();
  97. matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;
  98. }
  99. }
  100. /**
  101. * Constructor with default eta = 0.25 and momentum = 0.3.
  102. *
  103. * @param inputSize
  104. * @param hiddenSize
  105. * @param outputSize
  106. * @param epoch
  107. */
  108. public BP(int inputSize, int hiddenSize, int outputSize) {
  109. this(inputSize, hiddenSize, outputSize, 0.25, 0.9);
  110. }
  111. /**
  112. * Entry method. The train data should be a one-dim vector.
  113. *
  114. * @param trainData
  115. * @param target
  116. */
  117. public void train(double[] trainData, double[] target) {
  118. loadInput(trainData);
  119. loadTarget(target);
  120. forward();
  121. calculateDelta();
  122. adjustWeight();
  123. }
  124. /**
  125. * Test the BPNN.
  126. *
  127. * @param inData
  128. * @return
  129. */
  130. public double[] test(double[] inData) {
  131. if (inData.length != input.length - 1) {
  132. throw new IllegalArgumentException("Size Do Not Match.");
  133. }
  134. System.arraycopy(inData, 0, input, 1, inData.length);
  135. forward();
  136. return getNetworkOutput();
  137. }
  138. /**
  139. * Return the output layer.
  140. *
  141. * @return
  142. */
  143. private double[] getNetworkOutput() {
  144. int len = output.length;
  145. double[] temp = new double[len - 1];
  146. for (int i = 1; i != len; i++)
  147. temp[i - 1] = output[i];
  148. return temp;
  149. }
  150. /**
  151. * Load the target data.
  152. *
  153. * @param arg
  154. */
  155. private void loadTarget(double[] arg) {
  156. if (arg.length != target.length - 1) {
  157. throw new IllegalArgumentException("Size Do Not Match.");
  158. }
  159. System.arraycopy(arg, 0, target, 1, arg.length);
  160. }
  161. /**
  162. * Load the training data.
  163. *
  164. * @param inData
  165. */
  166. private void loadInput(double[] inData) {
  167. if (inData.length != input.length - 1) {
  168. throw new IllegalArgumentException("Size Do Not Match.");
  169. }
  170. System.arraycopy(inData, 0, input, 1, inData.length);
  171. }
  172. /**
  173. * Forward.
  174. *
  175. * @param layer0
  176. * @param layer1
  177. * @param weight
  178. */
  179. private void forward(double[] layer0, double[] layer1, double[][] weight) {
  180. // threshold unit.
  181. layer0[0] = 1.0;
  182. for (int j = 1, len = layer1.length; j != len; ++j) {
  183. double sum = 0;
  184. for (int i = 0, len2 = layer0.length; i != len2; ++i)
  185. sum += weight[i][j] * layer0[i];
  186. layer1[j] = sigmoid(sum);
  187. }
  188. }
  189. /**
  190. * Forward.
  191. */
  192. private void forward() {
  193. forward(input, hidden, iptHidWeights);
  194. forward(hidden, output, hidOptWeights);
  195. }
  196. /**
  197. * Calculate output error.
  198. */
  199. private void outputErr() {
  200. double errSum = 0;
  201. for (int idx = 1, len = optDelta.length; idx != len; ++idx) {
  202. double o = output[idx];
  203. optDelta[idx] = o * (1d - o) * (target[idx] - o);
  204. errSum += Math.abs(optDelta[idx]);
  205. }
  206. optErrSum = errSum;
  207. }
  208. /**
  209. * Calculate hidden errors.
  210. */
  211. private void hiddenErr() {
  212. double errSum = 0;
  213. for (int j = 1, len = hidDelta.length; j != len; ++j) {
  214. double o = hidden[j];
  215. double sum = 0;
  216. for (int k = 1, len2 = optDelta.length; k != len2; ++k)
  217. sum += hidOptWeights[j][k] * optDelta[k];
  218. hidDelta[j] = o * (1d - o) * sum;
  219. errSum += Math.abs(hidDelta[j]);
  220. }
  221. hidErrSum = errSum;
  222. }
  223. /**
  224. * Calculate errors of all layers.
  225. */
  226. private void calculateDelta() {
  227. outputErr();
  228. hiddenErr();
  229. }
  230. /**
  231. * Adjust the weight matrix.
  232. *
  233. * @param delta
  234. * @param layer
  235. * @param weight
  236. * @param prevWeight
  237. */
  238. private void adjustWeight(double[] delta, double[] layer,
  239. double[][] weight, double[][] prevWeight) {
  240. layer[0] = 1;
  241. for (int i = 1, len = delta.length; i != len; ++i) {
  242. for (int j = 0, len2 = layer.length; j != len2; ++j) {
  243. double newVal = momentum * prevWeight[j][i] + eta * delta[i]
  244. * layer[j];
  245. weight[j][i] += newVal;
  246. prevWeight[j][i] = newVal;
  247. }
  248. }
  249. }
  250. /**
  251. * Adjust all weight matrices.
  252. */
  253. private void adjustWeight() {
  254. adjustWeight(optDelta, hidden, hidOptWeights, hidOptPrevUptWeights);
  255. adjustWeight(hidDelta, input, iptHidWeights, iptHidPrevUptWeights);
  256. }
  257. /**
  258. * Sigmoid.
  259. *
  260. * @param val
  261. * @return
  262. */
  263. private double sigmoid(double val) {
  264. return 1d / (1d + Math.exp(-val));
  265. }
  266. }

为了验证正确性,我写了一个测试用例,目的是对于任意的整数(int型),BPNN在经过训练之后,能够准确地判断出它是奇数还是偶数,正数还是负数。首先对于训练的样本(是随机生成的数字),将它转化为一个32位的向量,向量的每个分量就是其二进制形式对应的位上的0或1。将目标输出视作一个4维的向量,[1,0,0,0]代表正奇数,[0,1,0,0]代表正偶数,[0,0,1,0]代表负奇数,[0,0,0,1]代表负偶数。

训练样本为1000个,学习200次。

  1. package ml;
  2. import java.io.IOException;
  3. import java.util.ArrayList;
  4. import java.util.List;
  5. import java.util.Random;
  6. public class Test {
  7. /**
  8. * @param args
  9. * @throws IOException
  10. */
  11. public static void main(String[] args) throws IOException {
  12. BP bp = new BP(32, 15, 4);
  13. Random random = new Random();
  14. List<Integer> list = new ArrayList<Integer>();
  15. for (int i = 0; i != 1000; i++) {
  16. int value = random.nextInt();
  17. list.add(value);
  18. }
  19. for (int i = 0; i != 200; i++) {
  20. for (int value : list) {
  21. double[] real = new double[4];
  22. if (value >= 0)
  23. if ((value & 1) == 1)
  24. real[0] = 1;
  25. else
  26. real[1] = 1;
  27. else if ((value & 1) == 1)
  28. real[2] = 1;
  29. else
  30. real[3] = 1;
  31. double[] binary = new double[32];
  32. int index = 31;
  33. do {
  34. binary[index--] = (value & 1);
  35. value >>>= 1;
  36. } while (value != 0);
  37. bp.train(binary, real);
  38. }
  39. }
  40. System.out.println("训练完毕,下面请输入一个任意数字,神经网络将自动判断它是正数还是复数,奇数还是偶数。");
  41. while (true) {
  42. byte[] input = new byte[10];
  43. System.in.read(input);
  44. Integer value = Integer.parseInt(new String(input).trim());
  45. int rawVal = value;
  46. double[] binary = new double[32];
  47. int index = 31;
  48. do {
  49. binary[index--] = (value & 1);
  50. value >>>= 1;
  51. } while (value != 0);
  52. double[] result = bp.test(binary);
  53. double max = -Integer.MIN_VALUE;
  54. int idx = -1;
  55. for (int i = 0; i != result.length; i++) {
  56. if (result[i] > max) {
  57. max = result[i];
  58. idx = i;
  59. }
  60. }
  61. switch (idx) {
  62. case 0:
  63. System.out.format("%d是一个正奇数\n", rawVal);
  64. break;
  65. case 1:
  66. System.out.format("%d是一个正偶数\n", rawVal);
  67. break;
  68. case 2:
  69. System.out.format("%d是一个负奇数\n", rawVal);
  70. break;
  71. case 3:
  72. System.out.format("%d是一个负偶数\n", rawVal);
  73. break;
  74. }
  75. }
  76. }
  77. }

运行结果截图如下:


 这个测试的例子非常简单。大家可以根据自己的需要去使用BP这个类。

BP神经网络的Java实现(转)的更多相关文章

  1. BP神经网络的Java实现(转载)

    神经网络的计算过程 神经网络结构如下图所示,最左边的是输入层,最右边的是输出层,中间是多个隐含层,隐含层和输出层的每个神经节点,都是由上一层节点乘以其权重累加得到,标上“+1”的圆圈为截距项b,对输入 ...

  2. BP神经网络的Java实现

    http://fantasticinblur.iteye.com/blog/1465497

  3. BP神经网络的直观推导与Java实现

    人工神经网络模拟人体对于外界刺激的反应.某种刺激经过人体多层神经细胞传递后,可以触发人脑中特定的区域做出反应.人体神经网络的作用就是把某种刺激与大脑中的特定区域关联起来了,这样我们对于不同的刺激就可以 ...

  4. 用java写bp神经网络(一)

    根据前篇博文<神经网络之后向传播算法>,现在用java实现一个bp神经网络.矩阵运算采用jblas库,然后逐渐增加功能,支持并行计算,然后支持输入向量调整,最后支持L-BFGS学习算法. ...

  5. BP神经网络—java实现(转载)

    神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...

  6. BP神经网络—java实现

    神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...

  7. JAVA实现BP神经网络算法

    工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测. 简介 BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法 ...

  8. 数据挖掘系列(9)——BP神经网络算法与实践

    神经网络曾经很火,有过一段低迷期,现在因为深度学习的原因继续火起来了.神经网络有很多种:前向传输网络.反向传输网络.递归神经网络.卷积神经网络等.本文介绍基本的反向传输神经网络(Backpropaga ...

  9. BP神经网络的数学原理及其算法实现

    什么是BP网络 BP网络的数学原理 BP网络算法实现 转载请声明出处http://blog.csdn.net/zhongkejingwang/article/details/44514073  上一篇 ...

随机推荐

  1. 问题记录 为ubuntu16.04添加windows字体(解决JIRA图表乱码的问题)

    最近遇到了JIRA在新的ubuntu机器上图表的中文无法正确显示的问题,解决的方法是,为ubuntu安装中文字体,我们选择把windows上的字体复制到ubuntu上来安装的方法,步骤如下: 从win ...

  2. C++ vs Python向量运算速度评测

    本文的起源来自最近一个让我非常不爽的事. 我最近在改一个开源RNN工具包currennt(http://sourceforge.net/projects/currennt/),想用它实现RNNLM功能 ...

  3. MYSQL中GROUP BY不包含所有的非聚合字段时的注意事项

    本文导读:在MYSQL中使用GROUP BY分组时,我们可以select 多个非聚合字段,但是这些字段不在GROUP BY中,这样的SQL查询在SQL SERVER.ORACLE中是不合理的,且会报错 ...

  4. ElasticSearch在linux上安装部署(转)

    一.安装准备工作安装参考文档: ELK官网:https://www.elastic.co/ ELK官网文档:https://www.elastic.co/guide/index.html ELK中文手 ...

  5. java8新特性之Optional类

    NullPointException可以说是所有java程序员都遇到过的一个异常,虽然java从设计之初就力图让程序员脱离指针的苦海,但是指针确实是实际存在的,而java设计者也只能是让指针在java ...

  6. CodeForces - 779D String Game 常规二分

    题意:给你两个串,S2是S1 的一个子串(可以不连续).给你一个s1字符下标的一个排列,按照这个数列删数,问你最多删到第几个时S2仍是S1 的一个子串. 题解:二分删掉的数.判定函数很好写和单调性也可 ...

  7. LAMP下安装zabbix流水

    一.安装zabbix (1)创建用户和组 [root@dbking zabbix-2.2.1]# groupadd zabbix [root@dbking zabbix-2.2.1]# useradd ...

  8. Oracle备份恢复之无备份情况下恢复undo表空间

    UNDO表空间存储着DML操作数据块的前镜像数据,在数据回滚,一致性读,闪回操作,实例恢复的时候都可能用到UNDO表空间中的数据.如果在生产过程中丢失或破坏了UNDO表空间,可能导致某些事务无法回滚, ...

  9. Oracle安装部署之dbca静默建库和删除库

    dbca查看帮助: [oracle@wen ~]$ dbca -help 1).运行静默建库语句 [oracle@wen ~]$ dbca -silent -cloneTemplate -gdbNam ...

  10. Recv-Q&Send-Q

    最近线上某些服务器老是报cpu load高,同机房其他机器却没有问题.排查发现以下异常 ss -nl Recv-Q Send-Q                 Local Address:Port ...