'''
Created on Nov 06, 2017
kNN: k Nearest Neighbors Input: inX: vector to compare to existing dataset (1xN)
dataSet: size m data set of known vectors (NxM)
labels: data set labels (1xM vector)
k: number of neighbors to use for comparison (should be an odd number) Output: the most popular class label @author: Liu Chuanfeng
'''
import operator
import numpy as np
import matplotlib.pyplot as plt
from os import listdir def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = np.tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5
sortedDistIndicies = distances.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0] #数据预处理,将文件中数据转换为矩阵类型
def file2matrix(filename):
fr = open(filename)
arrayLines = fr.readlines()
numberOfLines = len(arrayLines)
returnMat = np.zeros((numberOfLines, 3))
classLabelVector = []
index = 0
for line in arrayLines:
line = line.strip()
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat, classLabelVector #数据归一化处理:由于矩阵各列数据取值范围的巨大差异导致各列对计算结果的影响大小不一,需要归一化以保证相同的影响权重
def autoNorm(dataSet):
maxVals = dataSet.max(0)
minVals = dataSet.min(0)
ranges = maxVals - minVals
m = dataSet.shape[0]
normDataSet = (dataSet - np.tile(minVals, (m, 1))) / np.tile(ranges, (m, 1))
return normDataSet, ranges, minVals #约会网站测试代码
def datingClassTest():
hoRatio = 0.10
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m * hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifyResult = classify0(normMat[i,:], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
print('theclassifier came back with: %d, the real answer is: %d' % (classifyResult, datingLabels[i]))
if ( classifyResult != datingLabels[i]):
errorCount += 1.0
print ('the total error rate is: %.1f%%' % (errorCount/float(numTestVecs) * 100)) #约会网站预测函数
def classifyPerson():
resultList = ['not at all', 'in small doses', 'in large doses']
percentTats = float(input("percentage of time spent playing video games?"))
ffMiles = float(input("frequent flier miles earned per year?"))
iceCream = float(input("liters of ice cream consumed per year?"))
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = np.array([ffMiles, percentTats, iceCream])
classifyResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3)
print ("You will probably like this persoon:", resultList[classifyResult - 1]) #手写识别系统#============================================================================================================
#数据预处理:输入图片为32*32的文本类型,将其形状转换为1*1024
def img2vector(filename):
returnVect = np.zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32*i+j] = int(lineStr[j])
return returnVect #手写数字识别系统测试代码
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\traingDigits')
m = len(trainingFileList)
trainingMat = np.zeros((m, 1024))
for i in range(m): #|
fileNameStr = trainingFileList[i] #|
fileName = fileNameStr.split('.')[0] #| 获取训练集路径下每一个文件,分割文件名,将第一个数字作为标签存储在hwLabels中
classNumber = int(fileName.split('_')[0]) #|
hwLabels.append(classNumber) #|
trainingMat[i,:] = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\traingDigits\\%s' % fileNameStr) #变换矩阵形状: from 32*32 to 1*1024
testFileList = listdir('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest): #同训练集
fileNameStr = testFileList[i]
fileName = fileNameStr.split('.')[0]
classNumber = int(fileName.split('_')[0])
vectorUnderTest = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits\\%s' % fileNameStr)
classifyResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) #计算欧氏距离并分类,返回计算结果
print ('The classifier came back with: %d, the real answer is: %d' % (classifyResult, classNumber))
if (classifyResult != classNumber):
errorCount += 1.0
print ('The total number of errors is: %d' % (errorCount))
print ('The total error rate is: %.1f%%' % (errorCount/float(mTest) * 100)) # Simple unit test of func: file2matrix()
#datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
#print (datingDataMat)
#print (datingLabels) # Usage of figure construction of matplotlib
#fig=plt.figure()
#ax = fig.add_subplot(111)
#ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*np.array(datingLabels), 15.0*np.array(datingLabels))
#plt.show() # Simple unit test of func: autoNorm()
#normMat, ranges, minVals = autoNorm(datingDataMat)
#print (normMat)
#print (ranges)
#print (minVals) # Simple unit test of func: img2vector
#testVect = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits\\0_13.txt')
#print (testVect[0, 32:63] ) #约会网站测试
datingClassTest() #约会网站预测
classifyPerson() #手写数字识别系统预测
handwritingClassTest()

Output:

theclassifier came back with: 3, the real answer is: 3
the total error rate is: 0.0%
theclassifier came back with: 2, the real answer is: 2
the total error rate is: 0.0%
theclassifier came back with: 1, the real answer is: 1
the total error rate is: 0.0%

...

theclassifier came back with: 2, the real answer is: 2
the total error rate is: 4.0%
theclassifier came back with: 1, the real answer is: 1
the total error rate is: 4.0%
theclassifier came back with: 3, the real answer is: 1
the total error rate is: 5.0%

percentage of time spent playing video games?10
frequent flier miles earned per year?10000
liters of ice cream consumed per year?0.5
You will probably like this persoon: in small doses

...

The classifier came back with: 9, the real answer is: 9
The total number of errors is: 27
The total error rate is: 6.8%

 Reference:

《机器学习实战》

k近邻算法python实现 -- 《机器学习实战》的更多相关文章

  1. AI小记-K近邻算法

    K近邻算法和其他机器学习模型比,有个特点:即非参数化的局部模型. 其他机器学习模型一般都是基于训练数据,得出一般性知识,这些知识的表现是一个全局性模型的结构和参数.模型你和好了后,不再依赖训练数据,直 ...

  2. K近邻 Python实现 机器学习实战(Machine Learning in Action)

    算法原理 K近邻是机器学习中常见的分类方法之间,也是相对最简单的一种分类方法,属于监督学习范畴.其实K近邻并没有显式的学习过程,它的学习过程就是测试过程.K近邻思想很简单:先给你一个训练数据集D,包括 ...

  3. 机器学习实战 - python3 学习笔记(一) - k近邻算法

    一. 使用k近邻算法改进约会网站的配对效果 k-近邻算法的一般流程: 收集数据:可以使用爬虫进行数据的收集,也可以使用第三方提供的免费或收费的数据.一般来讲,数据放在txt文本文件中,按照一定的格式进 ...

  4. 02机器学习实战之K近邻算法

    第2章 k-近邻算法 KNN 概述 k-近邻(kNN, k-NearestNeighbor)算法是一种基本分类与回归方法,我们这里只讨论分类问题中的 k-近邻算法. 一句话总结:近朱者赤近墨者黑! k ...

  5. 机器学习实战笔记--k近邻算法

    #encoding:utf-8 from numpy import * import operator import matplotlib import matplotlib.pyplot as pl ...

  6. python 机器学习(二)分类算法-k近邻算法

      一.什么是K近邻算法? 定义: 如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别. 来源: KNN算法最早是由Cover和Hart提 ...

  7. 用Python从零开始实现K近邻算法

    KNN算法的定义: KNN通过测量不同样本的特征值之间的距离进行分类.它的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别.K通 ...

  8. 《机实战》第2章 K近邻算法实战(KNN)

    1.准备:使用Python导入数据 1.创建kNN.py文件,并在其中增加下面的代码: from numpy import * #导入科学计算包 import operator #运算符模块,k近邻算 ...

  9. 机器学习之K近邻算法(KNN)

    机器学习之K近邻算法(KNN) 标签: python 算法 KNN 机械学习 苛求真理的欲望让我想要了解算法的本质,于是我开始了机械学习的算法之旅 from numpy import * import ...

随机推荐

  1. 纯CSS弹出层,城市切换效果

    <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN" "http://www.w3.org/ ...

  2. 开启ss-libev多用户

    原理:通过查看进程,得到命令及需要的参数,然后,在制作一个配置文件,pid文件随意写. 1.首先正常开启一个: /etc/init.d/shadowsocks-libev start 2.然后:利用查 ...

  3. JNI调用实例

    1. 环境 Windows7-64Bit VS2010-32Bit JDK1.8-64Bit 2. 步骤 2.1 创建NativePrint类 public class NativePrint { p ...

  4. FileZilla Server-Can’t access file错误解决方法

    在某服务器上用FileZilla Server搭建了一个FTP服务器.开始使用没有发现任何问题,后来在向服务器传送大文件的时候,发现总是传输到固定的百分比的时候出现 ”550 can’t access ...

  5. MapReduce编程实例4

    MapReduce编程实例: MapReduce编程实例(一),详细介绍在集成环境中运行第一个MapReduce程序 WordCount及代码分析 MapReduce编程实例(二),计算学生平均成绩 ...

  6. Java实现ZIP解压功能

    1.引入依赖 <dependency> <groupId>org.apache.ant</groupId> <artifactId>ant</ar ...

  7. php 计算时间添加

    $Date_1=date("Y-m-d");$Date_2="2015-10-11";$d1=strtotime($Date_1);$d2=strtotime( ...

  8. "回车"(carriage return)和"换行"(line feed)

    “回车”(carriage return)和“换行”(line feed)这两个概念的来历: 在计算机还没有出现之前,有一种叫做电传打字机(Teletype Model 33)的玩意,每秒钟可以打10 ...

  9. 在线生成条形码的解决方案(39码、EAN-13)

    感谢博主:转自:http://xoyozo.eyuyao.com/blog/barcode.html public partial class ReceivablesFormView : System ...

  10. 理解WCF中的Contracts

    WCF中的Contracts WCF通过Contract来说明服务和操作,一般包含五种类型的Contract:ServiceContract,OperationContract,FaultContra ...