rnn-手写数字识别-网络结构-shape
手写数字识别经典案例,目标是:
1. 掌握tf编写RNN的方法
2. 剖析RNN网络结构
tensorflow编程
#coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data ### 注意
# init_state = tf.zeros(shape=[batch_size,rnn_cell.state_size])
# init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32) mnist=input_data.read_data_sets("./data",one_hot=True) # 常规参数
train_rate=0.001
train_step=10000
batch_size=1280
display_step=100 # rnn参数
frame_size=28 # 输入特征数
sequence_length=28 # 输入个数, 时序
hidden_num=100 # 隐层神经元个数
n_classes=10 # 定义输入,输出
# 此处输入格式是样本数*特征数,特征是把图片拉成一维的,当然一维还是二维自己定,改成相应的代码就行了
x=tf.placeholder(dtype=tf.float32,shape=[None,sequence_length*frame_size],name="inputx")
y=tf.placeholder(dtype=tf.float32,shape=[None,n_classes],name="expected_y") # 定义权值
# 注意权值设定只设定v, u和w无需设定
weights=tf.Variable(tf.truncated_normal(shape=[hidden_num,n_classes])) # 全连接层权重
bias=tf.Variable(tf.zeros(shape=[n_classes])) def RNN(x,weights,bias):
x=tf.reshape(x,shape=[-1,sequence_length,frame_size]) # 3维
rnn_cell=tf.nn.rnn_cell.BasicRNNCell(hidden_num) ### 注意
# init_state=tf.zeros(shape=[batch_size,rnn_cell.state_size]) # rnn_cell.state_size 100
init_state=rnn_cell.zero_state(batch_size, dtype=tf.float32) output,states=tf.nn.dynamic_rnn(rnn_cell,x,initial_state=init_state,dtype=tf.float32)
return tf.nn.softmax(tf.matmul(output[:,-1,:],weights)+bias,1) # y=softmax(vh+c) predy=RNN(x,weights,bias) # 以下所有神经网络大同小异
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predy,labels=y))
train=tf.train.AdamOptimizer(train_rate).minimize(cost) correct_pred=tf.equal(tf.argmax(predy,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.to_float(correct_pred)) sess=tf.Session()
sess.run(tf.global_variables_initializer())
step=1
testx,testy=mnist.test.next_batch(batch_size)
while step<train_step:
batch_x,batch_y=mnist.train.next_batch(batch_size)
_loss,__=sess.run([cost,train],feed_dict={x:batch_x,y:batch_y})
if step % display_step ==0:
print()
acc,loss=sess.run([accuracy,cost],feed_dict={x:testx,y:testy})
print(step,acc,loss) step+=1
如果你非常熟悉rnn,代码整体上还是比较好理解的,但是里面涉及许多次的shape设置,比较让人头大,特别是后期写各种rnn时,很容易迷糊,所以每个模型都要理解透彻。
以上代码涉及到shape的变量有
x y w b x变形 init_state
其中比较难理解的是 x x变形 init_state
网络结构
首先回顾一下RNN网络,以便对上个问题进行深入分析。
公式简写如下:
h1 = f(x1w1 + h0w2)
o1 = h1w3 输出层就是简单的全连接,这里不做讨论
shape分析
我们把每个时刻的输入看做向量或者矩阵,因为如果只是一个数,没有shape可言,而且也很简单,没有讨论的必要。
首先有如下思考:
1. h是隐层的输出,也就是x传进去得到的输出,因此传一个x就有一个h(但这并不足以说明什么)
其次从公式层面考虑
从公式可以看出,x和h的行必须相同,列不必相同
图形表示
这是单节点隐层,那么多节点呢?
首先一个神经元节点对应一组weight,多个神经元就是多组weight
其次从公式层面考虑
从公式看出,h和x行相同,h列和神经元个数相同。
图形表示
综上所述,h0的shape是行为 x的行,即batch,列为神经元个数
也就是说一个神经元对应一个h0
对应到上述代码
init_state=tf.zeros(shape=[batch_size,rnn_cell.state_size]) # rnn_cell.state_size 100,100为节点数
init_state=rnn_cell.zero_state(batch_size, dtype=tf.float32)
对于输入x的shape,把代码转化成图
根据图来理解:
每次输入n张图片,也就是一次性输入所有时序的x,所有x的shape 为 [None,sequence_length*frame_size]
在rnn模型中因为要与权重相乘,所以需要转化为 [-1,sequence_length,frame_size] [ 样本数,时序数,特征数 ],把特征划分出来,
然后特征乘以权重,然后按时序向上传递,得到输出
结合其他代码分析,对应图片而言,rnn包括LSTM的输入必须是 一次性输入所有时序的x,即 [ 样本数,时序数,特征数 ]
其实这个网络应该是这样
我的理解:像图像这种所有时序的特征结合起来才能确定y的模型用多对一RNN,且每次输入所有时序的特征,而词语预测不然。
rnn-手写数字识别-网络结构-shape的更多相关文章
- keras和tensorflow搭建DNN、CNN、RNN手写数字识别
MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...
- 5 TensorFlow入门笔记之RNN实现手写数字识别
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- TensorFlow使用RNN实现手写数字识别
学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...
- 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec
人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型
持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- 深度学习面试题12:LeNet(手写数字识别)
目录 神经网络的卷积.池化.拉伸 LeNet网络结构 LeNet在MNIST数据集上应用 参考资料 LeNet是卷积神经网络的祖师爷LeCun在1998年提出,用于解决手写数字识别的视觉任务.自那时起 ...
- Mnist手写数字识别 Tensorflow
Mnist手写数字识别 Tensorflow 任务目标 了解mnist数据集 搭建和测试模型 编辑环境 操作系统:Win10 python版本:3.6 集成开发环境:pycharm tensorflo ...
随机推荐
- AT2112 Non-redundant Drive
题目:https://www.luogu.org/problemnew/show/AT2112 对于这种找路径的就直接上点分治就好. 分治时,算出每一个点到分治重心的后能剩多少油,从分治重心走到每个点 ...
- pycharm安装步骤
python环境配置教程 https://jingyan.baidu.com/article/c45ad29c05c208051653e270.html 由于安装Pycharm时忘记截图了,所以详细安 ...
- laravel中常用的获取路径的函数
1. app_path() // 获取app目录的路径 2. base_path() // 根目录的路径 3. config_path() // config目录的路径 4. public_path( ...
- python:extend (扩展) 与 append (追加) 之间的天与地
>>> li = ['a', 'b', 'c'] >>> li.extend(['d', 'e', 'f']) >>> li ['a', 'b', ...
- python-day73--django-分页
''' 批量导入数据:bulk_create Booklist=[] for i in range(100): Booklist.append(Book(title="book"+ ...
- Leetcode 144
/** * Definition for a binary tree node. * struct TreeNode { * int val; * TreeNode *left; * TreeNode ...
- SQL - 数据定义
SQL 的数据定义功能包括模式定义.表定义.视图和索引的定义: 操作对象 操作方式 创建 删除 修改 模式 create schema drop schema 表 create table d ...
- Oracle外部表详解
外部表概述 外部表只能在Oracle 9i之后来使用.简单地说,外部表,是指不存在于数据库中的表.通过向Oracle提供描述外部表的元数据,我们可以把一个操作系统文件当成一个只读的数据库表,就像这些数 ...
- Oracle Log Block Size
Although the size of redo entries is measured in bytes, LGWR writes the redo to the log files on dis ...
- zookeeper 的心跳
假定:主机 A, B 通过 tcp 连接发送数据,如果拔掉 A 主机的网线,B 是无法感知到的.但是如果 A 定时给 B 发送心跳,则能根据心跳的回复来判断连接的状态. 以 zookeeper 为例: ...