TF:TF分类问题之MNIST手写50000数据集实现87.4%准确率识别:SGD法+softmax法+cross_entropy法—Jason niu
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# number 1 to 10 data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) def add_layer(inputs, in_size, out_size, activation_function=None,):
# add one more layer and return the output of this layer
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1,)
Wx_plus_b = tf.matmul(inputs, Weights) + biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b,)
return outputs def compute_accuracy(v_xs, v_ys): global prediction
y_pre = sess.run(prediction, feed_dict={xs: v_xs})
correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(v_ys,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys})
return result # define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10]) # add output layer
prediction = add_layer(xs, 784, 10, activation_function=tf.nn.softmax) # the error between prediction and real data
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.Session()
# important step
sess.run(tf.global_variables_initializer()) for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys})
if i % 50 == 0:
print(compute_accuracy(
mnist.test.images, mnist.test.labels))
TF:TF分类问题之MNIST手写50000数据集实现87.4%准确率识别:SGD法+softmax法+cross_entropy法—Jason niu的更多相关文章
- 吴裕雄 PYTHON 神经网络——TENSORFLOW 无监督学习处理MNIST手写数字数据集
# 导入模块 import numpy as np import tensorflow as tf import matplotlib.pyplot as plt # 加载数据 from tensor ...
- 吴裕雄 python 神经网络——TensorFlow 循环神经网络处理MNIST手写数字数据集
#加载TF并导入数据集 import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tuto ...
- MNIST手写数字数据集
下载python源代码之后,使用: import input_data mnist = input_data.read_data_sets('MNIST_data/',one_hot=True) 下载 ...
- keras实现mnist手写数字数据集的训练
网络:两层卷积,两层全连接,一层softmax 代码: import numpy as np from keras.utils import to_categorical from keras imp ...
- Tensorflow实现MNIST手写数字识别
之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...
- 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识
用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...
- Tensorflow之MNIST手写数字识别:分类问题(1)
一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点: 1.将离散特征的取值扩展 ...
- MNIST手写数字分类simple版(03-2)
simple版本nn模型 训练手写数字处理 MNIST_data数据 百度网盘链接:https://pan.baidu.com/s/19lhmrts-vz0-w5wv2A97gg 提取码:cgnx ...
- 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
随机推荐
- js——this
每个函数的this是在调用时绑定的,完全取决于函数的调用位置 1. 绑定规则总结 一般情况下,按下列顺序从下至上来判断this的绑定对象(绑定的优先级从下至上递减) 默认:在严格模式下绑定到undef ...
- Confluence 6 恢复一个空间
你可以导出一个空间 – 包括页面,评论和附件到一个压缩的 XML 文件中,可选的你可以在 XML 文件中包括所有空间使用的附件.希望导入空间到其他的 Confluence 站点中,请按照下面的方法进行 ...
- Confluence 6 用户目录图例 - 连接 Jira
上面的图:Confluence 连接到 JIRA 为用户管理. https://www.cwiki.us/display/CONFLUENCEWIKI/Diagrams+of+Possible+Con ...
- Flask框架(一)
Flask框架 Flask本身想当于一个内核,其自身几乎所有功能都依靠扩展(邮件扩展Flask-Mail.用户认证Flask-Login),都需要用第三方的扩展来实现.其WSGI工具箱采用Werkze ...
- Spark Streaming通过JDBC操作数据库
本文记录了学习使用Spark Streaming通过JDBC操作数据库的过程,源数据从Kafka中读取. Kafka从0.10版本提供了一种新的消费者API,和0.8不同,因此Spark Stream ...
- shell中的ps命令详解
ps简介:Linux中的ps命令是Process Status的缩写.ps命令用来列出系统中当前运行的那些进程.ps命令列出的是当前那些进程的快照,就是执行ps命令的那个时刻的那些进程,如果想要动态的 ...
- Java并发编程-并发工具类及线程池
JUC中提供了几个比较常用的并发工具类,比如CountDownLatch.CyclicBarrier.Semaphore. CountDownLatch: countdownlatch是一个同步工具类 ...
- cf1110F 离线+树上操作+线段树区间更新
自己搞的算法超时了..但是思路没什么问题:用线段树维护每个点到叶子节点的距离即可 /* 线段树维护区间最小值,每次向下访问,就把访问到的点对应的区间段减去边权 到另一颗子树访问时,向上回溯时加上减去的 ...
- CF 1051F
题意:给定一张n个点,m条边的无向联通图,其中m-n<=20,共q次询问,每次询问求给定两点u,v间的最短路长度 第一眼看见这题的时候,以为有什么神奇的全图最短路算法,满心欢喜的去翻了题解,发现 ...
- 使用powerdesigner导入sql脚本,生成物理模型
有些时候我们的powerdesigner以jdbc的形式链接本地数据库可能会失败,这时候我觉得从sql文件中生成物理模型是个很不错的方法 1.打开powerdesigner,文件->->r ...