需求:

利用一个手写数字“先验数据”集,使用knn算法来实现对手写数字的自动识别;

先验数据(训练数据)集:

♦数据维度比较大,样本数比较多。

♦ 数据集包括数字0-9的手写体。

♦每个数字大约有200个样本。

♦每个样本保持在一个txt文件中。

♦手写体图像本身的大小是32x32的二值图,转换到txt文件保存后,内容也是32x32个数字,0或者1,如下:

♦目录trainingDigits存放的是大约2000个训练数据

♦目录testDigits存放大约900个测试数据。

trainingDigits文件夹中为训练数据,里面存储的都是32*32的txt格式的数字图像数值矩阵。testDigits文件夹中为测试数据,存储格式与trainingDigits中相同。文件格式名例如:0_1.txt,0为数字的标签(即数字本身),1为表示数字0的第一个文件。训练数据是多张32*32手写图像的二维矩阵,所谓二维矩阵就是整个图像空白的地方使用0描述,写字的地方使用1描述,

代码python:https://github.com/kongxiaoshuang/KNN

  1. #-*- coding: utf-8 -*-
  2. from numpy import *
  3. import operator
  4.  
  5. import matplotlib
  6. import matplotlib.pyplot as plt
  7.  
  8. def createDataSet():
  9. group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
  10. labels = ['A', 'A', 'B', 'B']
  11. return group, labels
  12.  
  13. def classify0(inX, dataSet, labels, k): #inX为用于分类的输入向量,dataSet为输入的训练样本集, labels为训练标签,k表示用于选择最近的数目
  14. dataSetSize = dataSet.shape[0] #dataSet的行数
  15. diffMat = tile(inX, (dataSetSize, 1)) - dataSet #将inX数组复制成与dataSet相同行数,与dataSet相减,求坐标差
  16. sqDiffMat = diffMat**2 #diffMat的平方
  17. sqDistances = sqDiffMat.sum(axis=1) #将sqDiffMat每一行的所有数相加
  18. distances = sqDistances**0.5 #开根号,求点和点之间的欧式距离
  19. sortedDistIndicies = distances.argsort() #将distances中的元素从小到大排列,提取其对应的index,然后输出到sortedDistIndicies
  20. classCount = {} #创建字典
  21. for i in range(k):
  22. voteIlabel = labels[sortedDistIndicies[i]] #前k个标签数据
  23. classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 #判断classCount中有没有对应的voteIlabel,
  24. # 如果有返回voteIlabel对应的值,如果没有则返回0,在最后加1。为了计算k个标签的类别数量
  25. sortedClassCount = sorted(classCount.items(),
  26. key=operator.itemgetter(1), reverse=True) #生成classCount的迭代器,进行排序,
  27. # operator.itemgetter(1)以标签的个数降序排序
  28. return sortedClassCount[0][0] #返回个数最多的标签
  29.  
  30. def file2matrix(filename):
  31. fr = open(filename)
  32. arrayOLines = fr.readlines() #读入所有行
  33. numberOfLines = len(arrayOLines) #行数
  34. returnMat = zeros((numberOfLines, 3)) #创建数组,数据集
  35. classLabelVector = [] #标签集
  36. index = 0
  37. for line in arrayOLines:
  38. line = line.strip() #移除所有的回车符
  39. listFromLine = line.split('\t') #把一个字符串按\t分割成字符串数组
  40. returnMat[index,:] = listFromLine[0:3] #取listFromLine的前三个元素放入returnMat
  41. classLabelVector.append(int(listFromLine[-1])) #选取listFromLine的最后一个元素依次存入classLabelVector列表中
  42. index += 1
  43. return returnMat, classLabelVector
  44.  
  45. def autoNorm(dataSet):
  46. minVals = dataSet.min(0) #0表示从列中选取最小值
  47. maxVals = dataSet.max(0) #选取最大值
  48. ranges = maxVals-minVals
  49. normDataSet = zeros(shape(dataSet)) #创建一个与dataSet大小相同的零矩阵
  50. m = dataSet.shape[0] #取dataSet得行数
  51. normDataSet = dataSet - tile(minVals, (m, 1)) #将minVals复制m行 与dataSet数据集相减
  52. #归一化相除
  53. normDataSet = normDataSet/tile(ranges, (m, 1)) #将最大值-最小值的值复制m行 与normDataSet相除,即归一化
  54. return normDataSet, ranges, minVals #normDataSet为归一化特征值,ranges为最大值-最小值
  55.  
  56. def datingClassTest():
  57. hoRatio = 0.10 #测试数据占总数据的百分比
  58. datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') #将文本信息转成numpy格式
  59. #datingDataMat为数据集,datingLabels为标签集
  60. normMat, ranges, minVals = autoNorm(datingDataMat) #将datingDataMat数据归一化
  61. #normMat为归一化数据特征值,ranges为特征最大值-最小值,minVals为最小值
  62. m = normMat.shape[0] #取normMat的行数
  63. numTestVecs = int(m*hoRatio) #测试数据的行数
  64. errorCount = 0.0 #错误数据数量
  65. for i in range(numTestVecs):
  66. classifierResult = classify0(normMat[i,:], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
  67. #classify0为kNN分类器,normMat为用于分类的输入向量,normMat为输入的训练样本集(剩余的90%)
  68. #datingLabels为训练标签,3表示用于选择最近邻居的数目
  69. print("the classifier came back with: %d, the real answer is: %d" %(classifierResult, datingLabels[i]))
  70. if (classifierResult != datingLabels[i]):errorCount += 1.0 #分类器结果和原标签不一样,则errorCount加1
  71. print("the total error rate is : %f" %(errorCount/float(numTestVecs)))
  72.  
  73. # datingClassTest()
  74.  
  75. # datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
  76. #
  77. # normDataSet, ranges, minVals = autoNorm(datingDataMat)
  78.  
  79. # fig = plt.figure()
  80. # ax = fig.add_subplot(111) #一行一列一个
  81. # ax.scatter(datingDataMat[:,1], datingDataMat[:,2],
  82. # 15.0*array(datingLabels), 15.0*array(datingLabels)) #scatter画散点图,使用标签属性绘制不同颜色不同大小的点
  83. # plt.show()
  84.  
  85. # #测试分类器
  86. # group, labels = createDataSet()
  87. # label = classify0([1,1], group, labels, 3)
  88. # print(label)
  89.  
  90. from os import listdir
  91.  
  92. def img2vector (filename):
  93. returnVect = zeros((1, 1024)) #创建一个1*1024的数组
  94. fr = open(filename)
  95. for i in range(32):
  96. lineStr = fr.readline() #每次读入一行
  97. for j in range(32):
  98. returnVect[0, 32*i+j] = int(lineStr[j])
  99. return returnVect
  100.  
  101. def handwritingClassTest():
  102. hwLabels = [] #标签集
  103. trainingFileList = listdir('E:/digits/trainingDigits') #listdir获取训练集的文件目录
  104. m = len(trainingFileList) #文件数量
  105. trainingMat = zeros((m, 1024)) #一个数字1024个字符,创建m*1024的数组
  106. for i in range(m):
  107. fileNameStr = trainingFileList[i] #获取文件名
  108. fileStr = fileNameStr.split('.')[0] #以'.'将字符串分割,并取第一项,即0_0.txt取0_0
  109. classNumStr = int(fileStr.split('_')[0]) #以'_'将字符串分割,并取第一项
  110. hwLabels.append(classNumStr) #依次存入hwLabels标签集
  111. trainingMat[i, :] = img2vector('E:/digits/trainingDigits/%s' % fileNameStr) #将每个数字的字符值依次存入trainingMat
  112. testFileList = listdir('E:/digits/testDigits') #读入测试数据集
  113. errorCount = 0.0 #测试错误数量
  114. mTest = len(testFileList) #测试集的数量
  115. for i in range(mTest):
  116. fileNameStr = testFileList[i]
  117. fileStr = fileNameStr.split('.')[0]
  118. classNumStr = int(fileStr.split('_')[0]) #测试数据标签
  119. vectorUnderTest = img2vector('E:/digits/testDigits/%s' % fileNameStr) #读入测试数据
  120. classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) #分类器kNN算法,3为最近邻数目
  121. print("the calssifier came back with: %d, the real answer is : %d" %(classifierResult, classNumStr))
  122. if (classifierResult != classNumStr): errorCount +=1.0
  123. print("\nthe total number of errors is : %f" % errorCount)
  124. print("\nthe total error rate is :%f" % (errorCount/float(mTest)))
  125.  
  126. handwritingClassTest()

KNN算法识别手写数字的更多相关文章

  1. 基于OpenCV的KNN算法实现手写数字识别

    基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...

  2. KNN (K近邻算法) - 识别手写数字

    KNN项目实战——手写数字识别 1. 介绍 k近邻法(k-nearest neighbor, k-NN)是1967年由Cover T和Hart P提出的一种基本分类与回归方法.它的工作原理是:存在一个 ...

  3. 机器学习--kNN算法识别手写字母

    本文主要是用kNN算法对字母图片进行特征提取,分类识别.内容如下: kNN算法及相关Python模块介绍 对字母图片进行特征提取 kNN算法实现 kNN算法分析 一.kNN算法介绍 K近邻(kNN,k ...

  4. KNN算法案例--手写数字识别

    import numpy as np import matplotlib .pyplot as plt import pandas as pd from sklearn.neighbors impor ...

  5. KNN算法实现手写数字

    from numpy import * import operator from os import listdir def classify0(inX, dataSet, labels, k): d ...

  6. KNN 算法-实战篇-如何识别手写数字

    公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...

  7. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  8. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  9. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

随机推荐

  1. 一个项目里,httpclient竟然出现了四种

    有网友提问: 今年年初,到一家互联网公司实习,该公司是国内行业龙头.不过技术和管理方面,却弱爆了.那里的程序员,每天都在看邮件,查问题工单.这些问题,多半是他们设计不当,造成的.代码写

  2. Java并发包异步执行器CompletableFuture

    前言 CompletableFuture是对Future的一种强有力的扩展,Future只能通过轮询isDone()方法或者调用get()阻塞等待获取一个异步任务的结果,才能继续执行下一步,当我们执行 ...

  3. Error setting null for parameter #10 with JdbcType

    转: Error setting null for parameter #10 with JdbcType OTHER . 2014年02月23日 11:00:33 厚积 阅读数 58535   my ...

  4. SpringMVC+Ajax实现文件批量上传和下载功能实例代码

    需求: 文件批量上传,支持断点续传. 文件批量下载,支持断点续传. 使用JS能够实现批量下载,能够提供接口从指定url中下载文件并保存在本地指定路径中. 服务器不需要打包. 支持大文件断点下载.比如下 ...

  5. IDEA中提示Error:java: Compilation failed: internal java compiler error

    解决办法:File-->Setting...-->Build,Execution,Deployment-->Compiler-->Java Compiler 设置相应Modul ...

  6. 【设计】PC Web端框架组件

    https://uedart.com/demo/templatesWebKit/index.html#g=1&p=%E4%BD%9C%E5%93%81%E9%A6%96%E9%A1%B5 移动 ...

  7. spark:neither spark.yarn.jars not spark.yarn.archive is set

    1.Spark启动警告:neither spark.yarn.jars not spark.yarn.archive is set,falling back to uploading librarie ...

  8. IDA7.2破解版本

    更新说明 https://www.hex-rays.com/products/ida/7.2/index.shtml 破解文章 作者阐述了一下对IDA安装密码的攻击方法,通过枚举多种语言默认的随机数发 ...

  9. iOS技术面试02:内存管理

    怎么保证多人开发进行内存泄露的检查. 如何定位内存泄露? 1> 使用Analyze进行代码的静态分析(检测有无潜在的内存泄露) 2> 通过leak检查在程序运行过程中有无内存泄露 3> ...

  10. MySQL 事务一览

    MySQL 中的事务? 对 MySQL 来说,事务通常是一组包含对数据库操作的集合.在执行时,只有在该组内的事务都执行成功,这个事务才算执行成功,否则就算失败.MySQL 中,事务支持是在引擎层实现的 ...