【机器学习】k-近邻算法以及算法实例
机器学习中常常要用到分类算法,在诸多的分类算法中有一种算法名为k-近邻算法,也称为kNN算法。
一、kNN算法的工作原理
二、适用情况
三、算法实例及讲解
---1.收集数据
---2.准备数据
---3.设计算法分析数据
---4.测试算法
一、kNN算法的工作原理
官方解释:存在一个样本数据集,也称作训练样本集,并且样本中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系,输入没有标签的新数据后,将新数据的每个特征与样本集中的数据对应的特征进行比较,然后算法提取样本集中特征最相似的数据(最近邻)的分类标签。一般来说,我们只选择样本集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数,最后,选择k个最相似的数据中出现次数最多的分类,作为新数据的分类。
我的理解:k-近邻算法就是根据“新数据的分类取决于它的邻居”进行的,比如邻居中大多数都是退伍军人,那么这个人也极有可能是退伍军人。而算法的目的就是先找出它的邻居,然后分析这几位邻居大多数的分类,极有可能就是它本省的分类。
二、适用情况
优点:精度高,对异常数据不敏感(你的类别是由邻居中的大多数决定的,一个异常邻居并不能影响太大),无数据输入假定;
缺点:计算发杂度高(需要计算新的数据点与样本集中每个数据的“距离”,以判断是否是前k个邻居),空间复杂度高(巨大的矩阵);
适用数据范围:数值型(目标变量可以从无限的数值集合中取值)和标称型(目标变量只有在有限目标集中取值)。
三、算法实例及讲解
例子中的案例摘《机器学习实战》一书中的,代码例子是用python编写的(需要matplotlib和numpy库),不过重在算法,只要算法明白了,用其他语言都是可以写出来的:
海伦一直使用在线约会网站寻找合适自己的约会对象。尽管约会网站会推荐不同的人选,但她没有从中找到喜欢的人。经过一番总结,她发现曾交往过三种类型的人:1.不喜欢的人(以下简称1);2.魅力一般的人(以下简称2) ;3.极具魅力的人(以下简称3)
尽管发现了上述规律,但海伦依然无法将约会网站推荐的匹配对象归入恰当的分类。她觉得可以在周一到周五约会哪些魅力一般的人,而周末则更喜欢与那些极具魅力的人为伴。海伦希望我们的分类软件可以更好的帮助她将匹配对象划分到确切的分类中。此外海伦还收集了一些约会网站未曾记录的数据信息,她认为这些数据更有助于匹配对象的归类。
我们先提取一下这个案例的目标:根据一些数据信息,对指定人选进行分类(1或2或3)。为了使用kNN算法达到这个目标,我们需要哪些信息?前面提到过,就是需要样本数据,仔细阅读我们发现,这些样本数据就是“海伦还收集了一些约会网站未曾记录的数据信息”。好的,下面我们就开始吧!
----1.收集数据
海伦收集的数据是记录一个人的三个特征:每年获得的飞行常客里程数;玩视频游戏所消耗的时间百分比;每周消费的冰淇淋公升数。数据是txt格式文件,如下图,前三列依次是三个特征,第四列是分类(1:不喜欢的人,2:魅力一般的人,3:极具魅力的人),每一行代表一个人。
数据文档的下载链接是:http://pan.baidu.com/s/1jG7n4hS
----2.准备数据
何为准备数据?之前收集到了数据,放到了txt格式的文档中了,看起来也比较规整,但是计算机并不认识啊。计算机需要从txt文档中读取数据,并把数据进行格式化,也就是说存到矩阵中,用矩阵来承装这些数据,这样才能使用计算机处理。
需要两个矩阵:一个承装三个特征数据,一个承装对应的分类。于是,我们定义一个函数,函数的输入时数据文档(txt格式),输出为两个矩阵。
代码如下:
def file2matrix(filename):
fr = open(filename)
numberOfLines = len(fr.readlines())
returnMat = zeros((numberOfLines, 3))
classLabelVector = []
fr = open(filename)
index = 0
for line in fr.readlines():
line = line.strip()
listFromLine = line.split('\t')
returnMat[index, :] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat, classLabelVector
简要解读代码:首先打开文件,读取文件的行数,然后初始化之后要返回的两个矩阵(returnMat、classLabelsVector),然后进入循环,将每行的数据各就各位分配给returnMat和classLabelsVector。
----3.设计算法分析数据
k-近邻算法的目的就是找到新数据的前k个邻居,然后根据邻居的分类来确定该数据的分类。
首先要解决的问题,就是什么是邻居?当然就是“距离”近的了,不同人的距离怎么确定?这个有点抽象,不过我们有每个人的3个特征数据。每个人可以使用这三个特征数据来代替这个人——三维点。比如样本的第一个人就可以用(40920, 8.326976, 0.953952)来代替,并且他的分类是3。那么此时的距离就是点的距离:
A点(x1, x2, x3),B点(y1, y2, y3),这两个点的距离就是:(x1-y1)^2+(x2-y2)^2+(x3-y3)^2的平方根。求出新数据与样本中每个点的距离,然后进行从小到大排序,前k位的就是k-近邻,然后看看这k位近邻中占得最多的分类是什么,也就获得了最终的答案。
这个处理过程也是放到一个函数里的,代码如下:
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
简要解读代码:该函数的4个参数分别为新数据的三个特征inX、样本数据特征集(上一个函数的返回值)、样本数据分类(上一个函数的返回值)、k,函数返回位新数据的分类。第二行dataSetSize获取特征集矩阵的行数,第三行为新数据与样本各个数据的差值,第四行取差值去平方,之后就是再取和,然后平方根。代码中使用的排序函数都是python自带的。
好了,现在我们可以分析数据了,不过,有一点不知道大家有没有注意,我们回到那个数据集,第一列代表的特征数值远远大于其他两项特征,这样在求距离的公式中就会占很大的比重,致使两点的距离很大程度上取决于这个特征,这当然是不公平的,我们需要的三个特征都均平地决定距离,所以我们要对数据进行处理,希望处理之后既不影响相对大小又可以不失公平:
这种方法叫做,归一化数值,通过这种方法可以把每一列的取值范围划到0~1或-1~1:,处理的公式为:
newValue = (oldValue - min)/(max - min)
归一化数值的函数代码为:
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m, 1))
normDataSet = normDataSet / tile(ranges, (m, 1))
return normDataSet, ranges, minVals
---4.测试算法
经过了格式化数据、归一化数值,同时我们也已经完成kNN核心算法的函数,现在可以测试了,测试代码为:
def datingClassTest():
hoRatio = 0.10
datingDataMat, datingLabels = file2matrix('datingTestSet.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m * hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
if (classifierResult != datingLabels[i]): errorCount += 1.0
print "the total error rate is: %f" % (errorCount / float(numTestVecs))
通过测试代码我们可以在回忆一下这个例子的整体过程:
- 读取txt文件,提取里面的数据到datingDataMat、datingLabels;
- 归一化数据,得到归一化的数据矩阵;
- 测试数据不止一个,这里需要一个循环,依次对每个测试数据进行分类。
代码中大家可能不太明白hoRatio是什么。注意,这里的测试数据并不是另外一批数据而是之前的数据集里的一部分,这样我们可以把算法得到的结果和原本的分类进行对比,查看算法的准确度。在这里,海伦提供的数据集又1000行,我们把前100行作为测试数据,后900行作为样本数据集,现在大家应该可以明白hoRatio是什么了吧。
整体的代码:
from numpy import *
import operator def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0] def file2matrix(filename):
fr = open(filename)
numberOfLines = len(fr.readlines())
returnMat = zeros((numberOfLines, 3))
classLabelVector = []
fr = open(filename)
index = 0
for line in fr.readlines():
line = line.strip()
listFromLine = line.split('\t')
returnMat[index, :] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat, classLabelVector def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m, 1))
normDataSet = normDataSet / tile(ranges, (m, 1))
return normDataSet, ranges, minVals def datingClassTest():
hoRatio = 0.10
datingDataMat, datingLabels = file2matrix('datingTestSet.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m * hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
if (classifierResult != datingLabels[i]): errorCount += 1.0
print "the total error rate is: %f" % (errorCount / float(numTestVecs))
运行一下代码,这里我使用的是ipython:
最后的错误率为0.05。
【机器学习】k-近邻算法以及算法实例的更多相关文章
- 机器学习-K近邻(KNN)算法详解
一.KNN算法描述 KNN(K Near Neighbor):找到k个最近的邻居,即每个样本都可以用它最接近的这k个邻居中所占数量最多的类别来代表.KNN算法属于有监督学习方式的分类算法,所谓K近 ...
- [机器学习] k近邻算法
算是机器学习中最简单的算法了,顾名思义是看k个近邻的类别,测试点的类别判断为k近邻里某一类点最多的,少数服从多数,要点摘录: 1. 关键参数:k值 && 距离计算方式 &&am ...
- 机器学习--K近邻 (KNN)算法的原理及优缺点
一.KNN算法原理 K近邻法(k-nearst neighbors,KNN)是一种很基本的机器学习方法. 它的基本思想是: 在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对 ...
- Python3入门机器学习 - k近邻算法
邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一.所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代 ...
- [机器学习]-K近邻-最简单的入门实战例子
本篇文章分为两个部分,前一部分主要简单介绍K近邻,后一部分是一个例子 第一部分--K近邻简介 从字面意思就可以容易看出,所谓的K近邻,就是找到某个样本距离(这里的距离可以是欧式距离,曼哈顿距离,切比雪 ...
- 机器学习—K近邻
一.算法原理 还是图片格式~ 二.sklearn实现 import pandas as pd import numpy as np import matplotlib.pyplot as plt im ...
- 机器学习(1)——K近邻算法
KNN的函数写法 import numpy as np from math import sqrt from collections import Counter def KNN_classify(k ...
- 机器学习 Python实践-K近邻算法
机器学习K近邻算法的实现主要是参考<机器学习实战>这本书. 一.K近邻(KNN)算法 K最近邻(k-Nearest Neighbour,KNN)分类算法,理解的思路是:如果一个样本在特征空 ...
- SIGAI机器学习第七集 k近邻算法
讲授K近邻思想,kNN的预测算法,距离函数,距离度量学习,kNN算法的实际应用. KNN是有监督机器学习算法,K-means是一个聚类算法,都依赖于距离函数.没有训练过程,只有预测过程. 大纲: k近 ...
- 1.K近邻算法
(一)K近邻算法基础 K近邻(KNN)算法优点 思想极度简单 应用数学知识少(近乎为0) 效果好 可以解释机器学习算法使用过程中的很多细节问题 更完整的刻画机器学习应用的流程 图解K近邻算法 上图是以 ...
随机推荐
- 在CentOS 6.5 中安装JDK 1.7 + Eclipse并配置opencv的java开发环境(二)
一.安装JDK 1.7 1. 卸载OpenJDK rpm -qa | grep java rpm -e --nodeps java-1.6.0-openjdk-1.6.0.0-1.50.1.11.5. ...
- beego入门小坑
刚接触beego,按照官网的文档操作,始终发现在orm操作数据的时候提示表不存在,数据库连接设置都没问题 "0 Error 1146: Table 'beego.archives' does ...
- JAVA 枚举单例模式
1.枚举单例模式的实现 public enum Singleton { INSTANCE { @Override protected void read() { System.out.println ...
- Android实现自动定位城市并获取天气信息
定位实现代码: <span style="font-size:14px;">import java.io.IOException; import java.util.L ...
- 基于Mysql数据库亿级数据下的分库分表方案
移动互联网时代,海量的用户数据每天都在产生,基于用户使用数据的用户行为分析等这样的分析,都需要依靠数据都统计和分析,当数据量小时,问题没有暴露出来,数据库方面的优化显得不太重要,一旦数据量越来越大时, ...
- Java 异常(Java Exception)
Java异常 异常指不期而至的各种状况,如:文件找不到.网络连接失败.非法参数等.异常是一个事件,它发生在程序运行期间,干扰了正常的指令流程.Java通 过API中Throwable类的众多子类 ...
- C# 从串口读取数据
最近要做系统集成,需要从串口读取数据,随学习一下相关知识: 以下是从串口读取数据 public static void Main() { SerialPort mySerialPort = new S ...
- IIS 搭建
1. 在打开程序功能里面,点击IIS安装.注意要选择适当的各种有用的服务.例如默认文档就需要安装非IIS下面的选项. 2. IIS部署网站可以参考网上的步骤.会遇到500处理程序“Extensionl ...
- 【BZOJ1926】【SDOI2010】粟粟的书架 [主席树]
粟粟的书架 Time Limit: 30 Sec Memory Limit: 552 MB[Submit][Status][Discuss] Description 幸福幼儿园 B29 班的粟粟是一 ...
- 每个 Java 开发者都应该知道的 5 个注解
自 JDK5 推出以来,注解已成为Java生态系统不可缺少的一部分.虽然开发者为Java框架(例如Spring的@Autowired)开发了无数的自定义注解,但编译器认可的一些注解非常重要. 在本文中 ...