Alink漫谈(十八) :源码解析 之 多列字符串编码MultiStringIndexer

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文将带领大家来分析Alink中 MultiStringIndexer 的实现。

因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。

本文缘由是想分析GBDT,发现GBDT涉及到MultiStringIndexer的使用,所以只能先分析MultiStringIndexer 。

0x01 概念

Alink的官方介绍是:MultiStringIndexer训练组件的作用是训练一个模型用于将多列字符串映射为整数。

具体来说,StringIndexer(字符串-索引变换)将标签的"字符串列"编码为"标签索引的列"。

  • 标签索引序列的取值范围是[0,numLabels(字符串中所有出现的单词去掉重复的词后的总和)],按照标签出现频率排序,出现最多的标签索引为0(具体为升序降序是可以配置的)。
  • 如果输入是数值型,我们先将数值映射到字符串,再对字符串进行索引化。
  • 如果下游的pipeline(例如:Estimator或者Transformer)需要用到索引化后的标签序列,则需要将这个pipeline的输入列名字指定为索引化序列的名字。大部分情况下,通过setSelectedCols设置输入的列名。

以这些输入为例:

("football", "can"),
("football", "hhh"),
("football", "zzz"),
("basketball", "zzz"),
("basketball", "can"),
("tennis", "can")

对于第一列,MultiStringIndexer 对数据集的label进行重新编号。按label出现的频次,转换成0 ~ numOfLabels - 1(分类个数)。如果是按照从高到低排序,则频次最高的转换为0,以此类推,比如:

  • football,出现次数最多,出现了3次,转换(编号)为0
  • 其次是basketball,出现了2次,编号为1,以此类推。

在应用StringIndexer对labels进行重新编号后,带着这些编号后的label对数据进行了训练,并接着对其他数据进行了预测,得到预测结果,预测结果的label也是重新编号过的,因此需要转换回来。

0x02 示例代码

示例代码如下,本示例代码中,是按照升序排列,即football总数为3,则其idx为3,tennis个数为1,其idx为0:

public class MultiStringIndexerExample {
static AlgoOperator getData(boolean isBatch) {
Row[] array = new Row[] {
Row.of("football", "can"),
Row.of("football", "hhh"),
Row.of("football", "zzz"),
Row.of("basketball", "zzz"),
Row.of("basketball", "can"),
Row.of("tennis", "can")
}; if (isBatch) {
return new MemSourceBatchOp(
Arrays.asList(array), new String[] {"a", "b"});
} else {
return new MemSourceStreamOp(
Arrays.asList(array), new String[] {"a", "b"});
}
} public static void main(String[] args) throws Exception {
BatchOperator data = (BatchOperator)getData(true);
MultiStringIndexer stringindexer = new MultiStringIndexer()
.setSelectedCols("a", "b")
.setOutputCols("a_indexed", "b_indexed")
.setStringOrderType("frequency_asc");
stringindexer.fit(data).transform(data).print();
}
}

输出如下:

a|b|a_indexed|b_indexed
-|-|---------|---------
football|can|2|2
football|hhh|2|0
football|zzz|2|1
basketball|zzz|1|1
basketball|can|1|2
tennis|can|0|2

转换成表格看的更清楚。

a b a_indexed b_indexed
football can 2 2
football hhh 2 0
football zzz 2 1
basketball zzz 1 1
basketball can 1 2
tennis can 0 2

0x03 总体逻辑

我们先给出一个流程图

老套路,我们从 MultiStringIndexerTrainBatchOp.linkFrom开始挖掘。

@Override
public MultiStringIndexerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs); // 示例中有 .setSelectedCols("a", "b"),这里是取出具体列名字
final String[] selectedColNames = getSelectedCols();
// 获取列的类型
final String[] selectedColSqlType = new String[selectedColNames.length];
for (int i = 0; i < selectedColNames.length; i++) {
selectedColSqlType[i] = FlinkTypeConverter.getTypeString(
TableUtil.findColTypeWithAssertAndHint(in.getSchema(), selectedColNames[i]));
} // runtime打印数据
selectedColNames = {String[2]@2536}
0 = "a"
1 = "b"
selectedColSqlType = {String[2]@2537}
0 = "VARCHAR"
1 = "VARCHAR" // 获取选取列对应的数据
DataSet<Row> inputRows = in.select(selectedColNames).getDataSet();
//
DataSet<Tuple3<Integer, String, Long>> indexedToken =
StringIndexerUtil.indexTokens(inputRows, getStringOrderType(), 0L, true); DataSet<Row> values = indexedToken
.mapPartition(new RichMapPartitionFunction<Tuple3<Integer, String, Long>, Row>() {
@Override
public void mapPartition(Iterable<Tuple3<Integer, String, Long>> values, Collector<Row> out)
throws Exception {
Params meta = null;
if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
// 第一个task会做这个计算,就是把列名,列类型作为元数据传送
meta = new Params().set(HasSelectedCols.SELECTED_COLS, selectedColNames)
.set(HasSelectedColTypes.SELECTED_COL_TYPES, selectedColSqlType);
} // runtime打印数据
meta = {Params@9311} "Params {selectedCols=["a","b"], selectedColTypes=["VARCHAR","VARCHAR"]}"
params = {HashMap@9316} size = 2 new MultiStringIndexerModelDataConverter().save(Tuple2.of(meta, values), out);
}
})
.name("build_model"); this.setOutput(values, new MultiStringIndexerModelDataConverter().getModelSchema());
return this;
}

训练过程总体逻辑总结如下:

  • 取出具体列名字,列的类型;
  • 获取"选取列"对应的数据;
  • 把列名,列类型作为元数据传送;
  • StringIndexerUtil.indexTokens 给各个列的不同字串赋予连续的indices。每列的 indices 彼此不相关;
    • 调用到 indexSortedByFreq(data, startIndex, ignoreNull, true),作用是给各个列的不同字串赋予连续的indices,indices是按照字符串出现的频率排序;

      • 调用到 countTokens的作用是按照 "列idx","word" 来合并计算单词个数,得到<"列idx","word",单词个数>,比如第一列中,football这个单词的个数是3,则返回三元组是 <0,football,3>,其中列的idx从0开始计算。

        • 调用 flattenTokens 把输入数据 Row 给打散,返回 A DataSet of tuples of column index and token,即<"列idx","word">。比如对于 Row.of("football", "can") 这个输入,flattenTokens 输出两个Tuple2 ,<0, "football"> 和 <1, "can">。
        • 对上面结构进行map操作,输出<column idx, word, 1L>,比如 <0, "football", 1L> ;
        • 按照 "列idx","word" 来分组;
        • 按照 "列idx","word" 来合并计算单词个数;
      • indexSortedByFreq会对countTokens返回的结果<"列idx","word",词频>处理;
        • 首先按照 列idx 做分组;
        • 然后在上面结果基础上,按照单词个数排序;
        • 排序的index是以输入参数startIndex开始,startIndex在这里是0;
        • 最后得到 第一列的 (0,football,0),(0,basketball,1),(0,football,2);第二列的数据 (1,hhh,0),(1,zzz,1),(1,can,2);
  • 把indexTokens的结果存储为模型,其中使用之前提到的 "把列名,列类型作为元数据"。

下面具体剖析后两个阶段。

0x04 Add Index to Token

这部分就是给各个列的不同字串赋予连续的indices。每列的 indices 彼此不相关。

具体是由StringIndexerUtil.indexTokens 做到的。

public static DataSet<Tuple3<Integer, String, Long>> indexTokens(
DataSet<Row> data, HasStringOrderTypeDefaultAsRandom.StringOrderType orderType,
final long startIndex, final boolean ignoreNull) {
case FREQUENCY_ASC:
return indexSortedByFreq(data, startIndex, ignoreNull, true);
}

4.1 合并计算单词个数

indexSortedByFreq会调用countTokens来计算单词个数,所以我们先看countTokens。

countTokens的作用是按照 "列idx","word" 来合并计算单词个数,比如第一列中,football这个单词的个数是3,则返回三元组是 <0,football,3>,其中列的idx从0开始计算。

具体逻辑如下:

  • 调用 flattenTokens 把输入数据 Row 给打散,返回 A DataSet of tuples of column index and token,即<"列idx","word">。比如对于 Row.of("football", "can") 这个输入,flattenTokens 输出两个Tuple2 ,<0, "football"> 和 <1, "can">。
  • 对上面结果进行map操作,输出<column idx, word, 1L>,比如 <0, "football", 1L> ,这个是计数的常规操作。
  • 按照 "列idx","word" 来分组;
  • 按照 "列idx","word" 来合并计算单词个数,就是不停归并上面的 1L。

4.1.1 打散输入数据

其中 flattenTokens 的作用是把输入数据 Row 给打散,返回 A DataSet of tuples of column index and token.。

比如对于 Row.of("football", "can") 这个输入,flattenTokens 使用 out.collect(Tuple2.of(i, String.valueOf(o))); 输出两个Tuple2。

value = {Row@9212} "football,can"
fields = {Object[2]@9215}
0 = "football"
1 = "can" 输出 <0, "football"> 和 <1, "can">

4.1.2 分组计算个数

这是通过flattenTokens的结果进行 map,groupBy,reduce的一系列操作完成的。

具体代码如下:

public static DataSet<Tuple3<Integer, String, Long>> countTokens(DataSet<Row> data, final boolean ignoreNull) {
return flattenTokens(data, ignoreNull) // 把输入数据 Row 给打散
.map(new MapFunction<Tuple2<Integer, String>, Tuple3<Integer, String, Long>>() {
@Override
public Tuple3<Integer, String, Long> map(Tuple2<Integer, String> value) throws Exception {
return Tuple3.of(value.f0, value.f1, 1L); // 输出<column idx, word, 1L>,比如 <0, "football", 1L>
}
})
.groupBy(0, 1) // 按照 "列idx","word" 来分组
.reduce(new ReduceFunction<Tuple3<Integer, String, Long>>() {
@Override
public Tuple3<Integer, String, Long> reduce(Tuple3<Integer, String, Long> value1, Tuple3<Integer, String, Long> value2) throws Exception {
value1.f2 += value2.f2;
return value1; // 按照 "列idx","word" 来合并计算单词个数
}
})
.name("count_tokens");
} // reduce之后发出
value1 = {Tuple3@9284} "(0,football,3)"
f0 = {Integer@9226} 0
f1 = "football"
f2 = {Long@9295} 3

4.2 合并计算单词个数

前面 countTokens的 返回三元组是 <列idx","word" ,词频>,其中列的idx从0开始计算。

indexSortedByFreq会对countTokens返回的结果<"列idx","word",词频>处理;

  • 首先按照 列idx 做分组;
  • 然后在上面结果基础上,按照单词个数排序;
  • 排序的index是以输入参数startIndex开始,startIndex在这里是0;
  • 最后得到 第一列的 (0,tennis,0),(0,basketball,1),(0,football,2);第二列的数据 (1,hhh,0),(1,zzz,1),(1,can,2);

具体代码如下:

public static DataSet<Tuple3<Integer, String, Long>> indexSortedByFreq(
DataSet<Row> data, final long startIndex, final boolean ignoreNull, final boolean isAscending) {
return countTokens(data, ignoreNull)
.groupBy(0) //按照 列idx 做分组
.sortGroup(2, isAscending ? Order.ASCENDING : Order.DESCENDING) //按照单词个数排序
.reduceGroup(new GroupReduceFunction<Tuple3<Integer, String, Long>, Tuple3<Integer, String, Long>>() {
@Override
public void reduce(Iterable<Tuple3<Integer, String, Long>> values,
Collector<Tuple3<Integer, String, Long>> out) {
long id = startIndex;
for (Tuple3<Integer, String, Long> value : values) {
out.collect(Tuple3.of(value.f0, value.f1, id++)); // 归并
}
}
});
}

0x05 输出模型

这部分分为两部分:

  • 输出元数据,就是之前得到的 "把列名,列类型作为元数据"。
  • 输出具体每一列的每一个单词信息,比如 第一列的 (0,tennis,0),(0,basketball,1),(0,football,2);第二列的数据 (1,hhh,0),(1,zzz,1),(1,can,2);
public class MultiStringIndexerModelDataConverter implements
ModelDataConverter<Tuple2<Params, Iterable<Tuple3<Integer, String, Long>>>, MultiStringIndexerModelData> {
@Override
public void save(Tuple2<Params, Iterable<Tuple3<Integer, String, Long>>> modelData, Collector<Row> collector) {
if (modelData.f0 != null) {
collector.collect(Row.of(-1L, modelData.f0.toJson(), null));
}
modelData.f1.forEach(tuple -> {
collector.collect(Row.of(tuple.f0.longValue(), tuple.f1, tuple.f2));
});
}
} tuple = {Tuple3@9405} "(0,tennis,0)"
f0 = {Integer@9406} 0
f1 = "tennis"
f2 = {Long@9408} 0

0x06 预测

预测功能是在 ModelMapperAdapter 完成的。

public class ModelMapperAdapter extends RichMapFunction<Row, Row> implements Serializable {
private final ModelMapper mapper;
private final ModelSource modelSource; @Override
public void open(Configuration parameters) throws Exception {
List<Row> modelRows = this.modelSource.getModelRows(getRuntimeContext());
this.mapper.loadModel(modelRows); //加载模型
} @Override
public Row map(Row row) throws Exception {
return this.mapper.map(row); //预测
}
}

6.1 加载模型

MultiStringIndexerModelDataConverter中我们会进行模型加载。

  • 首先会加载元信息
  • 其次会逐条加载模型信息
public MultiStringIndexerModelData load(List<Row> rows) {
MultiStringIndexerModelData modelData = new MultiStringIndexerModelData();
modelData.tokenAndIndex = new ArrayList<>();
modelData.tokenNumber = new HashMap<>();
for (Row row : rows) {
long colIndex = (Long) row.getField(0);
if (colIndex < 0L) { // 元数据
modelData.meta = Params.fromJson((String) row.getField(1));
} else { // 具体模型信息
int columnIndex = ((Long) row.getField(0)).intValue();
Long tokenIndex = Long.valueOf(String.valueOf(row.getField(2)));
modelData.tokenAndIndex.add(Tuple3.of(columnIndex, (String) row.getField(1), tokenIndex));
modelData.tokenNumber.merge(columnIndex, 1L, Long::sum); // 合并列数据个数
}
} // To ensure that every columns has token number.
int numFields = 0;
if (modelData.meta != null) {
numFields = modelData.meta.get(HasSelectedCols.SELECTED_COLS).length;
}
for (int i = 0; i < numFields; i++) {
modelData.tokenNumber.merge(i, 0L, Long::sum);
}
return modelData;
}

最后模型内容如下,其中 tokenNumber 表示每列的数据有几个,tokenAndIndex表示具体信息,比如(0,tennis,0),(0,basketball,1),(0,football,2) 就表示他们都是第一列的,basketball转换后的数据是 1:

modelData = {MultiStringIndexerModelData@9348}
meta = {Params@9440} "Params {selectedCols=["a","b"], selectedColTypes=["VARCHAR","VARCHAR"]}"
tokenAndIndex = {ArrayList@9360} size = 6
0 = {Tuple3@9472} "(0,football,2)"
1 = {Tuple3@9511} "(0,tennis,0)"
2 = {Tuple3@9512} "(1,zzz,1)"
3 = {Tuple3@9513} "(1,hhh,0)"
4 = {Tuple3@9514} "(0,basketball,1)"
5 = {Tuple3@9515} "(1,can,2)"
tokenNumber = {HashMap@9385} size = 2
{Integer@9507} 0 -> {Long@9508} 3
{Integer@9509} 1 -> {Long@9508} 3
numFields = 2

6.2 预测

预测是在 MultiStringIndexerModelMapper 完成的。

// 假设输入是:row = {Row@9309} "football,can"
// 选择的列是:selectedColNames = {String[2]@9314} 0 = "a" 1 = "b"
// 模型映射器是:
this = {MultiStringIndexerModelMapper@9309}
indexMapper = {HashMap@9318} size = 2
{Integer@9357} 0 -> {HashMap@9314} size = 3
key = {Integer@9357} 0
value = 0
value = {HashMap@9314} size = 3
"basketball" -> {Long@9386} 1
"football" -> {Long@9332} 2
"tennis" -> {Long@9384} 0
{Integer@9352} 1 -> {HashMap@9358} size = 3
key = {Integer@9352} 1
value = 1
value = {HashMap@9358} size = 3
"can" -> {Long@9332} 2
"hhh" -> {Long@9384} 0
"zzz" -> {Long@9386} 1

则经历过下列代码,最后就可以进行预测

public Row map(Row row) throws Exception {
Row result = new Row(selectedColNames.length);
for (int i = 0; i < selectedColNames.length; i++) {
Map<String, Long> mapper = indexMapper.get(i);
int colIdxInData = selectedColIndicesInData[i];
Object val = row.getField(colIdxInData);
String key = val == null ? null : String.valueOf(val);
Long index = mapper.get(key);
if (index != null) {
result.setField(i, index); // 我们主要执行在这里
} else {
}
} // 最后预测结果是:
row = {Row@9308} "football,can"
result = {Row@9313} "2,2" return outputColsHelper.getResultRow(row, result);
}

0xFF 参考

Spark之特征预处理

Alink漫谈(十八) :源码解析 之 多列字符串编码MultiStringIndexer的更多相关文章

  1. Alink漫谈(二) : 从源码看机器学习平台Alink设计和架构

    Alink漫谈(二) : 从源码看机器学习平台Alink设计和架构 目录 Alink漫谈(二) : 从源码看机器学习平台Alink设计和架构 0x00 摘要 0x01 Alink设计原则 0x02 A ...

  2. Alink漫谈(十九) :源码解析 之 分位点离散化Quantile

    Alink漫谈(十九) :源码解析 之 分位点离散化Quantile 目录 Alink漫谈(十九) :源码解析 之 分位点离散化Quantile 0x00 摘要 0x01 背景概念 1.1 离散化 1 ...

  3. Alink漫谈(二十) :卡方检验源码解析

    Alink漫谈(二十) :卡方检验源码解析 目录 Alink漫谈(二十) :卡方检验源码解析 0x00 摘要 0x01 背景概念 1.1 假设检验 1.2 H0和H1是什么? 1.3 P值 (P-va ...

  4. Alink漫谈(十六) :Word2Vec源码分析 之 建立霍夫曼树

    Alink漫谈(十六) :Word2Vec源码分析 之 建立霍夫曼树 目录 Alink漫谈(十六) :Word2Vec源码分析 之 建立霍夫曼树 0x00 摘要 0x01 背景概念 1.1 词向量基础 ...

  5. springboot源码解析-管中窥豹系列之BeanPostProcessor(十二)

    一.前言 Springboot源码解析是一件大工程,逐行逐句的去研究代码,会很枯燥,也不容易坚持下去. 我们不追求大而全,而是试着每次去研究一个小知识点,最终聚沙成塔,这就是我们的springboot ...

  6. springboot源码解析-管中窥豹系列之bean如何生成?(十四)

    一.前言 Springboot源码解析是一件大工程,逐行逐句的去研究代码,会很枯燥,也不容易坚持下去. 我们不追求大而全,而是试着每次去研究一个小知识点,最终聚沙成塔,这就是我们的springboot ...

  7. 第三十六节,目标检测之yolo源码解析

    在一个月前,我就已经介绍了yolo目标检测的原理,后来也把tensorflow实现代码仔细看了一遍.但是由于这个暑假事情比较大,就一直搁浅了下来,趁今天有时间,就把源码解析一下.关于yolo目标检测的 ...

  8. 第十四章 Executors源码解析

    前边两章介绍了基础线程池ThreadPoolExecutor的使用方式.工作机理.参数详细介绍以及核心源码解析. 具体的介绍请参照: 第十二章 ThreadPoolExecutor使用与工作机理 第十 ...

  9. Netty 源码解析(八): 回到 Channel 的 register 操作

    原创申明:本文由公众号[猿灯塔]原创,转载请说明出处标注 今天是猿灯塔“365篇原创计划”第八篇. 接下来的时间灯塔君持续更新Netty系列一共九篇 Netty 源码解析(一): 开始 Netty 源 ...

随机推荐

  1. 用Python爬取双色球开奖信息,了解一下

    1工具     2具体方法 1.使用python2.7编写爬取脚本 这里除了正常的爬取操作,还增加了独立的参数设定.如果没有参数,爬取的数据就在当前目录下:如果有参数,可以设定保存目录.保存文件名后缀 ...

  2. MySQL(四)数据备份与还原

    数据备份与还原: 备份:将当前已有的数据或者记录保留 还原:将已经保留的数据恢复到对应的表中 为什么要做备份还原: 1.防止数据丢失:被盗.误操作 2.保护数据的记录 数据备份还原的方式很多:数据表备 ...

  3. python基础--14大内置模块(上)

    python的内置模块(重点掌握以下模块) 什么是模块 常见的场景:一个模块就是一个包含了python定义和声明的文件,文件名就是模块名字加上.py的后缀. 但其实import加载的模块分为四个通用类 ...

  4. Ansible部署zabbix-agent

    playbook目录 zabbix/ ├── hosts ##定义的主机列表 ├── install_zabbix_agent.yml ##安装入口文件 └── roles ├── install_z ...

  5. lua中 string.find(查找获取字符串) string.gsub(查找替换字符串) string.sub(截取字符串)

    > aaa='/p/v2/api/winapi/adapter/lgj'> print(string.find(aaa, "^/.+/adapter/(.*)"))1 ...

  6. Nginx配置各种响应头防止XSS,点击劫持,frame恶意攻击

    为什么要配置HTTP响应头? 不知道各位有没有被各类XSS攻击.点击劫持 (ClickJacking. frame 恶意引用等等方式骚扰过,百度联盟被封就有这些攻击的功劳在里面.为此一直都在搜寻相关防 ...

  7. CentOS6.5安装Oracle11g

    安装前必读: 1.      安装Oracle的虚拟机需要固定IP. 2.      注意安装过程中root用户与oracle用户的切换(su root/su oracle) 3.      环境变量 ...

  8. 在excel中如何给一列数据批量加上双引号

    在实际开发中,会遇到这样的需求,大量的数据,需要从配置文件里读取,客户给到的枚举值是字符串,而配置文件里的数据,是json格式,需要加上双引号,这样就需要使用Excel来批量格式化一下数据. 客户给到 ...

  9. Python简单的语句组

    Python简单的语句组: ''' if 条件1: 条件1满足时,需要运行的内容 ''' num = 10 if num % 6 == 4: print("num 对 6 的取模结果是 4& ...

  10. HTTP Request Method(十五种)

    序号 方法 描述 1 GET 请求指定的页面信息,并返回实体主体. 2 HEAD 类似于get请求,只不过返回的响应中没有具体的内容,用于获取报头 3 POST 向指定资源提交数据进行处理请求(例如提 ...