之前两篇随笔介绍了kd树的原理,并用python实现了kd树的构建和搜索,具体可以参考

  kd树的原理

  python kd树 搜索 代码

  kd树常与knn算法联系在一起,knn算法通常要搜索k近邻,而不仅仅是最近邻,下面的代码将利用kd树搜索目标点的k个近邻。

  首先还是创建一个类,用于保存结点的值,左右子树,以及用于划分左右子树的切分轴

class decisionnode:
def __init__(self,value=None,col=None,rb=None,lb=None):
self.value=value
self.col=col
self.rb=rb
self.lb=lb

  切分点为坐标轴上的中值,下面代码求得一个序列的中值

def median(x):
n=len(x)
x=list(x)
x_order=sorted(x)
return x_order[n//2],x.index(x_order[n//2])

  然后按照左子树大于切分点,右子树小于切分点的规则构造kd树,其中data是输入的数据

#以j列的中值划分数据,左小右大,j=节点深度%列数
def buildtree(x,j=0):
rb=[]
lb=[]
m,n=x.shape
if m==0: return None
edge,row=median(x[:,j].copy())
for i in range(m):
if x[i][j]>edge:
rb.append(i)
if x[i][j]<edge:
lb.append(i)
rb_x=x[rb,:]
lb_x=x[lb,:]
rightBranch=buildtree(rb_x,(j+1)%n)
leftBranch=buildtree(lb_x,(j+1)%n)
return decisionnode(x[row,:],j,rightBranch,leftBranch)

  接下来就是搜索树得到k近邻的过程,与搜索最近邻的过程大致相同,需要创建一个字典knears,用于存储k近邻的点以及与目标点的距离(欧氏距离)

  搜索的过程为:

  (1)第一步还是遍历树,找到目标点所属区域对应的叶节点

  (2)从叶结点依次向上回退,按照寻找最近邻点的方法回退到父节点,并判断其另一个子节点对区域内是否可能存在k近邻点,具体的,在每个结点上进行以下操作:

    (a)如果字典中的成员个数不足k个,将该结点加入字典

    (b)如果字典中的成员不少于k个,判断该结点与目标结点之间的距离是否不大于字典中各结点所对应距离的的最大值,如果不大于,便将其加入到字典中

    (c)对于父节点来说,如果目标点与其切分轴之间的距离不大于字典中各结点所对应距离的的最大值,便需要访问该父节点的另一个子节点

  (3)每当字典中增加新成员,就按距离值对字典进行降序排序,将得到的列表赋值给poinelist,pointlist[0][1]便是字典中各结点所对应距离的最大值

  (4)当回退到根节点并完成对其操作时,pointlist中后k个结点就是目标点的k近邻

  代码如下:

#搜索树:输出目标点的近邻点
def traveltree(node,aim):
global pointlist #存储排序后的k近邻点和对应距离
if node==None: return
col=node.col
if aim[col]>node.value[col]:
traveltree(node.rb,aim)
if aim[col]<node.value[col]:
traveltree(node.lb,aim)
dis=dist(node.value,aim)
if len(knears)<k:
knears.setdefault(tuple(node.value.tolist()),dis)#列表不能作为字典的键
pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)
elif dis<=pointlist[0][1]:
knears.setdefault(tuple(node.value.tolist()),dis)
pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)
if node.rb!=None or node.lb!=None:
if abs(aim[node.col] - node.value[node.col]) < pointlist[0][1]:
if aim[node.col]<node.value[node.col]:
traveltree(node.rb,aim)
if aim[node.col]>node.value[node.col]:
traveltree(node.lb,aim)
return pointlist

  完整代码在此处取

 import numpy as np
from numpy import array
class decisionnode:
def __init__(self,value=None,col=None,rb=None,lb=None):
self.value=value
self.col=col
self.rb=rb
self.lb=lb #读取数据并将数据转换为矩阵形式
def readdata(filename):
data=open(filename).readlines()
x=[]
for line in data:
line=line.strip().split('\t')
x_i=[]
for num in line:
num=float(num)
x_i.append(num)
x.append(x_i)
x=array(x)
return x #求序列的中值
def median(x):
n=len(x)
x=list(x)
x_order=sorted(x)
return x_order[n//2],x.index(x_order[n//2]) #以j列的中值划分数据,左小右大,j=节点深度%列数
def buildtree(x,j=0):
rb=[]
lb=[]
m,n=x.shape
if m==0: return None
edge,row=median(x[:,j].copy())
for i in range(m):
if x[i][j]>edge:
rb.append(i)
if x[i][j]<edge:
lb.append(i)
rb_x=x[rb,:]
lb_x=x[lb,:]
rightBranch=buildtree(rb_x,(j+1)%n)
leftBranch=buildtree(lb_x,(j+1)%n)
return decisionnode(x[row,:],j,rightBranch,leftBranch) #搜索树:输出目标点的近邻点
def traveltree(node,aim):
global pointlist #存储排序后的k近邻点和对应距离
if node==None: return
col=node.col
if aim[col]>node.value[col]:
traveltree(node.rb,aim)
if aim[col]<node.value[col]:
traveltree(node.lb,aim)
dis=dist(node.value,aim)
if len(knears)<k:
knears.setdefault(tuple(node.value.tolist()),dis)#列表不能作为字典的键
pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)
elif dis<=pointlist[0][1]:
knears.setdefault(tuple(node.value.tolist()),dis)
pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)
if node.rb!=None or node.lb!=None:
if abs(aim[node.col] - node.value[node.col]) < pointlist[0][1]:
if aim[node.col]<node.value[node.col]:
traveltree(node.rb,aim)
if aim[node.col]>node.value[node.col]:
traveltree(node.lb,aim)
return pointlist def dist(x1, x2): #欧式距离的计算
return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5 knears={}
k=int(input('请输入k的值'))
if k<2: print('k不能是1')
global pointlist
pointlist=[]
file=input('请输入数据文件地址')
data=readdata(file)
tree=buildtree(data)
tmp=input('请输入目标点')
tmp=tmp.split(',')
aim=[]
for num in tmp:
num=float(num)
aim.append(num)
aim=tuple(aim)
pointlist=traveltree(tree,aim)
for point in pointlist[-k:]:
print(point)

kdtree

kd树 求k近邻 python 代码的更多相关文章

  1. K近邻python

    有一个带标签的数据集X,标签为y.我们想通过这个数据集预测目标点x0的所属类别. K近邻算法是指在X的特征空间中,把x0放进去,然后找到距离x0最近的K个点.通过这K个点所属类别,一般根据少数服从多数 ...

  2. 统计学习三:2.K近邻法代码实现(以最近邻法为例)

    通过上文可知k近邻算法的基本原理,以及算法的具体流程,kd树的生成和搜索算法原理.本文实现了kd树的生成和搜索算法,通过对算法的具体实现,我们可以对算法原理有进一步的了解.具体代码可以在我的githu ...

  3. K近邻 Python实现 机器学习实战(Machine Learning in Action)

    算法原理 K近邻是机器学习中常见的分类方法之间,也是相对最简单的一种分类方法,属于监督学习范畴.其实K近邻并没有显式的学习过程,它的学习过程就是测试过程.K近邻思想很简单:先给你一个训练数据集D,包括 ...

  4. 机器学习_K近邻Python代码详解

    k近邻优点:精度高.对异常值不敏感.无数据输入假定:k近邻缺点:计算复杂度高.空间复杂度高 import numpy as npimport operatorfrom os import listdi ...

  5. CodeForces - 1093G:Multidimensional Queries (线段树求K维最远点距离)

    题意:给定N个K维的点,Q次操作,或者修改点的坐标:或者问[L,R]这些点中最远的点. 思路:因为最后一定可以表示维+/-(x1-x2)+/-(y1-y2)+/-(z1-z2)..... 所以我们可以 ...

  6. kNN(k近邻)算法代码实现

    目标:预测未知数据(或测试数据)X的分类y 批量kNN算法 1.输入一个待预测的X(一维或多维)给训练数据集,计算出训练集X_train中的每一个样本与其的距离 2.找到前k个距离该数据最近的样本-- ...

  7. k临近法的实现:kd树

    # coding:utf-8 import numpy as np import matplotlib.pyplot as plt T = [[2, 3], [5, 4], [9, 6], [4, 7 ...

  8. k近邻法的C++实现:kd树

    1.k近邻算法的思想 给定一个训练集,对于新的输入实例,在训练集中找到与该实例最近的k个实例,这k个实例中的多数属于某个类,就把该输入实例分为这个类. 因为要找到最近的k个实例,所以计算输入实例与训练 ...

  9. 统计学习方法与Python实现(二)——k近邻法

    统计学习方法与Python实现(二)——k近邻法 iwehdio的博客园:https://www.cnblogs.com/iwehdio/ 1.定义 k近邻法假设给定一个训练数据集,其中的实例类别已定 ...

随机推荐

  1. 如何打开linux内核中dev_dbg的开关

    比如要打开某个驱动中的dev_dbg,那么需要在驱动文件.c中这些行"<linux/device.h>"或者"<linux /platfom_devic ...

  2. JAVA基础补漏--泛型通配符

    泛型通配符只能用于方法的参数 不能用对象定义 public class Test { public static void main(String[] args) { ArrayList<Str ...

  3. python文件打开的几种访问模式

    文件打开的几种访问模式 访问模式 说明 r 以只读方式打开文件.文件的指针将会放在文件的开头.这是默认模式. w 打开一个文件只用于写入.如果该文件已存在则将其覆盖.如果该文件不存在,创建新文件. a ...

  4. LeetCode——Nth Digit

    Question Find the nth digit of the infinite integer sequence 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, ... ...

  5. 在Linux Centos 7.2 上安装指定版本Docker。

    相关资料链接: https://docs.docker.com/install/linux/docker-ce/centos/#install-docker-ce 先清空下“历史” yum remov ...

  6. pyspider—爬取下载图片

    以第一ppt网站为例:http://www.1ppt.com/ from pyspider.libs.base_handler import * import urllib2,HTMLParser,r ...

  7. lucene介绍和存储介绍

    全文检索基础 1. Windows系统中的有搜索功能:打开“我的电脑”,按“F3”就可以使用查找的功能,查找指定的文件或文件夹.搜索的范围是整个电脑中的文件资源. 2. 在BBS.BLOG.新闻等系统 ...

  8. MySQL索引底层实现

    一.定义 索引定义:索引(Index)是帮助MySQL高效获取数据的数据结构.本质:索引是数据结构. 二.B-Tree m阶B-Tree满足以下条件: 每个节点至多可以拥有m棵子树. 根节点,只有至少 ...

  9. 解决:make:cc 命令未找到的解决方法

    安装Redis的时候报这个错误 原因:未安装gcc 解决方法:安装gcc 自动安装,包括依赖库[root@VM_220_111_centos redis-3.2.9]# yum -y install ...

  10. 一个不错的JavaScript解析浏览器路径方法

    JavaScript中有时需要用到当前的请求路径等涉及到url的情况,正常情况下我们可以使用location对象来获取我们需要的信息,本文从另外一个途径来解决这个问题,而且更加巧妙 方法如下: fun ...