下面是一个具体的Python代码示例,展示如何使用TensorFlow实现一个简单的神经网络来解决手写数字识别问题(使用MNIST数据集)。以下是一个完整的Python代码示例,展示如何使用TensorFlow构建和训练一个简单的神经网络来进行手写数字识别。

MNIST数据集的训练集有60000个样本:

Python代码

  1. import tensorflow as tf
  2. from tensorflow.keras import layers, models
  3. import matplotlib.pyplot as plt
  4. import json
  5. import os
  6. # 加载MNIST数据集
  7. mnist = tf.keras.datasets.mnist
  8. (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
  9. # 预处理数据
  10. train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
  11. test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
  12. # 构建神经网络模型
  13. def create_model():
  14. model = models.Sequential()
  15. model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
  16. model.add(layers.MaxPooling2D((2, 2)))
  17. model.add(layers.Conv2D(64, (3, 3), activation='relu'))
  18. model.add(layers.MaxPooling2D((2, 2)))
  19. model.add(layers.Conv2D(64, (3, 3), activation='relu'))
  20. model.add(layers.Flatten())
  21. model.add(layers.Dense(64, activation='relu'))
  22. model.add(layers.Dense(10, activation='softmax'))
  23. # 编译模型
  24. model.compile(optimizer='adam',
  25. loss='sparse_categorical_crossentropy',
  26. metrics=['accuracy'])
  27. return model
  28. # 训练模型并保存
  29. def train_and_save_model():
  30. model = create_model()
  31. history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
  32. model.save('mnist_model.h5')
  33. # 保存训练历史记录
  34. with open('training_history.json', 'w') as f:
  35. json.dump(history.history, f)
  36. # 加载模型和历史记录
  37. def load_model_and_history():
  38. model = tf.keras.models.load_model('mnist_model.h5')
  39. with open('training_history.json', 'r') as f:
  40. history = json.load(f)
  41. return model, history
  42. # 评估模型
  43. def evaluate_model(model):
  44. test_loss, test_acc = model.evaluate(test_images, test_labels)
  45. print("Test accuracy: {}".format(test_acc))
  46. # 可视化训练过程
  47. def plot_training_history(history):
  48. plt.plot(history['accuracy'], label='accuracy')
  49. plt.plot(history['val_accuracy'], label='val_accuracy')
  50. plt.xlabel('Epoch')
  51. plt.ylabel('Accuracy')
  52. plt.ylim([0, 1])
  53. plt.legend(loc='lower right')
  54. plt.show()
  55. # 检查是否已经存在模型和历史记录
  56. if not os.path.exists('mnist_model.h5') or not os.path.exists('training_history.json'):
  57. train_and_save_model()
  58. model, training_history = load_model_and_history()
  59. evaluate_model(model)
  60. plot_training_history(training_history)

代码解释

  1. 加载MNIST数据集

    1. mnist = tf.keras.datasets.mnist
    2. (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
  2. 预处理数据

    • 将图像数据调整为 (28, 28, 1) 的形状。
    • 将像素值标准化为 [0, 1] 之间。
    1. train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
    2. test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
  3. 构建神经网络模型

    • 使用 Sequential 模型,按顺序添加层。
    • 添加卷积层、池化层、全连接层。
    1. model = models.Sequential()
    2. model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    3. model.add(layers.MaxPooling2D((2, 2)))
    4. model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    5. model.add(layers.MaxPooling2D((2, 2)))
    6. model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    7. model.add(layers.Flatten())
    8. model.add(layers.Dense(64, activation='relu'))
    9. model.add(layers.Dense(10, activation='softmax'))
  4. 编译模型

    • 使用 adam 优化器,损失函数为 sparse_categorical_crossentropy,评估指标为 accuracy
    1. model.compile(optimizer='adam',
    2. loss='sparse_categorical_crossentropy',
    3. metrics=['accuracy'])
  5. 训练模型

    • 训练模型5个epochs,并使用验证数据集评估模型性能。
    1. history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
  6. 评估模型

    • 在测试集上评估模型性能,并打印测试准确率。
    1. test_loss, test_acc = model.evaluate(test_images, test_labels)
    2. print(f"Test accuracy: {test_acc}")
  7. 可视化训练过程

    • 绘制训练和验证准确率随epoch变化的曲线。
    1. plt.plot(history.history['accuracy'], label='accuracy')
    2. plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
    3. plt.xlabel('Epoch')
    4. plt.ylabel('Accuracy')
    5. plt.ylim([0, 1])
    6. plt.legend(loc='lower right')
    7. plt.show()

通过这个修正后的示例,应该可以正常运行并训练一个简单的神经网络模型来进行手写数字识别。

手写数字识别-使用TensorFlow构建和训练一个简单的神经网络的更多相关文章

  1. Softmax用于手写数字识别(Tensorflow实现)-个人理解

    softmax函数的作用   对于分类方面,softmax函数的作用是从样本值计算得到该样本属于各个类别的概率大小.例如手写数字识别,softmax模型从给定的手写体图片像素值得出这张图片为数字0~9 ...

  2. 基于卷积神经网络的手写数字识别分类(Tensorflow)

    import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...

  3. TensorFlow 之 手写数字识别MNIST

    官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...

  4. TensorFlow使用RNN实现手写数字识别

    学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...

  5. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  6. 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)

    上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...

  7. 第二节,TensorFlow 使用前馈神经网络实现手写数字识别

    一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...

  8. python-积卷神经网络全面理解-tensorflow实现手写数字识别

    首先,关于神经网络,其实是一个结合很多知识点的一个算法,关于cnn(积卷神经网络)大家需要了解: 下面给出我之前总结的这两个知识点(基于吴恩达的机器学习) 代价函数: 代价函数 代价函数(Cost F ...

  9. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  10. Tensorflow实战 手写数字识别(Tensorboard可视化)

    一.前言 为了更好的理解Neural Network,本文使用Tensorflow实现一个最简单的神经网络,然后使用MNIST数据集进行测试.同时使用Tensorboard对训练过程进行可视化,算是打 ...

随机推荐

  1. 使用爬虫利器 Playwright,轻松爬取抖查查数据

    使用爬虫利器 Playwright,轻松爬取抖查查数据 我们先分析登录的接口,其中 url 有一些非业务参数:ts.he.sign.secret. 然后根据这些参数作为关键词,定位到相关的 js 代码 ...

  2. this的二种使用方式

    package com.ht.TestThis; public class TestThisKey { public static void main(String[] args) { // TODO ...

  3. zkq 数学听课笔记

    线性代数 域 \(F\),OI 中常用的域是 \(\Z_{p^c}\). \(n\) 维向量 \(\vec x \in F^n\),其中 \(x_i \in F\),注意向量是列向量. \(F^n\) ...

  4. LeetCode 347. Top K Frequent Elements 前 K 个高频元素 (Java)

    题目: Given a non-empty array of integers, return the k most frequent elements. Example 1: Input: nums ...

  5. 架构与思维:了解Http 和 Https的区别(图文详解)

    1 介绍 随着 HTTPS 的不断普及和使用成本的下降,现阶段大部分的系统都已经开始用上 HTTPS 协议. HTTPS 与 HTTP 相比, 主打的就是安全概念,相关的知识如 SSL .非对称加密. ...

  6. 项目管理--PMBOK 读书笔记(8)【项目质量管理】

    1.数据表现-流程图: 流程图也称为过程图,用来显示在一个或者多个输入转化为一个或者多个输入出的过程. 2.质量工具图比较: 发现问题:控制图(七点规则等).趋势图 寻找原因:因果图.流程图 分析原因 ...

  7. JavaScript实现防抖节流函数

    review 防抖函数 防抖函数一般是短时间内多次触发,但是只有最后一次触发结束后的delay秒内会去执行相对应的处理函数. 相当于一个赛道里面一次只能跑一辆赛车,如果此时已经有一辆赛车在跑道里面跑, ...

  8. MySQL常见的后端面试题,你会几道?

    为什么分库分表 单表数据量过大,会出现慢查询,所以需要水平分表 可以把低频.高频的字段分开为多个表,低频的表作为附加表,且逻辑更加清晰,性能更优 随着系统的业务模块的增多,放到单库会增加其复杂度,逻辑 ...

  9. 浮点数格式:FP64, FP32, FP16, BFLOAT16, TF32之间的相互区别

    浮点数格式 (参考1,参考2) 浮点数是一种用二进制表示的实数,它由三个部分组成:sign(符号位).exponent(指数位)和fraction(小数位).不同的浮点数格式有不同的位数分配给这三个部 ...

  10. Linux内核:regmap机制

    背景 在学习SPI框架的时候,看到了有一个rtc驱动用到了regmap,本想通过传统方式访问spi接口的我,突然有点不适应,翻了整个驱动,愣是没有找到读写spi的范式:因此了解了regmap以后,才发 ...