程序(有些不甚明白的地方改日修订):

 # _*_coding:utf-8_*_

 import inputdata
mnist = inputdata.read_data_sets('MNIST_data', one_hot=True) # mnist是一个以numpy数组形式存储训练、验证和测试数据的轻量级类 import tensorflow as tf
sess = tf.InteractiveSession() x = tf.placeholder("float",shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10]) W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10])) sess.run(tf.initialize_all_variables()) y = tf.nn.softmax(tf.matmul(x,W)+b) # nn:neural network # 代价函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # 最优化算法
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) # 会更新权值 for i in range(1000):
batch = mnist.train.next_batch(50)
train_step.run(feed_dict={x:batch[0], y_:batch[1]}) # 可以用feed_dict来替代任何张量,并不仅限于替换placeholder correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) print accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels}) # 构建多层卷积网络模型 # 初始化W,b的函数
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1) # truncated_normal表示的是截断的正态分布
return tf.Variable(initial) def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) # 卷积和池化
def conv2d(x, W): # 卷积用原版,1步长0边距
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') def max_pool_2x2(x): # 池化用传统的2*2模板做max polling
return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') # 第一层卷积
W_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_variable([32]) x_image = tf.reshape(x, [-1,28,28,1]) h_conv1= tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1) # 第二层卷积
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64]) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2) # 密集连接层
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) # dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) # 输出层
W_fc2= weight_variable([1024, 10])
b_fc2 = bias_variable([10]) y_conv= tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) # 训练和评估模型
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.initialize_all_variables())
for i in range(20000):
batch = mnist.train.next_batch(50)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict={ x:batch[0], y_: batch[1], keep_prob: 1.0})
print "step %d, training accuracy %g" %(i, train_accuracy)
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob:0.5}) print "test accuracy %g" % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

运行结果:

0.9092
step 0, training accuracy 0.08
step 100, training accuracy 0.9
step 200, training accuracy 0.94
step 300, training accuracy 0.98
step 400, training accuracy 0.98
step 500, training accuracy 0.9
step 600, training accuracy 0.96
step 700, training accuracy 0.96
step 800, training accuracy 0.96
step 900, training accuracy 0.94
step 1000, training accuracy 0.98
step 1100, training accuracy 1
step 1200, training accuracy 0.92
step 1300, training accuracy 0.96
step 1400, training accuracy 0.92
step 1500, training accuracy 0.98
...明天早上跑出来再贴

TensorFlow——深入MNIST的更多相关文章

  1. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  2. Ubuntu16.04安装TensorFlow及Mnist训练

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com TensorFlow是Google开发的开源的深度学习框架,也是当前使用最广泛的深度学习框架. 一.安 ...

  3. 一个简单的TensorFlow可视化MNIST数据集识别程序

    下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...

  4. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  5. 使用Tensorflow操作MNIST数据

    MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会被用作深度学习的入门样例.而TensorFlow的封装让使用MNIST数据集变得更加方便.MNIST数据集是NIST数据集的 ...

  6. TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架

    TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架 http://blog.sina.com.cn/s/blog_4b0020f30102wv4l.html

  7. 2、TensorFlow训练MNIST

    装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...

  8. 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门

    2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...

  9. Tensorflow之MNIST的最佳实践思路总结

    Tensorflow之MNIST的最佳实践思路总结   在上两篇文章中已经总结出了深层神经网络常用方法和Tensorflow的最佳实践所需要的知识点,如果对这些基础不熟悉,可以返回去看一下.在< ...

  10. TensorFlow训练MNIST报错ResourceExhaustedError

    title: TensorFlow训练MNIST报错ResourceExhaustedError date: 2018-04-01 12:35:44 categories: deep learning ...

随机推荐

  1. uvm_reg_fifo——寄存器模型(十五)

    当我们对寄存器register, 存储器memory, 都进行了建模,是时候对FIFO进行建模了 uvm_reg_fifo毫无旁贷底承担起了这个责任,包括:set, get, update, read ...

  2. ftp和sftp

    一.ftp ftp是文件传输协议,ftp协议包括两部分,一个是ftp客户端,另一个是ftp服务器. 原理:一般情况下,当使用FTP服务的时候,我们都知道默认是21号端口,其实还有一个20号端口.FTP ...

  3. userBean-作用范围session

    package com.java1234.model; public class Student { private String name;private int age; public Strin ...

  4. 51nod 1631 小鲨鱼在51nod小学

    基准时间限制:1 秒 空间限制:131072 KB 分值: 20 难度:3级算法题 鲨鱼巨巨2.0(以下简称小鲨鱼)以优异的成绩考入了51nod小学.并依靠算法方面的特长,在班里担任了许多职务.   ...

  5. mdns小结

    mdns的功能和普通DNS很类似,即提供主机名到IP地址的解析服务.   mdns一些基本特性: 1,mdns主要为小型私有网络(不存在DNS)提供名称解析. 2,mdns使用多播(Multicast ...

  6. codeforce Gym 100570B ShortestPath Query (最短路SPFA)

    题意:询问单源最短路径,每条边有一个颜色,要求路径上相邻边的颜色不能相同,无重边且边权为正. 题解:因为路径的合法性和边的颜色有关, 所以在做spfa的时候,把边丢到队列中去,松弛的时候注意判断一下颜 ...

  7. codeforce Gym 100500E IBM Chill Zone (SG函数)

    关于sg函数这篇blog讲得很详细http://blog.csdn.net/logic_nut/article/details/4711489. sg函数的价值在于把复杂的游戏拆分成简单的游戏,然后通 ...

  8. 总结一下自己脑海里的JavaScript吧(一)--DOM模型

    今天是2019年6月25日,闲来无事,写一篇文章来看看自己脑袋里装了多少JavaScript知识! 这儿就第一章: 说起JavaScript,它是什么?后端脚本语言?前端编程语言?还是在网站浏览器上运 ...

  9. Redis的安装以及spring整合Redis时出现Could not get a resource from the pool

    Redis的下载与安装 在Linux上使用wget http://download.redis.io/releases/redis-5.0.0.tar.gz下载源码到指定位置 解压:tar -xvf ...

  10. java从键盘输入三个整数,实现从小到大排序

    package study01; import java.util.Scanner; public class Sort { /** * 需求:由键盘输入三个整数分别存入变量a.b.c,对他们进行 排 ...