【机器学习】k-近邻算法以及算法实例
机器学习中常常要用到分类算法,在诸多的分类算法中有一种算法名为k-近邻算法,也称为kNN算法。
一、kNN算法的工作原理
二、适用情况
三、算法实例及讲解
---1.收集数据
---2.准备数据
---3.设计算法分析数据
---4.测试算法
一、kNN算法的工作原理
官方解释:存在一个样本数据集,也称作训练样本集,并且样本中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系,输入没有标签的新数据后,将新数据的每个特征与样本集中的数据对应的特征进行比较,然后算法提取样本集中特征最相似的数据(最近邻)的分类标签。一般来说,我们只选择样本集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数,最后,选择k个最相似的数据中出现次数最多的分类,作为新数据的分类。
我的理解:k-近邻算法就是根据“新数据的分类取决于它的邻居”进行的,比如邻居中大多数都是退伍军人,那么这个人也极有可能是退伍军人。而算法的目的就是先找出它的邻居,然后分析这几位邻居大多数的分类,极有可能就是它本省的分类。
二、适用情况
优点:精度高,对异常数据不敏感(你的类别是由邻居中的大多数决定的,一个异常邻居并不能影响太大),无数据输入假定;
缺点:计算发杂度高(需要计算新的数据点与样本集中每个数据的“距离”,以判断是否是前k个邻居),空间复杂度高(巨大的矩阵);
适用数据范围:数值型(目标变量可以从无限的数值集合中取值)和标称型(目标变量只有在有限目标集中取值)。
三、算法实例及讲解
例子中的案例摘《机器学习实战》一书中的,代码例子是用python编写的(需要matplotlib和numpy库),不过重在算法,只要算法明白了,用其他语言都是可以写出来的:
海伦一直使用在线约会网站寻找合适自己的约会对象。尽管约会网站会推荐不同的人选,但她没有从中找到喜欢的人。经过一番总结,她发现曾交往过三种类型的人:1.不喜欢的人(以下简称1);2.魅力一般的人(以下简称2) ;3.极具魅力的人(以下简称3)
尽管发现了上述规律,但海伦依然无法将约会网站推荐的匹配对象归入恰当的分类。她觉得可以在周一到周五约会哪些魅力一般的人,而周末则更喜欢与那些极具魅力的人为伴。海伦希望我们的分类软件可以更好的帮助她将匹配对象划分到确切的分类中。此外海伦还收集了一些约会网站未曾记录的数据信息,她认为这些数据更有助于匹配对象的归类。
我们先提取一下这个案例的目标:根据一些数据信息,对指定人选进行分类(1或2或3)。为了使用kNN算法达到这个目标,我们需要哪些信息?前面提到过,就是需要样本数据,仔细阅读我们发现,这些样本数据就是“海伦还收集了一些约会网站未曾记录的数据信息”。好的,下面我们就开始吧!
----1.收集数据
海伦收集的数据是记录一个人的三个特征:每年获得的飞行常客里程数;玩视频游戏所消耗的时间百分比;每周消费的冰淇淋公升数。数据是txt格式文件,如下图,前三列依次是三个特征,第四列是分类(1:不喜欢的人,2:魅力一般的人,3:极具魅力的人),每一行代表一个人。
数据文档的下载链接是:http://pan.baidu.com/s/1jG7n4hS
----2.准备数据
何为准备数据?之前收集到了数据,放到了txt格式的文档中了,看起来也比较规整,但是计算机并不认识啊。计算机需要从txt文档中读取数据,并把数据进行格式化,也就是说存到矩阵中,用矩阵来承装这些数据,这样才能使用计算机处理。
需要两个矩阵:一个承装三个特征数据,一个承装对应的分类。于是,我们定义一个函数,函数的输入时数据文档(txt格式),输出为两个矩阵。
代码如下:
- def file2matrix(filename):
- fr = open(filename)
- numberOfLines = len(fr.readlines())
- returnMat = zeros((numberOfLines, 3))
- classLabelVector = []
- fr = open(filename)
- index = 0
- for line in fr.readlines():
- line = line.strip()
- listFromLine = line.split('\t')
- returnMat[index, :] = listFromLine[0:3]
- classLabelVector.append(int(listFromLine[-1]))
- index += 1
- return returnMat, classLabelVector
简要解读代码:首先打开文件,读取文件的行数,然后初始化之后要返回的两个矩阵(returnMat、classLabelsVector),然后进入循环,将每行的数据各就各位分配给returnMat和classLabelsVector。
----3.设计算法分析数据
k-近邻算法的目的就是找到新数据的前k个邻居,然后根据邻居的分类来确定该数据的分类。
首先要解决的问题,就是什么是邻居?当然就是“距离”近的了,不同人的距离怎么确定?这个有点抽象,不过我们有每个人的3个特征数据。每个人可以使用这三个特征数据来代替这个人——三维点。比如样本的第一个人就可以用(40920, 8.326976, 0.953952)来代替,并且他的分类是3。那么此时的距离就是点的距离:
A点(x1, x2, x3),B点(y1, y2, y3),这两个点的距离就是:(x1-y1)^2+(x2-y2)^2+(x3-y3)^2的平方根。求出新数据与样本中每个点的距离,然后进行从小到大排序,前k位的就是k-近邻,然后看看这k位近邻中占得最多的分类是什么,也就获得了最终的答案。
这个处理过程也是放到一个函数里的,代码如下:
- def classify0(inX, dataSet, labels, k):
- dataSetSize = dataSet.shape[0]
- diffMat = tile(inX, (dataSetSize,1)) - dataSet
- sqDiffMat = diffMat**2
- sqDistances = sqDiffMat.sum(axis=1)
- distances = sqDistances**0.5
- sortedDistIndicies = distances.argsort()
- classCount={}
- for i in range(k):
- voteIlabel = labels[sortedDistIndicies[i]]
- classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
- sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1), reverse=True)
- return sortedClassCount[0][0]
简要解读代码:该函数的4个参数分别为新数据的三个特征inX、样本数据特征集(上一个函数的返回值)、样本数据分类(上一个函数的返回值)、k,函数返回位新数据的分类。第二行dataSetSize获取特征集矩阵的行数,第三行为新数据与样本各个数据的差值,第四行取差值去平方,之后就是再取和,然后平方根。代码中使用的排序函数都是python自带的。
好了,现在我们可以分析数据了,不过,有一点不知道大家有没有注意,我们回到那个数据集,第一列代表的特征数值远远大于其他两项特征,这样在求距离的公式中就会占很大的比重,致使两点的距离很大程度上取决于这个特征,这当然是不公平的,我们需要的三个特征都均平地决定距离,所以我们要对数据进行处理,希望处理之后既不影响相对大小又可以不失公平:
这种方法叫做,归一化数值,通过这种方法可以把每一列的取值范围划到0~1或-1~1:,处理的公式为:
newValue = (oldValue - min)/(max - min)
归一化数值的函数代码为:
- def autoNorm(dataSet):
- minVals = dataSet.min(0)
- maxVals = dataSet.max(0)
- ranges = maxVals - minVals
- normDataSet = zeros(shape(dataSet))
- m = dataSet.shape[0]
- normDataSet = dataSet - tile(minVals, (m, 1))
- normDataSet = normDataSet / tile(ranges, (m, 1))
- return normDataSet, ranges, minVals
---4.测试算法
经过了格式化数据、归一化数值,同时我们也已经完成kNN核心算法的函数,现在可以测试了,测试代码为:
- def datingClassTest():
- hoRatio = 0.10
- datingDataMat, datingLabels = file2matrix('datingTestSet.txt')
- normMat, ranges, minVals = autoNorm(datingDataMat)
- m = normMat.shape[0]
- numTestVecs = int(m * hoRatio)
- errorCount = 0.0
- for i in range(numTestVecs):
- classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
- print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
- if (classifierResult != datingLabels[i]): errorCount += 1.0
- print "the total error rate is: %f" % (errorCount / float(numTestVecs))
通过测试代码我们可以在回忆一下这个例子的整体过程:
- 读取txt文件,提取里面的数据到datingDataMat、datingLabels;
- 归一化数据,得到归一化的数据矩阵;
- 测试数据不止一个,这里需要一个循环,依次对每个测试数据进行分类。
代码中大家可能不太明白hoRatio是什么。注意,这里的测试数据并不是另外一批数据而是之前的数据集里的一部分,这样我们可以把算法得到的结果和原本的分类进行对比,查看算法的准确度。在这里,海伦提供的数据集又1000行,我们把前100行作为测试数据,后900行作为样本数据集,现在大家应该可以明白hoRatio是什么了吧。
整体的代码:
- from numpy import *
- import operator
- def classify0(inX, dataSet, labels, k):
- dataSetSize = dataSet.shape[0]
- diffMat = tile(inX, (dataSetSize,1)) - dataSet
- sqDiffMat = diffMat**2
- sqDistances = sqDiffMat.sum(axis=1)
- distances = sqDistances**0.5
- sortedDistIndicies = distances.argsort()
- classCount={}
- for i in range(k):
- voteIlabel = labels[sortedDistIndicies[i]]
- classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
- sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1), reverse=True)
- return sortedClassCount[0][0]
- def file2matrix(filename):
- fr = open(filename)
- numberOfLines = len(fr.readlines())
- returnMat = zeros((numberOfLines, 3))
- classLabelVector = []
- fr = open(filename)
- index = 0
- for line in fr.readlines():
- line = line.strip()
- listFromLine = line.split('\t')
- returnMat[index, :] = listFromLine[0:3]
- classLabelVector.append(int(listFromLine[-1]))
- index += 1
- return returnMat, classLabelVector
- def autoNorm(dataSet):
- minVals = dataSet.min(0)
- maxVals = dataSet.max(0)
- ranges = maxVals - minVals
- normDataSet = zeros(shape(dataSet))
- m = dataSet.shape[0]
- normDataSet = dataSet - tile(minVals, (m, 1))
- normDataSet = normDataSet / tile(ranges, (m, 1))
- return normDataSet, ranges, minVals
- def datingClassTest():
- hoRatio = 0.10
- datingDataMat, datingLabels = file2matrix('datingTestSet.txt')
- normMat, ranges, minVals = autoNorm(datingDataMat)
- m = normMat.shape[0]
- numTestVecs = int(m * hoRatio)
- errorCount = 0.0
- for i in range(numTestVecs):
- classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
- print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
- if (classifierResult != datingLabels[i]): errorCount += 1.0
- print "the total error rate is: %f" % (errorCount / float(numTestVecs))
运行一下代码,这里我使用的是ipython:
最后的错误率为0.05。
【机器学习】k-近邻算法以及算法实例的更多相关文章
- 机器学习-K近邻(KNN)算法详解
一.KNN算法描述 KNN(K Near Neighbor):找到k个最近的邻居,即每个样本都可以用它最接近的这k个邻居中所占数量最多的类别来代表.KNN算法属于有监督学习方式的分类算法,所谓K近 ...
- [机器学习] k近邻算法
算是机器学习中最简单的算法了,顾名思义是看k个近邻的类别,测试点的类别判断为k近邻里某一类点最多的,少数服从多数,要点摘录: 1. 关键参数:k值 && 距离计算方式 &&am ...
- 机器学习--K近邻 (KNN)算法的原理及优缺点
一.KNN算法原理 K近邻法(k-nearst neighbors,KNN)是一种很基本的机器学习方法. 它的基本思想是: 在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对 ...
- Python3入门机器学习 - k近邻算法
邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一.所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代 ...
- [机器学习]-K近邻-最简单的入门实战例子
本篇文章分为两个部分,前一部分主要简单介绍K近邻,后一部分是一个例子 第一部分--K近邻简介 从字面意思就可以容易看出,所谓的K近邻,就是找到某个样本距离(这里的距离可以是欧式距离,曼哈顿距离,切比雪 ...
- 机器学习—K近邻
一.算法原理 还是图片格式~ 二.sklearn实现 import pandas as pd import numpy as np import matplotlib.pyplot as plt im ...
- 机器学习(1)——K近邻算法
KNN的函数写法 import numpy as np from math import sqrt from collections import Counter def KNN_classify(k ...
- 机器学习 Python实践-K近邻算法
机器学习K近邻算法的实现主要是参考<机器学习实战>这本书. 一.K近邻(KNN)算法 K最近邻(k-Nearest Neighbour,KNN)分类算法,理解的思路是:如果一个样本在特征空 ...
- SIGAI机器学习第七集 k近邻算法
讲授K近邻思想,kNN的预测算法,距离函数,距离度量学习,kNN算法的实际应用. KNN是有监督机器学习算法,K-means是一个聚类算法,都依赖于距离函数.没有训练过程,只有预测过程. 大纲: k近 ...
- 1.K近邻算法
(一)K近邻算法基础 K近邻(KNN)算法优点 思想极度简单 应用数学知识少(近乎为0) 效果好 可以解释机器学习算法使用过程中的很多细节问题 更完整的刻画机器学习应用的流程 图解K近邻算法 上图是以 ...
随机推荐
- 洛谷P3396 哈希冲突 (分块)
洛谷P3396 哈希冲突 题目背景 此题约为NOIP提高组Day2T2难度. 题目描述 众所周知,模数的hash会产生冲突.例如,如果模的数p=7,那么4和11便冲突了. B君对hash冲突很感兴趣. ...
- maven的tomcat插件如何进行debug调试
利用maven来部署工程时,一般采用的是tomcat插件,使项目在tomcat上面运行,那么这个debug调试是如何进行呢? 我们在调试的时候问题: 会提示找不到资源,那么如何进行修改呢,方法两个: ...
- 关于我之前写的修改Windows系统Dos下显示的用户名之再修改测试
最近看到蛮多网友反映,自己修改Dos下用户名后出现了很多的问题--今天抽了时间,再次修改测试... ================= 提前说明:我自己修改了很多次没发现任何问题,<为避免修改可 ...
- ExecuteNonQuery,ExecuteReader,ExecuteScalar 区别
ExecuteNonQuery方法 :执行非查询SQL操作,包括增insert.删delete.改update ExcuteReader方法 :执行查询,返回DataReader,通过DataRead ...
- Spring Filter过滤器,Spring拦截未登录用户权限限制
转载自:http://pouyang.iteye.com/blog/695429 实现的功能:判断用户是否已登录,未登录用户禁止访问任何页面或action,自动跳转到登录页面. 比较好的做法是不管什 ...
- Fiddler进行模拟POST、PUT提交数据注意点
1.请求头要加 Accept: application/xml Content-Type: application/json 2.地址栏url地址后不要忘记加“/” 3.POST和PUT的对象参数都是 ...
- 搜索:BFS
如果题目真的要考察宽度优先搜索,那么这类题目往往具有比较大的编码难度,换个说法,就是细枝末节特别多,状态特别复杂.. 剥茧抽丝,这里以一个比较“裸”的BFS作为例子,了解一下实现BFS的一些规范. 直 ...
- C11内存管理之道:智能指针
1.shared_ptr共享智能指针 std::shared_ptr使用引用计数,每个shared_ptr的拷贝都指向相同的内存,在最后一个shared_ptr析构的时候,内存才会释放. 1.1 基本 ...
- struts2常用标签之数据标签
数据标签1 property标签 property标签的主要属性: value:用来获取值的OGNL表达式,如果value属性值没有指定,那么将会被设定为top,也就是返回位于值栈最顶端的对象. ...
- sscanf的用法
sscanf也太好用了8我竟然一直都不知道qaq #include<cstdio> #include<cstdlib> #include<cstring> #inc ...