用java写bp神经网络(二)
接上篇。
Net和Propagation具备后,我们就可以训练了。训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量);什么时候调用学习师根据训练的结果进行学习,然后改进网络的权重和状态;什么时候决定训练结束。
那么这两位老师儿长的什么样子,又是怎么做到的呢?
public interface Trainer {
public void train(Net net,DataProvider provider);
} public interface Learner {
public void learn(Net net,TrainResult trainResult);
}
所谓Trainer即是给定数据,对指定网络进行训练;所谓Learner即是给定训练结果,然后对指定网络进行权重调整。
下面给出这两个接口的简单实现。
Trainer
Trainer实现简单的批量训练功能,在给定的迭代次数后停止。代码示例如下。
public class CommonTrainer implements Trainer {
int ecophs;
Learner learner;
List<Double> costs = new ArrayList<>();
List<Double> accuracys = new ArrayList<>();
int batchSize = 1; public CommonTrainer(int ecophs, Learner learner) {
super();
this.ecophs = ecophs;
this.learner = learner == null ? new MomentAdaptLearner() : learner;
} public CommonTrainer(int ecophs, Learner learner, int batchSize) {
this(ecophs, learner);
this.batchSize = batchSize;
} public void trainOne(final Net net, DataProvider provider) {
final Propagation propagation = new Propagation(net);
DoubleMatrix input = provider.getInput();
DoubleMatrix target = provider.getTarget();
final int allLen = target.columns;
final int[] nodesNum = net.getNodesNum();
final int layersNum = net.getLayersNum(); List<DoubleMatrix> inputBatches = this.getBatches(input);
final List<DoubleMatrix> targetBatches = this.getBatches(target); final List<Integer> batchLen = MatrixUtil.getEndPosition(targetBatches); final BackwardResult backwardResult = new BackwardResult(net, allLen); // 分批并行训练
Parallel.For(inputBatches, new Parallel.Operation<DoubleMatrix>() {
@Override
public void perform(int index, DoubleMatrix subInput) {
ForwardResult subResult = propagation.forward(subInput);
DoubleMatrix subTarget = targetBatches.get(index);
BackwardResult backResult = propagation.backward(subTarget,
subResult); DoubleMatrix cost = backwardResult.cost;
DoubleMatrix accuracy = backwardResult.accuracy;
DoubleMatrix inputDeltas = backwardResult.getInputDelta(); int start = index == 0 ? 0 : batchLen.get(index - 1);
int end = batchLen.get(index) - 1;
int[] cIndexs = ArraysHelper.makeArray(start, end); cost.put(cIndexs, backResult.cost); if (accuracy != null) {
accuracy.put(cIndexs, backResult.accuracy);
} inputDeltas.put(ArraysHelper.makeArray(0, nodesNum[0] - 1),
cIndexs, backResult.getInputDelta()); for (int i = 0; i < layersNum; i++) {
DoubleMatrix gradients = backwardResult.gradients.get(i);
DoubleMatrix biasGradients = backwardResult.biasGradients
.get(i); DoubleMatrix subGradients = backResult.gradients.get(i)
.muli(backResult.cost.columns);
DoubleMatrix subBiasGradients = backResult.biasGradients
.get(i).muli(backResult.cost.columns);
gradients.addi(subGradients);
biasGradients.addi(subBiasGradients);
}
}
});
// 求均值
for(DoubleMatrix gradient:backwardResult.gradients){
gradient.divi(allLen);
}
for(DoubleMatrix gradient:backwardResult.biasGradients){
gradient.divi(allLen);
} // this.mergeBackwardResult(backResults, net, input.columns);
TrainResult trainResult = new TrainResult(null, backwardResult); learner.learn(net, trainResult); Double cost = backwardResult.getMeanCost();
Double accuracy = backwardResult.getMeanAccuracy();
if (cost != null)
costs.add(cost);
if (accuracy != null)
accuracys.add(accuracy); System.out.println(cost);
System.out.println(accuracy);
} @Override
public void train(Net net, DataProvider provider) {
for (int i = 0; i < this.ecophs; i++) {
this.trainOne(net, provider);
} }
}
Learner
Learner是具体的调整算法,当梯度计算出来后,它负责对网络权重进行调整。调整算法的选择直接影响着网络收敛的快慢。本文的实现采用简单的动量-自适应学习率算法。
其迭代公式如下:
$$W(t+1)=W(t)+\Delta W(t)$$
$$\Delta W(t)=rate(t)(1-moment(t))G(t)+moment(t)\Delta W(t-1)$$
$$rate(t+1)=\begin{cases} rate(t)\times 1.05 & \mbox{if } cost(t)<cost(t-1)\\ rate(t)\times 0.7 & \mbox{else if } cost(t)<cost(t-1)\times 1.04\\ 0.01 & \mbox{else} \end{cases}$$
$$moment(t+1)=\begin{cases} 0.9 & \mbox{if } cost(t)<cost(t-1)\\ moment(t)\times 0.7 & \mbox{else if } cost(t)<cost(t-1)\times 1.04\\ 1-0.9 & \mbox{else} \end{cases}$$
示例代码如下:
public class MomentAdaptLearner implements Learner { Net net;
double moment = 0.9;
double lmd = 1.05;
double preCost = 0;
double eta = 0.01;
double currentEta=eta;
double currentMoment=moment;
TrainResult preTrainResult; public MomentAdaptLearner(double moment, double eta) {
super();
this.moment = moment;
this.eta = eta;
this.currentEta=eta;
this.currentMoment=moment;
} @Override
public void learn(Net net, TrainResult trainResult) {
if (this.net == null)
init(net); BackwardResult backwardResult = trainResult.backwardResult;
BackwardResult preBackwardResult = preTrainResult.backwardResult;
double cost=backwardResult.getMeanCost();
this.modifyParameter(cost);
System.out.println("current eta:"+this.currentEta);
System.out.println("current moment:"+this.currentMoment);
for (int j = 0; j < net.getLayersNum(); j++) {
DoubleMatrix weight = net.getWeights().get(j);
DoubleMatrix gradient = backwardResult.gradients.get(j); gradient = gradient.muli(currentEta * (1 - this.currentMoment)).addi(
preBackwardResult.gradients.get(j).muli(this.currentMoment));
preBackwardResult.gradients.set(j, gradient); weight.subi(gradient); DoubleMatrix b = net.getBs().get(j);
DoubleMatrix bgradient = backwardResult.biasGradients.get(j); bgradient = bgradient.muli(currentEta * (1 - this.currentMoment)).addi(
preBackwardResult.biasGradients.get(j).muli(this.currentMoment));
preBackwardResult.biasGradients.set(j, bgradient); b.subi(bgradient);
} } public void modifyParameter(double cost){
if(cost<this.preCost){
this.currentEta*=1.05;
this.currentMoment=moment;
}else if(cost<1.04*this.preCost){
this.currentEta*=0.7;
this.currentMoment*=0.7;
}else{
this.currentEta=eta;
this.currentMoment=1-moment;
}
this.preCost=cost;
}
public void init(Net net) {
this.net = net;
BackwardResult bResult = new BackwardResult(); for (DoubleMatrix weight : net.getWeights()) {
bResult.gradients.add(DoubleMatrix.zeros(weight.rows,
weight.columns));
}
for (DoubleMatrix b : net.getBs()) {
bResult.biasGradients.add(DoubleMatrix.zeros(b.rows, b.columns));
}
preTrainResult=new TrainResult(null,bResult);
} }
现在,一个简单的神经网路从生成到训练已经简单实现完毕。
下一步,使用Levenberg-Marquardt学习算法改进收敛速率。
用java写bp神经网络(二)的更多相关文章
- 用java写bp神经网络(一)
根据前篇博文<神经网络之后向传播算法>,现在用java实现一个bp神经网络.矩阵运算采用jblas库,然后逐渐增加功能,支持并行计算,然后支持输入向量调整,最后支持L-BFGS学习算法. ...
- 用java写bp神经网络(四)
接上篇. 在(一)和(二)中,程序的体系是Net,Propagation,Trainer,Learner,DataProvider.这篇重构这个体系. Net 首先是Net,在上篇重新定义了激活函数和 ...
- 用java写bp神经网络(三)
孔子曰,吾日三省吾身.我们如果跟程序打交道,除了一日三省吾身外,还要三日一省吾代码.看代码是否可以更简洁,更易懂,更容易扩展,更通用,算法是否可以再优化,结构是否可以再往上抽象.代码在不断的重构过程中 ...
- JAVA实现BP神经网络算法
工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测. 简介 BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法 ...
- python手写bp神经网络实现人脸性别识别1.0
写在前面:本实验用到的图片均来自google图片,侵删! 实验介绍 用python手写一个简单bp神经网络,实现人脸的性别识别.由于本人的机器配置比较差,所以无法使用网上很红的人脸大数据数据集(如lf ...
- java写卷积神经网络---CupCnn简介
https://blog.csdn.net/u011913612/article/details/79253450
- 【机器学习】BP神经网络实现手写数字识别
最近用python写了一个实现手写数字识别的BP神经网络,BP的推导到处都是,但是一动手才知道,会理论推导跟实现它是两回事.关于BP神经网络的实现网上有一些代码,可惜或多或少都有各种问题,在下手写了一 ...
- 二、单层感知器和BP神经网络算法
一.单层感知器 1958年[仅仅60年前]美国心理学家FrankRosenblant剔除一种具有单层计算单元的神经网络,称为Perceptron,即感知器.感知器研究中首次提出了自组织.自学习的思想, ...
- BP神经网络—java实现(转载)
神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...
随机推荐
- linux内核学习-
我的博客:www.while0.com 1.端口地址的设置主要有统一编址和独立编址. cat /proc/ioports 可以查询linux主机的设备端口. 2.数据传输控制方式有循环查询,中断和D ...
- Executors常用的创建ExecutorService的几个方法说明
一.线程池的创建 我们可以通过ThreadPoolExecutor来创建一个线程池. new ThreadPoolExecutor(corePoolSize, maximumPoolSize, kee ...
- -_-#【Markdown】
nswbmw / N-blog 第2章 使用 Markdown Markdown 语法说明 (简体中文版)Markdown: Basics (快速入门) 这里示范了一些 Markdown 的语法, 请 ...
- Java项目中使用配置文件配置
private String readConfig() { Properties p = new Properties(); InputStream in = getClass().getClassL ...
- HDOJ/HDU 2535 Vote(排序、)
Problem Description 美国大选是按各州的投票结果来确定最终的结果的,如果得到超过一半的州的支持就可以当选,而每个州的投票结果又是由该州选民投票产生的,如果某个州超过一半的选民支持希拉 ...
- JavaScript高级程序设计14.pdf
继承,ECMAScript只支持实现继承,而且其实现继承主要是依靠原型链来实现的 构造函数.原型.和实例的关系:每个构造函数都有一个原型对象,每个原型对象都包含一个指向构造函数的指针,每个实例都包含一 ...
- poj 1050 To the Max (简单dp)
题目链接:http://poj.org/problem?id=1050 #include<cstdio> #include<cstring> #include<iostr ...
- Android 解决ScrollView下嵌套ListView进页面不在顶部的问题
以下为整理: 方法1 刚开始还可以,后来再调试时就不行了. 为了解决scrollview和listview冲突 设置了listview的高度 结果进页面就不是在顶部了 . 解决方案1:Scrol ...
- python之json学习
1. 从python原始类型向json类型的转换过程,具体的转换如下: import json json.dump(obj, fp, skipkeys=False,ensure_ascii=True, ...
- 【题解】警位安排( 树形 DP)
[题目描述]一个重要的基地被分成了 n 个连通的区域 , 出于某种原因 , 这个基地以某一个区域为核心,呈一树形分布.在每个区域里安排警卫的费用是不同的,而每个区域的警卫都可以望见其相邻的区域 .如果 ...