本篇博客主要讲述如何利用spark的mliib构建机器学习模型并预测新的数据,具体的流程如下图所示:

加载数据 对于数据的加载或保存,mllib提供了MLUtils包,其作用是Helper methods to load,save and pre-process data used in MLLib.博客中的数据是采用spark中提供的数据sample_libsvm_data.txt,其有一百个数据样本,658个特征。具体的数据形式如图所示:

加载libsvm

JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.libsvmFile).toJavaRDD();

LabeledPoint数据类型是对应与libsvmfile格式文件, 具体格式为: Lable(double类型),vector(Vector类型) 转化dataFrame数据类型

JavaRDD<Row> jrow = lpdata.map(new LabeledPointToRow());
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty()),
});
SQLContext jsql = new SQLContext(sc);
DataFrame df = jsql.createDataFrame(jrow, schema);

DataFrame:DataFrame是一个以命名列方式组织的分布式数据集。在概念上,它跟关系型数据库中的一张表或者1个Python(或者R)中的data frame一样,但是比他们更优化。DataFrame可以根据结构化的数据文件、hive表、外部数据库或者已经存在的RDD构造。 SQLContext:spark sql所有功能的入口是SQLContext类,或者SQLContext的子类。为了创建一个基本的SQLContext,需要一个SparkContext。 特征提取 特征归一化处理

StandardScaler scaler = new StandardScaler().setInputCol("features").setOutputCol("normFeatures").setWithStd(true);
DataFrame scalerDF = scaler.fit(df).transform(df);
scaler.save(this.scalerModelPath);

利用卡方统计做特征提取

ChiSqSelector selector = new ChiSqSelector().setNumTopFeatures().setFeaturesCol("normFeatures").setLabelCol("label").setOutputCol("selectedFeatures");
ChiSqSelectorModel chiModel = selector.fit(scalerDF);
DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");
chiModel.save(this.featureSelectedModelPath);

训练机器学习模型(以SVM为例)

//转化为LabeledPoint数据类型, 训练模型
JavaRDD<Row> selectedrows = selectedDF.javaRDD();
JavaRDD<LabeledPoint> trainset = selectedrows.map(new RowToLabel()); //训练SVM模型, 并保存
int numIteration = ;
SVMModel model = SVMWithSGD.train(trainset.rdd(), numIteration);
model.clearThreshold();
model.save(sc, this.mlModelPath); // LabeledPoint数据类型转化为Row
static class LabeledPointToRow implements Function<LabeledPoint, Row> { public Row call(LabeledPoint p) throws Exception {
double label = p.label();
Vector vector = p.features();
return RowFactory.create(label, vector);
}
} //Rows数据类型转化为LabeledPoint
static class RowToLabel implements Function<Row, LabeledPoint> { public LabeledPoint call(Row r) throws Exception {
Vector features = r.getAs();
double label = r.getDouble();
return new LabeledPoint(label, features);
}
}

测试新的样本 测试新的样本前,需要将样本做数据的转化和特征提取的工作,所有刚刚训练模型的过程中,除了保存机器学习模型,还需要保存特征提取的中间模型。具体代码如下:

//初始化spark
SparkConf conf = new SparkConf().setAppName("SVM").setMaster("local");
conf.set("spark.testing.memory", "");
SparkContext sc = new SparkContext(conf); //加载测试数据
JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictDataPath).toJavaRDD(); //转化DataFrame数据类型
JavaRDD<Row> jrow =testData.map(new LabeledPointToRow());
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty()),
});
SQLContext jsql = new SQLContext(sc);
DataFrame df = jsql.createDataFrame(jrow, schema); //数据规范化
StandardScaler scaler = StandardScaler.load(this.scalerModelPath);
DataFrame scalerDF = scaler.fit(df).transform(df); //特征选取
ChiSqSelectorModel chiModel = ChiSqSelectorModel.load( this.featureSelectedModelPath);
DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");

测试数据集

SVMModel svmmodel = SVMModel.load(sc, this.mlModelPath);
JavaRDD<Tuple2<Double, Double>> predictResult = testset.map(new Prediction(svmmodel)) ;
predictResult.collect(); static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> {
SVMModel model;
public Prediction(SVMModel model){
this.model = model;
}
public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {
Double score = model.predict(p.features());
return new Tuple2<Double , Double>(score, p.label());
}
}

计算准确率

double accuracy = predictResult.filter(new PredictAndScore()).count() * 1.0 / predictResult.count();
System.out.println(accuracy); static class PredictAndScore implements Function<Tuple2<Double, Double>, Boolean> {
public Boolean call(Tuple2<Double, Double> t) throws Exception {
double score = t._1();
double label = t._2();
System.out.print("score:" + score + ", label:"+ label);
if(score >= 0.0 && label >= 0.0) return true;
else if(score < 0.0 && label < 0.0) return true;
else return false;
}
}

关于spark的mllib学习总结(Java版)的更多相关文章

  1. spark Using MLLib in Scala/Java/Python

    Using MLLib in ScalaFollowing code snippets can be executed in spark-shell. Binary ClassificationThe ...

  2. 布隆过滤器(Bloom Filter)-学习笔记-Java版代码(挖坑ing)

    布隆过滤器解决"面试题: 如何建立一个十亿级别的哈希表,限制内存空间" "如何快速查询一个10亿大小的集合中的元素是否存在" 如题 布隆过滤器确实很神奇, 简单 ...

  3. spark读文件写mysql(java版)

    package org.langtong.sparkdemo; import com.fasterxml.jackson.databind.ObjectMapper; import org.apach ...

  4. 20165234 [第二届构建之法论坛] 预培训文档(Java版) 学习总结

    [第二届构建之法论坛] 预培训文档(Java版) 学习总结 我通读并学习了此文档,并且动手实践了一遍.以下是我学习过程的记录~ Part1.配置环境 配置JDK 原文中提到了2个容易被混淆的概念 JD ...

  5. Java基础及JavaWEB以及SSM框架学习笔记Xmind版

    Java基础及JavaWEB以及SSM框架学习笔记Xmind版 转行做程序员也1年多了,最近开始整理以前学习过程中记录的笔记,以及一些容易犯错的内容.现在分享给网友们.笔记共三部分. JavaSE 目 ...

  6. Spark中的各种action算子操作(java版)

    在我看来,Spark编程中的action算子的作用就像一个触发器,用来触发之前的transformation算子.transformation操作具有懒加载的特性,你定义完操作之后并不会立即加载,只有 ...

  7. PetaPojo —— JAVA版的PetaPoco

    背景 由于工作的一些原因,需要从C#转成JAVA.之前PetaPoco用得真是非常舒服,在学习JAVA的过程中熟悉了一下JAVA的数据组件: MyBatis 非常流行,代码生成也很成熟,性能也很好.但 ...

  8. python实现文章或博客的自动摘要(附java版开源项目)

    python实现文章或博客的自动摘要(附java版开源项目) 写博客的时候,都习惯给文章加入一个简介.现在可以自动完成了!TF-IDF与余弦相似性的应用(三):自动摘要 - 阮一峰的网络日志http: ...

  9. 复利计算--4.0 单元测试之JAVA版-软件工程

    复利计算--4.0 单元测试-软件工程 前言:由于本人之前做的是C语言版的复利计算,所以为了更好地学习单元测试,于是将C语言版的复利计算修改为JAVA版的. 一.主要的功能需求细分: 1.本金为100 ...

随机推荐

  1. Elasticsearch 学习之 节点重启

    ElasticSearch集群的高可用和自平衡方案会在节点挂掉(重启)后自动在别的结点上复制该结点的分片,这将导致了大量的IO和网络开销.如果离开的节点重新加入集群,elasticsearch为了对数 ...

  2. [原]openstack-kilo--issue(二十一) instance can't get ip 虚拟机不能得到ip(2)

    ===问题点==== 在使用vlan模式部署compute节点的时候出现了下面的错误:在controller节点的dhcp-agent.log中 2017-01-22 20:19:34.178 241 ...

  3. mysql格式化小数保留小数点后两位(小数点格式化)

    格式化浮点数的问题,用format(col,2)保留两位小数点,出现一个问题,例如下面的语句,后面我们给出解决方法 SELECT FORMAT(12562.6655,2); 结果:12,562.67 ...

  4. Arduino基本数据类型

    基本数据类型简介 常见的Arduino是基于ATmega的8位 AVR单片机,例如Arduino UNO ,Arduino Nano,Arduino mega2560等.还有高级点 32位的,如Ard ...

  5. onems设备管理系统(TR-069和OMA)

    onems设备管理系统(TR-069和OMA) 沃克斯科技OneMS设备管理套件是一个全面的为服务提供商和企业提供自动配置和远程管理功能的设备管理解决方案.它利用现有的网络基础设施来自动化订购,预配置 ...

  6. MyEclipse启动Tomcat缓慢的原因及解决办法

    不知道朋友们是否有一种烦恼:有时候使用MyEclipse启动Tomcat十分缓慢,可能在几分钟前20秒以内,但现在却需要200秒开外:其间内存和CPU都被占用地厉害,而控制台的输出似乎有重复的迹象:而 ...

  7. spring整合Jersey 无法注入service的问题

    现象: action中的@autowired注入service或dao失败,报空指针异常 原因: 造成该问题的原因是你并没有做好spring和jersey的整合工作,检查你的web.xml文件,jer ...

  8. Python2安装igraph

    前言 igraph是一个进行图计算和社交网络分析的软件包,支持python语言,打算学习igraph,然后应用在自己的项目中. 系统环境 64位win10系统,同时安装了python3.6和pytho ...

  9. ubuntu登录时出现“一闪之后回到登录界面”的现象

    ubuntu登录时出现“一闪之后回到登录界面”的现象 虚拟机vmware 12.5.6 build-5528349 操作系统ubuntu 18.04 问题:登录时出现一闪之后回到登录界面的现象 解决方 ...

  10. [No0000DB]C# FtpClientHelper Ftp客户端上传下载重命名 类封装

    using System; using System.Diagnostics; using System.IO; using System.Text; using Shared; namespace ...