在机器学习用于产品的时候,我们经常会遇到跨平台的问题。比如我们用Python基于一系列的机器学习库训练了一个模型,但是有时候其他的产品和项目想把这个模型集成进去,但是这些产品很多只支持某些特定的生产环境比如Java,为了上一个机器学习模型去大动干戈修改环境配置很不划算,此时我们就可以考虑用预测模型标记语言(Predictive Model Markup Language,以下简称PMML)来实现跨平台的机器学习模型部署了。

1. PMML概述

    PMML是数据挖掘的一种通用的规范,它用统一的XML格式来描述我们生成的机器学习模型。这样无论你的模型是sklearn,R还是Spark MLlib生成的,我们都可以将其转化为标准的XML格式来存储。当我们需要将这个PMML的模型用于部署的时候,可以使用目标环境的解析PMML模型的库来加载模型,并做预测。

    可以看出,要使用PMML,需要两步的工作,第一块是将离线训练得到的模型转化为PMML模型文件,第二块是将PMML模型文件载入在线预测环境,进行预测。这两块都需要相关的库支持。

2. PMML模型的生成和加载相关类库

    PMML模型的生成相关的库需要看我们使用的离线训练库。如果我们使用的是sklearn,那么可以使用sklearn2pmml这个python库来做模型文件的生成,这个库安装很简单,使用"pip install sklearn2pmml"即可,相关的使用我们后面会有一个demo。如果使用的是Spark MLlib, 这个库有一些模型已经自带了保存PMML模型的方法,可惜并不全。如果是R,则需要安装包"XML"和“PMML”。此外,JAVA库JPMML可以用来生成R,SparkMLlib,xgBoost,Sklearn的模型对应的PMML文件。

    加载PMML模型需要目标环境支持PMML加载的库,如果是JAVA,则可以用JPMML来加载PMML模型文件。相关的使用我们后面会有一个demo。

3. PMML模型生成和加载示例

    下面我们给一个示例,使用sklearn生成一个决策树模型,用sklearn2pmml生成模型文件,用JPMML加载模型文件,并做预测。

    完整代码https://github.com/ljpzzz/machinelearning/blob/master/model-in-product/sklearn-jpmml

    首先是用用sklearn生成一个决策树模型,由于我们是需要保存PMML文件,所以最好把模型先放到一个Pipeline数组里面。这个数组里面除了我们的决策树模型以外,还可以有归一化,降维等预处理操作,这里作为一个示例,我们Pipeline数组里面只有决策树模型。代码如下:

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
from sklearn import tree
from sklearn2pmml.pipeline import PMMLPipeline
from sklearn2pmml import sklearn2pmml import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files/Java/jdk1.8.0_171/bin' X=[[1,2,3,1],[2,4,1,5],[7,8,3,6],[4,8,4,7],[2,5,6,9]]
y=[0,1,0,2,1]
pipeline = PMMLPipeline([("classifier", tree.DecisionTreeClassifier(random_state=9))]);
pipeline.fit(X,y) sklearn2pmml(pipeline, ".\demo.pmml", with_repr = True) #

上面代码报错,可下载java端的jar包,jpmml-sklearn-executable
https://github.com/jpmml/jpmml-sklearn/releases

可参考博客:https://www.jianshu.com/p/fcb7256fcbd5修改代码。

上面这段代码做了一个非常简单的决策树分类模型,只有5个训练样本,特征有4个,输出类别有3个。实际应用时,我们需要将模型调参完毕后才将其放入PMMLPipeline进行保存。运行代码后,我们在当前目录会得到一个PMML的XML文件,可以直接打开看,内容大概如下:

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
<Header>
<Application name="JPMML-SkLearn" version="1.5.3"/>
<Timestamp>2018-06-24T05:47:17Z</Timestamp>
</Header>
<MiningBuildTask>
<Extension>PMMLPipeline(steps=[('classifier', DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort=False, random_state=9,
splitter='best'))])</Extension>
</MiningBuildTask>
<DataDictionary>
<DataField name="y" optype="categorical" dataType="integer">
<Value value="0"/>
<Value value="1"/>
<Value value="2"/>
</DataField>
<DataField name="x3" optype="continuous" dataType="float"/>
<DataField name="x4" optype="continuous" dataType="float"/>
</DataDictionary>
<TransformationDictionary>
<DerivedField name="double(x3)" optype="continuous" dataType="double">
<FieldRef field="x3"/>
</DerivedField>
<DerivedField name="double(x4)" optype="continuous" dataType="double">
<FieldRef field="x4"/>
</DerivedField>
</TransformationDictionary>
<TreeModel functionName="classification" missingValueStrategy="nullPrediction" splitCharacteristic="multiSplit">
<MiningSchema>
<MiningField name="y" usageType="target"/>
<MiningField name="x3"/>
<MiningField name="x4"/>
</MiningSchema>
<Output>
<OutputField name="probability(0)" optype="continuous" dataType="double" feature="probability" value="0"/>
<OutputField name="probability(1)" optype="continuous" dataType="double" feature="probability" value="1"/>
<OutputField name="probability(2)" optype="continuous" dataType="double" feature="probability" value="2"/>
</Output>
<Node>
<True/>
<Node>
<SimplePredicate field="double(x3)" operator="lessOrEqual" value="3.5"/>
<Node score="1" recordCount="1.0">
<SimplePredicate field="double(x3)" operator="lessOrEqual" value="2.0"/>
<ScoreDistribution value="0" recordCount="0.0"/>
<ScoreDistribution value="1" recordCount="1.0"/>
<ScoreDistribution value="2" recordCount="0.0"/>
</Node>
<Node score="0" recordCount="2.0">
<True/>
<ScoreDistribution value="0" recordCount="2.0"/>
<ScoreDistribution value="1" recordCount="0.0"/>
<ScoreDistribution value="2" recordCount="0.0"/>
</Node>
</Node>
<Node score="2" recordCount="1.0">
<SimplePredicate field="double(x4)" operator="lessOrEqual" value="8.0"/>
<ScoreDistribution value="0" recordCount="0.0"/>
<ScoreDistribution value="1" recordCount="0.0"/>
<ScoreDistribution value="2" recordCount="1.0"/>
</Node>
<Node score="1" recordCount="1.0">
<True/>
<ScoreDistribution value="0" recordCount="0.0"/>
<ScoreDistribution value="1" recordCount="1.0"/>
<ScoreDistribution value="2" recordCount="0.0"/>
</Node>
</Node>
</TreeModel>
</PMML>

    可以看到里面就是决策树模型的树结构节点的各个参数,以及输入值。我们的输入被定义为x1-x4,输出定义为y。

    有了PMML模型文件,我们就可以写JAVA代码来读取加载这个模型并做预测了。

    我们创建一个Maven或者gradle工程,加入JPMML的依赖,这里给出maven在pom.xml的依赖,gradle的结构是类似的。

    <dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.4.1</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.4.1</version>
</dependency>

    接着就是读取模型文件并预测的代码了,具体代码如下:

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException; import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map; /**
* Created by 刘建平Pinard on 2018/6/24.
*/
public class PMMLDemo {
private Evaluator loadPmml(){
PMML pmml = new PMML();
InputStream inputStream = null;
try {
inputStream = new FileInputStream("D:/demo.pmml");
} catch (IOException e) {
e.printStackTrace();
}
if(inputStream == null){
return null;
}
InputStream is = inputStream;
try {
pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
} catch (SAXException e1) {
e1.printStackTrace();
} catch (JAXBException e1) {
e1.printStackTrace();
}finally {
//关闭输入流
try {
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
pmml = null;
return evaluator;
}
private int predict(Evaluator evaluator,int a, int b, int c, int d) {
Map<String, Integer> data = new HashMap<String, Integer>();
data.put("x1", a);
data.put("x2", b);
data.put("x3", c);
data.put("x4", d);
List<InputField> inputFields = evaluator.getInputFields();
//过模型的原始特征,从画像中获取数据,作为模型输入
Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
Object rawValue = data.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
} Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<TargetField> targetFields = evaluator.getTargetFields(); TargetField targetField = targetFields.get(0);
FieldName targetFieldName = targetField.getName(); Object targetFieldValue = results.get(targetFieldName);
System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue);
int primitiveValue = -1;
if (targetFieldValue instanceof Computable) {
Computable computable = (Computable) targetFieldValue;
primitiveValue = (Integer)computable.getResult();
}
System.out.println(a + " " + b + " " + c + " " + d + ":" + primitiveValue);
return primitiveValue;
}
public static void main(String args[]){
PMMLDemo demo = new PMMLDemo();
Evaluator model = demo.loadPmml();
demo.predict(model,1,8,99,1);
demo.predict(model,111,89,9,11); }
}

    代码里有两个函数,第一个loadPmml是加载模型的,第二个predict是读取预测样本并返回预测值的。我的代码运行结果如下:

target: y value: {result=2, probability_entries=[0=0.0, 1=0.0, 2=1.0], entityId=5, confidence_entries=[]}
1 8 99 1:2
target: y value: {result=1, probability_entries=[0=0.0, 1=1.0, 2=0.0], entityId=6, confidence_entries=[]}
111 89 9 11:1

    也就是样本(1,8,99,1)被预测为类别2,而(111,89,9,11)被预测为类别1。

    以上就是PMML生成和加载的一个示例,使用起来其实门槛并不高,也很简单。

4. PMML总结与思考

    PMML的确是跨平台的利器,但是是不是就没有缺点呢?肯定是有的!

    第一个就是PMML为了满足跨平台,牺牲了很多平台独有的优化,所以很多时候我们用算法库自己的保存模型的API得到的模型文件,要比生成的PMML模型文件小很多。同时PMML文件加载速度也比算法库自己独有格式的模型文件加载慢很多。

    第二个就是PMML加载得到的模型和算法库自己独有的模型相比,预测会有一点点的偏差,当然这个偏差并不大。比如某一个样本,用sklearn的决策树模型预测为类别1,但是如果我们把这个决策树落盘为一个PMML文件,并用JAVA加载后,继续预测刚才这个样本,有较小的概率出现预测的结果不为类别1.

    第三个就是对于超大模型,比如大规模的集成学习模型,比如xgboost, 随机森林,或者tensorflow,生成的PMML文件很容易得到几个G,甚至上T,这时使用PMML文件加载预测速度会非常慢,此时推荐为模型建立一个专有的环境,就没有必要去考虑跨平台了。

    此外,对于TensorFlow,不推荐使用PMML的方式来跨平台。可能的方法一是TensorFlow serving,自己搭建预测服务,但是会稍有些复杂。另一个方法就是将模型保存为TensorFlow的模型文件,并用TensorFlow独有的JAVA库加载来做预测。

PMML辅助机器学习算法上线的更多相关文章

  1. 机器学习算法实践:Platt SMO 和遗传算法优化 SVM

    机器学习算法实践:Platt SMO 和遗传算法优化 SVM 之前实现了简单的SMO算法来优化SVM的对偶问题,其中在选取α的时候使用的是两重循环通过完全随机的方式选取,具体的实现参考<机器学习 ...

  2. Python机器学习算法 — 关联规则(Apriori、FP-growth)

    关联规则 -- 简介 关联规则挖掘是一种基于规则的机器学习算法,该算法可以在大数据库中发现感兴趣的关系.它的目的是利用一些度量指标来分辨数据库中存在的强规则.也即是说关联规则挖掘是用于知识发现,而非预 ...

  3. 机器学习算法的基本知识(使用Python和R代码)

    本篇文章是原文的译文,然后自己对其中做了一些修改和添加内容(随机森林和降维算法).文章简洁地介绍了机器学习的主要算法和一些伪代码,对于初学者有很大帮助,是一篇不错的总结文章,后期可以通过文中提到的算法 ...

  4. AI技术原理|机器学习算法

    摘要 机器学习算法分类:监督学习.半监督学习.无监督学习.强化学习 基本的机器学习算法:线性回归.支持向量机(SVM).最近邻居(KNN).逻辑回归.决策树.k平均.随机森林.朴素贝叶斯.降维.梯度增 ...

  5. 机器学习&数据挖掘笔记_16(常见面试之机器学习算法思想简单梳理)

    前言: 找工作时(IT行业),除了常见的软件开发以外,机器学习岗位也可以当作是一个选择,不少计算机方向的研究生都会接触这个,如果你的研究方向是机器学习/数据挖掘之类,且又对其非常感兴趣的话,可以考虑考 ...

  6. 建模分析之机器学习算法(附python&R代码)

    0序 随着移动互联和大数据的拓展越发觉得算法以及模型在设计和开发中的重要性.不管是现在接触比较多的安全产品还是大互联网公司经常提到的人工智能产品(甚至人类2045的的智能拐点时代).都基于算法及建模来 ...

  7. 【R】如何确定最适合数据集的机器学习算法 - 雪晴数据网

          [R]如何确定最适合数据集的机器学习算法 [R]如何确定最适合数据集的机器学习算法 抽查(Spot checking)机器学习算法是指如何找出最适合于给定数据集的算法模型.本文中我将介绍八 ...

  8. 在opencv3中的机器学习算法

    在opencv3.0中,提供了一个ml.cpp的文件,这里面全是机器学习的算法,共提供了这么几种: 1.正态贝叶斯:normal Bayessian classifier    我已在另外一篇博文中介 ...

  9. paper 19 :机器学习算法(简介)

    本来看了一天的分类器方面的代码,乱乱的,索性再把最基础的概念拿过来,现总结一下机器学习的算法吧! 1.机器学习算法简述 按照不同的分类标准,可以把机器学习的算法做不同的分类. 1.1 从机器学习问题角 ...

随机推荐

  1. html页面嵌入php代码不显示内容

    新建一个html文件,内容如下: <html> <head> <title>Example</title> </head> <body ...

  2. python zip用法

    import requests url = "https://magi.com/search" querystring = {"q":"堕却乡&quo ...

  3. Java语言基础(14)

    1 访问控制修饰符(二) 1)public:公共的,可以用来修饰类,属性,构造方法以及方法,被public修饰的类,属性,构造方法以及方法,可以任意的进行访问. 2)private:私有的,可以用来修 ...

  4. for(auto count:counts)

    c++中for(auto count : counts) 这是C++11中的语法,即:Range-based for loop.其中counts应满足:begin(counts), end(count ...

  5. netty-3.客户端与服务端通信

    (原) 第三篇,客户端与服务端通信 以下例子逻辑: 如果客户端连上服务端,服务端控制台就显示,XXX个客户端地址连接上线. 第一个客户端连接成功后,客户端控制台不显示信息,再有其它客户端再连接上线,则 ...

  6. 《Python基础教程》第三章:使用字符串

    find方法可以在一个较长的字符串中查找子字符串.它返回子串所在位置的最左端索引.如果没有找到则返回-1 join方法用来在队列中添加元素,需要添加的队列元素都必须是字符串 >>> ...

  7. 虚拟机vmware的连接方式以及IP端口,协议等概念

    1.NAT虚拟机相当于小弟,宿主机相当于大哥,宿主机虚拟出一个网段供虚拟机上网用 2.Bridge桥接,虚拟机和宿主机相当于局域网中的两台机器 3.Host-Only虚拟机只和宿主机通信,无法上网 3 ...

  8. Hadoop-No.2之标准文件格式

    标准文件格式可以指文本格式,也可以指二进制文件类型.前者包括逗号分隔值(Comma-Separated Value,CSV和可扩展的标记语言文本(Extensible Markup Language. ...

  9. [CCTF] pwn350

    0x00: 之前打了CCTF,在CCTF的过程中遇到一个比较有意思的思路,记录一下. 0x01: 可以看到,这是一个 fmt 的漏洞,不过很简单,接收的输入都在stack中,可以确定输入在栈中的位置, ...

  10. java输出txt文件到桌面

    private static void outputTxt(String ExportFailStudentMsg){ FileSystemView fsv = FileSystemView.getF ...