主体代码

NeuronNetwork.java

  1. package com.rockbb.math.nnetwork;
  2.  
  3. import java.util.ArrayList;
  4. import java.util.Arrays;
  5. import java.util.List;
  6.  
  7. public class NeutonNetwork {
  8. private List<NeuronLayer> layers;
  9.  
  10. public NeuronNetwork(int[] sizes, double bpFactor, Activator activator) {
  11. layers = new ArrayList<>(sizes.length - 1);
  12. int inputSize = sizes[0];
  13. for (int i = 1; i < sizes.length; i++) {
  14. NeuronLayer layer = new NeuronLayer(inputSize, sizes[i], activator, bpFactor);
  15. layers.add(layer);
  16. inputSize = sizes[i];
  17. }
  18. for (int i = 0; i < layers.size() - 1; i++) {
  19. layers.get(i).setNext(layers.get(i + 1));
  20. }
  21. }
  22.  
  23. public List<NeuronLayer> getLayers() {return layers;}
  24. public void setLayers(List<NeuronLayer> layers) {this.layers = layers;}
  25.  
  26. public double getError() {
  27. return layers.get(layers.size() - 1).getError();
  28. }
  29.  
  30. public List<Double> predict(List<Double> inputs) {
  31. List<Double> middle = inputs;
  32. for (int i = 0; i < layers.size(); i++) {
  33. middle = layers.get(i).forward(middle);
  34. }
  35. return middle;
  36. }
  37.  
  38. public void backward() {
  39. for (int j= layers.size() - 1; j >=0; j--) {
  40. layers.get(j).backward();
  41. }
  42. }
  43.  
  44. public void fillTargets(List<Double> targets) {
  45. layers.get(layers.size() - 1).fillTargets(targets);
  46. }
  47.  
  48. @Override
  49. public String toString() {
  50. StringBuilder sb = new StringBuilder();
  51. for (int j = 0; j < layers.size(); j++) {
  52. sb.append(layers.get(j).toString());
  53. }
  54. return sb.toString();
  55. }
  56.  
  57. public static String listToString(List<Double> list) {
  58. StringBuilder sb = new StringBuilder();
  59. for (Double t : list) {
  60. sb.append(String.format("% 10.8f ", t));
  61. }
  62. return sb.toString();
  63. }
  64.  
  65. public static void main(String[] args) {
  66. int[] sz = new int[]{2, 4, 1};
  67. double[][] trainData = {{0d, 0d},{0d, 1d},{1d, 0d},{1d, 1d}};
  68. double[][] targetDate = {{0d},{1d},{1d},{0d}};
  69.  
  70. NeuronNetwork nn = new NeuronNetwork(sz, 0.5d, new SigmoidActivator());
  71. for (int kk = 0; kk < 20000; kk++) {
  72. double totalError = 0d;
  73. for (int i = 0; i < trainData.length; i++) {
  74. List<Double> inputs = Arrays.asList(trainData[i][0], trainData[i][1]);
  75. List<Double> targets = Arrays.asList(targetDate[i][0]);
  76. nn.fillTargets(targets);
  77. nn.predict(inputs);
  78. //System.out.print(nn);
  79. System.out.println(String.format("kk:%5d, i:%d, error: %.8f\n", kk, i, nn.getError()));
  80. totalError += nn.getError();
  81. nn.backward();
  82. }
  83. System.out.println(String.format("kk:%5d, Total Error: %.8f\n\n", kk, totalError));
  84. if (totalError < 0.0001) {
  85. System.out.println(nn);
  86. break;
  87. }
  88. }
  89. System.out.println(nn);
  90. }
  91. }

NeuronLayer.java

  1. package com.rockbb.math.nnetwork;
  2.  
  3. import java.util.ArrayList;
  4. import java.util.List;
  5.  
  6. public class NeuronLayer {
  7. private int inputSize;
  8. private List<Neuron> neurons;
  9. private double bias;
  10. private Activator activator;
  11. private NeuronLayer next;
  12. private double bpFactor;
  13. private List<Double> inputs;
  14.  
  15. public NeuronLayer(int inputSize, int size, Activator activator, double bpFactor) {
  16. this.inputSize = inputSize;
  17. this.activator = activator;
  18. this.bpFactor = bpFactor;
  19. this.bias = Math.random() - 0.5;
  20.  
  21. this.neutrons = new ArrayList<>(size);
  22. for (int i = 0; i < size; i++) {
  23. Neuron neuron = new Neuron(this, inputSize);
  24. neurons.add(neuron);
  25. }
  26. }
  27.  
  28. public int getInputSize() {return inputSize;}
  29. public void setInputSize(int inputSize) {this.inputSize = inputSize;}
  30. public List<Neuron> getNeurons() {return neurons;}
  31. public void setNeurons(List<Neuron> neurons) {this.neurons = neurons;}
  32. public double getBias() {return bias;}
  33. public void setBias(double bias) {this.bias = bias;}
  34. public Activator getActivator() {return activator;}
  35. public void setActivator(Activator activator) {this.activator = activator;}
  36. public NeutronLayer getNext() {return next;}
  37. public void setNext(NeutronLayer next) {this.next = next;}
  38.  
  39. public List<Double> forward(List<Double> inputs) {
  40. this.inputs = inputs;
  41. List<Double> outputs = new ArrayList<Double>(neurons.size());
  42. for (int i = 0; i < neurons.size(); i++) {
  43. outputs.add(0d);
  44. }
  45. for (int i = 0; i < neurons.size(); i++) {
  46. double output = neurons.get(i).forward(inputs);
  47. outputs.set(i, output);
  48. }
  49. return outputs;
  50. }
  51.  
  52. public void backward() {
  53. if (this.next == null) {
  54. // If this is the output layer, calculate delta for each neutron
  55. double totalDelta = 0d;
  56. for (int i = 0; i < neurons.size(); i++) {
  57. Neutron n = neurons.get(i);
  58. double delta = -(n.getTarget() - n.getOutput()) * activator.backwardDelta(n.getOutput());
  59. n.setBpDelta(delta);
  60. totalDelta += delta;
  61. // Reflect to each weight under this neuron
  62. for (int j = 0; j < n.getWeights().size(); j++) {
  63. n.getWeights().set(j, n.getWeights().get(j) - bpFactor * delta * inputs.get(j));
  64. }
  65. }
  66. // Relfect to bias
  67. this.bias = this.bias - bpFactor * totalDelta / neutrons.size();
  68. } else {
  69. // if this is the hidden layer
  70. double totalDelta = 0d;
  71. for (int i = 0; i < neurons.size(); i++) {
  72. Neuron n = neurons.get(i);
  73. List<Neuron> downNeurons = next.getNeurons();
  74. double delta = 0;
  75. for (int j = 0; j < downNeurons.size(); j++) {
  76. delta += downNeurons.get(j).getBpDelta() * downNeurons.get(j).getWeights().get(i);
  77. }
  78. delta = delta * activator.backwardDelta(n.getOutput());
  79. n.setBpDelta(delta);
  80. totalDelta += delta;
  81. // Reflect to each weight under this neuron
  82. for (int j = 0; j < n.getWeights().size(); j++) {
  83. n.getWeights().set(j, n.getWeights().get(j) - bpFactor * delta * inputs.get(j));
  84. }
  85. }
  86. // Relfect to bias
  87. this.bias = this.bias - bpFactor * totalDelta / neutrons.size();
  88. }
  89. }
  90.  
  91. public double getError() {
  92. double totalError = 0d;
  93. for (int i = 0; i < neurons.size(); i++) {
  94. totalError += Math.pow(neurons.get(i).getError(), 2);
  95. }
  96. return totalError / (2 * neurons.size());
  97. }
  98.  
  99. public void fillTargets(List<Double> targets) {
  100. for (int i = 0; i < neurons.size(); i++) {
  101. neurons.get(i).setTarget(targets.get(i));
  102. }
  103. }
  104.  
  105. public double filter(double netInput) {
  106. return activator.forward(netInput + bias);
  107. }
  108.  
  109. @Override
  110. public String toString() {
  111. StringBuilder sb = new StringBuilder();
  112. sb.append(String.format("Input size: %d, bias: %.8f\n", inputSize, bias));
  113. for (int i = 0; i < neurons.size(); i++) {
  114. sb.append(String.format("%3d: %s\n", i, neurons.get(i).toString()));
  115. }
  116. return sb.toString();
  117. }
  118. }

Neuron.java

  1. package com.rockbb.math.nnetwork;
  2.  
  3. import java.util.ArrayList;
  4. import java.util.List;
  5.  
  6. public class Neuron {
  7. private NeuronLayer layer;
  8. private List<Double> weights;
  9. private double output;
  10. private double target;
  11. private double bpDelta;
  12.  
  13. public Neuron(NeuronLayer layer, int inputSize) {
  14. this.layer = layer;
  15. this.weights = new ArrayList<>(inputSize);
  16. for (int i = 0; i < inputSize; i++) {
  17. // Initialize each weight with value [0.1, 1)
  18. weights.add(Math.random() * 0.9 + 0.1);
  19. }
  20. this.bpDelta = 0d;
  21. }
  22.  
  23. public NeuronLayer getLayer() {return layer;}
  24. public void setLayer(NeuronLayer layer) {this.layer = layer;}
  25. public List<Double> getWeights() {return weights;}
  26. public void setWeights(List<Double> weights) {this.weights = weights;}
  27. public double getOutput() {return output;}
  28. public void setOutput(double output) {this.output = output;}
  29. public double getTarget() {return target;}
  30. public void setTarget(double target) {this.target = target;}
  31. public double getBpDelta() {return bpDelta;}
  32. public void setBpDelta(double bpDelta) {this.bpDelta = bpDelta;}
  33.  
  34. public double calcNetInput(List<Double> inputs) {
  35. double netOutput = 0f;
  36. for (int i = 0; i < weights.size(); i++) {
  37. netOutput += inputs.get(i) * weights.get(i);
  38. }
  39. return netOutput;
  40. }
  41.  
  42. public double forward(List<Double> inputs) {
  43. double netInput = calcNetInput(inputs);
  44. this.output = layer.filter(netInput);
  45. return this.output;
  46. }
  47.  
  48. public double getError() {
  49. return target - output;
  50. }
  51.  
  52. @Override
  53. public String toString() {
  54. StringBuilder sb = new StringBuilder();
  55. sb.append(String.format("O:% 10.8f T:% 10.8f D:% 10.8f w:{", output, target, bpDelta));
  56. for (int i = 0; i < weights.size(); i++) {
  57. sb.append(String.format("% 10.8f ", weights.get(i)));
  58. }
  59. sb.append('}');
  60. return sb.toString();
  61. }
  62. }

激活函数

Activator.java

  1. package com.rockbb.math.nnetwork;
  2.  
  3. public interface Activator {
  4.  
  5. double forward(double input);
  6.  
  7. double backwardDelta(double output);
  8. }

SigmoidActivator.java

  1. package com.rockbb.math.nnetwork;
  2.  
  3. public class SigmoidActivator implements Activator {
  4.  
  5. public double forward(double input) {
  6. return 1 / (1 + Math.exp(-input));
  7. }
  8.  
  9. public double backwardDelta(double output) {
  10. return output * (1 - output);
  11. }
  12. }

在同样的训练数据和误差目标下, 比 http://www.emergentmind.com/neural-network 使用更少的训练次数.

使用Sigmoid激活函数工作正常.

使用ReLu激活函数时总会使某个Neuron冻结, 不能收敛, 待检查

Java实现的简单神经网络(基于Sigmoid激活函数)的更多相关文章

  1. JAX-WS 学习一:基于java的最简单的WebService服务

    JAVA 1.6 之后,自带的JAX-WS API,这使得我们可以很方便的开发一个基于Java的WebService服务. 基于JAVA的WebService 服务 1.创建服务端WebService ...

  2. 深度学习原理与框架-神经网络架构 1.神经网络构架 2.激活函数(sigmoid和relu) 3.图片预处理(减去均值和除标准差) 4.dropout(防止过拟合操作)

    神经网络构架:主要时表示神经网络的组成,即中间隐藏层的结构 对图片进行说明:我们可以看出图中的层数分布: input layer表示输入层,维度(N_num, input_dim)  N_num表示输 ...

  3. day-11 python自带库实现2层简单神经网络算法

    深度神经网络算法,是基于神经网络算法的一种拓展,其层数更深,达到多层,本文以简单神经网络为例,利用梯度下降算法进行反向更新来训练神经网络权重和偏向参数,文章最后,基于Python 库实现了一个简单神经 ...

  4. struts1:(Struts重构)构建一个简单的基于MVC模式的JavaWeb

    在构建一个简单的基于MVC模式的JavaWeb 中,我们使用了JSP+Servlet+JavaBean构建了一个基于MVC模式的简单登录系统,但在其小结中已经指出,这种模式下的Controller 和 ...

  5. 一个简单的基于HTTP协议的屏幕共享应用

    HTTP协议可以能是应用层协议里使用最广泛并且用途最多样的一个了.我们一般使用HTTP协议来浏览网页,但是HTTP协议还用来做很多其它用途.对开发人员来讲很常见的一种就是用HTTP协议作为各种版本控制 ...

  6. 如何用70行Java代码实现深度神经网络算法

    http://www.tuicool.com/articles/MfYjQfV 如何用70行Java代码实现深度神经网络算法 时间 2016-02-18 10:46:17  ITeye 原文  htt ...

  7. java实现一个简单的Web服务器

    注:本段内容来源于<JAVA 实现 简单的 HTTP服务器> 1. HTTP所有状态码 状态码 状态码英文名称 中文描述 100 Continue 继续.客户端应继续其请求 101 Swi ...

  8. 最简单的基于FFmpeg的移动端样例:Android 视频解码器-单个库版

    ===================================================== 最简单的基于FFmpeg的移动端样例系列文章列表: 最简单的基于FFmpeg的移动端样例:A ...

  9. Java语言实现简单FTP软件------>FTP软件主界面的实现(四)

    首先看一下该软件的整体代码框架                        1.首先介绍程序的主入口FTPMain.java,采用了一个漂亮的外观风格 package com.oyp.ftp; im ...

随机推荐

  1. Spring Data JPA @Column 注解无效 打出的语句有下划线

    最近再写一个Restful API的小例子,遇到这样一个问题,在Spring Boot 下使用CrudRepository,总是提示如下错误: Caused by: java.sql.SQLSynta ...

  2. 【UOJ Round #1】

    枚举/DP+排列组合 缩进优化 QAQ我当时一直在想:$min\{ \sum_{i=1}^n (\lfloor\frac{a[i]}{x}\rfloor + a[i] \ mod\ x) \}$ 然而 ...

  3. scala编程第16章学习笔记(1)

    List列表的基本操作 head方法获得列表的第一个元素 tail方法获得列表除第一个元素之外的其它元素 isEmpty:判断列表是否为空,空的话返回真 last:获得列表最后一个元素 init:获得 ...

  4. Git-忽略规则.gitignore生效

    Git中如果忽略掉某个文件,不让这个文件提交到版本库中,可以使用修改根目录中 .gitignore 文件的方法,如下这个文件每一行保存了一个匹配的规则例如,忽略单个文件或者整个目录的文件: *.css ...

  5. (十一) 整合spring cloud云架构 - SSO单点登录之OAuth2.0登录流程(2)

    上一篇是站在巨人的肩膀上去研究OAuth2.0,也是为了快速帮助大家认识OAuth2.0,闲话少说,我根据框架中OAuth2.0的使用总结,画了一个简单的流程图(根据用户名+密码实现OAuth2.0的 ...

  6. Android -- Interpolator

    Interpolator 被用来修饰动画效果,定义动画的变化率,可以使存在的动画效果accelerated(加速),decelerated(减速),repeated(重复),bounced(弹跳)等. ...

  7. [Functional Programming] Arrow Functor with contramap

    What is Arrow Functor? Arrow is a Profunctor that lifts a function of type a -> b and allows for ...

  8. ASP.NET MVC 基于页面的权限管理

    菜单表 namespace AspNetMvcAuthDemo1.Models { public class PermissionItem { public int ID { set; get; } ...

  9. AI单挑Dota 2世界冠军:被电脑虐哭……

    OpenAI的机器人刚刚在 Dota2 1v1 比赛中战胜了人类顶级职业玩家 Denti.以建设安全的通用人工智能为己任的 OpenAI,通过“Self-Play”的方式,从零开始训练出了这个机器人. ...

  10. C# list与数组互相转换

    1,从System.String[]转到List<System.String>System.String[] str={"str","string" ...