http://fantasticinblur.iteye.com/blog/1465497

课程作业要求实现一个BPNN。这次尝试使用Java实现了一个。现共享之。版权属于大家。关于BPNN的原理,就不赘述了。

下面是BPNN的实现代码。类名为BP。

  1. package ml;
  2. import java.util.Random;
  3. /**
  4. * BPNN.
  5. *
  6. * @author RenaQiu
  7. *
  8. */
  9. public class BP {
  10. /**
  11. * input vector.
  12. */
  13. private final double[] input;
  14. /**
  15. * hidden layer.
  16. */
  17. private final double[] hidden;
  18. /**
  19. * output layer.
  20. */
  21. private final double[] output;
  22. /**
  23. * target.
  24. */
  25. private final double[] target;
  26. /**
  27. * delta vector of the hidden layer .
  28. */
  29. private final double[] hidDelta;
  30. /**
  31. * output layer of the output layer.
  32. */
  33. private final double[] optDelta;
  34. /**
  35. * learning rate.
  36. */
  37. private final double eta;
  38. /**
  39. * momentum.
  40. */
  41. private final double momentum;
  42. /**
  43. * weight matrix from input layer to hidden layer.
  44. */
  45. private final double[][] iptHidWeights;
  46. /**
  47. * weight matrix from hidden layer to output layer.
  48. */
  49. private final double[][] hidOptWeights;
  50. /**
  51. * previous weight update.
  52. */
  53. private final double[][] iptHidPrevUptWeights;
  54. /**
  55. * previous weight update.
  56. */
  57. private final double[][] hidOptPrevUptWeights;
  58. public double optErrSum = 0d;
  59. public double hidErrSum = 0d;
  60. private final Random random;
  61. /**
  62. * Constructor.
  63. * <p>
  64. * <strong>Note:</strong> The capacity of each layer will be the parameter
  65. * plus 1. The additional unit is used for smoothness.
  66. * </p>
  67. *
  68. * @param inputSize
  69. * @param hiddenSize
  70. * @param outputSize
  71. * @param eta
  72. * @param momentum
  73. * @param epoch
  74. */
  75. public BP(int inputSize, int hiddenSize, int outputSize, double eta,
  76. double momentum) {
  77. input = new double[inputSize + 1];
  78. hidden = new double[hiddenSize + 1];
  79. output = new double[outputSize + 1];
  80. target = new double[outputSize + 1];
  81. hidDelta = new double[hiddenSize + 1];
  82. optDelta = new double[outputSize + 1];
  83. iptHidWeights = new double[inputSize + 1][hiddenSize + 1];
  84. hidOptWeights = new double[hiddenSize + 1][outputSize + 1];
  85. random = new Random(19881211);
  86. randomizeWeights(iptHidWeights);
  87. randomizeWeights(hidOptWeights);
  88. iptHidPrevUptWeights = new double[inputSize + 1][hiddenSize + 1];
  89. hidOptPrevUptWeights = new double[hiddenSize + 1][outputSize + 1];
  90. this.eta = eta;
  91. this.momentum = momentum;
  92. }
  93. private void randomizeWeights(double[][] matrix) {
  94. for (int i = 0, len = matrix.length; i != len; i++)
  95. for (int j = 0, len2 = matrix[i].length; j != len2; j++) {
  96. double real = random.nextDouble();
  97. matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;
  98. }
  99. }
  100. /**
  101. * Constructor with default eta = 0.25 and momentum = 0.3.
  102. *
  103. * @param inputSize
  104. * @param hiddenSize
  105. * @param outputSize
  106. * @param epoch
  107. */
  108. public BP(int inputSize, int hiddenSize, int outputSize) {
  109. this(inputSize, hiddenSize, outputSize, 0.25, 0.9);
  110. }
  111. /**
  112. * Entry method. The train data should be a one-dim vector.
  113. *
  114. * @param trainData
  115. * @param target
  116. */
  117. public void train(double[] trainData, double[] target) {
  118. loadInput(trainData);
  119. loadTarget(target);
  120. forward();
  121. calculateDelta();
  122. adjustWeight();
  123. }
  124. /**
  125. * Test the BPNN.
  126. *
  127. * @param inData
  128. * @return
  129. */
  130. public double[] test(double[] inData) {
  131. if (inData.length != input.length - 1) {
  132. throw new IllegalArgumentException("Size Do Not Match.");
  133. }
  134. System.arraycopy(inData, 0, input, 1, inData.length);
  135. forward();
  136. return getNetworkOutput();
  137. }
  138. /**
  139. * Return the output layer.
  140. *
  141. * @return
  142. */
  143. private double[] getNetworkOutput() {
  144. int len = output.length;
  145. double[] temp = new double[len - 1];
  146. for (int i = 1; i != len; i++)
  147. temp[i - 1] = output[i];
  148. return temp;
  149. }
  150. /**
  151. * Load the target data.
  152. *
  153. * @param arg
  154. */
  155. private void loadTarget(double[] arg) {
  156. if (arg.length != target.length - 1) {
  157. throw new IllegalArgumentException("Size Do Not Match.");
  158. }
  159. System.arraycopy(arg, 0, target, 1, arg.length);
  160. }
  161. /**
  162. * Load the training data.
  163. *
  164. * @param inData
  165. */
  166. private void loadInput(double[] inData) {
  167. if (inData.length != input.length - 1) {
  168. throw new IllegalArgumentException("Size Do Not Match.");
  169. }
  170. System.arraycopy(inData, 0, input, 1, inData.length);
  171. }
  172. /**
  173. * Forward.
  174. *
  175. * @param layer0
  176. * @param layer1
  177. * @param weight
  178. */
  179. private void forward(double[] layer0, double[] layer1, double[][] weight) {
  180. // threshold unit.
  181. layer0[0] = 1.0;
  182. for (int j = 1, len = layer1.length; j != len; ++j) {
  183. double sum = 0;
  184. for (int i = 0, len2 = layer0.length; i != len2; ++i)
  185. sum += weight[i][j] * layer0[i];
  186. layer1[j] = sigmoid(sum);
  187. }
  188. }
  189. /**
  190. * Forward.
  191. */
  192. private void forward() {
  193. forward(input, hidden, iptHidWeights);
  194. forward(hidden, output, hidOptWeights);
  195. }
  196. /**
  197. * Calculate output error.
  198. */
  199. private void outputErr() {
  200. double errSum = 0;
  201. for (int idx = 1, len = optDelta.length; idx != len; ++idx) {
  202. double o = output[idx];
  203. optDelta[idx] = o * (1d - o) * (target[idx] - o);
  204. errSum += Math.abs(optDelta[idx]);
  205. }
  206. optErrSum = errSum;
  207. }
  208. /**
  209. * Calculate hidden errors.
  210. */
  211. private void hiddenErr() {
  212. double errSum = 0;
  213. for (int j = 1, len = hidDelta.length; j != len; ++j) {
  214. double o = hidden[j];
  215. double sum = 0;
  216. for (int k = 1, len2 = optDelta.length; k != len2; ++k)
  217. sum += hidOptWeights[j][k] * optDelta[k];
  218. hidDelta[j] = o * (1d - o) * sum;
  219. errSum += Math.abs(hidDelta[j]);
  220. }
  221. hidErrSum = errSum;
  222. }
  223. /**
  224. * Calculate errors of all layers.
  225. */
  226. private void calculateDelta() {
  227. outputErr();
  228. hiddenErr();
  229. }
  230. /**
  231. * Adjust the weight matrix.
  232. *
  233. * @param delta
  234. * @param layer
  235. * @param weight
  236. * @param prevWeight
  237. */
  238. private void adjustWeight(double[] delta, double[] layer,
  239. double[][] weight, double[][] prevWeight) {
  240. layer[0] = 1;
  241. for (int i = 1, len = delta.length; i != len; ++i) {
  242. for (int j = 0, len2 = layer.length; j != len2; ++j) {
  243. double newVal = momentum * prevWeight[j][i] + eta * delta[i]
  244. * layer[j];
  245. weight[j][i] += newVal;
  246. prevWeight[j][i] = newVal;
  247. }
  248. }
  249. }
  250. /**
  251. * Adjust all weight matrices.
  252. */
  253. private void adjustWeight() {
  254. adjustWeight(optDelta, hidden, hidOptWeights, hidOptPrevUptWeights);
  255. adjustWeight(hidDelta, input, iptHidWeights, iptHidPrevUptWeights);
  256. }
  257. /**
  258. * Sigmoid.
  259. *
  260. * @param val
  261. * @return
  262. */
  263. private double sigmoid(double val) {
  264. return 1d / (1d + Math.exp(-val));
  265. }
  266. }

为了验证正确性,我写了一个测试用例,目的是对于任意的整数(int型),BPNN在经过训练之后,能够准确地判断出它是奇数还是偶数,正数还是负数。首先对于训练的样本(是随机生成的数字),将它转化为一个32位的向量,向量的每个分量就是其二进制形式对应的位上的0或1。将目标输出视作一个4维的向量,[1,0,0,0]代表正奇数,[0,1,0,0]代表正偶数,[0,0,1,0]代表负奇数,[0,0,0,1]代表负偶数。

训练样本为1000个,学习200次。

  1. package ml;
  2. import java.io.IOException;
  3. import java.util.ArrayList;
  4. import java.util.List;
  5. import java.util.Random;
  6. public class Test {
  7. /**
  8. * @param args
  9. * @throws IOException
  10. */
  11. public static void main(String[] args) throws IOException {
  12. BP bp = new BP(32, 15, 4);
  13. Random random = new Random();
  14. List<Integer> list = new ArrayList<Integer>();
  15. for (int i = 0; i != 1000; i++) {
  16. int value = random.nextInt();
  17. list.add(value);
  18. }
  19. for (int i = 0; i != 200; i++) {
  20. for (int value : list) {
  21. double[] real = new double[4];
  22. if (value >= 0)
  23. if ((value & 1) == 1)
  24. real[0] = 1;
  25. else
  26. real[1] = 1;
  27. else if ((value & 1) == 1)
  28. real[2] = 1;
  29. else
  30. real[3] = 1;
  31. double[] binary = new double[32];
  32. int index = 31;
  33. do {
  34. binary[index--] = (value & 1);
  35. value >>>= 1;
  36. } while (value != 0);
  37. bp.train(binary, real);
  38. }
  39. }
  40. System.out.println("训练完毕,下面请输入一个任意数字,神经网络将自动判断它是正数还是复数,奇数还是偶数。");
  41. while (true) {
  42. byte[] input = new byte[10];
  43. System.in.read(input);
  44. Integer value = Integer.parseInt(new String(input).trim());
  45. int rawVal = value;
  46. double[] binary = new double[32];
  47. int index = 31;
  48. do {
  49. binary[index--] = (value & 1);
  50. value >>>= 1;
  51. } while (value != 0);
  52. double[] result = bp.test(binary);
  53. double max = -Integer.MIN_VALUE;
  54. int idx = -1;
  55. for (int i = 0; i != result.length; i++) {
  56. if (result[i] > max) {
  57. max = result[i];
  58. idx = i;
  59. }
  60. }
  61. switch (idx) {
  62. case 0:
  63. System.out.format("%d是一个正奇数\n", rawVal);
  64. break;
  65. case 1:
  66. System.out.format("%d是一个正偶数\n", rawVal);
  67. break;
  68. case 2:
  69. System.out.format("%d是一个负奇数\n", rawVal);
  70. break;
  71. case 3:
  72. System.out.format("%d是一个负偶数\n", rawVal);
  73. break;
  74. }
  75. }
  76. }
  77. }

运行结果截图如下:


 这个测试的例子非常简单。大家可以根据自己的需要去使用BP这个类。

BP神经网络的Java实现(转)的更多相关文章

  1. BP神经网络的Java实现(转载)

    神经网络的计算过程 神经网络结构如下图所示,最左边的是输入层,最右边的是输出层,中间是多个隐含层,隐含层和输出层的每个神经节点,都是由上一层节点乘以其权重累加得到,标上“+1”的圆圈为截距项b,对输入 ...

  2. BP神经网络的Java实现

    http://fantasticinblur.iteye.com/blog/1465497

  3. BP神经网络的直观推导与Java实现

    人工神经网络模拟人体对于外界刺激的反应.某种刺激经过人体多层神经细胞传递后,可以触发人脑中特定的区域做出反应.人体神经网络的作用就是把某种刺激与大脑中的特定区域关联起来了,这样我们对于不同的刺激就可以 ...

  4. 用java写bp神经网络(一)

    根据前篇博文<神经网络之后向传播算法>,现在用java实现一个bp神经网络.矩阵运算采用jblas库,然后逐渐增加功能,支持并行计算,然后支持输入向量调整,最后支持L-BFGS学习算法. ...

  5. BP神经网络—java实现(转载)

    神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...

  6. BP神经网络—java实现

    神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...

  7. JAVA实现BP神经网络算法

    工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测. 简介 BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法 ...

  8. 数据挖掘系列(9)——BP神经网络算法与实践

    神经网络曾经很火,有过一段低迷期,现在因为深度学习的原因继续火起来了.神经网络有很多种:前向传输网络.反向传输网络.递归神经网络.卷积神经网络等.本文介绍基本的反向传输神经网络(Backpropaga ...

  9. BP神经网络的数学原理及其算法实现

    什么是BP网络 BP网络的数学原理 BP网络算法实现 转载请声明出处http://blog.csdn.net/zhongkejingwang/article/details/44514073  上一篇 ...

随机推荐

  1. tars环境部署

    author: headsen  chen date: 2018-10-18 12:35:40 注意:依据Git上的tars搭建步骤整理而来 参考: https://max.book118.com/h ...

  2. TypeScript中处理大数字(会丢失后面部分数字)

    为啥要弄这玩意? 最近做数值游戏,需要用到很大的数字,在前端大数字会自动变成e的科学计数法. 有啥问题? 问题: 1. 在传递给服务端时,服务端因为不能处理大数字(怎么就处理不了?!),就想要我传字符 ...

  3. 【JSP】JSP中的Java脚本

    前言 现代Web开发中,在JSP中嵌入Java脚本不是推荐的做法,因为这样 不利于代码的维护.有很多好的,替代的方法避免在JSP中写Java脚本.本文仅做为JSP体系技术的一个了解.     类成员定 ...

  4. virgo-tomcat-server的生产环境线上配置与管理

    Virgo Tomcat Server简称VTS,VTS是一个应用服务器,它是轻量级, 模块化, 基于OSGi系统.与OSGi紧密结合并且可以开发bundles形式的Spring web apps应用 ...

  5. python之traceback

    traceback 模块允许你在程序里打印异常的跟踪返回 (Traceback)信息 1.1 traceback.print_exc() File: traceback-example-1.py # ...

  6. iOS SwiftMonkey 随机暴力测试

    参考源文章 https://github.com/zalando/SwiftMonkey https://kemchenj.github.io/2017/03/16/2017-03-16/ 简介 这个 ...

  7. MongoDB 学习笔记2----条件操作符

    条件操作符:用于两个比较两个表达式并从mongdb中获取文档 mongodb常见的操作符及解析说明 $lt:小于 example:ago<20 $lte:小于等于 example:<=20 ...

  8. ELK之filebate收集日志传递至Logstash

    软件版本查看(版本最好一致) 安装过程不详叙 本次使用filebeat监控nginx日志(已经配置json输出)收集并且传递给Logstash进行处理 filebeat配置文件/etc/filebea ...

  9. Ucloud云主机无法yum安装处理办法

    Ucloud云主机在yum安装的时候出现这个提示 执行一下命令 yum --disablerepo=salttestyum-config-manager --disable salttestyum-c ...

  10. stat命令的实现-mysate

    任务详情 学习使用stat(1),并用C语言实现 提交学习stat(1)的截图 man -k,grep -r的使用 伪代码 产品代码mystate.c,提交码云链接 测试代码,mysate与stat( ...