手写数字是32x32的黑白图像。为了能使用KNN分类器,我们需要把32x32的二进制图像转换为1x1024

1. 将图像转化为向量

  1. from numpy import *
  2. # 导入科学计算包numpy和运算符模块operator
  3. import operator
  4. from os import listdir
  1. def img2vector(filename):
  2. """
  3. 将图像数据转换为向量
  4. :param filename: 图片文件 因为我们的输入数据的图片格式是 32 * 32的
  5. :return: 一维矩阵
  6. 该函数将图像转换为向量:该函数创建 1 * 1024 的NumPy数组,然后打开给定的文件,
  7. 循环读出文件的前32行,并将每行的头32个字符值存储在NumPy数组中,最后返回数组。
  8. """
  9. returnVect = zeros((1, 1024))
  10. fr = open(filename)
  11. for i in range(32):
  12. lineStr = fr.readline()
  13. for j in range(32):
  14. returnVect[0, 32 * i + j] = int(lineStr[j])
  15. return returnVect

测试:

  1. testVector = img2vector('F:/迅雷下载/machinelearninginaction/Ch02/testDigits/0_13.txt')
  2. testVector[0, 0:31]
  1. array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.,
  2. 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

2. KNN分类器

  1. def classify0(inX, dataSet, labels, k):
  2. """
  3. inX: 用于分类的输入向量
  4. dataSet: 输入的训练样本集
  5. labels: 标签向量
  6. k: 选择最近邻居的数目
  7. 注意:labels元素数目和dataSet行数相同;程序使用欧式距离公式.
  8. """
  9. # 求出数据集的行数
  10. dataSetSize = dataSet.shape[0]
  11. # tile生成和训练样本对应的矩阵,并与训练样本求差
  12. """
  13. tile: 列: 3表示复制的行数, 行:1/2 表示对inx的重复的次数
  14. 例:In []: inX = [1, 2, 3]
  15. tile(inx, (3, 1))
  16.  
  17. Out[]: array([[1, 2, 3],
  18. [1, 2, 3],
  19. [1, 2, 3]])
  20. """
  21. # 用inx(输入向量)生成和dataSet类型一样的矩阵,在减去dataSet
  22. diffMat = tile(inX, (dataSetSize, 1)) - dataSet
  23. # 取平方
  24. sqDiffMat = diffMat ** 2
  25. # 将矩阵的每一行相加
  26. sqDistances = sqDiffMat.sum(axis=1)
  27. # 开方
  28. distances = sqDistances ** 0.5
  29. # 根据距离排序从小到大的排序,返回对应的索引位置
  30. # argsort() 是将x中的元素从小到大排列,提取其对应的index(索引),然后输出到y。
  31. """
  32. In [] : y = argsort([3, 0, 2, -1, 4, 5])
  33. print(y[0])
  34. print(y[5])
  35. Out[] : 3
  36. 5
  37. 由于最小的数是-1,它的序号是3,因此y[0] = 3, 最大的数是5,它的序号是5,因此y[5] = 5
  38. """
  39. sortedDistIndicies = distances.argsort()
  40. # 2. 选择距离最小的k个点
  41. classCount = {}
  42. for i in range(k):
  43. voteIlabel = labels[sortedDistIndicies[i]]
  44. classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
  45. sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
  46. return sortedClassCount[0][0]

3. 手写数字识别系统的测试代码

  1. def handwritingClassTest():
  2. # 1. 导入数据
  3. hwLabels = []
  4. trainingFileList = listdir('F:/迅雷下载/machinelearninginaction/Ch02/trainingDigits') # load the training set
  5. m = len(trainingFileList)
  6. trainingMat = zeros((m, 1024))
  7. # hwLabels存储0~9对应的index位置, trainingMat存放的每个位置对应的图片向量
  8. for i in range(m):
  9. fileNameStr = trainingFileList[i]
  10. fileStr = fileNameStr.split('.')[0] # take off .txt
  11. classNumStr = int(fileStr.split('_')[0])
  12. hwLabels.append(classNumStr)
  13. # 将 32*32的矩阵->1*1024的矩阵
  14. trainingMat[i, :] = img2vector('F:/迅雷下载/machinelearninginaction/Ch02/trainingDigits/%s' % fileNameStr)
  15.  
  16. # 2. 导入测试数据
  17. testFileList = listdir('F:/迅雷下载/machinelearninginaction/Ch02/testDigits') # iterate through the test set
  18. errorCount = 0.0
  19. mTest = len(testFileList)
  20. for i in range(mTest):
  21. fileNameStr = testFileList[i]
  22. fileStr = fileNameStr.split('.')[0] # take off .txt
  23. classNumStr = int(fileStr.split('_')[0])
  24. vectorUnderTest = img2vector('F:/迅雷下载/machinelearninginaction/Ch02/testDigits/%s' % fileNameStr)
  25. classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
  26. print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
  27. if (classifierResult != classNumStr): errorCount += 1.0
  28. print("\nthe total number of errors is: %d" % errorCount)
  29. print("\nthe total error rate is: %f" % (errorCount / float(mTest)))
  1. handwritingClassTest()
  1. the classifier came back with: 0, the real answer is: 0
  2. the classifier came back with: 0, the real answer is: 0
    ...
  1. the classifier came back with: 9, the real answer is: 9
  2. the classifier came back with: 9, the real answer is: 9
  3.  
  4. the total number of errors is: 10
  5.  
  6. the total error rate is: 0.010571
    k-近邻算法识别手写数字,错误率在1.1%.改变k的值、修改函数 handwritingClassTest 随机选取训练样本、改变训练样本的数目,都会对k-近邻算法的错误率产生影响。
    实际上,这个算法的执行效率并不高。因为每个算法需要为每个测试向量做2000次距离计算,每个距离计算包括了1024个维度浮点运算,总计执行900次。
    K决策树就是k-近邻的优化版。

4. 总结

k-近邻算法的特点:

1. 是分类数据最简单最有效的算法

2. 必须保存全部数据集,会使用大量存储空间

3. 必须对每个数据计算距离值,非常耗时

  1.  

k-近邻算法-手写识别系统的更多相关文章

  1. 第三篇:基于K-近邻分类算法的手写识别系统

    前言 本文将继续讲解K-近邻算法的项目实例 - 手写识别系统. 该系统在获取用户的手写输入后,判断用户写的是什么. 为了突出核心,简化细节,本示例系统中的输入为32x32矩阵,分类结果也均为数字.但对 ...

  2. 机器学习实战一:kNN手写识别系统

    实战一:kNN手写识别系统 本文将一步步地构造使用K-近邻分类器的手写识别系统.由于能力有限,这里构造的系统只能识别0-9.需要识别的数字已经使用图形处理软件,处理成具有相同的色彩和大小:32像素*3 ...

  3. 吴裕雄--天生自然python机器学习:KNN-近邻算法在手写识别系统上的应用

    需要识别的数字已经使用图形处理软件,处理成具有相同的色 彩和大小® : 宽髙是32像 素 *32像素的黑白图像.尽管采用文本格式存储图像不能有效地利用内 存空间,但是为了方便理解,我们还是将图像转换为 ...

  4. 【Machine Learning in Action --2】K-近邻算法构造手写识别系统

    为了简单起见,这里构造的系统只能识别数字0到9,需要识别的数字已经使用图形处理软件,处理成具有相同的色彩和大小:宽高是32像素的黑白图像.尽管采用文本格式存储图像不能有效地利用内存空间,但是为了方便理 ...

  5. K近邻实战手写数字识别

    1.导包 import numpy as np import operator from os import listdir from sklearn.neighbors import KNeighb ...

  6. 《机器学习实战》之k-近邻算法(手写识别系统)

    这个玩意和改进约会网站的那个差不多,它是提前把所有数字转换成了32*32像素大小的黑白图,然后转换成字符图(用0,1表示),将所有1024个像素点用一维矩阵保存下来,这样就可以通过knn计算欧几里得距 ...

  7. 《机器学习实战》-k近邻算法

    目录 K-近邻算法 k-近邻算法概述 解析和导入数据 使用 Python 导入数据 实施 kNN 分类算法 测试分类器 使用 k-近邻算法改进约会网站的配对效果 收集数据 准备数据:使用 Python ...

  8. 机器学习实战kNN之手写识别

    kNN算法算是机器学习入门级绝佳的素材.书上是这样诠释的:“存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都有标签,即我们知道样本集中每一条数据与所属分类的对应关系.输入没有标签的新数据 ...

  9. python 实现 KNN 分类器——手写识别

    1 算法概述 1.1 优劣 优点:进度高,对异常值不敏感,无数据输入假定 缺点:计算复杂度高,空间复杂度高 应用:主要用于文本分类,相似推荐 适用数据范围:数值型和标称型 1.2 算法伪代码 (1)计 ...

随机推荐

  1. 查看oracle 用户执行的sql语句历史记录

      select * from v$sqlarea t order by t.LAST_ACTIVE_TIME desc

  2. 【转】fnmatch模块的使用——主要作用是文件名称的匹配,并且匹配的模式使用的unix shell风格

    [转]fnmatch模块的使用 fnmatch模块的使用 此模块的主要作用是文件名称的匹配,并且匹配的模式使用的unix shell风格.fnmatch比较简单就4个方法分别是:fnmatch,fnm ...

  3. HDFS-put: unexpected URISyntaxException

    目的:将某zip上传到HDFS某目录 [hdfs@mr1 jars]$ hadoop fs -put "20180720_155245 label.zip" /user/File/ ...

  4. 用户态使用 glibc/backtrace 追踪函数调用堆栈定位段错误【转】

    转自:https://blog.csdn.net/gatieme/article/details/84189280 版权声明:本文为博主原创文章 && 转载请著名出处 @ http:/ ...

  5. 全面接触PDF:最好用的PDF软件汇总(转)

    全面接触PDF:最好用的PDF软件汇总(2010-12-07更新): http://xbeta.info/pdf-software.htm 比较全面的c#帮助类,各种功能性代码: https://gi ...

  6. RocketMQ 简单梳理 及 集群部署笔记【转】

    一.RocketMQ 基础知识介绍Apache RocketMQ是阿里开源的一款高性能.高吞吐量.队列模型的消息中间件的分布式消息中间件. 上图是一个典型的消息中间件收发消息的模型,RocketMQ也 ...

  7. select2使用方法总结

    官网:http://select2.github.io/ 调用 <link href="~/Content/select2.min.css" rel="styles ...

  8. pycharm 中自动补全代码提示前符号 p,m ,c,v, f 是什么意思

    是自动补全的变量的类别p:parameter 参数m:method 方法c:class 类 v:variable 变量f:function 函数

  9. PHP--php时间差8个小时的问题

    自PHP5.0开始,用PHP获取系统时间时,时间比当前时间少8个小时.原因是PHP.ini中没有设置timezone时,PHP是使用的UTC时间,所以在中国时间要少8小时. 解决办法: 1.在PHP. ...

  10. 【原创】大数据基础之Spark(9)spark部署方式yarn/mesos

    1 下载解压 https://spark.apache.org/downloads.html $ wget http://mirrors.shu.edu.cn/apache/spark/spark-2 ...