1 为什么使用卷积神经网络

Softmax回归是一个比较简单的模型,预测的准确率在91%左右,而使用卷积神经网络将预测的准确率提高到99%。

2 卷积网络的流程

3 代码展示

  1. # -*- coding: utf-8 -*-
  2. import tensorflow as tf
  3. from tensorflow.examples.tutorials.mnist import input_data
  4. #读入数据
  5. mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
  6. #x为训练图像的占位符,y_为训练图像标签的占位符
  7. x = tf.placeholder(tf.float32,[None,784])
  8. y_ = tf.placeholder(tf.float32,[None,10])
  9. #将单张图片从784维向量重新还原为28*28的矩阵图片
  10. x_image = tf.reshape(x,[-1,28,28,1]) #-1 表示任意的数,由实际输入的图像个数决定
  11. # 定义卷积过程中用到的函数
  12. def weight_variable(shape):
  13. initial = tf.truncated_normal(shape,stddev=0.1) #产生正太分布
  14. return tf.Variable(initial)
  15. def bias_variable(shape):
  16. initial = tf.constant(0.1,shape=shape)
  17. return tf.Variable(initial)
  18. def conv2d(x,w):
  19. return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding="SAME")
  20. def max_pool_2x2(x):
  21. return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
  22. # 第一层卷积
  23. w_conv1 = weight_variable([5,5,1,32])
  24. b_conv1 = bias_variable([32])
  25. h_conv1 = tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1)
  26. h_pool1 = max_pool_2x2(h_conv1)
  27. # 第二层卷积
  28. w_conv2 = weight_variable([5,5,32,64])
  29. b_conv2 = bias_variable([64])
  30. h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)
  31. h_pool2 = max_pool_2x2(h_conv2)
  32. # 第一层全连接层,输出1024维的向量
  33. w_fc1 = weight_variable([7*7*64,1024])
  34. b_fc1 = bias_variable([1024])
  35. h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
  36. h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)
  37. #使用Dropout ,keep_prob 是一个占位符,训练是0.5,测试时为1
  38. keep_prob = tf.placeholder(tf.float32)
  39. h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob)
  40. # 第二层全连接层,输出1024维的向量
  41. w_fc2 = weight_variable([1024,10])
  42. b_fc2 = bias_variable([10])
  43. y_conv = tf.matmul(h_fc1_drop,w_fc2)+b_fc2
  44. # 不采用先softmax再计算交叉熵的办法
  45. #采用tf.nn.softmax_cross_entropy_with_logits直接计算
  46. cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_,logits=y_conv))
  47. #定义train_step
  48. train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
  49. #定义准确率
  50. correct_prediction = tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))
  51. accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  52. # 训练
  53. # 创建Session,对变量初始化
  54. sess = tf.InteractiveSession()
  55. sess.run(tf.global_variables_initializer())
  56. #训练2000步
  57. for i in range(2000):
  58. batch = mnist.train.next_batch(50)
  59. # 每一百步报告一次在验证集上的准确率
  60. if i % 100 == 0 :
  61. train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1})
  62. print("step %d,training accuracy %g" % (i,train_accuracy))
  63. train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})
  64. # 训练结束后报告在测试集上的准确率
  65. print("test_accuracy %g" % accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))

4 补充

步长stride是一个一维的向量,长度为4。形式是[a,x,y,z],分别代表[batch滑动步长,水平滑动步长,垂直滑动步长,通道滑动步长]。在tensorflow中,stride的一般形式是[1,x,y,1]

  • 第一个1表示:在batch维度上的滑动步长为1,即不跳过任何一个样本
  • x表示:卷积核的水平滑动步长
  • y表示:卷积核的垂直滑动步长
  • 最后一个1表示:在通道维度上的滑动步长为1,即不跳过任何一个颜色通道

利用TensorFlow识别手写的数字---基于两层卷积网络的更多相关文章

  1. 利用TensorFlow识别手写的数字---基于Softmax回归

    1 MNIST数据集 MNIST数据集主要由一些手写数字的图片和相应的标签组成,图片一共有10类,分别对应从0-9,共10个阿拉伯数字.原始的MNIST数据库一共包含下面4个文件,见下表. 训练图像一 ...

  2. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  3. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  4. 07 训练Tensorflow识别手写数字

    打开Python Shell,输入以下代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input ...

  5. 利用Tensorflow实现手写字符识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  6. TensorFlow下利用MNIST训练模型并识别自己手写的数字

    最近一直在学习李宏毅老师的机器学习视频教程,学到和神经网络那一块知识的时候,我觉得单纯的学习理论知识过于枯燥,就想着自己动手实现一些简单的Demo,毕竟实践是检验真理的唯一标准!!!但是网上很多的与t ...

  7. OpenCV+TensorFlow图片手写数字识别(附源码)

    初次接触TensorFlow,而手写数字训练识别是其最基本的入门教程,网上关于训练的教程很多,但是模型的测试大多都是官方提供的一些素材,能不能自己随便写一串数字让机器识别出来呢?纸上得来终觉浅,带着这 ...

  8. 3 TensorFlow入门之识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  9. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

随机推荐

  1. ie6-8 avalon2 单页应用项目实战备忘

    坑爹的ie,作为小组leader,尼玛,小伙伴儿们不乐意做的事情,我来做好了..心累... 如果,各位同学有定制开发ie6-8版本的需求,还是尽量不要用单页应用模式了,也不要用avalon这类mvvm ...

  2. 关于vlfeat做vlad编码问题

    这里是官方文档,可以自己查看 在这里,只是想记录一下,我这几天学习vlfeat 做vlad编码的过程,便于以后整理 网上涉及到vlfeat做vlad编码资料较少,而官网上例子又相对简单,主要是那几个参 ...

  3. 用MapReduce实现关系的自然连接

  4. webpack英文文档

    https://github.com/webpack/docs/wiki/contents

  5. react中使用屏保

    1,默认路由路径为屏保组件 <HashRouter history={hashHistory}> <Switch> <Route exact path="/&q ...

  6. LoadRunner穿过防火墙运行Vuser和进行监控

    LoadRunner穿过防火墙运行Vuser和进行监控   LoadRunner穿过防火墙进行测试,总结下来是2个方法:1. 在controller和Vuser的LAN中的防火墙都打开54345端口即 ...

  7. [JZOJ4673] 【NOIP2016提高A组模拟7.20】LCS again

    题目 描述 题目大意 给你一个字符串和字符的取值范围,问和这个字符串的最长公共子串的长度为N−1N-1N−1的串的个数. 思考历程 一看就知道这是一个神仙题. 思考了一会儿,觉得AC是没有希望的了. ...

  8. HTML - 图片标签相关

    <html> <head></head> <body> <!-- src : 图片的路径 (本地资源路径, 网络资源路径) title : 图片的 ...

  9. 廖雪峰Java15JDBC编程-3JDBC接口-4JDBC事务

    1 数据库事务:Transaction 1.1 定义 若干SQL语句构成的一个操作序列 要么全部执行成功 要么全部执行不成功 1.2 数据库事务具有ACID特性: Atomicity:原子性 一个事务 ...

  10. LUOGU P2580 于是他错误的点名开始了(trie树)

    传送门 解题思路 trie树模板