手写数字识别-使用TensorFlow构建和训练一个简单的神经网络
下面是一个具体的Python代码示例,展示如何使用TensorFlow实现一个简单的神经网络来解决手写数字识别问题(使用MNIST数据集)。以下是一个完整的Python代码示例,展示如何使用TensorFlow构建和训练一个简单的神经网络来进行手写数字识别。
MNIST
数据集的训练集有60000个样本:
Python代码
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import json
import os
# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 预处理数据
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
# 构建神经网络模型
def create_model():
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
# 训练模型并保存
def train_and_save_model():
model = create_model()
history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
model.save('mnist_model.h5')
# 保存训练历史记录
with open('training_history.json', 'w') as f:
json.dump(history.history, f)
# 加载模型和历史记录
def load_model_and_history():
model = tf.keras.models.load_model('mnist_model.h5')
with open('training_history.json', 'r') as f:
history = json.load(f)
return model, history
# 评估模型
def evaluate_model(model):
test_loss, test_acc = model.evaluate(test_images, test_labels)
print("Test accuracy: {}".format(test_acc))
# 可视化训练过程
def plot_training_history(history):
plt.plot(history['accuracy'], label='accuracy')
plt.plot(history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show()
# 检查是否已经存在模型和历史记录
if not os.path.exists('mnist_model.h5') or not os.path.exists('training_history.json'):
train_and_save_model()
model, training_history = load_model_and_history()
evaluate_model(model)
plot_training_history(training_history)
代码解释
加载MNIST数据集:
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
预处理数据:
- 将图像数据调整为
(28, 28, 1)
的形状。 - 将像素值标准化为
[0, 1]
之间。
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
- 将图像数据调整为
构建神经网络模型:
- 使用
Sequential
模型,按顺序添加层。 - 添加卷积层、池化层、全连接层。
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
- 使用
编译模型:
- 使用
adam
优化器,损失函数为sparse_categorical_crossentropy
,评估指标为accuracy
。
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
- 使用
训练模型:
- 训练模型5个epochs,并使用验证数据集评估模型性能。
history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
评估模型:
- 在测试集上评估模型性能,并打印测试准确率。
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc}")
可视化训练过程:
- 绘制训练和验证准确率随epoch变化的曲线。
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show()
通过这个修正后的示例,应该可以正常运行并训练一个简单的神经网络模型来进行手写数字识别。
手写数字识别-使用TensorFlow构建和训练一个简单的神经网络的更多相关文章
- Softmax用于手写数字识别(Tensorflow实现)-个人理解
softmax函数的作用 对于分类方面,softmax函数的作用是从样本值计算得到该样本属于各个类别的概率大小.例如手写数字识别,softmax模型从给定的手写体图片像素值得出这张图片为数字0~9 ...
- 基于卷积神经网络的手写数字识别分类(Tensorflow)
import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dat ...
- TensorFlow 之 手写数字识别MNIST
官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...
- TensorFlow使用RNN实现手写数字识别
学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- 第三节,TensorFlow 使用CNN实现手写数字识别(卷积函数tf.nn.convd介绍)
上一节,我们已经讲解了使用全连接网络实现手写数字识别,其正确率大概能达到98%,这一节我们使用卷积神经网络来实现手写数字识别, 其准确率可以超过99%,程序主要包括以下几块内容 [1]: 导入数据,即 ...
- 第二节,TensorFlow 使用前馈神经网络实现手写数字识别
一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...
- python-积卷神经网络全面理解-tensorflow实现手写数字识别
首先,关于神经网络,其实是一个结合很多知识点的一个算法,关于cnn(积卷神经网络)大家需要了解: 下面给出我之前总结的这两个知识点(基于吴恩达的机器学习) 代价函数: 代价函数 代价函数(Cost F ...
- Tensorflow之MNIST手写数字识别:分类问题(1)
一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点: 1.将离散特征的取值扩展 ...
- Tensorflow实战 手写数字识别(Tensorboard可视化)
一.前言 为了更好的理解Neural Network,本文使用Tensorflow实现一个最简单的神经网络,然后使用MNIST数据集进行测试.同时使用Tensorboard对训练过程进行可视化,算是打 ...
随机推荐
- ISCC 2024 练武题 misc趣题记录
Number_is_the_key 题目 The answers to the questions are hidden in the numbers. 文件是空白的xlsx文件 我的解答: 乱点发现 ...
- Vue cli之路由router
一.安装路由 Vue-router用于提供给vue项目在开发中用于绑定url和组件页面的关系的核心插件. 默认情况下,vue没有提供路由的功能,所以我们使用vue-router,并需要在项目根目录. ...
- Win11 LTSC 中文版来了,丝般顺滑,极速响应
最近网络上出现了泄露的Win11的LTSC版本,版本号为Build 26100.1,据息,该泄露版是微软提供给OEM厂商测试用的,是今年下半年的Windows 11 LTSC RTM版的正式版本,却被 ...
- 工作中常用Less知识点实践总结
我所理解的Less的一些好处 函数式编程css 自定义变量用于整体主题调整 嵌套语法简化开发复杂度 mixin的写法 .defaultBorder(@width: 10px, @style: soli ...
- jquery的树状菜单
<body> <ul> <li>一级菜单 <ol> <li ...
- 升级babel7后,直接引用element-ui中没有暴露出来的组件image-viewer.vue导致的打包错误
问题 & 解决方案 升级babel7后,原先代码中像这样直接引用element-ui组件的地方,出现了报错 Module parse failed: Unexpected token (1:0 ...
- 两个Excel表格核对 excel表格中# DIV/0 核对两个表格的差异,合并运算VS高级筛选
两个Excel表格核对 excel表格中# DIV/0 核对两个表格的差异,合并运算VS高级筛选 1.两列顺序一样的数据核对 方法1:加一个辅助列,=B2=C2 结果为FALSE的就是不相同的 方 ...
- fastjson对接口参数的某个字段不打印输出,如文件的base64字符串
fastjson对接口参数的某个字段不打印输出,如文件的base64字符串 package com.example.core.mydemo.json5; import com.alibaba.fast ...
- @Valid + BindingResult 拦截接口错误信息
@Valid + BindingResult 拦截接口错误信息###测试发现: HttpServletRequest request, HttpServletResponse response, 需要 ...
- DHorse的配置文件
首先看一下DHorse的配置文件,如下: #============================================================================== ...