GDBT 可以解决分类和回归问题

回归问题

  1. def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100,
  2. subsample=1.0, criterion='friedman_mse', min_samples_split=2,
  3. min_samples_leaf=1, min_weight_fraction_leaf=0.,
  4. max_depth=3, min_impurity_decrease=0.,
  5. min_impurity_split=None, init=None, random_state=None,
  6. max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None,
  7. warm_start=False, presort='auto')

示例

  1. import numpy as np
  2. from sklearn.metrics import mean_squared_error
  3. from sklearn.datasets import make_friedman1
  4. from sklearn.ensemble import GradientBoostingRegressor
  5.  
  6. X, y = make_friedman1(n_samples=1200, random_state=0, noise=1.0)
  7. X_train, X_test = X[:200], X[200:]
  8. y_train, y_test = y[:200], y[200:]
  9.  
  10. ### 损失函数
  11. # 如果损失函数为 误差绝对值,L=|y-f(x)|,负梯度为 sign(y-f(x)),即要么1,要么-1,sklearn 中对应为 loss='lad'
  12. # 如果损失函数为 huber,sklearn 中对应为 loss='huber'
  13. # 如果损失函数为 均方误差,sklearn 中对应为 loss='ls'
  14. est = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=1, random_state=0, loss='huber').fit(X_train, y_train)
  15.  
  16. pred = est.predict(X_test)
  17. error = mean_squared_error(pred, y_test)
  18.  
  19. print(max(y_test), min(y_test)) # (27.214332670044374, 0.8719243023544349)
  20. print(error)
  21. # loss='ls' 5.009154859960321
  22. # loss='lad' 5.817510629608294
  23. # loss='huber' 4.690823542377095

分类问题

  1. def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100,
  2. subsample=1.0, criterion='friedman_mse', min_samples_split=2,
  3. min_samples_leaf=1, min_weight_fraction_leaf=0.,
  4. max_depth=3, min_impurity_decrease=0.,
  5. min_impurity_split=None, init=None,
  6. random_state=None, max_features=None, verbose=0,
  7. max_leaf_nodes=None, warm_start=False,
  8. presort='auto')

示例

  1. from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
  2. from sklearn.model_selection import GridSearchCV, train_test_split
  3. from sklearn.preprocessing import StandardScaler
  4. from sklearn.metrics import accuracy_score, mean_squared_error
  5. from time import time
  6. import numpy as np
  7. import pandas as pd
  8. import mnist
  9.  
  10. if __name__ == "__main__":
  11. # 读取Mnist数据集, 测试GBDT的分类模型
  12. mnistSet = mnist.loadLecunMnistSet()
  13. train_X, train_Y, test_X, test_Y = mnistSet[0], mnistSet[1], mnistSet[2], mnistSet[3]
  14.  
  15. m, n = np.shape(train_X)
  16. idx = range(m)
  17. np.random.shuffle(idx)
  18.  
  19. # 使用PCA降维
  20. # num = 30000
  21. # pca = PCA(n_components=0.9, whiten=True, random_state=0)
  22. # for i in range(int(np.ceil(1.0 * m / num))):
  23. # minEnd = min((i + 1) * num, m)
  24. # sub_idx = idx[i * num:minEnd]
  25. # train_pca_X = pca.fit_transform(train_X[sub_idx])
  26. # print np.shape(train_pca_X)
  27.  
  28. print "\n**********测试GradientBoostingClassifier类**********"
  29. t = time()
  30. # param_grid1 = {"n_estimators": range(1000, 2001, 100)}
  31. # param_grid2 = {'max_depth': range(30, 71, 10), 'min_samples_split': range(4, 9, 2)}
  32. # param_grid3 = {'min_samples_split': range(4, 9, 2), 'min_samples_leaf': range(3, 12, 2)}
  33. # param_grid4 = {'subsample': np.arange(0.6, 1.0, 0.05)}
  34. # model = GridSearchCV(
  35. # estimator=GradientBoostingClassifier(max_features=90, max_depth=40, min_samples_split=8, learning_rate=0.1,
  36. # n_estimators=1800),
  37. # param_grid=param_grid4, cv=3)
  38. # # 拟合训练数据集
  39. # model.fit(train_X, train_Y)
  40. # print "最好的参数是:%s, 此时的得分是:%0.2f" % (model.best_params_, model.best_score_)
  41. model = GradientBoostingClassifier(max_features=90, max_depth=40, min_samples_split=8, min_samples_leaf=3,
  42. n_estimators=1200, learning_rate=0.05, subsample=0.95)
  43. # 拟合训练数据集
  44. model.fit(train_X, train_Y)
  45. # 预测训练集
  46. train_Y_hat = model.predict(train_X[idx])
  47. print "训练集精确度: ", accuracy_score(train_Y[idx], train_Y_hat)
  48. # 预测测试集
  49. test_Y_hat = model.predict(test_X)
  50. print "测试集精确度: ", accuracy_score(test_Y, test_Y_hat)
  51. print "总耗时:", time() - t, "秒"

参考资料:

https://github.com/haidawyl/Mnist  各种模型的用法

sklearn-GDBT的更多相关文章

  1. sklearn的常用函数以及参数

    sklearn可实现的函数或者功能可分为如下几个方面 1.分类算法2.回归算法3.聚类算法4.降维算法5.模型优化6.文本预处理 其中分类算法和回归算法又叫监督学习,聚类算法和降维算法又叫非监督学习 ...

  2. 机器学习之sklearn——EM

    GMM计算更新∑k时,转置符号T应该放在倒数第二项(这样计算出来结果才是一个协方差矩阵) from sklearn.mixture import GMM    GMM中score_samples函数第 ...

  3. 机器学习之sklearn——聚类

    生成数据集方法:sklearn.datasets.make_blobs(n_samples,n_featurs,centers)可以生成数据集,n_samples表示个数,n_features表示特征 ...

  4. 机器学习之sklearn——SVM

    sklearn包对于SVM可输出支持向量,以及其系数和数目: print '支持向量的数目: ', clf.n_support_ print '支持向量的系数: ', clf.dual_coef_ p ...

  5. 使用sklearn做单机特征工程

    目录 1 特征工程是什么?2 数据预处理 2.1 无量纲化 2.1.1 标准化 2.1.2 区间缩放法 2.1.3 标准化与归一化的区别 2.2 对定量特征二值化 2.3 对定性特征哑编码 2.4 缺 ...

  6. 使用sklearn进行集成学习——实践

    系列 <使用sklearn进行集成学习——理论> <使用sklearn进行集成学习——实践> 目录 1 Random Forest和Gradient Tree Boosting ...

  7. 【原】关于使用sklearn进行数据预处理 —— 归一化/标准化/正则化

    一.标准化(Z-Score),或者去除均值和方差缩放 公式为:(X-mean)/std  计算时对每个属性/每列分别进行. 将数据按期属性(按列进行)减去其均值,并处以其方差.得到的结果是,对于每个属 ...

  8. sklearn 增量学习 数据量大

    问题 实际处理和解决机器学习问题过程中,我们会遇到一些"大数据"问题,比如有上百万条数据,上千上万维特征,此时数据存储已经达到10G这种级别.这种情况下,如果还是直接使用传统的方式 ...

  9. 使用sklearn优雅地进行数据挖掘【转】

    目录 1 使用sklearn进行数据挖掘 1.1 数据挖掘的步骤 1.2 数据初貌 1.3 关键技术2 并行处理 2.1 整体并行处理 2.2 部分并行处理3 流水线处理4 自动化调参5 持久化6 回 ...

  10. Sklearn库例子——决策树分类

    Sklearn上关于决策树算法使用的介绍:http://scikit-learn.org/stable/modules/tree.html 1.关于决策树:决策树是一个非参数的监督式学习方法,主要用于 ...

随机推荐

  1. CSP2019游(AFO?)记

    Day 1 不知道为啥一看到\(T1\)就想到\(longlong\)可能存不下,试了下果然. \(T2\)想了半个小时胡出个\(O(n)\)算法,但是假了.冷静了一下,做了前缀和之后,合法的子区间\ ...

  2. 开源!js实现微信/QQ直接跳转到支付宝APP打开口令领红包!附:demo

    最近支付宝的领红包可真是刷爆了各个微信群啊,满群都是支付宝口令. 可是这样推广可不是办法,又要复制又要打开支付宝又要点领取,太麻烦了. 于是乎,提出了一个疑问!是否可以在微信里面点一个链接然后直接打开 ...

  3. 关于PHP内部类的一些总结学习

    前言: 这篇文章主要对一些可以进行反序列化的php内置类的分析总结(膜lemon师傅之前的总结),当然不是所有的php内置类在存在反序列化漏洞时都能够直接利用,有些类不一定能够进行反序列化,php中使 ...

  4. Linux远程连接工具 Shell Xshell6 XFtp6 绿色破解免安装版

    百度云下载链接: https://pan.baidu.com/s/1HMkuxv1yaAM1yhtz09zpfQ 关注以下公众号,回复xshell,获取提取码 关注公众号githubcn,免费获取更多 ...

  5. 将Chrome中的缓存数据移出C盘

    Chrome浏览器会默认的将用户的缓存是数据存放于  C:\Users\你的用户名\AppData\Local\Google\Chrome\User Data文件夹内.用久了之后,就会积攒大量缓存数据 ...

  6. 【转】Microsoft SQL Server 2008 R2 官方简体中文正式版下载(附激活序列号密钥)

    原文: https://www.bensblog.cn/1238.html

  7. ccf 201703-4 地铁修建(95)(并查集)

    ccf 201703-4 地铁修建(95) 使用并查集,将路径按照耗时升序排列,依次加入路径,直到1和n连通,这时加入的最后一条路径,就是所需要修建的时间最长的路径. #include<iost ...

  8. 使用MyBatis的动态SQL表达式时遇到的“坑”(integer)

    现有一项目,ORM框架使用的MyBatis,在进行列表查询时,选择一状态(值为0)通过动态SQL拼接其中条件但无法返回正常的查询结果,随后进行排查. POJO private Integer stat ...

  9. C++输入输出流加速器,关闭同步流,ios::sync_with_stdio(false)和 cin.tie(0)

    leetcode练习时,总会发现运行时间短的代码都会有类似: static int x=[](){ std::ios::sync_with_stdio(false); cin.tie(NULL); ; ...

  10. Python排序搜索基本算法之归并排序实例分析

    Python排序搜索基本算法之归并排序实例分析 本文实例讲述了Python排序搜索基本算法之归并排序.分享给大家供大家参考,具体如下: 归并排序最令人兴奋的特点是:不论输入是什么样的,它对N个元素的序 ...