kNN算法python实现和简单数字识别
kNN算法
算法优缺点:
- 优点:精度高、对异常值不敏感、无输入数据假定
- 缺点:时间复杂度和空间复杂度都很高
- 适用数据范围:数值型和标称型
算法的思路:
KNN算法(全称K最近邻算法),算法的思想很简单,简单的说就是物以类聚,也就是说我们从一堆已知的训练集中找出k个与目标最靠近的,然后看他们中最多的分类是哪个,就以这个为依据分类。
函数解析:
库函数
tile()
如
tile(A,n)
就是将A重复n次
a = np.array([0, 1, 2])
np.tile(a, 2)
array([0, 1, 2, 0, 1, 2])
np.tile(a, (2, 2))
array([[0, 1, 2, 0, 1, 2],[0, 1, 2, 0, 1, 2]])
np.tile(a, (2, 1, 2))
array([[[0, 1, 2, 0, 1, 2]],[[0, 1, 2, 0, 1, 2]]])
b = np.array([[1, 2], [3, 4]])
np.tile(b, 2)
array([[1, 2, 1, 2],[3, 4, 3, 4]])
np.tile(b, (2, 1))
array([[1, 2],[3, 4],[1, 2],[3, 4]])`
自己实现的函数
createDataSet()
生成测试数组kNNclassify(inputX, dataSet, labels, k)
分类函数
- inputX 输入的参数
- dataSet 训练集
- labels 训练集的标号
- k 最近邻的数目
- #coding=utf-8
- from numpy import *
- import operator
- def createDataSet():
- group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
- labels = ['A','A','B','B']
- return group,labels
- #inputX表示输入向量(也就是我们要判断它属于哪一类的)
- #dataSet表示训练样本
- #label表示训练样本的标签
- #k是最近邻的参数,选最近k个
- def kNNclassify(inputX, dataSet, labels, k):
- dataSetSize = dataSet.shape[0]#计算有几个训练数据
- #开始计算欧几里得距离
- diffMat = tile(inputX, (dataSetSize,1)) - dataSet
- sqDiffMat = diffMat ** 2
- sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加
- distances = sqDistances ** 0.5
- #欧几里得距离计算完毕
- sortedDistance = distances.argsort()
- classCount = {}
- for i in xrange(k):
- voteLabel = labels[sortedDistance[i]]
- classCount[voteLabel] = classCount.get(voteLabel,0) + 1
- res = max(classCount)
- return res
- def main():
- group,labels = createDataSet()
- t = kNNclassify([0,0],group,labels,3)
- print t
- if __name__=='__main__':
- main()
- #coding=utf-8
kNN应用实例
手写识别系统的实现
数据集:
两个数据集:training和test。分类的标号在文件名中。像素32*32的。数据大概这个样子:
方法:
kNN的使用,不过这个距离算起来比较复杂(1024个特征),主要是要处理如何读取数据这个问题的,比较方面直接调用就可以了。
速度:
速度还是比较慢的,这里数据集是:training 2000+,test 900+(i5的CPU)
k=3的时候要32s+
- #coding=utf-8
- from numpy import *
- import operator
- import os
- import time
- def createDataSet():
- group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
- labels = ['A','A','B','B']
- return group,labels
- #inputX表示输入向量(也就是我们要判断它属于哪一类的)
- #dataSet表示训练样本
- #label表示训练样本的标签
- #k是最近邻的参数,选最近k个
- def kNNclassify(inputX, dataSet, labels, k):
- dataSetSize = dataSet.shape[0]#计算有几个训练数据
- #开始计算欧几里得距离
- diffMat = tile(inputX, (dataSetSize,1)) - dataSet
- #diffMat = inputX.repeat(dataSetSize, aixs=1) - dataSet
- sqDiffMat = diffMat ** 2
- sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加
- distances = sqDistances ** 0.5
- #欧几里得距离计算完毕
- sortedDistance = distances.argsort()
- classCount = {}
- for i in xrange(k):
- voteLabel = labels[sortedDistance[i]]
- classCount[voteLabel] = classCount.get(voteLabel,0) + 1
- res = max(classCount)
- return res
- def img2vec(filename):
- returnVec = zeros((1,1024))
- fr = open(filename)
- for i in range(32):
- lineStr = fr.readline()
- for j in range(32):
- returnVec[0,32*i+j] = int(lineStr[j])
- return returnVec
- def handwritingClassTest(trainingFloder,testFloder,K):
- hwLabels = []
- trainingFileList = os.listdir(trainingFloder)
- m = len(trainingFileList)
- trainingMat = zeros((m,1024))
- for i in range(m):
- fileName = trainingFileList[i]
- fileStr = fileName.split('.')[0]
- classNumStr = int(fileStr.split('_')[0])
- hwLabels.append(classNumStr)
- trainingMat[i,:] = img2vec(trainingFloder+'/'+fileName)
- testFileList = os.listdir(testFloder)
- errorCount = 0.0
- mTest = len(testFileList)
- for i in range(mTest):
- fileName = testFileList[i]
- fileStr = fileName.split('.')[0]
- classNumStr = int(fileStr.split('_')[0])
- vectorUnderTest = img2vec(testFloder+'/'+fileName)
- classifierResult = kNNclassify(vectorUnderTest, trainingMat, hwLabels, K)
- #print classifierResult,' ',classNumStr
- if classifierResult != classNumStr:
- errorCount +=1
- print 'tatal error ',errorCount
- print 'error rate',errorCount/mTest
- def main():
- t1 = time.clock()
- handwritingClassTest('trainingDigits','testDigits',3)
- t2 = time.clock()
- print 'execute ',t2-t1
- if __name__=='__main__':
- main()
- #coding=utf-8
kNN算法python实现和简单数字识别的更多相关文章
- KNN算法--python实现
邻近算法 或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一.所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代 ...
- 【机器学习】k-近邻算法应用之手写数字识别
上篇文章简要介绍了k-近邻算法的算法原理以及一个简单的例子,今天再向大家介绍一个简单的应用,因为使用的原理大体差不多,就没有没有过多的解释. 为了具有说明性,把手写数字的图像转换为txt文件,如下图所 ...
- KNN算法python实现
1 KNN 算法 knn,k-NearestNeighbor,即寻找与点最近的k个点. 2 KNN numpy实现 效果: k=1 k=2 3 numpy 广播,聚合操作. 这里求距离函数,求某点和集 ...
- KNN算法——python实现
二.Python实现 对于机器学习而已,Python需要额外安装三件宝,分别是Numpy,scipy和Matplotlib.前两者用于数值计算,后者用于画图.安装很简单,直接到各自的官网下载回来安装即 ...
- 神经网络(BP)算法Python实现及简单应用
首先用Python实现简单地神经网络算法: import numpy as np # 定义tanh函数 def tanh(x): return np.tanh(x) # tanh函数的导数 def t ...
- KNN算法python实现小样例
K近邻算法概述优点:精度高.对异常数据不敏感.无数据输入假定缺点:计算复杂度高.空间复杂度高适用数据范围:数值型和标称型工作原理:存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签 ...
- [Solution] 简单数字识别之Tesseract
图像识别涉及的理论:傅里叶变换,图形形态学,滤波,矩阵变换等等. Tesseract的出现为了解决在没有这些复杂的理论基础,快速识别图像的框架. 准备: 1.样本图像学习,预处理 (平均每1个元素出现 ...
- 深度学习(一):Python神经网络——手写数字识别
声明:本文章为阅读书籍<Python神经网络编程>而来,代码与书中略有差异,书籍封面: 源码 若要本地运行,请更改源码中图片与数据集的位置,环境为 Python3.6x. 1 import ...
- 基于sk_learn的k近邻算法实现-mnist手写数字识别且要求97%以上精确率
1. 导入需要的库 from sklearn.datasets import fetch_openml import numpy as np from sklearn.neighbors import ...
随机推荐
- 关于Wireshark "The NPF driver isn’t running……"解决办法
启动Wireshark软件时出现了如下图所示的错误,就搜索了一下解决方法,特总结如下: 这个错误是因为没有开启NPF服务造成的.简要说一下NPF吧. NPF即网 络数据包过滤器(Netgroup Pa ...
- 【BZOJ-3553】三叉神经树 树链剖分
3553: [Shoi2014]三叉神经树 Time Limit: 160 Sec Memory Limit: 256 MBSubmit: 347 Solved: 112[Submit][Stat ...
- java单例的几种实现方法
java单例的几种实现方法: 方式1: public class Something { private Something() {} private static class LazyHolder ...
- [C#] 图文解说调用WebServer实例
本文旨在实现如何在.NET环境下调用WebServer,以天气接口为例进行说明. WebServer地址:http://www.webxml.com.cn/WebServices/WeatherWeb ...
- 文件夹锁定(Source)
文件夹锁定(Source)private void Lock(string folderPath){ try { string adminUserName = Environ ...
- Ubuntu操作系统下软件的卸载
1.查找安装文件列表 $ dpkg --list 2. 将列表名录复制粘贴到文本文件中 3. 搜索关键词,找到准确的名称 4. 在终端中执行命令: $ sudo apt-get --purge rem ...
- GNURadio 使用问题
- tyvj1463 智商问题
背景 各种数据结构帝~各种小姊妹帝~各种一遍AC帝~ 来吧! 描述 某个同学又有很多小姊妹了他喜欢聪明的小姊妹 所以经常用神奇的函数来估算小姊妹的智商他得出了自己所有小姊妹的智商小姊妹的智商都是非负整 ...
- JavaScript -- 小试牛刀
//var a = parseInt(window.prompt("请输入一个数字!","")); //switch(a) { // case 1 : // c ...
- UVa2521
理解:max 记录的是有大牌的个数 mid 是有中断 而造成的不确定 我理解是一个间断点以下的 数和一个间断点抵消 在前面没有间断的情况下 才能确定这张牌稳赢 #include<iostrea ...