在机器学习中,一般都会按照下面几个步骤:特征提取、数据预处理、特征选择、模型训练、检验优化。那么特征的选择就很关键了,一般模型最后效果的好坏往往都是跟特征的选择有关系的,因为模型本身的参数并没有太多优化的点,反而特征这边有时候多加一个或者少加一个,最终的结果都会差别很大。

在SparkMLlib中为我们提供了几种特征选择的方法,分别是VectorSlicerRFormulaChiSqSelector

下面就介绍下这三个方法的使用,强烈推荐有时间的把参考的文献都阅读下,会有所收获!

VectorSlicer

这个转换器可以支持用户自定义选择列,可以基于下标索引,也可以基于列名。

  • 如果是下标都可以使用setIndices方法
  • 如果是列名可以使用setNames方法。使用这个方法的时候,vector字段需要通过AttributeGroup设置每个向量元素的列名。

注意1:可以同时使用setInices和setName

object VectorSlicer {
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("VectorSlicer-Test").setMaster("local[2]")
val sc = new SparkContext(conf)
sc.setLogLevel("WARN")
var sqlContext = new SQLContext(sc) val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0,1.0,2.0))) val defaultAttr = NumericAttribute.defaultAttr
val attrs = Array("f1", "f2", "f3","f4","f5").map(defaultAttr.withName)
val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) val dataRDD = sc.parallelize(data)
val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField()))) val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") slicer.setIndices(Array(0)).setNames(Array("f2"))
val output = slicer.transform(dataset)
println(output.select("userFeatures", "features").first())
}
}

注意2:如果下标和索引重复,会报重复的错:

比如:

slicer.setIndices(Array(1)).setNames(Array("f2"))

那么会遇到报错

Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: VectorSlicer requires indices and names to be disjoint sets of features, but they overlap. indices: [1]. names: [1:f2]
at scala.Predef$.require(Predef.scala:233)
at org.apache.spark.ml.feature.VectorSlicer.getSelectedFeatureIndices(VectorSlicer.scala:137)
at org.apache.spark.ml.feature.VectorSlicer.transform(VectorSlicer.scala:108)
at xingoo.mllib.VectorSlicer$.main(VectorSlicer.scala:35)
at xingoo.mllib.VectorSlicer.main(VectorSlicer.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:497)
at com.intellij.rt.execution.application.AppMain.main(AppMain.java:144)

注意3:如果下标不存在

slicer.setIndices(Array(6))

如果数组越界也会报错

Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 6
at org.apache.spark.ml.feature.VectorSlicer$$anonfun$3$$anonfun$apply$2.apply(VectorSlicer.scala:110)
at org.apache.spark.ml.feature.VectorSlicer$$anonfun$3$$anonfun$apply$2.apply(VectorSlicer.scala:110)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofInt.foreach(ArrayOps.scala:156)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:244)
at scala.collection.mutable.ArrayOps$ofInt.map(ArrayOps.scala:156)
at org.apache.spark.ml.feature.VectorSlicer$$anonfun$3.apply(VectorSlicer.scala:110)
at org.apache.spark.ml.feature.VectorSlicer$$anonfun$3.apply(VectorSlicer.scala:109)
at scala.Option.map(Option.scala:145)
at org.apache.spark.ml.feature.VectorSlicer.transform(VectorSlicer.scala:109)
at xingoo.mllib.VectorSlicer$.main(VectorSlicer.scala:35)
at xingoo.mllib.VectorSlicer.main(VectorSlicer.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:497)
at com.intellij.rt.execution.application.AppMain.main(AppMain.java:144)

注意4:如果名称不存在也会报错

Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: getFeatureIndicesFromNames found no feature with name f8 in column StructField(userFeatures,org.apache.spark.mllib.linalg.VectorUDT@f71b0bce,false).
at scala.Predef$.require(Predef.scala:233)
at org.apache.spark.ml.util.MetadataUtils$$anonfun$getFeatureIndicesFromNames$2.apply(MetadataUtils.scala:89)
at org.apache.spark.ml.util.MetadataUtils$$anonfun$getFeatureIndicesFromNames$2.apply(MetadataUtils.scala:88)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:108)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:244)
at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:108)
at org.apache.spark.ml.util.MetadataUtils$.getFeatureIndicesFromNames(MetadataUtils.scala:88)
at org.apache.spark.ml.feature.VectorSlicer.getSelectedFeatureIndices(VectorSlicer.scala:129)
at org.apache.spark.ml.feature.VectorSlicer.transform(VectorSlicer.scala:108)
at xingoo.mllib.VectorSlicer$.main(VectorSlicer.scala:35)
at xingoo.mllib.VectorSlicer.main(VectorSlicer.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:497)
at com.intellij.rt.execution.application.AppMain.main(AppMain.java:144)

注意5:经过特征选择后,特征的顺序与索引和名称的顺序相同

RFormula

这个转换器可以帮助基于R模型,自动生成feature和label。比如说最常用的线性回归,在先用回归中,我们需要把一些离散化的变量变成哑变量,即转变成onehot编码,使之数值化,这个我之前的文章也介绍过,这里就不多说了。

如果不是用这个RFormula,我们可能需要经过几个步骤:

StringIndex...OneHotEncoder...

而且每个特征都要经过这样的变换,非常繁琐。有了RFormula,几乎可以一键把所有的特征问题解决。

id | coutry | hour | clicked | my_test

--- | --- | --- | ---

7| US|18|1.0|a

8|CA|12|0.0|b

9|NZ|15|0.0|a

然后我们只要写一个类似这样的公式clicked ~ country + hour + my_test,就代表clickedlabelcoutry、hour、my_test是三个特征

比如下面的代码:

object RFormulaTest {
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("RFormula-Test").setMaster("local[2]")
val sc = new SparkContext(conf)
sc.setLogLevel("WARN")
var sqlContext = new SQLContext(sc) val dataset = sqlContext.createDataFrame(Seq(
(7, "US", 18, 1.0,"a"),
(8, "CA", 12, 0.0,"b"),
(9, "NZ", 15, 0.0,"a")
)).toDF("id", "country", "hour", "clicked","my_test")
val formula = new RFormula()
.setFormula("clicked ~ country + hour + my_test")
.setFeaturesCol("features")
.setLabelCol("label")
val output = formula.fit(dataset).transform(dataset)
output.show()
output.select("features", "label").show()
}
}

得到的结果

+---+-------+----+-------+-------+------------------+-----+
| id|country|hour|clicked|my_test| features|label|
+---+-------+----+-------+-------+------------------+-----+
| 7| US| 18| 1.0| a|[0.0,0.0,18.0,1.0]| 1.0|
| 8| CA| 12| 0.0| b|[1.0,0.0,12.0,0.0]| 0.0|
| 9| NZ| 15| 0.0| a|[0.0,1.0,15.0,1.0]| 0.0|
+---+-------+----+-------+-------+------------------+-----+ +------------------+-----+
| features|label|
+------------------+-----+
|[0.0,0.0,18.0,1.0]| 1.0|
|[1.0,0.0,12.0,0.0]| 0.0|
|[0.0,1.0,15.0,1.0]| 0.0|
+------------------+-----+

ChiSqSelector

这个选择器支持基于卡方检验的特征选择,卡方检验是一种计算变量独立性的检验手段。具体的可以参考维基百科,最终的结论就是卡方的值越大,就是我们越想要的特征。因此这个选择器就可以理解为,再计算卡方的值,最后按照这个值排序,选择我们想要的个数的特征。

代码也很简单

object ChiSqSelectorTest {
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("ChiSqSelector-Test").setMaster("local[2]")
val sc = new SparkContext(conf)
sc.setLogLevel("WARN")
var sqlContext = new SQLContext(sc) val data = Seq(
(7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0),
(8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0),
(9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0)
) val beanRDD = sc.parallelize(data).map(t3 => Bean(t3._1,t3._2,t3._3))
val df = sqlContext.createDataFrame(beanRDD) val selector = new ChiSqSelector()
.setNumTopFeatures(2)
.setFeaturesCol("features")
.setLabelCol("clicked")
.setOutputCol("selectedFeatures") val result = selector.fit(df).transform(df)
result.show()
} case class Bean(id:Double,features:org.apache.spark.mllib.linalg.Vector,clicked:Double){}
}

这样得到的结果:

+---+------------------+-------+----------------+
| id| features|clicked|selectedFeatures|
+---+------------------+-------+----------------+
|7.0|[0.0,0.0,18.0,1.0]| 1.0| [18.0,1.0]|
|8.0|[0.0,1.0,12.0,0.0]| 0.0| [12.0,0.0]|
|9.0|[1.0,0.0,15.0,0.1]| 0.0| [15.0,0.1]|
+---+------------------+-------+----------------+

总结

下面总结一下三种特征选择的使用场景:

  • VectorSilcer,这个选择器适合那种有很多特征,并且明确知道自己想要哪个特征的情况。比如你有一个很全的用户画像系统,每个人有成百上千个特征,但是你指向抽取用户对电影感兴趣相关的特征,因此只要手动选择一下就可以了。
  • RFormula,这个选择器适合在需要做OneHotEncoder的时候,可以一个简单的代码把所有的离散特征转化成数值化表示。
  • ChiSqSelector,卡方检验选择器适合在你有比较多的特征,但是不知道这些特征哪个有用,哪个没用,想要通过某种方式帮助你快速筛选特征,那么这个方法很适合。

以上的总结纯属个人看法,不代表官方做法,如果有其他的见解可以留言~ 多交流!

参考

1 Spark特征处理

2 Spark官方文档

3 如何优化逻辑回归

4 数据挖掘中的VI和WOE

5 Spark卡方选择器

6 卡方分布

7 皮尔逊卡方检验

8 卡方检验原理

推荐系统那点事 —— 基于Spark MLlib的特征选择的更多相关文章

  1. 基于Spark Mllib的文本分类

    基于Spark Mllib的文本分类 文本分类是一个典型的机器学习问题,其主要目标是通过对已有语料库文本数据训练得到分类模型,进而对新文本进行类别标签的预测.这在很多领域都有现实的应用场景,如新闻网站 ...

  2. 【spark】spark应用(分布式估算圆周率+基于Spark MLlib的贷款风险预测)

    注:本章不涉及spark和scala原理的探讨,详情见其他随笔 一.分布式估算圆周率 计算原理:假设正方形的面积S等于x²,而正方形的内切圆的面积C等于Pi×(x/2)²,因此圆面积与正方形面积之比C ...

  3. Spark 实践——基于 Spark MLlib 和 YFCC 100M 数据集的景点推荐系统

    1.前言 上接 YFCC 100M数据集分析笔记 和 使用百度地图api可视化聚类结果, 在对 YFCC 100M 聚类出的景点信息的基础上,使用 Spark MLlib 提供的 ALS 算法构建推荐 ...

  4. 基于Spark Mllib,SparkSQL的电影推荐系统

    本文测试的Spark版本是1.3.1 本文将在Spark集群上搭建一个简单的小型的电影推荐系统,以为之后的完整项目做铺垫和知识积累 整个系统的工作流程描述如下: 1.某电影网站拥有可观的电影资源和用户 ...

  5. 基于spark Mllib(ML)聚类实战

        写在前面的话:由于spark2.0.0之后ML中才包括LDA,GaussianMixture 模型,这里k-means用的是ML模块做测试,LDA,GaussianMixture 则用的是ML ...

  6. 基于Spark Mllib的Spark NLP库

    SparkNLP的官方文档 1>sbt引入: scala为2.11时 libraryDependencies += "com.johnsnowlabs.nlp" %% &qu ...

  7. 使用 Spark MLlib 做 K-means 聚类分析[转]

    原文地址:https://www.ibm.com/developerworks/cn/opensource/os-cn-spark-practice4/ 引言 提起机器学习 (Machine Lear ...

  8. Spark MLlib 之 StringIndexer、IndexToString使用说明以及源码剖析

    最近在用Spark MLlib进行特征处理时,对于StringIndexer和IndexToString遇到了点问题,查阅官方文档也没有解决疑惑.无奈之下翻看源码才明白其中一二...这就给大家娓娓道来 ...

  9. 基于Spark机器学习和实时流计算的智能推荐系统

    概要: 随着电子商务的高速发展和普及应用,个性化推荐的推荐系统已成为一个重要研究领域. 个性化推荐算法是推荐系统中最核心的技术,在很大程度上决定了电子商务推荐系统性能的优劣,决定着是否能够推荐用户真正 ...

随机推荐

  1. iOS安全攻防之使用 Frida 绕过越狱设备检测

    Frida 是 一款有趣的手机应用安全分析工具. 文章参考:Bypass Jailbreak Detection with Frida in iOS applications 在 Mac Termin ...

  2. Thinkphp5使用阿里大于短信验证

    现在各种平台登录验证很多时候会使用短信验证,快捷安全,有很多平台提供短信验证服务,相比较而言阿里大于价格比较便宜,快捷,所以在在千锋日常的php教学中多以此为例来说明短信验证的使用.下面我们在tp5中 ...

  3. Oracle数据库------体系结构

    ORACLE体系结构包括:实例(Instance),数据库文件,用户进程(User process),服务器进程以及其他文件. 1.ORACLE实例(instance)     1).要访问数据库必须 ...

  4. 图解Git命令

    上面的四条命令在工作目录.暂存目录(也叫做索引)和仓库之间复制文件. ·git add files把当前文件放入暂存区域. ·git commit 给暂存区域生成快照并提交. ·git reset - ...

  5. DOM知识梳理

    DOM 我们知道,JavaScript是由ECMAScript + DOM + BOM组成的.ECMAScript是JS中的一些语法,而BOM主要是浏览器对象(window)对象的一些相关知识的集合. ...

  6. git rebase -i命令修改commit历史

    [TOC] 修改commit历史的前提 修改历史的提交是可能有风险的,是否有风险取决于commit是否已经推送远程分支,未推送,无风险,如果已推送,就千万不要修改commit了. 修改commit历史 ...

  7. 转发:Ubuntu软件卸载安装的命令

    说明:由于图形化界面方法(如Add/Remove... 和Synaptic Package Manageer)比较简单,所以这里主要总结在终端通过命令行方式进行的软件包安装.卸载和删除的方法. 一.U ...

  8. 原型----------prototype详细解答

    function ren(name,age){ this.name=name; this.age=age; this.fa=function(){ alert('我喜欢吃'); } } var p1= ...

  9. ionic 项目中创建侧边栏的具体流程分4步简单学会

    这是在学习ionic时,当时遇到的一些问题,觉得很难,就记笔记下来了,现在觉得如果可以拿来分享,有可能会帮助到遇到同样问题的人 ionic slidemenu 项目流程: cd pretices(自己 ...

  10. js实现数据流(日志流,报警信息等)滚动展示,并分页(含实现demo)

    在项目中有遇到,后台向前端推送数据,前端以数据流的形式展示,即来一条我增加一条,类似于日志,报警等信息展示,想必大部分人都有遇到过,本来出于想找一个好的展示方式的心态,因为感觉自己设计的不太好看,结果 ...