1. TPOT介绍

一般来讲,创建一个机器学习模型需要经历以下几步:

  • 数据预处理
  • 特征工程
  • 模型选择
  • 超参数调整
  • 模型保存

本文介绍一个基于遗传算法的快速模型选择及调参的方法,TPOT:一种基于Python的自动机器学习开发工具。项目源代码位于:https://github.com/EpistasisLab/tpot

下图是一个机器学习模型开发图,其中灰色部分代表TPOT将要做的事情:即通过利用遗传算法,分析数千种可能的组合,为模型、参数找到最佳的组合,从而自动化机器学习中的模型选择及调参部分。

使用TPOT(版本0.9.5)开发模型需要把握以下几点:

  1. 在使用TPOT进行建模前需要对数据进行必要的清洗和特征工程操作。
  2. TPOT目前只能做有监督学习。
  3. TPOT目前支持的分类器主要有贝叶斯、决策树、集成树、SVM、KNN、线性模型、xgboost。
  4. TPOT目前支持的回归器主要有决策树、集成树、线性模型、xgboost。
  5. TPOT会对输入的数据做进一步处理操作,例如二值化、聚类、降维、标准化、正则化、独热编码操作等。
  6. 根据模型效果,TPOT会对输入特征做特征选择操作,包括基于树模型、基于方差、基于F-值的百分比。
  7. 可以通过export()方法把训练过程导出为形式为sklearn pipeline的.py文件

2. TPOT实现模型训练

下面是一个使用TPOT对MNIST数据集进行模型训练的例子:

# -*- coding: utf-8 -*-
"""
@author: wangkang
@file: start_tpot.py
@time: 2018/11/9 11:21
@desc: TPOT 实践
"""
import time
from tpot import TPOTClassifier
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split # 载入数据集
digits = load_digits() X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target,
train_size=0.75, test_size=0.25)
start = time.time() """
generations:运行管道优化过程的迭代次数
population_size:在遗传进化中每一代要保留的个体数量
verbosity: TPOT运行时能传递多少信息
"""
# 使用TPOT初始化分类器模型
tpot = TPOTClassifier(generations=5, population_size=20, verbosity=0) # 模型训练
tpot.fit(X_train, y_train)
print(tpot.score(X_test, y_test))
print('找到最优模型与超参数耗时:', time.time() - start) # 分类器其模型保存为 .py
tpot.export('tpot_mnist_pipeline.py')

运行结果如下所示:

可以观察到,经过5次遗传进化,找到了此范围内得分最高的模型及参数组合!但观察代码耗时发现,在i5-7500 CPU @ 3.40GHz条件下,这5次迭代,共耗时1297 S。

我们可以打开生成的 tpot_mnist_pipeline.py 文件,如下所示:

import numpy as np
import pandas as pd
from sklearn.ensemble import ExtraTreesClassifier, GradientBoostingClassifier
from sklearn.feature_selection import VarianceThreshold
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline, make_union
from tpot.builtins import StackingEstimator """
# NOTE: Make sure that the class is labeled 'target' in the data file
tpot_data = pd.read_csv('PATH/TO/DATA/FILE', sep='COLUMN_SEPARATOR', dtype=np.float64)
features = tpot_data.drop('target', axis=1).values
training_features, testing_features, training_target, testing_target = \
train_test_split(features, tpot_data['target'].values, random_state=None)
""" # 以上代码需修改为下面形式以供正确运行
from sklearn.datasets import load_digits digits = load_digits()
X = digits.data
y = digits.target training_features, testing_features, training_target, testing_target = \
train_test_split(X, y, random_state=None) # 此为由TPOT遗传算法得到的最优模型及参数组合
# Average CV score on the training set was:0.9792963424938936
exported_pipeline = make_pipeline(
PolynomialFeatures(degree=2, include_bias=False, interaction_only=False),
ZeroCount(),
LinearSVC(C=0.5, dual=True, loss="squared_hinge", penalty="l2", tol=0.001)
)
exported_pipeline.fit(training_features, training_target) results = exported_pipeline.predict(testing_features) 

print(results)

可以发现,训练好的模型以pipeline的形式保存(未进行持久化保存)。这样,整个关于MNIST数据集的分类器就训练完成了。

3. 总结

1、通过简单浏览源码发现,TPOT是在sklearn的基础之上做的封装库。其主要封装了sklearn的模型相关模块、processesing模块和feature_selection模块,所以TPOT的主要功能是集中在使用pipeline的方式完成模型的数据预处理、特征选择和模型选择方面。此外,我们还发现了TPOT已经对xgboost进行了支持。

2、虽然TPOT使用遗传算法代替了传统的网格搜索进行超参数选择,但由于默认初始值的随机性,在少量的进化(迭代)次数下,TPOT最终选择的模型往往并不相同。

3、计算效率问题。作者在代码中写道:进化(迭代)次数和每一代保留的个体数量值越多,最终得模型得分会越高。但这同样也会导致耗时很长。

初识TPOT:一个基于Python的自动化机器学习开发工具的更多相关文章

  1. psutil一个基于python的跨平台系统信息跟踪模块

    受益于这个模块的帮助,在这里我推荐一手. https://pythonhosted.org/psutil/#processes psutil是一个基于python的跨平台系统信息监视模块.在pytho ...

  2. 基于python的互联网软件测试开发(自动化测试)-全集合

    基于python的互联网软件测试开发(自动化测试)-全集合 1   关键字 为了便于搜索引擎收录本文,特别将本文的关键字给强调一下: python,互联网,自动化测试,测试开发,接口测试,服务测试,a ...

  3. 《Flask Web开发——基于Python的Web应用开发实践》一字一句上机实践(上)

    目录 前言 第1章 安装 第2章 程序的基本结构 第3章 模板 第4章 Web表单 第5章 数据库 第6章 电子邮件 第7章 大型程序的结构   前言 学习Python也有一个半月时间了,学到现在感觉 ...

  4. 基于Python的Web应用开发实践总结

    基于Python的Web应用开发学习总结 项目地址   本次学习采用的是Flask框架.根据教程开发个人博客系统.博客界面如图所示. 整个学习过程收获很多,以下是学习总结. 1.virtualenv ...

  5. 学习参考《Flask Web开发:基于Python的Web应用开发实战(第2版)》中文PDF+源代码

    在学习python Web开发时,我们会选择使用Django.flask等框架. 在学习flask时,推荐学习看看<Flask Web开发:基于Python的Web应用开发实战(第2版)> ...

  6. 一行导出所有任意微软SQL server数据脚本-基于Python的微软官方mssql-scripter工具使用全讲解

    文章标题: 一行导出所有任意微软SQL serer数据脚本-基于Python的微软官方mssql-scripter工具使用全讲解 关键字 : mssql-scripter,SQL Server 文章分 ...

  7. 转: Orz是一个基于Ogre思想的游戏开发架构

    Orz是一个基于Ogre思想的游戏开发架构,好的结构可以带来更多的功能.Orz和其他的商业以及非商业游戏开发架构不同.Orz更专著于开发者的感受,简化开发者工作.Orz可以用于集成其他Ogre3D之外 ...

  8. 一个基于 .NET Core 2.0 开发的简单易用的快速开发框架 - LinFx

    LinFx 一个基于 .NET Core 2.0 开发的简单易用的快速开发框架,遵循领域驱动设计(DDD)规范约束,提供实现事件驱动.事件回溯.响应式等特性的基础设施.让开发者享受到正真意义的面向对象 ...

  9. 基于Python的WEB接口开发与自动化测试 pdf(内含书签)

    基于Python的WEB接口开发与自动化测试 目录 目 录O V目 录章 Python 学习必知 ................................................... ...

随机推荐

  1. 数据结构入门之链表(C语言实现)

    这篇文章主要是根据<数据结构与算法分析--C语言描述>一书的链表章节内容所写,该书作者给出了链表ADT的一些方法,但是并没有给出所有方法的实现.在学习的过程中将练习的代码记录在文章中,并添 ...

  2. Hibernate Criteria用法大全

    1.标准查询简介 2.比较运算符 3.分页使用标准 4.排序结果 5.预测与聚合 6.关联 7. 动态关联抓取 8.查询示例 9.投影(Projections).聚合(aggregation)和分组( ...

  3. mssql 监控随笔

    性能监控列表: •    Memory: Pages/sec   ( 从硬盘上读取或写入硬盘的页数(参考值:00~20) •    Physical Disk: % Disk time 或 Physi ...

  4. HBase的写事务,MVCC及新的写线程模型

    MVCC是实现高性能数据库的关键技术,主要为了读不影响写.几乎所有数据库系统都用这技术,比如Spanner,看这里.Percolator,看这里.当然还有mysql.本文说HBase的MVCC和0.9 ...

  5. ZooKeeper 集群的安装部署

    0. 说明 ZooKeeper 安装在 s102.s103.s104上,这三个节点同时是 Hadoop 的 DataNode 1. ZooKeeper 本地模式安装配置 1.0 在 s101 上进行安 ...

  6. js fetch处理异步请求

    以往一直认为异步请求只能使用原生js的XMLHttpRequest或jQuery的$.ajax().$.post()等框架封装的异步请求方法 原来js还提供fetch来替代XMLHttpRequest ...

  7. Mysql查询缓存Query_cache的功用

    MySQL的查询缓存并非缓存执行计划,而是查询及其结果集,这就意味着只有相同的查询操作才能命中缓存,因此MySQL的查询缓存命中率很低,另一方面,对于大结果集的查询,其查询结果可以从cache中直接读 ...

  8. beta冲刺————第二天(2/5)

    完善具体内容: 前端: (1)添加了更多设置 (2)点击后出现底栏,分别可以进行字体背景设置.收藏.分享等操作,同时可以看出对文章的排版进行了完善 后端: 对阿里云服务器中的环境进行配置,同时熟悉阿里 ...

  9. PyQt5--GridLayout

    # -*- coding:utf-8 -*- ''' Created on Sep 13, 2018 @author: SaShuangYiBing ''' import sys from PyQt5 ...

  10. ES6中变量解构的用途—遍历Map结构