# MNIST Digit Prediction with k-Nearest Neighbors
#-----------------------------------------------
#
# This script will load the MNIST data, and split
# it into test/train and perform prediction with
# nearest neighbors
#
# For each test integer, we will return the
# closest image/integer.
#
# Integer images are represented as 28x8 matrices
# of floating point numbers import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework import ops
ops.reset_default_graph() # Create graph
sess = tf.Session() # Load the data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # Random sample
np.random.seed(13) # set seed for reproducibility
train_size = 1000
test_size = 102
rand_train_indices = np.random.choice(len(mnist.train.images), train_size, replace=False)
rand_test_indices = np.random.choice(len(mnist.test.images), test_size, replace=False)
x_vals_train = mnist.train.images[rand_train_indices]
x_vals_test = mnist.test.images[rand_test_indices]
y_vals_train = mnist.train.labels[rand_train_indices]
y_vals_test = mnist.test.labels[rand_test_indices] # Declare k-value and batch size
k = 4
batch_size=6 # Placeholders
x_data_train = tf.placeholder(shape=[None, 784], dtype=tf.float32)
x_data_test = tf.placeholder(shape=[None, 784], dtype=tf.float32)
y_target_train = tf.placeholder(shape=[None, 10], dtype=tf.float32)
y_target_test = tf.placeholder(shape=[None, 10], dtype=tf.float32) # Declare distance metric
# L1
distance = tf.reduce_sum(tf.abs(tf.subtract(x_data_train, tf.expand_dims(x_data_test,1))), axis=2) # L2
#distance = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(x_data_train, tf.expand_dims(x_data_test,1))), reduction_indices=1)) # Predict: Get min distance index (Nearest neighbor)
top_k_xvals, top_k_indices = tf.nn.top_k(tf.negative(distance), k=k)
prediction_indices = tf.gather(y_target_train, top_k_indices)
# Predict the mode category
count_of_predictions = tf.reduce_sum(prediction_indices, axis=1)
prediction = tf.argmax(count_of_predictions, axis=1) # Calculate how many loops over training data
num_loops = int(np.ceil(len(x_vals_test)/batch_size)) test_output = []
actual_vals = []
for i in range(num_loops):
min_index = i*batch_size
max_index = min((i+1)*batch_size,len(x_vals_train))
x_batch = x_vals_test[min_index:max_index]
y_batch = y_vals_test[min_index:max_index]
predictions = sess.run(prediction, feed_dict={x_data_train: x_vals_train, x_data_test: x_batch,
y_target_train: y_vals_train, y_target_test: y_batch})
test_output.extend(predictions)
actual_vals.extend(np.argmax(y_batch, axis=1)) accuracy = sum([1./test_size for i in range(test_size) if test_output[i]==actual_vals[i]])
print('Accuracy on test set: ' + str(accuracy)) # Plot the last batch results:
actuals = np.argmax(y_batch, axis=1) Nrows = 2
Ncols = 3
for i in range(len(actuals)):
plt.subplot(Nrows, Ncols, i+1)
plt.imshow(np.reshape(x_batch[i], [28,28]), cmap='Greys_r')
plt.title('Actual: ' + str(actuals[i]) + ' Pred: ' + str(predictions[i]),
fontsize=10)
frame = plt.gca()
frame.axes.get_xaxis().set_visible(False)
frame.axes.get_yaxis().set_visible(False) plt.show()

效果:

tensorflow knn mnist的更多相关文章

  1. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  2. Ubuntu16.04安装TensorFlow及Mnist训练

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com TensorFlow是Google开发的开源的深度学习框架,也是当前使用最广泛的深度学习框架. 一.安 ...

  3. 一个简单的TensorFlow可视化MNIST数据集识别程序

    下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...

  4. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  5. 使用Tensorflow操作MNIST数据

    MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会被用作深度学习的入门样例.而TensorFlow的封装让使用MNIST数据集变得更加方便.MNIST数据集是NIST数据集的 ...

  6. TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架

    TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架 http://blog.sina.com.cn/s/blog_4b0020f30102wv4l.html

  7. 2、TensorFlow训练MNIST

    装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...

  8. 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门

    2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...

  9. Tensorflow之MNIST的最佳实践思路总结

    Tensorflow之MNIST的最佳实践思路总结   在上两篇文章中已经总结出了深层神经网络常用方法和Tensorflow的最佳实践所需要的知识点,如果对这些基础不熟悉,可以返回去看一下.在< ...

随机推荐

  1. 管理voting disks

     管理voting disks 一.关于voting disk 的一些必需要知道的东西: 11g 曾经我们能够使用dd 命令来备份voting disk ,可是在11g 以后 oracle 不再支 ...

  2. ck-reset css(2016/5/13)

    /**rest by 2016/05/04 */ * {box-sizing: border-box;} *:before,*:after {box-sizing: border-box;} body ...

  3. 【Java】Spring Web MVC注意事项

    本文内容可能是书上没有的,至少是<Java Web整合开发实践>这本书上没有的.这是初学Spring的笔者走过的弯路,谨记以自勉. 这两天学习Spring WebMVC,照着书依葫芦画瓢写 ...

  4. 有一个投篮游戏。球场有p个篮筐,编号为0,1...,p-1。每个篮筐下有个袋子,每个袋子最多装一个篮球。有n个篮球,每个球编号xi 。规则是将数字为xi 的篮球投到xi 除p的余数为编号的袋里。若袋里已有篮球则球弹出游戏结束输出i,否则重复至所有球都投完。输出-1。问游戏最终的输出是什么?

    // ConsoleApplication5.cpp : 定义控制台应用程序的入口点. // #include "stdafx.h" #include<vector> ...

  5. 用GetTickCount()计算一段代码执行耗费的时间的小例子

    var aNow,aThen,aTime:Longint; begin aThen := GetTickCount(); Sleep();//代码段 aNow := GetTickCount(); a ...

  6. 11 redis之rdb快照持久化

    一:Redis持久化配置 Redis的持久化有2种方式[快照,是日志] 二:Rdb快照的配置选项 save 900 1 // 900内,有1条写入,则产生快照 save 300 1000 // 如果3 ...

  7. Windows+VS+SVN实现版本控制

    Subversion已经是一个热门话题,下面介绍一下Windows下Subversion和TortoiseSVN构建SVN版本控制 问题. 首先看一些基础知识: Subversion是架设一个SVN ...

  8. IDEA下使用Jetty进行Debug模式调试

    过程例如以下: (1)找到选项卡中的 –Run– 然后找到 –Edit Configurations (2)点击下图中绿色的plus–找到Maven点进去 (3)依照下边的方式在Command lin ...

  9. Android 六大存储

    Android平台进行存储的方式: 一.使用SharedPreferences存储 二.文件存储数据 三.SQLite数据库存储 四.使用ContentProvider存储数据 五.网络存储数据 今天 ...

  10. with(nolock) 与 with(readpast) 与不加此2个的区别

    调试窗口一: 或者查询窗口一: 总之:事务没有结束 查询窗口二: