采用信息增益或基尼指数寻找最优离散化点

package org.apache.spark.ml.feature

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types._
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.ml.Estimator
import org.apache.spark.ml.attribute._ /**
* 连续有序特征离散类。
*
* 相关参数参照决策树同名参数
*
*
* 采用二分,每次分割都对一个或多个分段进行二分。寻找信息增益(或基尼指数)最大的分割点。
*
* 停止条件:达到指定分段数(numBuckets),或分割的后信息增益小于指定值(minInfoGain)
*
*/
private object DiscretizerTest {
def main(args: Array[String]): Unit = {
val time1 = System.currentTimeMillis()
val spark = spark = SparkSession.builder().getOrCreate()
import spark.implicits._ val inputCol1 = "f1"
val inputCol2 = "f2"
val labelCol = "label"
val outputCol1 = "discretizer1"
val outputCol2 = "discretizer2" val train = spark.createDataFrame(
List(
(, 2.3, ),
(, 8.1, ),
(, 1.1, ),
(, 2.2, ),
(, 3.3, ),
(, 7.0, ))).toDF(inputCol1, inputCol2, labelCol) val test = spark.createDataFrame(
List(
(, ),
(, ))).toDF(inputCol1, inputCol2) val discretizer = new Discretizer().
setInputCols(Array(inputCol1, inputCol2)).
setOutputCols(Array(outputCol1, outputCol2)).
setNumBucketsArray(Array(, )).
setLabelCol(labelCol).
setMinInstancesPerBucket() val model = discretizer.fit(train) model.transform(test).show()
model.getSplitsArray.foreach {
arr => println(arr.mkString(","))
} val time2 = System.currentTimeMillis()
println(time2 - time1)
}
}
private[feature] trait DiscretizerBase extends Params
with HasHandleInvalid with HasInputCol with HasOutputCol
with HasInputCols with HasOutputCols with HasLabelCol { final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for information gain calculation (case-insensitive). " +
" Supported: \"entropy\" and \"gini\". (default = gini)",
ParamValidators.inArray(Array("gini", "entropy")))
final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain", "分组的最小信息增益(不包含),需非负数,默认0.0", ParamValidators.gtEq(0.0))
final val numBuckets: IntParam = new IntParam(this, "numBuckets", "离散分桶数量,正整数", ParamValidators.gtEq())
val numBucketsArray = new IntArrayParam(this, "numBucketsArray", "Array of number of buckets " +
"(quantiles, or categories) into which data points are grouped. This is for multiple " +
"columns input. If transforming multiple columns and numBucketsArray is not set, but " +
"numBuckets is set, then numBuckets will be applied across all columns.",
(arrayOfNumBuckets: Array[Int]) => arrayOfNumBuckets.forall(ParamValidators.gtEq())) final val minInstancesPerBucket: IntParam = new IntParam(this, "minInstancesPerBucket", "每个桶最少记录数量(包含),默认1")
def getImpurity() = $(minInfoGain)
def getMinInfoGain() = $(minInfoGain)
def getNumBuckets() = $(numBuckets)
def getNumBucketsArray: Array[Int] = $(numBucketsArray)
def getMinInstancesPerBucket() = $(minInstancesPerBucket) override val handleInvalid: Param[String] = new Param[String](
this,
"handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), " +
"or error (which will throw an error) or keep (keep invalid values in a special additional bucket).",
ParamValidators.inArray(Array("skip", "error", "keep"))) def setImpurity(value: String): this.type = set(impurity, value)
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
def setLabelCol(value: String): this.type = set(labelCol, value)
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
def setNumBuckets(value: Int): this.type = set(numBuckets, value)
def setNumBucketsArray(value: Array[Int]): this.type = set(numBucketsArray, value)
def setMinInstancesPerBucket(value: Int): this.type = set(minInstancesPerBucket, value) setDefault(minInfoGain -> 0.0, labelCol -> "label", minInstancesPerBucket -> , handleInvalid -> "error", impurity -> "gini") protected def getInOutCols: (Array[String], Array[String]) = {
require(
(isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) ||
(!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)),
"Discretizer only supports setting either inputCol/outputCol or" +
"inputCols/outputCols.") if (isSet(inputCol)) {
(Array($(inputCol)), Array($(outputCol)))
} else {
require(
$(inputCols).length == $(outputCols).length,
"inputCols number do not match outputCols")
($(inputCols), $(outputCols))
}
} } class Discretizer(override val uid: String) extends Estimator[Bucketizer]
with DiscretizerBase with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("Discretizer"))
override def copy(extra: ParamMap): this.type = defaultCopy(extra) override def fit(dataset: Dataset[_]): Bucketizer = {
transformSchema(dataset.schema, true)
val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid))
val (inputColNames, outputColNames) = getInOutCols val numBucketsArray_t = if (isSet(numBucketsArray) && isSet(inputCols)) {
$(numBucketsArray)
} else {
Array.fill[Int](inputColNames.size)($(numBuckets))
} val splitsArrayBuffer = new ArrayBuffer[Array[Double]]()
inputColNames.zip(numBucketsArray_t).foreach {
case (inputColName, numBuckets_t) =>
val splits = deiscretizeCol(dataset, inputColName, numBuckets_t)
splitsArrayBuffer += splits.sorted
} if (splitsArrayBuffer.size == ) {
val splits = splitsArrayBuffer.head
bucketizer.setSplits(splits)
} else {
var splitsArray = splitsArrayBuffer.toArray
splitsArray.foreach(f => f.foreach(println))
bucketizer.setSplitsArray(splitsArray)
}
copyValues(bucketizer.setParent(this))
} override def transformSchema(schema: StructType): StructType = {
val (inputColNames, outputColNames) = getInOutCols
val existingFields = schema.fields
var outputFields = existingFields
inputColNames.zip(outputColNames).foreach {
case (inputColName, outputColName) =>
require(
existingFields.exists(_.name == inputColName),
s"Iutput column ${inputColName} not exists.")
require(
existingFields.forall(_.name != outputColName),
s"Output column ${outputColName} already exists.")
val inputColType = schema(inputColName).dataType
require(
inputColType.isInstanceOf[NumericType],
s"The input column $inputColName must be numeric type, " +
s"but got $inputColType.") val attr = NominalAttribute.defaultAttr.withName(outputColName)
outputFields :+= attr.toStructField()
}
StructType(outputFields)
} def deiscretizeCol(dataset: Dataset[_], inputColName: String, numBuckets_t: Int) = {
val input_arr = dataset.select(col(inputColName).cast(DoubleType)).distinct().orderBy(inputColName).rdd.map(_.getDouble()).collect()
val splits = new ArrayBuffer[Double]()
splits.append(Double.MinValue)
splits.append(Double.MaxValue) var split_map_arr = new ArrayBuffer[scala.collection.mutable.Map[String, Any]]()
split_map_arr.append(scala.collection.mutable.Map(
"arr" -> input_arr,
"closure" -> true,
"node" -> null)) var flag = true
while (flag) {
for (split_map <- split_map_arr) {
if (split_map("node") == null) {
getBestPoint(split_map, dataset, inputColName)
}
} split_map_arr = split_map_arr.filter {
split_map => split_map("node").asInstanceOf[Map[String, Double]]("value") > $(minInfoGain)
} if (split_map_arr.length > ) {
val entropy_idxs = (Map[Double, Array[Int]]() /: split_map_arr.zipWithIndex) { (r, split_map_idx) =>
val (split_map, idx) = split_map_idx
val value = split_map("node").asInstanceOf[Map[String, Double]]("value") r + (value -> (r.get(value) match {
case Some(arr: Array[Int]) => arr :+ idx
case None => Array[Int](idx)
}))
} val split_map_arr_break = new ArrayBuffer[scala.collection.mutable.Map[String, Any]]() entropy_idxs(entropy_idxs.keys.max).zipWithIndex.foreach {
case (idx, i) => {
split_map_arr_break.append(split_map_arr.remove(idx - i))
}
} split_map_arr_break.foreach(
split_map => {
val point = split_map("node").asInstanceOf[Map[String, Double]]("point")
splits.append(point)
val arr = split_map("arr").asInstanceOf[Array[Double]]
val closure = split_map("closure").asInstanceOf[Boolean] var left_arr = Array[Double]()
var right_arr = Array[Double]()
for (e <- arr) {
if (e < point) {
left_arr :+= e
} else {
right_arr :+= e
}
}
left_arr :+= point val left_split_map = scala.collection.mutable.Map(
"arr" -> left_arr,
"closure" -> false,
"node" -> null) val right_split_map = scala.collection.mutable.Map(
"arr" -> right_arr,
"closure" -> closure,
"node" -> null) split_map_arr.append(left_split_map)
split_map_arr.append(right_split_map)
}) if (splits.length - >= numBuckets_t) {
flag = false
}
} else {
flag = false
}
} splits.toArray
} /*
* 获取最大信息增益分割点
*/
def getBestPoint(split_map: scala.collection.mutable.Map[String, Any], dataset: Dataset[_], inputColName: String) {
val arr: Array[Double] = split_map("arr").asInstanceOf[Array[Double]]
if (arr.length <= ) {
split_map("node") = Map("point" -> arr(), "value" -> 0.0)
return
} val start = arr()
val end = arr(arr.length - ) val closure = split_map("closure").asInstanceOf[Boolean]
var ds = dataset.filter(col(inputColName) >= start and col(inputColName) < end)
if (closure) {
ds = dataset.filter(col(inputColName) >= start and col(inputColName) <= end)
} val point_set = ds.select(col(inputColName).cast(DoubleType), col($(labelCol))).
groupBy(col(inputColName)).
pivot($(labelCol)).count.
orderBy(col(inputColName)).collect() val classNums = point_set().size -
val all_seq = (new Array[Long](classNums) /: point_set) { (arr, p) =>
val idx = ( to (p.size - )).foreach {
i =>
arr(i - ) += p(i).asInstanceOf[Long]
}
arr
} val info_entropy = entropy(all_seq)
val all_count = all_seq.sum var left_count = 0L
var left_seq = new Array[Long](classNums)
val gains = arr.zipWithIndex.map {
case (point, idx) =>
if (idx == ) {
(point, 0.0)
} else {
left_seq.zipWithIndex.foreach {
case (e, i) =>
left_seq(i) = e + point_set(idx - )(i + ).asInstanceOf[Long]
} val right_seq = all_seq.zip(left_seq).map {
case (all_e, left_e) =>
all_e - left_e
} val left_count = left_seq.sum
val right_count = right_seq.sum if (left_count < $(minInstancesPerBucket) || right_count < $(minInstancesPerBucket)) {
(point, -1.0)
} else {
val conditional_entropy = left_count * 1.0 / all_count * entropy(left_seq) + right_count * 1.0 / all_count * entropy(right_seq)
val gain = info_entropy - conditional_entropy
(point, gain)
}
}
} val gain_max = gains.map(_._2).max
val point = gains.filter(_._2 >= gain_max).map(_._1).apply()
val node = Map("point" -> point, "value" -> gain_max)
split_map("node") = node
} /*
* 计算信息熵, 单位nat
*/
def entropy(groupCounts: Seq[Long]) = {
val count = groupCounts.sum
if ($(impurity) == "gini") {
(0.0 /: groupCounts) { (sum, groupCount) =>
if (groupCount == ) {
sum
} else {
val p = groupCount * 1.0 / count
sum + (p * ( - p))
}
}
} else if ($(impurity) == "entropy") {
(0.0 /: groupCounts) { (sum, groupCount) =>
if (groupCount == ) {
sum
} else {
val p = groupCount * 1.0 / count
sum + (-p * math.log(p))
}
}
} else { //增加算法
Double.MaxValue
}
} }

对有序特征进行离散化(继承Spark的机器学习Estimator类)的更多相关文章

  1. 面向对象的三大特征——封装、继承、多态(&常用关键字)

    一.封装 Encapsulation 在面向对象程式设计方法中,封装是指,一种将抽象性函式接口的实作细节部份包装.隐藏起来的方法. 封装的概念(针对服务器开发,保护内部,确保服务器不出现问题) 将类的 ...

  2. JAVA基础知识总结5(面向对象特征之一:继承)

    继 承: 1:提高了代码的复用性. 2:让类与类之间产生了关系,提供了另一个特征多态的前提. 父类的由来:其实是由多个类不断向上抽取共性内容而来的. JAVA只支持单继承.java虽然不直接支持多继承 ...

  3. AJPFX总结面向对象特征之一的继承知识

    继 承(面向对象特征之一) 好处: 1:提高了代码的复用性. 2:让类与类之间产生了关系,提供了另一个特征多态的前提.   父类的由来:其实是由多个类不断向上抽取共性内容而来的. java中对于继承, ...

  4. 【JavaSE】面向对象三大特征——封装、继承、多态

    前言:本文主要介绍思想 封装 封装这一概念并不仅存在与面向对象中,甚至说封装这一概念不仅限于编程中,其实生活中的封装无处不在.比如 需求:你到银行取钱 参数:你只需要提供银行卡和密码 返回值:柜员会将 ...

  5. 面向对象编程(九)——面向对象三大特性之继承以及重写、Object类的介绍

    面向对象三大特性 面向对象三大特征:继承 :封装/隐藏 :多态(为了适应需求的多种变化,使代码变得更加通用!) 封装:主要实现了隐藏细节,对用户提供访问接口,无需关心方法的具体实现. 继承:很好的实现 ...

  6. 【转载】Spark学习 & 机器学习

    然后看的是机器学习这一块,因为偏理论,可以先看完.其他的实践,再看. http://www.cnblogs.com/shishanyuan/p/4747761.html “机器学习是用数据或以往的经验 ...

  7. 【转载】 C++多继承中重写不同基类中相同原型的虚函数

    本篇随笔为转载,原文地址:C++多继承中重写不同基类中相同原型的虚函数. 在C++多继承体系当中,在派生类中可以重写不同基类中的虚函数.下面就是一个例子: class CBaseA { public: ...

  8. 自定义继承于Page的基类

    自定义继承于Page的基类:MyBasePage[校验用户是否登录,如果登录则获取用户信息,否则跳转到登录页面]============================================ ...

  9. 从零开始学C++之继承(二):继承与构造函数、派生类到基类的转换

    一.不能自动继承的成员函数 构造函数 析构函数 =运算符 二.继承与构造函数 基类的构造函数不被继承,派生类中需要声明自己的构造函数. 声明构造函数时,只需要对本类中新增成员进行初始化,对继承来的基类 ...

随机推荐

  1. python中的exec和eval

    exec 描述 exec 执行储存在字符串或文件中的 Python 语句,相比于 eval,exec可以执行更复杂的 Python 代码. 返回值 exec 返回值永远为 None. 需要说明的是在 ...

  2. 三大框架整合模板ssh

    1.web.xml配置 <!-- 让spring随web启动而创建的监听器 --> <listener> <listener-class>org.springfra ...

  3. ORA-12638: Credential retrieval failed 解决办法

    ORA-12638 ORA-12638: Credential retrieval failed 身份证明检索失败     解决办法:   修改sqlnet.ora文件(位置:$ORACLE_HOME ...

  4. oracle数据库锁表

    在团队开发一个项目的时候,避免不了两个或两个以上的人同时操作某一数据库中的同一张表,这时候,如果一个用户没有提交事务,或者忘记提交事务,那么其他用户就不能对这张表进行操作了,这是很烦人的事情,下面是查 ...

  5. babyheap_fastbin_attack

    babyheap_fastbin_attack 首先检查程序保护 保护全开.是一个选单系统 分析程序 void new() { int index; // [rsp+0h] [rbp-10h] sig ...

  6. 从0到N建立高性价比的大数据平台(转载)

    2016-07-29 14:13:23 钱曙光 阅读数 794 原文链接:https://blog.csdn.net/qiansg123/article/details/80124521 声明:本文为 ...

  7. 【洛谷P4245】 【模板】任意模数NTT

    三模数 NTT,感觉不是很难写 $?$ 代码借鉴的 https://www.cnblogs.com/Mychael/p/9297652.html code: #include <bits/std ...

  8. python与各数据库的交互

    from redis import StrictRedis from pymongo import MongoClient import pymysql #redis客户端 redis_cli = S ...

  9. Linux中三种SCSI target的介绍之STGT

    版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/scaleqiao/article/deta ...

  10. C Primer Plus--C存储类、链接和内存管理之动态分配内存及类型限定词

    目录 存储类说明符 存储类和函数 动态分配内存 malloc函数 free函数 calloc函数 动态分配内存的缺点 C类型限定关键字 constant定义全局常量 volatile关键字 restr ...