1 MNIST数据集

MNIST数据集由70000张28x28像素的黑白图片组成,每一张图片都写有0~9中的一个数字,每个像素点的灰度值在0 ~ 255(0是黑色,255是白色)之间。



MINST数据集是由Yann LeCun教授提供的手写数字数据库文件,其官方下载地址THE MNIST DATABASE of handwritten digits



下载好MNIST数据集后,将其放在Spyder工作目录下(若使用Jupyter编程,则放在Jupyter工作目录下),如图:



G:\Anaconda\Spyder为笔者Spyder工作目录,MNIST_data为新建文件夹,读者也可以自行命名。

2 实验

为方便设计神经网络输入层,将每张28x28像素图片的像素值按行排成一行,故输入层设计28x28=784个神经元,隐藏层设计600个神经元,输出层设计10个神经元。使用read_data_sets()函数载入数据集,并返回一个类,这个类将MNIST数据集划分为train、validation、test 3个数据集,对应图片数分别为55000、5000、10000。本文采用交叉熵损失函数,并且为防止过拟合问题产生,引入正则化方法。

mnist.py

  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. #载入数据集
  4. mnist=input_data.read_data_sets("MNIST_data",one_hot=True)
  5. #每批次的大小
  6. batch_size=100
  7. #总批次数
  8. batch_num=mnist.train.num_examples//batch_size
  9. #训练轮数
  10. training_step = tf.Variable(0,trainable=False)
  11. #定义两个placeholder
  12. x=tf.placeholder(tf.float32, [None,784])
  13. y=tf.placeholder(tf.float32, [None,10])
  14. #神经网络layer_1
  15. w1=tf.Variable(tf.random_normal([784,600]))
  16. b1=tf.Variable(tf.constant(0.1,shape=[600]))
  17. z1=tf.matmul(x,w1)+b1
  18. a1=tf.nn.tanh(z1)
  19. #神经网络layer_2
  20. w2=tf.Variable(tf.random_normal([600,10]))
  21. b2=tf.Variable(tf.constant(0.1,shape=[10]))
  22. z2=tf.matmul(a1,w2)+b2
  23. #交叉熵代价函数
  24. cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y,1),logits=z2)
  25. #cross_entropy=tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=z2)
  26. #L2正则化函数
  27. regularizer=tf.contrib.layers.l2_regularizer(0.0001)
  28. #总损失
  29. loss=tf.reduce_mean(cross_entropy)+regularizer(w1)+regularizer(w2)
  30. #学习率(指数衰减法)
  31. laerning_rate = tf.train.exponential_decay(0.8,training_step,batch_num,0.999)
  32. #梯度下降法优化器
  33. train=tf.train.GradientDescentOptimizer(laerning_rate).minimize(loss,global_step=training_step)
  34. #预测精度
  35. correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(z2,1))
  36. accuracy=tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  37. #初始化变量
  38. init=tf.global_variables_initializer()
  39. with tf.Session() as sess:
  40. sess.run(init)
  41. test_feed={x:mnist.test.images,y:mnist.test.labels}
  42. for epoch in range(51):
  43. for batch in range(batch_num):
  44. x_,y_=mnist.train.next_batch(batch_size)
  45. sess.run(train,feed_dict={x:x_,y:y_})
  46. acc=sess.run(accuracy,feed_dict=test_feed)
  47. if epoch%10==0:
  48. print("epoch:",epoch,"accuracy:",acc)



迭代50次后,精度达到97.68%。

​ 声明:本文转自使用TensorFlow实现MNIST数据集分类

使用TensorFlow实现MNIST数据集分类的更多相关文章

  1. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  2. 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...

  3. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  4. 一个简单的TensorFlow可视化MNIST数据集识别程序

    下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...

  5. 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)

    1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...

  6. TensorFlow 训练MNIST数据集(2)—— 多层神经网络

    在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率.这次换一种神经网络(多层神经网络)来进行训练和测试. 1.获取MNIST数据 MNIST数据集只要一行代码 ...

  7. 《Hands-On Machine Learning with Scikit-Learn&TensorFlow》mnist数据集错误及解决方案

    最近在看这本书看到Chapter 3.Classification,是关于mnist数据集的分类,里面有个代码是 from sklearn.datasets import fetch_mldata m ...

  8. TensorFlow训练MNIST数据集(1) —— softmax 单层神经网络

    1.MNIST数据集简介 首先通过下面两行代码获取到TensorFlow内置的MNIST数据集: from tensorflow.examples.tutorials.mnist import inp ...

  9. 基于 tensorflow 的 mnist 数据集预测

    1. tensorflow 基本使用方法 2. mnist 数据集简介与预处理 3. 聚类算法模型 4. 使用卷积神经网络进行特征生成 5. 训练网络模型生成结果 how to install ten ...

  10. TensorFlow 下 mnist 数据集的操作及可视化

    from tensorflow.examples.tutorials.mnist import input_data 首先需要连网下载数据集: mnsit = input_data.read_data ...

随机推荐

  1. 解决windows系统电脑内存占用过高,一开机就是60%70%80%90%?

    1.问题 windows系统电脑内存占用过高,一开机就是60%70%80%90%? 2.解决方式 主要是虚拟内存一直没有及时释放导致的 先贴上B站视频链接:解决windows系统电脑内存占用过高 这里 ...

  2. ChatGPT-NextWeb部署和调试打造属于自己的GPT

    首先我关注这个项目有一段时间了,不得不说作者和他的社区真的很猛! 首先这个项目截至目前已经有了40.9K的Start了,Fork也已经有了38.1K了,这个数据真的超级牛批了. 那么我们来看一下这款号 ...

  3. ext4 磁盘扩容

    目录 ext4文件系统磁盘扩容 目标 途径 操作步骤 改变前的现状 操作和改变后的状态 ext4文件系统磁盘扩容 一个磁盘有多个分区,分别创建了物理卷.卷组.逻辑卷.通过虚拟机软件对虚拟机的磁盘/de ...

  4. [转帖]Linux-文本处理三剑客awk详解+企业真实案例(变量、正则、条件判断、循环、数组、分析日志)

    https://developer.aliyun.com/article/885607?spm=a2c6h.24874632.expert-profile.313.7c46cfe9h5DxWK 简介: ...

  5. 隐私集合求交(PSI)协议研究综述

    摘要 隐私集合求交(PSI)是安全多方计算(MPC)中的一种密码学技术,它允许参与计算的双方,在不获取对方额外信息(除交集外的其它信息)的基础上,计算出双方数据的交集.隐私集合求交在数据共享,广告转化 ...

  6. 限制input框中字数的输入maxlength

    今天产品提出一个需求就是.限制input框中的的值. 当用户超过10个字符时,用户再次输入的时,就不能够输入了. (最后就能够输入10个字符) maxlength=10 <input maxle ...

  7. 开源项目02-OSharp

    项目名称:OSharp 项目所用技术栈: osharp netstandard aspnetcore osharpns ng-alain angular等 项目简介: OSharp是一个基于.NetC ...

  8. C/C++ 实现FTP文件上传下载

    FTP(文件传输协议)是一种用于在网络上传输文件的标准协议.它属于因特网标准化的协议族之一,为文件的上传.下载和文件管理提供了一种标准化的方法,在Windows系统中操作FTP上传下载可以使用WinI ...

  9. Docker从认识到实践再到底层原理(二-1)|容器技术发展史+虚拟化容器概念和简介

    前言 那么这里博主先安利一些干货满满的专栏了! 首先是博主的高质量博客的汇总,这个专栏里面的博客,都是博主最最用心写的一部分,干货满满,希望对大家有帮助. 高质量博客汇总 然后就是博主最近最花时间的一 ...

  10. 解决:docker开启mongo镜像

    首先通过docker pull mongo拉取mongo镜像 (如果带版本,拉取为响应版本,若不带版本则拉取最新版本) 开启 mongodb 容器 可以选择将宿主机的mongo工作目录进行共享,作为d ...