K-近邻算法python实现
内容主要来源于机器学习实战这本书。加上自己的理解。
1.KNN算法的简单描写叙述
K近期邻(k-Nearest Neighbor。KNN)分类算法能够说是最简单的机器学习算法了。
它採用測量不同特征值之间的距离方法进行分类。
它的思想非常easy:假设一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别。则该样本也属于这个类别。
下图是大家引用的一个最经典演示样例图。
比方上面这个图,我们有两类数据,各自是蓝色方块和红色三角形,他们分布在一个上图的二维中间中。
那么假如我们有一个绿色圆圈这个数据,须要推断这个数据是属于蓝色方块这一类。还是与红色三角形同类。怎么做呢?我们先把离这个绿色圆圈近期的几个点找到。因为我们觉得离绿色圆圈近期的才对它的类别有推断的帮助。那究竟要用多少个来推断呢?这个个数就是k了。
假设k=3。就表示我们选择离绿色圆圈近期的3个点来推断,因为红色三角形所占比例为2/3。所以我们觉得绿色圆是和红色三角形同类。假设k=5。因为蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。从这里能够看到。k的值选取非常重要的。
KNN算法中。所选择的邻居都是已经正确分类的对象。
该方法在定类决策上仅仅根据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。
因为KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的。因此对于类域的交叉或重叠较多的待分样本集来说。KNN方法较其它方法更为适合。
该算法在分类时有个基本的不足是。当样本不平衡时。如一个类的样本容量非常大。而其它类样本容量非常小时。有可能导致当输入一个新样本时。该样本的K个邻居中大容量类的样本占多数。
因此能够採用权值的方法(和该样本距离小的邻居权值大)来改进。该方法的还有一个不足之处是计算量较大。由于对每个待分类的文本都要计算它到全体已知样本的距离。才干求得它的K个近期邻点。
眼下经常使用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。
该算法比較适用于样本容量比較大的类域的自己主动分类,而那些样本容量较小的类域採用这样的算法比較easy产生误分。
总的来说就是我们已经存在了一个带标签的数据比对库,然后输入没有标签的新数据后。将新数据的每一个特征与样本集中数据相应的特征进行比較。然后算法提取样本集中特征最相似(近期邻)的分类标签。一般来说,仅仅选择样本数据库中前k个最相似的数据。
最后,选择k个最相似数据中出现次数最多的分类。
其算法描写叙述例如以下:
1)计算已知类别数据集中的点与当前点之间的距离;
2)依照距离递增次序排序;
3)选取与当前点距离最小的k个点;
4)确定前k个点所在类别的出现频率;
5)返回前k个点出现频率最高的类别作为当前点的预測分类。
二:python程序部分
2.1 python导入数据
def createDataSet():
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
return group, labels
创建了数据集和标签。
依据上面说到的算法描写叙述中五个步骤K-近邻算法核心部分程序:
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet # tile :construct array by repeating inX dataSetSize times
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5 # get distance
sortedDistIndicies = distances.argsort() # return ordered array's index
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
不知道是不是编码设置问题,凝视没法写成中文,仅仅能是英文。
K-近邻算法书上应用到了改进约会站点的配对效果上面详细流程:
准备数据部分:从文本文件里解析数据,文本中说到3种特征:飞行里程、玩游戏时间、消费冰淇淋数量。我不知道作者为什么选择这三种特征,好像跟约会配对没什么毛关系。
这部分用到非常多numpy中处理矩阵的函数。
def file2matrix(filename):
fr = open(filename)
numberOfLines = len(fr.readlines()) #get the number of lines in the file
returnMat = zeros((numberOfLines,3)) #prepare matrix to return
classLabelVector = [] #prepare labels return
fr = open(filename)
index = 0
for line in fr.readlines():
line = line.strip() # delete character like tab or backspace
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3] # get 3 features
classLabelVector.append(int(listFromLine[-1])) # get classify result
index += 1
return returnMat,classLabelVector
处理数据中涉及到数据值的归一化。
意思就是说上面约会配对有三个特征,可是会发现飞行距离这个数值远远大于其他两个,为了体现3个特征同样的影响力,对数据进行归一化。
def autoNorm(dataSet):
minVals = dataSet.min(0) # select least value in column
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m,1))
normDataSet = normDataSet/tile(ranges, (m,1)) #element wise divide
return normDataSet, ranges, minVals
另外一个应用是在手写识别系统。
类似于前面约会站点应用,准备数据时须要进行图像到向量转换,然后调用K-近邻的核心算法实现。
以下是全部的代码综合和測试代码:主函数里加入了一些matplotlib绘图測试代码
'''
kNN: k Nearest Neighbors Input: inX: vector to compare to existing dataset (1xN)
dataSet: size m data set of known vectors (NxM)
labels: data set labels (1xM vector)
k: number of neighbors to use for comparison (should be an odd number) Output: the most popular class label '''
from numpy import *
import operator
from os import listdir
import matplotlib
import matplotlib.pyplot as plt def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet # tile :construct array by repeating inX dataSetSize times
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5 # get distance
sortedDistIndicies = distances.argsort() # return ordered array's index
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0] def createDataSet():
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
return group, labels def file2matrix(filename):
fr = open(filename)
numberOfLines = len(fr.readlines()) #get the number of lines in the file
returnMat = zeros((numberOfLines,3)) #prepare matrix to return
classLabelVector = [] #prepare labels return
fr = open(filename)
index = 0
for line in fr.readlines():
line = line.strip() # delete character like tab or backspace
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3] # get 3 features
classLabelVector.append(int(listFromLine[-1])) # get classify result
index += 1
return returnMat,classLabelVector def autoNorm(dataSet):
minVals = dataSet.min(0) # select least value in column
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m,1))
normDataSet = normDataSet/tile(ranges, (m,1)) #element wise divide
return normDataSet, ranges, minVals def datingClassTest():
hoRatio = 0.50 #hold out 10%
datingDataMat,datingLabels = file2matrix('E:\PythonMachine Learning in Action\datingTestSet2.txt') #load data setfrom file
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
print m
numTestVecs = int(m*hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
if (classifierResult != datingLabels[i]): errorCount += 1.0
print "the total error rate is: %f" % (errorCount/float(numTestVecs))
print errorCount
def classifyperson():
resultList = ['not at all','in small doses','in large doses']
percentTats = float(raw_input('percentage time spent on games ?'))
ffmiles = float(raw_input('frequent flier miles per year? '))
iceCream = float(raw_input('liters of ice cream consumed each year?'))
datingDataMat,datingLabels = file2matrix('E:\PythonMachine Learning in Action\datingTestSet2.txt') #load data setfrom file
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = array([ffmiles,percentTats,iceCream])
classifierResult = classify0((inArr-minVals)/ranges,normMat,datingLabels,3)
print "your probably like this person :" ,\
resultList[classifierResult-1]
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('E:/PythonMachine Learning in Action/trainingDigits') #load the training set
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('E:/PythonMachine Learning in Action/trainingDigits/%s' % fileNameStr)
testFileList = listdir('E:/PythonMachine Learning in Action/testDigits') #iterate through the test set
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('E:/PythonMachine Learning in Action/testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
if (classifierResult != classNumStr): errorCount += 1.0
print "\nthe total number of errors is: %d" % errorCount
print "\nthe total error rate is: %f" % (errorCount/float(mTest)) if __name__=='__main__':
#classifyperson()
datingClassTest()
dataSet, labels = createDataSet()
testX = array([1.2, 1.0])
k = 3
outputLabel = classify0(testX, dataSet, labels, 3)
print "Your input is:", testX, "and classified to class: ", outputLabel testX = array([0.1, 0.3])
outputLabel = classify0(testX, dataSet, labels, 3)
print "Your input is:", testX, "and classified to class: ", outputLabel
handwritingClassTest()
datingDataMat,datingLabels = file2matrix('E:\PythonMachine Learning in Action\datingTestSet2.txt')
print datingDataMat
print datingLabels[0:20]
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*array(datingLabels),15.0*array(datingLabels))
plt.show()
这里要注意:
trainingFileList = listdir('E:/PythonMachine Learning in Action/trainingDigits')
调用这个函数时路径写法。假设不想复杂指定路径简单就把目录和knn.py文件放在一起。
K-近邻算法python实现的更多相关文章
- k近邻算法python实现 -- 《机器学习实战》
''' Created on Nov 06, 2017 kNN: k Nearest Neighbors Input: inX: vector to compare to existing datas ...
- 用Python从零开始实现K近邻算法
KNN算法的定义: KNN通过测量不同样本的特征值之间的距离进行分类.它的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别.K通 ...
- python 机器学习(二)分类算法-k近邻算法
一.什么是K近邻算法? 定义: 如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别. 来源: KNN算法最早是由Cover和Hart提 ...
- 机器学习实战笔记--k近邻算法
#encoding:utf-8 from numpy import * import operator import matplotlib import matplotlib.pyplot as pl ...
- 机器学习之K近邻算法(KNN)
机器学习之K近邻算法(KNN) 标签: python 算法 KNN 机械学习 苛求真理的欲望让我想要了解算法的本质,于是我开始了机械学习的算法之旅 from numpy import * import ...
- 机器学习03:K近邻算法
本文来自同步博客. P.S. 不知道怎么显示数学公式以及排版文章.所以如果觉得文章下面格式乱的话请自行跳转到上述链接.后续我将不再对数学公式进行截图,毕竟行内公式截图的话排版会很乱.看原博客地址会有更 ...
- 机器学习——KNN算法(k近邻算法)
一 KNN算法 1. KNN算法简介 KNN(K-Nearest Neighbor)工作原理:存在一个样本数据集合,也称为训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分 ...
- [机器学习] k近邻算法
算是机器学习中最简单的算法了,顾名思义是看k个近邻的类别,测试点的类别判断为k近邻里某一类点最多的,少数服从多数,要点摘录: 1. 关键参数:k值 && 距离计算方式 &&am ...
- 机器学习实战 - python3 学习笔记(一) - k近邻算法
一. 使用k近邻算法改进约会网站的配对效果 k-近邻算法的一般流程: 收集数据:可以使用爬虫进行数据的收集,也可以使用第三方提供的免费或收费的数据.一般来讲,数据放在txt文本文件中,按照一定的格式进 ...
- 机器学习:k-NN算法(也叫k近邻算法)
一.kNN算法基础 # kNN:k-Nearest Neighboors # 多用于解决分裂问题 1)特点: 是机器学习中唯一一个不需要训练过程的算法,可以别认为是没有模型的算法,也可以认为训练数据集 ...
随机推荐
- (step6.1.3)hdu 1875(畅通工程再续——最小生成树)
题目大意:本题是中文题,可以直接在OJ上看 解题思路:最小生成树 1)本题的关键在于把二维的点转化成一维的点 for (i = 0; i < n; ++i) { scanf("%d%d ...
- Java中取小数点后两位(四种方法)
摘自http://irobot.iteye.com/blog/285537 Java中取小数点后两位(四种方法) 一 Long是长整型,怎么有小数,是double吧 java.text.D ...
- [转]使用Navicat for Oracle工具连接oracle的
使用Navicat for Oracle工具连接oracle的 这是一款oracle的客户端的图形化管理和开发工具,对于许多的数据库都有支持.之前用过 Navicat for sqlserver,感觉 ...
- C#主要字典集合性能对比[转]
A post I made a couple days ago about the side-effect of concurrency (the concurrent collections in ...
- Android - 使用Intent来启动Activity
本文地址: http://blog.csdn.net/caroline_wendy/article/details/21455141 Intent 的用途是 绑定 应用程序组件, 并在应用程序之间进行 ...
- Oracle SQL函数之转换函数To_char汇总
TO_CHAR(x[[,c2],C3])[功能]将日期或数据转换为char数据类型[参数]x是一个date或number数据类型.c2为格式参数c3为NLS设置参数如果x为日期nlsparm=NLS_ ...
- filezilla无法连接linux服务器
问题描述: 响应: 220 (vsFTPd 2.2.2)命令: AUTH TLS错误: 无法连接到服务器状态: 已从服务器断开 排查步骤: 1 检查服务器IP地址.用户名.密码是否正确 2 在控制面板 ...
- OC准备知识
#import 与 #include区别 include完成头文件的导入,可能会导致头文件的相互引用和函数或变量的重复定义 为了解决这个问题 我们必须这样做 #ifndef Student_h #de ...
- 对获取config文件的appSettings节点简单封装
转:http://www.cnblogs.com/marvin/archive/2011/07/29/EfficiencyAppSetting.html C#的开发中,无论你是winform开发还是w ...
- MySQL 初学笔记 ② -- MySQL安装
1. Ubuntu安装 sudo apt-get install mysql-server //安装mysql服务 sudo apt-get install mysql-client // sudo ...