统计学习方法——实现AdaBoost
Adaboost
适用问题:二分类问题
- 模型:加法模型
\]
- 策略:损失函数为指数函数
\]
- 算法:前向分步算法
\]
特点:AdaBoost算法的特点是通过迭代每次学习一个基本分类器。每次迭代中,提高那些被前一轮分类器错误分类数据的权值,而降低那些被正确分类的数据的权值。最后,AdaBoost将基本分类器的线性组合作为强分类器,其中给分类误差率小的基本分类器以大的权值,给分类误差率大的基本分类器以小的权值。
算法步骤:
1)给每个训练样本(\(x_{1},x_{2},….,x_{N}\))分配权重,初始权重\(w_{1}\)均为1/N。
2)针对带有权值的样本进行训练,得到模型\(G_m\)(初始模型为G1)。
3)计算模型\(G_m\)的误分率\(e_m=\sum_{i=1}^Nw_iI(y_i\not= G_m(x_i))\) (误分率应小于0.5,否则将预测结果翻转即可得到误分率小于0.5的分类器)
4)计算模型\(G_m\)的系数\(\alpha_m=0.5\log[(1-e_m)/e_m]\)
5)根据误分率e和当前权重向量\(w_m\)更新权重向量\(w_{m+1}\)。
6)计算组合模型\(f(x)=\sum_{m=1}^M\alpha_mG_m(x_i)\)的误分率。
7)当组合模型的误分率或迭代次数低于一定阈值,停止迭代;否则,回到步骤2)
提升树
提升树是以分类树或回归树为基本分类器的提升方法。提升树被认为是统计学习中最有效的方法之一。
提升方法:将弱可学习算法提升为强可学习算法。提升方法通过反复修改训练数据的权值分布,构建一系列基本分类器(弱分类器),并将这些基本分类器线性组合,构成一个强分类器。AdaBoost算法是提升方法的一个代表。
AdaBoost源码实现
假设弱分类器由 \(x < v\) 或 \(x > v\) 产生,阈值\(v\)使该分类器在训练集上分类误差率最低。
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline
def create_data():
iris = load_iris() # 鸢尾花数据集
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
data = np.array(df.iloc[:100, [0, 1, -1]]) # 取前一百个数据,只保留前两个特征
for d in data:
if d[-1] == 0:
d[-1] = -1
return data[:, :2], data[:, -1].astype(np.int)
class AdaBoost:
def __init__(self, num_classifier, increment=0.5):
"""
num_classifier: 弱分类器的数量
increment: 在特征上寻找最优切分点时,搜索时每次的增加值(数据稀疏时建议根据样本点来选择)
"""
self.num_classifier = num_classifier
self.increment = increment
def fit(self, X, Y):
self._init_args(X, Y)
# 逐个训练分类器
for m in range(self.num_classifier):
min_error, v_optimal, preds = float('INF'), None, None
direct_split = None
feature_idx = None # 选定的特征的列索引
# 遍历选择特征和切分点使得分类误差最小
for j in range(self.num_feature):
feature_values = self.X[:, j] # 第j个特征对应的所有取值
_ret = self._get_optimal_split(feature_values)
v_split, _direct_split, error, pred_labels = _ret
if error < min_error:
min_error = error
v_optimal = v_split
preds = pred_labels
direct_split = _direct_split
feature_idx = j
# 计算分类型权重alpha
alpha = self._cal_alpha(min_error)
self.alphas.append(alpha)
# 记录当前分类器G(x)
self.classifiers.append((feature_idx, v_optimal, direct_split))
# 更新样本集合权值分布
self._update_weights(alpha, preds)
def predict(self, x):
res = 0.0
for i in range(len(self.classifiers)):
idx, v, direct = self.classifiers[i]
# 输入弱分类器进行分类
if direct == '>':
output = 1 if x[idx] > v else -1
else: # direct == '<'
output = -1 if x[idx] > v else 1
res += self.alphas[i] * output
return 1 if res > 0 else -1 # sign(res)
def score(self, X_test, Y_test):
cnt = 0
for i, x in enumerate(X_test):
if self.predict(x) == Y_test[i]:
cnt += 1
return cnt / len(X_test)
def _init_args(self, X, Y):
self.X = X
self.Y = Y
self.N, self.num_feature = X.shape # N:样本数,num_feature:特征数量
# 初始时每个样本的权重均相同
self.weights = [1/self.N] * self.N
# 弱分类器集合
self.classifiers = []
# 每个分类器G(x)的权重
self.alphas = []
def _update_weights(self, alpha, pred_labels):
# 计算规范化因子Z
Z = self._cal_norm_factor(alpha, pred_labels)
for i in range(self.N):
self.weights[i] = (self.weights[i] *
np.exp(-1*alpha*self.Y[i]*pred_labels[i]) / Z)
def _cal_alpha(self, error):
return 0.5 * np.log((1-error)/error)
def _cal_norm_factor(self, alpha, pred_labels):
return sum([self.weights[i] * np.exp(-1*alpha*self.Y[i]*pred_labels[i])
for i in range(self.N)])
def _get_optimal_split(self, feature_values):
error = float('INF') # 分类误差
pred_labels = [] # 分类结果
v_split_optimal = None # 当前特征的最优切割点
direct_split = None # 最优切割点的判别方向
max_v = max(feature_values)
min_v = min(feature_values)
num_step = (max_v - min_v + self.increment)/self.increment
for i in range(int(num_step)):
# 选取分割点
v_split = min_v + i * self.increment
judge_direct = '>'
preds = [1 if feature_values[k] > v_split else -1
for k in range(len(feature_values))]
# 错误样本加权误差
weight_error = sum([self.weights[k] for k in range(self.N)
if preds[k] != self.Y[k]])
# 计算分类标签翻转后的误差
preds_inv = [-p for p in preds]
weight_error_inv = sum([self.weights[k] for k in range(self.N)
if preds_inv[k] != self.Y[k]])
# 取较小误差的判别方向作为分类器的判别方向
if weight_error_inv < weight_error:
preds = preds_inv
weight_error = weight_error_inv
judge_direct = '<'
if weight_error < error:
error = weight_error
pred_labels = preds
v_split_optimal = v_split
direct_split = judge_direct
return v_split_optimal, direct_split, error, pred_labels
测试模型准确率:
X, Y = create_data()
res = []
for i in range(10):
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)
clf = AdaBoost(num_classifier=50)
clf.fit(X_train, Y_train)
res.append(clf.score(X_test, Y_test))
print('My AdaBoost: {}次的平均准确率: {:.3f}'.format(len(res), sum(res)/len(res)))
My AdaBoost: 10次的平均准确率: 0.970
sklearn库的AdaBoost实例
from sklearn.ensemble import AdaBoostClassifier
res = []
for i in range(10):
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)
clf_sklearn = AdaBoostClassifier(n_estimators=50, learning_rate=0.5)
clf_sklearn.fit(X_train, Y_train)
res.append(clf_sklearn.score(X_test, Y_test))
print('sklearn AdaBoostClassifier: {}次的平均准确率: {:.3f}'.format(
len(res), sum(res)/len(res)))
sklearn AdaBoostClassifier: 10次的平均准确率: 0.945
统计学习方法——实现AdaBoost的更多相关文章
- Adaboost算法的一个简单实现——基于《统计学习方法(李航)》第八章
最近阅读了李航的<统计学习方法(第二版)>,对AdaBoost算法进行了学习. 在第八章的8.1.3小节中,举了一个具体的算法计算实例.美中不足的是书上只给出了数值解,这里用代码将它实现一 ...
- 【NLP】基于统计学习方法角度谈谈CRF(四)
基于统计学习方法角度谈谈CRF 作者:白宁超 2016年8月2日13:59:46 [摘要]:条件随机场用于序列标注,数据分割等自然语言处理中,表现出很好的效果.在中文分词.中文人名识别和歧义消解等任务 ...
- 统计学习方法 --- 感知机模型原理及c++实现
参考博客 Liam Q博客 和李航的<统计学习方法> 感知机学习旨在求出将训练数据集进行线性划分的分类超平面,为此,导入了基于误分类的损失函数,然后利用梯度下降法对损失函数进行极小化,从而 ...
- 统计学习方法笔记--EM算法--三硬币例子补充
本文,意在说明<统计学习方法>第九章EM算法的三硬币例子,公式(9.5-9.6如何而来) 下面是(公式9.5-9.8)的说明, 本人水平有限,怀着分享学习的态度发表此文,欢迎大家批评,交流 ...
- 统计学习方法:KNN
作者:桂. 时间:2017-04-19 21:20:09 链接:http://www.cnblogs.com/xingshansi/p/6736385.html 声明:欢迎被转载,不过记得注明出处哦 ...
- 统计学习方法:罗杰斯特回归及Tensorflow入门
作者:桂. 时间:2017-04-21 21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...
- 统计学习方法:核函数(Kernel function)
作者:桂. 时间:2017-04-26 12:17:42 链接:http://www.cnblogs.com/xingshansi/p/6767980.html 前言 之前分析的感知机.主成分分析( ...
- 统计学习方法学习(四)--KNN及kd树的java实现
K近邻法 1基本概念 K近邻法,是一种基本分类和回归规则.根据已有的训练数据集(含有标签),对于新的实例,根据其最近的k个近邻的类别,通过多数表决的方式进行预测. 2模型相关 2.1 距离的度量方式 ...
- 李航《统计学习方法》CH01
CH01 统计学方法概论 前言 章节目录 统计学习 监督学习 基本概念 问题的形式化 统计学习三要素 模型 策略 算法 模型评估与模型选择 训练误差与测试误差 过拟合与模型选择 正则化与交叉验证 正则 ...
随机推荐
- 手把手教你Centos7 部署 gitlab社区版
一.前置说明: 操作系统:Centos 7 物理内存:>=2G 本人亲测,如果安装低版本的gitlab,比如我这里所使用的v8.17.0,物理内存1G,swap 2G虚拟内存即可部署.高版本的所 ...
- 微信小程序:事件绑定
小程序中绑定事件,通过bind关键字来实现.如bindinput,bindtap(绑定点击事件),bindchange等. 什么是事件 事件是视图层到逻辑层的通讯方式. 事件可以将用户的行为反馈到逻辑 ...
- Java基础自学小项目
实现一个基于文本界面的<家庭记账软件> 需求:能够记录家庭的收入,支出,并能够收支明细表 主要涉及一下知识点: - 局部变量和基本数据类型 - 循环语句 - 分支语句 - 方法调用和返回值 ...
- MySQL索引由浅入深
索引是SQL优化中最重要的手段之一,本文从基础到原理,带你深度掌握索引. 一.索引基础 1.什么是索引 MySQL官方对索引的定义为:索引(Index)是帮助MySQL高效获取数据的数据结构,索引对于 ...
- 解决springboot项目打成jar包部署到linux服务器后上传图片无法访问的问题
前言:目前大三,自己也在学习和摸索的阶段.在和学校的同学一起做前后端分离项目的时候,我们发现将后端打包成jar,然后部署到服务器中通过java -jar xxx.jar运行项目以后,项目中存在文件上传 ...
- POJ-3080(KMP+多个字符串的最长公共子串)
Blue Jeans HDOJ-3080 本题使用的是KMP算法加暴力解决 首先枚举第一个字符串的所有子串,复杂度为O(60*60),随后再将每个子串和所有剩下的m-1个字符串比较,看是否存在这个子串 ...
- 【粉丝问答10】C语言关键字static的使用详解
视频地址:https://www.ixigua.com/6935761378816819748 粉丝提问 粉丝问题,总结一下: 关键字static的使用方法. 要想搞清楚关键字static的使用方法, ...
- Hi3559AV100外接UVC/MJPEG相机实时采图设计(四):VDEC_Send_Stream线程分析
下面随笔将对Hi3559AV100外接UVC/MJPEG相机实现实时采图设计的关键点-VDEC_Send_Stream线程进行分析,一两个星期前我写了有三篇系列随笔,已经实现了项目功能,大家可以参考下 ...
- Prometheus时序数据库-数据的插入
Prometheus时序数据库-数据的插入 前言 在之前的文章里,笔者详细的阐述了Prometheus时序数据库在内存和磁盘中的存储结构.有了前面的铺垫,笔者就可以在本篇文章阐述下数据的插入过程. 监 ...
- Apache配置 6. 访问日记切割
日志一直记录总有一天会把整个磁盘占满,所以有必要让它自动切割,并删除老的日志文件 (1)配置 (1)配置 # vim /usr/local/apache2 .4/conf/extra/httpd-vh ...