1.添加项目
maven添加依赖 or 导入jar包 or 使用jvm

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion> <groupId>YOURPROJECTNAME.com</groupId>
<artifactId>YOURPROJECTNAME</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging> <name>YOURNAME</name>
<url>http://maven.apache.org</url> <properties>
<nd4j.backend>nd4j-native-platform</nd4j.backend>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<shadedClassifier>bin</shadedClassifier>
<java.version>1.7</java.version>
<nd4j.version>0.6.0</nd4j.version>
<dl4j.version>0.6.0</dl4j.version>
<datavec.version>0.6.0</datavec.version>
<arbiter.version>0.6.0</arbiter.version>
<guava.version>19.0</guava.version>
<logback.version>1.1.7</logback.version>
<jfreechart.version>1.0.13</jfreechart.version>
<maven-shade-plugin.version>2.4.3</maven-shade-plugin.version>
<exec-maven-plugin.version>1.4.0</exec-maven-plugin.version>
<maven.minimum.version>3.3.1</maven.minimum.version>
</properties> <dependencyManagement>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-7.5-platform</artifactId>
<version>${nd4j.version}</version>
</dependency>
</dependencies>
</dependencyManagement> <dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>3.8.1</version>
<scope>test</scope>
</dependency>
<!-- ND4J后端。每个DL4J项目都需要一个。一般将artifactId指定为"nd4j-native-platform"或者"nd4j-cuda-7.5-platform" -->
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
</dependency> <!-- DL4J核心功能 -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency> <dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>${dl4j.version}</version>
</dependency> <!-- deeplearning4j-ui用于HistogramIterationListener + 可视化:参见http://deeplearning4j.org/cn/visualization -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-ui</artifactId>
<version>${dl4j.version}</version>
</dependency> <!-- 强制指定使用UI/HistogramIterationListener时的guava版本 -->
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>${guava.version}</version>
</dependency> <!-- datavec-data-codec:仅用于在视频处理示例中加载视频数据 -->
<dependency>
<artifactId>datavec-data-codec</artifactId>
<groupId>org.datavec</groupId>
<version>${datavec.version}</version>
</dependency> <!-- 用于前馈/分类/MLP*和前馈/回归/RegressionMathFunctions示例 -->
<dependency>
<groupId>jfree</groupId>
<artifactId>jfreechart</artifactId>
<version>${jfreechart.version}</version>
</dependency> <!-- Arbiter:用于超参数优化示例 -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-deeplearning4j</artifactId>
<version>${arbiter.version}</version>
</dependency>
</dependencies>
</project>

2.项目引用库

import org.deeplearning4j.nn.multilayer._
import org.deeplearning4j.nn.graph._
import org.deeplearning4j.nn.conf._
import org.deeplearning4j.nn.conf.inputs._
import org.deeplearning4j.nn.conf.layers._
import org.deeplearning4j.nn.weights._
import org.deeplearning4j.optimize.listeners._
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator
import org.deeplearning4j.eval.Evaluation import org.nd4j.linalg.learning.config._ // for different updaters like Adam, Nesterovs, etc.
import org.nd4j.linalg.activations.Activation // defines different activation functions like RELU, SOFTMAX, etc.
import org.nd4j.linalg.lossfunctions.LossFunctions // mean squared error, multiclass cross entropy, etc.

3.准备加载数据
dl4j有数据迭代器。帮助批处理和迭代数据集。Deeplearning4j带有一个内置的BaseDatasetIteratorEMNIST 实现,
称为EmnistDataSetIterator。这个特殊的迭代器是一个便利实用程序,用于处理数据的下载和准备。
可以创建多个数据迭代器,用于训练模型或者评估模型等。
创建迭代器代码

import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator //引入数据迭代器库

val batchSize = 16 // how many examples to simultaneously train in the network //数据集大小
val emnistSet = EmnistDataSetIterator.Set.BALANCED
val emnistTrain = new EmnistDataSetIterator(emnistSet, batchSize, true) //实例化训练迭代器
val emnistTest = new EmnistDataSetIterator(emnistSet, batchSize, false) //实例化评估迭代器

4.建立神经网络
在dl4j中使用的任何与神经网络有关的操作是在NeuralNetConfiguration类中的。可在此处配置超参数和算法的学习方式。

val outputNum = EmnistDataSetIterator.numLabels(emnistSet) // total output classes
val rngSeed = 123 // integer for reproducability of a random number generator
val numRows = 28 // number of "pixel rows" in an mnist digit
val numColumns = 28 val conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Adam())
.l2(1e-4)
.list()
.layer(new DenseLayer.Builder()
.nIn(numRows * numColumns) // Number of input datapoints.
.nOut(1000) // Number of output datapoints.
.activation(Activation.RELU) // Activation function.
.weightInit(WeightInit.XAVIER) // Weight initialization.
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(1000)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.build())
.pretrain(false).backprop(true)
.build()

5.训练模型
现在我们已经构建了一个NeuralNetConfiguration,我们可以使用配置来实例化一个MultiLayerNetwork。当我们init()在网络上调用该 方法时,它会在网络上应用所选的权重初始化,并允许我们将数据传递给训练。如果我们想在培训期间看到损失分数,我们也可以将听众传递给网络。
实例化模型有一个fit()接受数据集迭代器(扩展的迭代器BaseDatasetIterator),单个DataSet或ND数组(实现INDArray)的方法。由于我们的EMNIST迭代器已经扩展了迭代器基类,我们可以直接传递它来适应。如果我们想要训练多个时代,DL4J还提供了一个MultipleEpochsIterator可以为我们处理多个时代的类。

// create the MLN
val network = new MultiLayerNetwork(conf)
network.init() // pass a training listener that reports score every 10 iterations
val eachIterations = 5
network.addListeners(new ScoreIterationListener(eachIterations)) // fit a dataset for a single epoch
// network.fit(emnistTrain) // fit for multiple epochs
// val numEpochs = 2
// network.fit(new MultipleEpochsIterator(numEpochs, emnistTrain))

6.评估模型
Deeplearning4j公开了几种工具来评估模型的性能。您可以执行基本评估并获取精度和准确度等指标,或使用接收器操作特性(ROC)。请注意,通用ROC类适用于二进制分类器,而ROCMultiClass适用于分类器,例如我们在此构建的模型。

A MultiLayerNetwork方便地有一些内置的方法来帮助我们进行评估。您可以将包含测试/验证数据的数据集迭代器传递给evaluate()方法。

// evaluate basic performance
val eval = network.evaluate(emnistTest)
eval.accuracy()
eval.precision()
eval.recall() // evaluate ROC and calculate the Area Under Curve
val roc = network.evaluateROCMultiClass(emnistTest)
roc.calculateAverageAUC() val classIndex = 0
roc.calculateAUC(classIndex) // optionally, you can print all stats from the evaluations
print(eval.stats())
print(roc.stats())
// evaluate basic performance
val eval = network.evaluate(emnistTest)
eval.accuracy()
eval.precision()
eval.recall() // evaluate ROC and calculate the Area Under Curve
val roc = network.evaluateROCMultiClass(emnistTest)
roc.calculateAverageAUC() val classIndex = 0
roc.calculateAUC(classIndex) // optionally, you can print all stats from the evaluations
print(eval.stats())
print(roc.stats())

Java机器学习框架deeplearing4j入门教程的更多相关文章

  1. (转)Java任务调度框架Quartz入门教程指南(二) 使用job、trigger、schedule调用定时任务

    http://blog.csdn.net/zixiao217/article/details/53044890 读完第一节,我们已经对Quartz有了一个大体的认识,它可以定时帮我们执行一些处理程序, ...

  2. (转)Java任务调度框架Quartz入门教程指南(四)Quartz任务调度框架之触发器精讲SimpleTrigger和CronTrigger、最详细的Cron表达式范例

    http://blog.csdn.net/zixiao217/article/details/53075009 Quartz的主要接口类是Schedule.Job.Trigger,而触发器Trigge ...

  3. (转)Java任务调度框架Quartz入门教程指南(三)任务调度框架Quartz实例详解深入理解Scheduler,Job,Trigger,JobDetail

    http://blog.csdn.net/zixiao217/article/details/53053598 首先给一个简明扼要的理解: Scheduler 调度程序-任务执行计划表,只有安排进执行 ...

  4. Veins(车载通信仿真框架)入门教程(四)——调试及记录结果

    Veins(车载通信仿真框架)入门教程(四)——调试及记录结果 在Veins入门教程(三)最后的动图中(如下图)可以看到大大小小的光圈,这个怎么实现的呢? 很简单,以收到RTS消息为例,通过finHo ...

  5. Veins(车载通信仿真框架)入门教程(三)——多跳路由实现指导

    Veins(车载通信仿真框架)入门教程(三)——多跳路由实现指导 Veins(车载通信仿真框架)入门教程(三)——多跳路由实现指导 必要的message类实现 从下面开始是在veins/src/vei ...

  6. Veins(车载通信仿真框架)入门教程(二)——调用第三方库

    Veins(车载通信仿真框架)入门教程(二)——调用第三方库 在借助Veins进行自己的研究时我们经常需要实现一些比较复杂的功能,有时就需要借助第三方库的帮助. 博主的研究需要使用神经网络,但是自己编 ...

  7. Bootstrap框架菜鸟入门教程

    Bootstrap菜鸟入门教程 Bootstrap简介 Bootstrap,来自 Twitter,是目前最受欢迎的前端框架.Bootstrap 是基于 HTML.CSS.JAVASCRIPT 的,它简 ...

  8. java springboot整合zookeeper入门教程(增删改查)

    java springboot整合zookeeper增删改查入门教程 zookeeper的安装与集群搭建参考:https://www.cnblogs.com/zwcry/p/10272506.html ...

  9. Veins(车载通信仿真框架)入门教程

    Veins入门教程——教你如何下手研究 目录 Veins入门教程——教你如何下手研究 目录 废话少说! 讲解omnetpp.ini!(挑关键的) 讲解RSUExampleScnario.ned! 注意 ...

随机推荐

  1. 几个原生js方法总结

    一.document.getElementById('emoji').addEventListener('click', function(e) { var emojiwrapper = docume ...

  2. Python学习过程笔记整理(二)

    程序三大结构 -顺序 -分支 -循环 分支 -语法: if 条件表达式: 语句 ... -双向分支 if 条件表达式: 语句 ... else: 语句 ... -多路分支 if 条件表达式: 语句 . ...

  3. Cause: com.mysql.jdbc.exceptions.jdbc4.CommunicationsException: The last packet successfully received from the server was 78,050,512 milliseconds ago.

    今天访问已经架上服务器的网站,报错: Cause: com.mysql.jdbc.exceptions.jdbc4.CommunicationsException: The last packet s ...

  4. access和MySQL mssql

    Access.MSSQL.MYSQL数据库之间有什么区别?     Access数据库.MSSQL数据库.MYSQL数据库之间有什么区别?        不少企业和个人站长在网站制作时,会对数据库的概 ...

  5. 从汉诺塔游戏理解python递归函数

    汉诺塔游戏规则: 有三根相邻的柱子,标号为A,B,C,A柱子上从下到上按金字塔状叠放着n个不同大小的圆盘,现在把所有盘子一个一个移动到柱子B上,并且每次移动同一根柱子上都不能出现大盘子在小盘子上方 图 ...

  6. 2017年10月WEB前端开发实习生面试题总结

    从大一开始学习前端,今年大三,10月份开始投简历,陆续收到很多家公司的面试,目前为止的面试通过率是百分之百,总结下面试题. 不定期更新中... 百度第一次 一面 1.AJAX流程 2.promise简 ...

  7. Daily Scrum7 11.11

    今日任务: 徐钧鸿:结束了SQL和Affairs的移植,修改了连接池,学习C#和java的正则表达式并且完成相关的移植 张艺:个人阅读作业 黄可嵩:完成高亮显示的移植,进一步移植搜索代码 徐方宇:继续 ...

  8. Java程序设计基础项目总结报告

    Java程序设计基础项目总结报告 20135313吴子怡 一.项目内容 运用所学Java知识,不调用Java类库,实现密码学相关算法的设计,并完成TDD测试,设计运行界面. 二.具体任务 1.要求实现 ...

  9. python service 服务没有及时响应启动或控制请求

    1053错误 代码运行没有问题后,安装服务,然而start 的时候出现错误 1053:服务没有及时响应启动或控制请求(Error 1053: The service did not respond t ...

  10. 第二阶段Sprint冲刺会议7

    进展:试着把视频录制功能加到时间提醒中,但是整合没有成功,今天没有进展.