由于实验室事情缘故,需要将Python写的神经网络转成Java版本的,但是python中的numpy等啥包也不知道在Java里面对应的是什么工具,所以索性直接寻找一个现成可用的Java神经网络框架,于是就找到了JOONE,JOONE是一个神经网络的开源框架,使用的是BP算法进行迭代计算参数,使用起来比较方便也比较实用,下面介绍一下JOONE的一些使用方法。

JOONE需要使用一些外部的依赖包,这在官方网站上有,也可以在这里下载。将所需的包引入工程之后,就可以进行编码实现了。

首先看下完整的程序,这个是上面那个超链接给出的程序,应该是官方给出的一个示例吧,因为好多文章都用这个,这其实是神经网络训练一个异或计算器:

  1. import org.joone.engine.*;
  2. import org.joone.engine.learning.*;
  3. import org.joone.io.*;
  4. import org.joone.net.*;
  5. /*
  6. *
  7. * JOONE实现
  8. *
  9. * */
  10. public class XOR_using_NeuralNet implements NeuralNetListener
  11. {
  12. private NeuralNet nnet = null;
  13. private MemoryInputSynapse inputSynapse, desiredOutputSynapse;
  14. LinearLayer input;
  15. SigmoidLayer hidden, output;
  16. boolean singleThreadMode = true;
  17. // XOR input
  18. private double[][] inputArray = new double[][]
  19. {
  20. { 0.0, 0.0 },
  21. { 0.0, 1.0 },
  22. { 1.0, 0.0 },
  23. { 1.0, 1.0 } };
  24. // XOR desired output
  25. private double[][] desiredOutputArray = new double[][]
  26. {
  27. { 0.0 },
  28. { 1.0 },
  29. { 1.0 },
  30. { 0.0 } };
  31. /**
  32. * @param args
  33. *            the command line arguments
  34. */
  35. public static void main(String args[])
  36. {
  37. XOR_using_NeuralNet xor = new XOR_using_NeuralNet();
  38. xor.initNeuralNet();
  39. xor.train();
  40. xor.interrogate();
  41. }
  42. /**
  43. * Method declaration
  44. */
  45. public void train()
  46. {
  47. // set the inputs
  48. inputSynapse.setInputArray(inputArray);
  49. inputSynapse.setAdvancedColumnSelector(" 1,2 ");
  50. // set the desired outputs
  51. desiredOutputSynapse.setInputArray(desiredOutputArray);
  52. desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");
  53. // get the monitor object to train or feed forward
  54. Monitor monitor = nnet.getMonitor();
  55. // set the monitor parameters
  56. monitor.setLearningRate(0.8);
  57. monitor.setMomentum(0.3);
  58. monitor.setTrainingPatterns(inputArray.length);
  59. monitor.setTotCicles(5000);
  60. monitor.setLearning(true);
  61. long initms = System.currentTimeMillis();
  62. // Run the network in single-thread, synchronized mode
  63. nnet.getMonitor().setSingleThreadMode(singleThreadMode);
  64. nnet.go(true);
  65. System.out.println(" Total time=  "
  66. + (System.currentTimeMillis() - initms) + "  ms ");
  67. }
  68. private void interrogate()
  69. {
  70. double[][] inputArray = new double[][]
  71. {
  72. { 1.0, 1.0 } };
  73. // set the inputs
  74. inputSynapse.setInputArray(inputArray);
  75. inputSynapse.setAdvancedColumnSelector(" 1,2 ");
  76. Monitor monitor = nnet.getMonitor();
  77. monitor.setTrainingPatterns(4);
  78. monitor.setTotCicles(1);
  79. monitor.setLearning(false);
  80. MemoryOutputSynapse memOut = new MemoryOutputSynapse();
  81. // set the output synapse to write the output of the net
  82. if (nnet != null)
  83. {
  84. nnet.addOutputSynapse(memOut);
  85. System.out.println(nnet.check());
  86. nnet.getMonitor().setSingleThreadMode(singleThreadMode);
  87. nnet.go();
  88. for (int i = 0; i < 4; i++)
  89. {
  90. double[] pattern = memOut.getNextPattern();
  91. System.out.println(" Output pattern # " + (i + 1) + " = "
  92. + pattern[0]);
  93. }
  94. System.out.println(" Interrogating Finished ");
  95. }
  96. }
  97. /**
  98. * Method declaration
  99. */
  100. protected void initNeuralNet()
  101. {
  102. // First create the three layers
  103. input = new LinearLayer();
  104. hidden = new SigmoidLayer();
  105. output = new SigmoidLayer();
  106. // set the dimensions of the layers
  107. input.setRows(2);
  108. hidden.setRows(3);
  109. output.setRows(1);
  110. input.setLayerName(" L.input ");
  111. hidden.setLayerName(" L.hidden ");
  112. output.setLayerName(" L.output ");
  113. // Now create the two Synapses
  114. FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */
  115. FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */
  116. // Connect the input layer whit the hidden layer
  117. input.addOutputSynapse(synapse_IH);
  118. hidden.addInputSynapse(synapse_IH);
  119. // Connect the hidden layer whit the output layer
  120. hidden.addOutputSynapse(synapse_HO);
  121. output.addInputSynapse(synapse_HO);
  122. // the input to the neural net
  123. inputSynapse = new MemoryInputSynapse();
  124. input.addInputSynapse(inputSynapse);
  125. // The Trainer and its desired output
  126. desiredOutputSynapse = new MemoryInputSynapse();
  127. TeachingSynapse trainer = new TeachingSynapse();
  128. trainer.setDesired(desiredOutputSynapse);
  129. // Now we add this structure to a NeuralNet object
  130. nnet = new NeuralNet();
  131. nnet.addLayer(input, NeuralNet.INPUT_LAYER);
  132. nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);
  133. nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);
  134. nnet.setTeacher(trainer);
  135. output.addOutputSynapse(trainer);
  136. nnet.addNeuralNetListener(this);
  137. }
  138. public void cicleTerminated(NeuralNetEvent e)
  139. {
  140. }
  141. public void errorChanged(NeuralNetEvent e)
  142. {
  143. Monitor mon = (Monitor) e.getSource();
  144. if (mon.getCurrentCicle() % 100 == 0)
  145. System.out.println(" Epoch:  "
  146. + (mon.getTotCicles() - mon.getCurrentCicle()) + "  RMSE: "
  147. + mon.getGlobalError());
  148. }
  149. public void netStarted(NeuralNetEvent e)
  150. {
  151. Monitor mon = (Monitor) e.getSource();
  152. System.out.print(" Network started for  ");
  153. if (mon.isLearning())
  154. System.out.println(" training. ");
  155. else
  156. System.out.println(" interrogation. ");
  157. }
  158. public void netStopped(NeuralNetEvent e)
  159. {
  160. Monitor mon = (Monitor) e.getSource();
  161. System.out.println(" Network stopped. Last RMSE= "
  162. + mon.getGlobalError());
  163. }
  164. public void netStoppedError(NeuralNetEvent e, String error)
  165. {
  166. System.out.println(" Network stopped due the following error:  "
  167. + error);
  168. }
  169. }

现在我会逐步解释上面的程序。

【1】 从main方法开始说起,首先第一步新建一个对象:

  1. XOR_using_NeuralNet xor = new XOR_using_NeuralNet();

【2】然后初始化神经网络:

  1. xor.initNeuralNet();

初始化神经网络的方法中:

  1. // First create the three layers
  2. input = new LinearLayer();
  3. hidden = new SigmoidLayer();
  4. output = new SigmoidLayer();
  5. // set the dimensions of the layers
  6. input.setRows(2);
  7. hidden.setRows(3);
  8. output.setRows(1);
  9. input.setLayerName(" L.input ");
  10. hidden.setLayerName(" L.hidden ");
  11. output.setLayerName(" L.output ");

上面代码解释:

input=new LinearLayer()是新建一个输入层,因为神经网络的输入层并没有训练参数,所以使用的是线性层;

hidden = new SigmoidLayer();这里是新建一个隐含层,使用sigmoid函数作为激励函数,当然你也可以选择其他的激励函数,如softmax激励函数

output则是新建一个输出层

之后的三行代码是建立输入层、隐含层、输出层的神经元个数,这里表示输入层为2个神经元,隐含层是3个神经元,输出层是1个神经元

最后的三行代码是给每个输出层取一个名字。

  1. // Now create the two Synapses
  2. FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */
  3. FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */
  4. // Connect the input layer whit the hidden layer
  5. input.addOutputSynapse(synapse_IH);
  6. hidden.addInputSynapse(synapse_IH);
  7. // Connect the hidden layer whit the output layer
  8. hidden.addOutputSynapse(synapse_HO);
  9. output.addInputSynapse(synapse_HO);

上面代码解释:

上面代码的主要作用是将三个层连接起来,synapse_IH用来连接输入层和隐含层,synapse_HO用来连接隐含层和输出层

  1. // the input to the neural net
  2. inputSynapse = new MemoryInputSynapse();
  3. input.addInputSynapse(inputSynapse);
  4. // The Trainer and its desired output
  5. desiredOutputSynapse = new MemoryInputSynapse();
  6. TeachingSynapse trainer = new TeachingSynapse();
  7. trainer.setDesired(desiredOutputSynapse);

上面代码解释:

上面的代码是在训练的时候指定输入层的数据和目的输出的数据,

inputSynapse = new MemoryInputSynapse();这里指的是使用了从内存中输入数据的方法,指的是输入层输入数据,当然还有从文件输入的方法,这点在文章后面再谈。同理,desiredOutputSynapse = new MemoryInputSynapse();也是从内存中输入数据,指的是从输入层应该输出的数据

  1. // Now we add this structure to a NeuralNet object
  2. nnet = new NeuralNet();
  3. nnet.addLayer(input, NeuralNet.INPUT_LAYER);
  4. nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);
  5. nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);
  6. nnet.setTeacher(trainer);
  7. output.addOutputSynapse(trainer);
  8. nnet.addNeuralNetListener(this);

上面代码解释:

这段代码指的是将之前初始化的构件连接成一个神经网络,NeuralNet是JOONE提供的类,主要是连接各个神经层,最后一个nnet.addNeuralNetListener(this);这个作用是对神经网络的训练过程进行监听,因为这个类实现了NeuralNetListener这个接口,这个接口有一些方法,可以实现观察神经网络训练过程,有助于参数调整。

【3】然后我们来看一下train这个方法:

  1. inputSynapse.setInputArray(inputArray);
  2. inputSynapse.setAdvancedColumnSelector(" 1,2 ");
  3. // set the desired outputs
  4. desiredOutputSynapse.setInputArray(desiredOutputArray);
  5. desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");

上面代码解释:

inputSynapse.setInputArray(inputArray);这个方法是初始化输入层数据,也就是指定输入层数据的内容,inputArray是程序中给定的二维数组,这也就是为什么之前初始化神经网络的时候使用的是MemoryInputSynapse,表示从内存中读取数据

inputSynapse.setAdvancedColumnSelector(" 1,2 ");这个表示的是输入层数据使用的是inputArray的前两列数据。

desiredOutputSynapse这个也同理

  1. Monitor monitor = nnet.getMonitor();
  2. // set the monitor parameters
  3. monitor.setLearningRate(0.8);
  4. monitor.setMomentum(0.3);
  5. monitor.setTrainingPatterns(inputArray.length);
  6. monitor.setTotCicles(5000);
  7. <span style="line-height: 1.5;">monitor.setLearning(true);

上面代码解释:

这个monitor类也是JOONE框架提供的,主要是用来调节神经网络的参数,monitor.setLearningRate(0.8);是用来设置神经网络训练的步长参数,步长越大,神经网络梯度下降的速度越快,monitor.setTrainingPatterns(inputArray.length);这个是设置神经网络的输入层的训练数据大小size,这里使用的是数组的长度;monitor.setTotCicles(5000);这个指的是设置迭代数目;monitor.setLearning(true);这个true表示是在训练过程。

  1. nnet.getMonitor().setSingleThreadMode(singleThreadMode);
  2. nnet.go(true);

上面代码解释:

nnet.getMonitor().setSingleThreadMode(singleThreadMode);这个指的是是不是使用多线程,但是我不太清楚这里的多线程指的是什么意思

nnet.go(true)表示的是开始训练。

【4】最后来看一下interrogate方法

  1. double[][] inputArray = new double[][]
  2. {
  3. { 1.0, 1.0 } };
  4. // set the inputs
  5. inputSynapse.setInputArray(inputArray);
  6. inputSynapse.setAdvancedColumnSelector(" 1,2 ");
  7. Monitor monitor = nnet.getMonitor();
  8. monitor.setTrainingPatterns(4);
  9. monitor.setTotCicles(1);
  10. monitor.setLearning(false);
  11. MemoryOutputSynapse memOut = new MemoryOutputSynapse();
  12. // set the output synapse to write the output of the net
  13. if (nnet != null)
  14. {
  15. nnet.addOutputSynapse(memOut);
  16. System.out.println(nnet.check());
  17. nnet.getMonitor().setSingleThreadMode(singleThreadMode);
  18. nnet.go();
  19. for (int i = 0; i < 4; i++)
  20. {
  21. double[] pattern = memOut.getNextPattern();
  22. System.out.println(" Output pattern # " + (i + 1) + " = "
  23. + pattern[0]);
  24. }
  25. System.out.println(" Interrogating Finished ");
  26. }

这个方法相当于测试方法,这里的inputArray是测试数据, 注意这里需要设置monitor.setLearning(false);,因为这不是训练过程,并不需要学习,monitor.setTrainingPatterns(4);这个是指测试的数量,4表示有4个测试数据(虽然这里只有一个)。这里还给nnet添加了一个输出层数据对象,这个对象mmOut是初始测试结果,注意到之前我们初始化神经网络的时候并没有给输出层指定数据对象,因为那个时候我们在训练,而且指定了trainer作为目的输出。

接下来就是输出结果数据了,pattern的个数和输出层的神经元个数一样大,这里输出层神经元的个数是1,所以pattern大小为1.

【5】我们看一下测试结果:

  1. Output pattern # 1 = 0.018303527517809233

表示输出结果为0.01,根据sigmoid函数特性,我们得到的输出是0,和预期结果一致。如果输出层神经元个数大于1,那么输出值将会有多个,因为输出层结果是0|1离散值,所以我们取输出最大的那个神经元的输出值取为1,其他为0

【6】最后我们来看一下神经网络训练过程中的一些监听函数:

cicleTerminated:每个循环结束后输出的信息

errorChanged:神经网络错误率变化时候输出的信息

netStarted:神经网络开始运行的时候输出的信息

netStopped:神经网络停止的时候输出的信息

【7】好了,JOONE基本上内容就是这些。还有一些额外东西需要说明:

1,从文件中读取数据构建神经网络

2.如何保存训练好的神经网络到文件夹中,只要测试的时候直接load到内存中就行,而不用每次都需要训练。

【8】先看第一个问题:

从文件中读取数据:

文件的格式:

0;0;0

1;0;1

1;1;0

0;1;1

中间使用分号隔开,使用方法如下,也就是把上文的MemoryInputSynapse换成FileInputSynapse即可。

  1. fileInputSynapse = new FileInputSynapse();
  2. input.addInputSynapse(fileInputSynapse);
  3. fileDisireOutputSynapse = new FileInputSynapse();
  4. TeachingSynapse trainer = new TeachingSynapse();
  5. trainer.setDesired(fileDisireOutputSynapse);

我们看下文件是如何输出数据的:

  1. private File inputFile = new File(Constants.TRAIN_WORD_VEC_PATH);
  2. fileInputSynapse.setInputFile(inputFile);
  3. fileInputSynapse.setFirstCol(2);//使用文件的第2列到第3列作为输出层输入
  4. fileInputSynapse.setLastCol(3);
  1. fileDisireOutputSynapse.setInputFile(inputFile);
  2. fileDisireOutputSynapse.setFirstCol(1);//使用文件的第1列作为输出数据
  3. fileDisireOutputSynapse.setLastCol(1);

其余的代码和上文的是一样的。

【9】然后看第二个问题:

如何保存神经网络

其实很简单,直接序列化nnet对象就行了,然后读取该对象就是java的反序列化,这个就不多做介绍了,比较简单。但是需要说明的是,保存神经网络的时机一定是在神经网络训练完毕后,可以使用下面代码:

    1. public void netStopped(NeuralNetEvent e) {
    2. Monitor mon = (Monitor) e.getSource();
    3. try {
    4. if (mon.isLearning()) {
    5. saveModel(nnet); //序列化对象
    6. }
    7. } catch (IOException ee) {
    8. // TODO Auto-generated catch block
    9. ee.printStackTrace();
    10. }

LSTM java 实现的更多相关文章

  1. Spark案例分析

    一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...

  2. Python中利用LSTM模型进行时间序列预测分析

    时间序列模型 时间序列预测分析就是利用过去一段时间内某事件时间的特征来预测未来一段时间内该事件的特征.这是一类相对比较复杂的预测建模问题,和回归分析模型的预测不同,时间序列模型是依赖于事件发生的先后顺 ...

  3. 新手教程之:循环网络和LSTM指南 (A Beginner’s Guide to Recurrent Networks and LSTMs)

    新手教程之:循环网络和LSTM指南 (A Beginner’s Guide to Recurrent Networks and LSTMs) 本文翻译自:http://deeplearning4j.o ...

  4. [转] 图 + 文 + 公式 理解LSTM

    转自公号“机器之心” LSTM入门必读:从入门基础到工作方式详解 长短期记忆(LSTM)是一种非常重要的神经网络技术,其在语音识别和自然语言处理等许多领域都得到了广泛的应用..在这篇文章中,Edwin ...

  5. 机器学习与Tensorflow(6)——LSTM的Tensorflow实现、Tensorboard简单实现、CNN应用

    最近写的一些程序以及做的一个关于轴承故障诊断的程序 最近学习进度有些慢 而且马上假期 要去补习班 去赚下学期生活费 额.... 抓紧时间再多学习点 1.RNN递归神经网络Tensorflow实现程序 ...

  6. Tesseract:简单的Java光学字符识别

    1.1 介绍 开发具有一定价值的符号是人类特有的特征.对于人们来说识别这些符号和理解图片上的文字是非常正常的事情.与计算机那样去抓取文字不同,我们完全是基于视觉的本能去阅读它们. 另一方面,计算机的工 ...

  7. 尚学堂JAVA基础学习笔记

    目录 尚学堂JAVA基础学习笔记 写在前面 第1章 JAVA入门 第2章 数据类型和运算符 第3章 控制语句 第4章 Java面向对象基础 1. 面向对象基础 2. 面向对象的内存分析 3. 构造方法 ...

  8. java 读取CSV数据并写入txt文本

    java 读取CSV数据并写入txt文本 package com.vfsd; import java.io.BufferedWriter; import java.io.File; import ja ...

  9. Tika结合Tesseract-OCR 实现光学汉字识别(简体、宋体的识别率百分之百)—附Java源码、测试数据和训练集下载地址

     OCR(Optical character recognition) —— 光学字符识别,是图像处理的一个重要分支,中文的识别具有一定挑战性,特别是手写体和草书的识别,是重要和热门的科学研究方向.可 ...

随机推荐

  1. 【BZOJ4518】[Sdoi2016]征途 斜率优化

    [BZOJ4518][Sdoi2016]征途 Description Pine开始了从S地到T地的征途. 从S地到T地的路可以划分成n段,相邻两段路的分界点设有休息站. Pine计划用m天到达T地.除 ...

  2. PHP HTTP协议(报头/状态码/缓存)

    一.HTTP协议介绍 1. #HTTP协议       # (1 建立在TCP/IP协议基础上       # (2 web开发数据传输依赖于http协议       # (3 http 协议全称是文 ...

  3. rabbitmq 用户管理

    rabbitmqctl add_user root cor2016 rabbitmqctl set_user_tags root administrator http://host:15672/#/u ...

  4. IBM DEVOPS IN CLOUD--chaos monkey

  5. linux文件与目录管理命令(ubuntu)

    ls:列出目录 选项与参数: -a:全部文件,隐藏档(开头为.的文件)也会列出: -d:仅列出目录本身(也就是 . ),而不是目录下的所有文件及目录: -l:长字符串列出,包括文件的属性.权限等数据.

  6. Python 进程(process)

    1. 进程 1.1 进程的创建 fork 正在运行着的代码,就称为进程 # 示例: import os # 注意: fork 函数,只在 Unix/Linux/Mac 上运行, windows 不可以 ...

  7. FreeMarker 的使用方法

    1.FreeMarker 概述 FreeMarker 是一个用Java语言编写的模板引擎,使用模板来生成文本输出;主要用于做静态页面或页面展示; 2.FreeMarker 使用 // 导入jar包: ...

  8. (0)linux下的Mysql安装与基本使用(编译安装)

    一.大致操作步骤 环境介绍: OS:center OS6.5 mysql:5.6版本 1.关闭防火墙 查看防火墙状态:service iptables status 这样就意味着没有关闭. 运行以下命 ...

  9. python3 用requests 保存网页以及BeautifulSoup保存图片,并且在本地可以正常显示文章的内容和图片

    用requests 模块做了个简单的爬虫小程序,将博客的一篇文章以及图片保存到本地,文章格式存为'.html'.当文章保存到本地后,图片的连接可能是目标站点的绝对或者相对路径,所以要是想在本地也显示图 ...

  10. IOS自动化定位方式

    原文地址http://blog.csdn.net/wuyepiaoxue789/article/details/77885136 元素属性的介绍 type:元素类型,与className作用一致,如: ...