6.MNIST数据集分类简单版本
- import tensorflow as tf
- from tensorflow.examples.tutorials.mnist import input_data
- # 载入数据集
- mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
- # 批次大小
- batch_size = 64
- # 计算一个周期一共有多少个批次
- n_batch = mnist.train.num_examples // batch_size
- # 定义两个placeholder
- x = tf.placeholder(tf.float32,[None,784])
- y = tf.placeholder(tf.float32,[None,10])
- # 创建一个简单的神经网络:784-10
- W = tf.Variable(tf.truncated_normal([784,10], stddev=0.1))
- b = tf.Variable(tf.zeros([10]) + 0.1)
- prediction = tf.nn.softmax(tf.matmul(x,W)+b)
- # 二次代价函数
- loss = tf.losses.mean_squared_error(y, prediction)
- # 使用梯度下降法
- train = tf.train.GradientDescentOptimizer(0.3).minimize(loss)
- # 结果存放在一个布尔型列表中
- correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
- # 求准确率
- accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
- with tf.Session() as sess:
- # 变量初始化
- sess.run(tf.global_variables_initializer())
- # 周期epoch:所有数据训练一次,就是一个周期
- for epoch in range(21):
- for batch in range(n_batch):
- # 获取一个批次的数据和标签
- batch_xs,batch_ys = mnist.train.next_batch(batch_size)
- sess.run(train,feed_dict={x:batch_xs,y:batch_ys})
- # 每训练一个周期做一次测试
- acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
- print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
6.MNIST数据集分类简单版本的更多相关文章
- MNIST数据集分类简单版本
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist = ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- 3.keras-简单实现Mnist数据集分类
keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...
- 6.keras-基于CNN网络的Mnist数据集分类
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
- 深度学习(一)之MNIST数据集分类
任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...
- Tensorflow学习教程------普通神经网络对mnist数据集分类
首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...
- 神经网络MNIST数据集分类tensorboard
今天分享同样数据集的CNN处理方式,同时加上tensorboard,可以看到清晰的结构图,迭代1000次acc收敛到0.992 先放代码,注释比较详细,变量名字看单词就能知道啥意思 import te ...
- 卷积神经网络应用于MNIST数据集分类
先贴代码 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...
- MNIST数据集
一.MNIST数据集分类简单版本 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data # ...
随机推荐
- python之selenium三种等待方法
前提: 我们在做Web自动化时,有的时候要等待元素加载出来,才能操作,不然会报错 1.强制等待 2.隐式等待 3.显示等待 内容: 一,强制等待 这个比较简单,就是利用time模块的sleep的方法来 ...
- Django 1.x版本与2.x版本 区别
django 1.x版本与2.x版本 URL区别 在django 1.x中的方式 导入的模块是'from django.conf.urls import url',urlpatterns中url对应的 ...
- PJzhang:CVE-2019-14287 sudo权限绕过漏洞复现
猫宁!!! 参考链接:Ms08067实验室公众号 sudo 1.8.28版本之前有漏洞. 更新完kali linux,deepin截图工具失效,只能用自带的,不能划重点. 看一下sudo版本,1.8. ...
- PowerDesigner通过SQL语句生成PDM文件并将name和comment进行互相转换
本篇文章主要介绍了PowerDesigner通过SQL语句生成PDM文件并将name和comment进行互相转换 超详细过程(图文),具有一定的参考价值,感兴趣的小伙伴们可以参考一下 1.软件准备 软 ...
- python之pandas学习笔记-pandas数据结构
pandas数据结构 pandas处理3种数据结构,它们建立在numpy数组之上,所以运行速度很快: 1.系列(Series) 2.数据帧(DataFrame) 3.面板(Panel) 关系: 数据结 ...
- vs2017安装过程中下载不动的一种情况
第一种可能:微软可能有不同的下载地址,某些地址下载速度快,某些慢.这种情况下,禁用连接,再启用.有几率速度飞速上升. 第二种可能:由于总所周知的原因,连接不了Google.但是如果需要下载Androi ...
- 桥接模式下访问虚拟机中的Django项目
首先需要保证主机和虚拟机能相互Ping通,如果Ping不通,请参考我上篇文章,这里演示的是桥接模式下的方法,如果是NAT模式连接,请参考别处. 1. 虚拟机Linux系统内的Django项目 sett ...
- 南昌网络赛C.Angry FFF Party
南昌网络赛C.Angry FFF Party Describe In ACM labs, there are only few members who have girlfriends. And th ...
- C语言中signed和unsigned理解
一直在学java,今天开始研究ACM的算法题,需要用到C语言,发现好多知识点都不清楚了,看来以后要多多总结~ signed意思为有符号的,也就是第一个位代表正负,剩余的代表大小,例如:signed i ...
- 学生管理系统利用arrayList第二次优化
package StuManage; public class Student { private String name;//姓名 private String stuNum;//学号 privat ...