最近在用Spark MLlib进行特征处理时,对于StringIndexer和IndexToString遇到了点问题,查阅官方文档也没有解决疑惑。无奈之下翻看源码才明白其中一二...这就给大家娓娓道来。

更多内容参考我的大数据学习之路

文档说明

StringIndexer 字符串转索引

StringIndexer可以把字符串的列按照出现频率进行排序,出现次数最高的对应的Index为0。比如下面的列表进行StringIndexer

id category
0 a
1 b
2 c
3 a
4 a
5 c

就可以得到如下:

id category categoryIndex
0 a 0.0
1 b 2.0
2 c 1.0
3 a 0.0
4 a 0.0
5 c 1.0

可以看到出现次数最多的"a",索引为0;次数最少的"b"索引为2。

针对训练集中没有出现的字符串值,spark提供了几种处理的方法:

  • error,直接抛出异常
  • skip,跳过该样本数据
  • keep,使用一个新的最大索引,来表示所有未出现的值

下面是基于Spark MLlib 2.2.0的代码样例:

package xingoo.ml.features.tranformer

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.StringIndexer object StringIndexerTest {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("string-indexer").getOrCreate()
spark.sparkContext.setLogLevel("WARN") val df = spark.createDataFrame(
Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
).toDF("id", "category") val df1 = spark.createDataFrame(
Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "e"), (5, "f"))
).toDF("id", "category") val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.setHandleInvalid("keep") //skip keep error val model = indexer.fit(df) val indexed = model.transform(df1)
indexed.show(false)
}
}

得到的结果为:

+---+--------+-------------+
|id |category|categoryIndex|
+---+--------+-------------+
|0 |a |0.0 |
|1 |b |2.0 |
|2 |c |1.0 |
|3 |a |0.0 |
|4 |e |3.0 |
|5 |f |3.0 |
+---+--------+-------------+

IndexToString 索引转字符串

这个索引转回字符串要搭配前面的StringIndexer一起使用才行:

package xingoo.ml.features.tranformer

import org.apache.spark.ml.attribute.Attribute
import org.apache.spark.ml.feature.{IndexToString, StringIndexer}
import org.apache.spark.sql.SparkSession object IndexToString2 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate()
spark.sparkContext.setLogLevel("WARN") val df = spark.createDataFrame(Seq(
(0, "a"),
(1, "b"),
(2, "c"),
(3, "a"),
(4, "a"),
(5, "c")
)).toDF("id", "category") val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.fit(df)
val indexed = indexer.transform(df) println(s"Transformed string column '${indexer.getInputCol}' " +
s"to indexed column '${indexer.getOutputCol}'")
indexed.show() val inputColSchema = indexed.schema(indexer.getOutputCol)
println(s"StringIndexer will store labels in output column metadata: " +
s"${Attribute.fromStructField(inputColSchema).toString}\n") val converter = new IndexToString()
.setInputCol("categoryIndex")
.setOutputCol("originalCategory") val converted = converter.transform(indexed) println(s"Transformed indexed column '${converter.getInputCol}' back to original string " +
s"column '${converter.getOutputCol}' using labels in metadata")
converted.select("id", "categoryIndex", "originalCategory").show()
}
}

得到的结果如下:

Transformed string column 'category' to indexed column 'categoryIndex'
+---+--------+-------------+
| id|category|categoryIndex|
+---+--------+-------------+
| 0| a| 0.0|
| 1| b| 2.0|
| 2| c| 1.0|
| 3| a| 0.0|
| 4| a| 0.0|
| 5| c| 1.0|
+---+--------+-------------+ StringIndexer will store labels in output column metadata: {"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"} Transformed indexed column 'categoryIndex' back to original string column 'originalCategory' using labels in metadata
+---+-------------+----------------+
| id|categoryIndex|originalCategory|
+---+-------------+----------------+
| 0| 0.0| a|
| 1| 2.0| b|
| 2| 1.0| c|
| 3| 0.0| a|
| 4| 0.0| a|
| 5| 1.0| c|
+---+-------------+----------------+

使用问题

假如处理的过程很复杂,重新生成了一个DataFrame,此时想要把这个DataFrame基于IndexToString转回原来的字符串怎么办呢? 先来试试看:

package xingoo.ml.features.tranformer

import org.apache.spark.ml.feature.{IndexToString, StringIndexer}
import org.apache.spark.sql.SparkSession object IndexToString3 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate()
spark.sparkContext.setLogLevel("WARN") val df = spark.createDataFrame(Seq(
(0, "a"),
(1, "b"),
(2, "c"),
(3, "a"),
(4, "a"),
(5, "c")
)).toDF("id", "category") val df2 = spark.createDataFrame(Seq(
(0, 2.0),
(1, 1.0),
(2, 1.0),
(3, 0.0)
)).toDF("id", "index") val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.fit(df)
val indexed = indexer.transform(df) val converter = new IndexToString()
.setInputCol("categoryIndex")
.setOutputCol("originalCategory") val converted = converter.transform(df2)
converted.show()
}
}

运行后发现异常:

18/07/05 20:20:32 INFO StateStoreCoordinatorRef: Registered StateStoreCoordinator endpoint
Exception in thread "main" java.lang.IllegalArgumentException: Field "categoryIndex" does not exist.
at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)
at scala.collection.AbstractMap.getOrElse(Map.scala:59)
at org.apache.spark.sql.types.StructType.apply(StructType.scala:265)
at org.apache.spark.ml.feature.IndexToString.transformSchema(StringIndexer.scala:338)
at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:74)
at org.apache.spark.ml.feature.IndexToString.transform(StringIndexer.scala:352)
at xingoo.ml.features.tranformer.IndexToString3$.main(IndexToString3.scala:37)
at xingoo.ml.features.tranformer.IndexToString3.main(IndexToString3.scala)

这是为什么呢?跟随源码来看吧!

源码剖析

首先我们创建一个DataFrame,获得原始数据:

val df = spark.createDataFrame(Seq(
(0, "a"),
(1, "b"),
(2, "c"),
(3, "a"),
(4, "a"),
(5, "c")
)).toDF("id", "category")

然后创建对应的StringIndexer:

val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.setHandleInvalid("skip")
.fit(df)

这里面的fit就是在训练转换器了,进入fit():

override def fit(dataset: Dataset[_]): StringIndexerModel = {
transformSchema(dataset.schema, logging = true)
// 这里针对需要转换的列先强制转换成字符串,然后遍历统计每个字符串出现的次数
val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))
.rdd
.map(_.getString(0))
.countByValue()
// counts是一个map,里面的内容为{a->3, b->1, c->2}
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
// 按照个数大小排序,返回数组,[a, c, b]
// 把这个label保存起来,并返回对应的model(mllib里边的模型都是这个套路,跟sklearn学的)
copyValues(new StringIndexerModel(uid, labels).setParent(this))
}

这样就得到了一个列表,列表里面的内容是[a, c, b],然后执行transform来进行转换:

val indexed = indexer.transform(df)

这个transform可想而知就是用这个数组对每一行的该列进行转换,但是它其实还做了其他的事情:

override def transform(dataset: Dataset[_]): DataFrame = {
...
// --------
// 通过label生成一个Metadata,这个很关键!!!
// metadata其实是一个map,内容为:
// {"ml_attr":{"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"}}
// --------
val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata() // 如果是skip则过滤一些数据
... // 下面是针对不同的情况处理转换的列,逻辑很简单
val indexer = udf { label: String =>
...
if (labelToIndex.contains(label)) {
labelToIndex(label) //如果正常,就进行转换
} else if (keepInvalid) {
labels.length // 如果是keep,就返回索引的最大值(即数组的长度)
} else {
... // 如果是error,就抛出异常
}
} // 保留之前所有的列,新增一个字段,并设置字段的StructField中的Metadata!!!!
// 并设置字段的StructField中的Metadata!!!!
// 并设置字段的StructField中的Metadata!!!!
// 并设置字段的StructField中的Metadata!!!! filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
}

看到了吗!关键的地方在这里,给新增加的字段的类型StructField设置了一个Metadata。这个Metadata正常都是空的{},但是这里设置了metadata之后,里面包含了label数组的信息。

接下来看看IndexToString是怎么用的,由于IndexToString是一个Transformer,因此只有一个trasform方法:

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val inputColSchema = dataset.schema($(inputCol)) // If the labels array is empty use column metadata
// 关键是这里:
// 如果IndexToString设置了labels数组,就直接返回;
// 否则,就读取了传入的DataFrame的StructField中的Metadata
val values = if (!isDefined(labels) || $(labels).isEmpty) {
Attribute.fromStructField(inputColSchema)
.asInstanceOf[NominalAttribute].values.get
} else {
$(labels)
} // 基于这个values把index转成对应的值
val indexer = udf { index: Double =>
val idx = index.toInt
if (0 <= idx && idx < values.length) {
values(idx)
} else {
throw new SparkException(s"Unseen index: $index ??")
}
}
val outputColName = $(outputCol)
dataset.select(col("*"),
indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
}

了解StringIndexer和IndexToString的原理机制后,就可以作出如下的应对策略了。

1 增加StructField的MetaData信息

 val df2 = spark.createDataFrame(Seq(
(0, 2.0),
(1, 1.0),
(2, 1.0),
(3, 0.0)
)).toDF("id", "index").select(col("*"),col("index").as("formated_index", indexed.schema("categoryIndex").metadata)) val converter = new IndexToString()
.setInputCol("formated_index")
.setOutputCol("origin_col") val converted = converter.transform(df2)
converted.show(false)
+---+-----+--------------+----------+
|id |index|formated_index|origin_col|
+---+-----+--------------+----------+
|0 |2.0 |2.0 |b |
|1 |1.0 |1.0 |c |
|2 |1.0 |1.0 |c |
|3 |0.0 |0.0 |a |
+---+-----+--------------+----------+

2 获取之前StringIndexer后的DataFrame中的Label信息

    val df3 = spark.createDataFrame(Seq(
(0, 2.0),
(1, 1.0),
(2, 1.0),
(3, 0.0)
)).toDF("id", "index") val converter2 = new IndexToString()
.setInputCol("index")
.setOutputCol("origin_col")
.setLabels(indexed.schema("categoryIndex").metadata.getMetadata("ml_attr").getStringArray("vals")) val converted2 = converter2.transform(df3)
converted2.show(false)
+---+-----+----------+
|id |index|origin_col|
+---+-----+----------+
|0 |2.0 |b |
|1 |1.0 |c |
|2 |1.0 |c |
|3 |0.0 |a |
+---+-----+----------+

两种方法都能得到正确的输出。

完整的代码可以参考github链接:

https://github.com/xinghalo/spark-in-action/blob/master/src/xingoo/ml/features/tranformer/IndexToStringTest.scala

最终还是推荐详细阅读官方文档,不过官方文档真心有些粗糙,想要了解其中的原理,还是得静下心来看看源码。

Spark MLlib 之 StringIndexer、IndexToString使用说明以及源码剖析的更多相关文章

  1. Apache Spark源码剖析

    Apache Spark源码剖析(全面系统介绍Spark源码,提供分析源码的实用技巧和合理的阅读顺序,充分了解Spark的设计思想和运行机理) 许鹏 著   ISBN 978-7-121-25420- ...

  2. 《Apache Spark源码剖析》

    Spark Contributor,Databricks工程师连城,华为大数据平台开发部部长陈亮,网易杭州研究院副院长汪源,TalkingData首席数据科学家张夏天联袂力荐1.本书全面.系统地介绍了 ...

  3. (升级版)Spark从入门到精通(Scala编程、案例实战、高级特性、Spark内核源码剖析、Hadoop高端)

    本课程主要讲解目前大数据领域最热门.最火爆.最有前景的技术——Spark.在本课程中,会从浅入深,基于大量案例实战,深度剖析和讲解Spark,并且会包含完全从企业真实复杂业务需求中抽取出的案例实战.课 ...

  4. Spark源码剖析 - SparkContext的初始化(二)_创建执行环境SparkEnv

    2. 创建执行环境SparkEnv SparkEnv是Spark的执行环境对象,其中包括众多与Executor执行相关的对象.由于在local模式下Driver会创建Executor,local-cl ...

  5. Spark源码剖析 - SparkContext的初始化(三)_创建并初始化Spark UI

    3. 创建并初始化Spark UI 任何系统都需要提供监控功能,用浏览器能访问具有样式及布局并提供丰富监控数据的页面无疑是一种简单.高效的方式.SparkUI就是这样的服务. 在大型分布式系统中,采用 ...

  6. Spark jdbc postgresql数据库连接和写入操作源码解读

    概述:Spark postgresql jdbc 数据库连接和写入操作源码解读,详细记录了SparkSQL对数据库的操作,通过java程序,在本地开发和运行.整体为,Spark建立数据库连接,读取数据 ...

  7. Dream_Spark-----Spark 定制版:005~贯通Spark Streaming流计算框架的运行源码

    Spark 定制版:005~贯通Spark Streaming流计算框架的运行源码   本讲内容: a. 在线动态计算分类最热门商品案例回顾与演示 b. 基于案例贯通Spark Streaming的运 ...

  8. Node 进阶:express 默认日志组件 morgan 从入门使用到源码剖析

    本文摘录自个人总结<Nodejs学习笔记>,更多章节及更新,请访问 github主页地址.欢迎加群交流,群号 197339705. 章节概览 morgan是express默认的日志中间件, ...

  9. 豌豆夹Redis解决方案Codis源码剖析:Dashboard

    豌豆夹Redis解决方案Codis源码剖析:Dashboard 1.不只是Dashboard 虽然名字叫Dashboard,但它在Codis中的作用却不可小觑.它不仅仅是Dashboard管理页面,更 ...

随机推荐

  1. 单个 LINQ to Entities 查询中的两个结构上不兼容的初始化过程中出现类型“XXXX”

    最近在做一个报表的时候,用EF使用了Contact方法,但是程式运行一直出错.最近终于找到原因了,写下来提醒下自己.好了,进入正题: 现在我举个栗子,目前数据库中有ParentStudent表和Sub ...

  2. OneNET麒麟座应用开发之七:控制采样电机

    气体采样采用主动抽取气体的方式保证充足而平稳的气流,所以我们采用气泵抽取气体来完成. 1.设计概述 客户对这部分要求能够设定电机的速度,但并不需要动态调节.对电机的控制有很多方式,我们采用比较简单的方 ...

  3. window 连linux

    https://blog.csdn.net/ruanjianruanjianruan/article/details/46954681 https://blog.csdn.net/u013754317 ...

  4. laravel 列表搜索查询(when,with用法以及关联图像id处理图像路径)

    laravel中比较常规的列表查询: /** * 活动列表 * @param Request $request * @return \Illuminate\Http\JsonResponse */ p ...

  5. eclipse的操作

    IDEA至少在4G内存的电脑才能使用 eclipse中:项目名字小写 close project:关掉项目 删除未尽的项目导入eclipse中的步骤: 左边右键>>>import&l ...

  6. #2 codeforces 480 Parcels

    题意: 就是有一个用来堆放货物的板,承重力为S.现在有N件货物,每件货物有到达的时间,运走的时间,以及重量,承重,存放盈利.如果这件货物能再运达时间存放,并在指定时间取走的话,就能获得相应的盈利值.货 ...

  7. python常用内建模块--collections

    1.namedtuple #namedtuple是一个函数,它用来创建一个自定义的tuple对象,并且规定了tuple元素的个数,并可以用属性而不是索引来引用tuple的某个元素.#这样一来,我们用n ...

  8. python面试笔试题,你都会了吗?快来复习

    1.一行代码实现1--100之和 利用sum()函数求和 >>> sum(range(0,101)) 5050 2.如何在一个函数内部修改全局变量 利用global 修改全局变量 a ...

  9. C#连接数据库MD5数据库加密

    创建StringHelper类 首先数据库里的资料是加密了的. 创建将指定的字符串加密为MD5密文方法 public static string ToMD5(string source){ Strin ...

  10. 解决html5中video标签无法播放mp4问题的办法

    这篇文章主要给大家介绍了关于解决html5中video标签无法播放mp4问题的办法,文中介绍的非常详细,相信会对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面来一起看看吧. 最近发现了一个 ...