Spark中常见的三种分类模型:线性模型、决策树和朴素贝叶斯模型。

线性模型,简单而且相对容易扩展到非常大的数据集;线性模型又可以分成:1.逻辑回归;2.线性支持向量机

决策树是一个强大的非线性技术,训练过程计算量大并且较难扩展(幸运的是,MLlib会替我们考虑扩展性的问题),但是在很多情况下性能很好;

朴素贝叶斯模型简单、易训练,并且具有高效和并行的优点(实际中,模型训练只需要遍历所有数据集一次)。当采用合适的特征工程,这些模型在很多应用中都能达到不错的性能。而且,朴素贝叶斯模型可以作为一个很好的模型测试基准,用于比较其他模型的性能。

现在我们采用的数据集是stumbleupon,这个数据集是主要是一些网页的分类数据

内容样例:String = "http://www.bloomberg.com/news/2010-12-23/ibm-predicts-holographic-calls-air-breathing-batteries-by-2015.html"    "4042"    "{""title"":""IBM Sees Holographic Calls Air Breathing Batteries ibm sees holographic calls, air-breathing batteries"",""body"":""A sign stands outside the International Business Machines Corp IBM Almaden Research Center campus in San Jose California Photographer Tony Avelar Bloomberg Buildings stand at the International Business Machines Corp IBM Almaden Research Center campus in the Santa Teresa Hills of San Jose California Photographer Tony Avelar Bloomberg By 2015 your mobile phone will project a 3 D image of anyone who calls and your laptop will be powered by kinetic energy At least that s what International Business Machines Corp sees in its crystal ...

开始四列分别包含 URL 、页面的 ID 、原始的文本内容和分配给页面的类别。接下来 22 列包含各种各样的数值或者类属特征。最后一列为目标值, -1 为长久, 0 为短暂。

val rawData = sc.textFile("/user/common/stumbleupon/train_noheader.tsv")
val records = rawData.map(line => line.split("\t"))
records.first()

由于数据格式的问题,我们做一些数据清理的工作,在处理过程中把额外的( " )去掉。数据集中还有一些用 "?" 代替的缺失数据,本例中,我们直接用 0 替换那些缺失数据

在清理和处理缺失数据后,我们提取最后一列的标记变量以及第 5 列到第 25 列的特征矩阵。将标签变量转换为 Int 值,特征向量转换为 Double 数组。

最后,我们将标签和和特征向量转换为 LabeledPoint 实例,从而将特征向量存储到 MLlib 的 Vector 中。

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
val data = records.map { r =>
val trimmed = r.map(_.replaceAll("\"", ""))
val label = trimmed(r.size - 1).toInt
val features = trimmed.slice(4, r.size - 1).map(d =>
if (d =="?") 0.0 else d.toDouble)
LabeledPoint(label, Vectors.dense(features))
}

朴素贝叶斯特殊的数据处理)在对数据集做进一步处理之前,我们发现数值数据中包含负的特征值。我们知道,朴素贝叶斯模型要求特征值非负,否则碰到负的特征值程序会抛出错误。因此,需要为朴素贝叶斯模型构建一份输入特征向量的数据,将负特征值设为 0

val nbData = records.map { r =>
val trimmed = r.map(_.replaceAll("\"", ""))
val label = trimmed(r.size - 1).toInt
val features = trimmed.slice(4, r.size - 1).map(d =>
if (d =="?") 0.0 else d.toDouble).map(d =>
if (d < 0) 0.0 else d)
LabeledPoint(label, Vectors.dense(features))
}

分别训练逻辑回归、SVM、朴素贝叶斯模型和决策树

import org.apache.spark.mllib.classification.LogisticRegressionWithSGD
import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.impurity.Entropy
val numIterations = 10
val maxTreeDepth = 5

训练逻辑回归模型

val lrModel = LogisticRegressionWithSGD.train(data, numIterations)

训练SVM模型

val svmModel = SVMWithSGD.train(data, numIterations)

训练朴素贝叶斯模型

val nbModel = NaiveBayes.train(nbData)

训练决策树模型

val dtModel = DecisionTree.train(data, Algo.Classification, Entropy, maxTreeDepth)

验证预测结果的正确性,以逻辑回归为例子,说明预测的结果是错误的

val dataPoint = data.first
val prediction = lrModel.predict(dataPoint.features)
# 输出 prediction: Double = 1.0
val trueLabel = dataPoint.label
# 输出 trueLabel: Double = 0.0

评估分类模型的性能

1.逻辑回归模型

val lrTotalCorrect = data.map { point =>
if (lrModel.predict(point.features) == point.label) 1 else 0
}.sum
val lrAccuracy = lrTotalCorrect / data.count
lrAccuracy: Double = 0.5146720757268425

2.SVM模型

val svmTotalCorrect = data.map { point =>
if (svmModel.predict(point.features) == point.label) 1 else 0
}.sum
val svmAccuracy = svmTotalCorrect / data.count
svmAccuracy: Double = 0.5146720757268425

3.贝叶斯模型

val nbTotalCorrect = nbData.map { point =>
if (nbModel.predict(point.features) == point.label) 1 else 0
}.sum
val nbAccuracy = nbTotalCorrect / data.count
nbAccuracy: Double = 0.5803921568627451

4.决策树模型

val dtTotalCorrect = data.map { point =>
val score = dtModel.predict(point.features)
val predicted = if (score > 0.5) 1 else 0
if (predicted == point.label) 1 else 0
}.sum
val dtAccuracy = dtTotalCorrect / data.count
dtAccuracy: Double = 0.6482758620689655

准确率和召回率

改进模型性能以及参数调优

1.特征标准化

研究特征是如何分布的,先将特征向量用 RowMatrix 类表示成 MLlib 中的分布矩阵。 RowMatrix 是一个由向量组成的 RDD ,其中每个向量是分布矩阵的一行。

RowMatrix 类中有一些方便操作矩阵的方法,其中一个方法可以计算矩阵每列的统计特性

import org.apache.spark.mllib.linalg.distributed.RowMatrix
val vectors = data.map(lp => lp.features)
val matrix = new RowMatrix(vectors)
val matrixSummary = matrix.computeColumnSummaryStatistics()  
#computeColumnSummaryStatistics 方法计算特征矩阵每列的不同统计数据,包括均值和方差,所有统计值按每列一项的方式存储在一个 Vector 中
println(matrixSummary.mean)  #输出矩阵每列的均值
println(matrixSummary.min)  #输出矩阵每列的最小值
println(matrixSummary.max) #输出矩阵每列的最大值
println(matrixSummary.variance) #输出矩阵每列的方差
println(matrixSummary.numNonzeros) #输出矩阵每列中非 0 项的数目

对特征矩阵进行归一化

import org.apache.spark.mllib.feature.StandardScaler
val scaler = new StandardScaler(withMean = true, withStd = true).fit(vectors)
# 传入两个参数,一个表示是否从数据中减去均值,另一个表示是否应用标准差缩放
val scaledData = data.map(lp => LabeledPoint(lp.label,scaler.transform(lp.features)))
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics

val lrModelScaled = LogisticRegressionWithSGD.train(scaledData, numIterations)
val lrTotalCorrectScaled = scaledData.map { point =>
if (lrModelScaled.predict(point.features) == point.label) 1 else 0
}.sum
val lrAccuracyScaled = lrTotalCorrectScaled / numData
val lrPredictionsVsTrue = scaledData.map { point =>
(lrModelScaled.predict(point.features), point.label)
}
val lrMetricsScaled = new BinaryClassificationMetrics(lrPredictionsVsTrue)
val lrPr = lrMetricsScaled.areaUnderPR
val lrRoc = lrMetricsScaled.areaUnderROC println(f"${lrModelScaled.getClass.getSimpleName}\nAccuracy:${lrAccuracyScaled * 100}%2.4f%%\nArea under PR: ${lrPr *100.0}%2.4f%%\nArea under ROC: ${lrRoc * 100.0}%2.4f%%")

可以看出,特征标准化提升了逻辑回归模型的准确率和AUC

LogisticRegressionModel
Accuracy:62.0419%
Area under PR: 72.7254%
Area under ROC: 61.9663%

Spark学习笔记——构建分类模型的更多相关文章

  1. Spark学习笔记——构建基于Spark的推荐引擎

    推荐模型 推荐模型的种类分为: 1.基于内容的过滤:基于内容的过滤利用物品的内容或是属性信息以及某些相似度定义,来求出与该物品类似的物品. 2.协同过滤:协同过滤是一种借助众包智慧的途径.它利用大量已 ...

  2. ArcGIS模型构建器案例学习笔记-字段处理模型集

    ArcGIS模型构建器案例学习笔记-字段处理模型集 联系方式:谢老师,135-4855-4328,xiexiaokui@qq.com 由四个子模型组成 子模型1:判断字段是否存在 方法:python工 ...

  3. spark学习笔记总结-spark入门资料精化

    Spark学习笔记 Spark简介 spark 可以很容易和yarn结合,直接调用HDFS.Hbase上面的数据,和hadoop结合.配置很容易. spark发展迅猛,框架比hadoop更加灵活实用. ...

  4. Spark学习笔记-GraphX-1

    Spark学习笔记-GraphX-1 标签: SparkGraphGraphX图计算 2014-09-29 13:04 2339人阅读 评论(0) 收藏 举报  分类: Spark(8)  版权声明: ...

  5. Spark学习笔记0——简单了解和技术架构

    目录 Spark学习笔记0--简单了解和技术架构 什么是Spark 技术架构和软件栈 Spark Core Spark SQL Spark Streaming MLlib GraphX 集群管理器 受 ...

  6. 操作系统学习笔记----进程/线程模型----Coursera课程笔记

    操作系统学习笔记----进程/线程模型----Coursera课程笔记 进程/线程模型 0. 概述 0.1 进程模型 多道程序设计 进程的概念.进程控制块 进程状态及转换.进程队列 进程控制----进 ...

  7. V-rep学习笔记:机器人模型创建3—搭建动力学模型

    接着之前写的V-rep学习笔记:机器人模型创建2—添加关节继续机器人创建流程.如果已经添加好关节,那么就可以进入流程的最后一步:搭建层次结构模型和模型定义(build the model hierar ...

  8. Spring实战第五章学习笔记————构建Spring Web应用程序

    Spring实战第五章学习笔记----构建Spring Web应用程序 Spring MVC基于模型-视图-控制器(Model-View-Controller)模式实现,它能够构建像Spring框架那 ...

  9. Spark学习笔记2——RDD(上)

    目录 Spark学习笔记2--RDD(上) RDD是什么? 例子 创建 RDD 并行化方式 读取外部数据集方式 RDD 操作 转化操作 行动操作 惰性求值 Spark学习笔记2--RDD(上) 笔记摘 ...

随机推荐

  1. Codeforces.911F.Tree Destruction(构造 贪心)

    题目链接 \(Description\) 一棵n个点的树,每次可以选择树上两个叶子节点并删去一个,得到的价值为两点间的距离 删n-1次,问如何能使最后得到的价值最大,并输出方案 \(Solution\ ...

  2. Wannafly挑战赛25游记

    Wannafly挑战赛25游记 A - 因子 题目大意: 令\(x=n!(n\le10^{12})\),给定一大于\(1\)的正整数\(p(p\le10000)\)求一个\(k\)使得\(p^k|x\ ...

  3. C++ 类模板基础知识

    类模板与模板类 为什么要引入类模板:类模板是对一批仅仅成员数据类型不同的类的抽象,程序员只要为这一批类所组成的整个类家族创建一个类模板,给出一套程序代码,就可以用来生成多种具体的类,(这类可以看作是类 ...

  4. NOIP2013 花匠解题报告

    //<NOIP2013> 花匠 /* 最优子结构性质,可以用动规.注意到存在30%的变态数据(1 ≤ n ≤ 100,000, 0 ≤ h_i ≤1,000,000),因此应当找到线性的算 ...

  5. centos 7 安装 php 5.5 5.6 7.0

    查看当前安装的PHP包 [root@node1 ~]# yum list installed | grep php php56w.x86_64 -.w7 @webtatic php56w-cli.x8 ...

  6. db2 系统表

    SYSIBM: 基本系统编目,不建议直接访问SYSCAT: 默认授权给Public组.只读编目视图,一般通过这个来获取编目信息SYSSTAT: 可更新编目视图,会影响优化器的优化策略SYSFUN: 用 ...

  7. 树莓派.Qt.Creator安装方法

    树莓派硬件: Raspberry Pi 3 B 树莓派系统: Linux version 4.9.59-v7+ (32位) Qt版本(x86版本--32位): 安装过程 可以查看软件仓库支持的版本: ...

  8. C调用lua的table里面的函数

    网上搜索C.C++调用lua函数,有一大堆复制粘贴的. 但是搜索<C调用lua的table里面的函数> 怎么就没几个呢? 经过探索,发现其实逻辑是这样的: 1.根据name获取table ...

  9. windows命令行下杀死进程的方法

    xp和win7下有两个好东东tasklist和tskill.tasklist能列出所有的进程,和相应的信息.tskill能查杀进程,语法很简单:tskill程序名!或者是tskill 进程id 例如: ...

  10. oradim新建服务后,登录数据库报ORA-12560错误

    > oradim -new -sid mydb 实例已创建. > sqlplus /nolog SQL*Plus: Release 11.2.0.4.0 Production on 星期二 ...