统计学习方法 | 第3章 k邻近法
第3章 k近邻法
1.近邻法是基本且简单的分类与回归方法。近邻法的基本做法是:对给定的训练实例点和输入实例点,首先确定输入实例点的个最近邻训练实例点,然后利用这个训练实例点的类的多数来预测输入实例点的类。
2.近邻模型对应于基于训练数据集对特征空间的一个划分。近邻法中,当训练集、距离度量、值及分类决策规则确定后,其结果唯一确定。
3.近邻法三要素:距离度量、值的选择和分类决策规则。常用的距离度量是欧氏距离及更一般的pL距离。值小时,近邻模型更复杂;值大时,近邻模型更简单。值的选择反映了对近似误差与估计误差之间的权衡,通常由交叉验证选择最优的。
常用的分类决策规则是多数表决,对应于经验风险最小化。
4.近邻法的实现需要考虑如何快速搜索k个最近邻点。kd树是一种便于对k维空间中的数据进行快速检索的数据结构。kd树是二叉树,表示对维空间的一个划分,其每个结点对应于维空间划分中的一个超矩形区域。利用kd树可以省去对大部分数据点的搜索, 从而减少搜索的计算量。
距离度量
设特征空间是维实数向量空间 ,,, ,则:,的距离定义为:
- 曼哈顿距离
- 欧氏距离
- 闵式距离minkowski_distance
import math
from itertools import combinations
def L(x, y, p=2):
# x1 = [1, 1], x2 = [5,1]
if len(x) == len(y) and len(x) > 1:
sum = 0
for i in range(len(x)):
sum += math.pow(abs(x[i] - y[i]), p)
return math.pow(sum, 1 / p)
else:
return 0
课本例3.1
x1 = [1, 1]
x2 = [5, 1]
x3 = [4, 4]
# x1, x2
for i in range(1, 5):
r = {'1-{}'.format(c): L(x1, c, p=i) for c in [x2, x3]}
print(min(zip(r.values(), r.keys())))
(4.0, '1-[5, 1]')
(4.0, '1-[5, 1]')
(3.7797631496846193, '1-[4, 4]')
(3.5676213450081633, '1-[4, 4]')
python实现,遍历所有数据点,找出个距离最近的点的分类情况,少数服从多数
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counter
# data
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
# data = np.array(df.iloc[:100, [0, 1, -1]])
df
sepal length | sepal width | petal length | petal width | label | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
5 | 5.4 | 3.9 | 1.7 | 0.4 | 0 |
6 | 4.6 | 3.4 | 1.4 | 0.3 | 0 |
7 | 5.0 | 3.4 | 1.5 | 0.2 | 0 |
8 | 4.4 | 2.9 | 1.4 | 0.2 | 0 |
9 | 4.9 | 3.1 | 1.5 | 0.1 | 0 |
10 | 5.4 | 3.7 | 1.5 | 0.2 | 0 |
11 | 4.8 | 3.4 | 1.6 | 0.2 | 0 |
12 | 4.8 | 3.0 | 1.4 | 0.1 | 0 |
13 | 4.3 | 3.0 | 1.1 | 0.1 | 0 |
14 | 5.8 | 4.0 | 1.2 | 0.2 | 0 |
15 | 5.7 | 4.4 | 1.5 | 0.4 | 0 |
16 | 5.4 | 3.9 | 1.3 | 0.4 | 0 |
17 | 5.1 | 3.5 | 1.4 | 0.3 | 0 |
18 | 5.7 | 3.8 | 1.7 | 0.3 | 0 |
19 | 5.1 | 3.8 | 1.5 | 0.3 | 0 |
20 | 5.4 | 3.4 | 1.7 | 0.2 | 0 |
21 | 5.1 | 3.7 | 1.5 | 0.4 | 0 |
22 | 4.6 | 3.6 | 1.0 | 0.2 | 0 |
23 | 5.1 | 3.3 | 1.7 | 0.5 | 0 |
24 | 4.8 | 3.4 | 1.9 | 0.2 | 0 |
25 | 5.0 | 3.0 | 1.6 | 0.2 | 0 |
26 | 5.0 | 3.4 | 1.6 | 0.4 | 0 |
27 | 5.2 | 3.5 | 1.5 | 0.2 | 0 |
28 | 5.2 | 3.4 | 1.4 | 0.2 | 0 |
29 | 4.7 | 3.2 | 1.6 | 0.2 | 0 |
... | ... | ... | ... | ... | ... |
120 | 6.9 | 3.2 | 5.7 | 2.3 | 2 |
121 | 5.6 | 2.8 | 4.9 | 2.0 | 2 |
122 | 7.7 | 2.8 | 6.7 | 2.0 | 2 |
123 | 6.3 | 2.7 | 4.9 | 1.8 | 2 |
124 | 6.7 | 3.3 | 5.7 | 2.1 | 2 |
125 | 7.2 | 3.2 | 6.0 | 1.8 | 2 |
126 | 6.2 | 2.8 | 4.8 | 1.8 | 2 |
127 | 6.1 | 3.0 | 4.9 | 1.8 | 2 |
128 | 6.4 | 2.8 | 5.6 | 2.1 | 2 |
129 | 7.2 | 3.0 | 5.8 | 1.6 | 2 |
130 | 7.4 | 2.8 | 6.1 | 1.9 | 2 |
131 | 7.9 | 3.8 | 6.4 | 2.0 | 2 |
132 | 6.4 | 2.8 | 5.6 | 2.2 | 2 |
133 | 6.3 | 2.8 | 5.1 | 1.5 | 2 |
134 | 6.1 | 2.6 | 5.6 | 1.4 | 2 |
135 | 7.7 | 3.0 | 6.1 | 2.3 | 2 |
136 | 6.3 | 3.4 | 5.6 | 2.4 | 2 |
137 | 6.4 | 3.1 | 5.5 | 1.8 | 2 |
138 | 6.0 | 3.0 | 4.8 | 1.8 | 2 |
139 | 6.9 | 3.1 | 5.4 | 2.1 | 2 |
140 | 6.7 | 3.1 | 5.6 | 2.4 | 2 |
141 | 6.9 | 3.1 | 5.1 | 2.3 | 2 |
142 | 5.8 | 2.7 | 5.1 | 1.9 | 2 |
143 | 6.8 | 3.2 | 5.9 | 2.3 | 2 |
144 | 6.7 | 3.3 | 5.7 | 2.5 | 2 |
145 | 6.7 | 3.0 | 5.2 | 2.3 | 2 |
146 | 6.3 | 2.5 | 5.0 | 1.9 | 2 |
147 | 6.5 | 3.0 | 5.2 | 2.0 | 2 |
148 | 6.2 | 3.4 | 5.4 | 2.3 | 2 |
149 | 5.9 | 3.0 | 5.1 | 1.8 | 2 |
150 rows × 5 columns
plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()
<matplotlib.legend.Legend at 0x2c56f7f64e0>
data = np.array(df.iloc[:100, [0, 1, -1]])
X, y = data[:,:-1], data[:,-1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
class KNN:
def __init__(self, X_train, y_train, n_neighbors=3, p=2):
"""
parameter: n_neighbors 临近点个数
parameter: p 距离度量
"""
self.n = n_neighbors
self.p = p
self.X_train = X_train
self.y_train = y_train def predict(self, X):
# 取出n个点
knn_list = []
for i in range(self.n):
dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
knn_list.append((dist, self.y_train[i])) for i in range(self.n, len(self.X_train)):
max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))
dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
if knn_list[max_index][0] > dist:
knn_list[max_index] = (dist, self.y_train[i]) # 统计
knn = [k[-1] for k in knn_list]
count_pairs = Counter(knn)
# max_count = sorted(count_pairs, key=lambda x: x)[-1]
max_count = sorted(count_pairs.items(), key=lambda x: x[1])[-1][0]
return max_count def score(self, X_test, y_test):
right_count = 0
n = 10
for X, y in zip(X_test, y_test):
label = self.predict(X)
if label == y:
right_count += 1
return right_count / len(X_test)
clf = KNN(X_train, y_train)
clf.score(X_test, y_test)
0.95
test_point = [6.0, 3.0]
print('Test Point: {}'.format(clf.predict(test_point)))
Test Point: 1.0
plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
plt.plot(test_point[0], test_point[1], 'bo', label='test_point')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()
<matplotlib.legend.Legend at 0x2c56f8b0d68>
scikit-learn实例
from sklearn.neighbors import KNeighborsClassifier
clf_sk = KNeighborsClassifier()
clf_sk.fit(X_train, y_train)
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=None, n_neighbors=5, p=2,
weights='uniform')
clf_sk.score(X_test, y_test)
0.95
sklearn.neighbors.KNeighborsClassifier
- n_neighbors: 临近点个数
- p: 距离度量
- algorithm: 近邻算法,可选{'auto', 'ball_tree', 'kd_tree', 'brute'}
- weights: 确定近邻的权重
kd树
kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。
kd树是二叉树,表示对维空间的一个划分(partition)。构造kd树相当于不断地用垂直于坐标轴的超平面将维空间切分,构成一系列的k维超矩形区域。kd树的每个结点对应于一个维超矩形区域。
构造kd树的方法如下:
构造根结点,使根结点对应于维空间中包含所有实例点的超矩形区域;通过下面的递归方法,不断地对维空间进行切分,生成子结点。在超矩形区域(结点)上选择一个坐标轴和在此坐标轴上的一个切分点,确定一个超平面,这个超平面通过选定的切分点并垂直于选定的坐标轴,将当前超矩形区域切分为左右两个子区域 (子结点);这时,实例被分到两个子区域。这个过程直到子区域内没有实例时终止(终止时的结点为叶结点)。在此过程中,将实例保存在相应的结点上。
通常,依次选择坐标轴对空间切分,选择训练实例点在选定坐标轴上的中位数 (median)为切分点,这样得到的kd树是平衡的。注意,平衡的kd树搜索时的效率未必是最优的。
构造平衡kd树算法
输入:维空间数据集,
其中 ,;
输出:kd树。
(1)开始:构造根结点,根结点对应于包含的维空间的超矩形区域。
选择为坐标轴,以T中所有实例的坐标的中位数为切分点,将根结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴垂直的超平面实现。
由根结点生成深度为1的左、右子结点:左子结点对应坐标小于切分点的子区域, 右子结点对应于坐标大于切分点的子区域。
将落在切分超平面上的实例点保存在根结点。
(2)重复:对深度为的结点,选择为切分的坐标轴,,以该结点的区域中所有实例的坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴垂直的超平面实现。
由该结点生成深度为的左、右子结点:左子结点对应坐标小于切分点的子区域,右子结点对应坐标大于切分点的子区域。
将落在切分超平面上的实例点保存在该结点。
(3)直到两个子区域没有实例存在时停止。从而形成kd树的区域划分。
# kd-tree每个结点中主要包含的数据结构如下
class KdNode(object):
def __init__(self, dom_elt, split, left, right):
self.dom_elt = dom_elt # k维向量节点(k维空间中的一个样本点)
self.split = split # 整数(进行分割维度的序号)
self.left = left # 该结点分割超平面左子空间构成的kd-tree
self.right = right # 该结点分割超平面右子空间构成的kd-tree class KdTree(object):
def __init__(self, data):
k = len(data[0]) # 数据维度 def CreateNode(split, data_set): # 按第split维划分数据集exset创建KdNode
if not data_set: # 数据集为空
return None
# key参数的值为一个函数,此函数只有一个参数且返回一个值用来进行比较
# operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为需要获取的数据在对象中的序号
#data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序
data_set.sort(key=lambda x: x[split])
split_pos = len(data_set) // 2 # //为Python中的整数除法
median = data_set[split_pos] # 中位数分割点
split_next = (split + 1) % k # cycle coordinates # 递归的创建kd树
return KdNode(
median,
split,
CreateNode(split_next, data_set[:split_pos]), # 创建左子树
CreateNode(split_next, data_set[split_pos + 1:])) # 创建右子树 self.root = CreateNode(0, data) # 从第0维分量开始构建kd树,返回根节点 # KDTree的前序遍历
def preorder(root):
print(root.dom_elt)
if root.left: # 节点不为空
preorder(root.left)
if root.right:
preorder(root.right)
# 对构建好的kd树进行搜索,寻找与目标点最近的样本点:
from math import sqrt
from collections import namedtuple # 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数
result = namedtuple("Result_tuple",
"nearest_point nearest_dist nodes_visited") def find_nearest(tree, point):
k = len(point) # 数据维度 def travel(kd_node, target, max_dist):
if kd_node is None:
return result([0] * k, float("inf"),
0) # python中用float("inf")和float("-inf")表示正负无穷 nodes_visited = 1 s = kd_node.split # 进行分割的维度
pivot = kd_node.dom_elt # 进行分割的“轴” if target[s] <= pivot[s]: # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)
nearer_node = kd_node.left # 下一个访问节点为左子树根节点
further_node = kd_node.right # 同时记录下右子树
else: # 目标离右子树更近
nearer_node = kd_node.right # 下一个访问节点为右子树根节点
further_node = kd_node.left temp1 = travel(nearer_node, target, max_dist) # 进行遍历找到包含目标点的区域 nearest = temp1.nearest_point # 以此叶结点作为“当前最近点”
dist = temp1.nearest_dist # 更新最近距离 nodes_visited += temp1.nodes_visited if dist < max_dist:
max_dist = dist # 最近点将在以目标点为球心,max_dist为半径的超球体内 temp_dist = abs(pivot[s] - target[s]) # 第s维上目标点与分割超平面的距离
if max_dist < temp_dist: # 判断超球体是否与超平面相交
return result(nearest, dist, nodes_visited) # 不相交则可以直接返回,不用继续判断 #----------------------------------------------------------------------
# 计算目标点与分割点的欧氏距离
temp_dist = sqrt(sum((p1 - p2)**2 for p1, p2 in zip(pivot, target))) if temp_dist < dist: # 如果“更近”
nearest = pivot # 更新最近点
dist = temp_dist # 更新最近距离
max_dist = dist # 更新超球体半径 # 检查另一个子结点对应的区域是否有更近的点
temp2 = travel(further_node, target, max_dist) nodes_visited += temp2.nodes_visited
if temp2.nearest_dist < dist: # 如果另一个子结点内存在更近距离
nearest = temp2.nearest_point # 更新最近点
dist = temp2.nearest_dist # 更新最近距离 return result(nearest, dist, nodes_visited) return travel(tree.root, point, float("inf")) # 从根节点开始递归
例3.2
data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
kd = KdTree(data)
preorder(kd.root)
[7, 2]
[5, 4]
[2, 3]
[4, 7]
[9, 6]
[8, 1]
from time import clock
from random import random # 产生一个k维随机向量,每维分量值在0~1之间
def random_point(k):
return [random() for _ in range(k)] # 产生n个k维随机向量
def random_points(k, n):
return [random_point(k) for _ in range(n)]
ret = find_nearest(kd, [3,4.5])
print (ret)
Result_tuple(nearest_point=[2, 3], nearest_dist=1.8027756377319946, nodes_visited=4)
N = 400000
t0 = clock()
kd2 = KdTree(random_points(3, N)) # 构建包含四十万个3维空间样本点的kd树
ret2 = find_nearest(kd2, [0.1,0.5,0.8]) # 四十万个样本点中寻找离目标最近的点
t1 = clock()
print ("time: ",t1-t0, "s")
print (ret2)
time: 5.4623788 s
Result_tuple(nearest_point=[0.09929288205798159, 0.4954936771850429, 0.8005722800665575], nearest_dist=0.004597223680778027, nodes_visited=42)
统计学习方法 | 第3章 k邻近法的更多相关文章
- 统计学习方法 | 第3章 k邻近法 | 补充
namedtuple 不必再通过索引值进行访问,你可以把它看做一个字典通过名字进行访问,只不过其中的值是不能改变的. sorted()适用于任何可迭代容器,list.sort()仅支持list(本身就 ...
- 《统计学习方法》笔记三 k近邻法
本系列笔记内容参考来源为李航<统计学习方法> k近邻是一种基本分类与回归方法,书中只讨论分类情况.输入为实例的特征向量,输出为实例的类别.k值的选择.距离度量及分类决策规则是k近邻法的三个 ...
- 统计学习方法(三)——K近邻法
/*先把标题给写了.这样就能经常提醒自己*/ 1. k近邻算法 k临近算法的过程,即对一个新的样本,找到特征空间中与其最近的k个样本,这k个样本多数属于某个类,就把这个新的样本也归为这个类. 算法 ...
- 《统计学习方法(李航)》讲义 第03章 k近邻法
k 近邻法(k-nearest neighbor,k-NN) 是一种基本分类与回归方法.本书只讨论分类问题中的k近邻法.k近邻法的输入为实例的特征向量,对应于特征空间的点;输出为实例的类别,可以取多类 ...
- 统计学习三:1.k近邻法
全文引用自<统计学习方法>(李航) K近邻算法(k-nearest neighbor, KNN) 是一种非常简单直观的基本分类和回归方法,于1968年由Cover和Hart提出.在本文中, ...
- 统计学习三:2.K近邻法代码实现(以最近邻法为例)
通过上文可知k近邻算法的基本原理,以及算法的具体流程,kd树的生成和搜索算法原理.本文实现了kd树的生成和搜索算法,通过对算法的具体实现,我们可以对算法原理有进一步的了解.具体代码可以在我的githu ...
- 第三章 K近邻法(k-nearest neighbor)
书中存在的一些疑问 kd树的实现过程中,为何选择的切分坐标轴要不断变换?公式如:x(l)=j(modk)+1.有什么好处呢?优点在哪?还有的实现是通过选取方差最大的维度作为划分坐标轴,有何区别? 第一 ...
- 统计学习方法 | 第1章 统计学习方法概论 | Scipy中的Leastsq()
Scipy是一个用于数学.科学.工程领域的常用软件包,可以处理插值.积分.优化.图像处理.常微分方程数值解的求解.信号处理等问题.它用于有效计算Numpy矩阵,使Numpy和Scipy协同工作,高效解 ...
- 统计学习方法——第四章朴素贝叶斯及c++实现
1.名词解释 贝叶斯定理,自己看书,没啥说的,翻译成人话就是,条件A下的bi出现的概率等于A和bi一起出现的概率除以A出现的概率. 记忆方式就是变后验概率为先验概率,或者说,将条件与结果转换. 先验概 ...
随机推荐
- Eclipse 的 CheckStyle 插件
Eclipse 的 CheckStyle 插件 1.简介 Checkstyle 是 SourceForge 下的一个开源项目,提供了一个帮助 JAVA 开发人员遵守某些编码规范的工具.它能进行自动化代 ...
- freemarker页面静态化
1.工程结构 2. Student public class Student { private int id; private String name; private String address ...
- Battle ships HDU - 5093二分匹配
Battle shipsHDU - 5093 题目大意:n*m的地图,*代表海洋,#代表冰山,o代表浮冰,海洋上可以放置船舰,但是每一行每一列只能有一个船舰(类似象棋的車),除非同行或者同列的船舰中间 ...
- 灰度图像--图像增强 非锐化掩蔽 (Unsharpening Mask)
学习DIP第35天 转载请标明本文出处:http://blog.csdn.net/tonyshengtan,欢迎大家转载,发现博客被某些论坛转载后,图像无法正常显示,无法正常表达本人观点,对此表示很不 ...
- Java线程之Callable、Future
简述 在多线程中有时候我们希望一个线程执行完毕后可以返回一些值,在java5中引入了java.util.concurrent.Callable接口,它类似于Runnable接口,但是Callable可 ...
- Nginx一个server配置多个location
在配置文件中增加多个location,每个location对应一个项目 比如使用8066端口,location / 访问官网: location /demo访问培训管理系统配置多个站点我选择了配置多个 ...
- 连接数据库出现The server time zone value '�й���ʱ��' is unrecogni等问题的解决方案
使用JDBC连接数据库出现The server time zone value '�й���ʱ��' is解决方案 ** 将jdbc.properties中url后加入?serverTimezone ...
- HDX Insight Installation & Configuration
NetScaler Insight Center 11.1 Installation & Configuration NetScaler Insight Center 11.0 Insta ...
- GLSL语法入门
变量 GLSL的变量命名方式与C语言类似.变量的名称可以使用字母,数字以及下划线,但变量名不能以数字开头,还有变量名不能以gl_作为前缀,这个是GLSL保留的前缀,用于GLSL的内部变量.当然还有一些 ...
- LC 425. Word Squares 【lock,hard】
Given a set of words (without duplicates), find all word squares you can build from them. A sequence ...