K近邻法

1基本概念

K近邻法,是一种基本分类和回归规则。根据已有的训练数据集(含有标签),对于新的实例,根据其最近的k个近邻的类别,通过多数表决的方式进行预测。

2模型相关

2.1 距离的度量方式

定义距离

(1)欧式距离:p=2。

(2)曼哈顿距离:p=1。

(3)各坐标的最大值:p=∞。

2.2 K值的选择

通常使用交叉验证法来选取最优的k值。

k值大小的影响:

k越小,只有距该点较近的实例才会起作用,学习的近似误差会较小。但此时又会对这些近邻的实例很敏感,如果紧邻点存在噪声,预测就会出错,即学习的估计误差大,泛化能力不好。

K越大,距该点较远的实例也会起作用,造成近似误差增大,使预测发生错误。

2.3 k近邻法的实现:kd树

  Kd树是二叉树。kd树是一种对K维空间中的实例点进行存储以便对其进行快速检索的树形数据结构.

  Kd树是二叉树, 表示对K维空间的一个划分( partition).构造Kd树相 当于不断地用垂直于坐标轴的超平面将k维空间切分, 构成一系列的k维超矩形区域.Kd树的每个结点对应于一个k维超矩形区域

  其中,创建kd树时,垂直于坐标轴的超平面垂直的坐标轴选择是:

  L=(J mod k)+1。其中,j为当前节点的节点深度,k为k维空间(给定实例点的k个维度)。根节点的节点深度为0.此公式可看为:依次循环实例点的k个维所对应的坐标轴。

  Kd树的节点(分割点)为L维上所有实例点的中位数。

2.4 Kd树的实现

  别处代码实现基于其他博客,但是纠正了其中的错误,能够返回前k个近邻。如果要求最近邻,只需要将k=1即可。

  

  1. public class BinaryTreeOrder {
  2.  
  3. public void preOrder(Node root) {
  4. if(root!= null){
  5. System.out.print(root.toString());
  6. preOrder(root.left);
  7. preOrder(root.right);
  8. }
  9. }
  10. }
  1. public class kd_main {
  2.  
  3. public static void main(String[] args) {
  4. List<Node> nodeList=new ArrayList<Node>();
  5.  
  6. nodeList.add(new Node(new double[]{5,4}));
  7. nodeList.add(new Node(new double[]{9,6}));
  8.  
  9. nodeList.add(new Node(new double[]{8,1}));
  10. nodeList.add(new Node(new double[]{7,2}));
  11. nodeList.add(new Node(new double[]{2,3}));
  12. nodeList.add(new Node(new double[]{4,7}));
  13. nodeList.add(new Node(new double[]{4,3}));
  14. nodeList.add(new Node(new double[]{1,3}));
  15.  
  16. kd_main kdTree=new kd_main();
  17. Node root=kdTree.buildKDTree(nodeList,0);
  18. new BinaryTreeOrder().preOrder(root);
  19. for (Node node : nodeList) {
  20. System.out.println(node.toString()+"-->"+node.left.toString()+"-->"+node.right.toString());
  21. }
  22. System.out.println(root);
  23. System.out.println(kdTree.searchKNN(root,new Node(new double[]{2.1,3.1}),2));
  24. System.out.println(kdTree.searchKNN(root,new Node(new double[]{2,4.5}),1));
  25. System.out.println(kdTree.searchKNN(root,new Node(new double[]{2,4.5}),3));
  26. System.out.println(kdTree.searchKNN(root,new Node(new double[]{6,1}),5));
  27.  
  28. }
  29.  
  30. /**
  31. * 构建kd树 返回根节点
  32. * @param nodeList
  33. * @param index
  34. * @return
  35. */
  36. public Node buildKDTree(List<Node> nodeList,int index)
  37. {
  38. if(nodeList==null || nodeList.size()==0)
  39. return null;
  40. quickSortForMedian(nodeList,index,0,nodeList.size()-1);//中位数排序
  41. Node root=nodeList.get(nodeList.size()/2);//中位数 当做根节点
  42. root.dim=index;
  43. List<Node> leftNodeList=new ArrayList<Node>();//放入左侧区域的节点 包括包含与中位数等值的节点-_-
  44. List<Node> rightNodeList=new ArrayList<Node>();
  45.  
  46. for(Node node:nodeList)
  47. {
  48. if(root!=node)
  49. {
  50. if(node.getData(index)<=root.getData(index))
  51. leftNodeList.add(node);//左子区域 包含与中位数等值的节点
  52. else
  53. rightNodeList.add(node);
  54. }
  55. }
  56.  
  57. //计算从哪一维度切分
  58. int newIndex=index+1;//进入下一个维度
  59. if(newIndex>=root.data.length)
  60. newIndex=0;//从0维度开始再算
  61.  
  62. root.left=buildKDTree(leftNodeList,newIndex);//添加左右子区域
  63. root.right=buildKDTree(rightNodeList,newIndex);
  64.  
  65. if(root.left!=null)
  66. root.left.parent=root;//添加父指针
  67. if(root.right!=null)
  68. root.right.parent=root;//添加父指针
  69. return root;
  70. }
  71.  
  72. /**
  73. * 查询最近邻
  74. * @param root kd树
  75. * @param q 查询点
  76. * @param k
  77. * @return
  78. */
  79. public List<Node> searchKNN(Node root,Node q,int k)
  80. {
  81. List<Node> knnList=new ArrayList<Node>();
  82. searchBrother(knnList,root,q,k);
  83. return knnList;
  84. }
  85.  
  86. /**
  87. * searhchBrother
  88. * @param knnList
  89. * @param k
  90. * @param q
  91. */
  92. public void searchBrother(List<Node> knnList, Node root, Node q, int k) {
  93. // Node almostNNode=root;//近似最近点
  94. Node leafNNode=searchLeaf(root,q);
  95. double curD=q.computeDistance(leafNNode);//最近近似点与查询点的距离 也就是球体的半径
  96. leafNNode.distance=curD;
  97. maintainMaxHeap(knnList,leafNNode,k);
  98. System.out.println("leaf1"+leafNNode.getData(leafNNode.parent.dim));
  99. while(leafNNode!=root)
  100. {
  101. if (getBrother(leafNNode)!=null) {
  102. Node brother=getBrother(leafNNode);
  103. System.out.println("brother1"+brother.getData(brother.parent.dim));
  104. if(curD>Math.abs(q.getData(leafNNode.parent.dim)-leafNNode.parent.getData(leafNNode.parent.dim))||knnList.size()<k)
  105. {
  106. //这样可能在另一个子区域中存在更加近似的点
  107. searchBrother(knnList,brother, q, k);
  108. }
  109. }
  110. System.out.println("leaf2"+leafNNode.getData(leafNNode.parent.dim));
  111. leafNNode=leafNNode.parent;//返回上一级
  112. double rootD=q.computeDistance(leafNNode);//最近近似点与查询点的距离 也就是球体的半径
  113. leafNNode.distance=rootD;
  114. maintainMaxHeap(knnList,leafNNode,k);
  115. }
  116. }
  117.  
  118. /**
  119. * 获取兄弟节点
  120. * @param node
  121. * @return
  122. */
  123. public Node getBrother(Node node)
  124. {
  125. if(node==node.parent.left)
  126. return node.parent.right;
  127. else
  128. return node.parent.left;
  129. }
  130.  
  131. /**
  132. * 查询到叶子节点
  133. * @param root
  134. * @param q
  135. * @return
  136. */
  137. public Node searchLeaf(Node root,Node q)
  138. {
  139. Node leaf=root,next=null;
  140. int index=0;
  141. while(leaf.left!=null || leaf.right!=null)
  142. {
  143. if(q.getData(index)<leaf.getData(index))
  144. {
  145. next=leaf.left;//进入左侧
  146. }else if(q.getData(index)>leaf.getData(index))
  147. {
  148. next=leaf.right;
  149. }else{
  150. //当取到中位数时 判断左右子区域哪个更加近
  151. if(q.computeDistance(leaf.left)<q.computeDistance(leaf.right))
  152. next=leaf.left;
  153. else
  154. next=leaf.right;
  155. }
  156. if(next==null)
  157. break;//下一个节点是空时 结束了
  158. else{
  159. leaf=next;
  160. if(++index>=root.data.length)
  161. index=0;
  162. }
  163. }
  164.  
  165. return leaf;
  166. }
  167.  
  168. /**
  169. * 维护一个k的最大堆
  170. * @param listNode
  171. * @param newNode
  172. * @param k
  173. */
  174. public void maintainMaxHeap(List<Node> listNode,Node newNode,int k)
  175. {
  176. if(listNode.size()<k)
  177. {
  178. maxHeapFixUp(listNode,newNode);//不足k个堆 直接向上修复
  179. }else if(newNode.distance<listNode.get(0).distance){
  180. //比堆顶的要小 还需要向下修复 覆盖堆顶
  181. maxHeapFixDown(listNode,newNode);
  182. }
  183. }
  184.  
  185. /**
  186. * 从上往下修复 将会覆盖第一个节点
  187. * @param listNode
  188. * @param newNode
  189. */
  190. private void maxHeapFixDown(List<Node> listNode,Node newNode)
  191. {
  192. listNode.set(0, newNode);
  193. int i=0;
  194. int j=i*2+1;
  195. while(j<listNode.size())
  196. {
  197. if(j+1<listNode.size() && listNode.get(j).distance<listNode.get(j+1).distance)
  198. j++;//选出子结点中较大的点,第一个条件是要满足右子树不为空
  199.  
  200. if(listNode.get(i).distance>=listNode.get(j).distance)
  201. break;
  202.  
  203. Node t=listNode.get(i);
  204. listNode.set(i, listNode.get(j));
  205. listNode.set(j, t);
  206.  
  207. i=j;
  208. j=i*2+1;
  209. }
  210. }
  211.  
  212. private void maxHeapFixUp(List<Node> listNode,Node newNode)
  213. {
  214. listNode.add(newNode);
  215. int j=listNode.size()-1;
  216. int i=(j+1)/2-1;//i是j的parent节点
  217. while(i>=0)
  218. {
  219.  
  220. if(listNode.get(i).distance>=listNode.get(j).distance)
  221. break;
  222.  
  223. Node t=listNode.get(i);
  224. listNode.set(i, listNode.get(j));
  225. listNode.set(j, t);
  226.  
  227. j=i;
  228. i=(j+1)/2-1;
  229. }
  230. }
  231.  
  232. /**
  233. * 使用快排进进行一个中位数的查找 完了之后返回的数组size/2即中位数
  234. * @param nodeList
  235. * @param index
  236. * @param left
  237. * @param right
  238. */
  239. @Test
  240. private void quickSortForMedian(List<Node> nodeList,int index,int left,int right)
  241. {
  242. if(left>=right || nodeList.size()<=0)
  243. return ;
  244.  
  245. Node kn=nodeList.get(left);
  246. double k=kn.getData(index);//取得向量指定索引的值
  247.  
  248. int i=left,j=right;
  249.  
  250. //控制每一次遍历的结束条件,i与j相遇
  251. while(i<j)
  252. {
  253. //从右向左找一个小于i处值的值,并填入i的位置
  254. while(nodeList.get(j).getData(index)>=k && i<j)
  255. j--;
  256. nodeList.set(i, nodeList.get(j));
  257. //从左向右找一个大于i处值的值,并填入j的位置
  258. while(nodeList.get(i).getData(index)<=k && i<j)
  259. i++;
  260. nodeList.set(j, nodeList.get(i));
  261. }
  262.  
  263. nodeList.set(i, kn);
  264.  
  265. if(i==nodeList.size()/2)
  266. return ;//完成中位数的排序了,但并不是完成了所有数的排序,这个终止条件只是保证中位数是正确的。去掉该条件,可以保证在递归的作用下,将所有的树
  267. //将所有的数进行排序
  268.  
  269. else if(i<nodeList.size()/2)
  270. {
  271. quickSortForMedian(nodeList,index,i+1,right);//只需要排序右边就可以了
  272. }else{
  273. quickSortForMedian(nodeList,index,left,i-1);//只需要排序左边就可以了
  274. }
  275.  
  276. // for (Node node : nodeList) {
  277. // System.out.println(node.getData(index));
  278. // }
  279. }
  280. }
  1. public class Node implements Comparable<Node>{
  2. public double[] data;//树上节点的数据 是一个多维的向量
  3. public double distance;//与当前查询点的距离 初始化的时候是没有的
  4. public Node left,right,parent;//左右子节点 以及父节点
  5. public int dim=-1;//维度 建立树的时候判断的维度
  6.  
  7. public Node(double[] data)
  8. {
  9. this.data=data;
  10. }
  11.  
  12. /**
  13. * 返回指定索引上的数值
  14. * @param index
  15. * @return
  16. */
  17. public double getData(int index)
  18. {
  19. if(data==null || data.length<=index)
  20. return Integer.MIN_VALUE;
  21. return data[index];
  22. }
  23.  
  24. @Override
  25. public int compareTo(Node o) {
  26. if(this.distance>o.distance)
  27. return 1;
  28. else if(this.distance==o.distance)
  29. return 0;
  30. else return -1;
  31. }
  32.  
  33. /**
  34. * 计算距离 这里返回欧式距离
  35. * @param that
  36. * @return
  37. */
  38. public double computeDistance(Node that)
  39. {
  40. if(this.data==null || that.data==null || this.data.length!=that.data.length)
  41. return Double.MAX_VALUE;//出问题了 距离最远
  42. double d=0;
  43. for(int i=0;i<this.data.length;i++)
  44. {
  45. d+=Math.pow(this.data[i]-that.data[i], 2);
  46. }
  47.  
  48. return Math.sqrt(d);
  49. }
  50.  
  51. public String toString()
  52. {
  53. if(data==null || data.length==0)
  54. return null;
  55. StringBuilder sb=new StringBuilder();
  56. for(int i=0;i<data.length;i++)
  57. sb.append(data[i]+" ");
  58. sb.append(" d:"+this.distance);
  59. return sb.toString();
  60. }
  61. }

  参考文献:

    [1]李航.统计学习方法

  

统计学习方法学习(四)--KNN及kd树的java实现的更多相关文章

  1. 一看就懂的K近邻算法(KNN),K-D树,并实现手写数字识别!

    1. 什么是KNN 1.1 KNN的通俗解释 何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1 ...

  2. 统计学习方法笔记(KNN)

    k近邻法(k-nearest neighbor,k-NN) 输入:实例的特征向量,对应于特征空间的点:输出:实例的类别,可以取多类. 分类时,根据其k个最近邻的训练实例的类别,通过多数表决等方式进行预 ...

  3. 【分类算法】K近邻(KNN) ——kd树(转载)

    K近邻(KNN)的核心算法是kd树,转载如下几个链接: [量化课堂]一只兔子帮你理解 kNN [量化课堂]kd 树算法之思路篇 [量化课堂]kd 树算法之详细篇

  4. 统计学习方法——第四章朴素贝叶斯及c++实现

    1.名词解释 贝叶斯定理,自己看书,没啥说的,翻译成人话就是,条件A下的bi出现的概率等于A和bi一起出现的概率除以A出现的概率. 记忆方式就是变后验概率为先验概率,或者说,将条件与结果转换. 先验概 ...

  5. KNN算法与Kd树

    最近邻法和k-近邻法 下面图片中只有三种豆,有三个豆是未知的种类,如何判定他们的种类? 提供一种思路,即:未知的豆离哪种豆最近就认为未知豆和该豆是同一种类.由此,我们引出最近邻算法的定义:为了判定未知 ...

  6. 02-17 kd树

    目录 kd树 一.kd树学习目标 二.kd树引入 三.kd树详解 3.1 构造kd树 3.1.1 示例 3.2 kd树搜索 3.2.1 示例 四.kd树流程 4.1 输入 4.2 输出 4.3 流程 ...

  7. 统计学习方法——KD树最近邻搜索

    李航老师书上的的算法说明没怎么看懂,看了网上的博客,悟出一套循环(建立好KD树以后的最近邻搜索),我想应该是这样的(例子是李航<统计学习算法>第三章56页:例3.3): 步骤 结点查询标记 ...

  8. 李航统计学习方法(第二版)(六):k 近邻算法实现(kd树(kd tree)方法)

    1. kd树简介 构造kd树的方法如下:构造根结点,使根结点对应于k维空间中包含所有实例点的超矩形区域;通过下面的递归方法,不断地对k维空间进行切分,生成子结点.在超矩形区域(结点)上选择一个坐标轴和 ...

  9. 统计学习方法:KNN

    作者:桂. 时间:2017-04-19  21:20:09 链接:http://www.cnblogs.com/xingshansi/p/6736385.html 声明:欢迎被转载,不过记得注明出处哦 ...

随机推荐

  1. 以KeyValue形式构建Lua Table

    Key为字符串 -- 定义一个key,value形式的table local kv = {fruit = "apple", bread = "french", ...

  2. java-生成任意格式的json数据

    最近研究java的东西.之前靠着自己的摸索,实现了把java对象转成json格式的数据的功能,返回给前端.当时使用的是 JSONObject.fromObject(object) 方法把java对象换 ...

  3. Structured Streaming从Kafka 0.8中读取数据的问题

    众所周知,Structured Streaming默认支持Kafka 0.10,没有提供针对Kafka 0.8的Connector,但这对高手来说不是事儿,于是有个Hortonworks的邵大牛(前段 ...

  4. MongoDB学习笔记(一)

    最近有些时间,就抽空研究了一下MongoDB,我以前经常使用关系型数据库,如Oracle.MySQL,对MongoDB只是有些很肤浅的了解,最近下决心要好好研究一下,主要的参考书有两本:<Mon ...

  5. 【ANT】时间戳

    属性 说明 举例 DSTAMP 设置为当前日期,默认格式:yyyymmdd 20170309 TSTAMP 设置为当前时间,默认格式:hhmm 2007 TODAY 设置为当前日期,带完整的月份 Ma ...

  6. vue vuex vue-rouert后台项目——权限路由(超详细简单版)

    项目地址:vue-simple-template共三个角色:adan barbara carrie 密码全是:123456 adan 拥有 最高权限A 他可以看到 red , yellow 和 blu ...

  7. ES6之遍历器(Iterator)

    什么是Iterator?他是一种接口,为各种不同的数据结构提供统一的访问机制,任何数据结构只要部署上Iterator接口就可以完成遍历操作(PS:个人认为他的这个遍历就是c语言里面的指针),他的作用有 ...

  8. mvn命令笔记

    #发布到本地仓库 mvn deploy -DaltDeploymentRepository=snapshots::default::http://mvnrepo.xxx.com/mvn/snapsho ...

  9. [置顶] MVC输出缓存(OutputCache参数详解)

    1.学习之前你应该知道这些 几乎每个项目都会用到缓存,这是必然的.以前在学校时做的网站基本上的一个标准就是1.搞定增删改查2.页面做的不要太差3.能运行(ps真的有这种情况,答辩验收的时候几个人在讲台 ...

  10. signalr中Group 分组群发消息的简单使用

    前一段时间写了几篇关于signalr的文章 1.MVC中使用signalR入门教程 2.mvc中signalr实现一对一的聊天 3.Xamarin android中使用signalr实现即时通讯 在平 ...