由于实验室事情缘故,需要将Python写的神经网络转成Java版本的,但是python中的numpy等啥包也不知道在Java里面对应的是什么工具,所以索性直接寻找一个现成可用的Java神经网络框架,于是就找到了JOONE,JOONE是一个神经网络的开源框架,使用的是BP算法进行迭代计算参数,使用起来比较方便也比较实用,下面介绍一下JOONE的一些使用方法。

JOONE需要使用一些外部的依赖包,这在官方网站上有,也可以在这里下载。将所需的包引入工程之后,就可以进行编码实现了。

首先看下完整的程序,这个是上面那个超链接给出的程序,应该是官方给出的一个示例吧,因为好多文章都用这个,这其实是神经网络训练一个异或计算器:

  1. import org.joone.engine.*;
  2. import org.joone.engine.learning.*;
  3. import org.joone.io.*;
  4. import org.joone.net.*;
  5. /*
  6. *
  7. * JOONE实现
  8. *
  9. * */
  10. public class XOR_using_NeuralNet implements NeuralNetListener
  11. {
  12. private NeuralNet nnet = null;
  13. private MemoryInputSynapse inputSynapse, desiredOutputSynapse;
  14. LinearLayer input;
  15. SigmoidLayer hidden, output;
  16. boolean singleThreadMode = true;
  17. // XOR input
  18. private double[][] inputArray = new double[][]
  19. {
  20. { 0.0, 0.0 },
  21. { 0.0, 1.0 },
  22. { 1.0, 0.0 },
  23. { 1.0, 1.0 } };
  24. // XOR desired output
  25. private double[][] desiredOutputArray = new double[][]
  26. {
  27. { 0.0 },
  28. { 1.0 },
  29. { 1.0 },
  30. { 0.0 } };
  31. /**
  32. * @param args
  33. *            the command line arguments
  34. */
  35. public static void main(String args[])
  36. {
  37. XOR_using_NeuralNet xor = new XOR_using_NeuralNet();
  38. xor.initNeuralNet();
  39. xor.train();
  40. xor.interrogate();
  41. }
  42. /**
  43. * Method declaration
  44. */
  45. public void train()
  46. {
  47. // set the inputs
  48. inputSynapse.setInputArray(inputArray);
  49. inputSynapse.setAdvancedColumnSelector(" 1,2 ");
  50. // set the desired outputs
  51. desiredOutputSynapse.setInputArray(desiredOutputArray);
  52. desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");
  53. // get the monitor object to train or feed forward
  54. Monitor monitor = nnet.getMonitor();
  55. // set the monitor parameters
  56. monitor.setLearningRate(0.8);
  57. monitor.setMomentum(0.3);
  58. monitor.setTrainingPatterns(inputArray.length);
  59. monitor.setTotCicles(5000);
  60. monitor.setLearning(true);
  61. long initms = System.currentTimeMillis();
  62. // Run the network in single-thread, synchronized mode
  63. nnet.getMonitor().setSingleThreadMode(singleThreadMode);
  64. nnet.go(true);
  65. System.out.println(" Total time=  "
  66. + (System.currentTimeMillis() - initms) + "  ms ");
  67. }
  68. private void interrogate()
  69. {
  70. double[][] inputArray = new double[][]
  71. {
  72. { 1.0, 1.0 } };
  73. // set the inputs
  74. inputSynapse.setInputArray(inputArray);
  75. inputSynapse.setAdvancedColumnSelector(" 1,2 ");
  76. Monitor monitor = nnet.getMonitor();
  77. monitor.setTrainingPatterns(4);
  78. monitor.setTotCicles(1);
  79. monitor.setLearning(false);
  80. MemoryOutputSynapse memOut = new MemoryOutputSynapse();
  81. // set the output synapse to write the output of the net
  82. if (nnet != null)
  83. {
  84. nnet.addOutputSynapse(memOut);
  85. System.out.println(nnet.check());
  86. nnet.getMonitor().setSingleThreadMode(singleThreadMode);
  87. nnet.go();
  88. for (int i = 0; i < 4; i++)
  89. {
  90. double[] pattern = memOut.getNextPattern();
  91. System.out.println(" Output pattern # " + (i + 1) + " = "
  92. + pattern[0]);
  93. }
  94. System.out.println(" Interrogating Finished ");
  95. }
  96. }
  97. /**
  98. * Method declaration
  99. */
  100. protected void initNeuralNet()
  101. {
  102. // First create the three layers
  103. input = new LinearLayer();
  104. hidden = new SigmoidLayer();
  105. output = new SigmoidLayer();
  106. // set the dimensions of the layers
  107. input.setRows(2);
  108. hidden.setRows(3);
  109. output.setRows(1);
  110. input.setLayerName(" L.input ");
  111. hidden.setLayerName(" L.hidden ");
  112. output.setLayerName(" L.output ");
  113. // Now create the two Synapses
  114. FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */
  115. FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */
  116. // Connect the input layer whit the hidden layer
  117. input.addOutputSynapse(synapse_IH);
  118. hidden.addInputSynapse(synapse_IH);
  119. // Connect the hidden layer whit the output layer
  120. hidden.addOutputSynapse(synapse_HO);
  121. output.addInputSynapse(synapse_HO);
  122. // the input to the neural net
  123. inputSynapse = new MemoryInputSynapse();
  124. input.addInputSynapse(inputSynapse);
  125. // The Trainer and its desired output
  126. desiredOutputSynapse = new MemoryInputSynapse();
  127. TeachingSynapse trainer = new TeachingSynapse();
  128. trainer.setDesired(desiredOutputSynapse);
  129. // Now we add this structure to a NeuralNet object
  130. nnet = new NeuralNet();
  131. nnet.addLayer(input, NeuralNet.INPUT_LAYER);
  132. nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);
  133. nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);
  134. nnet.setTeacher(trainer);
  135. output.addOutputSynapse(trainer);
  136. nnet.addNeuralNetListener(this);
  137. }
  138. public void cicleTerminated(NeuralNetEvent e)
  139. {
  140. }
  141. public void errorChanged(NeuralNetEvent e)
  142. {
  143. Monitor mon = (Monitor) e.getSource();
  144. if (mon.getCurrentCicle() % 100 == 0)
  145. System.out.println(" Epoch:  "
  146. + (mon.getTotCicles() - mon.getCurrentCicle()) + "  RMSE: "
  147. + mon.getGlobalError());
  148. }
  149. public void netStarted(NeuralNetEvent e)
  150. {
  151. Monitor mon = (Monitor) e.getSource();
  152. System.out.print(" Network started for  ");
  153. if (mon.isLearning())
  154. System.out.println(" training. ");
  155. else
  156. System.out.println(" interrogation. ");
  157. }
  158. public void netStopped(NeuralNetEvent e)
  159. {
  160. Monitor mon = (Monitor) e.getSource();
  161. System.out.println(" Network stopped. Last RMSE= "
  162. + mon.getGlobalError());
  163. }
  164. public void netStoppedError(NeuralNetEvent e, String error)
  165. {
  166. System.out.println(" Network stopped due the following error:  "
  167. + error);
  168. }
  169. }

现在我会逐步解释上面的程序。

【1】 从main方法开始说起,首先第一步新建一个对象:

  1. XOR_using_NeuralNet xor = new XOR_using_NeuralNet();

【2】然后初始化神经网络:

  1. xor.initNeuralNet();

初始化神经网络的方法中:

  1. // First create the three layers
  2. input = new LinearLayer();
  3. hidden = new SigmoidLayer();
  4. output = new SigmoidLayer();
  5. // set the dimensions of the layers
  6. input.setRows(2);
  7. hidden.setRows(3);
  8. output.setRows(1);
  9. input.setLayerName(" L.input ");
  10. hidden.setLayerName(" L.hidden ");
  11. output.setLayerName(" L.output ");

上面代码解释:

input=new LinearLayer()是新建一个输入层,因为神经网络的输入层并没有训练参数,所以使用的是线性层;

hidden = new SigmoidLayer();这里是新建一个隐含层,使用sigmoid函数作为激励函数,当然你也可以选择其他的激励函数,如softmax激励函数

output则是新建一个输出层

之后的三行代码是建立输入层、隐含层、输出层的神经元个数,这里表示输入层为2个神经元,隐含层是3个神经元,输出层是1个神经元

最后的三行代码是给每个输出层取一个名字。

  1. // Now create the two Synapses
  2. FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */
  3. FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */
  4. // Connect the input layer whit the hidden layer
  5. input.addOutputSynapse(synapse_IH);
  6. hidden.addInputSynapse(synapse_IH);
  7. // Connect the hidden layer whit the output layer
  8. hidden.addOutputSynapse(synapse_HO);
  9. output.addInputSynapse(synapse_HO);

上面代码解释:

上面代码的主要作用是将三个层连接起来,synapse_IH用来连接输入层和隐含层,synapse_HO用来连接隐含层和输出层

  1. // the input to the neural net
  2. inputSynapse = new MemoryInputSynapse();
  3. input.addInputSynapse(inputSynapse);
  4. // The Trainer and its desired output
  5. desiredOutputSynapse = new MemoryInputSynapse();
  6. TeachingSynapse trainer = new TeachingSynapse();
  7. trainer.setDesired(desiredOutputSynapse);

上面代码解释:

上面的代码是在训练的时候指定输入层的数据和目的输出的数据,

inputSynapse = new MemoryInputSynapse();这里指的是使用了从内存中输入数据的方法,指的是输入层输入数据,当然还有从文件输入的方法,这点在文章后面再谈。同理,desiredOutputSynapse = new MemoryInputSynapse();也是从内存中输入数据,指的是从输入层应该输出的数据

  1. // Now we add this structure to a NeuralNet object
  2. nnet = new NeuralNet();
  3. nnet.addLayer(input, NeuralNet.INPUT_LAYER);
  4. nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);
  5. nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);
  6. nnet.setTeacher(trainer);
  7. output.addOutputSynapse(trainer);
  8. nnet.addNeuralNetListener(this);

上面代码解释:

这段代码指的是将之前初始化的构件连接成一个神经网络,NeuralNet是JOONE提供的类,主要是连接各个神经层,最后一个nnet.addNeuralNetListener(this);这个作用是对神经网络的训练过程进行监听,因为这个类实现了NeuralNetListener这个接口,这个接口有一些方法,可以实现观察神经网络训练过程,有助于参数调整。

【3】然后我们来看一下train这个方法:

  1. inputSynapse.setInputArray(inputArray);
  2. inputSynapse.setAdvancedColumnSelector(" 1,2 ");
  3. // set the desired outputs
  4. desiredOutputSynapse.setInputArray(desiredOutputArray);
  5. desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");

上面代码解释:

inputSynapse.setInputArray(inputArray);这个方法是初始化输入层数据,也就是指定输入层数据的内容,inputArray是程序中给定的二维数组,这也就是为什么之前初始化神经网络的时候使用的是MemoryInputSynapse,表示从内存中读取数据

inputSynapse.setAdvancedColumnSelector(" 1,2 ");这个表示的是输入层数据使用的是inputArray的前两列数据。

desiredOutputSynapse这个也同理

  1. Monitor monitor = nnet.getMonitor();
  2. // set the monitor parameters
  3. monitor.setLearningRate(0.8);
  4. monitor.setMomentum(0.3);
  5. monitor.setTrainingPatterns(inputArray.length);
  6. monitor.setTotCicles(5000);
  7. <span style="line-height: 1.5;">monitor.setLearning(true);

上面代码解释:

这个monitor类也是JOONE框架提供的,主要是用来调节神经网络的参数,monitor.setLearningRate(0.8);是用来设置神经网络训练的步长参数,步长越大,神经网络梯度下降的速度越快,monitor.setTrainingPatterns(inputArray.length);这个是设置神经网络的输入层的训练数据大小size,这里使用的是数组的长度;monitor.setTotCicles(5000);这个指的是设置迭代数目;monitor.setLearning(true);这个true表示是在训练过程。

  1. nnet.getMonitor().setSingleThreadMode(singleThreadMode);
  2. nnet.go(true);

上面代码解释:

nnet.getMonitor().setSingleThreadMode(singleThreadMode);这个指的是是不是使用多线程,但是我不太清楚这里的多线程指的是什么意思

nnet.go(true)表示的是开始训练。

【4】最后来看一下interrogate方法

  1. double[][] inputArray = new double[][]
  2. {
  3. { 1.0, 1.0 } };
  4. // set the inputs
  5. inputSynapse.setInputArray(inputArray);
  6. inputSynapse.setAdvancedColumnSelector(" 1,2 ");
  7. Monitor monitor = nnet.getMonitor();
  8. monitor.setTrainingPatterns(4);
  9. monitor.setTotCicles(1);
  10. monitor.setLearning(false);
  11. MemoryOutputSynapse memOut = new MemoryOutputSynapse();
  12. // set the output synapse to write the output of the net
  13. if (nnet != null)
  14. {
  15. nnet.addOutputSynapse(memOut);
  16. System.out.println(nnet.check());
  17. nnet.getMonitor().setSingleThreadMode(singleThreadMode);
  18. nnet.go();
  19. for (int i = 0; i < 4; i++)
  20. {
  21. double[] pattern = memOut.getNextPattern();
  22. System.out.println(" Output pattern # " + (i + 1) + " = "
  23. + pattern[0]);
  24. }
  25. System.out.println(" Interrogating Finished ");
  26. }

这个方法相当于测试方法,这里的inputArray是测试数据, 注意这里需要设置monitor.setLearning(false);,因为这不是训练过程,并不需要学习,monitor.setTrainingPatterns(4);这个是指测试的数量,4表示有4个测试数据(虽然这里只有一个)。这里还给nnet添加了一个输出层数据对象,这个对象mmOut是初始测试结果,注意到之前我们初始化神经网络的时候并没有给输出层指定数据对象,因为那个时候我们在训练,而且指定了trainer作为目的输出。

接下来就是输出结果数据了,pattern的个数和输出层的神经元个数一样大,这里输出层神经元的个数是1,所以pattern大小为1.

【5】我们看一下测试结果:

  1. Output pattern # 1 = 0.018303527517809233

表示输出结果为0.01,根据sigmoid函数特性,我们得到的输出是0,和预期结果一致。如果输出层神经元个数大于1,那么输出值将会有多个,因为输出层结果是0|1离散值,所以我们取输出最大的那个神经元的输出值取为1,其他为0

【6】最后我们来看一下神经网络训练过程中的一些监听函数:

cicleTerminated:每个循环结束后输出的信息

errorChanged:神经网络错误率变化时候输出的信息

netStarted:神经网络开始运行的时候输出的信息

netStopped:神经网络停止的时候输出的信息

【7】好了,JOONE基本上内容就是这些。还有一些额外东西需要说明:

1,从文件中读取数据构建神经网络

2.如何保存训练好的神经网络到文件夹中,只要测试的时候直接load到内存中就行,而不用每次都需要训练。

【8】先看第一个问题:

从文件中读取数据:

文件的格式:

0;0;0

1;0;1

1;1;0

0;1;1

中间使用分号隔开,使用方法如下,也就是把上文的MemoryInputSynapse换成FileInputSynapse即可。

  1. fileInputSynapse = new FileInputSynapse();
  2. input.addInputSynapse(fileInputSynapse);
  3. fileDisireOutputSynapse = new FileInputSynapse();
  4. TeachingSynapse trainer = new TeachingSynapse();
  5. trainer.setDesired(fileDisireOutputSynapse);

我们看下文件是如何输出数据的:

  1. private File inputFile = new File(Constants.TRAIN_WORD_VEC_PATH);
  2. fileInputSynapse.setInputFile(inputFile);
  3. fileInputSynapse.setFirstCol(2);//使用文件的第2列到第3列作为输出层输入
  4. fileInputSynapse.setLastCol(3);
  1. fileDisireOutputSynapse.setInputFile(inputFile);
  2. fileDisireOutputSynapse.setFirstCol(1);//使用文件的第1列作为输出数据
  3. fileDisireOutputSynapse.setLastCol(1);

其余的代码和上文的是一样的。

【9】然后看第二个问题:

如何保存神经网络

其实很简单,直接序列化nnet对象就行了,然后读取该对象就是java的反序列化,这个就不多做介绍了,比较简单。但是需要说明的是,保存神经网络的时机一定是在神经网络训练完毕后,可以使用下面代码:

    1. public void netStopped(NeuralNetEvent e) {
    2. Monitor mon = (Monitor) e.getSource();
    3. try {
    4. if (mon.isLearning()) {
    5. saveModel(nnet); //序列化对象
    6. }
    7. } catch (IOException ee) {
    8. // TODO Auto-generated catch block
    9. ee.printStackTrace();
    10. }

LSTM java 实现的更多相关文章

  1. Spark案例分析

    一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...

  2. Python中利用LSTM模型进行时间序列预测分析

    时间序列模型 时间序列预测分析就是利用过去一段时间内某事件时间的特征来预测未来一段时间内该事件的特征.这是一类相对比较复杂的预测建模问题,和回归分析模型的预测不同,时间序列模型是依赖于事件发生的先后顺 ...

  3. 新手教程之:循环网络和LSTM指南 (A Beginner’s Guide to Recurrent Networks and LSTMs)

    新手教程之:循环网络和LSTM指南 (A Beginner’s Guide to Recurrent Networks and LSTMs) 本文翻译自:http://deeplearning4j.o ...

  4. [转] 图 + 文 + 公式 理解LSTM

    转自公号“机器之心” LSTM入门必读:从入门基础到工作方式详解 长短期记忆(LSTM)是一种非常重要的神经网络技术,其在语音识别和自然语言处理等许多领域都得到了广泛的应用..在这篇文章中,Edwin ...

  5. 机器学习与Tensorflow(6)——LSTM的Tensorflow实现、Tensorboard简单实现、CNN应用

    最近写的一些程序以及做的一个关于轴承故障诊断的程序 最近学习进度有些慢 而且马上假期 要去补习班 去赚下学期生活费 额.... 抓紧时间再多学习点 1.RNN递归神经网络Tensorflow实现程序 ...

  6. Tesseract:简单的Java光学字符识别

    1.1 介绍 开发具有一定价值的符号是人类特有的特征.对于人们来说识别这些符号和理解图片上的文字是非常正常的事情.与计算机那样去抓取文字不同,我们完全是基于视觉的本能去阅读它们. 另一方面,计算机的工 ...

  7. 尚学堂JAVA基础学习笔记

    目录 尚学堂JAVA基础学习笔记 写在前面 第1章 JAVA入门 第2章 数据类型和运算符 第3章 控制语句 第4章 Java面向对象基础 1. 面向对象基础 2. 面向对象的内存分析 3. 构造方法 ...

  8. java 读取CSV数据并写入txt文本

    java 读取CSV数据并写入txt文本 package com.vfsd; import java.io.BufferedWriter; import java.io.File; import ja ...

  9. Tika结合Tesseract-OCR 实现光学汉字识别(简体、宋体的识别率百分之百)—附Java源码、测试数据和训练集下载地址

     OCR(Optical character recognition) —— 光学字符识别,是图像处理的一个重要分支,中文的识别具有一定挑战性,特别是手写体和草书的识别,是重要和热门的科学研究方向.可 ...

随机推荐

  1. 170313、poi:采用自定义注解的方式导入、导出excel(这种方式比较好扩展)

    步骤一.自定义注解 步骤二.写Excel泛型工具类 步骤三.在需要导出excel的类属相上加上自定义注解,并设置 步骤四.写service,controller 步骤一:自定义注解 import ja ...

  2. 浏览器加载不上css,样式走丢

    来自:http://www.cnblogs.com/crizygo/p/5466444.html 问题描述:使用eclipse修改样式文件,浏览器的页面一时显示一时不显示,最后直接没有加载最新的css ...

  3. ORACLE内存结构:PGA And UGA,ORACLE用户进程、服务器进程

    执行一个SQL语句 执行查询语句的过程: 用户进程执行一个查询语句如select * from emp where empno=7839 用户进程和服务器进程建立连接,把改用户进程的信息存储到PGA的 ...

  4. Reference counted objects

    Reference counted objects · netty/netty Wiki https://github.com/netty/netty/wiki/Reference-counted-o ...

  5. JavaScript自定义函数

    js对象转成用&拼接的请求参数(转) var parseParam=function(param, key){ var paramStr=""; if(param inst ...

  6. Spring Data 之 Repository 接口

    1. 介绍 Repository是一个空接口,即是一个标记性接口; 若我们定义的接口继承了Repository,则该接口会被IOC容器识别为一个 Repository Bean; 也可以通过@Repo ...

  7. scrapy之中间件

    中间件的简介 1.中间件的作用 在scrapy运行的整个过程中,对scrapy框架运行的某些步骤做一些适配自己项目的动作. 例如scrapy内置的HttpErrorMiddleware,可以在http ...

  8. Mybatis框架学习总结-表的关联查询

    一对一关联 创建表和数据:创建一张教师表和班级表,这里假设一个老师只负责教一个班,那么老师和班级之间的关系就是一种一对一的关系. CREATE TABLE teacher( t_id INT PRIM ...

  9. 【云安全与同态加密_调研分析(3)】国内云安全组织及标准——By Me

    ◆3. 国内云安全组织及标准◆ ◆云安全标准机构(主要的)◆ ◆标准机构介绍◆ ◆相关标准制定◆ ◆建立的相关模型参考◆ ◆备注(其他参考信息)◆ ★中国通信标准化协会(CCSA) ●组织简介:200 ...

  10. vue的项目结构

    一. 准备工作 1. 初始化项目    vue init webpack itany    cd itany    cnpm install    cnpm install less less-loa ...