投票法(voting)是集成学习里面针对分类问题的一种结合策略。基本思想是选择所有机器学习算法当中输出最多的那个类。

分类的机器学习算法输出有两种类型:一种是直接输出类标签,另外一种是输出类概率,使用前者进行投票叫做硬投票(Majority/Hard voting),使用后者进行分类叫做软投票(Soft voting)。 sklearn中的VotingClassifier是投票法的实现。

硬投票

硬投票是选择算法输出最多的标签,如果标签数量相等,那么按照升序的次序进行选择。下面是一个例子:

from sklearn import datasets
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import VotingClassifier iris = datasets.load_iris()
X, y = iris.data[:,1:3], iris.target
clf1 = LogisticRegression(random_state=1)
clf2 = RandomForestClassifier(random_state=1)
clf3 = GaussianNB() eclf = VotingClassifier(estimators=[('lr',clf1),('rf',clf2),('gnb',clf3)], voting='hard')
#使用投票法将三个模型结合在以前,estimotor采用 [(name1,clf1),(name2,clf2),...]这样的输入,和Pipeline的输入相同 voting='hard'表示硬投票 for clf, clf_name in zip([clf1, clf2, clf3, eclf],['Logistic Regrsssion', 'Random Forest', 'naive Bayes', 'Ensemble']):
scores = cross_val_score(clf, X, y, cv=5, scoring='accuracy')
print('Accuracy: {:.2f} (+/- {:.2f}) [{}]'.format(scores.mean(), scores.std(), clf_name))

输出结果如下:

Accuracy: 0.90 (+/- 0.05) [Logistic Regrsssion]
Accuracy: 0.93 (+/- 0.05) [Random Forest]
Accuracy: 0.91 (+/- 0.04) [naive Bayes]
Accuracy: 0.95 (+/- 0.05) [Ensemble] 实际当中会报:DeprecationWarning

软投票

软投票是使用各个算法输出的类概率来进行类的选择,输入权重的话,会得到每个类的类概率的加权平均值,值大的类会被选择。

from itertools import product

import numpy as np
import matplotlib.pyplot as plt from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.ensemble import VotingClassifier iris = datasets.load_iris()
X = iris.data[:,[0,2]] #取两列,方便绘图
y = iris.target clf1 = DecisionTreeClassifier(max_depth=4)
clf2 = KNeighborsClassifier(n_neighbors=7)
clf3 = SVC(kernel='rbf', probability=True)
eclf = VotingClassifier(estimators=[('dt',clf1),('knn',clf2),('svc',clf3)], voting='soft', weights=[2,1,1])
#weights控制每个算法的权重, voting=’soft' 使用了软权重 clf1.fit(X,y)
clf2.fit(X,y)
clf3.fit(X,y)
eclf.fit(X,y) x_min, x_max = X[:,0].min() -1, X[:,0].max() + 1
y_min, y_max = X[:,1].min() -1, X[:,1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01), np.arange(y_min, y_max, 0.01)) #创建网格 fig, axes = plt.subplots(2, 2, sharex='col', sharey='row', figsize=(10, 8)) #共享X轴和Y轴 for idx, clf, title in zip(product([0, 1],[0, 1]),
[clf1, clf2, clf3, eclf],
['Decision Tree (depth=4)', 'KNN (k=7)',
'Kernel SVM', 'Soft Voting']):
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #起初我以为是预测的X的值,实际上是预测了上面创建的网格的值,以这些值来进行描绘区域
Z = Z.reshape(xx.shape)
axes[idx[0], idx[1]].contourf(xx, yy, Z, alpha=0.4)
axes[idx[0], idx[1]].scatter(X[:, 0],X[:, 1], c=y, s=20, edgecolor='k')
axes[idx[0], idx[1]].set_title(title)
plt.show()

输出结果如下:

参考:

Voting Classifier

sklearn中的投票法的更多相关文章

  1. 剑指 Offer 39. 数组中出现次数超过一半的数字 + 摩尔投票法

    剑指 Offer 39. 数组中出现次数超过一半的数字 Offer_39 题目描述 方法一:使用map存储数字出现的次数 public class Offer_39 { public int majo ...

  2. LeetCode题解-----Majority Element II 摩尔投票法

    题目描述: Given an integer array of size n, find all elements that appear more than ⌊ n/3 ⌋ times. The a ...

  3. sklearn中LinearRegression使用及源码解读

    sklearn中的LinearRegression 函数原型:class sklearn.linear_model.LinearRegression(fit_intercept=True,normal ...

  4. Leetcode Majority Element系列 摩尔投票法

    先看一题,洛谷2397: 题目背景 自动上次redbag用加法好好的刁难过了yyy同学以后,yyy十分愤怒.他还击给了redbag一题,但是这题他惊讶的发现自己居然也不会,所以只好找你 题目描述 [h ...

  5. sklearn中调用集成学习算法

    1.集成学习是指对于同一个基础数据集使用不同的机器学习算法进行训练,最后结合不同的算法给出的意见进行决策,这个方法兼顾了许多算法的"意见",比较全面,因此在机器学习领域也使用地非常 ...

  6. sklearn中的多项式回归算法

    sklearn中的多项式回归算法 1.多项式回归法多项式回归的思路和线性回归的思路以及优化算法是一致的,它是在线性回归的基础上在原来的数据集维度特征上增加一些另外的多项式特征,使得原始数据集的维度增加 ...

  7. 【笔记】多项式回归的思想以及在sklearn中使用多项式回归和pipeline

    多项式回归以及在sklearn中使用多项式回归和pipeline 多项式回归 线性回归法有一个很大的局限性,就是假设数据背后是存在线性关系的,但是实际上,具有线性关系的数据集是相对来说比较少的,更多时 ...

  8. 机器学习——sklearn中的API

    import matplotlib.pyplot as pltfrom sklearn.svm import SVCfrom sklearn.model_selection import Strati ...

  9. 【Warrior刷题笔记】力扣169. 多数元素 【排序 || 哈希 || 随机算法 || 摩尔投票法】详细注释 不断优化 极致压榨

    题目 来源:力扣(LeetCode) 链接:https://leetcode-cn.com/problems/majority-element/ 注意,该题在LC中被标注为easy,所以我们更多应该关 ...

随机推荐

  1. 一个TCP报文段的数据部分最多为多少个字节,为什么

    IP数据报的最大长度=2^16-1=65535(字节)TCP报文段的数据部分=IP数据报的最大长度-IP数据报的首部-TCP报文段的首部=65535-20-20=65495(字节) 一个tcp报文段的 ...

  2. [MySQL] gap lock/next-key lock浅析

    当InnoDB在判断行锁是否冲突的时候, 除了最基本的IS/IX/S/X锁的冲突判断意外, InnoDB还将锁细分为如下几种子类型: record lock (RK) 记录锁, 仅仅锁住索引记录的一行 ...

  3. Mysql存储之原生语句操作(pymysql)

    Mysql存储之原生语句操作(pymysql) 关系型数据库是基于关系模型的数据库,而关系模型是通过二维表时实现的,于是构成了行列的表结构. 表可以看作是某个实体的集合,而实体之间存在联系,这个就需要 ...

  4. maven工程的建立

    /* 我曾经接触过一个Java web项目,在进行部署时,发现这个项目涉及了maven 没有接触过maven项目的我,发现了如果需要导入maven工程,需要先在eclipse里面对maven进行配置, ...

  5. maven的初步理解

    [情景] 在进行JAVA项目开发的过程中,代码写好后,需要经过编译.打包.运行.测试.部署等过程. 在JAVA项目的开发阶段,就会根据业务的需要引入许多jar包来实现功能,但我们需求的jar包本身可能 ...

  6. hashCode()与equals()区别

    这两个方法均是超类Object自带的成员方法.Object类是所有Java类的祖先.每个类都使用 Object 作为超类.所有对象(包括数组)都实现这个类的方法.在不明确给出超类的情况下,Java会自 ...

  7. tomcat打开gzip、配置utf-8

    在部署描述文件中配置如下内容:(web.xml) 打开gzip compression="on"配置utf-8 URIEncoding="UTF-8" < ...

  8. SP 页面缓存以及清除缓存

    JSP 页面缓存以及清除缓存 一.概述 缓存的思想可以应用在软件分层的各个层面.它是一种内部机制,对外界而言,是不可感知的. 数据库本身有缓存,持久层也可以缓存.(比如:hibernate,还分1级和 ...

  9. Delphi 中的自动释放策略

    来自万一老师的博客:http://www.cnblogs.com/del/archive/2011/12/21/2295794.html ------------------------------- ...

  10. 深入理解 WordPress 数据库中的用户数据 wp_user

    WordPress 使用 wp_users 数据表存储用户的主要数据,该数据表结构类似于wp_posts 和 wp_comments 数据表,存储的是需要经常访问的用户数据,该数据表的结构以及该数据表 ...