根据前篇博文《神经网络之后向传播算法》,现在用java实现一个bp神经网络。矩阵运算采用jblas库,然后逐渐增加功能,支持并行计算,然后支持输入向量调整,最后支持L-BFGS学习算法。

上帝说,要有神经网络,于是,便有了一个神经网络。上帝还说,神经网络要有节点,权重,激活函数,输出函数,目标函数,然后也许还要有一个准确率函数,于是,神经网络完成了:

  1. public class Net {
  2. List<DoubleMatrix> weights = new ArrayList<DoubleMatrix>();
  3. List<DoubleMatrix> bs = new ArrayList<>();
  4. List<ScalarDifferentiableFunction> activations = new ArrayList<>();
  5. CostFunctionFactory costFunc;
  6. CostFunctionFactory accuracyFunc;
  7. int[] nodesNum;
  8. int layersNum;
  9. public Net(int[] nodesNum, ScalarDifferentiableFunction[] activations,CostFunctionFactory costFunc) {
  10. super();
  11. this.initNet(nodesNum, activations);
  12. this.costFunc=costFunc;
  13. this.layersNum=nodesNum.length-1;
  14. }
  15.  
  16. public Net(int[] nodesNum, ScalarDifferentiableFunction[] activations,CostFunctionFactory costFunc,CostFunctionFactory accuracyFunc) {
  17. this(nodesNum,activations,costFunc);
  18. this.accuracyFunc=accuracyFunc;
  19. }
  20. public void resetNet() {
  21. this.initNet(nodesNum, (ScalarDifferentiableFunction[]) activations.toArray());
  22. }
  23.  
  24. private void initNet(int[] nodesNum, ScalarDifferentiableFunction[] activations) {
  25. assert (nodesNum != null && activations != null
  26. && nodesNum.length == activations.length + 1 && nodesNum.length > 1);
  27. this.nodesNum = nodesNum;
  28. this.weights.clear();
  29. this.bs.clear();
  30. this.activations.clear();
  31. for (int i = 0; i < nodesNum.length - 1; i++) {
  32. // 列数==输入;行数==输出。
  33. int columns = nodesNum[i];
  34. int rows = nodesNum[i + 1];
  35. double r1 = Math.sqrt(6) / Math.sqrt(rows + columns + 1);
  36. //r1=0.001;
  37. // W
  38. DoubleMatrix weight = DoubleMatrix.rand(rows, columns).muli(2*r1).subi(r1);
  39. //weight=DoubleMatrix.ones(rows, columns);
  40. weights.add(weight);
  41.  
  42. // b
  43. DoubleMatrix b = DoubleMatrix.zeros(rows, 1);
  44. bs.add(b);
  45.  
  46. // activations
  47. this.activations.add(activations[i]);
  48. }
  49. }
  50. }

上帝造完了神经网络,去休息了。人说,我要使用神经网络,我要利用正向传播计算各层的结果,然后利用反向传播调整网络的状态,最后,我要让它能告诉我猎物在什么方向,花儿为什么这样香。

  1. public class Propagation {
  2. Net net;
  3.  
  4. public Propagation(Net net) {
  5. super();
  6. this.net = net;
  7. }
  8.  
  9. // 多个样本。
  10. public ForwardResult forward(DoubleMatrix input) {
  11.  
  12. ForwardResult result = new ForwardResult();
  13. result.input = input;
  14. DoubleMatrix currentResult = input;
  15. int index = -1;
  16. for (DoubleMatrix weight : net.weights) {
  17. index++;
  18. DoubleMatrix b = net.bs.get(index);
  19. final ScalarDifferentiableFunction activation = net.activations
  20. .get(index);
  21. currentResult = weight.mmul(currentResult).addColumnVector(b);
  22. result.netResult.add(currentResult);
  23.  
  24. // 乘以导数
  25. DoubleMatrix derivative = MatrixUtil.applyNewElements(
  26. new ScalarFunction() {
  27. @Override
  28. public double valueAt(double x) {
  29. return activation.derivativeAt(x);
  30. }
  31.  
  32. }, currentResult);
  33.  
  34. currentResult = MatrixUtil.applyNewElements(activation,
  35. currentResult);
  36. result.finalResult.add(currentResult);
  37.  
  38. result.derivativeResult.add(derivative);
  39. }
  40.  
  41. result.netResult=null;// 不再需要。
  42.  
  43. return result;
  44. }
  45.  
  46. // 多个样本梯度平均值。
  47. public BackwardResult backward(DoubleMatrix target,
  48. ForwardResult forwardResult) {
  49. BackwardResult result = new BackwardResult();
  50. DoubleMatrix cost = DoubleMatrix.zeros(1,target.columns);
  51. DoubleMatrix output = forwardResult.finalResult
  52. .get(forwardResult.finalResult.size() - 1);
  53. DoubleMatrix outputDelta = DoubleMatrix.zeros(output.rows,
  54. output.columns);
  55. DoubleMatrix outputDerivative = forwardResult.derivativeResult
  56. .get(forwardResult.derivativeResult.size() - 1);
  57.  
  58. DoubleMatrix accuracy = null;
  59. if (net.accuracyFunc != null) {
  60. accuracy = DoubleMatrix.zeros(1,target.columns);
  61. }
  62.  
  63. for (int i = 0; i < target.columns; i++) {
  64. CostFunction costFunc = net.costFunc.create(target.getColumn(i)
  65. .toArray());
  66. cost.put(i, costFunc.valueAt(output.getColumn(i).toArray()));
  67. // System.out.println(i);
  68. DoubleMatrix column1 = new DoubleMatrix(
  69. costFunc.derivativeAt(output.getColumn(i).toArray()));
  70. DoubleMatrix column2 = outputDerivative.getColumn(i);
  71. outputDelta.putColumn(i, column1.muli(column2));
  72.  
  73. if (net.accuracyFunc != null) {
  74. CostFunction accuracyFunc = net.accuracyFunc.create(target
  75. .getColumn(i).toArray());
  76. accuracy.put(i,
  77. accuracyFunc.valueAt(output.getColumn(i).toArray()));
  78. }
  79. }
  80. result.deltas.add(outputDelta);
  81. result.cost = cost;
  82. result.accuracy = accuracy;
  83. for (int i = net.layersNum - 1; i >= 0; i--) {
  84. DoubleMatrix pdelta = result.deltas.get(result.deltas.size() - 1);
  85.  
  86. // 梯度计算,取所有样本平均
  87. DoubleMatrix layerInput = i == 0 ? forwardResult.input
  88. : forwardResult.finalResult.get(i - 1);
  89. DoubleMatrix gradient = pdelta.mmul(layerInput.transpose()).div(
  90. target.columns);
  91. result.gradients.add(gradient);
  92. // 偏置梯度
  93. result.biasGradients.add(pdelta.rowMeans());
  94.  
  95. // 计算前一层delta,若i=0,delta为输入层误差,即input调整梯度,不作平均处理。
  96. DoubleMatrix delta = net.weights.get(i).transpose().mmul(pdelta);
  97. if (i > 0)
  98. delta = delta.muli(forwardResult.derivativeResult.get(i - 1));
  99. result.deltas.add(delta);
  100. }
  101. Collections.reverse(result.gradients);
  102. Collections.reverse(result.biasGradients);
  103.  
  104. //其它的delta都不需要。
  105. DoubleMatrix inputDeltas=result.deltas.get(result.deltas.size()-1);
  106. result.deltas.clear();
  107. result.deltas.add(inputDeltas);
  108.  
  109. return result;
  110. }
  111.  
  112. public Net getNet() {
  113. return net;
  114. }
  115.  
  116. }

上面是一次正向/反向传播的具体代码。训练方式为批量训练,即所有样本一起训练。然而我们可以传入只有一列的input/target样本实现adapt方式的串行训练,也可以把样本分成很多批传入实现mini-batch方式的训练,这,不是Propagation要考虑的事情,它只是忠实的把传入的数据正向过一遍,反向过一遍,然后把过后的数据原封不动的返回给你。至于传入什么,以及结果怎么运用,是Trainer和Learner要做的事情。下回分解。

用java写bp神经网络(一)的更多相关文章

  1. 用java写bp神经网络(四)

    接上篇. 在(一)和(二)中,程序的体系是Net,Propagation,Trainer,Learner,DataProvider.这篇重构这个体系. Net 首先是Net,在上篇重新定义了激活函数和 ...

  2. 用java写bp神经网络(三)

    孔子曰,吾日三省吾身.我们如果跟程序打交道,除了一日三省吾身外,还要三日一省吾代码.看代码是否可以更简洁,更易懂,更容易扩展,更通用,算法是否可以再优化,结构是否可以再往上抽象.代码在不断的重构过程中 ...

  3. 用java写bp神经网络(二)

    接上篇. Net和Propagation具备后,我们就可以训练了.训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量):什么时候调用学习师根据训练的结果进 ...

  4. python手写bp神经网络实现人脸性别识别1.0

    写在前面:本实验用到的图片均来自google图片,侵删! 实验介绍 用python手写一个简单bp神经网络,实现人脸的性别识别.由于本人的机器配置比较差,所以无法使用网上很红的人脸大数据数据集(如lf ...

  5. JAVA实现BP神经网络算法

    工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测. 简介 BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法 ...

  6. java写卷积神经网络---CupCnn简介

    https://blog.csdn.net/u011913612/article/details/79253450

  7. BP神经网络的手写数字识别

    BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...

  8. 【机器学习】BP神经网络实现手写数字识别

    最近用python写了一个实现手写数字识别的BP神经网络,BP的推导到处都是,但是一动手才知道,会理论推导跟实现它是两回事.关于BP神经网络的实现网上有一些代码,可惜或多或少都有各种问题,在下手写了一 ...

  9. BP神经网络的直观推导与Java实现

    人工神经网络模拟人体对于外界刺激的反应.某种刺激经过人体多层神经细胞传递后,可以触发人脑中特定的区域做出反应.人体神经网络的作用就是把某种刺激与大脑中的特定区域关联起来了,这样我们对于不同的刺激就可以 ...

随机推荐

  1. IBM Websphere 说明文档

    http://pic.dhe.ibm.com/infocenter/wasinfo/v6r1/index.jsp?topic=%2Fcom.ibm.websphere.nd.doc%2Finfo%2F ...

  2. 两个div之间有空隙

    加句*{ margin:0; padding:0;} 最近在做网页时发现,在IE7下(FF没试过),div与div之间有时会出20个像素左右的空隙,除非把margin设成负值,否则空隙无法去除.我在 ...

  3. UVA- 1504 - Genghis Khan the Conqueror(最小生成树-好题)

    题意: n个点,m个边,然后给出m条边的顶点和权值,其次是q次替换,每次替换一条边,给出每次替换的边的顶点和权值,然后求出这次替换的最小生成树的值; 最后要你输出:q次替换的平均值.其中n<30 ...

  4. 【动态规划】Vijos P1616 迎接仪式

    题目链接: https://vijos.org/p/1616 题目大意: 长度为N的字符串,只含‘j’和‘z’,可以将任意两个字符调换K次,求能够拥有的最多的'jz'串. 题目思路: [动态规划] 首 ...

  5. UVa1658 Admiral(拆点法+最小费用流)

    题目链接:http://acm.hust.edu.cn/vjudge/problem/viewProblem.action?id=51253 [思路] 固定流量的最小费用流. 拆点,将u拆分成u1和u ...

  6. python 解析 配置文件

    资料: https://docs.python.org/3/library/configparser.html 环境 python 3.4.4 RawConfigParser方式 example.cf ...

  7. XDocument和XmlDocument的区别

    刚开始使用Xml的时候,没有注意到XDocument和XmlDocument的区别,后来发现两者还是有一些不同的. XDocument和XmlDocument都可以用来操作XML文档,XDocumen ...

  8. VS2008 error C2470

    error C2470: '***类' : looks like a function definition, but there is no parameter list; skipping app ...

  9. 页游AS客户端架构设计历程记录

    以下是一个只用JAVA做过服务器架构的程序员做的AS客户端架构,希望大家能推荐好的框架和意见,也求AS高程们的引导,等到基本功能成形后,低调开源,框架可以支持一个中度型页游的开发,本文不断更新中... ...

  10. linux网络编程学习笔记之五 -----并发机制与线程�

    进程线程分配方式 简述下常见的进程和线程分配方式:(好吧,我仅仅是举几个样例作为笔记...并发的水太深了,不敢妄谈...) 1.进程线程预分配 简言之,当I/O开销大于计算开销且并发量较大时,为了节省 ...