#load MNIST data
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) #start tensorflow interactiveSession
import tensorflow as tf
sess = tf.InteractiveSession() #weight initilization
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial) def bias_variable(shape):
initial = tf.constant(0.1, shape= shape)
return tf.Variable(initial) #convolution
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') #pooling
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1,2,2,1],strides=[1,2,2,1], padding='SAME') #Create the model
#placeholder
x = tf.placeholder("float",[None, 784])
y_ = tf.placeholder("float", [None, 10]) #variable
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x,W) +b) #first convolutional layer
w_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_variable([32]) x_image = tf.reshape(x,[-1,28,28,1]) h_conv1 =tf.nn.relu(conv2d(x_image,w_conv1) + b_conv1)
h_pool1 =max_pool_2x2(h_conv1) #second convolutional layer
w_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_variable([64]) h_conv2 =tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
h_pool2 =max_pool_2x2(h_conv2) #densely connected layer
w_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1,7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1) #dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) #readout layer
w_fc2 = weight_variable([1024,10])
b_fc2 = bias_variable([10]) y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2) + b_fc2) #train and evaluate the model
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
#train_step = tf.train.AdagradOptimizer(1e-4).minimize(cross_entropy)
train_step = tf.train.GradientDescentOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.initialize_all_variables())
for i in range(5000):
batch = mnist.train.next_batch(50)
if i%100 ==0:
train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1], keep_prob:1.0})
print "step %d, train accuracy %g " %(i,train_accuracy)
train_step.run(feed_dict={x:batch[0],y_:batch[1], keep_prob:0.5}) print "test accuracy %g" % accuracy.eval(fedd_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0})

同样是极客学院的课程,其实也是翻译的国外的robot-ai博客上的内容,但是这个博客,现在打不开了,可能是墙的问题?没有太深究。

按照作者的说法,是采用自适应下降的方式,在train阶段能达到99%的正确率,但是,我的结果只有93%左右,修改梯度步长到1e-4也只有94% 。因此尝试换用原来的梯度下降方式,反而能获得97.61%的正确率,在训练中还达到过98%,这个问题比较无奈,修改步长的结果提升也并不明显。有人在评论中说在不同的平台上测试的值不同,比如在纯CPU环境,和我的结果比较相似。在K20环境中能达到99%,这个问题留待以后探索。代码参考至:文章链接: http://blog.csdn.net/yhl_leo/article/details/50624471

TensorFlow入门——MNIST深入的更多相关文章

  1. TensorFlow入门——MNIST初探

    import tensorflow.examples.tutorials.mnist.input_data as input_data import tensorflow as tf mnist = ...

  2. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  3. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  4. TensorFlow入门之MNIST最佳实践

    在上一篇<TensorFlow入门之MNIST样例代码分析>中,我们讲解了如果来用一个三层全连接网络实现手写数字识别.但是在实际运用中我们需要更有效率,更加灵活的代码.在TensorFlo ...

  5. TensorFlow入门之MNIST最佳实践-深度学习

    在上一篇<TensorFlow入门之MNIST样例代码分析>中,我们讲解了如果来用一个三层全连接网络实现手写数字识别.但是在实际运用中我们需要更有效率,更加灵活的代码.在TensorFlo ...

  6. 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门

    2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...

  7. TensorFlow 入门之手写识别(MNIST) softmax算法 二

    TensorFlow 入门之手写识别(MNIST) softmax算法 二 MNIST Fly softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  8. TensorFlow 入门之手写识别(MNIST) 数据处理 一

    TensorFlow 入门之手写识别(MNIST) 数据处理 一 MNIST Fly softmax回归 准备数据 解压 与 重构 手写识别入门 MNIST手写数据集 图片以及标签的数据格式处理 准备 ...

  9. 统计学习方法:罗杰斯特回归及Tensorflow入门

    作者:桂. 时间:2017-04-21  21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...

随机推荐

  1. linux 之 pthread_create 实现类的成员函数做参数

    在C++的类中,普通成员函数不能作为pthread_create的线程函数,如果要作为pthread_create中的线程函数,必须是static ! 在C语言中,我们使用pthread_create ...

  2. Selenium2Library测试web

    Selenium 定位元素 ▲ Locator 可以id或name来用定位界面元素 也可以使用XPath或Dom,但是,必须用XPath=或Dom=来开头 ▲ 最好使用id来定位,强烈建议强制要求开发 ...

  3. JxBrowser开启调试模式,JxBrowser debug

    原文: 一.问题描述 像一般的浏览器都带了调试功能,按F12就能打开,在JxBrowser中如何开启调试模式了. 二.解决方法 以下代码就能开启调试模式: import com.teamdev.jxb ...

  4. PMML辅助机器学习算法上线

    在机器学习用于产品的时候,我们经常会遇到跨平台的问题.比如我们用Python基于一系列的机器学习库训练了一个模型,但是有时候其他的产品和项目想把这个模型集成进去,但是这些产品很多只支持某些特定的生产环 ...

  5. vue 默认展开详情页

    { path: '/Tree', component: Tree, children: [ { path: '/', component: Come } ] }

  6. Maven 默认 SpringMVC-servlet.xml 基本配置

    <?xml version="1.0" encoding="UTF-8"?> <beans xmlns="http://www.sp ...

  7. vs code的local history插件

    使用vs code来编写前端代码,内存的使用比webstrom要少很多. vs code可以下载中文及debug插件,使用起来会方便很多. vs code不像idea或者webstrom有local ...

  8. linux常用、常见错误

    1.md5加密使用 oppnssl md5 加密字符串的方法 [root@lab3 ~]# openssl //在终端中输入openssl后回车. OpenSSL> md5 //输入md5后回车 ...

  9. 使用robotframework做接口测试三——保持登录状态

    调用登录接口登录了,其他的接口怎么保持登录状态呢?  首先来看一看,web端或者说客户端是怎么样用cookie/token等保持登录状态的.一般来说,cookie都会在登录接口由服务端返回,而且会是在 ...

  10. elasticsearch-head-master下运行npm install报npm WARN elasticsearch-head@0.0.0 license should be a valid SPDX license expression

    2个月没有启动es和es配套服务,今天运行时,发现如下问题: 运行npm install 出现npm WARN elasticsearch-head@0.0.0 license should be a ...