用java写bp神经网络(二)
接上篇。
Net和Propagation具备后,我们就可以训练了。训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量);什么时候调用学习师根据训练的结果进行学习,然后改进网络的权重和状态;什么时候决定训练结束。
那么这两位老师儿长的什么样子,又是怎么做到的呢?
public interface Trainer {
public void train(Net net,DataProvider provider);
} public interface Learner {
public void learn(Net net,TrainResult trainResult);
}
所谓Trainer即是给定数据,对指定网络进行训练;所谓Learner即是给定训练结果,然后对指定网络进行权重调整。
下面给出这两个接口的简单实现。
Trainer
Trainer实现简单的批量训练功能,在给定的迭代次数后停止。代码示例如下。
public class CommonTrainer implements Trainer {
int ecophs;
Learner learner;
List<Double> costs = new ArrayList<>();
List<Double> accuracys = new ArrayList<>();
int batchSize = 1; public CommonTrainer(int ecophs, Learner learner) {
super();
this.ecophs = ecophs;
this.learner = learner == null ? new MomentAdaptLearner() : learner;
} public CommonTrainer(int ecophs, Learner learner, int batchSize) {
this(ecophs, learner);
this.batchSize = batchSize;
} public void trainOne(final Net net, DataProvider provider) {
final Propagation propagation = new Propagation(net);
DoubleMatrix input = provider.getInput();
DoubleMatrix target = provider.getTarget();
final int allLen = target.columns;
final int[] nodesNum = net.getNodesNum();
final int layersNum = net.getLayersNum(); List<DoubleMatrix> inputBatches = this.getBatches(input);
final List<DoubleMatrix> targetBatches = this.getBatches(target); final List<Integer> batchLen = MatrixUtil.getEndPosition(targetBatches); final BackwardResult backwardResult = new BackwardResult(net, allLen); // 分批并行训练
Parallel.For(inputBatches, new Parallel.Operation<DoubleMatrix>() {
@Override
public void perform(int index, DoubleMatrix subInput) {
ForwardResult subResult = propagation.forward(subInput);
DoubleMatrix subTarget = targetBatches.get(index);
BackwardResult backResult = propagation.backward(subTarget,
subResult); DoubleMatrix cost = backwardResult.cost;
DoubleMatrix accuracy = backwardResult.accuracy;
DoubleMatrix inputDeltas = backwardResult.getInputDelta(); int start = index == 0 ? 0 : batchLen.get(index - 1);
int end = batchLen.get(index) - 1;
int[] cIndexs = ArraysHelper.makeArray(start, end); cost.put(cIndexs, backResult.cost); if (accuracy != null) {
accuracy.put(cIndexs, backResult.accuracy);
} inputDeltas.put(ArraysHelper.makeArray(0, nodesNum[0] - 1),
cIndexs, backResult.getInputDelta()); for (int i = 0; i < layersNum; i++) {
DoubleMatrix gradients = backwardResult.gradients.get(i);
DoubleMatrix biasGradients = backwardResult.biasGradients
.get(i); DoubleMatrix subGradients = backResult.gradients.get(i)
.muli(backResult.cost.columns);
DoubleMatrix subBiasGradients = backResult.biasGradients
.get(i).muli(backResult.cost.columns);
gradients.addi(subGradients);
biasGradients.addi(subBiasGradients);
}
}
});
// 求均值
for(DoubleMatrix gradient:backwardResult.gradients){
gradient.divi(allLen);
}
for(DoubleMatrix gradient:backwardResult.biasGradients){
gradient.divi(allLen);
} // this.mergeBackwardResult(backResults, net, input.columns);
TrainResult trainResult = new TrainResult(null, backwardResult); learner.learn(net, trainResult); Double cost = backwardResult.getMeanCost();
Double accuracy = backwardResult.getMeanAccuracy();
if (cost != null)
costs.add(cost);
if (accuracy != null)
accuracys.add(accuracy); System.out.println(cost);
System.out.println(accuracy);
} @Override
public void train(Net net, DataProvider provider) {
for (int i = 0; i < this.ecophs; i++) {
this.trainOne(net, provider);
} }
}
Learner
Learner是具体的调整算法,当梯度计算出来后,它负责对网络权重进行调整。调整算法的选择直接影响着网络收敛的快慢。本文的实现采用简单的动量-自适应学习率算法。
其迭代公式如下:
$$W(t+1)=W(t)+\Delta W(t)$$
$$\Delta W(t)=rate(t)(1-moment(t))G(t)+moment(t)\Delta W(t-1)$$
$$rate(t+1)=\begin{cases} rate(t)\times 1.05 & \mbox{if } cost(t)<cost(t-1)\\ rate(t)\times 0.7 & \mbox{else if } cost(t)<cost(t-1)\times 1.04\\ 0.01 & \mbox{else} \end{cases}$$
$$moment(t+1)=\begin{cases} 0.9 & \mbox{if } cost(t)<cost(t-1)\\ moment(t)\times 0.7 & \mbox{else if } cost(t)<cost(t-1)\times 1.04\\ 1-0.9 & \mbox{else} \end{cases}$$
示例代码如下:
public class MomentAdaptLearner implements Learner { Net net;
double moment = 0.9;
double lmd = 1.05;
double preCost = 0;
double eta = 0.01;
double currentEta=eta;
double currentMoment=moment;
TrainResult preTrainResult; public MomentAdaptLearner(double moment, double eta) {
super();
this.moment = moment;
this.eta = eta;
this.currentEta=eta;
this.currentMoment=moment;
} @Override
public void learn(Net net, TrainResult trainResult) {
if (this.net == null)
init(net); BackwardResult backwardResult = trainResult.backwardResult;
BackwardResult preBackwardResult = preTrainResult.backwardResult;
double cost=backwardResult.getMeanCost();
this.modifyParameter(cost);
System.out.println("current eta:"+this.currentEta);
System.out.println("current moment:"+this.currentMoment);
for (int j = 0; j < net.getLayersNum(); j++) {
DoubleMatrix weight = net.getWeights().get(j);
DoubleMatrix gradient = backwardResult.gradients.get(j); gradient = gradient.muli(currentEta * (1 - this.currentMoment)).addi(
preBackwardResult.gradients.get(j).muli(this.currentMoment));
preBackwardResult.gradients.set(j, gradient); weight.subi(gradient); DoubleMatrix b = net.getBs().get(j);
DoubleMatrix bgradient = backwardResult.biasGradients.get(j); bgradient = bgradient.muli(currentEta * (1 - this.currentMoment)).addi(
preBackwardResult.biasGradients.get(j).muli(this.currentMoment));
preBackwardResult.biasGradients.set(j, bgradient); b.subi(bgradient);
} } public void modifyParameter(double cost){
if(cost<this.preCost){
this.currentEta*=1.05;
this.currentMoment=moment;
}else if(cost<1.04*this.preCost){
this.currentEta*=0.7;
this.currentMoment*=0.7;
}else{
this.currentEta=eta;
this.currentMoment=1-moment;
}
this.preCost=cost;
}
public void init(Net net) {
this.net = net;
BackwardResult bResult = new BackwardResult(); for (DoubleMatrix weight : net.getWeights()) {
bResult.gradients.add(DoubleMatrix.zeros(weight.rows,
weight.columns));
}
for (DoubleMatrix b : net.getBs()) {
bResult.biasGradients.add(DoubleMatrix.zeros(b.rows, b.columns));
}
preTrainResult=new TrainResult(null,bResult);
} }
现在,一个简单的神经网路从生成到训练已经简单实现完毕。
下一步,使用Levenberg-Marquardt学习算法改进收敛速率。
用java写bp神经网络(二)的更多相关文章
- 用java写bp神经网络(一)
根据前篇博文<神经网络之后向传播算法>,现在用java实现一个bp神经网络.矩阵运算采用jblas库,然后逐渐增加功能,支持并行计算,然后支持输入向量调整,最后支持L-BFGS学习算法. ...
- 用java写bp神经网络(四)
接上篇. 在(一)和(二)中,程序的体系是Net,Propagation,Trainer,Learner,DataProvider.这篇重构这个体系. Net 首先是Net,在上篇重新定义了激活函数和 ...
- 用java写bp神经网络(三)
孔子曰,吾日三省吾身.我们如果跟程序打交道,除了一日三省吾身外,还要三日一省吾代码.看代码是否可以更简洁,更易懂,更容易扩展,更通用,算法是否可以再优化,结构是否可以再往上抽象.代码在不断的重构过程中 ...
- JAVA实现BP神经网络算法
工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测. 简介 BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法 ...
- python手写bp神经网络实现人脸性别识别1.0
写在前面:本实验用到的图片均来自google图片,侵删! 实验介绍 用python手写一个简单bp神经网络,实现人脸的性别识别.由于本人的机器配置比较差,所以无法使用网上很红的人脸大数据数据集(如lf ...
- java写卷积神经网络---CupCnn简介
https://blog.csdn.net/u011913612/article/details/79253450
- 【机器学习】BP神经网络实现手写数字识别
最近用python写了一个实现手写数字识别的BP神经网络,BP的推导到处都是,但是一动手才知道,会理论推导跟实现它是两回事.关于BP神经网络的实现网上有一些代码,可惜或多或少都有各种问题,在下手写了一 ...
- 二、单层感知器和BP神经网络算法
一.单层感知器 1958年[仅仅60年前]美国心理学家FrankRosenblant剔除一种具有单层计算单元的神经网络,称为Perceptron,即感知器.感知器研究中首次提出了自组织.自学习的思想, ...
- BP神经网络—java实现(转载)
神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...
随机推荐
- struts2表单验证里field-validator type值一共可以取哪些?都什么含义?
int 整数: double 实数: date 日期: expression 两数的关系比较: email Email地址: url visitor conversion regex 正则表达式验证: ...
- Windows Azure云服务价格调整通知
好消息!由世纪互联运营的 Windows Azure推出优惠啦.我们采纳了多渠道客户的意见和建议,为了更好地服务大家,将降低多种云服务的价格,其中包括我们最受欢迎的服务 -虚拟机和 Block ...
- Linux Shell编程(14)——内部变量
内建变量影响Bash脚本行为的变量.$BASHBash二进制程序文件的路径 bash$ echo $BASH /bin/bash$BASH_ENV该环境变量保存一个Bash启动文件路径,当启动一个脚本 ...
- netstat 命令state值
1.LISTENING状态 FTP服务启动后首先处于侦听(LISTENING)状态. State显示是LISTENING时表示处于侦听状态,就是说该端口是开放的,等待连接,但还没有被连接.就像你房子的 ...
- unity3d shader之Roberts,Sobel,Canny 三种边缘检测方法
方法其实都差不多,就是用两个过滤器,分别处理两个分量 Sobel算子 先说Sobel算子 GX为水平过滤器,GY为垂直过滤器,垂直过滤器就是水平过滤器旋转90度.过滤器为3x3的矩阵,将与图像作平面卷 ...
- 51单片机的堆栈指针(SP)
堆栈指针(SP,Stack Pointer),专门用于指出堆栈顶部数据的地址. 那么51单片机的堆栈在什么地方呢?由于单片机中存放数据的区域有限,我们不能够专门分配一块地方做堆栈,所以就在内存(RAM ...
- 把测试app打包成ipa文件
我终于把我的程序放到我的touch上了,其实把app放到touch上还有很多办法,这篇教程是主要讲怎么把app注册了,然后打包成一个ipa文件的. 先上官方文档:https://developer.a ...
- U盘做启动盘后,如何恢复原始容量
上次用U盘装系统后,U盘缩水1G多,格式化和快速格式化,没有用,无法恢复U盘原来的容量,后来在网上查到一个方法,成功释放U盘空间,故将恢复方法写在下面. (1)右击“我的电脑”,选择“管理”选项,之后 ...
- Bzoj 2763: [JLOI2011]飞行路线 拆点,分层图,最短路,SPFA
2763: [JLOI2011]飞行路线 Time Limit: 10 Sec Memory Limit: 128 MBSubmit: 1694 Solved: 635[Submit][Statu ...
- Node.js 初探
概念 Node.js 是构建在Chrome javascript runtime之上的平台,能够很容易的构建快速的,可伸缩性的网络应用程序.Node.js使用事件驱动,非阻塞I/O 模式,这使它能够更 ...