支持连续变量和类别变量,类别变量就是某个属性有三个值,a,b,c,需要用Feature Transformers中的vectorindexer处理

上来是一堆参数

setMaxDepth:最大树深度

setMaxBins:最大装箱数,为了近似统计变量,比如变量有100个值,我只分成10段去做统计

setMinInstancesPerNode:每个节点最少实例

setMinInfoGain:最小信息增益

setMaxMemoryInMB:最大内存MB单位,这个值越大,一次处理的节点划分就越多

setCacheNodeIds:是否缓存节点id,缓存可以加速深层树的训练

setCheckpointInterval:检查点间隔,就是多少次迭代固化一次

setImpurity:随机森林有三种方式,entropy,gini,variance,回归肯定就是variance

setSubsamplingRate:设置采样率,就是每次选多少比例的样本构成新树

setSeed:采样种子,种子不变,采样结果不变

setNumTrees:设置森林里有多少棵树

setFeatureSubsetStrategy:设置特征子集选取策略,随机森林就是两个随机,构成树的样本随机,每棵树开始分裂的属性是随机的,其他跟决策树区别不大,注释这么写的

* The number of features to consider for splits at each tree node.
   * Supported options:
   *  - "auto": Choose automatically for task://默认策略
   *            If numTrees == 1, set to "all."     //决策树选择所有属性
   *            If numTrees > 1 (forest), set to "sqrt" for classification and //决策森林 分类选择属性数开平方,回归选择三分之一属性
   *              to "onethird" for regression.
   *  - "all": use all features
   *  - "onethird": use 1/3 of the features
   *  - "sqrt": use sqrt(number of features)
   *  - "log2": use log2(number of features) //还有取对数的
   * (default = "auto") 
   *
   * These various settings are based on the following references:
   *  - log2: tested in Breiman (2001)
   *  - sqrt: recommended by Breiman manual for random forests
   *  - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
   *    package.

参数完毕,下面比较重要的是这段代码

val categoricalFeatures: Map[Int, Int] =
      MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))

这个地比较蛋疼的是dataset.schema($(featuresCol))

/** An alias for [[getOrDefault()]]. */
  protected final def $[T](param: Param[T]): T = getOrDefault(param)

这段代码说明了$(featuresCol))只是求出一个字段名,实战中直接data.schema("features") ,data.schema("features")出来的是StructField,

case classStructField(name: String, dataType: DataType, nullable: Boolean = true, metadata: Metadata = Metadata.empty) extendsProduct with Serializable

StructField包含四个内容,最好知道一下,机器学习代码很多都用

回头说下getCategoricalFeatures,这个方法是识别一个属性是二值变量还是名义变量,例如a,b就是二值变量,a,b,c就是名义变量,最终把属性索引和变量值的个数放到一个map

这个函数的功能和vectorindexer类似,但是一般都用vectorindexer,因为实战中我们大都从sql读数据,sql读出来的数据metadata是空,无法识别二值变量还是名义变量

后面是

val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
    val strategy =
      super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
    val trees =
      RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
        .map(_.asInstanceOf[DecisionTreeRegressionModel])

val numFeatures = oldDataset.first().features.size

new RandomForestRegressionModel(trees, numFeatures)

可以看出还是调的RDD的旧方法,run这个方法是核心有1000多行,后面会详细跟踪,最后返回的是RandomForestRegressionModel,里面有Array[DecisionTreeRegressionModel] ,就是生成的一组决策树模型,也就是决策森林,另外一个是属性数,我们继续看RandomForestRegressionModel

在1.6版本每棵树的权重都是1,里面还有这么一个方法

override protected def transformImpl(dataset: DataFrame): DataFrame = {
    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
    val predictUDF = udf { (features: Any) =>
      bcastModel.value.predict(features.asInstanceOf[Vector])
    }
    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
  }

可以看到把模型通过广播的形式传给exectors,搞一个udf预测函数,最后通过withColumn把预测数据粘到原数据后面,

注意这个写法dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) ,第一个参数是列名,第二个是计算出来的col,col是列类型,预测方法如下

override protected def predict(features: Vector): Double = {
    // TODO: When we add a generic Bagging class, handle transform there.  SPARK-7128
    // Predict average of tree predictions.
    // Ignore the weights since all are 1.0 for now.
    _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
  }

可见预测用的是每个树的跟节点,predictImpl(features)返回这个根节点分配的叶节点,这是一个递归调用的过程,关于如何递归,后面也会拿出来细说,最后再用.prediction方法把所有树预测的结果相加求平均

后面有一个计算各属性重要性的方法

lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)

实现如下

private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
    val totalImportances = new OpenHashMap[Int, Double]()
    trees.foreach { tree =>
      // Aggregate feature importance vector for this tree      先计算每棵树的属性重要性值
      val importances = new OpenHashMap[Int, Double]()
      computeFeatureImportance(tree.rootNode, importances)
      // Normalize importance vector for this tree, and add it to total.
      // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
      val treeNorm = importances.map(_._2).sum
      if (treeNorm != 0) {
        importances.foreach { case (idx, impt) =>
          val normImpt = impt / treeNorm
          totalImportances.changeValue(idx, normImpt, _ + normImpt)
        }
      }
    }
    // Normalize importances
    normalizeMapValues(totalImportances)
    // Construct vector
    val d = if (numFeatures != -1) {
      numFeatures
    } else {
      // Find max feature index used in trees
      val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
      maxFeatureIndex + 1
    }
    if (d == 0) {
      assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" +
        s" importance: No splits in forest, but some non-zero importances.")
    }
    val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
    Vectors.sparse(d, indices.toArray, values.toArray)
  }

computeFeatureImportance的实现如下

/**
   * Recursive method for computing feature importances for one tree.
   * This walks down the tree, adding to the importance of 1 feature at each node.
   * @param node  Current node in recursion
   * @param importances  Aggregate feature importances, modified by this method
   */
  private[impl] def computeFeatureImportance(
      node: Node,
      importances: OpenHashMap[Int, Double]): Unit = {
    node match {
      case n: InternalNode =>
        val feature = n.split.featureIndex
        val scaledGain = n.gain * n.impurityStats.count
        importances.changeValue(feature, scaledGain, _ + scaledGain)
        computeFeatureImportance(n.leftChild, importances)
        computeFeatureImportance(n.rightChild, importances)
      case n: LeafNode =>
        // do nothing
    }
  }

由于属性重要性是由gain概念扩展而来,这里以gain来说明如何计算属性重要性。

这里首先可以看出为什么每次树的调用都回到rootnode的调用,因为要递归的沿着树的层深往下游走,这里游走到叶节点什么也不做,其他分裂节点也就是代码里的InternalNode ,先找到该节点划分的属性索引,然后该节点增益乘以该节点数据量,然后更新属性重要性值,这样继续递归左节点,右节点,直到结束

然后回到featureImportances方法,val treeNorm = importances.map(_._2).sum是把刚才计算的每棵树的属性重要性求和,然后计算每个属性重要性占这棵树总重要性的比值,这样整棵树就搞完了,foreach走完,所有树的属性重要性就累加到totalImportances里了,然后normalizeMapValues(totalImportances)再按刚才的方法算一遍,这样出来的属性值和就为1了,有了属性个数和排好序的属性重要性值,装入向量,就是最终输出的结果

入口方法就这些了

现在我们还有run方法的1000多行,还有如何递归分配节点这两个点需要讲,后面会继续

RF的特征子集选取策略(spark ml)的更多相关文章

  1. 使用spark ml pipeline进行机器学习

    一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...

  2. spark ml 的例子

    一.关于spark ml pipeline与机器学习 一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的 ...

  3. spark ml pipeline构建机器学习任务

    一.关于spark ml pipeline与机器学习一个典型的机器学习构建包含若干个过程 1.源数据ETL 2.数据预处理 3.特征选取 4.模型训练与验证 以上四个步骤可以抽象为一个包括多个步骤的流 ...

  4. Spark ML下实现的多分类adaboost+naivebayes算法在文本分类上的应用

    1. Naive Bayes算法 朴素贝叶斯算法算是生成模型中一个最经典的分类算法之一了,常用的有Bernoulli和Multinomial两种.在文本分类上经常会用到这两种方法.在词袋模型中,对于一 ...

  5. Spark ML源码分析之四 树

            之前我们讲过,在Spark ML中所有的机器学习模型都是以参数作为划分的,树相关的参数定义在treeParams.scala这个文件中,这里构建一个关于树的体系结构.首先,以Decis ...

  6. Spark ML机器学习

    Spark提供了常用机器学习算法的实现, 封装于spark.ml和spark.mllib中. spark.mllib是基于RDD的机器学习库, spark.ml是基于DataFrame的机器学习库. ...

  7. Spark ML 几种 归一化(规范化)方法总结

    规范化,有关之前都是用 python写的,  偶然要用scala 进行写, 看到这位大神写的, 那个网页也不错,那个连接图做的还蛮不错的,那天也将自己的博客弄一下那个插件. 本文来源 原文地址:htt ...

  8. Spark ML Pipeline简介

    Spark ML Pipeline基于DataFrame构建了一套High-level API,我们可以使用MLPipeline构建机器学习应用,它能够将一个机器学习应用的多个处理过程组织起来,通过在 ...

  9. spark org.apache.spark.ml.linalg.DenseVector cannot be cast to org.apache.spark.ml.linalg.SparseVector

    在使用 import org.apache.spark.ml.feature.VectorAssembler 转换特征后,想要放入 import org.apache.spark.mllib.clas ...

随机推荐

  1. Unit的各种断言

    今天遇到这个问题,就值得自己总结一下. 1.介绍 JUnit为我们提供了一些辅助函数,他们用来帮助我们确定被测试的方法是否按照预期的效果正常工作,通常,把这些辅助函数称为断言.下面我们来介绍一下JUn ...

  2. java线程中断的办法

    目录 中断线程相关的方法 中断线程 for循环标记退出 阻塞的退出线程 使用stop()方法停止线程 中断线程相关的方法 中断线程有一些相应的方法,这里列出来一下. 注意,如果是Thread.meth ...

  3. jQuery-Selectors(选择器)的使用(二、层次篇)(转载)

    原文:http://www.cnblogs.com/bynet/archive/2009/12/01/1614405.html 本系列文章导航 jQuery-Selectors(选择器)的使用(一.基 ...

  4. 【Ray Tracing in One Weekend 超详解】 光线追踪1-7 Dielectric 半径为负,实心球体镂空技巧

    今天讲这本书最后一种材质 Preface 水,玻璃和钻石等透明材料是电介质.当光线照射它们时,它会分裂成反射光线和折射(透射)光线. 处理方案:在反射或折射之间随机选择并且每次交互仅产生一条散射光线 ...

  5. ScrollView中嵌套GridView,Listview的办法

    按照android的标准,ScrollView中是不能嵌套具有滑动特性的View的,但是有时如果设计真的有这样做的需要,或者为了更方便简单的实现外观(比如在外在的大布局需要有滑动的特性,并且内部有类似 ...

  6. hdu 4460 第37届ACM/ICPC杭州赛区H题 STL+bfs

    题意:一些小伙伴之间有朋友关系,比如a和b是朋友,b和c是朋友,a和c不是朋友,则a和c之间存在朋友链,且大小为2,给出一些关系,求出这些关系中最大的链是多少? 求最短路的最大距离 #include& ...

  7. 如何使用PhoneGap打包Web App

    最近做了一款小游戏,定位是移动端访问,思来想去最后选择了jQuery mobile最为框架,制作差不多以后,是否可以打包成App,恰好以前对PhoneGap有耳闻,便想用这个来做打包,可以其中艰辛曲折 ...

  8. ios手机填坑总结

    1. 日期格式 ios系统.safari只能识别"2018/10/15 00:00:00",不能识别"2018-10-15 00:00:00",所以需要转换格式 ...

  9. 《Go语言实战》摘录:6.2 并发 - goroutine

    6.2 goroutine

  10. [置顶] android socket 聊天实现与调试

    网上很多基于Socket的聊天实现都是不完整的... 结合自己的经验给大家分享一下,完整代码可以在GitHub里获取https://github.com/zz7zz7zz/android-socket ...