【机器学*】k-*邻算法(kNN) 学*笔记

标签(空格分隔): 机器学*


kNN简介

kNN算法是做分类问题的。思想如下:

KNN算法的思想总结一下:就是在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训练集中与之最为相似的前K个数据,则该测试数据对应的类别就是K个数据中出现次数最多的那个分类,其算法的描述为:

  1. 计算测试数据与各个训练数据之间的距离;
  2. 按照距离的递增关系进行排序;
  3. 选取距离最小的K个点;
  4. 确定前K个点所在类别的出现频率;
  5. 返回前K个点中出现频率最高的类别作为测试数据的预测分类。

更为详细的介绍见这个博客:机器学*(一)——K-*邻(KNN)算法
kNN的优缺点见:KNN算法理解
这个博客的内容来自《机器学*实战》一书。

这个博客主要讲解kNN的python实现,把每行的代码都弄明白。

kNN代码实现

下面classify0()就是kNN,这些代码做了对一个点的分类。

  1. # coding=utf-8
  2. import operator
  3. from numpy import *
  4. def createDataSet():
  5. group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
  6. labels = ['A', 'A', 'B', 'B']
  7. return group, labels
  8. def classify0(inX, dataSet, labels, k):
  9. dataSetSize = dataSet.shape[0]
  10. # 矩阵有一个shape属性,是一个(行,列)形式的元组
  11. diffMat = tile(inX, (dataSetSize, 1)) - dataSet
  12. # 输入的点到每个点的横纵坐标差
  13. # tile是把矩阵重复多次
  14. sqDiffMat = diffMat ** 2
  15. # 横纵坐标差的平方
  16. sqDistances = sqDiffMat.sum(axis=1)
  17. # axis=0, 表示列。axis=1, 表示行。
  18. distances = sqDistances ** 0.5
  19. # 开方
  20. sortedDistIndicies = distances.argsort()
  21. # argsort函数返回的是数组值从小到大的索引值
  22. classCount = {}
  23. # 保存A,B出现次数的字典
  24. for i in range(k):
  25. voteIlabel = labels[sortedDistIndicies[i]]
  26. # 获取索引值对应的是A还是B
  27. classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
  28. # 在字典中保存A,B出现的次数
  29. sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
  30. # 按照A,B出现的次数排序
  31. return sortedClassCount[0][0] # 返回A,B出现最多的那个
  32. group, labels = createDataSet()
  33. answer = classify0([0, 0], group, labels, 3)
  34. print answer

关于tile()函数,可以见文章:【python】tile函数简单介绍
关于sorted()函数:

  1. sorted函数sorted(iterable, cmp=None, key=None, reverse=False)
  2. iterable:是可迭代类型;
  3. cmp:用于比较的函数,比较什么由key决定;
  4. key:用列表元素的某个属性或函数进行作为关键字,有默认值,迭代集合中的一项;
  5. operator.itemgetter(1)表示用第2个数据项排序
  6. reverse:排序规则. reverse = True 降序 或者 reverse = False 升序,有默认值。

kNN实战一 改进约会网站配对效果

我只给出代码和每行代码的解释,这个实战项目的更具体介绍见:机器学*(一)——K-*邻(KNN)算法

  1. # coding=utf-8
  2. import operator
  3. from numpy import *
  4. import matplotlib
  5. import matplotlib.pyplot as plt
  6. def classify0(inX, dataSet, labels, k):
  7. """
  8. :param inX: 样本点
  9. :param dataSet: 初始样本集合
  10. :param labels: 样本集合对应的标签集合
  11. :param k: 选取的k
  12. :return: kNN分类结果
  13. """
  14. dataSetSize = dataSet.shape[0]
  15. # 矩阵有一个shape属性,是一个(行,列)形式的元组
  16. diffMat = tile(inX, (dataSetSize, 1)) - dataSet # 输入的点到每个点的横纵坐标差
  17. # tile是把矩阵重复多次
  18. sqDiffMat = diffMat ** 2 # 横纵坐标差的平方
  19. sqDistances = sqDiffMat.sum(axis=1) # axis=0, 表示列。axis=1, 表示行。
  20. distances = sqDistances ** 0.5 # 开方
  21. sortedDistIndicies = distances.argsort() # argsort函数返回的是数组值从小到大的索引值
  22. classCount = {} # 保存A,B出现次数的字典
  23. for i in range(k):
  24. voteIlabel = labels[sortedDistIndicies[i]] # 获取索引值对应的是A还是B
  25. classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 # 在字典中保存A,B出现的次数
  26. sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
  27. # 按照A,B出现的次数排序
  28. # sorted函数sorted(iterable, cmp=None, key=None, reverse=False)
  29. '''
  30. iterable:是可迭代类型;
  31. cmp:用于比较的函数,比较什么由key决定;
  32. key:用列表元素的某个属性或函数进行作为关键字,有默认值,迭代集合中的一项;
  33. operator.itemgetter(1)表示用第2个数据项排序
  34. reverse:排序规则. reverse = True 降序 或者 reverse = False 升序,有默认值。
  35. '''
  36. return sortedClassCount[0][0] # 返回A,B出现最多的那个
  37. group, labels = createDataSet()
  38. answer = classify0([0, 0], group, labels, 3)
  39. print answer
  40. def file2matrix(filename):
  41. """
  42. :param filename: 文件名称
  43. :return: 文件中的数据和标签
  44. """
  45. fr = open(filename)
  46. arrayOLines = fr.readlines()
  47. numberOfLines = len(arrayOLines)
  48. # 获取文件的行数
  49. returnMat = zeros((numberOfLines, 3))
  50. # 创建返回的NumPy矩阵,二维矩阵
  51. # zeros函数功能是创建给定类型的矩阵,并初始化为0
  52. classLabelVector = []
  53. # 创建返回的标签
  54. index = 0 # index
  55. for line in arrayOLines:
  56. # 循环每列
  57. line = line.strip()
  58. # 去除每行回车字符
  59. listFromLine = line.split('\t')
  60. # 分割
  61. returnMat[index, :] = listFromLine[0:3]
  62. # 把数据的前三列都放到要返回的矩阵中,3这个索引是不包括的
  63. classLabelVector.append(int(listFromLine[-1]))
  64. # 把数据的每列最后一个元素转换成整数放到标签list里
  65. index += 1
  66. # index自增
  67. return returnMat, classLabelVector
  68. def autoNorm(dataSet):
  69. """
  70. :param dataSet: 数据集
  71. :return: 归一化结果
  72. """
  73. minVals = dataSet.min(0) # 0代表列
  74. maxVals = dataSet.max(0)
  75. ranges = maxVals - minVals
  76. normDataSet = zeros(shape(dataSet)) # 创建了行列数与dataSet一致的全0矩阵
  77. m = dataSet.shape[0] # 行数
  78. normDataSet = dataSet - tile(minVals, (m, 1)) # 每个元素都减去该列最小值
  79. normDataSet = normDataSet / tile(ranges, (m, 1)) # 具体数值的除,归一化;不是矩阵相除
  80. return normDataSet, ranges, minVals
  81. def datingClassTest():
  82. """
  83. 测试算法的函数
  84. :return:
  85. """
  86. hoRatio = 0.10 # hold out 10%
  87. datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') # load data setfrom file
  88. normMat, ranges, minVals = autoNorm(datingDataMat) # 归一化
  89. m = normMat.shape[0] # 行数
  90. numTestVecs = int(m * hoRatio) # 抽出的行数
  91. print "numTestVecs", numTestVecs
  92. errorCount = 0.0 # 错误率
  93. for i in range(numTestVecs):
  94. classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
  95. print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
  96. if classifierResult != datingLabels[i]:
  97. errorCount += 1.0
  98. print "the total error rate is: %f" % (errorCount / float(numTestVecs))
  99. print errorCount
  100. def classifyPerson():
  101. """
  102. 用户输入点作为测试点
  103. :return: 无
  104. """
  105. resultList = ['not at all', 'in small doses', 'in large doses']
  106. percentTats = float(raw_input("percentage of time spent playing video games?"))
  107. ffMiles = float(raw_input("frequent flier miles earned per year?"))
  108. iceCream = float(raw_input("liters of ice cream consumed per year?"))
  109. datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') # load data setfrom file
  110. normMat, ranges, minVals = autoNorm(datingDataMat) # 归一化
  111. inArr = array([ffMiles, percentTats, iceCream]) # 把用户输出的点当做要求点
  112. classifierResult = classify0((inArr - minVals) / ranges, normMat, datingLabels, 3) # 用kNN做分类
  113. print "you weill probably like this person:", resultList[classifierResult - 1] # 转换成真名
  114. datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
  115. fig = plt.figure()
  116. ax = fig.add_subplot(111)
  117. ax.scatter(datingDataMat[:, 0], datingDataMat[:, 1], 15.0 * array(datingLabels), 15.0 * array(datingLabels))
  118. plt.show()
  119. datingClassTest()
  120. classifyPerson()

上述程序将把数据画出图来,然后计算kNN判断准确率,并且最后让用户输入数据,对该数据进行分类。

kNN实战二:手写体识别

数据集下载: http://www.ituring.com.cn/book/download/0019ab9d-0fda-4c17-941b-afe639fcccac

  1. def img2vector(filename):
  2. returnVect = zeros((1, 1024)) # 每个文件一行结果,1024个0
  3. fr = open(filename) # 打开文件
  4. for i in range(32): # 遍历32行
  5. lineStr = fr.readline() # 读行
  6. for j in range(32): # 读每个字符
  7. returnVect[0, 32 * i + j] = int(lineStr[j]) # 把字符放到结果中
  8. return returnVect
  9. def handwritingClassTest():
  10. hwLabels = [] # 保存标签
  11. trainingFileList = listdir('digits/trainingDigits') # 加载训练集
  12. m = len(trainingFileList) # 训练集文件个数
  13. trainingMat = zeros((m, 1024)) # 训练集数据矩阵
  14. for i in range(m): # 遍历文件
  15. fileNameStr = trainingFileList[i] # 训练集文件名
  16. fileStr = fileNameStr.split('.')[0] # 去掉文件名结尾的.txt
  17. classNumStr = int(fileStr.split('_')[0]) # 把文件名分割之后,获得前半部分,即这个文件表示的字符标签
  18. hwLabels.append(classNumStr) # 把文件表示的字符标签放到标签list中
  19. trainingMat[i, :] = img2vector('digits/trainingDigits/%s' % fileNameStr)
  20. # 把每个文件中的字符画转成行向量
  21. testFileList = listdir('digits/testDigits') # 得到测试集所有文件目录
  22. errorCount = 0.0 # 错误率
  23. mTest = len(testFileList) # 测试集长度
  24. for i in range(mTest): # 遍历测试集
  25. fileNameStr = testFileList[i] # 测试集文件名
  26. fileStr = fileNameStr.split('.')[0] # 去掉文件名结尾的.txt
  27. classNumStr = int(fileStr.split('_')[0]) # 把文件名分割之后,获得前半部分,即这个文件表示的字符标签
  28. vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
  29. # 把每个文件中的字符画转成行向量
  30. classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) # 做分类
  31. print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
  32. if classifierResult != classNumStr: errorCount += 1.0
  33. print "\nthe total number of errors is: %d" % errorCount
  34. print "\nthe total error rate is: %f" % (errorCount / float(mTest))
  35. handwritingClassTest()

这个算法的执行效率并不高,每个测试向量做2000次的距离计算,每个距离计算包括了1024个维度浮点计算,总计要执行900次。

最后运行结果,错误率1.2%:

  1. the total number of errors is: 11
  2. the total error rate is: 0.011628

kNN是分类问题最有效最简单的算法,但是要保存全部数据集,对每个数据计算距离值。实际使用很耗时。而且无法给出任何数据的基础结构信息,无法知晓平均实例样本和典型实例样本之间具有什么特征。

这篇博客是对《机器学*实战》一书的学*笔记,如有不明白之处,请阅读该书。

【机器学*】k-*邻算法(kNN) 学*笔记的更多相关文章

  1. 机器学*——K*邻算法(KNN)

    1 前言 Kjin邻法(k-nearest neighbors,KNN)是一种基本的机器学*方法,采用类似"物以类聚,人以群分"的思想.比如,判断一个人的人品,只需观察他来往最密切 ...

  2. 【机器学*与R语言】2-懒惰学*K*邻(kNN)

    目录 1.理解使用KNN进行分类 KNN特点 KNN步骤 1)计算距离 2)选择合适的K 3)数据准备 2.用KNN诊断乳腺癌 1)收集数据 2)探索和准备数据 3)训练模型 4)评估模型的性能 5) ...

  3. 【机器学*】k*邻算法-02

    k邻*算法具体应用:2-2约会网站配对 心得体会: 1.对所有特征值进行归一化处理:将特征值单位带来的距离影响消除,使所有特征同权重--然后对不同的特征进行加权2.对于相互独立的特征,可以通过建立(特 ...

  4. k近邻算法(KNN)

    k近邻算法(KNN) 定义:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别. from sklearn.model_selection ...

  5. 什么是机器学习的分类算法?【K-近邻算法(KNN)、交叉验证、朴素贝叶斯算法、决策树、随机森林】

    1.K-近邻算法(KNN) 1.1 定义 (KNN,K-NearestNeighbor) 如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类 ...

  6. 【机器学*】k*邻算法-01

    k临*算法(解决分类问题): 已知数据集,以及该数据对应类型 给出一个数据x,在已知数据集中选择最接*x的k条数据,根据这k条数据的类型判断x的类型 具体实现: from numpy import * ...

  7. <机器学习实战>读书笔记--k邻近算法KNN

    k邻近算法的伪代码: 对未知类别属性的数据集中的每个点一次执行以下操作: (1)计算已知类别数据集中的点与当前点之间的距离: (2)按照距离递增次序排列 (3)选取与当前点距离最小的k个点 (4)确定 ...

  8. 一看就懂的K近邻算法(KNN),K-D树,并实现手写数字识别!

    1. 什么是KNN 1.1 KNN的通俗解释 何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1 ...

  9. K-近邻算法kNN

    K-近邻算法(k-Nearest Neighbor,简称kNN)采用测量不同特征值之间的距离方法进行分类,是一种常用的监督学习方法,其工作机制很简单:给定测试样本,基于某种距离亮度找出训练集中与其靠近 ...

  10. k邻近算法(KNN)实例

    一 k近邻算法原理 k近邻算法是一种基本分类和回归方法. 原理:K近邻算法,即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的K个实例,这K个实例的多数属于某个类,就把该输入实 ...

随机推荐

  1. 通过yum安装 memcache

    . 通过yum安装 复制代码代码如下: yum -y install memcached#安装完成后执行:memcached -h#出现memcached帮助信息说明安装成功 2. 加入启动服务 复制 ...

  2. Oracle基础入门

    说明:钓鱼君昨天在网上找到一份oracle项目实战的文档,粗略看了一下大致内容,感觉自己很多知识不够扎实,便跟着文档敲了一遍,目前除了机械性代码没有实现外,主要涉及知识:创建表空间.创建用户.给用户赋 ...

  3. 云原生PaaS平台通过插件整合SkyWalking,实现APM即插即用

    一. 简介 SkyWalking 是一个开源可观察性平台,用于收集.分析.聚合和可视化来自服务和云原生基础设施的数据.支持分布式追踪.性能指标分析.应用和服务依赖分析等:它是一种现代 APM,专为云原 ...

  4. 日常Java 2021/10/18

    Vecter类实现了一个动态数组,不同于ArrayList的是,Vecter是同步访问的, Vecter主要用在事先不知道数组的大小或可以改变大小的数组 Vecter类支持多种构造方法:Vecter( ...

  5. vue-baidu-map相关随笔

    一,使用vue-baidu-map 1.下载相关包依赖 npm i vue-baidu-map   2.在main.js中import引入相关包依赖,在main.js中添加如下代码: import B ...

  6. MySQL压力测试工具

    一.MySQL自带的压力测试工具--Mysqlslap mysqlslap是mysql自带的基准测试工具,该工具查询数据,语法简单,灵活容易使用.该工具可以模拟多个客户端同时并发的向服务器发出查询更新 ...

  7. 【Elasticsearch-Java】Java客户端搭建

    Elasticsearch Java高级客户端   1.  概述 Java REST Client 有两种风格: Java Low Level REST Client :用于Elasticsearch ...

  8. Output of C++ Program | Set 8

    Predict the output of following C++ programs. Question 1 1 #include<iostream> 2 using namespac ...

  9. 增大Oracle Virtualbox的磁盘空间

    https://blog.csdn.net/hiyachen/article/details/102131823 背景 在virtualbox中装好Linux以及Application之后,发现硬盘空 ...

  10. 接口测试 python+PyCharm 环境搭建

    1.配置Python环境变量 a:我的电脑->属性->高级系统设置->环境变量->系统变量中的PATH变量. 变量名:PATH      修改变量值为:;C:\Python27 ...