scikit-learn系列之如何存储和导入机器学习模型

 
如何存储和导入机器学习模型

找到一个准确的机器学习模型,你的项目并没有完成。本文中你将学习如何使用scikit-learn来存储和导入机器学习模型。你可以把你的模型保持到文件中,然后再导入内存进行预测。

1. 用Pickle敲定你的模型

Pickle是python中一种标准的序列化对象的方法。你可以使用pickle操作来序列化你的机器学习算法,保存这种序列化的格式到一个文件中。稍后你可以导入这个文件反序列化你的模型,用它进行新的预测。以下的例子向你展示:如何使用Pima Indians onset of diabetes数据集,训练一个logistic回归模型,保存模型到文件,导入模型对未知数据进行预测。运行以下代码把模型存入你工作路径中的finalized_model.sav,导入模型,用未知数据评估模型的准确率。

  1. # Save Model Using Pickle
  2. import pandas
  3. from sklearn import model_selection
  4. from sklearn.linear_model import LogisticRegression
  5. import pickle
  6. url = "https://archive.ics.uci.edu/ml/machine-learning-databases/pima-indians-diabetes/pima-indians-diabetes.data"
  7. names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
  8. dataframe = pandas.read_csv(url, names=names)
  9. array = dataframe.values
  10. X = array[:,0:8]
  11. Y = array[:,8]
  12. test_size = 0.33
  13. seed = 7
  14. X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y, test_size=test_size, random_state=seed)
  15. # Fit the model on 33%
  16. model = LogisticRegression()
  17. model.fit(X_train, Y_train)
  18. # save the model to disk
  19. filename = 'finalized_model.sav'
  20. pickle.dump(model, open(filename, 'wb'))
  21. # some time later...
  22. # load the model from disk
  23. loaded_model = pickle.load(open(filename, 'rb'))
  24. result = loaded_model.score(X_test, Y_test)
  25. print(result)

2. 用joblib敲定你的模型

Joblib 是SciPy生态的一部分,为管道化python的工作提供的工具。它提供了存储和导入python对象的工具,可以对Numpy数据结构进行有效的利用。这对于要求很多参数和存储整个数据集的算法(比如K-Nearest Neighbors)很有帮助。以下代码向你展示:如何使用Pima Indians onset of diabetes数据集,训练一个logistic回归模型,使用joblib保存模型到文件,导入模型对未知数据进行预测。运行以下代码把模型存入你工作路径中的finalized_model.sav,也会创建一个文件保存Numpy数组,导入模型,用未知数据评估模型的准确率。

  1. # Save Model Using joblib
  2. import pandas
  3. from sklearn import model_selection
  4. from sklearn.linear_model import LogisticRegression
  5. from sklearn.externals import joblib
  6. url = "https://archive.ics.uci.edu/ml/machine-learning-databases/pima-indians-diabetes/pima-indians-diabetes.data"
  7. names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
  8. dataframe = pandas.read_csv(url, names=names)
  9. array = dataframe.values
  10. X = array[:,0:8]
  11. Y = array[:,8]
  12. test_size = 0.33
  13. seed = 7
  14. X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y, test_size=test_size, random_state=seed)
  15. # Fit the model on 33%
  16. model = LogisticRegression()
  17. model.fit(X_train, Y_train)
  18. # save the model to disk
  19. filename = 'finalized_model.sav'
  20. joblib.dump(model, filename)
  21. # some time later...
  22. # load the model from disk
  23. loaded_model = joblib.load(filename)
  24. result = loaded_model.score(X_test, Y_test)
  25. print(result)

3. 保存模型的几点提醒

当你存储你的机器学习模型时,需要考虑以下重要问题。一定要记住,记录下你的工具版本,以便于重构环境。

1. python的版本:记录下python的版本。需要相同大版本号的python来序列化和反序列化模型。
2. 库的版本:主要的库的版本要保持一致,不仅限于Numpy和scikit-learn的版本。
3. 手动序列化:你可能想要手动的输出你的模型参数以便于你可以直接把他们用在scikit-learn或者其他的平台。确实学习算法参数实现比算法本身实现要难得多。如果你有能力也可以自己写代码来导出参数。

4. 知识点:

  1. model_selection.train_test_split
  2. pickle.dump, pickle.load
  3. joblib.dump, joblib.load

原文链接:Save and Load Machine Learning Models in Python with scikit-learn

scikit-learn系列之如何存储和导入机器学习模型的更多相关文章

  1. Scikit Learn: 在python中机器学习

    转自:http://my.oschina.net/u/175377/blog/84420#OSC_h2_23 Scikit Learn: 在python中机器学习 Warning 警告:有些没能理解的 ...

  2. (原创)(三)机器学习笔记之Scikit Learn的线性回归模型初探

    一.Scikit Learn中使用estimator三部曲 1. 构造estimator 2. 训练模型:fit 3. 利用模型进行预测:predict 二.模型评价 模型训练好后,度量模型拟合效果的 ...

  3. scikit learn 模块 调参 pipeline+girdsearch 数据举例:文档分类 (python代码)

    scikit learn 模块 调参 pipeline+girdsearch 数据举例:文档分类数据集 fetch_20newsgroups #-*- coding: UTF-8 -*- import ...

  4. (原创)(四)机器学习笔记之Scikit Learn的Logistic回归初探

    目录 5.3 使用LogisticRegressionCV进行正则化的 Logistic Regression 参数调优 一.Scikit Learn中有关logistics回归函数的介绍 1. 交叉 ...

  5. 智能合约语言 Solidity 教程系列4 - 数据存储位置分析

    写在前面 Solidity 是以太坊智能合约编程语言,阅读本文前,你应该对以太坊.智能合约有所了解, 如果你还不了解,建议你先看以太坊是什么 这部分的内容官方英文文档讲的不是很透,因此我在参考Soli ...

  6. Scikit Learn

    Scikit Learn Scikit-Learn简称sklearn,基于 Python 语言的,简单高效的数据挖掘和数据分析工具,建立在 NumPy,SciPy 和 matplotlib 上.

  7. bullet物理引擎与OpenGL结合 导入3D模型进行碰撞检测 以及画三角网格的坑

    原文作者:aircraft 原文链接:https://www.cnblogs.com/DOMLX/p/11681069.html 一.初始化世界以及模型 /// 冲突配置包含内存的默认设置,冲突设置. ...

  8. TensorFlow系列专题(二):机器学习基础

    欢迎大家关注我们的网站和系列教程:http://www.tensorflownews.com/ ,学习更多的机器学习.深度学习的知识! 目录: 数据预处理 归一化 标准化 离散化 二值化 哑编码 特征 ...

  9. PV3D学习笔记-导入DAE模型

      网上关于PV3D导入DAE模型的例子都非常多,可惜我研究了半天,一个都没成功,或者是破面问题,或者是贴图不显示,再或者贴图乱掉了.今天晚上终于搞定,心得发上来. 制作模型的软件是SketchUp ...

随机推荐

  1. IE不支持 ES6 Promise 对象的解决方案

    * 引入bluebird.js即可完美解决. /*ie兼容 Promise*/ isIE(); function isIE() { //ie? if ( !! window.ActiveXObject ...

  2. springmvc 中配置aop

    之前自己搭建了springmvc+spring+mybaits/hibernate 的框架,并在applicationcontext.xml中配置了aop,但 发现aop根本不生效,而不用框架的话则可 ...

  3. vc++创建Win32 Application窗体过程

    #include<windows.h>#include<stdio.h>LRESULT CALLBACK WinSunProc( HWND hwnd, UINT uMsg, W ...

  4. [SCOI2008]奖励关_状压动归_数学期望

    Code: #include<cstdio> #include<algorithm> using namespace std; const int maxn = 20; dou ...

  5. awk一次性分别赋值多个value给多个变量,速度对比

    方法 #方法1: echo "apple banana orange" | awk '{print $1,$2,$3}' | while read a b c do echo a= ...

  6. C语言提高 (5) 第五天 结构体,结构体对齐 文件

    1昨日回顾 2作业讲解 3 结构体的基本定义 //1 struct teacher { int id; char name[64]; }; struct teacher t5 = { 5, " ...

  7. geohash:用字符串实现附近地点搜索

    转自:http://blog.charlee.li/geohash-intro/ geohash:用字符串实现附近地点搜索 上回说到了用经纬度范围实现附近地点搜索.一些小型应用中这样做没问题,但在大型 ...

  8. vue 动态添加路由 require.context()

    之前的写法 'use strict' import Vue from 'vue' import MessageBroadcast from 'page/MessageBroadcast' import ...

  9. 第一章 JavaScript 简介

    1.1   JavaScript 的简史 JavaScript 诞生于1995年 ,后由 欧洲计算机制造商协会( ECMA,European Computer Manufacturers Associ ...

  10. jenkins 新增节点的3种方式

    1.通过ssh建立节点(在节点机子上要安装好jdk) (1)通过用户+密码建立ssh连接 (2)通过用户+密钥建立连接 2.通过jnlp,javaweb的方式连接 (1)创建好节点 (2)在节点的机子 ...