KNN实现手写数字识别
KNN实现手写数字识别
博客上显示这个没有Jupyter的好看,想看Jupyter Notebook的请戳KNN实现手写数字识别.ipynb
1 - 导入模块
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from ld_mnist import load_digits
%matplotlib inline
2 - 导入数据及数据预处理
import tensorflow as tf
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
def load_digits():
mnist = input_data.read_data_sets("path/", one_hot=True)
return mnist
mnist = load_digits()
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-labels-idx1-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-labels-idx1-ubyte.gz
数据维度
print("Train: "+ str(mnist.train.images.shape))
print("Train: "+ str(mnist.train.labels.shape))
print("Test: "+ str(mnist.test.images.shape))
print("Test: "+ str(mnist.test.labels.shape))
Train: (55000, 784)
Train: (55000, 10)
Test: (10000, 784)
Test: (10000, 10)
mnist数据采用的是TensorFlow的一个函数进行读取的,由上面的结果可以知道训练集数据X_train有55000个,每个X的数据长度是784(28*28)。
x_train, y_train, x_test, y_test = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
展示手写数字
nums = 6
for i in range(1,nums+1):
plt.subplot(1,nums,i)
plt.imshow(x_train[i].reshape(28,28), cmap="gray")
3 - 构建模型
class Knn():
def __init__(self,k):
self.k = k
self.distance = {}
def topKDistance(self, x_train, x_test):
'''
计算距离,这里采用欧氏距离
'''
print("计算距离...")
distance = {}
for i in range(x_test.shape[0]):
dis1 = x_train - x_test[i]
dis2 = np.sqrt(np.sum(dis1*dis1, axis=1))
distance[str(i)] = np.argsort(dis2)[:self.k]
if i%1000==0:
print(distance[str(i)])
return distance
def predict(self, x_train, y_train, x_test):
'''
预测
'''
self.distance = self.topKDistance(x_train, x_test)
y_hat = []
print("选出每项最佳预测结果")
for i in range(x_test.shape[0]):
classes = {}
for j in range(self.k):
num = np.argmax(y_train[self.distance[str(i)][j]])
classes[num] = classes.get(num, 0) + 1
sortClasses = sorted(classes.items(), key= lambda x:x[1], reverse=True)
y_hat.append(sortClasses[0][0])
y_hat = np.array(y_hat).reshape(-1,1)
return y_hat
def fit(self, x_train, y_train, x_test, y_test):
'''
计算准确率
'''
print("预测...")
y_hat = self.predict(x_train, y_train, x_test)
# index_hat =np.argmax(y_hat , axis=1)
print("计算准确率...")
index_test = np.argmax(y_test, axis=1).reshape(1,-1)
accuracy = np.sum(y_hat.reshape(index_test.shape) == index_test)*1.0/y_test.shape[0]
return accuracy, y_hat
clf = Knn(10)
accuracy, y_hat = clf.fit(x_train,y_train,x_test,y_test)
print(accuracy)
预测...
计算距离...
[48843 33620 11186 22059 42003 9563 39566 10260 35368 31395]
[54214 4002 11005 15264 49069 8791 38147 47304 51494 11053]
[46624 10708 22134 20108 48606 19774 7855 43740 51345 9308]
[ 8758 47844 50994 45610 1930 3312 30140 17618 910 51918]
[14953 1156 50024 26833 26006 38112 31080 9066 32112 41846]
[45824 14234 48282 28432 50966 22786 40902 52264 38552 44080]
[24878 4655 20258 36065 30755 15075 35584 12152 4683 43255]
[48891 20744 47822 53511 54545 27392 10240 3970 25721 30357]
[ 673 17747 33803 20960 25463 35723 969 50577 36714 35719]
[ 8255 42067 53282 14383 14073 52083 7233 8199 8963 12617]
选出每项最佳预测结果
计算准确率...
0.9672
准确率略高。
KNN实现手写数字识别的更多相关文章
- 机器学习(二)-kNN手写数字识别
一.kNN算法是机器学习的入门算法,其中不涉及训练,主要思想是计算待测点和参照点的距离,选取距离较近的参照点的类别作为待测点的的类别. 1,距离可以是欧式距离,夹角余弦距离等等. 2,k值不能选择太大 ...
- 一看就懂的K近邻算法(KNN),K-D树,并实现手写数字识别!
1. 什么是KNN 1.1 KNN的通俗解释 何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1 ...
- kaggle 实战 (1): PCA + KNN 手写数字识别
文章目录 加载package read data PCA 降维探索 选择50维度, 拆分数据为训练集,测试机 KNN PCA降维和K值筛选 分析k & 维度 vs 精度 预测 生成提交文件 本 ...
- Kaggle竞赛丨入门手写数字识别之KNN、CNN、降维
引言 这段时间来,看了西瓜书.蓝皮书,各种机器学习算法都有所了解,但在实践方面却缺乏相应的锻炼.于是我决定通过Kaggle这个平台来提升一下自己的应用能力,培养自己的数据分析能力. 我个人的计划是先从 ...
- 基于OpenCV的KNN算法实现手写数字识别
基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...
- K近邻实战手写数字识别
1.导包 import numpy as np import operator from os import listdir from sklearn.neighbors import KNeighb ...
- C#中调用Matlab人工神经网络算法实现手写数字识别
手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化 投影 矩阵 目标定位 Matlab 手写数字图像识别简介: 手写 ...
- CNN 手写数字识别
1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...
- 【深度学习系列】PaddlePaddle之手写数字识别
上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...
随机推荐
- Linux内核分析— —计算机是如何工作的(20135213林涵锦)
实验部分 (以下命令为实验楼64位Linux虚拟机环境下适用,32位Linux环境可能会稍有不同) 使用 gcc –S –o main.s main.c -m32 命令编译成汇编代码, int g(i ...
- 《Linux内核分析》第五周:分析system_call中断处理过程
实验 分析system_call中断处理过程 使用gdb跟踪分析一个系统调用内核函数(您上周选择那一个系统调用),系统调用列表参见http://codelab.shiyanlou.com/xref/l ...
- C#代码分析--阅读程序,回答问题
阅读下面程序,请回答如下问题: 问题1:这个程序要找的是符合什么条件的数? 问题2:这样的数存在么?符合这一条件的最小的数是什么? 问题3:在电脑上运行这一程序,你估计多长时间才能输出第一个结果?时间 ...
- mosquitto集群配置
--------------------------------------------------------前言------------------------------------------ ...
- CodeM Qualifying Match Q2
问题描述: 组委会正在为美团点评CodeM大赛的决赛设计新赛制. 比赛有 n 个人参加(其中 n 为2的幂),每个参赛者根据资格赛和预赛.复赛的成绩,会有不同的积分. 比赛采取锦标赛赛制,分轮次进行, ...
- [Delphi]实现使用TIdHttp控件向https地址Post请求[转]
开篇:公司之前一直使用http协议进行交互(比如登录等功能),但是经常被爆安全性不高,所以准备改用https协议.百度了一下资料,其实使用IdHttp控件实现https交互的帖子并不少,鉴于这次成功实 ...
- 51nod 1476 括号序列的最小代价(贪心+优先队列)
题意 我们这有一种仅由"(",")"和"?"组成的括号序列,你必须将"?"替换成括号,从而得到一个合法的括号序列. 对于 ...
- Mobile Phone Network CodeForces - 1023F(并查集lca+修改环)
题意: 就是有几个点,你掌控了几条路,你的商业对手也掌控了几条路,然后你想让游客都把你的所有路都走完,那么你就有钱了,但你又想挣的钱最多,真是的过分..哈哈 游客肯定要对比一下你的对手的路 看看那个便 ...
- 【刷题】LOJ 6002 「网络流 24 题」最小路径覆盖
题目描述 给定有向图 \(G = (V, E)\) .设 \(P\) 是 \(G\) 的一个简单路(顶点不相交)的集合.如果 \(V\) 中每个顶点恰好在 \(P\) 的一条路上,则称 \(P\) 是 ...
- 【转】STM32 - 程序跳转、中断、开关总中断
程序跳转注意: 1.如果跳转之前的程序A里有些中断没有关,在跳转之后程序B的中断触发,但程序B里没有定义中断响应函数,找不到地址会导致死机. 2.程序跳转前关总中断,程序跳转后开总中断(关总中断,只是 ...