一、任务基础

导入所需要的库

  1. import matplotlib.pyplot as plt
  2. import pandas as pd
  3.  
  4. %matplotlib inline

加载sklearn内置数据集 ,查看数据描述

  1. from sklearn.datasets.california_housing import fetch_california_housing
  2. housing = fetch_california_housing()
  3. print(housing.DESCR)

数据集包含房价价格以及影响房价的一些因素

  1. .. _california_housing_dataset:
  2.  
  3. California Housing dataset
  4. --------------------------
  5.  
  6. **Data Set Characteristics:**
  7.  
  8. :Number of Instances: 20640
  9.  
  10. :Number of Attributes: 8 numeric, predictive attributes and the target
  11.  
  12. :Attribute Information:
  13. - MedInc median income in block
  14. - HouseAge median house age in block
  15. - AveRooms average number of rooms
  16. - AveBedrms average number of bedrooms
  17. - Population block population
  18. - AveOccup average house occupancy
  19. - Latitude house block latitude
  20. - Longitude house block longitude
  21.  
  22. :Missing Attribute Values: None
  23.  
  24. This dataset was obtained from the StatLib repository.
  25. http://lib.stat.cmu.edu/datasets/
  26.  
  27. The target variable is the median house value for California districts.
  28.  
  29. This dataset was derived from the 1990 U.S. census, using one row per census
  30. block group. A block group is the smallest geographical unit for which the U.S.
  31. Census Bureau publishes sample data (a block group typically has a population
  32. of 600 to 3,000 people).
  33.  
  34. It can be downloaded/loaded using the
  35. :func:`sklearn.datasets.fetch_california_housing` function.
  36.  
  37. .. topic:: References
  38.  
  39. - Pace, R. Kelley and Ronald Barry, Sparse Spatial Autoregressions,
  40. Statistics and Probability Letters, 33 (1997) 291-297

查看数据集维度

  1. housing.data.shape  
  1. (20640, 8)

查看第一条数据

  1. housing.data[0]
  1. array([ 8.3252 , 41. , 6.98412698, 1.02380952,
  2. 322. , 2.55555556, 37.88 , -122.23 ])

二、构造决策树模型

决策树模型参数:

(1)criterion gini or entropy   基尼系数或者熵
(2)splitter best or random 前者是在所有特征中找最好的切分点 后者是在部分特征中(数据量大的时候)
(3)max_features: None(所有),log2,sqrt,N。特征小于50的时候一般使用所有的特征
(4)max_depth 数据少或者特征少的时候可以不管这个值,如果模型样本量多,特征也多的情况下,可以尝试限制下这个决策树的深度。可以尝试遍历max_depth找出最佳。(最常用参数之一)
(5)min_samples_split 如果某节点的样本数少于min_samples_split,则不会继续再尝试选择最优特征来进行划分如果样本量不大,不需要管这个值。如果样本量数量级非常大,则推荐增大这个值。(最常用参数之一)
(6)min_samples_leaf 这个值限制了叶子节点最少的样本数,如果某叶子节点数目小于样本数,则会和兄弟节点一起被剪枝,如果样本量不大,不需要管这个值,大些如10W可是尝试下
(7)min_weight_fraction_leaf 这个值限制了叶子节点所有样本权重和的最小值,如果小于这个值,则会和兄弟节点一起被剪枝默认是0,就是不考虑权重问题。一般来说,如果我们有较多样本有缺失值,或者分类树样本的分布类别偏差很大,就会引入样本权重,这时我们就要注意这个值了。
(8)max_leaf_nodes 通过限制最大叶子节点数,可以防止过拟合,默认是"None”,即不限制最大的叶子节点数。如果加了限制,算法会建立在最大叶子节点数内最优的决策树。如果特征不多,可以不考虑这个值,但是如果特征分成多的话,可以加以限制具体的值可以通过交叉验证得到。
(9)class_weight 指定样本各类别的的权重,主要是为了防止训练集某些类别的样本过多导致训练的决策树过于偏向这些类别。这里可以自己指定各个样本的权重如果使用“balanced”,则算法会自己计算权重,样本量少的类别所对应的样本权重会高。
(10)min_impurity_split 这个值限制了决策树的增长,如果某节点的不纯度(基尼系数,信息增益,均方差,绝对差)小于这个阈值则该节点不再生成子节点。即为叶子节点 。
(11)n_estimators:要建立树的个数

这些参数都是用来剪枝决策树,防止决策树太过庞大或者出现过拟合的现象。

这里只选择了Longitude(经度)和Latitude(纬度)两个特征来构造决策树模型

  1. from sklearn import tree # 导入指定模块
  2. dtr = tree.DecisionTreeRegressor(max_depth=2) # 决策分类
  3. dtr.fit(housing.data[:, [6, 7]], housing.target) # x,y值

可以看出有些参数只需要保持默认即可

  1. DecisionTreeRegressor(criterion='mse', max_depth=2, max_features=None,
  2. max_leaf_nodes=None, min_impurity_decrease=0.0,
  3. min_impurity_split=None, min_samples_leaf=1,
  4. min_samples_split=2, min_weight_fraction_leaf=0.0,
  5. presort=False, random_state=None, splitter='best')

决策树的好处之一就在于通过可视化显示可以直观的看到构造出来的决策树模型

  1. # 要可视化显示,首先需要安装 graphviz
  2. # https://graphviz.gitlab.io/_pages/Download/Download_windows.html 下载可视化软件
  3. # pip install graphviz
  4.  
  5. # 设置临时环境遍历
  6. import os
  7. os.environ["PATH"] += os.pathsep + 'D:/program files (x86)/Graphviz2.38/bin/' #注意修改你的路径
  8.  
  9. dot_data = tree.export_graphviz(dtr, # 注意这个参数为决策树对象名称
  10. out_file=None,
  11. feature_names=housing.feature_names[6:8], # 还需要指定特征名
  12. filled=True,
  13. impurity=False,
  14. rounded=True)

显示决策树模型

  1. # pip install pydotplus
  2. import pydotplus
  3.  
  4. graph = pydotplus.graph_from_dot_data(dot_data)
  5. graph.get_nodes()[7].set_fillcolor("#FFF2DD")
  6. from IPython.display import Image
  7. Image(graph.create_png())

  

将数据集划分为训练集测试集

  1. from sklearn.model_selection import train_test_split
  2. data_train, data_test, target_train, target_test = train_test_split(
  3. housing.data, housing.target, test_size=0.1, random_state=42)
  4. dtr = tree.DecisionTreeRegressor(random_state=42)
  5. dtr.fit(data_train, target_train)
  6.  
  7. dtr.score(data_test, target_test)

得到精度值

  1. 0.637355881715626

可以看到,精度值不太高的样子。

三、随机森林模型

导入集成算法里面的随机森林模型库,在这个实例里面只是简单的使用下随机森林模型。

  1. from sklearn.ensemble import RandomForestRegressor # Regressor 回归
  2. # random_state就是为了保证程序每次运行都分割一样的训练集和测试集
  3. rfr = RandomForestRegressor( random_state = 42)
  4. rfr.fit(data_train, target_train)
  5. rfr.score(data_test, target_test)

可以看到随机森林模型精度要好点

  1. 0.7910601348350835

GridSearchCV,它存在的意义就是自动调参,只要把参数输进去,就能给出最优化的结果和参数。但是这个方法适合于小数据集,一旦数据的量级上去了,很难得出结果。相当于循环遍历给出的所有的参数来得到最优的结果,十分的耗时。

  1. # from sklearn.grid_search import GridSearchCV
  2. # sklearn.grid_search模块在0.18版本中被弃用,它所支持的类转移到model_selection模块中。还要注意,
  3. # 新的CV迭代器的接口与这个模块的接口不同。sklearn.grid_search将在0.20中被删除。
  4. from sklearn.model_selection import GridSearchCV
  5. tree_param_grid = {'min_samples_split':list((3,6,9)),'n_estimators':list((10,50,100))}
  6.  
  7. # cv 交叉验证(Cross-validation)的简写 代表进行几次交叉验证
  8. grid = GridSearchCV(RandomForestRegressor(),param_grid=tree_param_grid,cv=5)
  9. grid.fit(data_train,target_train)
  10. # grid_scores_在sklearn0.20版本中已被删除,取而代之的是cv_results_。
  11. grid.cv_results_, grid.best_params_, grid.best_score_

可以看到最优化的参数值为'min_samples_split': 3, 'n_estimators': 100,得到精度为0.80多一点(在输出结果最后一行)

  1. ({'mean_fit_time': array([0.91196742, 4.46895003, 8.89996696, 0.90845881, 4.01207662,
  2. 9.11067271, 0.84911356, 4.16957936, 8.08404155]),
  3. 'std_fit_time': array([0.04628971, 0.19323399, 0.36771072, 0.07048984, 0.05280237,
  4. 0.55379083, 0.0599862 , 0.19719896, 0.34949627]),
  5. 'mean_score_time': array([0.00918159, 0.0467237 , 0.08795581, 0.00958099, 0.03958073,
  6. 0.08624392, 0.01018567, 0.03616033, 0.06846623]),
  7. 'std_score_time': array([0.00367907, 0.00559777, 0.00399863, 0.00047935, 0.00082726,
  8. 0.0135891 , 0.0003934 , 0.0052837 , 0.00697507]),
  9. 'param_min_samples_split': masked_array(data=[3, 3, 3, 6, 6, 6, 9, 9, 9],
  10. mask=[False, False, False, False, False, False, False, False,
  11. False],
  12. fill_value='?',
  13. dtype=object),
  14. 'param_n_estimators': masked_array(data=[10, 50, 100, 10, 50, 100, 10, 50, 100],
  15. mask=[False, False, False, False, False, False, False, False,
  16. False],
  17. fill_value='?',
  18. dtype=object),
  19. 'params': [{'min_samples_split': 3, 'n_estimators': 10},
  20. {'min_samples_split': 3, 'n_estimators': 50},
  21. {'min_samples_split': 3, 'n_estimators': 100},
  22. {'min_samples_split': 6, 'n_estimators': 10},
  23. {'min_samples_split': 6, 'n_estimators': 50},
  24. {'min_samples_split': 6, 'n_estimators': 100},
  25. {'min_samples_split': 9, 'n_estimators': 10},
  26. {'min_samples_split': 9, 'n_estimators': 50},
  27. {'min_samples_split': 9, 'n_estimators': 100}],
  28. 'split0_test_score': array([0.79254741, 0.80793267, 0.81163631, 0.78859073, 0.81211894,
  29. 0.81222231, 0.79241065, 0.80784586, 0.80958409]),
  30. 'split1_test_score': array([0.77856084, 0.80047265, 0.80266101, 0.78474831, 0.79898533,
  31. 0.80203702, 0.77912397, 0.79714354, 0.80029259]),
  32. 'split2_test_score': array([0.78105784, 0.80063538, 0.8052804 , 0.78584898, 0.8036029 ,
  33. 0.80240046, 0.78148243, 0.79955117, 0.80072995]),
  34. 'split3_test_score': array([0.79582001, 0.80687008, 0.8100583 , 0.79947207, 0.80958334,
  35. 0.80851996, 0.78633104, 0.80797192, 0.81129754]),
  36. 'split4_test_score': array([0.79103059, 0.8071791 , 0.81016989, 0.78011578, 0.80719335,
  37. 0.81117408, 0.79282783, 0.8064226 , 0.8085679 ]),
  38. 'mean_test_score': array([0.78780359, 0.80461815, 0.80796138, 0.78775522, 0.80629709,
  39. 0.80727103, 0.7864355 , 0.80378723, 0.8060946 ]),
  40. 'std_test_score': array([0.00675437, 0.00333656, 0.00340769, 0.00646542, 0.00460913,
  41. 0.00429948, 0.00556 , 0.00453896, 0.00464337]),
  42. 'rank_test_score': array([7, 5, 1, 8, 3, 2, 9, 6, 4])},
  43. {'min_samples_split': 3, 'n_estimators': 100},
  44. 0.8079613788142571)

设置上面得到的最优化的参数,构造随机森林模型

  1. rfr = RandomForestRegressor( min_samples_split=3,n_estimators = 100,random_state = 42)
  2. rfr.fit(data_train, target_train)
  3. rfr.score(data_test, target_test)

可以看出精度相比较上面有了提升。 

  1. 0.8088623476993486

四、总结

通过这次案例学习了决策树模型里面的一些参数对于决策树的意义,以及怎么样使用可视化库来比较方便的展示构造出来的决策树模型,最后学习到了如何使用模型调参利器GridSearchCV来得到一般机器学习模型里面最优的参数,达到提升模型精度的目的。

  

机器学习之使用sklearn构造决策树模型的更多相关文章

  1. 深入了解机器学习决策树模型——C4.5算法

    本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是机器学习专题的第22篇文章,我们继续决策树的话题. 上一篇文章当中介绍了一种最简单构造决策树的方法--ID3算法,也就是每次选择一个特 ...

  2. Python机器学习笔记:sklearn库的学习

    网上有很多关于sklearn的学习教程,大部分都是简单的讲清楚某一方面,其实最好的教程就是官方文档. 官方文档地址:https://scikit-learn.org/stable/ (可是官方文档非常 ...

  3. chapter02 三种决策树模型:单一决策树、随机森林、GBDT(梯度提升决策树) 预测泰坦尼克号乘客生还情况

    单一标准的决策树:会根每维特征对预测结果的影响程度进行排序,进而决定不同特征从上至下构建分类节点的顺序.Random Forest Classifier:使用相同的训练样本同时搭建多个独立的分类模型, ...

  4. 机器学习笔记(四)--sklearn数据集

    sklearn数据集 (一)机器学习的一般数据集会划分为两个部分 训练数据:用于训练,构建模型. 测试数据:在模型检验时使用,用于评估模型是否有效. 划分数据的API:sklearn.model_se ...

  5. Spark2.0机器学习系列之3:决策树

    概述 分类决策树模型是一种描述对实例进行分类的树形结构. 决策树可以看为一个if-then规则集合,具有“互斥完备”性质 .决策树基本上都是 采用的是贪心(即非回溯)的算法,自顶向下递归分治构造. 生 ...

  6. 机器学习入门之sklearn介绍

    SKlearn简介 scikit-learn,又写作sklearn,是一个开源的基于python语言的机器学习工具包.它通过NumPy, SciPy和Matplotlib等python数值计算的库实现 ...

  7. sklearn CART决策树分类

    sklearn CART决策树分类 决策树是一种常用的机器学习方法,可以用于分类和回归.同时,决策树的训练结果非常容易理解,而且对于数据预处理的要求也不是很高. 理论部分 比较经典的决策树是ID3.C ...

  8. 【机器学习】多项式回归sklearn实现

    [机器学习]多项式回归原理介绍 [机器学习]多项式回归python实现 [机器学习]多项式回归sklearn实现 使用sklearn框架实现多项式回归.使用框架更方便,可以少写很多代码. 使用一个简单 ...

  9. 【机器学习笔记】ID3构建决策树

    好多算法之类的,看理论描述,让人似懂非懂,代码走一走,现象就了然了. 引: from sklearn import tree names = ['size', 'scale', 'fruit', 'b ...

随机推荐

  1. 09 audio和vedio标签

    <!DOCTYPE html> <html lang="zh-CN"> <head> <meta charset="UTF-8& ...

  2. 05 div的嵌套

    <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...

  3. 写在Logg SAP项目上线之际

    根据大环境大行业的惯用做法,公司建立Logg品牌是在意料之中.毫无意外的,Logg也要上到SAP系统中. 其实按它的业务模式来说上SAP系统并不困难,早在几年前就已经有做过了.无非就是接单不生产,外包 ...

  4. 使用回调的方式实现中间件-laravel

    $app = function ($request) { echo $request . "\n"; return "项目运行中....."; }; // 现在 ...

  5. CI框架中的奇葩

    今天在win下开发,使用ci框架,本来是没有任何问题,然后转向了mac上开发,结果出现了个奇葩的问题,就是在ci框架中,控制器命名以"Admin_"为前缀的,在url中,控制器也必 ...

  6. Spring Bean 生命周期之“我从哪里来?” 懂得这个很重要

    Spring bean 的生命周期很容易理解.实例化 bean 时,可能需要执行一些初始化以使其进入可用 (Ready for Use)状态.类似地,当不再需要 bean 并将其从容器中移除时,可能需 ...

  7. Mac上使用brew update会卡住的问题

    Mac上使用brew update会卡住的问题 brew默认的源是Github,会非常慢,建议换为国内的源.推荐中科大的镜像源,比较全面. 解决方案 Homebrew Homebrew源代码仓库 替换 ...

  8. C# 画箭头

    绘制箭头   1,直接用平台库 Pen arrowPen = new Pen(Color.Blue);            arrowPen.Width = 4;            arrowP ...

  9. Smobiler控件的使用:ListView的数据绑定及实现多选

    环境 SmobilerDesigner 4.7 Visual Studio 2010以上 正文 listview绑定数据 打开Visual Studio ,新建一个SmobilerApplicatio ...

  10. 成功解决 org.mybatis.spring.MyBatisSystemException问题!!

    org.mybatis.spring.MyBatisSystemException: nested exception is org.apache.ibatis.binding.BindingExce ...