用PMML实现机器学习模型的跨平台上线
在机器学习用于产品的时候,我们经常会遇到跨平台的问题。比如我们用Python基于一系列的机器学习库训练了一个模型,但是有时候其他的产品和项目想把这个模型集成进去,但是这些产品很多只支持某些特定的生产环境比如Java,为了上一个机器学习模型去大动干戈修改环境配置很不划算,此时我们就可以考虑用预测模型标记语言(Predictive Model Markup Language,以下简称PMML)来实现跨平台的机器学习模型部署了。
1. PMML概述
PMML是数据挖掘的一种通用的规范,它用统一的XML格式来描述我们生成的机器学习模型。这样无论你的模型是sklearn,R还是Spark MLlib生成的,我们都可以将其转化为标准的XML格式来存储。当我们需要将这个PMML的模型用于部署的时候,可以使用目标环境的解析PMML模型的库来加载模型,并做预测。
可以看出,要使用PMML,需要两步的工作,第一块是将离线训练得到的模型转化为PMML模型文件,第二块是将PMML模型文件载入在线预测环境,进行预测。这两块都需要相关的库支持。
2. PMML模型的生成和加载相关类库
PMML模型的生成相关的库需要看我们使用的离线训练库。如果我们使用的是sklearn,那么可以使用sklearn2pmml这个python库来做模型文件的生成,这个库安装很简单,使用"pip install sklearn2pmml"即可,相关的使用我们后面会有一个demo。如果使用的是Spark MLlib, 这个库有一些模型已经自带了保存PMML模型的方法,可惜并不全。如果是R,则需要安装包"XML"和“PMML”。此外,JAVA库JPMML可以用来生成R,SparkMLlib,xgBoost,Sklearn的模型对应的PMML文件。github地址是:https://github.com/jpmml/jpmml。
加载PMML模型需要目标环境支持PMML加载的库,如果是JAVA,则可以用JPMML来加载PMML模型文件。相关的使用我们后面会有一个demo。
3. PMML模型生成和加载示例
下面我们给一个示例,使用sklearn生成一个决策树模型,用sklearn2pmml生成模型文件,用JPMML加载模型文件,并做预测。
完整代码参见我的github:https://github.com/ljpzzz/machinelearning/blob/master/model-in-product/sklearn-jpmml
首先是用用sklearn生成一个决策树模型,由于我们是需要保存PMML文件,所以最好把模型先放到一个Pipeline数组里面。这个数组里面除了我们的决策树模型以外,还可以有归一化,降维等预处理操作,这里作为一个示例,我们Pipeline数组里面只有决策树模型。代码如下:
- import numpy as np
- import matplotlib.pyplot as plt
- %matplotlib inline
- import pandas as pd
- from sklearn import tree
- from sklearn2pmml.pipeline import PMMLPipeline
- from sklearn2pmml import sklearn2pmml
- import os
- os.environ["PATH"] += os.pathsep + 'C:/Program Files/Java/jdk1.8.0_171/bin'
- X=[[1,2,3,1],[2,4,1,5],[7,8,3,6],[4,8,4,7],[2,5,6,9]]
- y=[0,1,0,2,1]
- pipeline = PMMLPipeline([("classifier", tree.DecisionTreeClassifier(random_state=9))]);
- pipeline.fit(X,y)
- sklearn2pmml(pipeline, ".\demo.pmml", with_repr = True)
上面这段代码做了一个非常简单的决策树分类模型,只有5个训练样本,特征有4个,输出类别有3个。实际应用时,我们需要将模型调参完毕后才将其放入PMMLPipeline进行保存。运行代码后,我们在当前目录会得到一个PMML的XML文件,可以直接打开看,内容大概如下:
- <?xml version="1.0" encoding="UTF-8" standalone="yes"?>
- <PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
- <Header>
- <Application name="JPMML-SkLearn" version="1.5.3"/>
- <Timestamp>2018-06-24T05:47:17Z</Timestamp>
- </Header>
- <MiningBuildTask>
- <Extension>PMMLPipeline(steps=[('classifier', DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
- max_features=None, max_leaf_nodes=None,
- min_impurity_decrease=0.0, min_impurity_split=None,
- min_samples_leaf=1, min_samples_split=2,
- min_weight_fraction_leaf=0.0, presort=False, random_state=9,
- splitter='best'))])</Extension>
- </MiningBuildTask>
- <DataDictionary>
- <DataField name="y" optype="categorical" dataType="integer">
- <Value value="0"/>
- <Value value="1"/>
- <Value value="2"/>
- </DataField>
- <DataField name="x3" optype="continuous" dataType="float"/>
- <DataField name="x4" optype="continuous" dataType="float"/>
- </DataDictionary>
- <TransformationDictionary>
- <DerivedField name="double(x3)" optype="continuous" dataType="double">
- <FieldRef field="x3"/>
- </DerivedField>
- <DerivedField name="double(x4)" optype="continuous" dataType="double">
- <FieldRef field="x4"/>
- </DerivedField>
- </TransformationDictionary>
- <TreeModel functionName="classification" missingValueStrategy="nullPrediction" splitCharacteristic="multiSplit">
- <MiningSchema>
- <MiningField name="y" usageType="target"/>
- <MiningField name="x3"/>
- <MiningField name="x4"/>
- </MiningSchema>
- <Output>
- <OutputField name="probability(0)" optype="continuous" dataType="double" feature="probability" value="0"/>
- <OutputField name="probability(1)" optype="continuous" dataType="double" feature="probability" value="1"/>
- <OutputField name="probability(2)" optype="continuous" dataType="double" feature="probability" value="2"/>
- </Output>
- <Node>
- <True/>
- <Node>
- <SimplePredicate field="double(x3)" operator="lessOrEqual" value="3.5"/>
- <Node score="1" recordCount="1.0">
- <SimplePredicate field="double(x3)" operator="lessOrEqual" value="2.0"/>
- <ScoreDistribution value="0" recordCount="0.0"/>
- <ScoreDistribution value="1" recordCount="1.0"/>
- <ScoreDistribution value="2" recordCount="0.0"/>
- </Node>
- <Node score="0" recordCount="2.0">
- <True/>
- <ScoreDistribution value="0" recordCount="2.0"/>
- <ScoreDistribution value="1" recordCount="0.0"/>
- <ScoreDistribution value="2" recordCount="0.0"/>
- </Node>
- </Node>
- <Node score="2" recordCount="1.0">
- <SimplePredicate field="double(x4)" operator="lessOrEqual" value="8.0"/>
- <ScoreDistribution value="0" recordCount="0.0"/>
- <ScoreDistribution value="1" recordCount="0.0"/>
- <ScoreDistribution value="2" recordCount="1.0"/>
- </Node>
- <Node score="1" recordCount="1.0">
- <True/>
- <ScoreDistribution value="0" recordCount="0.0"/>
- <ScoreDistribution value="1" recordCount="1.0"/>
- <ScoreDistribution value="2" recordCount="0.0"/>
- </Node>
- </Node>
- </TreeModel>
- </PMML>
可以看到里面就是决策树模型的树结构节点的各个参数,以及输入值。我们的输入被定义为x1-x4,输出定义为y。
有了PMML模型文件,我们就可以写JAVA代码来读取加载这个模型并做预测了。
我们创建一个Maven或者gradle工程,加入JPMML的依赖,这里给出maven在pom.xml的依赖,gradle的结构是类似的。
- <dependency>
- <groupId>org.jpmml</groupId>
- <artifactId>pmml-evaluator</artifactId>
- <version>1.4.1</version>
- </dependency>
- <dependency>
- <groupId>org.jpmml</groupId>
- <artifactId>pmml-evaluator-extension</artifactId>
- <version>1.4.1</version>
- </dependency>
接着就是读取模型文件并预测的代码了,具体代码如下:
- import org.dmg.pmml.FieldName;
- import org.dmg.pmml.PMML;
- import org.jpmml.evaluator.*;
- import org.xml.sax.SAXException;
- import javax.xml.bind.JAXBException;
- import java.io.FileInputStream;
- import java.io.IOException;
- import java.io.InputStream;
- import java.util.HashMap;
- import java.util.LinkedHashMap;
- import java.util.List;
- import java.util.Map;
- /**
- * Created by 刘建平Pinard on 2018/6/24.
- */
- public class PMMLDemo {
- private Evaluator loadPmml(){
- PMML pmml = new PMML();
- InputStream inputStream = null;
- try {
- inputStream = new FileInputStream("D:/demo.pmml");
- } catch (IOException e) {
- e.printStackTrace();
- }
- if(inputStream == null){
- return null;
- }
- InputStream is = inputStream;
- try {
- pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
- } catch (SAXException e1) {
- e1.printStackTrace();
- } catch (JAXBException e1) {
- e1.printStackTrace();
- }finally {
- //关闭输入流
- try {
- is.close();
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
- ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
- Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
- pmml = null;
- return evaluator;
- }
- private int predict(Evaluator evaluator,int a, int b, int c, int d) {
- Map<String, Integer> data = new HashMap<String, Integer>();
- data.put("x1", a);
- data.put("x2", b);
- data.put("x3", c);
- data.put("x4", d);
- List<InputField> inputFields = evaluator.getInputFields();
- //过模型的原始特征,从画像中获取数据,作为模型输入
- Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>();
- for (InputField inputField : inputFields) {
- FieldName inputFieldName = inputField.getName();
- Object rawValue = data.get(inputFieldName.getValue());
- FieldValue inputFieldValue = inputField.prepare(rawValue);
- arguments.put(inputFieldName, inputFieldValue);
- }
- Map<FieldName, ?> results = evaluator.evaluate(arguments);
- List<TargetField> targetFields = evaluator.getTargetFields();
- TargetField targetField = targetFields.get(0);
- FieldName targetFieldName = targetField.getName();
- Object targetFieldValue = results.get(targetFieldName);
- System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue);
- int primitiveValue = -1;
- if (targetFieldValue instanceof Computable) {
- Computable computable = (Computable) targetFieldValue;
- primitiveValue = (Integer)computable.getResult();
- }
- System.out.println(a + " " + b + " " + c + " " + d + ":" + primitiveValue);
- return primitiveValue;
- }
- public static void main(String args[]){
- PMMLDemo demo = new PMMLDemo();
- Evaluator model = demo.loadPmml();
- demo.predict(model,1,8,99,1);
- demo.predict(model,111,89,9,11);
- }
- }
代码里有两个函数,第一个loadPmml是加载模型的,第二个predict是读取预测样本并返回预测值的。我的代码运行结果如下:
target: y value: {result=2, probability_entries=[0=0.0, 1=0.0, 2=1.0], entityId=5, confidence_entries=[]}
1 8 99 1:2
target: y value: {result=1, probability_entries=[0=0.0, 1=1.0, 2=0.0], entityId=6, confidence_entries=[]}
111 89 9 11:1
也就是样本(1,8,99,1)被预测为类别2,而(111,89,9,11)被预测为类别1。
以上就是PMML生成和加载的一个示例,使用起来其实门槛并不高,也很简单。
4. PMML总结与思考
PMML的确是跨平台的利器,但是是不是就没有缺点呢?肯定是有的!
第一个就是PMML为了满足跨平台,牺牲了很多平台独有的优化,所以很多时候我们用算法库自己的保存模型的API得到的模型文件,要比生成的PMML模型文件小很多。同时PMML文件加载速度也比算法库自己独有格式的模型文件加载慢很多。
第二个就是PMML加载得到的模型和算法库自己独有的模型相比,预测会有一点点的偏差,当然这个偏差并不大。比如某一个样本,用sklearn的决策树模型预测为类别1,但是如果我们把这个决策树落盘为一个PMML文件,并用JAVA加载后,继续预测刚才这个样本,有较小的概率出现预测的结果不为类别1.
第三个就是对于超大模型,比如大规模的集成学习模型,比如xgboost, 随机森林,或者tensorflow,生成的PMML文件很容易得到几个G,甚至上T,这时使用PMML文件加载预测速度会非常慢,此时推荐为模型建立一个专有的环境,就没有必要去考虑跨平台了。
此外,对于TensorFlow,不推荐使用PMML的方式来跨平台。可能的方法一是TensorFlow serving,自己搭建预测服务,但是会稍有些复杂。另一个方法就是将模型保存为TensorFlow的模型文件,并用TensorFlow独有的JAVA库加载来做预测。
我们在下一篇会讨论用python+tensorflow训练保存模型,并用tensorflow的JAVA库加载做预测的方法和实例。
(欢迎转载,转载请注明出处。欢迎沟通交流: liujianping-ok@163.com)
用PMML实现机器学习模型的跨平台上线的更多相关文章
- tensorflow机器学习模型的跨平台上线
在用PMML实现机器学习模型的跨平台上线中,我们讨论了使用PMML文件来实现跨平台模型上线的方法,这个方法当然也适用于tensorflow生成的模型,但是由于tensorflow模型往往较大,使用无法 ...
- 用PMML实现python机器学习模型的跨平台上线
python信用评分卡(附代码,博主录制) https://study.163.com/course/introduction.htm?courseId=1005214003&utm_camp ...
- 使用pmml实现跨平台部署机器学习模型
一.概述 对于由Python训练的机器学习模型,通常有pickle和pmml两种部署方式,pickle方式用于在python环境中的部署,pmml方式用于跨平台(如Java环境)的部署,本文叙述的 ...
- 使用pmml跨平台部署机器学习模型Demo——房价预测
基于房价数据,在python中训练得到一个线性回归的模型,在JavaWeb中加载模型完成房价预测的功能. 一. 训练.保存模型 工具:PyCharm-2017.Python-39.sklearn2 ...
- PMML辅助机器学习算法上线
在机器学习用于产品的时候,我们经常会遇到跨平台的问题.比如我们用Python基于一系列的机器学习库训练了一个模型,但是有时候其他的产品和项目想把这个模型集成进去,但是这些产品很多只支持某些特定的生产环 ...
- 使用ML.NET + ASP.NET Core + Docker + Azure Container Instances部署.NET机器学习模型
本文将使用ML.NET创建机器学习分类模型,通过ASP.NET Core Web API公开它,将其打包到Docker容器中,并通过Azure Container Instances将其部署到云中. ...
- GMIS 2017 大会陈雨强演讲:机器学习模型,宽与深的大战
https://blog.csdn.net/starzhou/article/details/72819374 2017-05-27 19:15:36 GMIS 2017 10 0 5 ...
- Python 3 利用 Dlib 19.7 和 sklearn机器学习模型 实现人脸微笑检测
0.引言 利用机器学习的方法训练微笑检测模型,给一张人脸照片,判断是否微笑: 使用的数据集中69张没笑脸,65张有笑脸,训练结果识别精度在95%附近: 效果: 图1 示例效果 工程利用pytho ...
- Python 3 利用机器学习模型 进行手写体数字识别
0.引言 介绍了如何生成数据,提取特征,利用sklearn的几种机器学习模型建模,进行手写体数字1-9识别. 用到的四种模型: 1. LR回归模型,Logistic Regression 2. SGD ...
随机推荐
- CF719E. Sasha and Array [线段树维护矩阵]
CF719E. Sasha and Array 题意: 对长度为 n 的数列进行 m 次操作, 操作为: a[l..r] 每一项都加一个常数 C, 其中 0 ≤ C ≤ 10^9 求 F[a[l]]+ ...
- Do Now 一个让你静心学习的APP——团队博客
Do Now 一个让你静心学习的APP 来自油条只要半根团队的智慧凝聚的产物! 团队博客总目录: 团队作业第一周 团队作业第二周 Do Now -- 团队冲刺博客一 Do-Now-团队Scrum 冲刺 ...
- CoreException: Could not get the value for parameter compilerId for plugin execution default-compile: PluginResolutionException: Plugin org.apache.maven.plugins:maven-compiler-plugin:3.1
今天遇到一个奇怪的问题, 之前写好的代码, 更换环境后, 重新搭建的nexus, maven私服总是报错, 各种clean/update都不管用 原来是没写版本号, 后来加上3.1版本, 还是报错, ...
- MyBatis 缓存机制
Mybatis 有两级缓存: 一级缓存: 也称为本地缓存,SqlSession级别的缓存.一级缓存是一直开启的: 与数据库同一次会话期间查询到的数据会放在本地缓存中,以后如果需要获取相同的数据,直接从 ...
- netcore应用程序部署程序到ubuntu
运维需求:获取服务器的运行情况,是否CPU.内存较高等,上报到运维系统 环境:ubuntu16.04 工具::netcore2.1.supervisor 程序实现(代码就不贴了)参考:https:// ...
- linux configure 应用
linux下configure命令详细介绍 2018年01月11日 15:02:20 冷月霜 阅读数:705 标签: configure 更多 个人分类: 数据库技术 Linux环境下的软件安装, ...
- mysql学习3
1.索引 索引是表的目录,在查找内容之前可以先在目录中查找索引位置,以此快速定位查询数据.对于索引, 会保存在额外的文件中. 作用: 约束 加速查找 1.1.建立索引 a.额外的文件保存特殊的数据结构 ...
- Java后期拓展(一)之Redis
1.NoSQL数据库简介 2.Redis的介绍及安装启动 3.Redis的五大数据类型 4.Redis的相关配置 5.Redis的Java客户端Jedis 6.Redis的事务 7.Redis的持久化 ...
- 高德地图JS API获取经纬度,根据经纬度获取城市
<!DOCTYPE HTML> <html> <head> <meta http-equiv="Content-Type" content ...
- Tomcat线程池配置
简介 线程池作为提高程序处理数据能力的一种方案,应用非常广泛.大量的服务器都或多或少的使用到了线程池技术,不管是用Java还是C++实现,线程池都有如下的特点:线程池一般有三个重要参数: 最大线程数 ...