欢迎访问我的GitHub

https://github.com/zq2599/blog_demos

内容:所有原创文章分类汇总及配套源码,涉及Java、Docker、Kubernetes、DevOPS等;

本篇概览

  • 本文是《DL4J》实战的第二篇,前面做好了准备工作,接下来进入正式实战,本篇内容是经典的入门例子:鸢尾花分类
  • 下图是一朵鸢尾花,我们可以测量到它的四个特征:花瓣(petal)的宽和高,花萼(sepal)的 宽和高:

  • 鸢尾花有三种:Setosa、Versicolor、Virginica
  • 今天的实战是用前馈神经网络Feed-Forward Neural Network (FFNN)就行鸢尾花分类的模型训练和评估,在拿到150条鸢尾花的特征和分类结果后,我们先训练出模型,再评估模型的效果:

源码下载

名称 链接 备注
项目主页 https://github.com/zq2599/blog_demos 该项目在GitHub上的主页
git仓库地址(https) https://github.com/zq2599/blog_demos.git 该项目源码的仓库地址,https协议
git仓库地址(ssh) git@github.com:zq2599/blog_demos.git 该项目源码的仓库地址,ssh协议
  • 这个git项目中有多个文件夹,《DL4J实战》系列的源码在dl4j-tutorials文件夹下,如下图红框所示:

  • dl4j-tutorials文件夹下有多个子工程,本次实战代码在dl4j-tutorials目录下,如下图红框:

编码

  • 在dl4j-tutorials工程下新建子工程classifier-iris,其pom.xml如下:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>dlfj-tutorials</artifactId>
<groupId>com.bolingcavalry</groupId>
<version>1.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion> <artifactId>classifier-iris</artifactId> <properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
</properties> <dependencies>
<dependency>
<groupId>com.bolingcavalry</groupId>
<artifactId>commons</artifactId>
<version>${project.version}</version>
</dependency> <dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency> <dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
</dependency> <dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</dependency>
</dependencies>
</project>
  • 上述pom.xml有一处需要注意的地方,就是${nd4j.backend}参数的值,该值在决定了后端线性代数计算是用CPU还是GPU,本篇为了简化操作选择了CPU(因为个人的显卡不同,代码里无法统一),对应的配置就是nd4j-native;

  • 源码全部在Iris.java文件中,并且代码中已添加详细注释,就不再赘述了:

package com.bolingcavalry.classifier;

import com.bolingcavalry.commons.utils.DownloaderUtility;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File; /**
* @author will (zq2599@gmail.com)
* @version 1.0
* @description: 鸢尾花训练
* @date 2021/6/13 17:30
*/
@SuppressWarnings("DuplicatedCode")
@Slf4j
public class Iris { public static void main(String[] args) throws Exception { //第一阶段:准备 // 跳过的行数,因为可能是表头
int numLinesToSkip = 0;
// 分隔符
char delimiter = ','; // CSV读取工具
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter); // 下载并解压后,得到文件的位置
String dataPathLocal = DownloaderUtility.IRISDATA.Download(); log.info("鸢尾花数据已下载并解压至 : {}", dataPathLocal); // 读取下载后的文件
recordReader.initialize(new FileSplit(new File(dataPathLocal,"iris.txt"))); // 每一行的内容大概是这样的:5.1,3.5,1.4,0.2,0
// 一共五个字段,从零开始算的话,标签在第四个字段
int labelIndex = 4; // 鸢尾花一共分为三类
int numClasses = 3; // 一共150个样本
int batchSize = 150; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets) // 加载到数据集迭代器中
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses); DataSet allData = iterator.next(); // 洗牌(打乱顺序)
allData.shuffle(); // 设定比例,150个样本中,百分之六十五用于训练
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training // 训练用的数据集
DataSet trainingData = testAndTrain.getTrain(); // 验证用的数据集
DataSet testData = testAndTrain.getTest(); // 指定归一化器:独立地将每个特征值(和可选的标签值)归一化为0平均值和1的标准差。
DataNormalization normalizer = new NormalizerStandardize(); // 先拟合
normalizer.fit(trainingData); // 对训练集做归一化
normalizer.transform(trainingData); // 对测试集做归一化
normalizer.transform(testData); // 每个鸢尾花有四个特征
final int numInputs = 4; // 共有三种鸢尾花
int outputNum = 3; // 随机数种子
long seed = 6; //第二阶段:训练
log.info("开始配置...");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.activation(Activation.TANH) // 激活函数选用标准的tanh(双曲正切)
.weightInit(WeightInit.XAVIER) // 权重初始化选用XAVIER:均值 0, 方差为 2.0/(fanIn + fanOut)的高斯分布
.updater(new Sgd(0.1)) // 更新器,设置SGD学习速率调度器
.l2(1e-4) // L2正则化配置
.list() // 配置多层网络
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(3) // 隐藏层
.build())
.layer(new DenseLayer.Builder().nIn(3).nOut(3) // 隐藏层
.build())
.layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) // 损失函数:负对数似然
.activation(Activation.SOFTMAX) // 输出层指定激活函数为:SOFTMAX
.nIn(3).nOut(outputNum).build())
.build(); // 模型配置
MultiLayerNetwork model = new MultiLayerNetwork(conf); // 初始化
model.init(); // 每一百次迭代打印一次分数(损失函数的值)
model.setListeners(new ScoreIterationListener(100)); long startTime = System.currentTimeMillis(); log.info("开始训练");
// 训练
for(int i=0; i<1000; i++ ) {
model.fit(trainingData);
}
log.info("训练完成,耗时[{}]ms", System.currentTimeMillis()-startTime); // 第三阶段:评估 // 在测试集上评估模型
Evaluation eval = new Evaluation(numClasses);
INDArray output = model.output(testData.getFeatures());
eval.eval(testData.getLabels(), output); log.info("评估结果如下\n" + eval.stats());
}
}
  • 编码完成后,运行main方法,可见顺利完成训练并输出了评估结果,还有混淆矩阵用于辅助分析:

  • 至此,咱们的第一个实战就完成了,通过经典实例体验的DL4J训练和评估的常规步骤,对重要API也有了初步认识,接下来会继续实战,接触到更多的经典实例;

你不孤单,欣宸原创一路相伴

  1. Java系列
  2. Spring系列
  3. Docker系列
  4. kubernetes系列
  5. 数据库+中间件系列
  6. DevOps系列

欢迎关注公众号:程序员欣宸

微信搜索「程序员欣宸」,我是欣宸,期待与您一同畅游Java世界...

https://github.com/zq2599/blog_demos

DL4J实战之二:鸢尾花分类的更多相关文章

  1. DL4J实战之一:准备

    欢迎访问我的GitHub https://github.com/zq2599/blog_demos 内容:所有原创文章分类汇总及配套源码,涉及Java.Docker.Kubernetes.DevOPS ...

  2. DL4J实战之五:矩阵操作基本功

    欢迎访问我的GitHub https://github.com/zq2599/blog_demos 内容:所有原创文章分类汇总及配套源码,涉及Java.Docker.Kubernetes.DevOPS ...

  3. python机器学习实战(二)

    python机器学习实战(二) 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7159775.html 前言 这篇noteboo ...

  4. 【深度学习系列】PaddlePaddle垃圾邮件处理实战(二)

    PaddlePaddle垃圾邮件处理实战(二) 前文回顾   在上篇文章中我们讲了如何用支持向量机对垃圾邮件进行分类,auc为73.3%,本篇讲继续讲如何用PaddlePaddle实现邮件分类,将深度 ...

  5. [Python]基于K-Nearest Neighbors[K-NN]算法的鸢尾花分类问题解决方案

    看了原理,总觉得需要用具体问题实现一下机器学习算法的模型,才算学习深刻.而写此博文的目的是,网上关于K-NN解决此问题的博文很多,但大都是调用Python高级库实现,尤其不利于初级学习者本人对模型的理 ...

  6. (转载)Android项目实战(二十七):数据交互(信息编辑)填写总结

    Android项目实战(二十七):数据交互(信息编辑)填写总结   前言: 项目中必定用到的数据填写需求.比如修改用户名的文字编辑对话框,修改生日的日期选择对话框等等.现总结一下,方便以后使用. 注: ...

  7. (转载)Android项目实战(二十八):Zxing二维码实现及优化

    Android项目实战(二十八):Zxing二维码实现及优化   前言: 多年之前接触过zxing实现二维码,没想到今日项目中再此使用竟然使用的还是zxing,百度之,竟是如此牛的玩意. 当然,项目中 ...

  8. 02-15 Logistic回归(鸢尾花分类)

    目录 Logistic回归(鸢尾花分类) 一.导入模块 二.获取数据 三.构建决策边界 四.训练模型 4.1 C参数与权重系数的关系 五.可视化 更新.更全的<机器学习>的更新网站,更有p ...

  9. 02-19 k近邻算法(鸢尾花分类)

    [TOC] 更新.更全的<机器学习>的更新网站,更有python.go.数据结构与算法.爬虫.人工智能教学等着你:https://www.cnblogs.com/nickchen121/ ...

随机推荐

  1. (int)a、&a、(int)&a、(int&)a的区别,很偏僻的题

    (int)a.&a.(int)&a.(int&)a的区别,很偏僻的题 #include <iostream> #include <stdio.h> #i ...

  2. flex布局中flex属性运用在随机发红包的算法上

    flex布局是现在前端基本上都会运用的一种布局,基本上用到比较多的是父元素设置display:flex,两个子元素,一个设置固定宽度,另一个设置为flex:1(这里都指flex-direction为r ...

  3. go语言 切片表达式

    切片表达式 切片的底层就是一个数组,所以我们可以基于数组通过切片表达式得到切片. 切片表达式中的low和high表示一个索引范围(左包含,右不包含),得到的切片长度=high-low,容量等于得到的切 ...

  4. [leetcode]1109. 航班预订统计(击败100%用户算法-差分数组的详解)

    执行用时2ms,击败100%用户 内存消耗52.1MB,击败91%用户 这也是我第一次用差分数组,之前从来没有碰到过,利用差分数组就是利用了差分数组在某一区间内同时加减情况,只会改变最左边和最右边+1 ...

  5. 关于python使用的那些事儿

    时间:2019-04-11 整理:PangYuaner 标题:Python获取并输出当前日期时间 地址:https://www.cnblogs.com/kerwinC/p/5760811.html 实 ...

  6. asp.NetCore3.1系统自带Imemcache缓存-滑动/绝对/文件依赖的缓存使用测试

    个人测试环境为:Asp.net coe 3.1 WebApi 1:封装自定义的cacheHelper帮助类,部分代码 1 public static void SetCacheByFile<T& ...

  7. 记一次 .NET 某新能源汽车锂电池检测程序 UI挂死分析

    更多高质量干货:参见我的 GitHub: dotnetfly 一:背景 1. 讲故事 这世间事说来也奇怪,近两个月有三位朋友找到我,让我帮忙分析下他的程序hangon现象,这三个dump分别涉及: 医 ...

  8. Java基础之类加载器

    Java类加载器是用户程序和JVM虚拟机之间的桥梁,在Java程序中起了至关重要的作用,理解它有利于我们写出更优雅的程序.本文首先介绍了Java虚拟机加载程序的过程,简述了Java类加载器的加载方式( ...

  9. [Elasticsearch] ES更新问题踩坑记录

    问题描述 我们有个系统设计的时候针对Hive创建表.删除表, 需要更新ES中的一个状态,标记是否删除,在几乎同时执行两条下面的语句的时候,发现在ES 中出现表即使被创建了还是无法被查询到的情况,针对该 ...

  10. noip模拟36

    \(\color{white}{\mathbb{荷花映日,莲叶遮天,名之以:残荷}}\) 今天再次翻车掉出前十 开题看错 \(t1\) 以为操作2的值固定发现是个简单题,然后 \(t2\) 开始大力 ...