接上篇。

在(一)和(二)中,程序的体系是Net,Propagation,Trainer,Learner,DataProvider。这篇重构这个体系。

Net

首先是Net,在上篇重新定义了激活函数和误差函数后,内容大致是这样的:

List<DoubleMatrix> weights = new ArrayList<DoubleMatrix>();
List<DoubleMatrix> bs = new ArrayList<>();
List<ActivationFunction> activations = new ArrayList<>();
CostFunction costFunc;
CostFunction accuracyFunc;
int[] nodesNum;
int layersNum; public CompactDoubleMatrix getCompact(){
return new CompactDoubleMatrix(this.weights,this.bs);
}

函数getCompact()生成对应的超矩阵。

DataProvider

DataProvider是数据的提供者。

public interface DataProvider {
DoubleMatrix getInput();
DoubleMatrix getTarget();
}

如果输入为向量,还包含一个向量字典。

public interface DictDataProvider extends DataProvider {
public DoubleMatrix getIndexs();
public DoubleMatrix getDict();
}

每一列为一个样本。getIndexs()返回输入向量在字典中的索引。

我写了一个有用的类BatchDataProviderFactory来对样本进行批量分割,分割成minibatch。

int batchSize;
int dataLen;
DataProvider originalProvider;
List<Integer> endPositions;
List<DataProvider> providers; public BatchDataProviderFactory(int batchSize, DataProvider originalProvider) {
super();
this.batchSize = batchSize;
this.originalProvider = originalProvider;
this.dataLen = this.originalProvider.getTarget().columns;
this.initEndPositions();
this.initProviders();
} public BatchDataProviderFactory(DataProvider originalProvider) {
this(4, originalProvider);
} public List<DataProvider> getProviders() {
return providers;
}

batchSize指明要分多少批,getProviders返回生成的minibatch,被分的原始数据为originalProvider。

Propagation

Propagation负责对神经网络的正向传播过程和反向传播过程。接口定义如下:

public interface Propagation {
public PropagationResult propagate(Net net,DataProvider provider);
}

传播函数propagate用指定数据对指定网络进行传播操作,返回执行结果。

BasePropagation实现了该接口,实现了简单的反向传播:

public class BasePropagation implements Propagation{

	// 多个样本。
protected ForwardResult forward(Net net,DoubleMatrix input) { ForwardResult result = new ForwardResult();
result.input = input;
DoubleMatrix currentResult = input;
int index = -1;
for (DoubleMatrix weight : net.weights) {
index++;
DoubleMatrix b = net.bs.get(index);
final ActivationFunction activation = net.activations
.get(index);
currentResult = weight.mmul(currentResult).addColumnVector(b);
result.netResult.add(currentResult); // 乘以导数
DoubleMatrix derivative = activation.derivativeAt(currentResult);
result.derivativeResult.add(derivative); currentResult = activation.valueAt(currentResult);
result.finalResult.add(currentResult); } result.netResult=null;// 不再需要。 return result;
} // 多个样本梯度平均值。
protected BackwardResult backward(Net net,DoubleMatrix target,
ForwardResult forwardResult) {
BackwardResult result = new BackwardResult(); DoubleMatrix output = forwardResult.getOutput();
DoubleMatrix outputDerivative = forwardResult.getOutputDerivative(); result.cost = net.costFunc.valueAt(output, target);
DoubleMatrix outputDelta = net.costFunc.derivativeAt(output, target).muli(outputDerivative);
if (net.accuracyFunc != null) {
result.accuracy=net.accuracyFunc.valueAt(output, target);
} result.deltas.add(outputDelta);
for (int i = net.layersNum - 1; i >= 0; i--) {
DoubleMatrix pdelta = result.deltas.get(result.deltas.size() - 1); // 梯度计算,取所有样本平均
DoubleMatrix layerInput = i == 0 ? forwardResult.input
: forwardResult.finalResult.get(i - 1);
DoubleMatrix gradient = pdelta.mmul(layerInput.transpose()).div(
target.columns);
result.gradients.add(gradient);
// 偏置梯度
result.biasGradients.add(pdelta.rowMeans()); // 计算前一层delta,若i=0,delta为输入层误差,即input调整梯度,不作平均处理。
DoubleMatrix delta = net.weights.get(i).transpose().mmul(pdelta);
if (i > 0)
delta = delta.muli(forwardResult.derivativeResult.get(i - 1));
result.deltas.add(delta);
}
Collections.reverse(result.gradients);
Collections.reverse(result.biasGradients); //其它的delta都不需要。
DoubleMatrix inputDeltas=result.deltas.get(result.deltas.size()-1);
result.deltas.clear();
result.deltas.add(inputDeltas); return result;
} @Override
public PropagationResult propagate(Net net, DataProvider provider) {
ForwardResult forwardResult=this.forward(net, provider.getInput());
BackwardResult backwardResult=this.backward(net, provider.getTarget(), forwardResult);
PropagationResult result=new PropagationResult(backwardResult);
result.output=forwardResult.getOutput();
return result;
}

我们定义的PropagationResult略为:

public class PropagationResult{
DoubleMatrix output;// 输出结果矩阵:outputLen*sampleLength
DoubleMatrix cost;// 误差矩阵:1*sampleLength
DoubleMatrix accuracy;// 准确度矩阵:1*sampleLength
private List<DoubleMatrix> gradients;// 权重梯度矩阵
private List<DoubleMatrix> biasGradients;// 偏置梯度矩阵
DoubleMatrix inputDeltas;//输入层delta矩阵:inputLen*sampleLength public CompactDoubleMatrix getCompact(){
return new CompactDoubleMatrix(gradients,biasGradients);
} }

另一个实现了该接口的类为MiniBatchPropagation。他在内部用并行方式对样本进行传播,然后对每个minipatch结果进行综合,内部用到了BatchDataProviderFactory类和BasePropagation类。

Trainer

Trainer接口定义为:

public interface Trainer {
public void train(Net net,DataProvider provider);
}

简单的实现类为:

public class CommonTrainer implements Trainer {
int ecophs;
Learner learner;
Propagation propagation;
List<Double> costs = new ArrayList<>();
List<Double> accuracys = new ArrayList<>();
public void trainOne(Net net, DataProvider provider) {
PropagationResult propResult = this.propagation
.propagate(net, provider);
learner.learn(net, propResult, provider); Double cost = propResult.getMeanCost();
Double accuracy = propResult.getMeanAccuracy();
if (cost != null)
costs.add(cost);
if (accuracy != null)
accuracys.add(accuracy);
} @Override
public void train(Net net, DataProvider provider) {
for (int i = 0; i < this.ecophs; i++) {
System.out.println("echops:"+i);
this.trainOne(net, provider);
} }
}

简单的迭代echops此,没有智能停止功能,每次迭代用Learner调节权重。

Learner

Learner根据每次传播结果对网络权重进行调整,接口定义如下:

public interface Learner<N extends Net,P extends DataProvider> {
public void learn(N net,PropagationResult propResult,P provider);
}

一个简单的根据动量因子-自适应学习率进行调整的实现类为:

public class MomentAdaptLearner<N extends Net, P extends DataProvider>
implements Learner<N, P> {
double moment = 0.7;
double lmd = 1.05;
double preCost = 0;
double eta = 0.01;
double currentEta = eta;
double currentMoment = moment;
CompactDoubleMatrix preGradient; public MomentAdaptLearner(double moment, double eta) {
super();
this.moment = moment;
this.eta = eta;
this.currentEta = eta;
this.currentMoment = moment;
} public MomentAdaptLearner() { } @Override
public void learn(N net, PropagationResult propResult, P provider) {
if (this.preGradient == null)
init(net, propResult, provider); double cost = propResult.getMeanCost();
this.modifyParameter(cost);
System.out.println("current eta:" + this.currentEta);
System.out.println("current moment:" + this.currentMoment);
this.updateGradient(net, propResult, provider); } public void updateGradient(N net, PropagationResult propResult, P provider) {
CompactDoubleMatrix netCompact = this.getNetCompact(net, propResult,
provider);
CompactDoubleMatrix gradCompact = this.getGradientCompact(net,
propResult, provider);
gradCompact = gradCompact.mul(currentEta * (1 - currentMoment)).addi(
preGradient.mul(currentMoment));
netCompact.subi(gradCompact);
this.preGradient = gradCompact;
} public CompactDoubleMatrix getNetCompact(N net,
PropagationResult propResult, P provider) {
return net.getCompact();
} public CompactDoubleMatrix getGradientCompact(N net,
PropagationResult propResult, P provider) {
return propResult.getCompact();
} public void modifyParameter(double cost) { if (this.currentEta > 10) {
this.currentEta = 10;
} else if (this.currentEta < 0.0001) {
this.currentEta = 0.0001;
} else if (cost < this.preCost) {
this.currentEta *= 1.05;
this.currentMoment = moment;
} else if (cost < 1.04 * this.preCost) {
this.currentEta *= 0.7;
this.currentMoment *= 0.7;
} else {
this.currentEta = eta;
this.currentMoment = 0.1;
}
this.preCost = cost;
} public void init(Net net, PropagationResult propResult, P provider) {
PropagationResult pResult = new PropagationResult(net);
preGradient = pResult.getCompact().dup();
} }

在上面的代码中,我们可以看到CompactDoubleMatrix类对权重自变量的封装,使代码更加简洁,它在此表现出来的就是一个超矩阵,超向量,完全忽略了内部的结构。

同时,其子类实现了同步更新字典的功能,代码也很简洁,只是简单的把需要调整的矩阵append到超矩阵中去即可,在父类中会统一对其进行调整:

public class DictMomentLearner extends
MomentAdaptLearner<Net, DictDataProvider> { public DictMomentLearner(double moment, double eta) {
super(moment, eta);
} public DictMomentLearner() {
super();
} @Override
public CompactDoubleMatrix getNetCompact(Net net,
PropagationResult propResult, DictDataProvider provider) {
CompactDoubleMatrix result = super.getNetCompact(net, propResult,
provider);
result.append(provider.getDict());
return result;
} @Override
public CompactDoubleMatrix getGradientCompact(Net net,
PropagationResult propResult, DictDataProvider provider) {
CompactDoubleMatrix result = super.getGradientCompact(net, propResult,
provider);
result.append(DictUtil.getDictGradient(provider, propResult));
return result;
} @Override
public void init(Net net, PropagationResult propResult,
DictDataProvider provider) {
DoubleMatrix preDictGradient = DoubleMatrix.zeros(
provider.getDict().rows, provider.getDict().columns);
super.init(net, propResult, provider);
this.preGradient.append(preDictGradient);
}
}

用java写bp神经网络(四)的更多相关文章

  1. 用java写bp神经网络(一)

    根据前篇博文<神经网络之后向传播算法>,现在用java实现一个bp神经网络.矩阵运算采用jblas库,然后逐渐增加功能,支持并行计算,然后支持输入向量调整,最后支持L-BFGS学习算法. ...

  2. 用java写bp神经网络(三)

    孔子曰,吾日三省吾身.我们如果跟程序打交道,除了一日三省吾身外,还要三日一省吾代码.看代码是否可以更简洁,更易懂,更容易扩展,更通用,算法是否可以再优化,结构是否可以再往上抽象.代码在不断的重构过程中 ...

  3. 用java写bp神经网络(二)

    接上篇. Net和Propagation具备后,我们就可以训练了.训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量):什么时候调用学习师根据训练的结果进 ...

  4. python手写bp神经网络实现人脸性别识别1.0

    写在前面:本实验用到的图片均来自google图片,侵删! 实验介绍 用python手写一个简单bp神经网络,实现人脸的性别识别.由于本人的机器配置比较差,所以无法使用网上很红的人脸大数据数据集(如lf ...

  5. JAVA实现BP神经网络算法

    工作中需要预测一个过程的时间,就想到了使用BP神经网络来进行预测. 简介 BP神经网络(Back Propagation Neural Network)是一种基于BP算法的人工神经网络,其使用BP算法 ...

  6. java写卷积神经网络---CupCnn简介

    https://blog.csdn.net/u011913612/article/details/79253450

  7. 【机器学习】BP神经网络实现手写数字识别

    最近用python写了一个实现手写数字识别的BP神经网络,BP的推导到处都是,但是一动手才知道,会理论推导跟实现它是两回事.关于BP神经网络的实现网上有一些代码,可惜或多或少都有各种问题,在下手写了一 ...

  8. BP神经网络—java实现(转载)

    神经网络的结构 神经网络的网络结构由输入层,隐含层,输出层组成.隐含层的个数+输出层的个数=神经网络的层数,也就是说神经网络的层数不包括输入层.下面是一个三层的神经网络,包含了两层隐含层,一个输出层. ...

  9. BP神经网络的手写数字识别

    BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...

随机推荐

  1. 【HDOJ】2386 Dart Challenge

    纯粹母函数+滚动数组,水之. /* 2386 */ #include <iostream> #include <string> #include <map> #in ...

  2. 解决Mac OS X Lion狮子系统及win7多分区教程

    [绿茶教程]解决Mac OS X Lion狮子系统及win7多分区教程   工具/原料 8G的u盘制作lion系统安装盘   步骤/方法  插入U盘---开机---按住左下角“Option”键(alt ...

  3. poj -2975 Nim

      Nim Time Limit: 1000MS   Memory Limit: 65536K Total Submissions: 4312   Accepted: 1998 Description ...

  4. NHibernate 存储过程使用

    NHibernate也是能够操作存储过程的,不过第一次配置可能会碰到很多错误. 一.删除 首先,我们新建一个存储过程如下: CREATE PROC DeletePerson @Id int AS DE ...

  5. HDU-1300(基础方程DP-遍历之前所有状态)

    Problem Description In Pearlania everybody is fond of pearls. One company, called The Royal Pearl, p ...

  6. 【长篇高能】ReactiveCocoa 和 MVVM 入门

    翻译自ReactiveCocoa and MVVM, an Introduction. 文中引用的 Gist 可能无法显示.为了和谐社会, 请科学上网. MVC 任何一个正经开发过一阵子软件的人都熟悉 ...

  7. 408. Valid Word Abbreviation

    感冒之后 睡了2天觉 现在痊愈了 重启刷题进程.. Google的题,E难度.. 比较的方法很多,应该是为后面的题铺垫的. 题不难,做对不容易,edge cases很多,修修改改好多次,写完发现是一坨 ...

  8. winform 窗体最大化 分类: WinForm 2014-07-17 15:57 215人阅读 评论(0) 收藏

    1:窗体首次加载时最大化 (1):主窗体 this.WindowState = FormWindowState.Maximized; //窗体显示中间部分,不显示窗体名称和最小化.最大化.关闭按钮   ...

  9. Hibernate配置文件详解

    Hibernate配置方式 Hibernate给人的感受是灵活的,要达到同一个目的,我们可以使用几种不同的办法.就拿Hibernate配置来说,常用的有如下三种方式,任选其一. 在 hibernate ...

  10. [RxJS] What RxJS operators are

    We have covered the basics of what is Observable.create, and other creation functions. Now lets fina ...