数据集

数据集下载MNIST

首先读取数据集, 并打印相关信息

包括

  • 图像的数量, 形状
  • 像素的最大, 最小值
  • 以及看一下第一张图片
  1. path = 'MNIST/mnist.npz'
  2. with np.load(path, allow_pickle=True) as f:
  3. x_train, y_train = f['x_train'], f['y_train']
  4. x_test, y_test = f['x_test'], f['y_test']
  5. print(f'dataset info: shape: {x_train.shape}, {y_train.shape}')
  6. print(f'dataset info: max: {x_train.max()}')
  7. print(f'dataset info: min: {x_train.min()}')
  8. print("A sample:")
  9. print("y_train: ", y_train[0])
  10. # print("x_train: \n", x_train[0])
  11. show_pic = x_train[0].copy()
  12. show_pic = cv2.resize(show_pic, (28 * 10, 28 * 10))
  13. cv2.imshow("A image sample", show_pic)
  14. key = cv2.waitKey(0)
  15. # 按 q 退出
  16. if key == ord('q'):
  17. cv2.destroyAllWindows()
  18. print("show demo over")

转换为tf 数据集的格式, 并进行归一化

  1. # convert to tf tensor
  2. x_train = tf.convert_to_tensor(x_train, dtype=tf.float32) // 255.
  3. x_test = tf.convert_to_tensor(x_test, dtype=tf.float32) // 255.
  4. dataset_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
  5. dataset_train = dataset_train.batch(batch_size).repeat(class_num)

定义网络

在这里定义一个简单的全连接网络

  1. def build_simple_net():
  2. net = Sequential([
  3. layers.Dense(256, activation='relu'),
  4. layers.Dense(256, activation='relu'),
  5. layers.Dense(256, activation='relu'),
  6. layers.Dense(class_num)
  7. ])
  8. net.build(input_shape=(None, 28 * 28))
  9. # net.summary()
  10. return net

训练

使用 SGD 优化器进行训练

  1. def train(print_info_step=250):
  2. net = build_simple_net()
  3. # 优化器
  4. optimizer = optimizers.SGD(lr=0.01)
  5. # 计算准确率
  6. acc = metrics.Accuracy()
  7. for step, (x, y) in enumerate(dataset_train):
  8. with tf.GradientTape() as tape:
  9. # [b, 28, 28] => [b, 784]
  10. x = tf.reshape(x, (-1, 28 * 28))
  11. # [b, 784] => [b, 10]
  12. out = net(x)
  13. # [b] => [b, 10]
  14. y_onehot = tf.one_hot(y, depth=class_num)
  15. # [b, 10]
  16. loss = tf.square(out - y_onehot)
  17. # [b]
  18. loss = tf.reduce_sum(loss) / batch_size
  19. # 反向传播
  20. acc.update_state(tf.argmax(out, axis=1), y)
  21. grads = tape.gradient(loss, net.trainable_variables)
  22. optimizer.apply_gradients(zip(grads, net.trainable_variables))
  23. if acc.result() >= 0.90:
  24. net.save_weights(save_path)
  25. print(f'final acc: {acc.result()}, total step: {step}')
  26. break
  27. if step % print_info_step == 0:
  28. print(f'step: {step}, loss: {loss}, acc: {acc.result().numpy()}')
  29. acc.reset_states()
  30. if step % 500 == 0 and step != 0:
  31. print('save model')
  32. net.save_weights(save_path)

验证

验证在测试集的模型效果, 这里仅取出第一张进行验证

  1. def test_dataset():
  2. net = build_simple_net()
  3. # 加载模型
  4. net.load_weights(save_path)
  5. # 拿到测试集第一张图片
  6. pred_image = x_test[0]
  7. pred_image = tf.reshape(pred_image, (-1, 28 * 28))
  8. pred = net.predict(pred_image)
  9. # print(pred)
  10. print(f'pred: {tf.argmax(pred, axis=1).numpy()}, label: {y_test[0]}')

应用

分割手写数字, 并进行逐一识别

  • 先将图像二值化
  • 找到轮廓
  • 得到数字的坐标
  • 转为模型的需要的输入格式, 并进行识别
  • 显示
  1. def split_number(img):
  2. result = []
  3. net = build_simple_net()
  4. # 加载模型
  5. net.load_weights(save_path)
  6. image = cv2.cvtColor(img.copy(), cv2.COLOR_RGB2GRAY)
  7. ret, thresh = cv2.threshold(image, 127, 255, 0)
  8. contours, hierarchy = cv2.findContours(thresh, 1, 2)
  9. for cnt in contours[:-1]:
  10. x, y, w, h = cv2.boundingRect(cnt)
  11. image = img[y:y+h, x:x+w]
  12. image = cv2.resize(image, (28, 28))
  13. pred_image = tf.convert_to_tensor(image, dtype=tf.float32) / 255.
  14. pred_image = tf.reshape(pred_image, (-1, 28 * 28))
  15. pred = net.predict(pred_image)
  16. out = tf.argmax(pred, axis=1).numpy()
  17. result = [out[0]] + result
  18. img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2)
  19. cv2.imshow("demo", img)
  20. print(result)
  21. k = cv2.waitKey(0)
  22. # 按 q 退出
  23. if k == ord('q'):
  24. pass
  25. cv2.destroyAllWindows()

效果

单数字



多数字



附录

所有代码, 文件 tf2_mnist.py

  1. import os
  2. import cv2
  3. import numpy as np
  4. import tensorflow as tf
  5. from tensorflow.keras import layers, Sequential, optimizers, metrics
  6. # 屏蔽通知信息和警告信息
  7. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  8. # 每批几张图片
  9. batch_size = 2
  10. # 类别数
  11. class_num = 10
  12. # 保存模型的路径
  13. save_path = "./models/mnist.ckpt"
  14. # 展示样例
  15. show_demo = False
  16. # 验证测试集
  17. evaluate_dataset = False
  18. # 是否训练
  19. run_train = False
  20. # 图片路径, 仅用于 detect_image(), 当为False时不识别
  21. image_path = 'images/36.png'
  22. path = 'MNIST/mnist.npz'
  23. with np.load(path, allow_pickle=True) as f:
  24. x_train, y_train = f['x_train'], f['y_train']
  25. x_test, y_test = f['x_test'], f['y_test']
  26. if show_demo:
  27. print(f'dataset info: shape: {x_train.shape}, {y_train.shape}')
  28. print(f'dataset info: max: {x_train.max()}')
  29. print(f'dataset info: min: {x_train.min()}')
  30. print("A sample:")
  31. print("y_train: ", y_train[0])
  32. # print("x_train: \n", x_train[0])
  33. show_pic = x_train[0].copy()
  34. show_pic = cv2.resize(show_pic, (28 * 10, 28 * 10))
  35. cv2.imshow("A image sample", show_pic)
  36. key = cv2.waitKey(0)
  37. if key == ord('q'):
  38. cv2.destroyAllWindows()
  39. print("show demo over")
  40. # convert to tf tensor
  41. x_train = tf.convert_to_tensor(x_train, dtype=tf.float32) // 255.
  42. x_test = tf.convert_to_tensor(x_test, dtype=tf.float32) // 255.
  43. dataset_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
  44. dataset_train = dataset_train.batch(batch_size).repeat(class_num)
  45. def build_simple_net():
  46. net = Sequential([
  47. layers.Dense(256, activation='relu'),
  48. layers.Dense(256, activation='relu'),
  49. layers.Dense(256, activation='relu'),
  50. layers.Dense(class_num)
  51. ])
  52. net.build(input_shape=(None, 28 * 28))
  53. # net.summary()
  54. return net
  55. def train(print_info_step=250):
  56. net = build_simple_net()
  57. # 优化器
  58. optimizer = optimizers.SGD(lr=0.01)
  59. # 计算准确率
  60. acc = metrics.Accuracy()
  61. for step, (x, y) in enumerate(dataset_train):
  62. with tf.GradientTape() as tape:
  63. # [b, 28, 28] => [b, 784]
  64. x = tf.reshape(x, (-1, 28 * 28))
  65. # [b, 784] => [b, 10]
  66. out = net(x)
  67. # [b] => [b, 10]
  68. y_onehot = tf.one_hot(y, depth=class_num)
  69. # [b, 10]
  70. loss = tf.square(out - y_onehot)
  71. # [b]
  72. loss = tf.reduce_sum(loss) / batch_size
  73. # 反向传播
  74. acc.update_state(tf.argmax(out, axis=1), y)
  75. grads = tape.gradient(loss, net.trainable_variables)
  76. optimizer.apply_gradients(zip(grads, net.trainable_variables))
  77. if acc.result() >= 0.90:
  78. net.save_weights(save_path)
  79. print(f'final acc: {acc.result()}, total step: {step}')
  80. break
  81. if step % print_info_step == 0:
  82. print(f'step: {step}, loss: {loss}, acc: {acc.result().numpy()}')
  83. acc.reset_states()
  84. if step % 500 == 0 and step != 0:
  85. print('save model')
  86. net.save_weights(save_path)
  87. def test_dataset():
  88. net = build_simple_net()
  89. # 加载模型
  90. net.load_weights(save_path)
  91. # 拿到测试集第一张图片
  92. pred_image = x_test[0]
  93. pred_image = tf.reshape(pred_image, (-1, 28 * 28))
  94. pred = net.predict(pred_image)
  95. # print(pred)
  96. print(f'pred: {tf.argmax(pred, axis=1).numpy()}, label: {y_test[0]}')
  97. def split_number(img):
  98. result = []
  99. net = build_simple_net()
  100. # 加载模型
  101. net.load_weights(save_path)
  102. image = cv2.cvtColor(img.copy(), cv2.COLOR_RGB2GRAY)
  103. ret, thresh = cv2.threshold(image, 127, 255, 0)
  104. contours, hierarchy = cv2.findContours(thresh, 1, 2)
  105. for cnt in contours[:-1]:
  106. x, y, w, h = cv2.boundingRect(cnt)
  107. image = img[y:y+h, x:x+w]
  108. image = cv2.resize(image, (28, 28))
  109. pred_image = tf.convert_to_tensor(image, dtype=tf.float32) / 255.
  110. pred_image = tf.reshape(pred_image, (-1, 28 * 28))
  111. pred = net.predict(pred_image)
  112. out = tf.argmax(pred, axis=1).numpy()
  113. result = [out[0]] + result
  114. img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2)
  115. cv2.imshow("demo", img)
  116. print(result)
  117. k = cv2.waitKey(0)
  118. if k == ord('q'):
  119. pass
  120. cv2.destroyAllWindows()
  121. if __name__ == '__main__':
  122. if run_train:
  123. train()
  124. elif evaluate_dataset:
  125. test_dataset()
  126. elif image_path:
  127. image = cv2.imread(image_path)
  128. # detect_image(image)
  129. split_number(image)

linux-基于tensorflow2.x的手写数字识别-基于MNIST数据集的更多相关文章

  1. 基于Numpy的神经网络+手写数字识别

    基于Numpy的神经网络+手写数字识别 本文代码来自Tariq Rashid所著<Python神经网络编程> 代码分为三个部分,框架如下所示: # neural network class ...

  2. 手写数字识别——基于LeNet-5卷积网络模型

    在<手写数字识别——利用Keras高层API快速搭建并优化网络模型>一文中,我们搭建了全连接层网络,准确率达到0.98,但是这种网络的参数量达到了近24万个.本文将搭建LeNet-5网络, ...

  3. 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)

    主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  4. 【TensorFlow-windows】(三) 多层感知器进行手写数字识别(mnist)

    主要内容: 1.基于多层感知器的mnist手写数字识别(代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  5. TensorFlow.NET机器学习入门【5】采用神经网络实现手写数字识别(MNIST)

    从这篇文章开始,终于要干点正儿八经的工作了,前面都是准备工作.这次我们要解决机器学习的经典问题,MNIST手写数字识别. 首先介绍一下数据集.请首先解压:TF_Net\Asset\mnist_png. ...

  6. 基于卷积神经网络的手写数字识别分类(Tensorflow)

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  7. 基于多层感知机的手写数字识别(Tensorflow实现)

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  8. 吴裕雄--天生自然python机器学习:基于支持向量机SVM的手写数字识别

    from numpy import * def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i ...

  9. 【TensorFlow-windows】(一)实现Softmax Regression进行手写数字识别(mnist)

    博文主要内容有: 1.softmax regression的TensorFlow实现代码(教科书级的代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3 ...

随机推荐

  1. 微信小程序实战,用vue3实现每日浪漫情话推荐~

    之前做了个恋爱话术微信小程序,实现高情商的恋爱聊天. 但最近突然发现,每天早上给女朋友发一段优美情话可以让她开心一整天,但无奈自己的语言水平确实有限,不能随手拈来,着实让人有点不爽. 不过办法总比困难 ...

  2. 《手把手教你》系列基础篇(八十五)-java+ selenium自动化测试-框架设计基础-TestNG自定义日志-下篇(详解教程)

    1.简介 TestNG为日志记录和报告提供的不同选项.现在,宏哥讲解分享如何开始使用它们.首先,我们将编写一个示例程序,在该程序中我们将使用 ITestListener方法进行日志记录. 2.Test ...

  3. Java学习day39

    类加载的作用:将class文件字节码内容加载到内存中,并将这些静态数据转换成方法区的运行时数据结构,然后在堆中生成一个代表这个类的java.lang.Class对象,作为方法区中类数据的访问入口. 类 ...

  4. http协议 知识点

    前端工程师,也叫Web前端开发工程师.他是随着web发展,细分出来的行业.第一步要学好HTML.CSS和JavaScript!接着就要学习交互,HTTP协议.Tomcat服务器.PHP服务器端技术是必 ...

  5. RecyclerView + SQLite 简易备忘录-----中(1)

    在上一节讲完了登录界面的内容,现在随着Activity的跳转,来到MainActivity. 1.主界面activity_main.xml 由上图,activity_main.xml的内容很简单. 首 ...

  6. python基础练习题(题目 矩阵对角线之和)

    day25 --------------------------------------------------------------- 实例038:矩阵对角线之和 题目 求一个3*3矩阵主对角线元 ...

  7. 开发并发布npm包,支持TypeScript提示,rollup构建打包

    前言: 工作了几年,想把一些不好找现成的库的常用方法整理一下,发布成npm包,方便使用.也学习一下开发发布流程. 主要用到的工具:npm. 开发库:babel.typescript.rollup.es ...

  8. 架构师必备:Redis的几种集群方案

    结论 有以下几种Redis集群方案,先说结论: Redis cluster:应当优先考虑使用Redis cluster. codis:旧项目如果仍在使用codis,可继续使用,但也推荐迁移到Redis ...

  9. HashMap中红黑树插入节点的调整过程

    如果有对红黑树的定义及调整过程有过研究,其实很容易理解HashMap中的红黑树插入节点的调整过程. "红黑树定义及调整过程"的参考文章:<红黑树原理.查找效率.插入及变化规则 ...

  10. 搭建PWN学习环境

    环境清单 系统环境 Ubuntu22.04 编写脚本 pwntools ZIO 调试 IDA PRO gdb pwndbg ROP工具 checksec ROPgadget one_gadget Li ...