Alink漫谈(八) : 二分类评估 AUC、K-S、PRC、Precision、Recall、LiftChart 如何实现

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。二分类评估是对二分类算法的预测结果进行效果评估。本文将剖析Alink中对应代码实现。

0x01 相关概念

如果对本文某些概念有疑惑,可以参见之前文章 [白话解析] 通过实例来梳理概念 :准确率 (Accuracy)、精准率(Precision)、召回率(Recall) 和 F值(F-Measure)

0x02 示例代码

  1. public class EvalBinaryClassExample {
  2. AlgoOperator getData(boolean isBatch) {
  3. Row[] rows = new Row[]{
  4. Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"),
  5. Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"),
  6. Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"),
  7. Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"),
  8. Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}")
  9. };
  10. String[] schema = new String[]{"label", "detailInput"};
  11. if (isBatch) {
  12. return new MemSourceBatchOp(rows, schema);
  13. } else {
  14. return new MemSourceStreamOp(rows, schema);
  15. }
  16. }
  17. public static void main(String[] args) throws Exception {
  18. EvalBinaryClassExample test = new EvalBinaryClassExample();
  19. BatchOperator batchData = (BatchOperator) test.getData(true);
  20. BinaryClassMetrics metrics = new EvalBinaryClassBatchOp()
  21. .setLabelCol("label")
  22. .setPredictionDetailCol("detailInput")
  23. .linkFrom(batchData)
  24. .collectMetrics();
  25. System.out.println("RocCurve:" + metrics.getRocCurve());
  26. System.out.println("AUC:" + metrics.getAuc());
  27. System.out.println("KS:" + metrics.getKs());
  28. System.out.println("PRC:" + metrics.getPrc());
  29. System.out.println("Accuracy:" + metrics.getAccuracy());
  30. System.out.println("Macro Precision:" + metrics.getMacroPrecision());
  31. System.out.println("Micro Recall:" + metrics.getMicroRecall());
  32. System.out.println("Weighted Sensitivity:" + metrics.getWeightedSensitivity());
  33. }
  34. }

程序输出

  1. RocCurve:([0.0, 0.0, 0.0, 0.5, 0.5, 1.0, 1.0],[0.0, 0.3333333333333333, 0.6666666666666666, 0.6666666666666666, 1.0, 1.0, 1.0])
  2. AUC:0.8333333333333333
  3. KS:0.6666666666666666
  4. PRC:0.9027777777777777
  5. Accuracy:0.6
  6. Macro Precision:0.3
  7. Micro Recall:0.6
  8. Weighted Sensitivity:0.6

在 Alink 中,二分类评估有批处理,流处理两种实现,下面一一为大家介绍( Alink 复杂之一在于大量精细的数据结构,所以下文会大量打印程序中变量以便大家理解)。

2.1 主要思路

  • 把 [0,1] 分成假设 100000个桶(bin)。所以得到positiveBin / negativeBin 两个100000的数组。

  • 根据输入给positiveBin / negativeBin赋值。positiveBin就是 TP + FP,negativeBin就是 TN + FN。这些是后续计算的基础。

  • 遍历bins中每一个有意义的点,计算出totalTrue和totalFalse,并且在每一个点上计算该点的混淆矩阵,tpr,以及rocCurve,recallPrecisionCurve,liftChart在该点对应的数据;

  • 依据曲线内容计算并且存储 AUC/PRC/KS

具体后续还有详细调用关系综述。

0x03 批处理

3.1 EvalBinaryClassBatchOp

EvalBinaryClassBatchOp是二分类评估的实现,功能是计算二分类的评估指标(evaluation metrics)。

输入有两种:

  • label column and predResult column
  • label column and predDetail column。如果有predDetail,则predResult被忽略

我们例子中 "prefix1" 就是 label,"{\"prefix1\": 0.9, \"prefix0\": 0.1}" 就是 predDetail

  1. Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}")

具体类摘录如下:

  1. public class EvalBinaryClassBatchOp extends BaseEvalClassBatchOp<EvalBinaryClassBatchOp> implements BinaryEvaluationParams <EvalBinaryClassBatchOp>, EvaluationMetricsCollector<BinaryClassMetrics> {
  2. @Override
  3. public BinaryClassMetrics collectMetrics() {
  4. return new BinaryClassMetrics(this.collect().get(0));
  5. }
  6. }

可以看到,其主要工作都是在基类BaseEvalClassBatchOp中完成,所以我们会首先看BaseEvalClassBatchOp。

3.2 BaseEvalClassBatchOp

我们还是从 linkFrom 函数入手,其主要是做了几件事:

  • 获取配置信息
  • 从输入中提取某些列:"label","detailInput"
  • calLabelPredDetailLocal会按照partition分别计算evaluation metrics
  • 综合reduce上述计算结果
  • SaveDataAsParams函数会把最终数值输入到 output table

具体代码如下

  1. @Override
  2. public T linkFrom(BatchOperator<?>... inputs) {
  3. BatchOperator<?> in = checkAndGetFirst(inputs);
  4. String labelColName = this.get(MultiEvaluationParams.LABEL_COL);
  5. String positiveValue = this.get(BinaryEvaluationParams.POS_LABEL_VAL_STR);
  6. // Judge the evaluation type from params.
  7. ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams());
  8. DataSet<BaseMetricsSummary> res;
  9. switch (type) {
  10. case PRED_DETAIL: {
  11. String predDetailColName = this.get(MultiEvaluationParams.PREDICTION_DETAIL_COL);
  12. // 从输入中提取某些列:"label","detailInput"
  13. DataSet<Row> data = in.select(new String[] {labelColName, predDetailColName}).getDataSet();
  14. // 按照partition分别计算evaluation metrics
  15. res = calLabelPredDetailLocal(data, positiveValue, binary);
  16. break;
  17. }
  18. ......
  19. }
  20. // 综合reduce上述计算结果
  21. DataSet<BaseMetricsSummary> metrics = res
  22. .reduce(new EvaluationUtil.ReduceBaseMetrics());
  23. // 把最终数值输入到 output table
  24. this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()),
  25. new String[] {DATA_OUTPUT}, new TypeInformation[] {Types.STRING});
  26. return (T)this;
  27. }
  28. // 执行中一些变量如下
  29. labelColName = "label"
  30. predDetailColName = "detailInput"
  31. type = {ClassificationEvaluationUtil$Type@2532} "PRED_DETAIL"
  32. binary = true
  33. positiveValue = null

3.2.0 调用关系综述

因为后续代码调用关系复杂,所以先给出一个调用关系

  • 从输入中提取某些列:"label","detailInput",in.select(new String[] {labelColName, predDetailColName}).getDataSet()。因为可能输入还有其他列,而只有某些列是我们计算需要的,所以只提取这些列。
  • 按照partition分别计算evaluation metrics,即调用 calLabelPredDetailLocal(data, positiveValue, binary);
    • flatMap会从label列和prediction列中,取出所有labels(注意是取出labels的名字 ),发送给下游算子。
    • reduceGroup主要功能是通过 buildLabelIndexLabelArray 去重 "labels名字",然后给每一个label一个ID,得到一个 <labels, ID>的map,最后返回是二元组(map, labels),即({prefix1=0, prefix0=1},[prefix1, prefix0])。从后文看,<labels, ID>Map看来是多分类才用到。二分类只用到了labels。
    • mapPartition 分区调用 CalLabelDetailLocal 来计算混淆矩阵,主要是分区调用getDetailStatistics,前文中得到的二元组(map, labels)会作为参数传递进来 。
      • getDetailStatistics 遍历 rows 数据,提取每一个item(比如 "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"),然后通过updateBinaryMetricsSummary累积计算混淆矩阵所需数据。

        • updateBinaryMetricsSummary 把 [0,1] 分成假设 100000个桶(bin)。所以得到positiveBin / negativeBin 两个100000的数组。positiveBin就是 TP + FP,negativeBin就是 TN + FN。

          • 如果某个 sample 为 正例 (positive value) 的概率是 p, 则该 sample 对应的 bin index 就是 p * 100000。如果 p 被预测为正例 (positive value) ,则positiveBin[index]++,
          • 否则就是被预测为负例(negative value) ,则negativeBin[index]++。
  • 综合reduce上述计算结果,metrics = res.reduce(new EvaluationUtil.ReduceBaseMetrics());
    • 具体计算是在BinaryMetricsSummary.merge,其作用就是Merge the bins, and add the logLoss。
  • 把最终数值输入到 output table,setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()..);
    • 归并所有BaseMetrics后,得到total BaseMetrics,计算indexes存入params。collector.collect(t.toMetrics().serialize());

      • 实际业务在BinaryMetricsSummary.toMetrics,即基于bin的信息计算,然后存储到params。

        • extractMatrixThreCurve函数取出非空的bins,据此计算出ConfusionMatrix array(混淆矩阵), threshold array, rocCurve/recallPrecisionCurve/LiftChart.

          • 遍历bins中每一个有意义的点,计算出totalTrue和totalFalse,并且在每一个点上计算:
          • curTrue += positiveBin[index]; curFalse += negativeBin[index];
          • 得到该点的混淆矩阵 new ConfusionMatrix(new long[][] {{curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
          • 得到 tpr = (totalTrue == 0 ? 1.0 : 1.0 * curTrue / totalTrue);
          • rocCurve,recallPrecisionCurve,liftChart在该点对应的数据;
        • 依据曲线内容计算并且存储 AUC/PRC/KS
        • 对生成的rocCurve/recallPrecisionCurve/LiftChart输出进行抽样
        • 依据抽样后的输出存储 RocCurve/RecallPrecisionCurve/LiftChar
        • 存储正例样本的度量指标
        • 存储Logloss
        • Pick the middle point where threshold is 0.5.

3.2.1 calLabelPredDetailLocal

本函数按照partition分别计算评估指标 evaluation metrics。是的,这代码很短,但是有个地方需要注意。有时候越简单的地方越容易疏漏。容易疏漏点是:

第一行代码的结果 labels 是第二行代码的参数,而并非第二行主体。第二行代码主体和第一行代码主体一样,都是data。

  1. private static DataSet<BaseMetricsSummary> calLabelPredDetailLocal(DataSet<Row> data, final String positiveValue, oolean binary) {
  2. DataSet<Tuple2<Map<String, Integer>, String[]>> labels = data.flatMap(new FlatMapFunction<Row, String>() {
  3. @Override
  4. public void flatMap(Row row, Collector<String> collector) {
  5. TreeMap<String, Double> labelProbMap;
  6. if (EvaluationUtil.checkRowFieldNotNull(row)) {
  7. labelProbMap = EvaluationUtil.extractLabelProbMap(row);
  8. labelProbMap.keySet().forEach(collector::collect);
  9. collector.collect(row.getField(0).toString());
  10. }
  11. }
  12. }).reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary, positiveValue));
  13. return data
  14. .rebalance()
  15. .mapPartition(new CalLabelDetailLocal(binary))
  16. .withBroadcastSet(labels, LABELS);
  17. }

calLabelPredDetailLocal中具体分为三步骤:

  • 在flatMap会从label列和prediction列中,取出所有labels(注意是取出labels的名字 ),发送给下游算子。
  • reduceGroup的主要功能是去重 "labels名字",然后给每一个label一个ID,最后结果是一个<labels, ID>Map。
  • mapPartition 是分区调用 CalLabelDetailLocal 来计算混淆矩阵。

下面具体看看。

3.2.1.1 flatMap

在flatMap中,主要是从label列和prediction列中,取出所有labels(注意是取出labels的名字 ),发送给下游算子。

EvaluationUtil.extractLabelProbMap 作用就是解析输入的json,获得具体detailInput中的信息。

下游算子是reduceGroup,所以Flink runtime会对这些labels自动去重。如果对这部分有兴趣,可以参见我之前介绍reduce的文章。CSDN : [源码解析] Flink的groupBy和reduce究竟做了什么 博客园 : [源码解析] Flink的groupBy和reduce究竟做了什么

程序中变量如下

  1. row = {Row@8922} "prefix1,{"prefix1": 0.9, "prefix0": 0.1}"
  2. fields = {Object[2]@8925}
  3. 0 = "prefix1"
  4. 1 = "{"prefix1": 0.9, "prefix0": 0.1}"
  5. labelProbMap = {TreeMap@9008} size = 2
  6. "prefix0" -> {Double@9015} 0.1
  7. "prefix1" -> {Double@9017} 0.9
  8. labelProbMap.keySet().forEach(collector::collect); //这里发送 "prefix0", "prefix1"
  9. collector.collect(row.getField(0).toString()); // 这里发送 "prefix1"
  10. // 因为下一个操作是reduceGroup,所以这些label会被runtime去重
3.2.1.2 reduceGroup

主要功能是通过buildLabelIndexLabelArray去重labels,然后给每一个label一个ID,最后结果是一个<labels, ID>的Map。

  1. reduceGroup(new EvaluationUtil.DistinctLabelIndexMap(binary, positiveValue));

DistinctLabelIndexMap的作用是从label列和prediction列中,取出所有不同的labels,返回一个<labels, ID>的map,根据后续代码看,这个map是多分类才用到。Get all the distinct labels from label column and prediction column, and return the map of labels and their IDs.

前面已经提到,这里的参数rows已经被自动去重。

  1. public static class DistinctLabelIndexMap implements
  2. GroupReduceFunction<String, Tuple2<Map<String, Integer>, String[]>> {
  3. ......
  4. @Override
  5. public void reduce(Iterable<String> rows, Collector<Tuple2<Map<String, Integer>, String[]>> collector) throws Exception {
  6. HashSet<String> labels = new HashSet<>();
  7. rows.forEach(labels::add);
  8. collector.collect(buildLabelIndexLabelArray(labels, binary, positiveValue));
  9. }
  10. }
  11. // 变量为
  12. labels = {HashSet@9008} size = 2
  13. 0 = "prefix1"
  14. 1 = "prefix0"
  15. binary = true

buildLabelIndexLabelArray的作用是给每一个label一个ID,得到一个 <labels, ID>的map,最后返回是二元组(map, labels),即({prefix1=0, prefix0=1},[prefix1, prefix0])。

  1. // Give each label an ID, return a map of label and ID.
  2. public static Tuple2<Map<String, Integer>, String[]> buildLabelIndexLabelArray(HashSet<String> set,boolean binary, String positiveValue) {
  3. String[] labels = set.toArray(new String[0]);
  4. Arrays.sort(labels, Collections.reverseOrder());
  5. Map<String, Integer> map = new HashMap<>(labels.length);
  6. if (binary && null != positiveValue) {
  7. if (labels[1].equals(positiveValue)) {
  8. labels[1] = labels[0];
  9. labels[0] = positiveValue;
  10. }
  11. map.put(labels[0], 0);
  12. map.put(labels[1], 1);
  13. } else {
  14. for (int i = 0; i < labels.length; i++) {
  15. map.put(labels[i], i);
  16. }
  17. }
  18. return Tuple2.of(map, labels);
  19. }
  20. // 程序变量如下
  21. labels = {String[2]@9013}
  22. 0 = "prefix1"
  23. 1 = "prefix0"
  24. map = {HashMap@9014} size = 2
  25. "prefix1" -> {Integer@9020} 0
  26. "prefix0" -> {Integer@9021} 1
3.2.1.3 mapPartition

这里主要功能是分区调用 CalLabelDetailLocal 来为后来计算混淆矩阵做准备。

  1. return data
  2. .rebalance()
  3. .mapPartition(new CalLabelDetailLocal(binary)) //这里是业务所在
  4. .withBroadcastSet(labels, LABELS);

具体工作是 CalLabelDetailLocal 完成的,其作用是分区调用getDetailStatistics

  1. // Calculate the confusion matrix based on the label and predResult.
  2. static class CalLabelDetailLocal extends RichMapPartitionFunction<Row, BaseMetricsSummary> {
  3. private Tuple2<Map<String, Integer>, String[]> map;
  4. private boolean binary;
  5. @Override
  6. public void open(Configuration parameters) throws Exception {
  7. List<Tuple2<Map<String, Integer>, String[]>> list = getRuntimeContext().getBroadcastVariable(LABELS);
  8. this.map = list.get(0);// 前文生成的二元组(map, labels)
  9. }
  10. @Override
  11. public void mapPartition(Iterable<Row> rows, Collector<BaseMetricsSummary> collector) {
  12. // 调用到了 getDetailStatistics
  13. collector.collect(getDetailStatistics(rows, binary, map));
  14. }
  15. }

getDetailStatistics 的作用是:初始化分类评估的度量指标 base classification evaluation metrics,累积计算混淆矩阵需要的数据。主要就是遍历 rows 数据,提取每一个item(比如 "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"),然后累积计算混淆矩阵所需数据。

  1. // Initialize the base classification evaluation metrics. There are two cases: BinaryClassMetrics and MultiClassMetrics.
  2. private static BaseMetricsSummary getDetailStatistics(Iterable<Row> rows,
  3. String positiveValue,
  4. boolean binary,
  5. Tuple2<Map<String, Integer>, String[]> tuple) {
  6. BinaryMetricsSummary binaryMetricsSummary = null;
  7. MultiMetricsSummary multiMetricsSummary = null;
  8. Tuple2<Map<String, Integer>, String[]> labelIndexLabelArray = tuple; // 前文生成的二元组(map, labels)
  9. Iterator<Row> iterator = rows.iterator();
  10. Row row = null;
  11. while (iterator.hasNext() && !checkRowFieldNotNull(row)) {
  12. row = iterator.next();
  13. }
  14. Map<String, Integer> labelIndexMap = null;
  15. if (binary) {
  16. // 二分法在这里
  17. binaryMetricsSummary = new BinaryMetricsSummary(
  18. new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER],
  19. new long[ClassificationEvaluationUtil.DETAIL_BIN_NUMBER],
  20. labelIndexLabelArray.f1, 0.0, 0L);
  21. } else {
  22. //
  23. labelIndexMap = labelIndexLabelArray.f0; // 前文生成的<labels, ID>Map看来是多分类才用到。
  24. multiMetricsSummary = new MultiMetricsSummary(
  25. new long[labelIndexMap.size()][labelIndexMap.size()],
  26. labelIndexLabelArray.f1, 0.0, 0L);
  27. }
  28. while (null != row) {
  29. if (checkRowFieldNotNull(row)) {
  30. TreeMap<String, Double> labelProbMap = extractLabelProbMap(row);
  31. String label = row.getField(0).toString();
  32. if (ArrayUtils.indexOf(labelIndexLabelArray.f1, label) >= 0) {
  33. if (binary) {
  34. // 二分法在这里
  35. updateBinaryMetricsSummary(labelProbMap, label, binaryMetricsSummary);
  36. } else {
  37. updateMultiMetricsSummary(labelProbMap, label, labelIndexMap, multiMetricsSummary);
  38. }
  39. }
  40. }
  41. row = iterator.hasNext() ? iterator.next() : null;
  42. }
  43. return binary ? binaryMetricsSummary : multiMetricsSummary;
  44. }
  45. //变量如下
  46. tuple = {Tuple2@9252} "({prefix1=0, prefix0=1},[prefix1, prefix0])"
  47. f0 = {HashMap@9257} size = 2
  48. "prefix1" -> {Integer@9264} 0
  49. "prefix0" -> {Integer@9266} 1
  50. f1 = {String[2]@9258}
  51. 0 = "prefix1"
  52. 1 = "prefix0"
  53. row = {Row@9271} "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"
  54. fields = {Object[2]@9276}
  55. 0 = "prefix1"
  56. 1 = "{"prefix1": 0.8, "prefix0": 0.2}"
  57. labelIndexLabelArray = {Tuple2@9240} "({prefix1=0, prefix0=1},[prefix1, prefix0])"
  58. f0 = {HashMap@9288} size = 2
  59. "prefix1" -> {Integer@9294} 0
  60. "prefix0" -> {Integer@9296} 1
  61. f1 = {String[2]@9242}
  62. 0 = "prefix1"
  63. 1 = "prefix0"
  64. labelProbMap = {TreeMap@9342} size = 2
  65. "prefix0" -> {Double@9378} 0.1
  66. "prefix1" -> {Double@9380} 0.9

先回忆下混淆矩阵:

预测值 0 预测值 1
真实值 0 TN FP
真实值 1 FN TP

针对混淆矩阵,BinaryMetricsSummary 的作用是Save the evaluation data for binary classification。函数具体计算思路是:

  • 把 [0,1] 分成ClassificationEvaluationUtil.DETAIL_BIN_NUMBER(100000)这么多桶(bin)。所以binaryMetricsSummary的positiveBin/negativeBin分别是两个100000的数组。如果某一个 sample 为 正例(positive value) 的概率是 p, 则该 sample 对应的 bin index 就是 p * 100000。如果 p 被预测为正例(positive value) ,则positiveBin[index]++,否则就是被预测为负例(negative value) ,则negativeBin[index]++。positiveBin就是 TP + FP,negativeBin就是 TN + FN。

  • 所以这里会遍历输入,如果某一个输入(以"prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"为例),0.9 是prefix1(正例) 的概率,0.1 是为prefix0(负例) 的概率。

    • 既然这个算法选择了 prefix1(正例) ,所以就说明此算法是判别成 positive 的,所以在 positiveBin 的 90000 处 + 1。
    • 假设这个算法选择了 prefix0(负例) ,则说明此算法是判别成 negative 的,所以应该在 negativeBin 的 90000 处 + 1。

具体对应我们示例代码的5个采样,分类如下:

  1. Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}"), positiveBin 90000处+1
  2. Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"), positiveBin 80000处+1
  3. Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"), positiveBin 70000处+1
  4. Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"), negativeBin 75000处+1
  5. Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}") negativeBin 60000处+1

具体代码如下

  1. public static void updateBinaryMetricsSummary(TreeMap<String, Double> labelProbMap,
  2. String label,
  3. BinaryMetricsSummary binaryMetricsSummary) {
  4. binaryMetricsSummary.total++;
  5. binaryMetricsSummary.logLoss += extractLogloss(labelProbMap, label);
  6. double d = labelProbMap.get(binaryMetricsSummary.labels[0]);
  7. int idx = d == 1.0 ? ClassificationEvaluationUtil.DETAIL_BIN_NUMBER - 1 :
  8. (int)Math.floor(d * ClassificationEvaluationUtil.DETAIL_BIN_NUMBER);
  9. if (idx >= 0 && idx < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER) {
  10. if (label.equals(binaryMetricsSummary.labels[0])) {
  11. binaryMetricsSummary.positiveBin[idx] += 1;
  12. } else if (label.equals(binaryMetricsSummary.labels[1])) {
  13. binaryMetricsSummary.negativeBin[idx] += 1;
  14. } else {
  15. .....
  16. }
  17. }
  18. }
  19. private static double extractLogloss(TreeMap<String, Double> labelProbMap, String label) {
  20. Double prob = labelProbMap.get(label);
  21. prob = null == prob ? 0. : prob;
  22. return -Math.log(Math.max(Math.min(prob, 1 - LOG_LOSS_EPS), LOG_LOSS_EPS));
  23. }
  24. // 变量如下
  25. ClassificationEvaluationUtil.DETAIL_BIN_NUMBER=100000
  26. // 当 "prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}" 时候
  27. labelProbMap = {TreeMap@9305} size = 2
  28. "prefix0" -> {Double@9331} 0.1
  29. "prefix1" -> {Double@9333} 0.9
  30. d = 0.9
  31. idx = 90000
  32. binaryMetricsSummary = {BinaryMetricsSummary@9262}
  33. labels = {String[2]@9242}
  34. 0 = "prefix1"
  35. 1 = "prefix0"
  36. total = 1
  37. positiveBin = {long[100000]@9263} // 90000处+1
  38. negativeBin = {long[100000]@9264}
  39. logLoss = 0.10536051565782628
  40. // 当 "prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}" 时候
  41. labelProbMap = {TreeMap@9514} size = 2
  42. "prefix0" -> {Double@9546} 0.4
  43. "prefix1" -> {Double@9547} 0.6
  44. d = 0.6
  45. idx = 60000
  46. binaryMetricsSummary = {BinaryMetricsSummary@9262}
  47. labels = {String[2]@9242}
  48. 0 = "prefix1"
  49. 1 = "prefix0"
  50. total = 2
  51. positiveBin = {long[100000]@9263}
  52. negativeBin = {long[100000]@9264} // 60000处+1
  53. logLoss = 1.0216512475319812

3.2.2 ReduceBaseMetrics

ReduceBaseMetrics作用是把局部计算的 BaseMetrics 聚合起来。

  1. DataSet<BaseMetricsSummary> metrics = res
  2. .reduce(new EvaluationUtil.ReduceBaseMetrics());

ReduceBaseMetrics如下

  1. public static class ReduceBaseMetrics implements ReduceFunction<BaseMetricsSummary> {
  2. @Override
  3. public BaseMetricsSummary reduce(BaseMetricsSummary t1, BaseMetricsSummary t2) throws Exception {
  4. return null == t1 ? t2 : t1.merge(t2);
  5. }
  6. }

具体计算是在BinaryMetricsSummary.merge,其作用就是Merge the bins, and add the logLoss。

  1. @Override
  2. public BinaryMetricsSummary merge(BinaryMetricsSummary binaryClassMetrics) {
  3. for (int i = 0; i < this.positiveBin.length; i++) {
  4. this.positiveBin[i] += binaryClassMetrics.positiveBin[i];
  5. }
  6. for (int i = 0; i < this.negativeBin.length; i++) {
  7. this.negativeBin[i] += binaryClassMetrics.negativeBin[i];
  8. }
  9. this.logLoss += binaryClassMetrics.logLoss;
  10. this.total += binaryClassMetrics.total;
  11. return this;
  12. }
  13. // 程序变量是
  14. this = {BinaryMetricsSummary@9316}
  15. labels = {String[2]@9322}
  16. 0 = "prefix1"
  17. 1 = "prefix0"
  18. total = 2
  19. positiveBin = {long[100000]@9320}
  20. negativeBin = {long[100000]@9323}
  21. logLoss = 1.742969305058623

3.2.3 SaveDataAsParams

  1. this.setOutput(metrics.flatMap(new EvaluationUtil.SaveDataAsParams()),
  2. new String[] {DATA_OUTPUT}, new TypeInformation[] {Types.STRING});

当归并所有BaseMetrics之后,得到了total BaseMetrics,计算indexes,存入到params。

  1. public static class SaveDataAsParams implements FlatMapFunction<BaseMetricsSummary, Row> {
  2. @Override
  3. public void flatMap(BaseMetricsSummary t, Collector<Row> collector) throws Exception {
  4. collector.collect(t.toMetrics().serialize());
  5. }
  6. }

实际业务在BinaryMetricsSummary.toMetrics中完成,即基于bin的信息计算,得到confusionMatrix array, threshold array, rocCurve/recallPrecisionCurve/LiftChart等等,然后存储到params。

  1. public BinaryClassMetrics toMetrics() {
  2. Params params = new Params();
  3. // 生成若干曲线,比如rocCurve/recallPrecisionCurve/LiftChart
  4. Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> matrixThreCurve =
  5. extractMatrixThreCurve(positiveBin, negativeBin, total);
  6. // 依据曲线内容计算并且存储 AUC/PRC/KS
  7. setCurveAreaParams(params, matrixThreCurve.f2);
  8. // 对生成的rocCurve/recallPrecisionCurve/LiftChart输出进行抽样
  9. Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> sampledMatrixThreCurve = sample(
  10. PROBABILITY_INTERVAL, matrixThreCurve);
  11. // 依据抽样后的输出存储 RocCurve/RecallPrecisionCurve/LiftChar
  12. setCurvePointsParams(params, sampledMatrixThreCurve);
  13. ConfusionMatrix[] matrices = sampledMatrixThreCurve.f0;
  14. // 存储正例样本的度量指标
  15. setComputationsArrayParams(params, sampledMatrixThreCurve.f1, sampledMatrixThreCurve.f0);
  16. // 存储Logloss
  17. setLoglossParams(params, logLoss, total);
  18. // Pick the middle point where threshold is 0.5.
  19. int middleIndex = getMiddleThresholdIndex(sampledMatrixThreCurve.f1);
  20. setMiddleThreParams(params, matrices[middleIndex], labels);
  21. return new BinaryClassMetrics(params);
  22. }

extractMatrixThreCurve是全文重点。这里是 Extract the bins who are not empty, keep the middle threshold 0.5,然后初始化了 RocCurve, Recall-Precision Curve and Lift Curve,计算出ConfusionMatrix array(混淆矩阵), threshold array, rocCurve/recallPrecisionCurve/LiftChart.。

  1. /**
  2. * Extract the bins who are not empty, keep the middle threshold 0.5.
  3. * Initialize the RocCurve, Recall-Precision Curve and Lift Curve.
  4. * RocCurve: (FPR, TPR), starts with (0,0). Recall-Precision Curve: (recall, precision), starts with (0, p), p is the precision with the lowest. LiftChart: (TP+FP/total, TP), starts with (0,0). confusion matrix = [TP FP][FN * TN].
  5. *
  6. * @param positiveBin positiveBins.
  7. * @param negativeBin negativeBins.
  8. * @param total sample number
  9. * @return ConfusionMatrix array, threshold array, rocCurve/recallPrecisionCurve/LiftChart.
  10. */
  11. static Tuple3<ConfusionMatrix[], double[], EvaluationCurve[]> extractMatrixThreCurve(long[] positiveBin, long[] negativeBin, long total) {
  12. ArrayList<Integer> effectiveIndices = new ArrayList<>();
  13. long totalTrue = 0, totalFalse = 0;
  14. // 计算totalTrue,totalFalse,effectiveIndices
  15. for (int i = 0; i < ClassificationEvaluationUtil.DETAIL_BIN_NUMBER; i++) {
  16. if (0L != positiveBin[i] || 0L != negativeBin[i]
  17. || i == ClassificationEvaluationUtil.DETAIL_BIN_NUMBER / 2) {
  18. effectiveIndices.add(i);
  19. totalTrue += positiveBin[i];
  20. totalFalse += negativeBin[i];
  21. }
  22. }
  23. // 以我们例子,得到
  24. effectiveIndices = {ArrayList@9273} size = 6
  25. 0 = {Integer@9277} 50000 //这里加入了中间点
  26. 1 = {Integer@9278} 60000
  27. 2 = {Integer@9279} 70000
  28. 3 = {Integer@9280} 75000
  29. 4 = {Integer@9281} 80000
  30. 5 = {Integer@9282} 90000
  31. totalTrue = 3
  32. totalFalse = 2
  33. // 继续初始化,生成若干curve
  34. final int length = effectiveIndices.size();
  35. final int newLen = length + 1;
  36. final double m = 1.0 / ClassificationEvaluationUtil.DETAIL_BIN_NUMBER;
  37. EvaluationCurvePoint[] rocCurve = new EvaluationCurvePoint[newLen];
  38. EvaluationCurvePoint[] recallPrecisionCurve = new EvaluationCurvePoint[newLen];
  39. EvaluationCurvePoint[] liftChart = new EvaluationCurvePoint[newLen];
  40. ConfusionMatrix[] data = new ConfusionMatrix[newLen];
  41. double[] threshold = new double[newLen];
  42. long curTrue = 0;
  43. long curFalse = 0;
  44. // 以我们例子,得到
  45. length = 6
  46. newLen = 7
  47. m = 1.0E-5
  48. // 计算, 其中rocCurve,recallPrecisionCurve,liftChart 都可以从代码中看出
  49. for (int i = 1; i < newLen; i++) {
  50. int index = effectiveIndices.get(length - i);
  51. curTrue += positiveBin[index];
  52. curFalse += negativeBin[index];
  53. threshold[i] = index * m;
  54. // 计算出混淆矩阵
  55. data[i] = new ConfusionMatrix(
  56. new long[][] {{curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
  57. double tpr = (totalTrue == 0 ? 1.0 : 1.0 * curTrue / totalTrue);
  58. // 比如当 90000 这点,得到 curTrue = 1 curFalse = 0 i = 1 index = 90000 tpr = 0.3333333333333333。totalTrue = 3 totalFalse = 2,
  59. // 我们也知道,TPR = TP / (TP + FN) ,所以可以计算 tpr = 1 / 3
  60. rocCurve[i] = new EvaluationCurvePoint(totalFalse == 0 ? 1.0 : 1.0 * curFalse / totalFalse, tpr, threshold[i]);
  61. recallPrecisionCurve[i] = new EvaluationCurvePoint(tpr, curTrue + curTrue == 0 ? 1.0 : 1.0 * curTrue / (curTrue + curFalse), threshold[i]);
  62. liftChart[i] = new EvaluationCurvePoint(1.0 * (curTrue + curFalse) / total, curTrue, threshold[i]);
  63. }
  64. // 以我们例子,得到
  65. curTrue = 3
  66. curFalse = 2
  67. threshold = {double[7]@9349}
  68. 0 = 0.0
  69. 1 = 0.9
  70. 2 = 0.8
  71. 3 = 0.7500000000000001
  72. 4 = 0.7000000000000001
  73. 5 = 0.6000000000000001
  74. 6 = 0.5
  75. rocCurve = {EvaluationCurvePoint[7]@9315}
  76. 1 = {EvaluationCurvePoint@9440}
  77. x = 0.0
  78. y = 0.3333333333333333
  79. p = 0.9
  80. 2 = {EvaluationCurvePoint@9448}
  81. x = 0.0
  82. y = 0.6666666666666666
  83. p = 0.8
  84. 3 = {EvaluationCurvePoint@9449}
  85. x = 0.5
  86. y = 0.6666666666666666
  87. p = 0.7500000000000001
  88. 4 = {EvaluationCurvePoint@9450}
  89. x = 0.5
  90. y = 1.0
  91. p = 0.7000000000000001
  92. 5 = {EvaluationCurvePoint@9451}
  93. x = 1.0
  94. y = 1.0
  95. p = 0.6000000000000001
  96. 6 = {EvaluationCurvePoint@9452}
  97. x = 1.0
  98. y = 1.0
  99. p = 0.5
  100. recallPrecisionCurve = {EvaluationCurvePoint[7]@9320}
  101. 1 = {EvaluationCurvePoint@9444}
  102. x = 0.3333333333333333
  103. y = 1.0
  104. p = 0.9
  105. 2 = {EvaluationCurvePoint@9453}
  106. x = 0.6666666666666666
  107. y = 1.0
  108. p = 0.8
  109. 3 = {EvaluationCurvePoint@9454}
  110. x = 0.6666666666666666
  111. y = 0.6666666666666666
  112. p = 0.7500000000000001
  113. 4 = {EvaluationCurvePoint@9455}
  114. x = 1.0
  115. y = 0.75
  116. p = 0.7000000000000001
  117. 5 = {EvaluationCurvePoint@9456}
  118. x = 1.0
  119. y = 0.6
  120. p = 0.6000000000000001
  121. 6 = {EvaluationCurvePoint@9457}
  122. x = 1.0
  123. y = 0.6
  124. p = 0.5
  125. liftChart = {EvaluationCurvePoint[7]@9325}
  126. 1 = {EvaluationCurvePoint@9458}
  127. x = 0.2
  128. y = 1.0
  129. p = 0.9
  130. 2 = {EvaluationCurvePoint@9459}
  131. x = 0.4
  132. y = 2.0
  133. p = 0.8
  134. 3 = {EvaluationCurvePoint@9460}
  135. x = 0.6
  136. y = 2.0
  137. p = 0.7500000000000001
  138. 4 = {EvaluationCurvePoint@9461}
  139. x = 0.8
  140. y = 3.0
  141. p = 0.7000000000000001
  142. 5 = {EvaluationCurvePoint@9462}
  143. x = 1.0
  144. y = 3.0
  145. p = 0.6000000000000001
  146. 6 = {EvaluationCurvePoint@9463}
  147. x = 1.0
  148. y = 3.0
  149. p = 0.5
  150. data = {ConfusionMatrix[7]@9339}
  151. 0 = {ConfusionMatrix@9486}
  152. longMatrix = {LongMatrix@9488}
  153. matrix = {long[2][]@9491}
  154. 0 = {long[2]@9492}
  155. 0 = 0
  156. 1 = 0
  157. 1 = {long[2]@9493}
  158. 0 = 3
  159. 1 = 2
  160. rowNum = 2
  161. colNum = 2
  162. labelCnt = 2
  163. total = 5
  164. actualLabelFrequency = {long[2]@9489}
  165. 0 = 3
  166. 1 = 2
  167. predictLabelFrequency = {long[2]@9490}
  168. 0 = 0
  169. 1 = 5
  170. tpCount = 2.0
  171. tnCount = 2.0
  172. fpCount = 3.0
  173. fnCount = 3.0
  174. 1 = {ConfusionMatrix@9435}
  175. longMatrix = {LongMatrix@9469}
  176. matrix = {long[2][]@9472}
  177. 0 = {long[2]@9474}
  178. 0 = 1
  179. 1 = 0
  180. 1 = {long[2]@9475}
  181. 0 = 2
  182. 1 = 2
  183. rowNum = 2
  184. colNum = 2
  185. labelCnt = 2
  186. total = 5
  187. actualLabelFrequency = {long[2]@9470}
  188. 0 = 3
  189. 1 = 2
  190. predictLabelFrequency = {long[2]@9471}
  191. 0 = 1
  192. 1 = 4
  193. tpCount = 3.0
  194. tnCount = 3.0
  195. fpCount = 2.0
  196. fnCount = 2.0
  197. ......
  198. threshold[0] = 1.0;
  199. data[0] = new ConfusionMatrix(new long[][] {{0, 0}, {totalTrue, totalFalse}});
  200. rocCurve[0] = new EvaluationCurvePoint(0, 0, threshold[0]);
  201. recallPrecisionCurve[0] = new EvaluationCurvePoint(0, recallPrecisionCurve[1].getY(), threshold[0]);
  202. liftChart[0] = new EvaluationCurvePoint(0, 0, threshold[0]);
  203. return Tuple3.of(data, threshold, new EvaluationCurve[] {new EvaluationCurve(rocCurve),
  204. new EvaluationCurve(recallPrecisionCurve), new EvaluationCurve(liftChart)});
  205. }

3.2.4 计算混淆矩阵

这里再给大家讲讲混淆矩阵如何计算,这里思路比较绕。

3.2.4.1 原始矩阵

调用之处是:

  1. // 调用之处
  2. data[i] = new ConfusionMatrix(
  3. new long[][] {{curTrue, curFalse}, {totalTrue - curTrue, totalFalse - curFalse}});
  4. // 调用时候各种赋值
  5. i = 1
  6. index = 90000
  7. totalTrue = 3
  8. totalFalse = 2
  9. curTrue = 1
  10. curFalse = 0

得到原始矩阵,以下都有cur,说明只针对当前点来说

curTrue = 1 curFalse = 0
totalTrue - curTrue = 2 totalFalse - curFalse = 2
3.2.4.2 计算标签

后续ConfusionMatrix计算中,由此可以得到

  1. actualLabelFrequency = longMatrix.getColSums();
  2. predictLabelFrequency = longMatrix.getRowSums();
  3. actualLabelFrequency = {long[2]@9322}
  4. 0 = 3
  5. 1 = 2
  6. predictLabelFrequency = {long[2]@9323}
  7. 0 = 1
  8. 1 = 4

可以看出来,Alink算法认为:每列的sum和实际标签有关;每行sum和预测标签有关。

得到新矩阵如下

predictLabelFrequency
curTrue = 1 curFalse = 0 1 = curTrue + curFalse
totalTrue - curTrue = 2 totalFalse - curFalse = 2 4 = total - curTrue - curFalse
actualLabelFrequency 3 = totalTrue 2 = totalFalse

后续计算将要基于这些来计算:

计算中就用到longMatrix 对角线上的数据,即longMatrix(0)(0)和 longMatrix(1)(1)。一定要注意,这里考虑的都是 当前状态 (画重点强调)

longMatrix(0)(0) :curTrue

longMatrix(1)(1) :totalFalse - curFalse

totalFalse :( TN + FN )

totalTrue :( TP + FP )

  1. double numTrueNegative(Integer labelIndex) {
  2. // labelIndex为 0 时候,return 1 + 5 - 1 - 3 = 2;
  3. // labelIndex为 1 时候,return 2 + 5 - 4 - 2 = 1;
  4. return null == labelIndex ? tnCount : longMatrix.getValue(labelIndex, labelIndex) + total - predictLabelFrequency[labelIndex] - actualLabelFrequency[labelIndex];
  5. }
  6. double numTruePositive(Integer labelIndex) {
  7. // labelIndex为 0 时候,return 1; 这个是 curTrue,就是真实标签是True,判别也是True。是TP
  8. // labelIndex为 1 时候,return 2; 这个是 totalFalse - curFalse,总判别错 - 当前判别错。这就意味着“本来判别错了但是当前没有发现”,所以认为在当前状态下,这也算是TP
  9. return null == labelIndex ? tpCount : longMatrix.getValue(labelIndex, labelIndex);
  10. }
  11. double numFalseNegative(Integer labelIndex) {
  12. // labelIndex为 0 时候,return 3 - 1;
  13. // actualLabelFrequency[0] = totalTrue。所以return totalTrue - curTrue,即当前“全部正确”中没有“判别为正确”,这个就可以认为是“判别错了且判别为负”
  14. // labelIndex为 1 时候,return 2 - 2;
  15. // actualLabelFrequency[1] = totalFalse。所以return totalFalse - ( totalFalse - curFalse ) = curFalse
  16. return null == labelIndex ? fnCount : actualLabelFrequency[labelIndex] - longMatrix.getValue(labelIndex, labelIndex);
  17. }
  18. double numFalsePositive(Integer labelIndex) {
  19. // labelIndex为 0 时候,return 1 - 1;
  20. // predictLabelFrequency[0] = curTrue + curFalse。
  21. // 所以 return = curTrue + curFalse - curTrue = curFalse = current( TN + FN ) 这可以认为是判断错了实际是正确标签
  22. // labelIndex为 1 时候,return 4 - 2;
  23. // predictLabelFrequency[1] = total - curTrue - curFalse。
  24. // 所以 return = total - curTrue - curFalse - (totalFalse - curFalse) = totalTrue - curTrue = ( TP + FP ) - currentTP = currentFP
  25. return null == labelIndex ? fpCount : predictLabelFrequency[labelIndex] - longMatrix.getValue(labelIndex, labelIndex);
  26. }
  27. // 最后得到
  28. tpCount = 3.0
  29. tnCount = 3.0
  30. fpCount = 2.0
  31. fnCount = 2.0
3.2.4.3 具体代码
  1. // 具体计算
  2. public ConfusionMatrix(LongMatrix longMatrix) {
  3. longMatrix = {LongMatrix@9297}
  4. 0 = {long[2]@9324}
  5. 0 = 1
  6. 1 = 0
  7. 1 = {long[2]@9325}
  8. 0 = 2
  9. 1 = 2
  10. this.longMatrix = longMatrix;
  11. labelCnt = this.longMatrix.getRowNum();
  12. // 这里就是计算
  13. actualLabelFrequency = longMatrix.getColSums();
  14. predictLabelFrequency = longMatrix.getRowSums();
  15. actualLabelFrequency = {long[2]@9322}
  16. 0 = 3
  17. 1 = 2
  18. predictLabelFrequency = {long[2]@9323}
  19. 0 = 1
  20. 1 = 4
  21. labelCnt = 2
  22. total = 5
  23. total = longMatrix.getTotal();
  24. for (int i = 0; i < labelCnt; i++) {
  25. tnCount += numTrueNegative(i);
  26. tpCount += numTruePositive(i);
  27. fnCount += numFalseNegative(i);
  28. fpCount += numFalsePositive(i);
  29. }
  30. }

0x04 流处理

4.1 示例

Alink原有python示例代码中,Stream部分是没有输出的,因为MemSourceStreamOp没有和时间相关联,而Alink中没有提供基于时间的StreamOperator,所以只能自己仿照MemSourceBatchOp写了一个。虽然代码有些丑,但是至少可以提供输出,这样就能够调试。

4.1.1 主类

  1. public class EvalBinaryClassExampleStream {
  2. AlgoOperator getData(boolean isBatch) {
  3. Row[] rows = new Row[]{
  4. Row.of("prefix1", "{\"prefix1\": 0.9, \"prefix0\": 0.1}")
  5. };
  6. String[] schema = new String[]{"label", "detailInput"};
  7. if (isBatch) {
  8. return new MemSourceBatchOp(rows, schema);
  9. } else {
  10. return new TimeMemSourceStreamOp(rows, schema, new EvalBinaryStreamSource());
  11. }
  12. }
  13. public static void main(String[] args) throws Exception {
  14. EvalBinaryClassExampleStream test = new EvalBinaryClassExampleStream();
  15. StreamOperator streamData = (StreamOperator) test.getData(false);
  16. StreamOperator sOp = new EvalBinaryClassStreamOp()
  17. .setLabelCol("label")
  18. .setPredictionDetailCol("detailInput")
  19. .setTimeInterval(1)
  20. .linkFrom(streamData);
  21. sOp.print();
  22. StreamOperator.execute();
  23. }
  24. }

4.1.2 TimeMemSourceStreamOp

这个是我自己炮制的。借鉴了MemSourceStreamOp。

  1. public final class TimeMemSourceStreamOp extends StreamOperator<TimeMemSourceStreamOp> {
  2. public TimeMemSourceStreamOp(Row[] rows, String[] colNames, EvalBinaryStrSource source) {
  3. super(null);
  4. init(source, Arrays.asList(rows), colNames);
  5. }
  6. private void init(EvalBinaryStreamSource source, List <Row> rows, String[] colNames) {
  7. Row first = rows.iterator().next();
  8. int arity = first.getArity();
  9. TypeInformation <?>[] types = new TypeInformation[arity];
  10. for (int i = 0; i < arity; ++i) {
  11. types[i] = TypeExtractor.getForObject(first.getField(i));
  12. }
  13. init(source, colNames, types);
  14. }
  15. private void init(EvalBinaryStreamSource source, String[] colNames, TypeInformation <?>[] colTypes) {
  16. DataStream <Row> dastr = MLEnvironmentFactory.get(getMLEnvironmentId())
  17. .getStreamExecutionEnvironment().addSource(source);
  18. StringBuilder sbd = new StringBuilder();
  19. sbd.append(colNames[0]);
  20. for (int i = 1; i < colNames.length; i++) {
  21. sbd.append(",").append(colNames[i]);
  22. }
  23. this.setOutput(dastr, colNames, colTypes);
  24. }
  25. @Override
  26. public TimeMemSourceStreamOp linkFrom(StreamOperator<?>... inputs) {
  27. return null;
  28. }
  29. }

4.1.3 Source

定时提供Row,加入了随机数,让概率有变化。

  1. class EvalBinaryStreamSource extends RichSourceFunction[Row] {
  2. override def run(ctx: SourceFunction.SourceContext[Row]) = {
  3. while (true) {
  4. val rdm = Math.random() // 这里加入了随机数,让概率有变化
  5. val rows: Array[Row] = Array[Row](
  6. Row.of("prefix1", "{\"prefix1\": " + rdm + ", \"prefix0\": " + (1-rdm) + "}"),
  7. Row.of("prefix1", "{\"prefix1\": 0.8, \"prefix0\": 0.2}"),
  8. Row.of("prefix1", "{\"prefix1\": 0.7, \"prefix0\": 0.3}"),
  9. Row.of("prefix0", "{\"prefix1\": 0.75, \"prefix0\": 0.25}"),
  10. Row.of("prefix0", "{\"prefix1\": 0.6, \"prefix0\": 0.4}"))
  11. for(row <- rows) {
  12. println(s"当前值:$row")
  13. ctx.collect(row)
  14. }
  15. Thread.sleep(1000)
  16. }
  17. }
  18. override def cancel() = ???
  19. }

4.2 BaseEvalClassStreamOp

Alink流处理类是 EvalBinaryClassStreamOp,主要工作在其基类 BaseEvalClassStreamOp,所以我们重点看后者。

  1. public class BaseEvalClassStreamOp<T extends BaseEvalClassStreamOp<T>> extends StreamOperator<T> {
  2. @Override
  3. public T linkFrom(StreamOperator<?>... inputs) {
  4. StreamOperator<?> in = checkAndGetFirst(inputs);
  5. String labelColName = this.get(MultiEvaluationStreamParams.LABEL_COL);
  6. String positiveValue = this.get(BinaryEvaluationStreamParams.POS_LABEL_VAL_STR);
  7. Integer timeInterval = this.get(MultiEvaluationStreamParams.TIME_INTERVAL);
  8. ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(this.getParams());
  9. DataStream<BaseMetricsSummary> statistics;
  10. switch (type) {
  11. case PRED_RESULT: {
  12. ......
  13. }
  14. case PRED_DETAIL: {
  15. String predDetailColName = this.get(MultiEvaluationStreamParams.PREDICTION_DETAIL_COL);
  16. //
  17. PredDetailLabel eval = new PredDetailLabel(positiveValue, binary);
  18. // 获取输入数据,重点是timeWindowAll
  19. statistics = in.select(new String[] {labelColName, predDetailColName})
  20. .getDataStream()
  21. .timeWindowAll(Time.of(timeInterval, TimeUnit.SECONDS))
  22. .apply(eval);
  23. break;
  24. }
  25. }
  26. // 把各个窗口的数据累积到 totalStatistics,注意,这里是新变量了。
  27. DataStream<BaseMetricsSummary> totalStatistics = statistics
  28. .map(new EvaluationUtil.AllDataMerge())
  29. .setParallelism(1); // 并行度设置为1
  30. // 基于两种 bins 计算&序列化,得到当前的 statistics
  31. DataStream<Row> windowOutput = statistics.map(
  32. new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.WINDOW.f0));
  33. // 基于bins计算&序列化,得到累积的 totalStatistics
  34. DataStream<Row> allOutput = totalStatistics.map(
  35. new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.ALL.f0));
  36. // "当前" 和 "累积" 做联合,最终返回
  37. DataStream<Row> union = windowOutput.union(allOutput);
  38. this.setOutput(union,
  39. new String[] {ClassificationEvaluationUtil.STATISTICS_OUTPUT, DATA_OUTPUT},
  40. new TypeInformation[] {Types.STRING, Types.STRING});
  41. return (T)this;
  42. }
  43. }

具体业务是:

  • PredDetailLabel 会进行去重标签名字 和 累积计算混淆矩阵所需数据

    • buildLabelIndexLabelArray 去重 "labels名字",然后给每一个label一个ID,最后结果是一个<labels, ID>Map。
    • getDetailStatistics 遍历 rows 数据,提取每一个item(比如 "prefix1,{"prefix1": 0.8, "prefix0": 0.2}"),然后通过updateBinaryMetricsSummary累积计算混淆矩阵所需数据。
  • 根据标签从Window中获取数据 statistics = in.select().getDataStream().timeWindowAll() .apply(eval);
  • EvaluationUtil.AllDataMerge 把各个窗口的数据累积到 totalStatistics 。
  • 得到windowOutput -------- EvaluationUtil.SaveDataStream,对"当前数据statistics"做处理。实际业务在BinaryMetricsSummary.toMetrics,即基于bin的信息计算,然后存储到params,并序列化返回Row。
    • extractMatrixThreCurve函数取出非空的bins,据此计算出ConfusionMatrix array(混淆矩阵), threshold array, rocCurve/recallPrecisionCurve/LiftChart.
    • 依据曲线内容计算并且存储 AUC/PRC/KS
    • 对生成的rocCurve/recallPrecisionCurve/LiftChart输出进行抽样
    • 依据抽样后的输出存储 RocCurve/RecallPrecisionCurve/LiftChar
    • 存储正例样本的度量指标
    • 存储Logloss
    • Pick the middle point where threshold is 0.5.
  • 得到allOutput -------- EvaluationUtil.SaveDataStream , 对"累积数据totalStatistics"做处理。
    • 详细处理流程同windowOutput。
  • windowOutput 和 allOutput 做联合。最终返回 DataStream union = windowOutput.union(allOutput);

4.2.1 PredDetailLabel

  1. static class PredDetailLabel implements AllWindowFunction<Row, BaseMetricsSummary, TimeWindow> {
  2. @Override
  3. public void apply(TimeWindow timeWindow, Iterable<Row> rows, Collector<BaseMetricsSummary> collector) throws Exception {
  4. HashSet<String> labels = new HashSet<>();
  5. // 首先还是获取 labels 名字
  6. for (Row row : rows) {
  7. if (EvaluationUtil.checkRowFieldNotNull(row)) {
  8. labels.addAll(EvaluationUtil.extractLabelProbMap(row).keySet());
  9. labels.add(row.getField(0).toString());
  10. }
  11. }
  12. labels = {HashSet@9757} size = 2
  13. 0 = "prefix1"
  14. 1 = "prefix0"
  15. // 之前介绍过,buildLabelIndexLabelArray 去重 "labels名字",然后给每一个label一个ID,最后结果是一个<labels, ID>Map。
  16. // getDetailStatistics 遍历 rows 数据,累积计算混淆矩阵所需数据( "TP + FN" / "TN + FP")。
  17. if (labels.size() > 0) {
  18. collector.collect(
  19. getDetailStatistics(rows, binary, buildLabelIndexLabelArray(labels, binary, positiveValue)));
  20. }
  21. }
  22. }

4.2.2 AllDataMerge

EvaluationUtil.AllDataMerge 把各个窗口的数据累积

  1. /**
  2. * Merge data from different windows.
  3. */
  4. public static class AllDataMerge implements MapFunction<BaseMetricsSummary, BaseMetricsSummary> {
  5. private BaseMetricsSummary statistics;
  6. @Override
  7. public BaseMetricsSummary map(BaseMetricsSummary value) {
  8. this.statistics = (null == this.statistics ? value : this.statistics.merge(value));
  9. return this.statistics;
  10. }
  11. }

4.2.3 SaveDataStream

SaveDataStream具体调用的函数之前批处理介绍过,实际业务在BinaryMetricsSummary.toMetrics,即基于bin的信息计算,存储到params。

这里与批处理不同的是直接就把"构建出的度量信息“返回给用户。

  1. public static class SaveDataStream implements MapFunction<BaseMetricsSummary, Row> {
  2. @Override
  3. public Row map(BaseMetricsSummary baseMetricsSummary) throws Exception {
  4. BaseMetricsSummary metrics = baseMetricsSummary;
  5. BaseMetrics baseMetrics = metrics.toMetrics();
  6. Row row = baseMetrics.serialize();
  7. return Row.of(funtionName, row.getField(0));
  8. }
  9. }
  10. // 最后得到的 row 其实就是最终返回给用户的度量信息
  11. row = {Row@10008} "{"PRC":"0.9164636268708667","SensitivityArray":"[0.38461538461538464,0.6923076923076923,0.6923076923076923,1.0,1.0,1.0]","ConfusionMatrix":"[[13,8],[0,0]]","MacroRecall":"0.5","MacroSpecificity":"0.5","FalsePositiveRateArray":"[0.0,0.0,0.5,0.5,1.0,1.0]" ...... 还有很多其他的

4.2.4 Union

  1. DataStream<Row> windowOutput = statistics.map(
  2. new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.WINDOW.f0));
  3. DataStream<Row> allOutput = totalStatistics.map(
  4. new EvaluationUtil.SaveDataStream(ClassificationEvaluationUtil.ALL.f0));
  5. DataStream<Row> union = windowOutput.union(allOutput);

最后返回两种统计数据

4.2.4.1 allOutput
  1. all|{"PRC":"0.7341146115890359","SensitivityArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,0.7333333333333333,0.8,0.8,0.8666666666666667,0.8666666666666667,0.9333333333333333,1.0]","ConfusionMatrix":"[[13,10],[2,0]]","MacroRecall":"0.43333333333333335","MacroSpecificity":"0.43333333333333335","FalsePositiveRateArray":"[0.0,0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.0]","TruePositiveRateArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,0.7333333333333333,0.8,0.8,0.8666666666666667,0.8666666666666667,0.9333333333333333,1.0]","AUC":"0.5666666666666667","MacroAccuracy":"0.52", ......

4.2.4.2 windowOutput

  1. window|{"PRC":"0.7638888888888888","SensitivityArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,1.0,1.0,1.0]","ConfusionMatrix":"[[3,2],[0,0]]","MacroRecall":"0.5","MacroSpecificity":"0.5","FalsePositiveRateArray":"[0.0,0.5,0.5,0.5,1.0,1.0]","TruePositiveRateArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,1.0,1.0,1.0]","AUC":"0.6666666666666666","MacroAccuracy":"0.6","RecallArray":"[0.3333333333333333,0.3333333333333333,0.6666666666666666,1.0,1.0,1.0]","KappaArray":"[0.28571428571428564,-0.15384615384615377,0.1666666666666666,0.5454545454545455,0.0,0.0]","MicroFalseNegativeRate":"0.4","WeightedRecall":"0.6","WeightedPrecision":"0.36","Recall":"1.0","MacroPrecision":"0.3",......

0xFF 参考

[[白话解析] 通过实例来梳理概念 :准确率 (Accuracy)、精准率(Precision)、召回率(Recall) 和 F值(F-Measure)](

Alink漫谈(八) : 二分类评估 AUC、K-S、PRC、Precision、Recall、LiftChart 如何实现的更多相关文章

  1. Alink漫谈(十二) :在线学习算法FTRL 之 整体设计

    Alink漫谈(十二) :在线学习算法FTRL 之 整体设计 目录 Alink漫谈(十二) :在线学习算法FTRL 之 整体设计 0x00 摘要 0x01概念 1.1 逻辑回归 1.1.1 推导过程 ...

  2. Alink漫谈(十一) :线性回归 之 L-BFGS优化

    Alink漫谈(十一) :线性回归 之 L-BFGS优化 目录 Alink漫谈(十一) :线性回归 之 L-BFGS优化 0x00 摘要 0x01 回顾 1.1 优化基本思路 1.2 各类优化方法 0 ...

  3. Alink漫谈(十三) :在线学习算法FTRL 之 具体实现

    Alink漫谈(十三) :在线学习算法FTRL 之 具体实现 目录 Alink漫谈(十三) :在线学习算法FTRL 之 具体实现 0x00 摘要 0x01 回顾 0x02 在线训练 2.1 预置模型 ...

  4. Alink漫谈(二十二) :源码分析之聚类评估

    Alink漫谈(二十二) :源码分析之聚类评估 目录 Alink漫谈(二十二) :源码分析之聚类评估 0x00 摘要 0x01 背景概念 1.1 什么是聚类 1.2 聚类分析的方法 1.3 聚类评估 ...

  5. Alink漫谈(二) : 从源码看机器学习平台Alink设计和架构

    Alink漫谈(二) : 从源码看机器学习平台Alink设计和架构 目录 Alink漫谈(二) : 从源码看机器学习平台Alink设计和架构 0x00 摘要 0x01 Alink设计原则 0x02 A ...

  6. Alink漫谈(十八) :源码解析 之 多列字符串编码MultiStringIndexer

    Alink漫谈(十八) :源码解析 之 多列字符串编码MultiStringIndexer 目录 Alink漫谈(十八) :源码解析 之 多列字符串编码MultiStringIndexer 0x00 ...

  7. Alink漫谈(二十) :卡方检验源码解析

    Alink漫谈(二十) :卡方检验源码解析 目录 Alink漫谈(二十) :卡方检验源码解析 0x00 摘要 0x01 背景概念 1.1 假设检验 1.2 H0和H1是什么? 1.3 P值 (P-va ...

  8. 二分类问题中混淆矩阵、PR以及AP评估指标

    仿照上篇博文对于混淆矩阵.ROC和AUC指标的探讨,本文简要讨论机器学习二分类问题中的混淆矩阵.PR以及AP评估指标:实际上,(ROC,AUC)与(PR,AP)指标对具有某种相似性. 按照循序渐进的原 ...

  9. 【AUC】二分类模型的评价指标ROC Curve

    AUC是指:从一堆样本中随机抽一个,抽到正样本的概率比抽到负样本的概率大的可能性! AUC是一个模型评价指标,只能用于二分类模型的评价,对于二分类模型,还有很多其他评价指标,比如logloss,acc ...

随机推荐

  1. Alpha总结展望——前事不忘后事之师

    这个作业属于哪个课程 软件工程 这个作业要求在哪里 Alpha总结展望--前事不忘后事之师 这个作业的目标 Alpha总结展望 作业正文 正文 其他参考文献 无 一.个人感想总结 吴秋悦: 对Alph ...

  2. 从0开始探究vue-公共变量的管理

    背景 在Vue项目中,我们总会遇到一些公共数据的处理,如方法拦截,全局变量等,本文旨在解决这些问题 解决方案 事件总线 所谓事件总线,就是在当前的Vue实例之外,再创建一个Vue实例来专门进行变量传递 ...

  3. “造轮运动”之 ORM框架系列(二)~ 说说我心目中的ORM框架

    ORM概念解析 首先梳理一下ORM的概念,ORM的全拼是Object Relation Mapping (对象关系映射),其中Object就是面向对象语言中的对象,本文使用的是c#语言,所以就是.ne ...

  4. Java实现 LeetCode 420 强密码检验器

    420. 强密码检验器 一个强密码应满足以下所有条件: 由至少6个,至多20个字符组成. 至少包含一个小写字母,一个大写字母,和一个数字. 同一字符不能连续出现三次 (比如 "-aaa-&q ...

  5. Java实现蓝桥杯正则切分

    java中提供了对正则表达式的支持. 有的时候,恰当地使用正则,可以让我们的工作事半功倍! 如下代码用来检验一个四则运算式中数据项的数目,请填写划线部分缺少的代码. 注意:只填写缺少代码,不要写任何多 ...

  6. opencl(6)读写传输命令、内存映射命令

    1:将缓存对象的内容读到缓存对象中(从设备到主机) cl_int clEnqueuReadBuffer( cl_command_queue command_queue, //命令队列 cl_mem b ...

  7. jsp页面用DBHelper实现简单的登陆验证

    首先我们需要写一个简单的登陆页面login.jsp,然后用from表单提交给index.jsp页面.在index.jsp页面通过DBHelper连接数据库判断账号和密码,如果密码正确则显示登陆成功. ...

  8. 华为EMUI在service中不能打印debug级别的日志

    华为emui在service里面不能打印debug级别的日志,因为这个小问题调试了一上午,刚开始我还以为emui把系统service的启动流程都改了呢

  9. java实现简单的oss存储

    oss 工作中需要用到文件上传,之前使用的是本地文件系统存储方式,后来重构为支持多个存储源的方式,目前支持三种方式:local.seaweedfs.minio 存储介质 seaweedfs seawe ...

  10. 构造函数,拷贝构造和赋值运算符调用时机,explicit,

    #include<iostream> #include <stdio.h> using namespace std; class test{ int mvalue; publi ...