KNN实现手写数字识别


博客上显示这个没有Jupyter的好看,想看Jupyter Notebook的请戳KNN实现手写数字识别.ipynb

1 - 导入模块

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from ld_mnist import load_digits %matplotlib inline

2 - 导入数据及数据预处理

import tensorflow as tf

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data def load_digits():
mnist = input_data.read_data_sets("path/", one_hot=True)
return mnist
mnist = load_digits()
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-labels-idx1-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-labels-idx1-ubyte.gz

数据维度

print("Train: "+ str(mnist.train.images.shape))
print("Train: "+ str(mnist.train.labels.shape))
print("Test: "+ str(mnist.test.images.shape))
print("Test: "+ str(mnist.test.labels.shape))
Train: (55000, 784)
Train: (55000, 10)
Test: (10000, 784)
Test: (10000, 10)

mnist数据采用的是TensorFlow的一个函数进行读取的,由上面的结果可以知道训练集数据X_train有55000个,每个X的数据长度是784(28*28)。

x_train, y_train, x_test, y_test = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

展示手写数字

nums = 6
for i in range(1,nums+1):
plt.subplot(1,nums,i)
plt.imshow(x_train[i].reshape(28,28), cmap="gray")

3 - 构建模型

class Knn():

    def __init__(self,k):
self.k = k
self.distance = {} def topKDistance(self, x_train, x_test):
'''
计算距离,这里采用欧氏距离
'''
print("计算距离...")
distance = {}
for i in range(x_test.shape[0]):
dis1 = x_train - x_test[i]
dis2 = np.sqrt(np.sum(dis1*dis1, axis=1))
distance[str(i)] = np.argsort(dis2)[:self.k]
if i%1000==0:
print(distance[str(i)])
return distance def predict(self, x_train, y_train, x_test):
'''
预测
'''
self.distance = self.topKDistance(x_train, x_test)
y_hat = []
print("选出每项最佳预测结果")
for i in range(x_test.shape[0]):
classes = {}
for j in range(self.k):
num = np.argmax(y_train[self.distance[str(i)][j]])
classes[num] = classes.get(num, 0) + 1
sortClasses = sorted(classes.items(), key= lambda x:x[1], reverse=True)
y_hat.append(sortClasses[0][0])
y_hat = np.array(y_hat).reshape(-1,1)
return y_hat def fit(self, x_train, y_train, x_test, y_test):
'''
计算准确率
'''
print("预测...")
y_hat = self.predict(x_train, y_train, x_test)
# index_hat =np.argmax(y_hat , axis=1)
print("计算准确率...")
index_test = np.argmax(y_test, axis=1).reshape(1,-1)
accuracy = np.sum(y_hat.reshape(index_test.shape) == index_test)*1.0/y_test.shape[0]
return accuracy, y_hat
clf = Knn(10)
accuracy, y_hat = clf.fit(x_train,y_train,x_test,y_test)
print(accuracy)
预测...
计算距离...
[48843 33620 11186 22059 42003 9563 39566 10260 35368 31395]
[54214 4002 11005 15264 49069 8791 38147 47304 51494 11053]
[46624 10708 22134 20108 48606 19774 7855 43740 51345 9308]
[ 8758 47844 50994 45610 1930 3312 30140 17618 910 51918]
[14953 1156 50024 26833 26006 38112 31080 9066 32112 41846]
[45824 14234 48282 28432 50966 22786 40902 52264 38552 44080]
[24878 4655 20258 36065 30755 15075 35584 12152 4683 43255]
[48891 20744 47822 53511 54545 27392 10240 3970 25721 30357]
[ 673 17747 33803 20960 25463 35723 969 50577 36714 35719]
[ 8255 42067 53282 14383 14073 52083 7233 8199 8963 12617]
选出每项最佳预测结果
计算准确率...
0.9672

准确率略高。



MARSGGBO♥原创





2017-8-21

KNN实现手写数字识别的更多相关文章

  1. 机器学习(二)-kNN手写数字识别

    一.kNN算法是机器学习的入门算法,其中不涉及训练,主要思想是计算待测点和参照点的距离,选取距离较近的参照点的类别作为待测点的的类别. 1,距离可以是欧式距离,夹角余弦距离等等. 2,k值不能选择太大 ...

  2. 一看就懂的K近邻算法(KNN),K-D树,并实现手写数字识别!

    1. 什么是KNN 1.1 KNN的通俗解释 何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1 ...

  3. kaggle 实战 (1): PCA + KNN 手写数字识别

    文章目录 加载package read data PCA 降维探索 选择50维度, 拆分数据为训练集,测试机 KNN PCA降维和K值筛选 分析k & 维度 vs 精度 预测 生成提交文件 本 ...

  4. Kaggle竞赛丨入门手写数字识别之KNN、CNN、降维

    引言 这段时间来,看了西瓜书.蓝皮书,各种机器学习算法都有所了解,但在实践方面却缺乏相应的锻炼.于是我决定通过Kaggle这个平台来提升一下自己的应用能力,培养自己的数据分析能力. 我个人的计划是先从 ...

  5. 基于OpenCV的KNN算法实现手写数字识别

    基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...

  6. K近邻实战手写数字识别

    1.导包 import numpy as np import operator from os import listdir from sklearn.neighbors import KNeighb ...

  7. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  8. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  9. 【深度学习系列】PaddlePaddle之手写数字识别

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

随机推荐

  1. PHP学习笔记2

    PHP Switch语句 用于根据多个不同条件执行不同动作.如果不在每个条件后加break,将会输出所有结果. <?php $language="java"; switch( ...

  2. JAVA每日一旅2

    1.关于类型转换 两个数值进行二元操作时,会有如下的转换操作: 如果两个操作数其中有一个是double类型,另一个操作就会转换为double类型. 否则,如果其中一个操作数是float类型,另一个将会 ...

  3. one team

    Double H Team 1.队员 王熙航211606379(队长) 李冠锐211606364 曾磊鑫211606350 戴俊涵211606359 聂寒冰211606324 杨艺勇211606342 ...

  4. Fibbing以让虚结点的设置更简单为目的优化网络需求

  5. Beta阶段冲刺-5

    一. 每日会议 1. 照片 2. 昨日完成工作 3. 今日完成工作 4. 工作中遇到的困难 杨晨露:现在我过的某种意义上挺滋润的,没啥事了都.......咳,困难就是前端每天都在想砸电脑,我要怎么阻止 ...

  6. JS基础(三)语句

    一.判断语句(PS:一般情况下判断条件最终应该是一个布尔值.) 1.if语句 1)基本格式 if(判断条件){ 如果判断条件成立则执行的语句 }else{ 如果判断条件不成立则执行的语句 } 2)扩展 ...

  7. Linux基础二(挂载、关机重启与系统等级)

    一.Linux 基础之挂载 1. 挂载和查询 1.1 挂载 什么叫挂载?装系统的时候要给硬盘分区,在 Windows 中要分 C 盘 D 盘 DEF 盘,这个操作我们叫做分配盘符,分配盘符之后我们就可 ...

  8. 暑假学习笔记(一)——初识Neo4j和APICloud入门

    暑假学习笔记(一)--初识Neo4j和APICloud入门 20180719笔记 1.Neo4j 接了学姐的系统测试报告任务,感觉工作很繁重,但是自己却每天挥霍时光.9月份就要提交系统测试报告了,但是 ...

  9. 使用 TListView 控件(4)

    本例效果图: 代码文件: unit Unit1; interface uses   Windows, Messages, SysUtils, Variants, Classes, Graphics, ...

  10. 如何将img垂直居中?

    方法一: 这种方法可实现图片超出frame尺寸时,自动选择水平.垂直居中,效果如下 <div class="frame"> <img src="foo& ...