K-近邻算法

一、算法概述

(1)采用测量不同特征值之间的距离方法进行分类

  • 优点: 精度高、对异常值不敏感、无数据输入假定。
  • 缺点: 计算复杂度高、空间复杂度高。

(2)KNN模型的三个要素

kNN算法模型实际上就是对特征空间的的划分。模型有三个基本要素:距离度量、K值的选择和分类决策规则的决定。

  • 距离度量

    距离定义为:

    \[L_p(x_i,x_j)=\left( \sum^n_{l=1} |x_i^{(l)} - x_j^{(l)}|^p \right) ^{\frac{1}{p}}
    \]

    一般使用欧式距离:p = 2的个情况

    \[L_p(x_i,x_j)=\left( \sum^n_{l=1} |x_i^{(l)} - x_j^{(l)}|^2 \right) ^{\frac{1}{2}}
    \]

  • K值的选择

    一般根据经验选择,需要多次选择对比才可以选择一个比较合适的K值。

    如果K值太小,会导致模型太复杂,容易产生过拟合现象,并且对噪声点非常敏感。

    如果K值太大,模型太过简单,忽略的大部分有用信息,也是不可取的。

  • 分类决策规则

    一般采用多数表决规则,通俗点说就是在这K个类别中,哪种类别最后就判别为哪种类型

二、实施kNN算法

2.1 伪代码

  • 计算法已经类别数据集中的点与当前点之间的距离
  • 按照距离递增次序排序
  • 选取与但前点距离最小的k个点
  • 确定前k个点所在类别的出现频率
  • 返回前k个点出现频率最高的类别作为当前点的预测分类

#### 2.2 实际代码

  1. def classify0(inX, dataSet, labels, k):
  2. dataSetSize = dataSet.shape[0]
  3. diffMat = tile(inX, (dataSetSize,1)) - dataSet
  4. sqDiffMat = diffMat**2
  5. sqDistances = sqDiffMat.sum(axis=1)
  6. distances = sqDistances**0.5
  7. sortedDistIndicies = distances.argsort()
  8. classCount={}
  9. for i in range(k):
  10. voteIlabel = labels[sortedDistIndicies[i]]
  11. classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
  12. sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
  13. return sortedClassCount[0][0]

三、实际案例:使用kNN算法改进约会网站的配对效果

我的朋友阿J一直使用在线约会软件寻找约会对象,他曾经交往过三种类型的人:

  • 不喜欢的人
  • 感觉一般的人
  • 非常喜欢的人

步骤:

  • 收集数据
  • 准备数据:也就是读取数据的过程
  • 分析数据:使用Matplotlib画出二维散点图
  • 训练算法
  • 测试算法
  • 使用算法

3.1 准备数据

样本数据共有1000个,3个特征值,共有4列数据,最后一列表示标签分类(0:不喜欢的人;1:感觉一般的人;2:非常喜欢的人)

特征

  • 每年获得的飞行常客里程数
  • 玩视频游戏所好的时间百分比
  • 每周消费的冰淇淋公斤数

部分数据如下:

  1. 40920 8.326976 0.953952 3
  2. 14488 7.153469 1.673904 2
  3. 26052 1.441871 0.805124 1
  4. 75136 13.147394 0.428964 1
  5. 38344 1.669788 0.134296 1
  6. 72993 10.141740 1.032955 1
  7. 35948 6.830792 1.213192 3
  8. 42666 13.276369 0.543880 3
  9. 67497 8.631577 0.749278 1
  10. 35483 12.273169 1.508053 3

读取数据(读取txt文件)

  1. def file2matrix(filename):
  2. fr = open(filename)
  3. numberOfLines = len(fr.readlines()) #get the number of lines in the file
  4. returnMat = zeros((numberOfLines,3)) #prepare matrix to return
  5. classLabelVector = [] #prepare labels return
  6. fr = open(filename)
  7. index = 0
  8. for line in fr.readlines():
  9. line = line.strip()
  10. listFromLine = line.split('\t')
  11. returnMat[index,:] = listFromLine[0:3]
  12. classLabelVector.append(int(listFromLine[-1]))
  13. index += 1
  14. return returnMat,classLabelVector

3.2 分析数据:使用Matplotlib创建散点图

初步分析
  1. import matplotlib
  2. import matplotlib.pyplot as plt
  3. plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
  4. fig = plt.figure()
  5. ax = fig.add_subplot(111)
  6. ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
  7. ax.set_xlabel("玩视频游戏所耗时间百分比")
  8. ax.set_ylabel("每周消费的冰淇淋公斤数")
  9. plt.show()

因为有三种类型的分类,这样看的不直观,我们添加以下颜色

  1. fig = plt.figure()
  2. ax = fig.add_subplot(111)
  3. ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
  4. ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*array(datingLabels), 15.0*array(datingLabels))
  5. ax.set_xlabel("玩视频游戏所耗时间百分比")
  6. ax.set_ylabel("每周消费的冰淇淋公斤数")
  7. plt.show()

通过都多次的尝试后发现,玩游戏时间和冰淇淋这个两个特征关系比较明显

具体的步骤:

  • 分别将标签为1,2,3的三种类型的数据分开
  • 使用matplotlib绘制,并使用不同的颜色加以区分
  1. datingDataType1 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==1])
  2. datingDataType2 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==2])
  3. datingDataType3 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==3])
  4. fig, axs = plt.subplots(2, 2, figsize = (15,10))
  5. axs[0,0].scatter(datingDataType1[:,0], datingDataType1[:,1], s = 20, c = 'red')
  6. axs[0,1].scatter(datingDataType2[:,0], datingDataType2[:,1], s = 30, c = 'green')
  7. axs[1,0].scatter(datingDataType3[:,0], datingDataType3[:,1], s = 40, c = 'blue')
  8. type1 = axs[1,1].scatter(datingDataType1[:,0], datingDataType1[:,1], s = 20, c = 'red')
  9. type2 = axs[1,1].scatter(datingDataType2[:,0], datingDataType2[:,1], s = 30, c = 'green')
  10. type3 = axs[1,1].scatter(datingDataType3[:,0], datingDataType3[:,1], s = 40, c = 'blue')
  11. axs[1,1].legend([type1, type2, type3], ["Did Not Like", "Liked in Small Doses", "Liked in Large Doses"], loc=2)
  12. axs[1,1].set_xlabel("玩视频游戏所耗时间百分比")
  13. axs[1,1].set_ylabel("每周消费的冰淇淋公斤数")
  14. plt.show()

3.3 准备数据:数据归一化

通过上面的图形绘制,发现三个特征值的范围不一样,在使用KNN进行计算距离的时候,数值大的特征值就会对结果产生更大的影响。

数据归一化:就是将几组不同范围的数据,转换到同一个范围内。

公式: newValue = (oldValue - min)/(max - min)

  1. def autoNorm(dataSet):
  2. minVals = dataSet.min(0) # array([[1,20,3], [4,5,60], [7,8,9]]) min(0) = [1, 5, 3]
  3. maxVals = dataSet.max(0)
  4. ranges = maxVals - minVals
  5. normData = zeros(shape(dataSet))
  6. m = dataSet.shape[0]
  7. normData = (dataSet - tile(minVals, (m,1)))/tile(ranges,(m,1))
  8. return normData

3.4 测试算法

我们将原始样本保留20%作为测试集,剩余80%作为训练集

  1. def datingClassTest():
  2. hoRatio = 0.20
  3. datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file
  4. normMat = autoNorm(datingDataMat)
  5. m = normMat.shape[0]
  6. numTestVecs = int(m*hoRatio)
  7. errorCount = 0.0
  8. for i in range(numTestVecs):
  9. classifierResult = classify0(normMat[i,:],normMat[numTestVecs:,:],datingLabels[numTestVecs:],3)
  10. if (classifierResult != datingLabels[i]):
  11. errorCount += 1.0
  12. print ("the total error rate is: %f" % (errorCount/float(numTestVecs)))
  13. print (errorCount)

运行结果

  1. the total error rate is: 0.080000
  2. 16.0

四、源代码

  1. from numpy import *
  2. import operator
  3. from os import listdir
  4. import matplotlib
  5. import matplotlib.pyplot as plt
  6. ## KNN function
  7. def classify0(inX, dataSet, labels, k):
  8. dataSetSize = dataSet.shape[0]
  9. diffMat = tile(inX, (dataSetSize,1)) - dataSet
  10. sqDiffMat = diffMat**2
  11. sqDistances = sqDiffMat.sum(axis=1)
  12. distances = sqDistances**0.5
  13. sortedDistIndicies = distances.argsort()
  14. classCount={}
  15. for i in range(k):
  16. voteIlabel = labels[sortedDistIndicies[i]]
  17. classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
  18. sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
  19. return sortedClassCount[0][0]
  20. # read txt data
  21. def file2matrix(filename):
  22. fr = open(filename)
  23. numberOfLines = len(fr.readlines()) #get the number of lines in the file
  24. returnMat = zeros((numberOfLines,3)) #prepare matrix to return
  25. classLabelVector = [] #prepare labels return
  26. fr = open(filename)
  27. index = 0
  28. for line in fr.readlines():
  29. line = line.strip()
  30. listFromLine = line.split('\t')
  31. returnMat[index,:] = listFromLine[0:3]
  32. classLabelVector.append(int(listFromLine[-1]))
  33. index += 1
  34. return returnMat,classLabelVector
  35. def autoNorm(dataSet):
  36. minVals = dataSet.min(0) # array([[1,20,3], [4,5,60], [7,8,9]]) min(0) = [1, 5, 3]
  37. maxVals = dataSet.max(0)
  38. ranges = maxVals - minVals
  39. normData = zeros(shape(dataSet))
  40. m = dataSet.shape[0]
  41. normData = (dataSet - tile(minVals, (m,1)))/tile(ranges,(m,1))
  42. return normData
  43. def drawScatter1(datingDataMat, datingLabels):
  44. plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
  45. fig = plt.figure()
  46. ax = fig.add_subplot(111)
  47. ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
  48. ax.set_xlabel("玩视频游戏所耗时间百分比")
  49. ax.set_ylabel("每周消费的冰淇淋公斤数")
  50. plt.show()
  51. def drawScatter2(datingDataMat, datingLabels):
  52. fig = plt.figure()
  53. ax = fig.add_subplot(111)
  54. ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
  55. ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*array(datingLabels), 15.0*array(datingLabels))
  56. ax.set_xlabel("玩视频游戏所耗时间百分比")
  57. ax.set_ylabel("每周消费的冰淇淋公斤数")
  58. plt.show()
  59. def drawScatter3(datingDataMat, datingLabels):
  60. datingDataType1 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==1])
  61. datingDataType2 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==2])
  62. datingDataType3 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==3])
  63. fig, axs = plt.subplots(2, 2, figsize = (15,10))
  64. axs[0,0].scatter(datingDataType1[:,0], datingDataType1[:,1], s = 20, c = 'red')
  65. axs[0,1].scatter(datingDataType2[:,0], datingDataType2[:,1], s = 30, c = 'green')
  66. axs[1,0].scatter(datingDataType3[:,0], datingDataType3[:,1], s = 40, c = 'blue')
  67. type1 = axs[1,1].scatter(datingDataType1[:,0], datingDataType1[:,1], s = 20, c = 'red')
  68. type2 = axs[1,1].scatter(datingDataType2[:,0], datingDataType2[:,1], s = 30, c = 'green')
  69. type3 = axs[1,1].scatter(datingDataType3[:,0], datingDataType3[:,1], s = 40, c = 'blue')
  70. axs[1,1].legend([type1, type2, type3], ["Did Not Like", "Liked in Small Doses", "Liked in Large Doses"], loc=2)
  71. axs[1,1].set_xlabel("玩视频游戏所耗时间百分比")
  72. axs[1,1].set_ylabel("每周消费的冰淇淋公斤数")
  73. plt.show()
  74. def datingClassTest():
  75. hoRatio = 0.20
  76. datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file
  77. normMat = autoNorm(datingDataMat)
  78. m = normMat.shape[0]
  79. numTestVecs = int(m*hoRatio)
  80. errorCount = 0.0
  81. for i in range(numTestVecs):
  82. classifierResult = classify0(normMat[i,:],normMat[numTestVecs:,:],datingLabels[numTestVecs:],3)
  83. if (classifierResult != datingLabels[i]):
  84. errorCount += 1.0
  85. print ("the total error rate is: %f" % (errorCount/float(numTestVecs)))
  86. print (errorCount)
  87. datingDataMat, datingLabels = file2matrix("datingTestSet2.txt")
  88. drawScatter1(datingDataMat, datingLabels)
  89. drawScatter2(datingDataMat, datingLabels)
  90. drawScatter3(datingDataMat, datingLabels)
  91. datingClassTest()

[机器学习笔记]kNN进邻算法的更多相关文章

  1. 机器学习笔记(五) K-近邻算法

    K-近邻算法 (一)定义:如果一个样本在特征空间中的k个最相似的样本中的大多数属于某一个类别,则该样本也属于这个类别. (二)相似的样本,特征之间的值应该是相近的,使用k-近邻算法需要做标准化处理.否 ...

  2. kNN进邻算法

    一.算法概述 (1)采用测量不同特征值之间的距离方法进行分类 优点: 精度高.对异常值不敏感.无数据输入假定. 缺点: 计算复杂度高.空间复杂度高. (2)KNN模型的三个要素 kNN算法模型实际上就 ...

  3. 《机器学习实战》——k-近邻算法Python实现问题记录(转载)

    py2.7 : <机器学习实战> k-近邻算法 11.19 更新完毕 原文链接 <机器学习实战>第二章k-近邻算法,自己实现时遇到的问题,以及解决方法.做个记录. 1.写一个k ...

  4. Python机器学习笔记:异常点检测算法——LOF(Local Outiler Factor)

    完整代码及其数据,请移步小编的GitHub 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/MachineLearningNote 在数据挖掘方面,经常需 ...

  5. 机器学习实战读书笔记(二)k-近邻算法

    knn算法: 1.优点:精度高.对异常值不敏感.无数据输入假定 2.缺点:计算复杂度高.空间复杂度高. 3.适用数据范围:数值型和标称型. 一般流程: 1.收集数据 2.准备数据 3.分析数据 4.训 ...

  6. 机器学习实践之K-近邻算法实践学习

    关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2017年12月04日 22:54:26所撰写内容(http://blog.csdn.n ...

  7. 机器学习实战(一)k-近邻算法

    转载请注明源出处:http://www.cnblogs.com/lighten/p/7593656.html 1.原理 本章介绍机器学习实战的第一个算法——k近邻算法(k Nearest Neighb ...

  8. 吴裕雄--天生自然python机器学习:使用K-近邻算法改进约会网站的配对效果

    在约会网站使用K-近邻算法 准备数据:从文本文件中解析数据 海伦收集约会数据巳经有了一段时间,她把这些数据存放在文本文件(1如1^及抓 比加 中,每 个样本数据占据一行,总共有1000行.海伦的样本主 ...

  9. 机器学习实战笔记(1)——k-近邻算法

    机器学习实战笔记(1) 1. 写在前面 近来感觉机器学习,深度学习神马的是越来越火了,从AlphaGo到Master,所谓的人工智能越来越NB,而我又是一个热爱新潮事物的人,于是也来凑个热闹学习学习. ...

随机推荐

  1. idea+springboot+mybatis逆向工程

    前提:使用idea开发,基于springboot.用到了mybatis的逆向工程 因为之前用eclipse开发ssm比较多,现在转idea 使用springboot 踩了一些坑,在这记录一下~ 注意事 ...

  2. 11.Linux用户特殊权限

    1.特殊权限概述 前面我们已经学习过 r(读).w(写). x(执行)这三种普通权限,但是我们在査询系统文件权限时会发现出现了一些其他权限字母,比如: 2.特殊权限SUID set uid 简称sui ...

  3. oracle计算两个时间的差值(XX天XX时XX分XX秒)

    在工作中需要计算两个时间的差值,结束时间 - 开始时间,又不想在js里写function,也不想在java里去计算,干脆就在数据库做了一个函数来计算两个时间的差值.格式为XX天XX时XX分XX秒: 上 ...

  4. 自然语言处理(NLP)

    苹果语音助手Siri的工作流程: 听 懂 思考 组织语言 回答 这其中每一步骤涉及的流程为: 语音识别 自然语言处理 - 语义分析 逻辑分析 - 结合业务场景与上下文 自然语言处理 - 分析结果生成自 ...

  5. Redis 集群搭建(基于Linux)

    一.基础环境 1.虚拟机 VMware 15.x 2.Linux系统,用的是Centos7的Linux系统 3.Redis数据库版本 5.0.3 二.Redis集群简介 1.背景 Redis在3.0版 ...

  6. css布局两端固定中间自适应

    第一种:采用浮动 1.1首先来看一下网上一个哥们给的代码 <body> <div class="left">左</div> <div cl ...

  7. 通过一个生活中的案例场景,揭开并发包底层AQS的神秘面纱

    本文导读 生活中案例场景介绍 联想到 AQS 到底是什么 AQS 的设计初衷 揭秘 AQS 底层实现 最后的总结 当你在学习某一个技能的时候,是否曾有过这样的感觉,就是同一个技能点学完了之后,过了一段 ...

  8. 一般链表实现集合运算(C语言)

    最近在学习数据结构,遇到以下问题: 假设集合A = (c, b, e, g, f, d),B = (a, b, n, f),利用一般线性链表实现集合运算(A-B)∪(B-A). 分析: 上面的问题只要 ...

  9. 根据多个成对的cron表达式生成的时间段,合并

    场景:数据库一张表,有个startcron 和endcron 两个字段,根据表达式计算今天的所有时间段. 例:startcron :0 30 20 ? * * endcron :0 30 21 ? * ...

  10. 【IT教程-Oracle】尚观Oracle白金级入门教程

    链接: https://pan.baidu.com/s/1GMncQN6mpgaH3hZQjGelaA 提取码: qu6j