接上篇。

Net和Propagation具备后,我们就可以训练了。训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量);什么时候调用学习师根据训练的结果进行学习,然后改进网络的权重和状态;什么时候决定训练结束。

那么这两位老师儿长的什么样子,又是怎么做到的呢?

public interface Trainer {
public void train(Net net,DataProvider provider);
} public interface Learner {
public void learn(Net net,TrainResult trainResult);
}

所谓Trainer即是给定数据,对指定网络进行训练;所谓Learner即是给定训练结果,然后对指定网络进行权重调整。

下面给出这两个接口的简单实现。

Trainer

Trainer实现简单的批量训练功能,在给定的迭代次数后停止。代码示例如下。

public class CommonTrainer implements Trainer {
int ecophs;
Learner learner;
List<Double> costs = new ArrayList<>();
List<Double> accuracys = new ArrayList<>();
int batchSize = 1; public CommonTrainer(int ecophs, Learner learner) {
super();
this.ecophs = ecophs;
this.learner = learner == null ? new MomentAdaptLearner() : learner;
} public CommonTrainer(int ecophs, Learner learner, int batchSize) {
this(ecophs, learner);
this.batchSize = batchSize;
} public void trainOne(final Net net, DataProvider provider) {
final Propagation propagation = new Propagation(net);
DoubleMatrix input = provider.getInput();
DoubleMatrix target = provider.getTarget();
final int allLen = target.columns;
final int[] nodesNum = net.getNodesNum();
final int layersNum = net.getLayersNum(); List<DoubleMatrix> inputBatches = this.getBatches(input);
final List<DoubleMatrix> targetBatches = this.getBatches(target); final List<Integer> batchLen = MatrixUtil.getEndPosition(targetBatches); final BackwardResult backwardResult = new BackwardResult(net, allLen);           // 分批并行训练
Parallel.For(inputBatches, new Parallel.Operation<DoubleMatrix>() {
@Override
public void perform(int index, DoubleMatrix subInput) {
ForwardResult subResult = propagation.forward(subInput);
DoubleMatrix subTarget = targetBatches.get(index);
BackwardResult backResult = propagation.backward(subTarget,
subResult); DoubleMatrix cost = backwardResult.cost;
DoubleMatrix accuracy = backwardResult.accuracy;
DoubleMatrix inputDeltas = backwardResult.getInputDelta(); int start = index == 0 ? 0 : batchLen.get(index - 1);
int end = batchLen.get(index) - 1;
int[] cIndexs = ArraysHelper.makeArray(start, end); cost.put(cIndexs, backResult.cost); if (accuracy != null) {
accuracy.put(cIndexs, backResult.accuracy);
} inputDeltas.put(ArraysHelper.makeArray(0, nodesNum[0] - 1),
  cIndexs, backResult.getInputDelta()); for (int i = 0; i < layersNum; i++) {
DoubleMatrix gradients = backwardResult.gradients.get(i);
DoubleMatrix biasGradients = backwardResult.biasGradients
.get(i);    DoubleMatrix subGradients = backResult.gradients.get(i)
.muli(backResult.cost.columns);
DoubleMatrix subBiasGradients = backResult.biasGradients
.get(i).muli(backResult.cost.columns);
gradients.addi(subGradients);
biasGradients.addi(subBiasGradients);
}
}
});
         // 求均值
for(DoubleMatrix gradient:backwardResult.gradients){
gradient.divi(allLen);
}
for(DoubleMatrix gradient:backwardResult.biasGradients){
gradient.divi(allLen);
} // this.mergeBackwardResult(backResults, net, input.columns);
TrainResult trainResult = new TrainResult(null, backwardResult); learner.learn(net, trainResult); Double cost = backwardResult.getMeanCost();
Double accuracy = backwardResult.getMeanAccuracy();
if (cost != null)
costs.add(cost);
if (accuracy != null)
accuracys.add(accuracy);    System.out.println(cost);
System.out.println(accuracy);
} @Override
public void train(Net net, DataProvider provider) {
for (int i = 0; i < this.ecophs; i++) {
this.trainOne(net, provider);
} }
}

Learner

Learner是具体的调整算法,当梯度计算出来后,它负责对网络权重进行调整。调整算法的选择直接影响着网络收敛的快慢。本文的实现采用简单的动量-自适应学习率算法。

其迭代公式如下:

$$W(t+1)=W(t)+\Delta W(t)$$

$$\Delta W(t)=rate(t)(1-moment(t))G(t)+moment(t)\Delta W(t-1)$$

$$rate(t+1)=\begin{cases} rate(t)\times 1.05 & \mbox{if } cost(t)<cost(t-1)\\ rate(t)\times 0.7 & \mbox{else if } cost(t)<cost(t-1)\times 1.04\\ 0.01 & \mbox{else} \end{cases}$$

$$moment(t+1)=\begin{cases} 0.9 & \mbox{if } cost(t)<cost(t-1)\\ moment(t)\times 0.7 & \mbox{else if } cost(t)<cost(t-1)\times 1.04\\ 1-0.9 & \mbox{else} \end{cases}$$

示例代码如下:

public class MomentAdaptLearner implements Learner {

	Net net;
double moment = 0.9;
double lmd = 1.05;
double preCost = 0;
double eta = 0.01;
double currentEta=eta;
double currentMoment=moment;
TrainResult preTrainResult; public MomentAdaptLearner(double moment, double eta) {
super();
this.moment = moment;
this.eta = eta;
this.currentEta=eta;
this.currentMoment=moment;
} @Override
public void learn(Net net, TrainResult trainResult) {
if (this.net == null)
init(net); BackwardResult backwardResult = trainResult.backwardResult;
BackwardResult preBackwardResult = preTrainResult.backwardResult;
double cost=backwardResult.getMeanCost();
this.modifyParameter(cost);
System.out.println("current eta:"+this.currentEta);
System.out.println("current moment:"+this.currentMoment);
for (int j = 0; j < net.getLayersNum(); j++) {
DoubleMatrix weight = net.getWeights().get(j);
DoubleMatrix gradient = backwardResult.gradients.get(j); gradient = gradient.muli(currentEta * (1 - this.currentMoment)).addi(
preBackwardResult.gradients.get(j).muli(this.currentMoment));
preBackwardResult.gradients.set(j, gradient); weight.subi(gradient); DoubleMatrix b = net.getBs().get(j);
DoubleMatrix bgradient = backwardResult.biasGradients.get(j); bgradient = bgradient.muli(currentEta * (1 - this.currentMoment)).addi(
preBackwardResult.biasGradients.get(j).muli(this.currentMoment));
preBackwardResult.biasGradients.set(j, bgradient); b.subi(bgradient);
} } public void modifyParameter(double cost){
if(cost<this.preCost){
this.currentEta*=1.05;
this.currentMoment=moment;
}else if(cost<1.04*this.preCost){
this.currentEta*=0.7;
this.currentMoment*=0.7;
}else{
this.currentEta=eta;
this.currentMoment=1-moment;
}
this.preCost=cost;
}
public void init(Net net) {
this.net = net;
BackwardResult bResult = new BackwardResult(); for (DoubleMatrix weight : net.getWeights()) {
bResult.gradients.add(DoubleMatrix.zeros(weight.rows,
weight.columns));
}
for (DoubleMatrix b : net.getBs()) {
bResult.biasGradients.add(DoubleMatrix.zeros(b.rows, b.columns));
}
preTrainResult=new TrainResult(null,bResult);
} }

现在,一个简单的神经网路从生成到训练已经简单实现完毕。

下一步,使用Levenberg-Marquardt学习算法改进收敛速率。

用java写bp神经网络(二)的更多相关文章

  1. 用java写bp神经网络(一)

    根据前篇博文<神经网络之后向传播算法>,现在用java实现一个bp神经网络.矩阵运算采用jblas库,然后逐渐增加功能,支持并行计算,然后支持输入向量调整,最后支持L-BFGS学习算法. ...

  2. 用java写bp神经网络(四)

    接上篇. 在(一)和(二)中,程序的体系是Net,Propagation,Trainer,Learner,DataProvider.这篇重构这个体系. Net 首先是Net,在上篇重新定义了激活函数和 ...

  3. 用java写bp神经网络(三)

    孔子曰,吾日三省吾身.我们如果跟程序打交道,除了一日三省吾身外,还要三日一省吾代码.看代码是否可以更简洁,更易懂,更容易扩展,更通用,算法是否可以再优化,结构是否可以再往上抽象.代码在不断的重构过程中 ...

  4. JAVA实现BP神经网络算法

    工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测. 简介 BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法 ...

  5. python手写bp神经网络实现人脸性别识别1.0

    写在前面:本实验用到的图片均来自google图片,侵删! 实验介绍 用python手写一个简单bp神经网络,实现人脸的性别识别.由于本人的机器配置比较差,所以无法使用网上很红的人脸大数据数据集(如lf ...

  6. java写卷积神经网络---CupCnn简介

    https://blog.csdn.net/u011913612/article/details/79253450

  7. 【机器学习】BP神经网络实现手写数字识别

    最近用python写了一个实现手写数字识别的BP神经网络,BP的推导到处都是,但是一动手才知道,会理论推导跟实现它是两回事.关于BP神经网络的实现网上有一些代码,可惜或多或少都有各种问题,在下手写了一 ...

  8. 二、单层感知器和BP神经网络算法

    一.单层感知器 1958年[仅仅60年前]美国心理学家FrankRosenblant剔除一种具有单层计算单元的神经网络,称为Perceptron,即感知器.感知器研究中首次提出了自组织.自学习的思想, ...

  9. BP神经网络—java实现(转载)

    神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...

随机推荐

  1. SQL server 触发器、视图

    一.触发器 1.触发器为特殊类型的存储过程,可在执行语言事件时自动生效.SQL Server 包括三种常规类型的触发器:DML 触发器.DDL 触发器和登录触发器. 主要讲述DML触发器,DML触发器 ...

  2. java学习面向对象之内部类

    什么是面向对象内部类呢?所谓的内部类,即从字面意义上来理解的话,就是把类放到类当中. 那么内部类都有什么特点呢? 1.内部类可以访问包裹他的类的成员. 2.如果包裹他的类想访问被其包裹的类的话就得实例 ...

  3. netstat 命令state值

    1.LISTENING状态 FTP服务启动后首先处于侦听(LISTENING)状态. State显示是LISTENING时表示处于侦听状态,就是说该端口是开放的,等待连接,但还没有被连接.就像你房子的 ...

  4. 265行JavaScript代码的第一人称3D H5游戏Demo【个人总结1】

    本文目的是分解前面的代码.其实,它得逻辑很清楚,只是对于我这种只是用过 Canvas 画线(用过 Fabric.js Canvas库)的人来说,这个还是很复杂的.我研究这个背景天空也是搞了一天,下面就 ...

  5. Bootstrap 3 兼容 IE8 浏览器

    公司新上的项目,前端用的Bootstrap3的框架,但它已经放弃对IE9下的支持了.可IE8还是有着许多用户,不能不照顾到他们,IE7以下的,我只想说,现在什么年代了,要解放思想,与时俱进啊,就不能动 ...

  6. 用cmd改计算机名.bat 无需重启电脑生效

    echo offset /p cname=请输入计算机名: echo REGEDIT4 >reg.reg echo [HKEY_LOCAL_MACHINE\SYSTEM\CurrentContr ...

  7. .net code injection

    .NET Internals and Code Injection http://www.ntcore.com/files/netint_injection.htm Windows Hooks in ...

  8. [原]ubuntu下制作openstack-havana源

    ubuntu下可以用apt-mirror下载openstack的源: 1.安装apt-mirror: apt-get install apt-mirror 2.配置/etc/apt/mirror.li ...

  9. hdoj 4324 Triangle LOVE【拓扑排序判断是否存在环】

    Triangle LOVE Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/65536 K (Java/Others)Tot ...

  10. C语言学习_C如何在一个文件里调用另一个源文件中的函数

    问题 C如何在一个文件里调用另一个源文件中的函数,如题. 解决办法 当程序大了代码多了之后,想模块化开发,不同文件中存一点,是很好的解决办法,那我们如何做才能让各个文件中的代码协同工作呢?我们知道,m ...