KNN实现手写数字识别
KNN实现手写数字识别
博客上显示这个没有Jupyter的好看,想看Jupyter Notebook的请戳KNN实现手写数字识别.ipynb
1 - 导入模块
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from ld_mnist import load_digits
%matplotlib inline
2 - 导入数据及数据预处理
import tensorflow as tf
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
def load_digits():
mnist = input_data.read_data_sets("path/", one_hot=True)
return mnist
mnist = load_digits()
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-labels-idx1-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-labels-idx1-ubyte.gz
数据维度
print("Train: "+ str(mnist.train.images.shape))
print("Train: "+ str(mnist.train.labels.shape))
print("Test: "+ str(mnist.test.images.shape))
print("Test: "+ str(mnist.test.labels.shape))
Train: (55000, 784)
Train: (55000, 10)
Test: (10000, 784)
Test: (10000, 10)
mnist数据采用的是TensorFlow的一个函数进行读取的,由上面的结果可以知道训练集数据X_train有55000个,每个X的数据长度是784(28*28)。
x_train, y_train, x_test, y_test = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
展示手写数字
nums = 6
for i in range(1,nums+1):
plt.subplot(1,nums,i)
plt.imshow(x_train[i].reshape(28,28), cmap="gray")
3 - 构建模型
class Knn():
def __init__(self,k):
self.k = k
self.distance = {}
def topKDistance(self, x_train, x_test):
'''
计算距离,这里采用欧氏距离
'''
print("计算距离...")
distance = {}
for i in range(x_test.shape[0]):
dis1 = x_train - x_test[i]
dis2 = np.sqrt(np.sum(dis1*dis1, axis=1))
distance[str(i)] = np.argsort(dis2)[:self.k]
if i%1000==0:
print(distance[str(i)])
return distance
def predict(self, x_train, y_train, x_test):
'''
预测
'''
self.distance = self.topKDistance(x_train, x_test)
y_hat = []
print("选出每项最佳预测结果")
for i in range(x_test.shape[0]):
classes = {}
for j in range(self.k):
num = np.argmax(y_train[self.distance[str(i)][j]])
classes[num] = classes.get(num, 0) + 1
sortClasses = sorted(classes.items(), key= lambda x:x[1], reverse=True)
y_hat.append(sortClasses[0][0])
y_hat = np.array(y_hat).reshape(-1,1)
return y_hat
def fit(self, x_train, y_train, x_test, y_test):
'''
计算准确率
'''
print("预测...")
y_hat = self.predict(x_train, y_train, x_test)
# index_hat =np.argmax(y_hat , axis=1)
print("计算准确率...")
index_test = np.argmax(y_test, axis=1).reshape(1,-1)
accuracy = np.sum(y_hat.reshape(index_test.shape) == index_test)*1.0/y_test.shape[0]
return accuracy, y_hat
clf = Knn(10)
accuracy, y_hat = clf.fit(x_train,y_train,x_test,y_test)
print(accuracy)
预测...
计算距离...
[48843 33620 11186 22059 42003 9563 39566 10260 35368 31395]
[54214 4002 11005 15264 49069 8791 38147 47304 51494 11053]
[46624 10708 22134 20108 48606 19774 7855 43740 51345 9308]
[ 8758 47844 50994 45610 1930 3312 30140 17618 910 51918]
[14953 1156 50024 26833 26006 38112 31080 9066 32112 41846]
[45824 14234 48282 28432 50966 22786 40902 52264 38552 44080]
[24878 4655 20258 36065 30755 15075 35584 12152 4683 43255]
[48891 20744 47822 53511 54545 27392 10240 3970 25721 30357]
[ 673 17747 33803 20960 25463 35723 969 50577 36714 35719]
[ 8255 42067 53282 14383 14073 52083 7233 8199 8963 12617]
选出每项最佳预测结果
计算准确率...
0.9672
准确率略高。
KNN实现手写数字识别的更多相关文章
- 机器学习(二)-kNN手写数字识别
一.kNN算法是机器学习的入门算法,其中不涉及训练,主要思想是计算待测点和参照点的距离,选取距离较近的参照点的类别作为待测点的的类别. 1,距离可以是欧式距离,夹角余弦距离等等. 2,k值不能选择太大 ...
- 一看就懂的K近邻算法(KNN),K-D树,并实现手写数字识别!
1. 什么是KNN 1.1 KNN的通俗解释 何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1 ...
- kaggle 实战 (1): PCA + KNN 手写数字识别
文章目录 加载package read data PCA 降维探索 选择50维度, 拆分数据为训练集,测试机 KNN PCA降维和K值筛选 分析k & 维度 vs 精度 预测 生成提交文件 本 ...
- Kaggle竞赛丨入门手写数字识别之KNN、CNN、降维
引言 这段时间来,看了西瓜书.蓝皮书,各种机器学习算法都有所了解,但在实践方面却缺乏相应的锻炼.于是我决定通过Kaggle这个平台来提升一下自己的应用能力,培养自己的数据分析能力. 我个人的计划是先从 ...
- 基于OpenCV的KNN算法实现手写数字识别
基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...
- K近邻实战手写数字识别
1.导包 import numpy as np import operator from os import listdir from sklearn.neighbors import KNeighb ...
- C#中调用Matlab人工神经网络算法实现手写数字识别
手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化 投影 矩阵 目标定位 Matlab 手写数字图像识别简介: 手写 ...
- CNN 手写数字识别
1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...
- 【深度学习系列】PaddlePaddle之手写数字识别
上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...
随机推荐
- PHP学习笔记2
PHP Switch语句 用于根据多个不同条件执行不同动作.如果不在每个条件后加break,将会输出所有结果. <?php $language="java"; switch( ...
- JAVA每日一旅2
1.关于类型转换 两个数值进行二元操作时,会有如下的转换操作: 如果两个操作数其中有一个是double类型,另一个操作就会转换为double类型. 否则,如果其中一个操作数是float类型,另一个将会 ...
- one team
Double H Team 1.队员 王熙航211606379(队长) 李冠锐211606364 曾磊鑫211606350 戴俊涵211606359 聂寒冰211606324 杨艺勇211606342 ...
- Fibbing以让虚结点的设置更简单为目的优化网络需求
- Beta阶段冲刺-5
一. 每日会议 1. 照片 2. 昨日完成工作 3. 今日完成工作 4. 工作中遇到的困难 杨晨露:现在我过的某种意义上挺滋润的,没啥事了都.......咳,困难就是前端每天都在想砸电脑,我要怎么阻止 ...
- JS基础(三)语句
一.判断语句(PS:一般情况下判断条件最终应该是一个布尔值.) 1.if语句 1)基本格式 if(判断条件){ 如果判断条件成立则执行的语句 }else{ 如果判断条件不成立则执行的语句 } 2)扩展 ...
- Linux基础二(挂载、关机重启与系统等级)
一.Linux 基础之挂载 1. 挂载和查询 1.1 挂载 什么叫挂载?装系统的时候要给硬盘分区,在 Windows 中要分 C 盘 D 盘 DEF 盘,这个操作我们叫做分配盘符,分配盘符之后我们就可 ...
- 暑假学习笔记(一)——初识Neo4j和APICloud入门
暑假学习笔记(一)--初识Neo4j和APICloud入门 20180719笔记 1.Neo4j 接了学姐的系统测试报告任务,感觉工作很繁重,但是自己却每天挥霍时光.9月份就要提交系统测试报告了,但是 ...
- 使用 TListView 控件(4)
本例效果图: 代码文件: unit Unit1; interface uses Windows, Messages, SysUtils, Variants, Classes, Graphics, ...
- 如何将img垂直居中?
方法一: 这种方法可实现图片超出frame尺寸时,自动选择水平.垂直居中,效果如下 <div class="frame"> <img src="foo& ...