公号:码农充电站pro

主页:https://codeshellme.github.io

上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字

1,手写数字数据集

手写数字数据集是一个用于图像处理的数据集,这些数据描绘了 [0, 9] 的数字,我们可以用KNN 算法来识别这些数字。

MNIST 是完整的手写数字数据集,其中包含了60000 个训练样本和10000 个测试样本。

sklearn 中也有一个自带的手写数字数据集

  • 共包含 1797 个数据样本,每个样本描绘了一个 8*8 像素的 [0, 9] 的数字。
  • 每个样本由 65 个数字组成:
    • 前 64 个数字是特征数据,特征数据的范围是 [0, 16]
    • 最后一个数字是目标数据,目标数据的范围是 [0, 9]

我们抽出 5 个样本来看下:

  1. 0,0,5,13,9,1,0,0,0,0,13,15,10,15,5,0,0,3,15,2,0,11,8,0,0,4,12,0,0,8,8,0,0,5,8,0,0,9,8,0,0,4,11,0,1,12,7,0,0,2,14,5,10,12,0,0,0,0,6,13,10,0,0,0,0
  2. 0,0,0,12,13,5,0,0,0,0,0,11,16,9,0,0,0,0,3,15,16,6,0,0,0,7,15,16,16,2,0,0,0,0,1,16,16,3,0,0,0,0,1,16,16,6,0,0,0,0,1,16,16,6,0,0,0,0,0,11,16,10,0,0,1
  3. 0,0,0,4,15,12,0,0,0,0,3,16,15,14,0,0,0,0,8,13,8,16,0,0,0,0,1,6,15,11,0,0,0,1,8,13,15,1,0,0,0,9,16,16,5,0,0,0,0,3,13,16,16,11,5,0,0,0,0,3,11,16,9,0,2
  4. 0,0,7,15,13,1,0,0,0,8,13,6,15,4,0,0,0,2,1,13,13,0,0,0,0,0,2,15,11,1,0,0,0,0,0,1,12,12,1,0,0,0,0,0,1,10,8,0,0,0,8,4,5,14,9,0,0,0,7,13,13,9,0,0,3
  5. 0,0,0,1,11,0,0,0,0,0,0,7,8,0,0,0,0,0,1,13,6,2,2,0,0,0,7,15,0,9,8,0,0,5,16,10,0,16,6,0,0,4,15,16,13,16,1,0,0,0,0,3,15,10,0,0,0,0,0,2,16,4,0,0,4

使用该数据集,需要先加载:

  1. >>> from sklearn.datasets import load_digits
  2. >>> digits = load_digits()

查看第一个图像数据:

  1. >>> digits.images[0]
  2. array([[ 0., 0., 5., 13., 9., 1., 0., 0.],
  3. [ 0., 0., 13., 15., 10., 15., 5., 0.],
  4. [ 0., 3., 15., 2., 0., 11., 8., 0.],
  5. [ 0., 4., 12., 0., 0., 8., 8., 0.],
  6. [ 0., 5., 8., 0., 0., 9., 8., 0.],
  7. [ 0., 4., 11., 0., 1., 12., 7., 0.],
  8. [ 0., 2., 14., 5., 10., 12., 0., 0.],
  9. [ 0., 0., 6., 13., 10., 0., 0., 0.]])

我们可以用 matplotlib 将该图像画出来:

  1. >>> import matplotlib.pyplot as plt
  2. >>> plt.imshow(digits.images[0])
  3. >>> plt.show()

画出来的图像如下,代表 0

2,sklearn 对 KNN 算法的实现

sklearn 库的 neighbors 模块实现了KNN 相关算法,其中:

  • KNeighborsClassifier 类用于分类问题
  • KNeighborsRegressor 类用于回归问题

这两个类的构造方法基本一致,这里我们主要介绍 KNeighborsClassifier 类,原型如下:

  1. KNeighborsClassifier(
  2. n_neighbors=5,
  3. weights='uniform',
  4. algorithm='auto',
  5. leaf_size=30,
  6. p=2,
  7. metric='minkowski',
  8. metric_params=None,
  9. n_jobs=None,
  10. **kwargs)

来看下几个重要参数的含义:

  • n_neighbors:即 KNN 中的 K 值,一般使用默认值 5。
  • weights:用于确定邻居的权重,有三种方式:
    • weights=uniform,表示所有邻居的权重相同。
    • weights=distance,表示权重是距离的倒数,即与距离成反比。
    • 自定义函数,可以自定义不同距离所对应的权重,一般不需要自己定义函数。
  • algorithm:用于设置计算邻居的算法,它有四种方式:
    • algorithm=auto,根据数据的情况自动选择适合的算法。
    • algorithm=kd_tree,使用 KD 树 算法。
      • KD 树是一种多维空间的数据结构,方便对数据进行检索。
      • KD 树适用于维度较少的情况,一般维数不超过 20,如果维数大于 20 之后,效率会下降。
    • algorithm=ball_tree,使用球树算法。
      • KD 树一样都是多维空间的数据结构。
      • 球树更适用于维度较大的情况。
    • algorithm=brute,称为暴力搜索
      • 它和 KD 树相比,采用的是线性扫描,而不是通过构造树结构进行快速检索。
      • 缺点是,当训练集较大的时候,效率很低。
    • leaf_size:表示构造 KD 树球树时的叶子节点数,默认是 30。
      • 调整 leaf_size 会影响树的构造和搜索速度。

3,构造 KNN 分类器

首先加载数据集:

  1. from sklearn.datasets import load_digits
  2. digits = load_digits()
  3. data = digits.data # 特征集
  4. target = digits.target # 目标集

将数据集拆分为训练集(75%)和测试集(25%),

  1. from sklearn.model_selection import train_test_split
  2. train_x, test_x, train_y, test_y = train_test_split(
  3. data, target, test_size=0.25, random_state=33)

构造KNN 分类器:

  1. from sklearn.neighbors import KNeighborsClassifier
  2. # 采用默认参数
  3. knn = KNeighborsClassifier()

拟合模型:

  1. knn.fit(train_x, train_y)

预测数据:

  1. predict_y = knn.predict(test_x)

计算模型准确度:

  1. from sklearn.metrics import accuracy_score
  2. score = accuracy_score(test_y, predict_y)
  3. print score # 0.98

最终计算出来模型的准确度是 98%,准确度还是不错的。

4,总结

本篇文章使用KNN 算法处理了一个实际的分类问题,主要介绍了以下几点:

  • 介绍了sklearn 中自带的手写数字集,并用 matplotlib 模块画出了数字图像。
  • 介绍了sklearnneighbors.KNeighborsClassifier 类的用法。
  • 使用 KNeighborsClassifier 来识别手写数字。

(本节完。)


推荐阅读:

KNN 算法-理论篇-如何给电影进行分类

决策树算法-理论篇-如何计算信息纯度

决策树算法-实战篇-鸢尾花及波士顿房价预测

朴素贝叶斯分类-理论篇-如何通过概率解决分类问题

朴素贝叶斯分类-实战篇-如何进行文本分类


欢迎关注作者公众号,获取更多技术干货。

KNN 算法-实战篇-如何识别手写数字的更多相关文章

  1. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  2. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  3. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  4. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  5. python手写神经网络实现识别手写数字

    写在开头:这个实验和matlab手写神经网络实现识别手写数字一样. 实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手 ...

  6. 用BP人工神经网络识别手写数字

    http://wenku.baidu.com/link?url=HQ-5tZCXBQ3uwPZQECHkMCtursKIpglboBHq416N-q2WZupkNNH3Gv4vtEHyPULezDb5 ...

  7. python机器学习使用PCA降维识别手写数字

    PCA降维识别手写数字 关注公众号"轻松学编程"了解更多. PCA 用于数据降维,减少运算时间,避免过拟合. PCA(n_components=150,whiten=True) n ...

  8. 3 TensorFlow入门之识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  9. KNN (K近邻算法) - 识别手写数字

    KNN项目实战——手写数字识别 1. 介绍 k近邻法(k-nearest neighbor, k-NN)是1967年由Cover T和Hart P提出的一种基本分类与回归方法.它的工作原理是:存在一个 ...

随机推荐

  1. Linux 系统编程 学习:07-基于socket的网络编程2:基于 UDP 的通信

    Linux 系统编程 学习:07-基于socket的网络编程2:基于 UDP 的通信 背景 上一讲我们介绍了网络编程的一些概念.socket的网络编程的有关概念 这一讲我们来看UDP 通信. 知识 U ...

  2. (三)URI、URL和URN/GET与POST的区别

    (一)URI.URL.URN HTTP使用统一资源标识符(Uniform Resource Identifiers,URI)来传输数据和建立连接. URL是一种特殊类型的URI,包含了用于查找某个资源 ...

  3. 【Redis】Redis 持久化之 RDB 与 AOF 详解

    一.Redis 持久化 我们知道Redis的数据是全部存储在内存中的,如果机器突然GG,那么数据就会全部丢失,因此需要有持久化机制来保证数据不会一位宕机而丢失.Redis 为我们提供了两种持久化方案, ...

  4. 基于synchronized锁的深度解析

    1. 问题引入 小伙伴们都接触过线程,也都会使用线程,今天我们要讲的是线程安全相关的内容,在这之前我们先来看一个简单的代码案例. 代码案例: /** * @url: i-code.online * @ ...

  5. MySQL全面瓦解7:查询的过滤条件

    概述 在实际的业务场景应用中,我们经常要根据业务条件获取并筛选出我们的目标数据.这个过程我们称之为数据查询的过滤.而过滤过程使用的各种条件(比如日期时间.用户.状态)是我们获取精准数据的必要步骤, 这 ...

  6. tp3.2关闭debug save方法执行失败

    解决该问题需要 清除缓存文件 将retime下的文件删除

  7. Vuex原理详解

    一.Vuex是什么 Vuex是专门为Vuejs应用程序设计的状态管理工具.它采用集中式存储管理应用的所有组件的状态,并以相应的规则保证状态以一种可预测的方式发生改变.它集中于MVC模式中的Model层 ...

  8. idea开发工具下,进行多个线程切换调试

  9. python文件操作与编解码

    1 # 文件操作 2 3 ''' 4 1.文件路径:要知道文件的路径 5 6 2.编码方式:要知道文件是什么编码的.utf-8 gbk...... 7 8 3.操作方式:要以什么样的方式进行打开这个文 ...

  10. Spring源码理论

    Spring Bean的创建过程: Spring容器获取Bean和创建Bean都会调用getBean()方法. getBean()方法 1)getBean()方法内部最终调用doGetBean()方法 ...