机器学习(2) - KNN识别MNIST
代码
https://github.com/s055523/MNISTTensorFlowSharp
数据的获得
数据可以由http://yann.lecun.com/exdb/mnist/下载。之后,储存在trainDir中,下次就不需要下载了。
- /// <summary>
- /// 如果文件不存在就去下载
- /// </summary>
- /// <param name="urlBase">下载地址</param>
- /// <param name="trainDir">文件目录地址</param>
- /// <param name="file">文件名</param>
- /// <returns></returns>
- public static Stream MaybeDownload(string urlBase, string trainDir, string file)
- {
- if (!Directory.Exists(trainDir))
- {
- Directory.CreateDirectory(trainDir);
- }
- var target = Path.Combine(trainDir, file);
- if (!File.Exists(target))
- {
- var wc = new WebClient();
- wc.DownloadFile(urlBase + file, target);
- }
- return File.OpenRead(target);
- }
数据格式处理
下载下来的文件共有四个,都是扩展名为gz的压缩包。
train-images-idx3-ubyte.gz 55000张训练图片和5000张验证图片
train-labels-idx1-ubyte.gz 训练图片对应的数字标签(即答案)
t10k-images-idx3-ubyte.gz 10000张测试图片
t10k-labels-idx1-ubyte.gz 测试图片对应的数字标签(即答案)
处理图片数据压缩包
每个压缩包的格式为:
偏移量 |
类型 |
值 |
意义 |
0 |
Int32 |
2051或2049 |
一个定死的魔术数。用来验证该压缩包是训练集(2051)或测试集(2049) |
4 |
Int32 |
60000或10000 |
压缩包的图片数 |
8 |
Int32 |
28 |
每个图片的行数 |
12 |
Int32 |
28 |
每个图片的列数 |
16 |
Unsigned byte |
0 - 255 |
第一张图片的第一个像素 |
17 |
Unsigned byte |
0 - 255 |
第一张图片的第二个像素 |
… |
… |
… |
… |
因此,我们可以使用一个统一的方式将数据处理。我们只需要那些图片像素。
- /// <summary>
- /// 从数据流中读取下一个int32
- /// </summary>
- /// <param name="s"></param>
- /// <returns></returns>
- int Read32(Stream s)
- {
- var x = new byte[];
- s.Read(x, , );
- return DataConverter.BigEndian.GetInt32(x, );
- }
- /// <summary>
- /// 处理图片数据
- /// </summary>
- /// <param name="input"></param>
- /// <param name="file"></param>
- /// <returns></returns>
- MnistImage[] ExtractImages(Stream input, string file)
- {
- //文件是gz格式的
- using (var gz = new GZipStream(input, CompressionMode.Decompress))
- {
- //不是2051说明下载的文件不对
- if (Read32(gz) != )
- {
- throw new Exception("不是2051说明下载的文件不对: " + file);
- }
- //图片数
- var count = Read32(gz);
- //行数
- var rows = Read32(gz);
- //列数
- var cols = Read32(gz);
- Console.WriteLine($"准备读取{count}张图片。");
- var result = new MnistImage[count];
- for (int i = ; i < count; i++)
- {
- //图片的大小(每个像素占一个bit)
- var size = rows * cols;
- var data = new byte[size];
- //从数据流中读取这么大的一块内容
- gz.Read(data, , size);
- //将读取到的内容转换为MnistImage类型
- result[i] = new MnistImage(cols, rows, data);
- }
- return result;
- }
- }
准备一个MnistImage类型:
- /// <summary>
- /// 图片类型
- /// </summary>
- public struct MnistImage
- {
- public int Cols, Rows;
- public byte[] Data;
- public float[] DataFloat;
- public MnistImage(int cols, int rows, byte[] data)
- {
- Cols = cols;
- Rows = rows;
- Data = data;
- DataFloat = new float[data.Length];
- for (int i = ; i < data.Length; i++)
- {
- //数据归一化(这里将0-255除255变成了0-1之间的小数)
- //也可以归一为-0.5到0.5之间
- DataFloat[i] = Data[i] / 255f;
- }
- }
- }
这样一来,图片数据就处理完成了。
处理数字标签数据压缩包
数字标签数据压缩包和图片数据压缩包的格式类似。
偏移量 |
类型 |
值 |
意义 |
0 |
Int32 |
2051或2049 |
一个定死的魔术数。用来验证该压缩包是训练集(2051)或测试集(2049) |
4 |
Int32 |
60000或10000 |
压缩包的数字标签数 |
5 |
Unsigned byte |
0 - 9 |
第一张图片对应的数字 |
6 |
Unsigned byte |
0 - 9 |
第二张图片对应的数字 |
… |
… |
… |
… |
它的处理更加简单。
- /// <summary>
- /// 处理标签数据
- /// </summary>
- /// <param name="input"></param>
- /// <param name="file"></param>
- /// <returns></returns>
- byte[] ExtractLabels(Stream input, string file)
- {
- using (var gz = new GZipStream(input, CompressionMode.Decompress))
- {
- //不是2049说明下载的文件不对
- if (Read32(gz) != )
- {
- throw new Exception("不是2049说明下载的文件不对:" + file);
- }
- var count = Read32(gz);
- var labels = new byte[count];
- gz.Read(labels, , count);
- return labels;
- }
- }
将数字标签转化为二维数组:one-hot编码
由于我们的数字为0-9,所以,可以视为有十个class。此时,为了后续的处理方便,我们将数字标签转化为数组。因此,一组标签就转换为了一个二维数组。
例如,标签0变成[1,0,0,0,0,0,0,0,0,0]
标签1变成[0,1,0,0,0,0,0,0,0,0]
以此类推。
- /// <summary>
- /// 将数字标签一维数组转为一个二维数组
- /// </summary>
- /// <param name="labels"></param>
- /// <param name="numClasses">多少个类别,这里是10(0到9)</param>
- /// <returns></returns>
- byte[,] OneHot(byte[] labels, int numClasses)
- {
- var oneHot = new byte[labels.Length, numClasses];
- for (int i = ; i < labels.Length; i++)
- {
- oneHot[i, labels[i]] = ;
- }
- return oneHot;
- }
到此为止,数据格式处理就全部结束了。下面的代码展示了数据处理的全过程。
- /// <summary>
- /// 处理数据集
- /// </summary>
- /// <param name="trainDir">数据集所在文件夹</param>
- /// <param name="numClasses"></param>
- /// <param name="validationSize">拿出多少做验证?</param>
- public void ReadDataSets(string trainDir, int numClasses = , int validationSize = )
- {
- const string SourceUrl = "http://yann.lecun.com/exdb/mnist/";
- const string TrainImagesName = "train-images-idx3-ubyte.gz";
- const string TrainLabelsName = "train-labels-idx1-ubyte.gz";
- const string TestImagesName = "t10k-images-idx3-ubyte.gz";
- const string TestLabelsName = "t10k-labels-idx1-ubyte.gz";
- //获得训练数据,然后处理训练数据和测试数据
- TrainImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TrainImagesName), TrainImagesName);
- TestImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TestImagesName), TestImagesName);
- TrainLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TrainLabelsName), TrainLabelsName);
- TestLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TestLabelsName), TestLabelsName);
- //拿出前面的一部分做验证
- ValidationImages = Pick(TrainImages, , validationSize);
- ValidationLabels = Pick(TrainLabels, , validationSize);
- //拿出剩下的做训练(输入0意味着拿剩下所有的)
- TrainImages = Pick(TrainImages, validationSize, );
- TrainLabels = Pick(TrainLabels, validationSize, );
- //将数字标签转换为二维数组
- //例如,标签3 =》 [0,0,0,1,0,0,0,0,0,0]
- //标签0 =》 [1,0,0,0,0,0,0,0,0,0]
- if (numClasses != -)
- {
- OneHotTrainLabels = OneHot(TrainLabels, numClasses);
- OneHotValidationLabels = OneHot(ValidationLabels, numClasses);
- OneHotTestLabels = OneHot(TestLabels, numClasses);
- }
- }
- /// <summary>
- /// 获得source集合中的一部分,从first开始,到last结束
- /// </summary>
- /// <typeparam name="T"></typeparam>
- /// <param name="source"></param>
- /// <param name="first"></param>
- /// <param name="last"></param>
- /// <returns></returns>
- T[] Pick<T>(T[] source, int first, int last)
- {
- if (last == )
- {
- last = source.Length;
- }
- var count = last - first;
- var ret = source.Skip(first).Take(count).ToArray();
- return ret;
- }
- public static Mnist Load()
- {
- var x = new Mnist();
- x.ReadDataSets(@"D:\人工智能\C#代码\MNISTTensorFlowSharp\MNISTTensorFlowSharp\data");
- return x;
- }
在这里,数据共有下面几部分:
- 训练图片数据55000 TrainImages及对应标签TrainLabels
- 验证图片数据5000 ValidationImages及对应标签ValidationLabels
- 测试图片数据10000 TestImages及对应标签TestLabels
KNN算法的实现
现在,我们已经有了所有的数据在手。需要实现的是:
- 拿出数据中的一部分(例如,5000张图片)作为KNN的训练数据,然后,再从数据中的另一部分拿一张图片A
- 对这张图片A,求它和5000张训练图片的距离,并找出一张训练图片B,它是所有训练图片中,和A距离最小的那张(这意味着K=1)
- 此时,就认为A所代表的数字等同于B所代表的数字b
- 重复1-3,N次
首先进行数据的收集:
- //三个Reader分别从总的数据库中获得数据
- public BatchReader GetTrainReader() => new BatchReader(TrainImages, TrainLabels, OneHotTrainLabels);
- public BatchReader GetTestReader() => new BatchReader(TestImages, TestLabels, OneHotTestLabels);
- public BatchReader GetValidationReader() => new BatchReader(ValidationImages, ValidationLabels, OneHotValidationLabels);
- /// <summary>
- /// 数据的一部分,包括了所有的有用信息
- /// </summary>
- public class BatchReader
- {
- int start = ;
- //图片库
- MnistImage[] source;
- //数字标签
- byte[] labels;
- //oneHot之后的数字标签
- byte[,] oneHotLabels;
- internal BatchReader(MnistImage[] source, byte[] labels, byte[,] oneHotLabels)
- {
- this.source = source;
- this.labels = labels;
- this.oneHotLabels = oneHotLabels;
- }
- /// <summary>
- /// 返回两个浮点二维数组(C# 7的新语法)
- /// </summary>
- /// <param name="batchSize"></param>
- /// <returns></returns>
- public (float[,], float[,]) NextBatch(int batchSize)
- {
- //一张图
- var imageData = new float[batchSize, ];
- //标签
- var labelData = new float[batchSize, ];
- int p = ;
- for (int item = ; item < batchSize; item++)
- {
- Buffer.BlockCopy(source[start + item].DataFloat, , imageData, p, * sizeof(float));
- p += * sizeof(float);
- for (var j = ; j < ; j++)
- labelData[item, j] = oneHotLabels[item + start, j];
- }
- start += batchSize;
- return (imageData, labelData);
- }
- }
然后,在算法中,获取数据:
- static void KNN()
- {
- //取得数据
- var mnist = Mnist.Load();
- //拿5000个训练数据,200个测试数据
- const int trainCount = ;
- const int testCount = ;
- //获得的数据有两个
- //一个是图片,它们都是28*28的
- //一个是one-hot的标签,它们都是1*10的
- (var trainingImages, var trainingLabels) = mnist.GetTrainReader().NextBatch(trainCount);
- (var testImages, var testLabels) = mnist.GetTestReader().NextBatch(testCount);
- Console.WriteLine($"MNIST 1NN");
下面进行计算。这里使用了K=1的L1距离。这是最简单的情况。
- //建立一个图表示计算任务
- using (var graph = new TFGraph())
- {
- var session = new TFSession(graph);
- //用来feed数据的占位符。trainingInput表示N张用来进行训练的图片,N是一个变量,所以这里使用-1
- TFOutput trainingInput = graph.Placeholder(TFDataType.Float, new TFShape(-, ));
- //xte表示一张用来测试的图片
- TFOutput xte = graph.Placeholder(TFDataType.Float, new TFShape());
- //计算这两张图片的L1距离。这很简单,实际上就是把784个数字逐对相减,然后取绝对值,最后加起来变成一个总和
- var distance = graph.ReduceSum(graph.Abs(graph.Sub(trainingInput, xte)), axis: graph.Const());
- //这里只是用了最近的那个数据
- //也就是说,最近的那个数据是什么,那pred(预测值)就是什么
- TFOutput pred = graph.ArgMin(distance, graph.Const());
最后是开启Session计算的过程:
- var accuracy = 0f;
- //开始循环进行计算,循环trainCount次
- for (int i = ; i < testCount; i++)
- {
- var runner = session.GetRunner();
- //每次,对一张新的测试图,计算它和trainCount张训练图的距离,并获得最近的那张
- var result = runner.Fetch(pred).Fetch(distance)
- //trainCount张训练图(数据是trainingImages)
- .AddInput(trainingInput, trainingImages)
- //testCount张测试图(数据是从testImages中拿出来的)
- .AddInput(xte, Extract(testImages, i))
- .Run();
- //最近的点的序号
- var nn_index = (int)(long)result[].GetValue();
- //从trainingLabels中找到答案(这是预测值)
- var prediction = ArgMax(trainingLabels, nn_index);
- //正确答案位于testLabels[i]中
- var real = ArgMax(testLabels, i);
- //PrintImage(testImages, i);
- Console.WriteLine($"测试 {i}: " +
- $"预测: {prediction} " +
- $"正确答案: {real} (最近的点的序号={nn_index})");
- //Console.WriteLine(testImages);
- if (prediction == real)
- {
- accuracy += 1f / testCount;
- }
- }
- Console.WriteLine("准确率: " + accuracy);
对KNN的改进
本文只是对KNN识别MNIST数据集进行了一个非常简单的介绍。在实现了最简单的K=1的L1距离计算之后,正确率约为91%。大家可以试着将算法进行改进,例如取K=2或者其他数,或者计算L2距离等。L2距离的结果比L1好一些,可以达到93-94%的正确率。
机器学习(2) - KNN识别MNIST的更多相关文章
- 机器学习(1) - TensorflowSharp 简单使用与KNN识别MNIST流程
机器学习是时下非常流行的话题,而Tensorflow是机器学习中最有名的工具包.TensorflowSharp是Tensorflow的C#语言表述.本文会对TensorflowSharp的使用进行一个 ...
- KNN识别图像上的数字及python实现
领导让我每天手工录入BI系统中的数据并判断数据是否存在异常,若有异常点,则检测是系统问题还是业务问题.为了解放双手,我决定写个程序完成每天录入管理驾驶舱数据的任务.首先用按键精灵录了一套脚本把系统中的 ...
- 使用KNN对MNIST数据集进行实验
由于KNN的计算量太大,还没有使用KD-tree进行优化,所以对于60000训练集,10000测试集的数据计算比较慢.这里只是想测试观察一下KNN的效果而已,不调参. K选择之前看过貌似最好不要超过2 ...
- TensorFlow 之 手写数字识别MNIST
官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...
- TensorFlow 入门之手写识别(MNIST) 数据处理 一
TensorFlow 入门之手写识别(MNIST) 数据处理 一 MNIST Fly softmax回归 准备数据 解压 与 重构 手写识别入门 MNIST手写数据集 图片以及标签的数据格式处理 准备 ...
- 机器学习-SVM-手写识别问题
机器学习-SVM-手写识别问题 这里我们解决的还是之前用KNN曾经解决过的手写识别问题(https://www.cnblogs.com/jiading/p/11622019.html),但相比于KNN ...
- 机器学习算法·KNN
机器学习算法应用·KNN算法 一.问题描述 验证码目前在互联网上非常常见,从学校的教务系统到12306购票系统,充当着防火墙的功能.但是随着OCR技术的发展,验证码暴露出的安全问题越来越严峻.目前对验 ...
- TensorFlow 入门之手写识别(MNIST) softmax算法
TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...
- matlab练习程序(神经网络识别mnist手写数据集)
记得上次练习了神经网络分类,不过当时应该有些地方写的还是不对. 这次用神经网络识别mnist手写数据集,主要参考了深度学习工具包的一些代码. mnist数据集训练数据一共有28*28*60000个像素 ...
随机推荐
- 【一天一道LeetCode】#73. Set Matrix Zeroes
一天一道LeetCode 本系列文章已全部上传至我的github,地址:ZeeCoder's Github 欢迎大家关注我的新浪微博,我的新浪微博 欢迎转载,转载请注明出处 (一)题目 Given a ...
- 十大常见Java String问题
翻译人员: 铁锚 翻译时间: 2013年11月7日 原文链接: Top 10 questions of Java Strings 本文介绍Java中关于String最常见的10个问题: 1. 字符串比 ...
- (三十三)UIApplicationDelegate和程序的启动过程
移动操作系统有个致命弱点,是app容易受到干扰(来电或者锁屏). 当app受到干扰时,会产生一系列的系统事件,这时UIApplication会通知其delegate对象,让delegate处理系统事件 ...
- Android进阶(十五)socket通信——聊天室
想做一个聊天室,花费了将近一天的时间,各种错误.讲解知识点之前,絮叨几句:动手能力还是很重要的,有时看似简单的一个问题,当你真正着手去解决的时候就有可能会遇到各种各样的问题,原因之一就是因为你的知识储 ...
- unity使用UGUI创建摇杆
1.现在unity做一个项目,各种插件各种包,于是项目资源就无限变大了,其实一些简单的功能可以自己写,这里就是试着使用UGUI编写一个摇杆功能 2.脚本如下: using UnityEngine; u ...
- Unity 5.X扩展编辑器之打包assetbundle
5.x的assetbundle与4.x以及之前的版本有些不同,不过本质是一样的,只不过5.x打包assetbundle更为简单和人性化了,总体来说只需要三个步骤: 第一步:创建打包资源 //这里是一个 ...
- AndroidBinder进程间通信系统-android学习之旅(86)
目录 前言及知识准备 Service组件结构 Clinet组件结构 与Binder驱动程序交互 总结 Binder进程间通信实例 问题 本次主要介绍Android平台下Binder进程间通信库.所谓通 ...
- AngularJS进阶(十七)在AngularJS应用中实现微信认证授权遇到的坑
在AngularJS应用中集成微信认证授权遇到的坑 注:请点击此处进行充电! 前言 项目开发过程中,移动端新近增加了一个功能"微信授权登录",由于自己不是负责移动端开发的,但最后他 ...
- 如何使用VS2013本地C++单元测试框架
在VS2013中,可以使用VS自带的C++单元测试框架. 在使用该框架前,需要先安装Unit Test Generator(可以通过菜单“工具->扩展和更新”搜索安装). 下边,就阐述一下利用该 ...
- linux下32位汇编调用规则
传递给系统调用的参数必须安装参数顺序一次放到寄存器中,当系统调用完成后,返回值放在eax中: 当系统调用参数<=5个时: eax中存放系统调用的功能号,传递给系统调用的参数顺序依次放到寄存器:e ...