rnn-手写数字识别-网络结构-shape
手写数字识别经典案例,目标是:
1. 掌握tf编写RNN的方法
2. 剖析RNN网络结构
tensorflow编程
#coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data ### 注意
# init_state = tf.zeros(shape=[batch_size,rnn_cell.state_size])
# init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32) mnist=input_data.read_data_sets("./data",one_hot=True) # 常规参数
train_rate=0.001
train_step=10000
batch_size=1280
display_step=100 # rnn参数
frame_size=28 # 输入特征数
sequence_length=28 # 输入个数, 时序
hidden_num=100 # 隐层神经元个数
n_classes=10 # 定义输入,输出
# 此处输入格式是样本数*特征数,特征是把图片拉成一维的,当然一维还是二维自己定,改成相应的代码就行了
x=tf.placeholder(dtype=tf.float32,shape=[None,sequence_length*frame_size],name="inputx")
y=tf.placeholder(dtype=tf.float32,shape=[None,n_classes],name="expected_y") # 定义权值
# 注意权值设定只设定v, u和w无需设定
weights=tf.Variable(tf.truncated_normal(shape=[hidden_num,n_classes])) # 全连接层权重
bias=tf.Variable(tf.zeros(shape=[n_classes])) def RNN(x,weights,bias):
x=tf.reshape(x,shape=[-1,sequence_length,frame_size]) # 3维
rnn_cell=tf.nn.rnn_cell.BasicRNNCell(hidden_num) ### 注意
# init_state=tf.zeros(shape=[batch_size,rnn_cell.state_size]) # rnn_cell.state_size 100
init_state=rnn_cell.zero_state(batch_size, dtype=tf.float32) output,states=tf.nn.dynamic_rnn(rnn_cell,x,initial_state=init_state,dtype=tf.float32)
return tf.nn.softmax(tf.matmul(output[:,-1,:],weights)+bias,1) # y=softmax(vh+c) predy=RNN(x,weights,bias) # 以下所有神经网络大同小异
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predy,labels=y))
train=tf.train.AdamOptimizer(train_rate).minimize(cost) correct_pred=tf.equal(tf.argmax(predy,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.to_float(correct_pred)) sess=tf.Session()
sess.run(tf.global_variables_initializer())
step=1
testx,testy=mnist.test.next_batch(batch_size)
while step<train_step:
batch_x,batch_y=mnist.train.next_batch(batch_size)
_loss,__=sess.run([cost,train],feed_dict={x:batch_x,y:batch_y})
if step % display_step ==0:
print()
acc,loss=sess.run([accuracy,cost],feed_dict={x:testx,y:testy})
print(step,acc,loss) step+=1
如果你非常熟悉rnn,代码整体上还是比较好理解的,但是里面涉及许多次的shape设置,比较让人头大,特别是后期写各种rnn时,很容易迷糊,所以每个模型都要理解透彻。
以上代码涉及到shape的变量有
x y w b x变形 init_state
其中比较难理解的是 x x变形 init_state
网络结构
首先回顾一下RNN网络,以便对上个问题进行深入分析。
公式简写如下:
h1 = f(x1w1 + h0w2)
o1 = h1w3 输出层就是简单的全连接,这里不做讨论
shape分析
我们把每个时刻的输入看做向量或者矩阵,因为如果只是一个数,没有shape可言,而且也很简单,没有讨论的必要。
首先有如下思考:
1. h是隐层的输出,也就是x传进去得到的输出,因此传一个x就有一个h(但这并不足以说明什么)
其次从公式层面考虑
从公式可以看出,x和h的行必须相同,列不必相同
图形表示
这是单节点隐层,那么多节点呢?
首先一个神经元节点对应一组weight,多个神经元就是多组weight
其次从公式层面考虑
从公式看出,h和x行相同,h列和神经元个数相同。
图形表示
综上所述,h0的shape是行为 x的行,即batch,列为神经元个数
也就是说一个神经元对应一个h0
对应到上述代码
init_state=tf.zeros(shape=[batch_size,rnn_cell.state_size]) # rnn_cell.state_size 100,100为节点数
init_state=rnn_cell.zero_state(batch_size, dtype=tf.float32)
对于输入x的shape,把代码转化成图
根据图来理解:
每次输入n张图片,也就是一次性输入所有时序的x,所有x的shape 为 [None,sequence_length*frame_size]
在rnn模型中因为要与权重相乘,所以需要转化为 [-1,sequence_length,frame_size] [ 样本数,时序数,特征数 ],把特征划分出来,
然后特征乘以权重,然后按时序向上传递,得到输出
结合其他代码分析,对应图片而言,rnn包括LSTM的输入必须是 一次性输入所有时序的x,即 [ 样本数,时序数,特征数 ]
其实这个网络应该是这样
我的理解:像图像这种所有时序的特征结合起来才能确定y的模型用多对一RNN,且每次输入所有时序的特征,而词语预测不然。
rnn-手写数字识别-网络结构-shape的更多相关文章
- keras和tensorflow搭建DNN、CNN、RNN手写数字识别
MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...
- 5 TensorFlow入门笔记之RNN实现手写数字识别
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- TensorFlow使用RNN实现手写数字识别
学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...
- 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec
人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型
持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- 深度学习面试题12:LeNet(手写数字识别)
目录 神经网络的卷积.池化.拉伸 LeNet网络结构 LeNet在MNIST数据集上应用 参考资料 LeNet是卷积神经网络的祖师爷LeCun在1998年提出,用于解决手写数字识别的视觉任务.自那时起 ...
- Mnist手写数字识别 Tensorflow
Mnist手写数字识别 Tensorflow 任务目标 了解mnist数据集 搭建和测试模型 编辑环境 操作系统:Win10 python版本:3.6 集成开发环境:pycharm tensorflo ...
随机推荐
- Linux中符号总结
常用符号~ 登陆用户当前的家目录 . 当前目录.. 当前目录的上一级目录cd - 返回上一次的目录; 命令分隔符# 表示注释 ? 通配符中表示任意一个字符* 通配符中表 ...
- Vue.js的后端数据支持:使用Express建立app, 并使用MongoDB数据库。
需要用到的backed tech stack: Node: JavaScript on the server/backend. That's basically what it is, but mor ...
- int __get_order(unsigned long size)
2017-12-6 10:26:504.9.51int __get_order(unsigned long size){ int order; size--; size >>= ...
- Spring Batch 体系结构
Spring Batch 设计的时候充分考虑了可扩展性和各类终端用户. 下图显示了 Spring Batch 的架构层次示意图,这种架构层次为终端用户开发者提供了很好的扩展性与易用性. 上图显示的是 ...
- SpringMVC的底层实现
SpringMVC的底层实现流程: SpringMVC的核心是DispatchServlet,它负责接收HTTP的请求和协调SpringMVC中各个组件来完成请求处理的任务,一个请求被截获后,Disp ...
- python 小练习4
给你一个整数list L, 如 L=[2,-3,3,50], 求L的一个连续子序列,使其和最大,输出最大子序列的和. 例如,对于L=[2,-3,3,50], 输出53(分析:很明显,该列表最大连续子序 ...
- 小Z的袜子(hose)
小Z的袜子(hose) 作为一个生活散漫的人,小Z每天早上都要耗费很久从一堆五颜六色的袜子中找出一双来穿.终于有一天,小Z再也无法忍受这恼人的找袜子过程,于是他决定听天由命……具体来说,小Z把这N只袜 ...
- [codechef July Challenge 2017] Calculator
CALC: 计算器题目描述大厨有一个计算器,计算器上有两个屏幕和两个按钮.初始时每个屏幕上显示的都是 0.每按一次第一个按钮,就会让第一个屏幕上显示的数字加 1,同时消耗 1 单位的能量.每按一次第二 ...
- 一、JDBC
一.CRUD(增删改查) C:增加Create R查询(Retrieve) U更新(Update) D删除(Delete) 在JDBC中,增加,删除,修改,操作都基本一样,只是传递不同SQL语句,最后 ...
- web.xml之env-entry
1.目的 定义一个jndi变量 2.schemas定义 2.web.xml中定义变量 <web-app> ... <env-entry> <env-entry-name& ...