代码

https://github.com/s055523/MNISTTensorFlowSharp

数据的获得

数据可以由http://yann.lecun.com/exdb/mnist/下载。之后,储存在trainDir中,下次就不需要下载了。

  1. /// <summary>
  2. /// 如果文件不存在就去下载
  3. /// </summary>
  4. /// <param name="urlBase">下载地址</param>
  5. /// <param name="trainDir">文件目录地址</param>
  6. /// <param name="file">文件名</param>
  7. /// <returns></returns>
  8. public static Stream MaybeDownload(string urlBase, string trainDir, string file)
  9. {
  10. if (!Directory.Exists(trainDir))
  11. {
  12. Directory.CreateDirectory(trainDir);
  13. }
  14.  
  15. var target = Path.Combine(trainDir, file);
  16. if (!File.Exists(target))
  17. {
  18. var wc = new WebClient();
  19. wc.DownloadFile(urlBase + file, target);
  20. }
  21. return File.OpenRead(target);
  22. }

数据格式处理

下载下来的文件共有四个,都是扩展名为gz的压缩包。

train-images-idx3-ubyte.gz  55000张训练图片和5000张验证图片

train-labels-idx1-ubyte.gz     训练图片对应的数字标签(即答案)

t10k-images-idx3-ubyte.gz   10000张测试图片

t10k-labels-idx1-ubyte.gz     测试图片对应的数字标签(即答案)

处理图片数据压缩包

每个压缩包的格式为:

偏移量

类型

意义

0

Int32

2051或2049

一个定死的魔术数。用来验证该压缩包是训练集(2051)或测试集(2049)

4

Int32

60000或10000

压缩包的图片数

8

Int32

28

每个图片的行数

12

Int32

28

每个图片的列数

16

Unsigned byte

0 - 255

第一张图片的第一个像素

17

Unsigned byte

0 - 255

第一张图片的第二个像素

因此,我们可以使用一个统一的方式将数据处理。我们只需要那些图片像素。

  1. /// <summary>
  2. /// 从数据流中读取下一个int32
  3. /// </summary>
  4. /// <param name="s"></param>
  5. /// <returns></returns>
  6. int Read32(Stream s)
  7. {
  8. var x = new byte[];
  9. s.Read(x, , );
  10. return DataConverter.BigEndian.GetInt32(x, );
  11. }
  12.  
  13. /// <summary>
  14. /// 处理图片数据
  15. /// </summary>
  16. /// <param name="input"></param>
  17. /// <param name="file"></param>
  18. /// <returns></returns>
  19. MnistImage[] ExtractImages(Stream input, string file)
  20. {
  21. //文件是gz格式的
  22. using (var gz = new GZipStream(input, CompressionMode.Decompress))
  23. {
  24. //不是2051说明下载的文件不对
  25. if (Read32(gz) != )
  26. {
  27. throw new Exception("不是2051说明下载的文件不对: " + file);
  28. }
  29. //图片数
  30. var count = Read32(gz);
  31. //行数
  32. var rows = Read32(gz);
  33. //列数
  34. var cols = Read32(gz);
  35.  
  36. Console.WriteLine($"准备读取{count}张图片。");
  37.  
  38. var result = new MnistImage[count];
  39. for (int i = ; i < count; i++)
  40. {
  41. //图片的大小(每个像素占一个bit)
  42. var size = rows * cols;
  43. var data = new byte[size];
  44.  
  45. //从数据流中读取这么大的一块内容
  46. gz.Read(data, , size);
  47.  
  48. //将读取到的内容转换为MnistImage类型
  49. result[i] = new MnistImage(cols, rows, data);
  50. }
  51. return result;
  52. }
  53. }

准备一个MnistImage类型:

  1. /// <summary>
  2. /// 图片类型
  3. /// </summary>
  4. public struct MnistImage
  5. {
  6. public int Cols, Rows;
  7. public byte[] Data;
  8. public float[] DataFloat;
  9.  
  10. public MnistImage(int cols, int rows, byte[] data)
  11. {
  12. Cols = cols;
  13. Rows = rows;
  14. Data = data;
  15. DataFloat = new float[data.Length];
  16. for (int i = ; i < data.Length; i++)
  17. {
  18. //数据归一化(这里将0-255除255变成了0-1之间的小数)
  19. //也可以归一为-0.5到0.5之间
  20. DataFloat[i] = Data[i] / 255f;
  21. }
  22. }
  23. }

这样一来,图片数据就处理完成了。

处理数字标签数据压缩包

数字标签数据压缩包和图片数据压缩包的格式类似。

偏移量

类型

意义

0

Int32

2051或2049

一个定死的魔术数。用来验证该压缩包是训练集(2051)或测试集(2049)

4

Int32

60000或10000

压缩包的数字标签数

5

Unsigned byte

0 - 9

第一张图片对应的数字

6

Unsigned byte

0 - 9

第二张图片对应的数字

它的处理更加简单。

  1. /// <summary>
  2. /// 处理标签数据
  3. /// </summary>
  4. /// <param name="input"></param>
  5. /// <param name="file"></param>
  6. /// <returns></returns>
  7. byte[] ExtractLabels(Stream input, string file)
  8. {
  9. using (var gz = new GZipStream(input, CompressionMode.Decompress))
  10. {
  11. //不是2049说明下载的文件不对
  12. if (Read32(gz) != )
  13. {
  14. throw new Exception("不是2049说明下载的文件不对:" + file);
  15. }
  16. var count = Read32(gz);
  17. var labels = new byte[count];
  18.  
  19. gz.Read(labels, , count);
  20.  
  21. return labels;
  22. }
  23. }

将数字标签转化为二维数组:one-hot编码

由于我们的数字为0-9,所以,可以视为有十个class。此时,为了后续的处理方便,我们将数字标签转化为数组。因此,一组标签就转换为了一个二维数组。

例如,标签0变成[1,0,0,0,0,0,0,0,0,0]

标签1变成[0,1,0,0,0,0,0,0,0,0]

以此类推。

  1. /// <summary>
  2. /// 将数字标签一维数组转为一个二维数组
  3. /// </summary>
  4. /// <param name="labels"></param>
  5. /// <param name="numClasses">多少个类别,这里是10(0到9)</param>
  6. /// <returns></returns>
  7. byte[,] OneHot(byte[] labels, int numClasses)
  8. {
  9. var oneHot = new byte[labels.Length, numClasses];
  10. for (int i = ; i < labels.Length; i++)
  11. {
  12. oneHot[i, labels[i]] = ;
  13. }
  14. return oneHot;
  15. }

到此为止,数据格式处理就全部结束了。下面的代码展示了数据处理的全过程。

  1. /// <summary>
  2. /// 处理数据集
  3. /// </summary>
  4. /// <param name="trainDir">数据集所在文件夹</param>
  5. /// <param name="numClasses"></param>
  6. /// <param name="validationSize">拿出多少做验证?</param>
  7. public void ReadDataSets(string trainDir, int numClasses = , int validationSize = )
  8. {
  9. const string SourceUrl = "http://yann.lecun.com/exdb/mnist/";
  10. const string TrainImagesName = "train-images-idx3-ubyte.gz";
  11. const string TrainLabelsName = "train-labels-idx1-ubyte.gz";
  12. const string TestImagesName = "t10k-images-idx3-ubyte.gz";
  13. const string TestLabelsName = "t10k-labels-idx1-ubyte.gz";
  14.  
  15. //获得训练数据,然后处理训练数据和测试数据
  16. TrainImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TrainImagesName), TrainImagesName);
  17. TestImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TestImagesName), TestImagesName);
  18. TrainLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TrainLabelsName), TrainLabelsName);
  19. TestLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TestLabelsName), TestLabelsName);
  20.  
  21. //拿出前面的一部分做验证
  22. ValidationImages = Pick(TrainImages, , validationSize);
  23. ValidationLabels = Pick(TrainLabels, , validationSize);
  24.  
  25. //拿出剩下的做训练(输入0意味着拿剩下所有的)
  26. TrainImages = Pick(TrainImages, validationSize, );
  27. TrainLabels = Pick(TrainLabels, validationSize, );
  28.  
  29. //将数字标签转换为二维数组
  30. //例如,标签3 =》 [0,0,0,1,0,0,0,0,0,0]
  31. //标签0 =》 [1,0,0,0,0,0,0,0,0,0]
  32. if (numClasses != -)
  33. {
  34. OneHotTrainLabels = OneHot(TrainLabels, numClasses);
  35. OneHotValidationLabels = OneHot(ValidationLabels, numClasses);
  36. OneHotTestLabels = OneHot(TestLabels, numClasses);
  37. }
  38. }
  39.  
  40. /// <summary>
  41. /// 获得source集合中的一部分,从first开始,到last结束
  42. /// </summary>
  43. /// <typeparam name="T"></typeparam>
  44. /// <param name="source"></param>
  45. /// <param name="first"></param>
  46. /// <param name="last"></param>
  47. /// <returns></returns>
  48. T[] Pick<T>(T[] source, int first, int last)
  49. {
  50. if (last == )
  51. {
  52. last = source.Length;
  53. }
  54.  
  55. var count = last - first;
  56. var ret = source.Skip(first).Take(count).ToArray();
  57. return ret;
  58. }
  59.  
  60. public static Mnist Load()
  61. {
  62. var x = new Mnist();
  63. x.ReadDataSets(@"D:\人工智能\C#代码\MNISTTensorFlowSharp\MNISTTensorFlowSharp\data");
  64. return x;
  65. }

在这里,数据共有下面几部分:

  1. 训练图片数据55000 TrainImages及对应标签TrainLabels
  2. 验证图片数据5000 ValidationImages及对应标签ValidationLabels
  3. 测试图片数据10000 TestImages及对应标签TestLabels

KNN算法的实现

现在,我们已经有了所有的数据在手。需要实现的是:

  1. 拿出数据中的一部分(例如,5000张图片)作为KNN的训练数据,然后,再从数据中的另一部分拿一张图片A
  2. 对这张图片A,求它和5000张训练图片的距离,并找出一张训练图片B,它是所有训练图片中,和A距离最小的那张(这意味着K=1)
  3. 此时,就认为A所代表的数字等同于B所代表的数字b
  4. 重复1-3,N次

首先进行数据的收集:

  1. //三个Reader分别从总的数据库中获得数据
  2. public BatchReader GetTrainReader() => new BatchReader(TrainImages, TrainLabels, OneHotTrainLabels);
  3. public BatchReader GetTestReader() => new BatchReader(TestImages, TestLabels, OneHotTestLabels);
  4. public BatchReader GetValidationReader() => new BatchReader(ValidationImages, ValidationLabels, OneHotValidationLabels);
  5.  
  6. /// <summary>
  7. /// 数据的一部分,包括了所有的有用信息
  8. /// </summary>
  9. public class BatchReader
  10. {
  11. int start = ;
  12. //图片库
  13. MnistImage[] source;
  14. //数字标签
  15. byte[] labels;
  16. //oneHot之后的数字标签
  17. byte[,] oneHotLabels;
  18.  
  19. internal BatchReader(MnistImage[] source, byte[] labels, byte[,] oneHotLabels)
  20. {
  21. this.source = source;
  22. this.labels = labels;
  23. this.oneHotLabels = oneHotLabels;
  24. }
  25.  
  26. /// <summary>
  27. /// 返回两个浮点二维数组(C# 7的新语法)
  28. /// </summary>
  29. /// <param name="batchSize"></param>
  30. /// <returns></returns>
  31. public (float[,], float[,]) NextBatch(int batchSize)
  32. {
  33. //一张图
  34. var imageData = new float[batchSize, ];
  35. //标签
  36. var labelData = new float[batchSize, ];
  37.  
  38. int p = ;
  39. for (int item = ; item < batchSize; item++)
  40. {
  41. Buffer.BlockCopy(source[start + item].DataFloat, , imageData, p, * sizeof(float));
  42. p += * sizeof(float);
  43. for (var j = ; j < ; j++)
  44. labelData[item, j] = oneHotLabels[item + start, j];
  45. }
  46.  
  47. start += batchSize;
  48. return (imageData, labelData);
  49. }
  50. }

然后,在算法中,获取数据:

  1. static void KNN()
  2. {
  3. //取得数据
  4. var mnist = Mnist.Load();
  5.  
  6. //拿5000个训练数据,200个测试数据
  7. const int trainCount = ;
  8. const int testCount = ;
  9.  
  10. //获得的数据有两个
  11. //一个是图片,它们都是28*28的
  12. //一个是one-hot的标签,它们都是1*10的
  13. (var trainingImages, var trainingLabels) = mnist.GetTrainReader().NextBatch(trainCount);
  14. (var testImages, var testLabels) = mnist.GetTestReader().NextBatch(testCount);
  15.  
  16. Console.WriteLine($"MNIST 1NN");

下面进行计算。这里使用了K=1的L1距离。这是最简单的情况。

  1. //建立一个图表示计算任务
  2. using (var graph = new TFGraph())
  3. {
  4. var session = new TFSession(graph);
  5.  
  6. //用来feed数据的占位符。trainingInput表示N张用来进行训练的图片,N是一个变量,所以这里使用-1
  7. TFOutput trainingInput = graph.Placeholder(TFDataType.Float, new TFShape(-, ));
  8.  
  9. //xte表示一张用来测试的图片
  10. TFOutput xte = graph.Placeholder(TFDataType.Float, new TFShape());
  11.  
  12. //计算这两张图片的L1距离。这很简单,实际上就是把784个数字逐对相减,然后取绝对值,最后加起来变成一个总和
  13. var distance = graph.ReduceSum(graph.Abs(graph.Sub(trainingInput, xte)), axis: graph.Const());
  14.  
  15. //这里只是用了最近的那个数据
  16. //也就是说,最近的那个数据是什么,那pred(预测值)就是什么
  17. TFOutput pred = graph.ArgMin(distance, graph.Const());

最后是开启Session计算的过程:

  1. var accuracy = 0f;
  2.  
  3. //开始循环进行计算,循环trainCount次
  4. for (int i = ; i < testCount; i++)
  5. {
  6. var runner = session.GetRunner();
  7.  
  8. //每次,对一张新的测试图,计算它和trainCount张训练图的距离,并获得最近的那张
  9. var result = runner.Fetch(pred).Fetch(distance)
  10. //trainCount张训练图(数据是trainingImages)
  11. .AddInput(trainingInput, trainingImages)
  12. //testCount张测试图(数据是从testImages中拿出来的)
  13. .AddInput(xte, Extract(testImages, i))
  14. .Run();
  15.  
  16. //最近的点的序号
  17. var nn_index = (int)(long)result[].GetValue();
  18.  
  19. //从trainingLabels中找到答案(这是预测值)
  20. var prediction = ArgMax(trainingLabels, nn_index);
  21.  
  22. //正确答案位于testLabels[i]中
  23. var real = ArgMax(testLabels, i);
  24.  
  25. //PrintImage(testImages, i);
  26.  
  27. Console.WriteLine($"测试 {i}: " +
  28. $"预测: {prediction} " +
  29. $"正确答案: {real} (最近的点的序号={nn_index})");
  30. //Console.WriteLine(testImages);
  31.  
  32. if (prediction == real)
  33. {
  34. accuracy += 1f / testCount;
  35. }
  36. }
  37. Console.WriteLine("准确率: " + accuracy);

对KNN的改进

本文只是对KNN识别MNIST数据集进行了一个非常简单的介绍。在实现了最简单的K=1的L1距离计算之后,正确率约为91%。大家可以试着将算法进行改进,例如取K=2或者其他数,或者计算L2距离等。L2距离的结果比L1好一些,可以达到93-94%的正确率。

机器学习(2) - KNN识别MNIST的更多相关文章

  1. 机器学习(1) - TensorflowSharp 简单使用与KNN识别MNIST流程

    机器学习是时下非常流行的话题,而Tensorflow是机器学习中最有名的工具包.TensorflowSharp是Tensorflow的C#语言表述.本文会对TensorflowSharp的使用进行一个 ...

  2. KNN识别图像上的数字及python实现

    领导让我每天手工录入BI系统中的数据并判断数据是否存在异常,若有异常点,则检测是系统问题还是业务问题.为了解放双手,我决定写个程序完成每天录入管理驾驶舱数据的任务.首先用按键精灵录了一套脚本把系统中的 ...

  3. 使用KNN对MNIST数据集进行实验

    由于KNN的计算量太大,还没有使用KD-tree进行优化,所以对于60000训练集,10000测试集的数据计算比较慢.这里只是想测试观察一下KNN的效果而已,不调参. K选择之前看过貌似最好不要超过2 ...

  4. TensorFlow 之 手写数字识别MNIST

    官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...

  5. TensorFlow 入门之手写识别(MNIST) 数据处理 一

    TensorFlow 入门之手写识别(MNIST) 数据处理 一 MNIST Fly softmax回归 准备数据 解压 与 重构 手写识别入门 MNIST手写数据集 图片以及标签的数据格式处理 准备 ...

  6. 机器学习-SVM-手写识别问题

    机器学习-SVM-手写识别问题 这里我们解决的还是之前用KNN曾经解决过的手写识别问题(https://www.cnblogs.com/jiading/p/11622019.html),但相比于KNN ...

  7. 机器学习算法·KNN

    机器学习算法应用·KNN算法 一.问题描述 验证码目前在互联网上非常常见,从学校的教务系统到12306购票系统,充当着防火墙的功能.但是随着OCR技术的发展,验证码暴露出的安全问题越来越严峻.目前对验 ...

  8. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  9. matlab练习程序(神经网络识别mnist手写数据集)

    记得上次练习了神经网络分类,不过当时应该有些地方写的还是不对. 这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码. mnist数据集训练数据一共有28*28*60000个像素 ...

随机推荐

  1. 【一天一道LeetCode】#73. Set Matrix Zeroes

    一天一道LeetCode 本系列文章已全部上传至我的github,地址:ZeeCoder's Github 欢迎大家关注我的新浪微博,我的新浪微博 欢迎转载,转载请注明出处 (一)题目 Given a ...

  2. 十大常见Java String问题

    翻译人员: 铁锚 翻译时间: 2013年11月7日 原文链接: Top 10 questions of Java Strings 本文介绍Java中关于String最常见的10个问题: 1. 字符串比 ...

  3. (三十三)UIApplicationDelegate和程序的启动过程

    移动操作系统有个致命弱点,是app容易受到干扰(来电或者锁屏). 当app受到干扰时,会产生一系列的系统事件,这时UIApplication会通知其delegate对象,让delegate处理系统事件 ...

  4. Android进阶(十五)socket通信——聊天室

    想做一个聊天室,花费了将近一天的时间,各种错误.讲解知识点之前,絮叨几句:动手能力还是很重要的,有时看似简单的一个问题,当你真正着手去解决的时候就有可能会遇到各种各样的问题,原因之一就是因为你的知识储 ...

  5. unity使用UGUI创建摇杆

    1.现在unity做一个项目,各种插件各种包,于是项目资源就无限变大了,其实一些简单的功能可以自己写,这里就是试着使用UGUI编写一个摇杆功能 2.脚本如下: using UnityEngine; u ...

  6. Unity 5.X扩展编辑器之打包assetbundle

    5.x的assetbundle与4.x以及之前的版本有些不同,不过本质是一样的,只不过5.x打包assetbundle更为简单和人性化了,总体来说只需要三个步骤: 第一步:创建打包资源 //这里是一个 ...

  7. AndroidBinder进程间通信系统-android学习之旅(86)

    目录 前言及知识准备 Service组件结构 Clinet组件结构 与Binder驱动程序交互 总结 Binder进程间通信实例 问题 本次主要介绍Android平台下Binder进程间通信库.所谓通 ...

  8. AngularJS进阶(十七)在AngularJS应用中实现微信认证授权遇到的坑

    在AngularJS应用中集成微信认证授权遇到的坑 注:请点击此处进行充电! 前言 项目开发过程中,移动端新近增加了一个功能"微信授权登录",由于自己不是负责移动端开发的,但最后他 ...

  9. 如何使用VS2013本地C++单元测试框架

    在VS2013中,可以使用VS自带的C++单元测试框架. 在使用该框架前,需要先安装Unit Test Generator(可以通过菜单“工具->扩展和更新”搜索安装). 下边,就阐述一下利用该 ...

  10. linux下32位汇编调用规则

    传递给系统调用的参数必须安装参数顺序一次放到寄存器中,当系统调用完成后,返回值放在eax中: 当系统调用参数<=5个时: eax中存放系统调用的功能号,传递给系统调用的参数顺序依次放到寄存器:e ...