http://fantasticinblur.iteye.com/blog/1465497

课程作业要求实现一个BPNN。这次尝试使用Java实现了一个。现共享之。版权属于大家。关于BPNN的原理,就不赘述了。

下面是BPNN的实现代码。类名为BP。

  1. package ml;
  2. import java.util.Random;
  3. /**
  4. * BPNN.
  5. *
  6. * @author RenaQiu
  7. *
  8. */
  9. public class BP {
  10. /**
  11. * input vector.
  12. */
  13. private final double[] input;
  14. /**
  15. * hidden layer.
  16. */
  17. private final double[] hidden;
  18. /**
  19. * output layer.
  20. */
  21. private final double[] output;
  22. /**
  23. * target.
  24. */
  25. private final double[] target;
  26. /**
  27. * delta vector of the hidden layer .
  28. */
  29. private final double[] hidDelta;
  30. /**
  31. * output layer of the output layer.
  32. */
  33. private final double[] optDelta;
  34. /**
  35. * learning rate.
  36. */
  37. private final double eta;
  38. /**
  39. * momentum.
  40. */
  41. private final double momentum;
  42. /**
  43. * weight matrix from input layer to hidden layer.
  44. */
  45. private final double[][] iptHidWeights;
  46. /**
  47. * weight matrix from hidden layer to output layer.
  48. */
  49. private final double[][] hidOptWeights;
  50. /**
  51. * previous weight update.
  52. */
  53. private final double[][] iptHidPrevUptWeights;
  54. /**
  55. * previous weight update.
  56. */
  57. private final double[][] hidOptPrevUptWeights;
  58. public double optErrSum = 0d;
  59. public double hidErrSum = 0d;
  60. private final Random random;
  61. /**
  62. * Constructor.
  63. * <p>
  64. * <strong>Note:</strong> The capacity of each layer will be the parameter
  65. * plus 1. The additional unit is used for smoothness.
  66. * </p>
  67. *
  68. * @param inputSize
  69. * @param hiddenSize
  70. * @param outputSize
  71. * @param eta
  72. * @param momentum
  73. * @param epoch
  74. */
  75. public BP(int inputSize, int hiddenSize, int outputSize, double eta,
  76. double momentum) {
  77. input = new double[inputSize + 1];
  78. hidden = new double[hiddenSize + 1];
  79. output = new double[outputSize + 1];
  80. target = new double[outputSize + 1];
  81. hidDelta = new double[hiddenSize + 1];
  82. optDelta = new double[outputSize + 1];
  83. iptHidWeights = new double[inputSize + 1][hiddenSize + 1];
  84. hidOptWeights = new double[hiddenSize + 1][outputSize + 1];
  85. random = new Random(19881211);
  86. randomizeWeights(iptHidWeights);
  87. randomizeWeights(hidOptWeights);
  88. iptHidPrevUptWeights = new double[inputSize + 1][hiddenSize + 1];
  89. hidOptPrevUptWeights = new double[hiddenSize + 1][outputSize + 1];
  90. this.eta = eta;
  91. this.momentum = momentum;
  92. }
  93. private void randomizeWeights(double[][] matrix) {
  94. for (int i = 0, len = matrix.length; i != len; i++)
  95. for (int j = 0, len2 = matrix[i].length; j != len2; j++) {
  96. double real = random.nextDouble();
  97. matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;
  98. }
  99. }
  100. /**
  101. * Constructor with default eta = 0.25 and momentum = 0.3.
  102. *
  103. * @param inputSize
  104. * @param hiddenSize
  105. * @param outputSize
  106. * @param epoch
  107. */
  108. public BP(int inputSize, int hiddenSize, int outputSize) {
  109. this(inputSize, hiddenSize, outputSize, 0.25, 0.9);
  110. }
  111. /**
  112. * Entry method. The train data should be a one-dim vector.
  113. *
  114. * @param trainData
  115. * @param target
  116. */
  117. public void train(double[] trainData, double[] target) {
  118. loadInput(trainData);
  119. loadTarget(target);
  120. forward();
  121. calculateDelta();
  122. adjustWeight();
  123. }
  124. /**
  125. * Test the BPNN.
  126. *
  127. * @param inData
  128. * @return
  129. */
  130. public double[] test(double[] inData) {
  131. if (inData.length != input.length - 1) {
  132. throw new IllegalArgumentException("Size Do Not Match.");
  133. }
  134. System.arraycopy(inData, 0, input, 1, inData.length);
  135. forward();
  136. return getNetworkOutput();
  137. }
  138. /**
  139. * Return the output layer.
  140. *
  141. * @return
  142. */
  143. private double[] getNetworkOutput() {
  144. int len = output.length;
  145. double[] temp = new double[len - 1];
  146. for (int i = 1; i != len; i++)
  147. temp[i - 1] = output[i];
  148. return temp;
  149. }
  150. /**
  151. * Load the target data.
  152. *
  153. * @param arg
  154. */
  155. private void loadTarget(double[] arg) {
  156. if (arg.length != target.length - 1) {
  157. throw new IllegalArgumentException("Size Do Not Match.");
  158. }
  159. System.arraycopy(arg, 0, target, 1, arg.length);
  160. }
  161. /**
  162. * Load the training data.
  163. *
  164. * @param inData
  165. */
  166. private void loadInput(double[] inData) {
  167. if (inData.length != input.length - 1) {
  168. throw new IllegalArgumentException("Size Do Not Match.");
  169. }
  170. System.arraycopy(inData, 0, input, 1, inData.length);
  171. }
  172. /**
  173. * Forward.
  174. *
  175. * @param layer0
  176. * @param layer1
  177. * @param weight
  178. */
  179. private void forward(double[] layer0, double[] layer1, double[][] weight) {
  180. // threshold unit.
  181. layer0[0] = 1.0;
  182. for (int j = 1, len = layer1.length; j != len; ++j) {
  183. double sum = 0;
  184. for (int i = 0, len2 = layer0.length; i != len2; ++i)
  185. sum += weight[i][j] * layer0[i];
  186. layer1[j] = sigmoid(sum);
  187. }
  188. }
  189. /**
  190. * Forward.
  191. */
  192. private void forward() {
  193. forward(input, hidden, iptHidWeights);
  194. forward(hidden, output, hidOptWeights);
  195. }
  196. /**
  197. * Calculate output error.
  198. */
  199. private void outputErr() {
  200. double errSum = 0;
  201. for (int idx = 1, len = optDelta.length; idx != len; ++idx) {
  202. double o = output[idx];
  203. optDelta[idx] = o * (1d - o) * (target[idx] - o);
  204. errSum += Math.abs(optDelta[idx]);
  205. }
  206. optErrSum = errSum;
  207. }
  208. /**
  209. * Calculate hidden errors.
  210. */
  211. private void hiddenErr() {
  212. double errSum = 0;
  213. for (int j = 1, len = hidDelta.length; j != len; ++j) {
  214. double o = hidden[j];
  215. double sum = 0;
  216. for (int k = 1, len2 = optDelta.length; k != len2; ++k)
  217. sum += hidOptWeights[j][k] * optDelta[k];
  218. hidDelta[j] = o * (1d - o) * sum;
  219. errSum += Math.abs(hidDelta[j]);
  220. }
  221. hidErrSum = errSum;
  222. }
  223. /**
  224. * Calculate errors of all layers.
  225. */
  226. private void calculateDelta() {
  227. outputErr();
  228. hiddenErr();
  229. }
  230. /**
  231. * Adjust the weight matrix.
  232. *
  233. * @param delta
  234. * @param layer
  235. * @param weight
  236. * @param prevWeight
  237. */
  238. private void adjustWeight(double[] delta, double[] layer,
  239. double[][] weight, double[][] prevWeight) {
  240. layer[0] = 1;
  241. for (int i = 1, len = delta.length; i != len; ++i) {
  242. for (int j = 0, len2 = layer.length; j != len2; ++j) {
  243. double newVal = momentum * prevWeight[j][i] + eta * delta[i]
  244. * layer[j];
  245. weight[j][i] += newVal;
  246. prevWeight[j][i] = newVal;
  247. }
  248. }
  249. }
  250. /**
  251. * Adjust all weight matrices.
  252. */
  253. private void adjustWeight() {
  254. adjustWeight(optDelta, hidden, hidOptWeights, hidOptPrevUptWeights);
  255. adjustWeight(hidDelta, input, iptHidWeights, iptHidPrevUptWeights);
  256. }
  257. /**
  258. * Sigmoid.
  259. *
  260. * @param val
  261. * @return
  262. */
  263. private double sigmoid(double val) {
  264. return 1d / (1d + Math.exp(-val));
  265. }
  266. }

为了验证正确性,我写了一个测试用例,目的是对于任意的整数(int型),BPNN在经过训练之后,能够准确地判断出它是奇数还是偶数,正数还是负数。首先对于训练的样本(是随机生成的数字),将它转化为一个32位的向量,向量的每个分量就是其二进制形式对应的位上的0或1。将目标输出视作一个4维的向量,[1,0,0,0]代表正奇数,[0,1,0,0]代表正偶数,[0,0,1,0]代表负奇数,[0,0,0,1]代表负偶数。

训练样本为1000个,学习200次。

  1. package ml;
  2. import java.io.IOException;
  3. import java.util.ArrayList;
  4. import java.util.List;
  5. import java.util.Random;
  6. public class Test {
  7. /**
  8. * @param args
  9. * @throws IOException
  10. */
  11. public static void main(String[] args) throws IOException {
  12. BP bp = new BP(32, 15, 4);
  13. Random random = new Random();
  14. List<Integer> list = new ArrayList<Integer>();
  15. for (int i = 0; i != 1000; i++) {
  16. int value = random.nextInt();
  17. list.add(value);
  18. }
  19. for (int i = 0; i != 200; i++) {
  20. for (int value : list) {
  21. double[] real = new double[4];
  22. if (value >= 0)
  23. if ((value & 1) == 1)
  24. real[0] = 1;
  25. else
  26. real[1] = 1;
  27. else if ((value & 1) == 1)
  28. real[2] = 1;
  29. else
  30. real[3] = 1;
  31. double[] binary = new double[32];
  32. int index = 31;
  33. do {
  34. binary[index--] = (value & 1);
  35. value >>>= 1;
  36. } while (value != 0);
  37. bp.train(binary, real);
  38. }
  39. }
  40. System.out.println("训练完毕,下面请输入一个任意数字,神经网络将自动判断它是正数还是复数,奇数还是偶数。");
  41. while (true) {
  42. byte[] input = new byte[10];
  43. System.in.read(input);
  44. Integer value = Integer.parseInt(new String(input).trim());
  45. int rawVal = value;
  46. double[] binary = new double[32];
  47. int index = 31;
  48. do {
  49. binary[index--] = (value & 1);
  50. value >>>= 1;
  51. } while (value != 0);
  52. double[] result = bp.test(binary);
  53. double max = -Integer.MIN_VALUE;
  54. int idx = -1;
  55. for (int i = 0; i != result.length; i++) {
  56. if (result[i] > max) {
  57. max = result[i];
  58. idx = i;
  59. }
  60. }
  61. switch (idx) {
  62. case 0:
  63. System.out.format("%d是一个正奇数\n", rawVal);
  64. break;
  65. case 1:
  66. System.out.format("%d是一个正偶数\n", rawVal);
  67. break;
  68. case 2:
  69. System.out.format("%d是一个负奇数\n", rawVal);
  70. break;
  71. case 3:
  72. System.out.format("%d是一个负偶数\n", rawVal);
  73. break;
  74. }
  75. }
  76. }
  77. }

运行结果截图如下:


 这个测试的例子非常简单。大家可以根据自己的需要去使用BP这个类。

BP神经网络的Java实现(转)的更多相关文章

  1. BP神经网络的Java实现(转载)

    神经网络的计算过程 神经网络结构如下图所示,最左边的是输入层,最右边的是输出层,中间是多个隐含层,隐含层和输出层的每个神经节点,都是由上一层节点乘以其权重累加得到,标上“+1”的圆圈为截距项b,对输入 ...

  2. BP神经网络的Java实现

    http://fantasticinblur.iteye.com/blog/1465497

  3. BP神经网络的直观推导与Java实现

    人工神经网络模拟人体对于外界刺激的反应.某种刺激经过人体多层神经细胞传递后,可以触发人脑中特定的区域做出反应.人体神经网络的作用就是把某种刺激与大脑中的特定区域关联起来了,这样我们对于不同的刺激就可以 ...

  4. 用java写bp神经网络(一)

    根据前篇博文<神经网络之后向传播算法>,现在用java实现一个bp神经网络.矩阵运算采用jblas库,然后逐渐增加功能,支持并行计算,然后支持输入向量调整,最后支持L-BFGS学习算法. ...

  5. BP神经网络—java实现(转载)

    神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...

  6. BP神经网络—java实现

    神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...

  7. JAVA实现BP神经网络算法

    工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测. 简介 BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法 ...

  8. 数据挖掘系列(9)——BP神经网络算法与实践

    神经网络曾经很火,有过一段低迷期,现在因为深度学习的原因继续火起来了.神经网络有很多种:前向传输网络.反向传输网络.递归神经网络.卷积神经网络等.本文介绍基本的反向传输神经网络(Backpropaga ...

  9. BP神经网络的数学原理及其算法实现

    什么是BP网络 BP网络的数学原理 BP网络算法实现 转载请声明出处http://blog.csdn.net/zhongkejingwang/article/details/44514073  上一篇 ...

随机推荐

  1. vue-resource使用 (vue仿百度搜索)

    1.this.$http.get()方法2.this.$http.post()方法3.this.$http.jsonp()方法 (vue仿百度搜索) 在输入框中输入a, 然后在百度f12 ==> ...

  2. 为gitlab10.x增加使用remote_user HTTP头的方式登录

    项目的结构是这样的: 客户端通过Apache来访问后端的gitlab(gitlab的版本是10.4,手动从源码安装的简体中文版) , Apache作为gitlab的反向代理服务器 Apache内置了C ...

  3. (TOJ 4413)IP address

    描述 To give you an IP address, it may be dotted decimal IP address, it may be 32-bit binary IP addres ...

  4. h5页面弹窗滚动穿透的思考

    可能我们经常做这样的弹窗对吧,兴许我们绝对很简单,两下搞定: 弹窗的页面结构代码: <!-- 弹窗模块 引用时移除static_tip类--> <div class="ma ...

  5. oracle的日期相减

    oracle的日期相减 : 两个date类型的 日期相减,得到的是天数,可能是带小数点的.如下:

  6. Shell转义字符与变量替换

    转义字符 含义 \\ 反斜杠 \a 警报,响铃 \b 退格(删除键) \f 换页(FF),将当前位置移到下页开头 \n 换行 \r 回车 \t 水平制表符(tab键)  \v 垂直制表符 vim te ...

  7. R的transform

    函数transform 作用:为原数据框添加新的列,改变原变量列的值,通过赋值NULL删除列变量 用法: transform(‘data’,….) data就是要修改的data,  '…..'代表你要 ...

  8. 9.10Django模板

    2018-9-10 16:37:29 模板就一个 不能嵌套 模板:  http://www.cnblogs.com/liwenzhou/p/7931828.html 2018-9-10 21:23:3 ...

  9. ELK之生产日志收集构架(filebeat-logstash-redis-logstash-elasticsearch-kibana)

    本次构架图如下 说明: 1,前端服务器只启动轻量级日志收集工具filebeat(不需要JDK环境) 2,收集的日志不进过处理直接发送到redis消息队列 3,redis消息队列只是暂时存储日志数据,不 ...

  10. 详解Oracle多种表连接方式

    1. 内连接(自然连接) 2. 外连接 (1)左外连接 (左边的表不加限制) (2)右外连接(右边的表不加限制) (3)全外连接(左右两表都不加限制) 3. 自连接(同一张表内的连接) SQL的标准语 ...