1. import pandas as pd
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import os
  5. from sklearn.neighbors import KNeighborsClassifier
  6.  
  7. def build_data(dir_name):
  8. """
  9. 构建数据
  10. :param dir_name: 指定传入文件夹名称
  11. :return: 构建好的数据
  12. """
  13. # 获取文件名列表
  14. file_name_list = os.listdir(dir_name + "/")
  15. print("获取到的文件名列表:\n", file_name_list)
  16. # 进行读取文件
  17.  
  18. data = np.zeros(shape=(len(file_name_list), 1025))
  19.  
  20. # 循环读取文件
  21. for file_index, file_name in enumerate(file_name_list):
  22. # file_index 文本名称所对应的下标
  23. # file_name 文本名称
  24. # 加载数据
  25. file_data = np.loadtxt(dir_name + "/" + file_name, dtype=np.str)
  26.  
  27. # 构建一个列表
  28. arr = []
  29. for file_data_index, file_data_content in enumerate(file_data):
  30. # print(file_data_content)
  31. # print("*"*80)
  32. # 将 每一个元素转化为一个int 类型的列表
  33. arr_sigle_list = [int(tmp) for tmp in file_data_content]
  34. # print(arr)
  35. # 把每个元素添加到列表中
  36. arr.append(arr_sigle_list)
  37.  
  38. # print(arr)
  39. # 将一个样本转化为数组
  40. arr_single_sample = np.array(arr)
  41. # print(arr_single_sample)
  42. # np.savetxt("./hh.txt",arr_single_sample,fmt="%d")
  43. # 将二维数组展开为一维---特征值
  44. arr_single_sample = arr_single_sample.ravel()
  45. # print(arr_single_sample)
  46. # 目标值
  47. label = int(file_name[0])
  48. # print(res)
  49. # print(arr_single_sample.shape)
  50. # 将一个 完整的样本拼接起来,组成完整的样本
  51. arr_single_sample = np.concatenate((arr_single_sample, [label]), axis=0)
  52.  
  53. # print(arr_single_sample)
  54. # print(arr_single_sample.shape)
  55.  
  56. data[file_index, :] = arr_single_sample
  57.  
  58. # print(data)
  59. return data
  60.  
  61. def save_data(file_name, data):
  62. """
  63. 保存文件
  64. :param file_name: 保存的文件名称
  65. :param data: 保存的数组
  66. :return: None
  67. """
  68. if not os.path.exists("./data/"):
  69. os.makedirs("./data/")
  70.  
  71. np.save("./data/" + file_name, data)
  72.  
  73. def load_data(file_name):
  74. """
  75. 加载数据
  76. :param file_name:文件路径+ 名称
  77. :return: 数据
  78. """
  79. data = np.load(file_name, allow_pickle=True)
  80.  
  81. return data
  82.  
  83. def distance(v1, v2):
  84. """
  85. 计算距离
  86. :param v1: 点1
  87. :param v2: 点2
  88. :return: 距离
  89. """
  90. dist = np.sqrt(np.sum(np.power((v1 - v2), 2)))
  91.  
  92. return dist
  93.  
  94. def knn_owns(train, test, k):
  95. """
  96. 自定knn算法实现手写字识别
  97. :param train: 训练集数据
  98. :param test: 测试集数据
  99. :param k: 邻居个数
  100. :return: 准确率
  101. """
  102. # 设置计数器
  103. true_num = 0
  104. # 获取训练集的特征值 目标值
  105. train_x = train.iloc[:, :-1].values
  106. train_y = train.iloc[:, -1].values
  107. # 获取测试集的特征值 目标值
  108. test_x = test.iloc[:, :-1].values
  109. test_y = test.iloc[:, -1].values
  110. # 计算每一个测试样本特征与每一个训练样本特征的距离
  111. for i in range(test.shape[0]): # 循环每一个 测试样本
  112. for j in range(train.shape[0]):
  113. # 计算距离
  114. dist = distance(test_x[i,:],train_x[j,:])
  115. train.loc[j,'dist'] = dist
  116.  
  117. res = train.sort_values(by='dist')
  118.  
  119. mode = res.iloc[:,-2][:k].mode()[0]
  120.  
  121. if mode == test_y[i]:
  122. true_num += 1
  123. # print(test_y)
  124.  
  125. score = true_num / test.shape[0]
  126.  
  127. print(score)
  128.  
  129. return score
  130.  
  131. # train_data = build_data("./trainingDigits")
  132. # test_data = build_data("./testDigits")
  133. #
  134. # save_data("train_data",train_data)
  135. # save_data("test_data",test_data)
  136.  
  137. # 加载数据
  138. train = load_data("./data/train_data.npy")
  139. test = load_data("./data/test_data.npy")
  140.  
  141. train = pd.DataFrame(train)
  142. test = pd.DataFrame(test)
  143.  
  144. # print(train)
  145. # print("*"*80)
  146. # print(test)
  147. k_list = [5,6,7,8,9,10]
  148. score_list = []
  149. for k in k_list:
  150. # score = knn_owns(train, test, k)
  151. # score_list.append(score)
  152. knn = KNeighborsClassifier(n_neighbors=k)
  153. #训练数据
  154. knn.fit(train.iloc[:,:-1].values,train.iloc[:,-1].values)
  155. # 进行预测
  156. y_predict = knn.predict(test.iloc[:,:-1].values)
  157.  
  158. # 可以获取准确率
  159. score = knn.score(test.iloc[:,:-1].values,test.iloc[:,-1].values)
  160.  
  161. score_list.append(score)
  162. print(score_list)
  163.  
  164. #进行结果可视化
  165. # 1、创建画布
  166. plt.figure()
  167. # 默认不支持中文,需要配置RC 参数
  168. plt.rcParams['font.sans-serif']='SimHei'
  169. # 设置字体之后不支持负号,需要去设置RC参数更改编码
  170. plt.rcParams['axes.unicode_minus']=False
  171. # 2、绘图
  172. x = np.array(k_list)
  173. y = np.array(score_list)
  174.  
  175. plt.plot(x,y)
  176.  
  177. plt.title("k与准确率的关系走势图")
  178. plt.xlabel("k值")
  179. plt.ylabel("准确率")
  180. plt.savefig("./k值对准确率的影响.png")
  181. # 3、展示
  182.  
  183. plt.show()

  

knn算法手写字识别案例的更多相关文章

  1. 【Machine Learning】KNN算法虹膜图片识别

    K-近邻算法虹膜图片识别实战 作者:白宁超 2017年1月3日18:26:33 摘要:随着机器学习和深度学习的热潮,各种图书层出不穷.然而多数是基础理论知识介绍,缺乏实现的深入理解.本系列文章是作者结 ...

  2. 用TensorFlow教你手写字识别

    博主原文链接:用TensorFlow教你做手写字识别(准确率94.09%) 如需转载,请备注出处及链接,谢谢. 2012 年,Alex Krizhevsky, Geoff Hinton, and Il ...

  3. k最邻近算法——使用kNN进行手写识别

    上篇文章中提到了使用pillow对手写文字进行预处理,本文介绍如何使用kNN算法对文字进行识别. 基本概念 k最邻近算法(k-Nearest Neighbor, KNN),是机器学习分类算法中最简单的 ...

  4. 机器学习实战kNN之手写识别

    kNN算法算是机器学习入门级绝佳的素材.书上是这样诠释的:“存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都有标签,即我们知道样本集中每一条数据与所属分类的对应关系.输入没有标签的新数据 ...

  5. python 实现 KNN 分类器——手写识别

    1 算法概述 1.1 优劣 优点:进度高,对异常值不敏感,无数据输入假定 缺点:计算复杂度高,空间复杂度高 应用:主要用于文本分类,相似推荐 适用数据范围:数值型和标称型 1.2 算法伪代码 (1)计 ...

  6. tensorflow卷积神经网络与手写字识别

    1.知识点 """ 基础知识: 1.神经网络(neural networks)的基本组成包括输入层.隐藏层.输出层.而卷积神经网络的特点在于隐藏层分为卷积层和池化层(po ...

  7. k-近邻算法-手写识别系统

    手写数字是32x32的黑白图像.为了能使用KNN分类器,我们需要把32x32的二进制图像转换为1x1024 1. 将图像转化为向量 from numpy import * # 导入科学计算包numpy ...

  8. tensorflow神经网络与单层手写字识别

    1.知识点 """ 1.基础知识: 1.神经网络结构:1.输入层 2.隐含层 3.全连接层(类别个数=全连接层神经元个数)+softmax函数 4.输出层 2.逻辑回归: ...

  9. 基于PyTorch实现MNIST手写字识别

    本篇不涉及模型原理,只是分享下代码.想要了解模型原理的可以去看网上很多大牛的博客. 目前代码实现了CNN和LSTM两个网络,整个代码分为四部分: Config:项目中涉及的参数: CNN:卷积神经网络 ...

随机推荐

  1. C#for(;;)是什么意思?

    一,正常for循环我们都接触过很多,如下,我们都理解 ,,,,, }; ; i < ; i++) { Console.WriteLine(tt[i]); } 二,但是for(;;)实际上它的含义 ...

  2. org.springframework.beans.factory.NoSuchBeanDefinitionException: No qualifying bean of type [dx.service.ItemService] found for dependency

    在整合ssm框架,测试service层的时候报错 Caused by: org.springframework.beans.factory.NoSuchBeanDefinitionException: ...

  3. webpack的一般性配置及说明

    1.webpack的常规配置 先给出一个示例: const path = require('path'); const HtmlWebpackPlugin = require('html-webpac ...

  4. windows 2012 R2 及 centos 7.X 禁用不必要服务

    8.windows 2012 R2 及 centos 7.X 禁用不必要服务 React VR 技术开发群 579149907 1.windows2012 R2 可以禁用以下不必要的服务,以下禁用的服 ...

  5. Simple GB28181 System

    I. Deployment  / Architecture Block Diagram II. Resources Used 1. freeswitch —— sip server and media ...

  6. Jupyter配置工作路径

    在修改之前,C:\Users\Administrator\ .jupyter 目录下面只有一个“migrated”文件. 打开命令窗口(运行->cmd),进入python的Script目录下输入 ...

  7. java并发学习--第六章 线程之间的通信

    一.等待通知机制wait()与notify() 在线程中除了线程同步机制外,还有一个最重要的机制就是线程之间的协调任务.比如说最常见的生产者与消费者模式,很明显如果要实现这个模式,我们需要创建两个线程 ...

  8. 将postgresql中的数据实时同步到kafka中

    参考地址:https://blog.csdn.net/weixin_33985507/article/details/92460419 参考地址:https://mp.weixin.qq.com/s/ ...

  9. 运行 tensorboard

    使用下面命令总是报错: tensorboard --logdir=mylogdir tensorboard --logdir='./mylogdir' 正确命令 tensorboard --logdi ...

  10. Django中ifequal 和ifnotequal的使用

    Django中{% ifequal A B %} 用来比较A和B两个值是否相等,{% ifnotequal A B %}` 用来比较A和B两个值是否不相等..如: {% ifequal user cu ...