手写数字识别经典案例,目标是:

1. 掌握tf编写RNN的方法

2. 剖析RNN网络结构

tensorflow编程

#coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data ### 注意
# init_state = tf.zeros(shape=[batch_size,rnn_cell.state_size])
# init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32) mnist=input_data.read_data_sets("./data",one_hot=True) # 常规参数
train_rate=0.001
train_step=10000
batch_size=1280
display_step=100 # rnn参数
frame_size=28 # 输入特征数
sequence_length=28 # 输入个数, 时序
hidden_num=100 # 隐层神经元个数
n_classes=10 # 定义输入,输出
# 此处输入格式是样本数*特征数,特征是把图片拉成一维的,当然一维还是二维自己定,改成相应的代码就行了
x=tf.placeholder(dtype=tf.float32,shape=[None,sequence_length*frame_size],name="inputx")
y=tf.placeholder(dtype=tf.float32,shape=[None,n_classes],name="expected_y") # 定义权值
# 注意权值设定只设定v, u和w无需设定
weights=tf.Variable(tf.truncated_normal(shape=[hidden_num,n_classes])) # 全连接层权重
bias=tf.Variable(tf.zeros(shape=[n_classes])) def RNN(x,weights,bias):
x=tf.reshape(x,shape=[-1,sequence_length,frame_size]) # 3维
rnn_cell=tf.nn.rnn_cell.BasicRNNCell(hidden_num) ### 注意
# init_state=tf.zeros(shape=[batch_size,rnn_cell.state_size]) # rnn_cell.state_size 100
init_state=rnn_cell.zero_state(batch_size, dtype=tf.float32) output,states=tf.nn.dynamic_rnn(rnn_cell,x,initial_state=init_state,dtype=tf.float32)
return tf.nn.softmax(tf.matmul(output[:,-1,:],weights)+bias,1) # y=softmax(vh+c) predy=RNN(x,weights,bias) # 以下所有神经网络大同小异
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predy,labels=y))
train=tf.train.AdamOptimizer(train_rate).minimize(cost) correct_pred=tf.equal(tf.argmax(predy,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.to_float(correct_pred)) sess=tf.Session()
sess.run(tf.global_variables_initializer())
step=1
testx,testy=mnist.test.next_batch(batch_size)
while step<train_step:
batch_x,batch_y=mnist.train.next_batch(batch_size)
_loss,__=sess.run([cost,train],feed_dict={x:batch_x,y:batch_y})
if step % display_step ==0:
print()
acc,loss=sess.run([accuracy,cost],feed_dict={x:testx,y:testy})
print(step,acc,loss) step+=1

如果你非常熟悉rnn,代码整体上还是比较好理解的,但是里面涉及许多次的shape设置,比较让人头大,特别是后期写各种rnn时,很容易迷糊,所以每个模型都要理解透彻。

以上代码涉及到shape的变量有

x y w b x变形 init_state

其中比较难理解的是 x x变形 init_state

网络结构

首先回顾一下RNN网络,以便对上个问题进行深入分析。

公式简写如下:

h1 = f(x1w1 + h0w2)

o1 = h1w3  输出层就是简单的全连接,这里不做讨论

shape分析

我们把每个时刻的输入看做向量或者矩阵,因为如果只是一个数,没有shape可言,而且也很简单,没有讨论的必要。

首先有如下思考:

1.  h是隐层的输出,也就是x传进去得到的输出,因此传一个x就有一个h(但这并不足以说明什么)

其次从公式层面考虑

从公式可以看出,x和h的行必须相同,列不必相同

图形表示

这是单节点隐层,那么多节点呢?

首先一个神经元节点对应一组weight,多个神经元就是多组weight

其次从公式层面考虑

从公式看出,h和x行相同,h列和神经元个数相同。

图形表示

综上所述,h0的shape是行为 x的行,即batch,列为神经元个数

也就是说一个神经元对应一个h0

对应到上述代码

init_state=tf.zeros(shape=[batch_size,rnn_cell.state_size])         # rnn_cell.state_size 100,100为节点数
init_state=rnn_cell.zero_state(batch_size, dtype=tf.float32)

对于输入x的shape,把代码转化成图

根据图来理解:

每次输入n张图片,也就是一次性输入所有时序的x,所有x的shape 为 [None,sequence_length*frame_size]

在rnn模型中因为要与权重相乘,所以需要转化为 [-1,sequence_length,frame_size]    [ 样本数,时序数,特征数 ],把特征划分出来,

然后特征乘以权重,然后按时序向上传递,得到输出

结合其他代码分析,对应图片而言,rnn包括LSTM的输入必须是 一次性输入所有时序的x,即 [ 样本数,时序数,特征数 ]

其实这个网络应该是这样

我的理解:像图像这种所有时序的特征结合起来才能确定y的模型用多对一RNN,且每次输入所有时序的特征,而词语预测不然。

rnn-手写数字识别-网络结构-shape的更多相关文章

  1. keras和tensorflow搭建DNN、CNN、RNN手写数字识别

    MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...

  2. 5 TensorFlow入门笔记之RNN实现手写数字识别

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  3. TensorFlow使用RNN实现手写数字识别

    学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...

  4. 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec

    人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...

  5. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  6. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

  7. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  8. 深度学习面试题12:LeNet(手写数字识别)

    目录 神经网络的卷积.池化.拉伸 LeNet网络结构 LeNet在MNIST数据集上应用 参考资料 LeNet是卷积神经网络的祖师爷LeCun在1998年提出,用于解决手写数字识别的视觉任务.自那时起 ...

  9. Mnist手写数字识别 Tensorflow

    Mnist手写数字识别 Tensorflow 任务目标 了解mnist数据集 搭建和测试模型 编辑环境 操作系统:Win10 python版本:3.6 集成开发环境:pycharm tensorflo ...

随机推荐

  1. AT2112 Non-redundant Drive

    题目:https://www.luogu.org/problemnew/show/AT2112 对于这种找路径的就直接上点分治就好. 分治时,算出每一个点到分治重心的后能剩多少油,从分治重心走到每个点 ...

  2. pycharm安装步骤

    python环境配置教程 https://jingyan.baidu.com/article/c45ad29c05c208051653e270.html 由于安装Pycharm时忘记截图了,所以详细安 ...

  3. laravel中常用的获取路径的函数

    1. app_path() // 获取app目录的路径 2. base_path() // 根目录的路径 3. config_path() // config目录的路径 4. public_path( ...

  4. python:extend (扩展) 与 append (追加) 之间的天与地

    >>> li = ['a', 'b', 'c'] >>> li.extend(['d', 'e', 'f']) >>> li ['a', 'b', ...

  5. python-day73--django-分页

    ''' 批量导入数据:bulk_create Booklist=[] for i in range(100): Booklist.append(Book(title="book"+ ...

  6. Leetcode 144

    /** * Definition for a binary tree node. * struct TreeNode { * int val; * TreeNode *left; * TreeNode ...

  7. SQL - 数据定义

    SQL 的数据定义功能包括模式定义.表定义.视图和索引的定义: 操作对象 操作方式 创建 删除 修改 模式  create schema drop schema   表  create table d ...

  8. Oracle外部表详解

    外部表概述 外部表只能在Oracle 9i之后来使用.简单地说,外部表,是指不存在于数据库中的表.通过向Oracle提供描述外部表的元数据,我们可以把一个操作系统文件当成一个只读的数据库表,就像这些数 ...

  9. Oracle Log Block Size

    Although the size of redo entries is measured in bytes, LGWR writes the redo to the log files on dis ...

  10. zookeeper 的心跳

    假定:主机 A, B 通过 tcp 连接发送数据,如果拔掉 A 主机的网线,B 是无法感知到的.但是如果 A 定时给 B 发送心跳,则能根据心跳的回复来判断连接的状态. 以 zookeeper 为例: ...