主体代码

NeuronNetwork.java

package com.rockbb.math.nnetwork;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; public class NeutonNetwork {
private List<NeuronLayer> layers; public NeuronNetwork(int[] sizes, double bpFactor, Activator activator) {
layers = new ArrayList<>(sizes.length - 1);
int inputSize = sizes[0];
for (int i = 1; i < sizes.length; i++) {
NeuronLayer layer = new NeuronLayer(inputSize, sizes[i], activator, bpFactor);
layers.add(layer);
inputSize = sizes[i];
}
for (int i = 0; i < layers.size() - 1; i++) {
layers.get(i).setNext(layers.get(i + 1));
}
} public List<NeuronLayer> getLayers() {return layers;}
public void setLayers(List<NeuronLayer> layers) {this.layers = layers;} public double getError() {
return layers.get(layers.size() - 1).getError();
} public List<Double> predict(List<Double> inputs) {
List<Double> middle = inputs;
for (int i = 0; i < layers.size(); i++) {
middle = layers.get(i).forward(middle);
}
return middle;
} public void backward() {
for (int j= layers.size() - 1; j >=0; j--) {
layers.get(j).backward();
}
} public void fillTargets(List<Double> targets) {
layers.get(layers.size() - 1).fillTargets(targets);
} @Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (int j = 0; j < layers.size(); j++) {
sb.append(layers.get(j).toString());
}
return sb.toString();
} public static String listToString(List<Double> list) {
StringBuilder sb = new StringBuilder();
for (Double t : list) {
sb.append(String.format("% 10.8f ", t));
}
return sb.toString();
} public static void main(String[] args) {
int[] sz = new int[]{2, 4, 1};
double[][] trainData = {{0d, 0d},{0d, 1d},{1d, 0d},{1d, 1d}};
double[][] targetDate = {{0d},{1d},{1d},{0d}}; NeuronNetwork nn = new NeuronNetwork(sz, 0.5d, new SigmoidActivator());
for (int kk = 0; kk < 20000; kk++) {
double totalError = 0d;
for (int i = 0; i < trainData.length; i++) {
List<Double> inputs = Arrays.asList(trainData[i][0], trainData[i][1]);
List<Double> targets = Arrays.asList(targetDate[i][0]);
nn.fillTargets(targets);
nn.predict(inputs);
//System.out.print(nn);
System.out.println(String.format("kk:%5d, i:%d, error: %.8f\n", kk, i, nn.getError()));
totalError += nn.getError();
nn.backward();
}
System.out.println(String.format("kk:%5d, Total Error: %.8f\n\n", kk, totalError));
if (totalError < 0.0001) {
System.out.println(nn);
break;
}
}
System.out.println(nn);
}
}

NeuronLayer.java

package com.rockbb.math.nnetwork;

import java.util.ArrayList;
import java.util.List; public class NeuronLayer {
private int inputSize;
private List<Neuron> neurons;
private double bias;
private Activator activator;
private NeuronLayer next;
private double bpFactor;
private List<Double> inputs; public NeuronLayer(int inputSize, int size, Activator activator, double bpFactor) {
this.inputSize = inputSize;
this.activator = activator;
this.bpFactor = bpFactor;
this.bias = Math.random() - 0.5; this.neutrons = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
Neuron neuron = new Neuron(this, inputSize);
neurons.add(neuron);
}
} public int getInputSize() {return inputSize;}
public void setInputSize(int inputSize) {this.inputSize = inputSize;}
public List<Neuron> getNeurons() {return neurons;}
public void setNeurons(List<Neuron> neurons) {this.neurons = neurons;}
public double getBias() {return bias;}
public void setBias(double bias) {this.bias = bias;}
public Activator getActivator() {return activator;}
public void setActivator(Activator activator) {this.activator = activator;}
public NeutronLayer getNext() {return next;}
public void setNext(NeutronLayer next) {this.next = next;} public List<Double> forward(List<Double> inputs) {
this.inputs = inputs;
List<Double> outputs = new ArrayList<Double>(neurons.size());
for (int i = 0; i < neurons.size(); i++) {
outputs.add(0d);
}
for (int i = 0; i < neurons.size(); i++) {
double output = neurons.get(i).forward(inputs);
outputs.set(i, output);
}
return outputs;
} public void backward() {
if (this.next == null) {
// If this is the output layer, calculate delta for each neutron
double totalDelta = 0d;
for (int i = 0; i < neurons.size(); i++) {
Neutron n = neurons.get(i);
double delta = -(n.getTarget() - n.getOutput()) * activator.backwardDelta(n.getOutput());
n.setBpDelta(delta);
totalDelta += delta;
// Reflect to each weight under this neuron
for (int j = 0; j < n.getWeights().size(); j++) {
n.getWeights().set(j, n.getWeights().get(j) - bpFactor * delta * inputs.get(j));
}
}
// Relfect to bias
this.bias = this.bias - bpFactor * totalDelta / neutrons.size();
} else {
// if this is the hidden layer
double totalDelta = 0d;
for (int i = 0; i < neurons.size(); i++) {
Neuron n = neurons.get(i);
List<Neuron> downNeurons = next.getNeurons();
double delta = 0;
for (int j = 0; j < downNeurons.size(); j++) {
delta += downNeurons.get(j).getBpDelta() * downNeurons.get(j).getWeights().get(i);
}
delta = delta * activator.backwardDelta(n.getOutput());
n.setBpDelta(delta);
totalDelta += delta;
// Reflect to each weight under this neuron
for (int j = 0; j < n.getWeights().size(); j++) {
n.getWeights().set(j, n.getWeights().get(j) - bpFactor * delta * inputs.get(j));
}
}
// Relfect to bias
this.bias = this.bias - bpFactor * totalDelta / neutrons.size();
}
} public double getError() {
double totalError = 0d;
for (int i = 0; i < neurons.size(); i++) {
totalError += Math.pow(neurons.get(i).getError(), 2);
}
return totalError / (2 * neurons.size());
} public void fillTargets(List<Double> targets) {
for (int i = 0; i < neurons.size(); i++) {
neurons.get(i).setTarget(targets.get(i));
}
} public double filter(double netInput) {
return activator.forward(netInput + bias);
} @Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(String.format("Input size: %d, bias: %.8f\n", inputSize, bias));
for (int i = 0; i < neurons.size(); i++) {
sb.append(String.format("%3d: %s\n", i, neurons.get(i).toString()));
}
return sb.toString();
}
}

Neuron.java

package com.rockbb.math.nnetwork;

import java.util.ArrayList;
import java.util.List; public class Neuron {
private NeuronLayer layer;
private List<Double> weights;
private double output;
private double target;
private double bpDelta; public Neuron(NeuronLayer layer, int inputSize) {
this.layer = layer;
this.weights = new ArrayList<>(inputSize);
for (int i = 0; i < inputSize; i++) {
// Initialize each weight with value [0.1, 1)
weights.add(Math.random() * 0.9 + 0.1);
}
this.bpDelta = 0d;
} public NeuronLayer getLayer() {return layer;}
public void setLayer(NeuronLayer layer) {this.layer = layer;}
public List<Double> getWeights() {return weights;}
public void setWeights(List<Double> weights) {this.weights = weights;}
public double getOutput() {return output;}
public void setOutput(double output) {this.output = output;}
public double getTarget() {return target;}
public void setTarget(double target) {this.target = target;}
public double getBpDelta() {return bpDelta;}
public void setBpDelta(double bpDelta) {this.bpDelta = bpDelta;} public double calcNetInput(List<Double> inputs) {
double netOutput = 0f;
for (int i = 0; i < weights.size(); i++) {
netOutput += inputs.get(i) * weights.get(i);
}
return netOutput;
} public double forward(List<Double> inputs) {
double netInput = calcNetInput(inputs);
this.output = layer.filter(netInput);
return this.output;
} public double getError() {
return target - output;
} @Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(String.format("O:% 10.8f T:% 10.8f D:% 10.8f w:{", output, target, bpDelta));
for (int i = 0; i < weights.size(); i++) {
sb.append(String.format("% 10.8f ", weights.get(i)));
}
sb.append('}');
return sb.toString();
}
}

激活函数

Activator.java

package com.rockbb.math.nnetwork;

public interface Activator {

    double forward(double input);

    double backwardDelta(double output);
}

SigmoidActivator.java

package com.rockbb.math.nnetwork;

public class SigmoidActivator implements Activator {

    public double forward(double input) {
return 1 / (1 + Math.exp(-input));
} public double backwardDelta(double output) {
return output * (1 - output);
}
}

在同样的训练数据和误差目标下, 比 http://www.emergentmind.com/neural-network 使用更少的训练次数.

使用Sigmoid激活函数工作正常.

使用ReLu激活函数时总会使某个Neuron冻结, 不能收敛, 待检查

Java实现的简单神经网络(基于Sigmoid激活函数)的更多相关文章

  1. JAX-WS 学习一:基于java的最简单的WebService服务

    JAVA 1.6 之后,自带的JAX-WS API,这使得我们可以很方便的开发一个基于Java的WebService服务. 基于JAVA的WebService 服务 1.创建服务端WebService ...

  2. 深度学习原理与框架-神经网络架构 1.神经网络构架 2.激活函数(sigmoid和relu) 3.图片预处理(减去均值和除标准差) 4.dropout(防止过拟合操作)

    神经网络构架:主要时表示神经网络的组成,即中间隐藏层的结构 对图片进行说明:我们可以看出图中的层数分布: input layer表示输入层,维度(N_num, input_dim)  N_num表示输 ...

  3. day-11 python自带库实现2层简单神经网络算法

    深度神经网络算法,是基于神经网络算法的一种拓展,其层数更深,达到多层,本文以简单神经网络为例,利用梯度下降算法进行反向更新来训练神经网络权重和偏向参数,文章最后,基于Python 库实现了一个简单神经 ...

  4. struts1:(Struts重构)构建一个简单的基于MVC模式的JavaWeb

    在构建一个简单的基于MVC模式的JavaWeb 中,我们使用了JSP+Servlet+JavaBean构建了一个基于MVC模式的简单登录系统,但在其小结中已经指出,这种模式下的Controller 和 ...

  5. 一个简单的基于HTTP协议的屏幕共享应用

    HTTP协议可以能是应用层协议里使用最广泛并且用途最多样的一个了.我们一般使用HTTP协议来浏览网页,但是HTTP协议还用来做很多其它用途.对开发人员来讲很常见的一种就是用HTTP协议作为各种版本控制 ...

  6. 如何用70行Java代码实现深度神经网络算法

    http://www.tuicool.com/articles/MfYjQfV 如何用70行Java代码实现深度神经网络算法 时间 2016-02-18 10:46:17  ITeye 原文  htt ...

  7. java实现一个简单的Web服务器

    注:本段内容来源于<JAVA 实现 简单的 HTTP服务器> 1. HTTP所有状态码 状态码 状态码英文名称 中文描述 100 Continue 继续.客户端应继续其请求 101 Swi ...

  8. 最简单的基于FFmpeg的移动端样例:Android 视频解码器-单个库版

    ===================================================== 最简单的基于FFmpeg的移动端样例系列文章列表: 最简单的基于FFmpeg的移动端样例:A ...

  9. Java语言实现简单FTP软件------>FTP软件主界面的实现(四)

    首先看一下该软件的整体代码框架                        1.首先介绍程序的主入口FTPMain.java,采用了一个漂亮的外观风格 package com.oyp.ftp; im ...

随机推荐

  1. Bluemix结合DevOps Service实现一键部署

    林炳文Evankaka原创作品.转载请注明出处http://blog.csdn.net/evankaka 摘要:本文讲述了怎样通过Bluemix与DevOps Service相结合.来构建与部署一个持 ...

  2. Unhandled Exception: System.BadImageFormatException: Could not load file or assembly (2008R2配置x64website)

    .NET Error Message: Unhandled Exception: System.BadImageFormatException: Could not load file or asse ...

  3. [leetcode]Distinct Subsequences @ Python

    原题地址:https://oj.leetcode.com/problems/distinct-subsequences/ 题意: Given a string S and a string T, co ...

  4. js 处理URL实用技巧

    escape().encodeURI().encodeURIComponent()三种方法都能对一些影响URL完整性的特殊字符进行过滤.     但后两者是将字符串转换为UTF-8的方式来传输,解决了 ...

  5. [ Laravel 5.5 文档 ] 官方扩展包 —— 全文搜索解决方案:Laravel Scout

    简介 Laravel Scout 为 Eloquent 模型全文搜索实现提供了简单的.基于驱动的解决方案.通过使用模型观察者,Scout 会自动同步更新模型记录的索引. 目前,Scout 通过 Alg ...

  6. C指针原理(14)

    tcc源码分析 本博客所有内容是原创,如果转载请注明来源 http://blog.csdn.net/myhaspl/ tcctok.h定义了C语言的词法分析的基本元素,主要定义了关键字. /* key ...

  7. JPA(四):EntityManager

    Persistence Persistence类使用于获取EntityManagerFactory实例,该类包含一个名为createEntityManagerFactory的静态方法. // 创建En ...

  8. Redis 实现队列http://igeekbar.com/igeekbar/post/436.htm

    场景说明: ·用于处理比较耗时的请求,例如批量发送邮件,如果直接在网页触发执行发送,程序会出现超时 ·高并发场景,当某个时刻请求瞬间增加时,可以把请求写入到队列,后台在去处理这些请求 ·抢购场景,先入 ...

  9. 电脑的fn锁,f1-f12与功能键 互换

    提要: 有些机子特别逆天,比如说Thinkpad e系列.好好的f1-f12一定要加上fn才能按出来,默认的是画在上面的功能键,作为娱乐来说其实是还不错的,但是像我等程序员就觉得特别逆天了.你有两个选 ...

  10. postman发送post请求

    在地址栏里输入请求url:http://127.0.0.1:8081/getuser 选择“POST”方式, 点击''body", ''form-data", 添加key:user ...