6.MNIST数据集分类简单版本
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 批次大小
batch_size = 64
# 计算一个周期一共有多少个批次
n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10]) # 创建一个简单的神经网络:784-10
W = tf.Variable(tf.truncated_normal([784,10], stddev=0.1))
b = tf.Variable(tf.zeros([10]) + 0.1)
prediction = tf.nn.softmax(tf.matmul(x,W)+b) # 二次代价函数
loss = tf.losses.mean_squared_error(y, prediction)
# 使用梯度下降法
train = tf.train.GradientDescentOptimizer(0.3).minimize(loss) # 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess:
# 变量初始化
sess.run(tf.global_variables_initializer())
# 周期epoch:所有数据训练一次,就是一个周期
for epoch in range(21):
for batch in range(n_batch):
# 获取一个批次的数据和标签
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x:batch_xs,y:batch_ys})
# 每训练一个周期做一次测试
acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
6.MNIST数据集分类简单版本的更多相关文章
- MNIST数据集分类简单版本
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist = ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- 3.keras-简单实现Mnist数据集分类
keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...
- 6.keras-基于CNN网络的Mnist数据集分类
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
- 深度学习(一)之MNIST数据集分类
任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...
- Tensorflow学习教程------普通神经网络对mnist数据集分类
首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...
- 神经网络MNIST数据集分类tensorboard
今天分享同样数据集的CNN处理方式,同时加上tensorboard,可以看到清晰的结构图,迭代1000次acc收敛到0.992 先放代码,注释比较详细,变量名字看单词就能知道啥意思 import te ...
- 卷积神经网络应用于MNIST数据集分类
先贴代码 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...
- MNIST数据集
一.MNIST数据集分类简单版本 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data # ...
随机推荐
- 【转】Jquery ajax与asp.net MVC前后端各种交互
本文转载自:https://www.cnblogs.com/fengyeqingxiang/p/11169218.html 1.Jquery通过ajaxSubmit提交表单 if (jQuery(&q ...
- 综合对比 Kafka、RabbitMQ、RocketMQ、ActiveMQ 四个分布式消息队列
来源:http://t.cn/RVDWcfe 一.资料文档 Kafka:中.有kafka作者自己写的书,网上资料也有一些.rabbitmq:多.有一些不错的书,网上资料多.zeromq:少.没有专门写 ...
- smoothscroll
smoothscroll是一款jQuery插件,可以平滑地滚动到指定的地方. 可以解决chrome锚点失效的问题. 官方网站 http://iamdustan.com/smoothscroll/ gi ...
- 安装 maven
1.打开http://maven.apache.org/index.html 2.选择USE,点击下载 3.下移窗口到File点击红框内的链接 4.将下载的压缩包解压到c:\apps 5.将maven ...
- JS字符串格式化~欢迎来搂~~
/* 函数:格式化字符串 参数:str:字符串模板: data:数据 调用方式:formatString("api/values/{id}/{name}",{id:101,name ...
- rowkey散列和预分区设计解决hbase热点问题(数据倾斜)
Hbase的表会被划分为1....n个Region,被托管在RegionServer中.Region二个重要的属性:Startkey与EndKey表示这个Region维护的rowkey的范围,当我们要 ...
- Linux下通过ssh上传下载文件
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/jun8148/article/deta ...
- Java IO与NIO的总结、比较
一.IO流总结 1.Java I/O主要包括如下3层次: 流式部分——最主要的部分.如:OutputStream.InputStream.Writer.Reader等 非流式部分——如:File类.R ...
- Java Web ActiveMQ与WebService的异同
Webservice 和MQ(MessageQueue)都是解决跨平台通信的常用手段 一.WebService:用来远程调用服务,达到打通系统.服务复用的目的.是SOA系统架构——面向服务架构的体现. ...
- 无障碍开发(六)之ARIA在HTML中的使用规则
ARIA使用规则一 如果你使用的元素( HTML5 )具有语义化,应该使用这些元素,而不应该重新定义一个添加ARIA的角色.状态或属性的元素. 浏览器的语义化标签已经默认隐含ARIA语义,像nav,a ...