DecisionTreeClassifier

from sklearn.datasets import load_wine # 红酒数据集
from sklearn.tree import DecisionTreeClassifier, export_graphviz # 决策树, 画树
from sklearn.model_selection import train_test_split # 数据集划分
import graphviz
import matplotlib.pyplot as plt
# 实例化红酒数据集
wine = load_wine()
# 划分测试集和训练集
x_train, x_test, y_train, y_test = train_test_split(wine.data, wine.target, test_size=0.25)
# 实例化决策树
clf = DecisionTreeClassifier(
criterion="entropy"
,random_state=30
,splitter="random"
,max_depth=4
)
clf.fit(x_train, y_train)
score = clf.score(x_test, y_test)
score
0.9333333333333333
# 查看每个特征的重要性
feature_names = ['酒精', '苹果酸', '灰', '灰的碱性', '镁', '总酚', '类黄酮', '非黄烷类酚类', '花青素', '颜色强度','色调','od280/od315稀释葡萄酒','脯氨酸']
[*zip(feature_names, clf.feature_importances_)]
[('酒精', 0.2251130582973216),
('苹果酸', 0.0),
('灰', 0.02596756412075755),
('灰的碱性', 0.0),
('镁', 0.0),
('总酚', 0.0),
('类黄酮', 0.43464628982715003),
('非黄烷类酚类', 0.03292950151904385),
('花青素', 0.02494017691000391),
('颜色强度', 0.0),
('色调', 0.03635605431269296),
('od280/od315稀释葡萄酒', 0.17795967993642653),
('脯氨酸', 0.04208767507660348)]
# 画出这棵树
data_dot = export_graphviz(
clf
,feature_names=feature_names
,class_names=["红酒","黄酒","啤酒"]
,filled=True
,rounded=True
)
grap = graphviz.Source(data_dot)
grap

# 展示max_depth各值对准确率影响的曲线

test = []
for i in range(10):
clf = DecisionTreeClassifier(
criterion="entropy", random_state=30, splitter="random", max_depth=i+1
)
clf = clf.fit(x_train, y_train)
score = clf.score(x_test, y_test)
test.append(score) plt.plot(range(1, 11),test, color="red", label="max_depth")
plt.legend()
plt.show()

DecisionTreeClassRegression

import pandas as pd # 数据处理
from sklearn.tree import DecisionTreeRegressor # 回归树
from sklearn.model_selection import cross_val_score # 交叉验证
# 导入数据
df = pd.read_csv("./data//boston_house_prices.csv")
df.head()
.dataframe tbody tr th:only-of-type { vertical-align: middle }
\3c pre>\3c code>.dataframe tbody tr th { vertical-align: top }
.dataframe thead th { text-align: right }

CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT MEDV
0 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98 24.0
1 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14 21.6
2 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03 34.7
3 0.03237 0.0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94 33.4
4 0.06905 0.0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33 36.2
# 特征值
data = df.iloc[:,:-1]
data
.dataframe tbody tr th:only-of-type { vertical-align: middle }
\3c pre>\3c code>.dataframe tbody tr th { vertical-align: top }
.dataframe thead th { text-align: right }

CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT
0 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98
1 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14
2 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03
3 0.03237 0.0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94
4 0.06905 0.0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33
... ... ... ... ... ... ... ... ... ... ... ... ... ...
501 0.06263 0.0 11.93 0 0.573 6.593 69.1 2.4786 1 273 21.0 391.99 9.67
502 0.04527 0.0 11.93 0 0.573 6.120 76.7 2.2875 1 273 21.0 396.90 9.08
503 0.06076 0.0 11.93 0 0.573 6.976 91.0 2.1675 1 273 21.0 396.90 5.64
504 0.10959 0.0 11.93 0 0.573 6.794 89.3 2.3889 1 273 21.0 393.45 6.48
505 0.04741 0.0 11.93 0 0.573 6.030 80.8 2.5050 1 273 21.0 396.90 7.88

506 rows × 13 columns

# 目标值
target = df.iloc[:,-1:]
target
.dataframe tbody tr th:only-of-type { vertical-align: middle }
\3c pre>\3c code>.dataframe tbody tr th { vertical-align: top }
.dataframe thead th { text-align: right }

MEDV
0 24.0
1 21.6
2 34.7
3 33.4
4 36.2
... ...
501 22.4
502 20.6
503 23.9
504 22.0
505 11.9

506 rows × 1 columns

# 实例化回归树
clr = DecisionTreeRegressor(random_state=0)
# 实例化交叉验证
cross = cross_val_score(clr, data, target, scoring="neg_mean_squared_error", cv=10)
cross
array([-18.08941176, -10.61843137, -16.31843137, -44.97803922,
-17.12509804, -49.71509804, -12.9986 , -88.4514 ,
-55.7914 , -25.0816 ])

一维回归图像绘制

import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt
rng = np.random.RandomState(1)
rng
RandomState(MT19937) at 0x7FC5EEAAAF40
x = np.sort(5 * rng.rand(80,1), axis=0)
x
array([[5.71874087e-04],
[9.14413867e-02],
[9.68347894e-02],
[1.36937966e-01],
[1.95273916e-01],
[2.49767295e-01],
[2.66812726e-01],
[4.25221057e-01],
[4.61692974e-01],
[4.91734169e-01],
[5.11672144e-01],
[5.16130033e-01],
[6.50142861e-01],
[6.87373521e-01],
[6.96381736e-01],
[7.01934693e-01],
[7.33642875e-01],
[7.33779454e-01],
[8.26770986e-01],
[8.49152098e-01],
[9.31301057e-01],
[9.90507445e-01],
[1.02226125e+00],
[1.05814058e+00],
[1.32773330e+00],
[1.40221996e+00],
[1.43887669e+00],
[1.46807074e+00],
[1.51166286e+00],
[1.56712089e+00],
[1.57757816e+00],
[1.72780364e+00],
[1.73882930e+00],
[1.98383737e+00],
[1.98838418e+00],
[2.07027994e+00],
[2.07089635e+00],
[2.08511002e+00],
[2.08652401e+00],
[2.09597257e+00],
[2.10553813e+00],
[2.23946763e+00],
[2.45786580e+00],
[2.57444556e+00],
[2.66582642e+00],
[2.67948203e+00],
[2.69408367e+00],
[2.79344914e+00],
[2.87058803e+00],
[2.93277520e+00],
[2.94652768e+00],
[3.31897323e+00],
[3.35233755e+00],
[3.39417766e+00],
[3.42609750e+00],
[3.43250464e+00],
[3.45938557e+00],
[3.46161308e+00],
[3.47200079e+00],
[3.49879180e+00],
[3.60162247e+00],
[3.62998993e+00],
[3.74082827e+00],
[3.75072157e+00],
[3.75406052e+00],
[3.94639664e+00],
[4.00372284e+00],
[4.03695644e+00],
[4.17312836e+00],
[4.38194576e+00],
[4.39058718e+00],
[4.39071252e+00],
[4.47303332e+00],
[4.51700958e+00],
[4.54297752e+00],
[4.63754290e+00],
[4.72297378e+00],
[4.78944765e+00],
[4.84130788e+00],
[4.94430544e+00]])
y = np.sin(x).ravel()

y[::5] += 3 * (0.5 - rng.rand(16))
y
array([-1.1493464 ,  0.09131401,  0.09668352,  0.13651039,  0.19403525,
-0.12383814, 0.26365828, 0.41252216, 0.44546446, 0.47215529,
-0.26319138, 0.49351799, 0.60530013, 0.63450933, 0.64144608,
1.09900119, 0.66957978, 0.66968122, 0.73574834, 0.75072053,
1.4926134 , 0.8363043 , 0.8532893 , 0.87144496, 0.97060533,
-0.20183403, 0.99131122, 0.99472837, 0.99825213, 0.99999325,
1.21570343, 0.98769965, 0.98591565, 0.9159044 , 0.91406986,
-0.51669013, 0.8775346 , 0.87063055, 0.86993408, 0.86523559,
0.37007575, 0.78464608, 0.63168655, 0.53722799, 0.45801971,
0.08075119, 0.43272116, 0.34115328, 0.26769953, 0.20730318,
1.34959235, -0.17645185, -0.20918837, -0.24990778, -0.28068224,
-1.63529379, -0.31247075, -0.31458595, -0.32442911, -0.34965155,
-0.29371122, -0.46921115, -0.56401144, -0.57215326, -0.57488849,
-0.95586361, -0.75923066, -0.78043659, -0.85808859, -0.94589863,
-0.6730775 , -0.94870673, -0.97149093, -0.98097408, -0.98568417,
-0.20828128, -0.99994398, -0.99703245, -0.99170146, -0.9732277 ])
reg1 = DecisionTreeRegressor(max_depth=2)
reg2 = DecisionTreeRegressor(max_depth=5)
reg1.fit(x, y)
reg2.fit(x, y)
DecisionTreeRegressor(max_depth=5)
x_test = np.arange(0.0, 5.0, 0.01)[:,np.newaxis]
x_test
array([[0.  ],
[0.01],
[0.02],
[0.03],
[0.04],
[0.05],
[0.06],
[0.07],
[0.08],
[0.09],
[0.1 ],
[0.11],
[0.12],
[0.13],
[0.14],
[0.15],
[0.16],
[0.17],
[0.18],
[0.19],
[0.2 ],
[0.21],
[0.22],
[0.23],
[0.24],
[0.25],
[0.26],
[0.27],
[0.28],
[0.29],
[0.3 ],
[0.31],
[0.32],
[0.33],
[0.34],
[0.35],
[0.36],
[0.37],
[0.38],
[0.39],
[0.4 ],
[0.41],
[0.42],
[0.43],
[0.44],
[0.45],
[0.46],
[0.47],
[0.48],
[0.49],
[0.5 ],
[0.51],
[0.52],
[0.53],
[0.54],
[0.55],
[0.56],
[0.57],
[0.58],
[0.59],
[0.6 ],
[0.61],
[0.62],
[0.63],
[0.64],
[0.65],
[0.66],
[0.67],
[0.68],
[0.69],
[0.7 ],
[0.71],
[0.72],
[0.73],
[0.74],
[0.75],
[0.76],
[0.77],
[0.78],
[0.79],
[0.8 ],
[0.81],
[0.82],
[0.83],
[0.84],
[0.85],
[0.86],
[0.87],
[0.88],
[0.89],
[0.9 ],
[0.91],
[0.92],
[0.93],
[0.94],
[0.95],
[0.96],
[0.97],
[0.98],
[0.99],
[1. ],
[1.01],
[1.02],
[1.03],
[1.04],
[1.05],
[1.06],
[1.07],
[1.08],
[1.09],
[1.1 ],
[1.11],
[1.12],
[1.13],
[1.14],
[1.15],
[1.16],
[1.17],
[1.18],
[1.19],
[1.2 ],
[1.21],
[1.22],
[1.23],
[1.24],
[1.25],
[1.26],
[1.27],
[1.28],
[1.29],
[1.3 ],
[1.31],
[1.32],
[1.33],
[1.34],
[1.35],
[1.36],
[1.37],
[1.38],
[1.39],
[1.4 ],
[1.41],
[1.42],
[1.43],
[1.44],
[1.45],
[1.46],
[1.47],
[1.48],
[1.49],
[1.5 ],
[1.51],
[1.52],
[1.53],
[1.54],
[1.55],
[1.56],
[1.57],
[1.58],
[1.59],
[1.6 ],
[1.61],
[1.62],
[1.63],
[1.64],
[1.65],
[1.66],
[1.67],
[1.68],
[1.69],
[1.7 ],
[1.71],
[1.72],
[1.73],
[1.74],
[1.75],
[1.76],
[1.77],
[1.78],
[1.79],
[1.8 ],
[1.81],
[1.82],
[1.83],
[1.84],
[1.85],
[1.86],
[1.87],
[1.88],
[1.89],
[1.9 ],
[1.91],
[1.92],
[1.93],
[1.94],
[1.95],
[1.96],
[1.97],
[1.98],
[1.99],
[2. ],
[2.01],
[2.02],
[2.03],
[2.04],
[2.05],
[2.06],
[2.07],
[2.08],
[2.09],
[2.1 ],
[2.11],
[2.12],
[2.13],
[2.14],
[2.15],
[2.16],
[2.17],
[2.18],
[2.19],
[2.2 ],
[2.21],
[2.22],
[2.23],
[2.24],
[2.25],
[2.26],
[2.27],
[2.28],
[2.29],
[2.3 ],
[2.31],
[2.32],
[2.33],
[2.34],
[2.35],
[2.36],
[2.37],
[2.38],
[2.39],
[2.4 ],
[2.41],
[2.42],
[2.43],
[2.44],
[2.45],
[2.46],
[2.47],
[2.48],
[2.49],
[2.5 ],
[2.51],
[2.52],
[2.53],
[2.54],
[2.55],
[2.56],
[2.57],
[2.58],
[2.59],
[2.6 ],
[2.61],
[2.62],
[2.63],
[2.64],
[2.65],
[2.66],
[2.67],
[2.68],
[2.69],
[2.7 ],
[2.71],
[2.72],
[2.73],
[2.74],
[2.75],
[2.76],
[2.77],
[2.78],
[2.79],
[2.8 ],
[2.81],
[2.82],
[2.83],
[2.84],
[2.85],
[2.86],
[2.87],
[2.88],
[2.89],
[2.9 ],
[2.91],
[2.92],
[2.93],
[2.94],
[2.95],
[2.96],
[2.97],
[2.98],
[2.99],
[3. ],
[3.01],
[3.02],
[3.03],
[3.04],
[3.05],
[3.06],
[3.07],
[3.08],
[3.09],
[3.1 ],
[3.11],
[3.12],
[3.13],
[3.14],
[3.15],
[3.16],
[3.17],
[3.18],
[3.19],
[3.2 ],
[3.21],
[3.22],
[3.23],
[3.24],
[3.25],
[3.26],
[3.27],
[3.28],
[3.29],
[3.3 ],
[3.31],
[3.32],
[3.33],
[3.34],
[3.35],
[3.36],
[3.37],
[3.38],
[3.39],
[3.4 ],
[3.41],
[3.42],
[3.43],
[3.44],
[3.45],
[3.46],
[3.47],
[3.48],
[3.49],
[3.5 ],
[3.51],
[3.52],
[3.53],
[3.54],
[3.55],
[3.56],
[3.57],
[3.58],
[3.59],
[3.6 ],
[3.61],
[3.62],
[3.63],
[3.64],
[3.65],
[3.66],
[3.67],
[3.68],
[3.69],
[3.7 ],
[3.71],
[3.72],
[3.73],
[3.74],
[3.75],
[3.76],
[3.77],
[3.78],
[3.79],
[3.8 ],
[3.81],
[3.82],
[3.83],
[3.84],
[3.85],
[3.86],
[3.87],
[3.88],
[3.89],
[3.9 ],
[3.91],
[3.92],
[3.93],
[3.94],
[3.95],
[3.96],
[3.97],
[3.98],
[3.99],
[4. ],
[4.01],
[4.02],
[4.03],
[4.04],
[4.05],
[4.06],
[4.07],
[4.08],
[4.09],
[4.1 ],
[4.11],
[4.12],
[4.13],
[4.14],
[4.15],
[4.16],
[4.17],
[4.18],
[4.19],
[4.2 ],
[4.21],
[4.22],
[4.23],
[4.24],
[4.25],
[4.26],
[4.27],
[4.28],
[4.29],
[4.3 ],
[4.31],
[4.32],
[4.33],
[4.34],
[4.35],
[4.36],
[4.37],
[4.38],
[4.39],
[4.4 ],
[4.41],
[4.42],
[4.43],
[4.44],
[4.45],
[4.46],
[4.47],
[4.48],
[4.49],
[4.5 ],
[4.51],
[4.52],
[4.53],
[4.54],
[4.55],
[4.56],
[4.57],
[4.58],
[4.59],
[4.6 ],
[4.61],
[4.62],
[4.63],
[4.64],
[4.65],
[4.66],
[4.67],
[4.68],
[4.69],
[4.7 ],
[4.71],
[4.72],
[4.73],
[4.74],
[4.75],
[4.76],
[4.77],
[4.78],
[4.79],
[4.8 ],
[4.81],
[4.82],
[4.83],
[4.84],
[4.85],
[4.86],
[4.87],
[4.88],
[4.89],
[4.9 ],
[4.91],
[4.92],
[4.93],
[4.94],
[4.95],
[4.96],
[4.97],
[4.98],
[4.99]])
y1 = reg1.predict(x_test)
y2 = reg2.predict(x_test)
plt.figure()
plt.scatter(x,y,s=20, edgecolors="black", c="darkorange", label="data")
plt.plot(x_test, y1, color="cornflowerblue",label="max_depth=2",linewidth=2)
plt.plot(x_test, y2, color="yellowgreen",label="max_depth=5",linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regressor")
plt.legend()
plt.show()



DecisionTreeClassifier&DecisionTreeClassRegression的更多相关文章

  1. 机器学习——决策树,DecisionTreeClassifier参数详解,决策树可视化查看树结构

    0.决策树 决策树是一种树型结构,其中每个内部节结点表示在一个属性上的测试,每一个分支代表一个测试输出,每个叶结点代表一种类别. 决策树学习是以实例为基础的归纳学习 决策树学习采用的是自顶向下的递归方 ...

  2. 机器学习之路: python 决策树分类DecisionTreeClassifier 预测泰坦尼克号乘客是否幸存

    使用python3 学习了决策树分类器的api 涉及到 特征的提取,数据类型保留,分类类型抽取出来新的类型 需要网上下载数据集,我把他们下载到了本地, 可以到我的git下载代码和数据集: https: ...

  3. 数据挖掘之DecisionTreeClassifier决策树

    用决策树DecisionTreeClassifier的数据挖掘算法来通过三个参数,Pclass,Sex,Age,三个参数来求取乘客的获救率. 分为三大步: 一,创建决策树DecisionTreeCla ...

  4. 【sklearn决策树算法】DecisionTreeClassifier(API)的使用以及决策树代码实例 - 鸢尾花分类

    决策树算法 决策树算法主要有ID3, C4.5, CART这三种. ID3算法从树的根节点开始,总是选择信息增益最大的特征,对此特征施加判断条件建立子节点,递归进行,直到信息增益很小或者没有特征时结束 ...

  5. sklearn.tree.DecisionTreeClassifier 详细说明

    sklearn.tree.DecisionTreeClassifier()函数用于构建决策树,默认使用CART算法,现对该函数参数进行说明,参考的是scikit-learn 0.20.3版本.     ...

  6. 【Machine Learning】决策树案例:基于python的商品购买能力预测系统

    决策树在商品购买能力预测案例中的算法实现 作者:白宁超 2016年12月24日22:05:42 摘要:随着机器学习和深度学习的热潮,各种图书层出不穷.然而多数是基础理论知识介绍,缺乏实现的深入理解.本 ...

  7. scikit-learn 梯度提升树(GBDT)调参小结

    在梯度提升树(GBDT)原理小结中,我们对GBDT的原理做了总结,本文我们就从scikit-learn里GBDT的类库使用方法作一个总结,主要会关注调参中的一些要点. 1. scikit-learn ...

  8. scikit-learn Adaboost类库使用小结

    在集成学习之Adaboost算法原理小结中,我们对Adaboost的算法原理做了一个总结.这里我们就从实用的角度对scikit-learn中Adaboost类库的使用做一个小结,重点对调参的注意事项做 ...

  9. 【原】Spark之机器学习(Python版)(二)——分类

    写这个系列是因为最近公司在搞技术分享,学习Spark,我的任务是讲PySpark的应用,因为我主要用Python,结合Spark,就讲PySpark了.然而我在学习的过程中发现,PySpark很鸡肋( ...

  10. scikit-learn决策树算法类库使用小结

    之前对决策树的算法原理做了总结,包括决策树算法原理(上)和决策树算法原理(下).今天就从实践的角度来介绍决策树算法,主要是讲解使用scikit-learn来跑决策树算法,结果的可视化以及一些参数调参的 ...

随机推荐

  1. AMBA总线介绍-02

    AMBA总线介绍 1 HSIZE AHB总线的地址位宽和数据位宽一般都是32bit,一个字节8bit,一个字节占用一个地址空间,但当一个32bit的数据写入一个存储器中或者从一个存储器中读取,32bi ...

  2. steam无法登陆/更新客户端

    1.问题 最近CS2更新,正准备尝试游玩一下,发现提示要使用最新版本客户端,在检查steam客户端更新时,却发现检查更新失败,无法更新,有可能是丢失了某些文件导致的. (之前有过一次重新安装的经历,但 ...

  3. [转帖]MYSQL--表分区、查看分区

    https://www.cnblogs.com/pejsidney/p/10074980.html 一.       mysql分区简介 数据库分区 数据库分区是一种物理数据库设计技术.虽然分区技术可 ...

  4. [转帖]前后台切换命令(ctrl+z jobs bg fg &)

    当我在终端里面运行某个命令的时候,结果不是很快就能出来的那种,或者是一大堆字在屏幕上狂翻.这个时候,有时ctrl+c也不起作用,那我会用ctrl+z退出来,这个很有效,但是说实话我不知道为什么这个可以 ...

  5. [转帖]TiDB 中的各种超时

    https://docs.pingcap.com/zh/tidb/stable/dev-guide-timeouts-in-tidb 本章将介绍 TiDB 中的各种超时,为排查错误提供依据. GC 超 ...

  6. [转帖]神秘的backlog参数与TCP连接队列

    https://www.cnblogs.com/codelogs/p/16060820.html 简介# 这要从一次压测项目说起,那是我们公司的系统与另几家同行公司的系统做性能比拼,性能数据会直接影响 ...

  7. 关于IO性能的一些学习与了解

    关于IO性能的一些学习与了解 摘要 最近心气不高. 学习进度也拖的比较慢. 以后想能够多为自己着想.自己有自己的节奏, 不能只为别人考虑. 要改变一下自己的做事风格. 一些事情想帮则帮, 不想帮就当看 ...

  8. Oracle 查看所有表大小的SQL

    Oracle 查看所有表大小的SQL 比较坑的是 lob 字段和 表的大小不在一个地方 为了出结果 我这边使用了 union all 慢的一逼... SELECT sum( tablesize ), ...

  9. add_argument()方法基本参数使用

    selenium做web自动化时我们想要通过get打开一个页面之前就设置好一些基本参数,需要 通过add_argument()方法来设置,下面以一个简单的不展示窗口为例. option = webdr ...

  10. 【记录一个问题】gin框架中,ShouldBindUri()函数依赖特定版本编译器,更换库的版本号后导致panic

    panic发生在这一行: uriBindErr = c.ShouldBindUri(methodLastInParam.Interface()) 导致panic的堆栈信息如下: err=reflect ...