K-近邻算法(k-Nearest Neighbor,简称kNN)采用测量不同特征值之间的距离方法进行分类,是一种常用的监督学习方法,其工作机制很简单:给定测试样本,基于某种距离亮度找出训练集中与其靠近的k个训练样本,然后基于这k个“邻居”的信息进行预测。kNN算法属于懒惰学习,此类学习技术在训练阶段仅仅是把样本保存起来,训练时间靠小为零,在收到测试样本后在进行处理,所以可知kNN算法的缺点是计算复杂度高、空间复杂度高。但其也有优点,精度高、对异常值不敏感、无数据输入设定。

  借张图来说:

当k = 1时目标点有一个class2邻居,根据kNN算法的原理,目标点也为class2。

当k = 5时目标点有两个class2邻居,有三个class1的邻居,根据其原理,目标点的类别为class2。

算法流程

总体来说,KNN分类算法包括以下4个步骤:

①准备数据,对数据进行预处理 。

②计算测试样本点(也就是待分类点)到其他每个样本点的距离。

③对每个距离进行排序,然后选择出距离最小的K个点 。

④对K个点所属的类别进行比较,根据少数服从多数的原则,将测试样本点归入在K个点中占比最高的那一类 。

算法代码

  1. package com.top.knn;
  2.  
  3. import com.top.constants.OrderEnum;
  4. import com.top.matrix.Matrix;
  5. import com.top.utils.MatrixUtil;
  6.  
  7. import java.util.*;
  8.  
  9. /**
  10. * @program: top-algorithm-set
  11. * @description: KNN k-临近算法进行分类
  12. * @author: Mr.Zhao
  13. * @create: 2020-10-13 22:03
  14. **/
  15. public class KNN {
  16. public static Matrix classify(Matrix input, Matrix dataSet, Matrix labels, int k) throws Exception {
  17. if (dataSet.getMatrixRowCount() != labels.getMatrixRowCount()) {
  18. throw new IllegalArgumentException("矩阵训练集与标签维度不一致");
  19. }
  20. if (input.getMatrixColCount() != dataSet.getMatrixColCount()) {
  21. throw new IllegalArgumentException("待分类矩阵列数与训练集列数不一致");
  22. }
  23. if (dataSet.getMatrixRowCount() < k) {
  24. throw new IllegalArgumentException("训练集样本数小于k");
  25. }
  26. // 归一化
  27. int trainCount = dataSet.getMatrixRowCount();
  28. int testCount = input.getMatrixRowCount();
  29. Matrix trainAndTest = dataSet.splice(2, input);
  30. Map<String, Object> normalize = MatrixUtil.normalize(trainAndTest, 0, 1);
  31. trainAndTest = (Matrix) normalize.get("res");
  32. dataSet = trainAndTest.subMatrix(0, trainCount, 0, trainAndTest.getMatrixColCount());
  33. input = trainAndTest.subMatrix(0, testCount, 0, trainAndTest.getMatrixColCount());
  34.  
  35. // 获取标签信息
  36. List<Double> labelList = new ArrayList<>();
  37. for (int i = 0; i < labels.getMatrixRowCount(); i++) {
  38. if (!labelList.contains(labels.getValOfIdx(i, 0))) {
  39. labelList.add(labels.getValOfIdx(i, 0));
  40. }
  41. }
  42.  
  43. Matrix result = new Matrix(new double[input.getMatrixRowCount()][1]);
  44. for (int i = 0; i < input.getMatrixRowCount(); i++) {
  45. // 求向量间的欧式距离
  46. Matrix var1 = input.getRowOfIdx(i).extend(2, dataSet.getMatrixRowCount());
  47. Matrix var2 = dataSet.subtract(var1);
  48. Matrix var3 = var2.square();
  49. Matrix var4 = var3.sumRow();
  50. Matrix var5 = var4.pow(0.5);
  51. // 距离矩阵合并上labels矩阵
  52. Matrix var6 = var5.splice(1, labels);
  53. // 将计算出的距离矩阵按照距离升序排序
  54. var6.sort(0, OrderEnum.ASC);
  55. // 遍历最近的k个变量
  56. Map<Double, Integer> map = new HashMap<>();
  57. for (int j = 0; j < k; j++) {
  58. // 遍历标签种类数
  59. for (Double label : labelList) {
  60. if (var6.getValOfIdx(j, 1) == label) {
  61. map.put(label, map.getOrDefault(label, 0) + 1);
  62. }
  63. }
  64. }
  65. result.setValue(i, 0, getKeyOfMaxValue(map));
  66. }
  67. return result;
  68. }
  69.  
  70. /**
  71. * 取map中值最大的key
  72. *
  73. * @param map
  74. * @return
  75. */
  76. private static Double getKeyOfMaxValue(Map<Double, Integer> map) {
  77. if (map == null)
  78. return null;
  79. Double keyOfMaxValue = 0.0;
  80. Integer maxValue = 0;
  81. for (Double key : map.keySet()) {
  82. if (map.get(key) > maxValue) {
  83. keyOfMaxValue = key;
  84. maxValue = map.get(key);
  85. }
  86. }
  87. return keyOfMaxValue;
  88. }
  89.  
  90. }

KNN

注:其中的矩阵方法请参考https://github.com/ineedahouse/top-algorithm-set/blob/dev/src/main/java/com/top/matrix/Matrix.java

  升降序枚举类参考https://github.com/ineedahouse/top-algorithm-set/blob/dev/src/main/java/com/top/constants/OrderEnum.java

该算法为本人github项目中的一部分,地址为https://github.com/ineedahouse/top-algorithm-set

如果对你有帮助可以点个star~

参考

《机器学习》-周志华

《机器学习实战》-Peter Harrington

K-近邻算法kNN的更多相关文章

  1. k近邻算法(KNN)

    k近邻算法(KNN) 定义:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别. from sklearn.model_selection ...

  2. 机器学习(四) 分类算法--K近邻算法 KNN (上)

    一.K近邻算法基础 KNN------- K近邻算法--------K-Nearest Neighbors 思想极度简单 应用数学知识少 (近乎为零) 效果好(缺点?) 可以解释机器学习算法使用过程中 ...

  3. 一看就懂的K近邻算法(KNN),K-D树,并实现手写数字识别!

    1. 什么是KNN 1.1 KNN的通俗解释 何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1 ...

  4. 机器学习(四) 机器学习(四) 分类算法--K近邻算法 KNN (下)

    六.网格搜索与 K 邻近算法中更多的超参数 七.数据归一化 Feature Scaling 解决方案:将所有的数据映射到同一尺度 八.scikit-learn 中的 Scaler preprocess ...

  5. k近邻算法(knn)的c语言实现

    最近在看knn算法,顺便敲敲代码. knn属于数据挖掘的分类算法.基本思想是在距离空间里,如果一个样本的最接近的k个邻居里,绝大多数属于某个类别,则该样本也属于这个类别.俗话叫,"随大流&q ...

  6. 《机器学习实战》---第二章 k近邻算法 kNN

    下面的代码是在python3中运行, # -*- coding: utf-8 -*- """ Created on Tue Jul 3 17:29:27 2018 @au ...

  7. 最基础的分类算法-k近邻算法 kNN简介及Jupyter基础实现及Python实现

    k-Nearest Neighbors简介 对于该图来说,x轴对应的是肿瘤的大小,y轴对应的是时间,蓝色样本表示恶性肿瘤,红色样本表示良性肿瘤,我们先假设k=3,这个k先不考虑怎么得到,先假设这个k是 ...

  8. 07.k近邻算法kNN

    1.将数据分为测试数据和预测数据 2.数据分为data和target,data是矩阵,target是向量 3.将每条data(向量)绘制在坐标系中,就得到了一系列的点 4.根据每条data的targe ...

  9. 机器学习随笔01 - k近邻算法

    算法名称: k近邻算法 (kNN: k-Nearest Neighbor) 问题提出: 根据已有对象的归类数据,给新对象(事物)归类. 核心思想: 将对象分解为特征,因为对象的特征决定了事对象的分类. ...

  10. 机器学习(1)——K近邻算法

    KNN的函数写法 import numpy as np from math import sqrt from collections import Counter def KNN_classify(k ...

随机推荐

  1. thinkphp6.0.x 反序列化详记(一)

    前言 这几天算是进阶到框架类漏洞的学习了,首当其冲想到是thinkphp,先拿thinkphp6.0.x来学习一下,体验一下寻找pop链的快乐. 在此感谢楷师傅的帮忙~ 环境配置 用composer指 ...

  2. 编程体系结构(08):Spring.Mvc.Boot框架

    本文源码:GitHub·点这里 || GitEE·点这里 一.Spring框架 1.框架概述 Spring是一个开源框架,框架的主要优势之一就是其分层架构,分层架构允许使用者选择使用哪一个组件,同时为 ...

  3. MeteoInfoLab脚本示例:CloudSAT Swath HDF数据

    读取CloudSAT HDF Swath数据,绘图分上下两部分,上面是时间和高度维的Radar Reflectivity Factor二维图,下面是卫星轨迹图.示例程序: # Add file f = ...

  4. 【嵌入式】C语言高级编程▁▁▁嵌入式C语言入门编程学习!

    ✍  1.C 语言标准 什么是 C 语言标准呢? 我们生活的现实世界,就是由各种标准构成的,正是这些标准,我们的社会才会有条不紊的运行. 比如我们过马路,遵循的交通规则就是一个标准:红灯停,绿灯行,黄 ...

  5. 【Linux教程】Linux系统零基础编程入门,想当大神?这些你都要学

    ✍ 文件和文件系统 文件是Linux系统中最重要的抽象,大多数情况下你可以把linux系统中的任何东西都理解为文件,很多的交互操作其实都是通过文件的读写来实现的. 文件描述符 在Linux内核中,文件 ...

  6. centos8平台安装gitosis服务

    一,git服务器端:准备gitosis需要的各依赖软件 1,确认openssh是否存在?如不存在,以下列命令进行安装 [root@yjweb ~]# yum install openssh opens ...

  7. linux wget指定下载目录和重命名

    当我们在使用wget命令下载文件时,通常会需要将文件下载到指定的目录,这时就可以使用 -P 参数来指定目录,如果指定的目录不存在,则会自动创建. 示例: p.p1 { margin: 0; font: ...

  8. DateDiff() 方法语法 T-SQL语法

    表达式DateDiff(timeinterval,date1,date2 [, firstdayofweek [, firstweekofyear]]) 允许数据类型: timeinterval 表示 ...

  9. 【应用服务 App Service】发布到Azure上的应用显示时间不是本地时间的问题,修改应用服务的默认时区

    问题情形 应用程序发布到App Service后,时间显示不是北京时间,默认情况为UTC时间,比中国时间晚 8 个小时. 详细日志 无 问题原因 Azure 上所有的服务时间都采用了 UTC 时间. ...

  10. MySQL数据库基础-2范式

    数据库结构设计 范式 设计数据库的规范 第12345范式,凡是之间有依赖关系. 关系模型的发明者埃德加·科德最早提出这一概念,并于1970 年代初定义了第一范式.第二范式和第三范式的概念 设计关系数据 ...