TensorFlow.NET机器学习入门【5】采用神经网络实现手写数字识别(MNIST)
从这篇文章开始,终于要干点正儿八经的工作了,前面都是准备工作。这次我们要解决机器学习的经典问题,MNIST手写数字识别。
首先介绍一下数据集。请首先解压:TF_Net\Asset\mnist_png.tar.gz文件
文件夹内包括两个文件夹:training和validation,其中training文件夹下包括60000个训练图片validation下包括10000个评估图片,图片为28*28像素,分别放在0~9十个文件夹中。
程序总体流程和上一篇文章介绍的BMI分析程序基本一致,毕竟都是多元分类,有几点不一样。
1、BMI程序的特征数据(输入)为一维数组,包含两个数字,MNIST的特征数据为28*28的二位数组;
2、BMI程序的输出为3个,MNIST的输出为10个;
网络模型构建如下:
private readonly int img_rows = 28;
private readonly int img_cols = 28;
private readonly int num_classes = 10; // total classes
/// <summary>
/// 构建网络模型
/// </summary>
private Model BuildModel()
{
// 网络参数
int n_hidden_1 = 128; // 1st layer number of neurons.
int n_hidden_2 = 128; // 2nd layer number of neurons.
float scale = 1.0f / 255; var model = keras.Sequential(new List<ILayer>
{
keras.layers.InputLayer((img_rows,img_cols)),
keras.layers.Flatten(),
keras.layers.Rescaling(scale),
keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
}); return model;
}
这个网络里用到了两个新方法,需要解释一下:
1、Flatten方法:这里表示拉平,把28*28的二维数组拉平为含784个数据的一维数组,因为二维数组无法进行运算;
2、Rescaling 方法:就是对每个数据乘以一个系数,因为我们从图片获取的数据为每一个位点的灰度值,其取值范围为0~255,所以乘以一个系数将数据缩小到1以内,以免后面运算时溢出。
其它基本和上一篇文章介绍的差不多,全部代码如下:

/// <summary>
/// 通过神经网络来实现多元分类
/// </summary>
public class NN_MultipleClassification_BMI
{
private readonly Random random = new Random(1); // 网络参数
int num_features = 2; // data features
int num_classes = 3; // total output . public void Run()
{
var model = BuildModel();
model.summary(); Console.WriteLine("Press any key to continue...");
Console.ReadKey(); (NDArray train_x, NDArray train_y) = PrepareData(1000);
model.compile(optimizer: keras.optimizers.Adam(0.001f),
loss: keras.losses.SparseCategoricalCrossentropy(),
metrics: new[] { "accuracy" });
model.fit(train_x, train_y, batch_size: 128, epochs: 300); test(model);
} /// <summary>
/// 构建网络模型
/// </summary>
private Model BuildModel()
{
// 网络参数
int n_hidden_1 = 64; // 1st layer number of neurons.
int n_hidden_2 = 64; // 2nd layer number of neurons. var model = keras.Sequential(new List<ILayer>
{
keras.layers.InputLayer(num_features),
keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
}); return model;
} /// <summary>
/// 加载训练数据
/// </summary>
/// <param name="total_size"></param>
private (NDArray, NDArray) PrepareData(int total_size)
{
float[,] arrx = new float[total_size, num_features];
int[] arry = new int[total_size]; for (int i = 0; i < total_size; i++)
{
float weight = (float)random.Next(30, 100) / 100;
float height = (float)random.Next(140, 190) / 100;
float bmi = (weight * 100) / (height * height); arrx[i, 0] = weight;
arrx[i, 1] = height; switch (bmi)
{
case var x when x < 18.0f:
arry[i] = 0;
break; case var x when x >= 18.0f && x <= 28.0f:
arry[i] = 1;
break; case var x when x > 28.0f:
arry[i] = 2;
break;
}
} return (np.array(arrx), np.array(arry));
} /// <summary>
/// 消费模型
/// </summary>
private void test(Model model)
{
int test_size = 20;
for (int i = 0; i < test_size; i++)
{
float weight = (float)random.Next(40, 90) / 100;
float height = (float)random.Next(145, 185) / 100;
float bmi = (weight * 100) / (height * height); var test_x = np.array(new float[1, 2] { { weight, height } });
var pred_y = model.Apply(test_x); Console.WriteLine($"{i}:weight={(float)weight} \theight={height} \tBMI={bmi:0.0} \tPred:{pred_y[0].numpy()}");
}
}
}
另有两点说明:
1、由于对图片的读取比较耗时,所以我采用了一个方法,就是把读取到的数据序列化到一个二进制文件中,下次直接从二进制文件反序列化即可,大大加快处理速度。
2、我没有采用validation图片进行评估,只是简单选了20个样本测试了一下。
【相关资源】
源码:Git: https://gitee.com/seabluescn/tf_not.git
项目名称:NN_MultipleClassification_MNIST
TensorFlow.NET机器学习入门【5】采用神经网络实现手写数字识别(MNIST)的更多相关文章
- TensorFlow卷积神经网络实现手写数字识别以及可视化
边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...
- TensorFlow 之 手写数字识别MNIST
官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...
- BP神经网络的手写数字识别
BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...
- 利用c++编写bp神经网络实现手写数字识别详解
利用c++编写bp神经网络实现手写数字识别 写在前面 从大一入学开始,本菜菜就一直想学习一下神经网络算法,但由于时间和资源所限,一直未展开比较透彻的学习.大二下人工智能课的修习,给了我一个学习的契机. ...
- 第二节,TensorFlow 使用前馈神经网络实现手写数字识别
一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...
- 5 TensorFlow入门笔记之RNN实现手写数字识别
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- 卷积神经网络CNN 手写数字识别
1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...
- 【机器学习】BP神经网络实现手写数字识别
最近用python写了一个实现手写数字识别的BP神经网络,BP的推导到处都是,但是一动手才知道,会理论推导跟实现它是两回事.关于BP神经网络的实现网上有一些代码,可惜或多或少都有各种问题,在下手写了一 ...
- BP神经网络(手写数字识别)
1实验环境 实验环境:CPU i7-3770@3.40GHz,内存8G,windows10 64位操作系统 实现语言:python 实验数据:Mnist数据集 程序使用的数据库是mnist手写数字数据 ...
随机推荐
- 强化学习实战 | 表格型Q-Learning玩井字棋(一)
在 强化学习实战 | 自定义Gym环境之井子棋 中,我们构建了一个井字棋环境,并进行了测试.接下来我们可以使用各种强化学习方法训练agent出棋,其中比较简单的是Q学习,Q即Q(S, a),是状态动作 ...
- 浅讲.Net 6 之 WebApplicationBuilder
介绍 .Net 6为我们带来的一种全新的引导程序启动的方式.与之前的拆分成Program.cs和Startup不同,整个引导启动代码都在Program.cs中. WebApplicationBuild ...
- android studio 使用 aidl(三)权限验证
这篇文章是基于android studio 使用 aidl (一) 和 android studio 使用 aidl(二) 异步回调 下面的代码都是简化的,如果看不懂请先移步上2篇文章 网上的东西太坑 ...
- 数组实现堆栈——Java实现
1 package struct; 2 3 4 //接口 5 interface IArrayStack{ 6 //栈的容量 7 int length(); 8 //栈中元素个数(栈大小) 9 int ...
- zabbix之修改中文
#在zabbix服务器安装中文名包 root@ubuntu:~# sudo apt-get install language-pack-zh* #:修改环境变量 root@ubuntu:~# sudo ...
- 【编程思想】【设计模式】【结构模式Structural】front_controller
Python版 https://github.com/faif/python-patterns/blob/master/structural/front_controller.py #!/usr/bi ...
- 【面试】【Linux】【Web】基础概念
1. HTTP https://en.wikipedia.org/wiki/Hypertext_Transfer_Protocol 2. TCP handshake https://en.wikipe ...
- jdk1.7源码之-hashMap源码解析
背景: 笔者最近这几天在思考,为什么要学习设计模式,学些设计模式无非是提高自己的开发技能,但是通过这一段时间来看,其实我也学习了一些设计模式,但是都是一些demo,没有具体的例子,学习起来不深刻,所以 ...
- 12月第二周bug总结
1.bug总结 复制 别人的依赖和依赖指定类型 报错 解决:依赖还没加载完成,你就指定了版本型号,所以报错,所以先让他加载依赖,后指定该型号 eureka(优瑞卡) 注册服务 控制台没有显示出来的话 ...
- Mybatis中对象关系映射
在实际开发中,实体类之间有一对一.一对多.多对多的关系,所以需要正确配置它们对应关系,Mybatis通过配置文件能够从数据库中获取列数据后自动封装成对象. 如:一个订单Orders类对应一个用户Use ...