RNN静态与动态
静态、多层RNN:
import numpy as np
import tensorflow as tf
# 导入 MINST 数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/data/", one_hot=True) #网络模型参数
n_input = 28 # MNIST data 输入 (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST 列别 (0-9 ,一共10类) #训练参数
batch_size = 128
learning_rate = 0.001
training_iters = 10000
display_step = 10 # tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes]) #构建网络
stacked_rnn = []
for _ in range(3):
stacked_rnn.append(tf.contrib.rnn.LSTMCell(n_hidden))
mcell = tf.contrib.rnn.MultiRNNCell(stacked_rnn) x1=tf.unstack(x,n_steps,1)#在axis=1进行解包分解。 outputs, states = tf.contrib.rnn.static_rnn(mcell, x1, dtype=tf.float32)#inputs must be a sequence
#最后一层全连接 outputs[-1]
pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None) # Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # 启动session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Reshape data to get 28 seq of 28 elements
batch_x = batch_x.reshape((batch_size, n_steps, n_input))
# Run optimization op (backprop)
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % display_step == 0:
# 计算批次数据的准确率
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print ("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \
"{:.6f}".format(loss) + ", Training Accuracy= " + \
"{:.5f}".format(acc))
step += 1
print (" Finished!")
# 计算准确率 for 128 mnist test images
test_len = 100
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
print ("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
在学习RNN这一章的时候,遇到static_rnn中输入数据 x 的格式:
[None, n_steps, n_input] 进行变换→ x1=tf.unstack(x,n_steps,1)
之后再传入:outputs, states = tf.contrib.rnn.static_rnn(mcell, x1, dtype=tf.float32)
很难理解,为什么要这样做,数据又进行了怎样的变换。
以下,为stack和unstack的详细举例:
- tf.stack(values, axis=0, name=’stack’)
以指定的轴axis,将一个维度为R的张量数组转变成一个维度为R+1的张量。即将一组张量以指定的轴,提高一个维度。
假设要转变的张量数组values的长度为N,其中的每个张量的形状为(A, B, C)。
如果轴axis=0,则转变后的张量的形状为(N, A, B, C)。
如果轴axis=1,则转变后的张量的形状为(A, N, B, C)。
如果轴axis=2,则转变后的张量的形状为(A, B, N, C)。其它情况依次类推。
举例如下:
‘x’ is [1, 4], 形状是(2),维度为1
‘y’ is [2, 5], 形状是(2),维度为1
‘z’ is [3, 6], 形状是(2),维度为1
stack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # axis的值默认为0。输出的形状为(3, 2)
stack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] # axis的值为1。输出的形状为(2, 3)
‘x’ is [[1,1,1,1],[2,2,2,2],[3,3,3,3]],形状是(3,4),维度为2
‘y’ is [[4,4,4,4],[5,5,5,5],[6,6,6,6]],形状是(3,4),维度为2
stack([x,y]) => [[[1,1,1,1],[2,2,2,2],[3,3,3,3]], [[4,4,4,4],[5,5,5,5],[6,6,6,6]]] # axis的值默认为0。输出的形状为(2, 3, 4)
stack([x,y],axis=1) => [[[1,1,1,1],[4,4,4,4]],[[2,2,2,2],[5,5,5,5]],[[3,3,3,3],[6,6,6,6]]] # axis的值为1。输出的形状为(3, 2, 4)
stack([x,y],axis=2) => [[[1,4],[1,4],[1,4],[1,4]],[[2,5],[2,5],[2,5],[2,5]],[[3,6],[3,6],[3,6],[3,6]]]# axis的值为2。输出的形状为(3, 4, 2)
axis可这样理解:stack就是要将一组相同形状的张量提高一个维度。axis就是这些张量里,将axis指定的维度用所有这些张量数组代替。如axis=2,表示指定在第2个维度,原来的元素用整个张量数组里的元素代替,即从(A, B, C)转变为(A, B, N, C)
参数:
values: 一个有相同形状与数据类型的张量数组。
axis: 以轴axis为中心来转变的整数。默认是第一个维度即axis=0。支持负数。取值范围为[-(R+1), R+1)
name: 这个操作的名字(可选)
返回:被提高一个维度后的张量
异常: ValueError: 如果轴axis超出范围[-(R+1), R+1).
- tf.unstack()
tf.unstack(value, num=None, axis=0, name=’unstack’)
以指定的轴axis,将一个维度为R的张量数组转变成一个维度为R-1的张量。即将一组张量以指定的轴,减少一个维度。正好和stack()相反。
将张量value分割成num个张量数组。如果num没有指定,则是根据张量value的形状来指定。如果value.shape[axis]不存在,则抛出ValueError的异常。
假如一个张量的形状是(A, B, C, D)。
如果axis == 0,则输出的张量是value[i, :, :, :],i取值为[0,A),每个输出的张量的形状为(B,C,D)。
如果axis == 1,则输出的张量是value[:, i, :, :],i取值为[0,B),每个输出的张量的形状为(A,C,D)。
如果axis == 2,则输出的张量是value[:, :, i, :],i取值为[0,C),每个输出的张量的形状为(A,B,D)。依次类推。
举例如下:
‘x’ is [[1,1,1,1],[2,2,2,2],[3,3,3,3]] # 形状是(3,4),维度为2
unstack(x,axis=0) =>以指定的维度0为轴,转变成3个形状为(4)张量[1,1,1,1],[2,2,2,2],[3,3,3,3]
unstack(x,axis=1) =>以指定的维度1为轴,转变成4个形状为(3)张量[1,2,3],[1,2,3],[1,2,4],[1,2,3]
axis可这样理解:unstack就是要将一个张量降低为低一个维度的张量数组。axis就是将axis指定的维度,用所有这个张量里同维度的数据代替。
参数:
value: 一个将要被降维的维度大于0的张量。
num: 整数。指定的维度axis的长度。如果设置为None(默认值),将自动求值。
axis: 整数.以轴axis指定的维度来转变 默认是第一个维度即axis=0。支持负数。取值范围为[-R, R)
name: 这个操作的名字(可选)
返回:
从张量value降维后的张量数组。
异常:
ValueError: 如果num没有指定并且无法求出来。
ValueError: 如果axis超出范围 [-R, R)。
经过下面的例子理解后,上面的1对应axis=1, nsteps对应函数中的num参数,表示axis=1的长度。该操作将数据 x 按照序列数目切开。我们传入的 x 是个3维tensor,将其按照序列数切开,得到了n_steps个 二维的tensor, [batchsize, n_input]
RNN静态与动态的更多相关文章
- Android中BroadcastReceiver的两种注册方式(静态和动态)详解
今天我们一起来探讨下安卓中BroadcastReceiver组件以及详细分析下它的两种注册方式. BroadcastReceiver也就是"广播接收者"的意思,顾名思义,它就是用来 ...
- 生成lua的静态库.动态库.lua.exe和luac.exe
前些日子准备学习下关于lua coroutine更为强大的功能,然而发现根据lua 5.1.4版本来运行一段代码的话也会导致 "lua: attempt to yield across me ...
- Delphi DLL的创建、静态及动态调用
转载:http://blog.csdn.net/welcome000yy/article/details/7905463 结合这篇博客:http://www.cnblogs.com/xumenger/ ...
- 3D touch 静态、动态设置及进入APP的跳转方式
申明Quick Action有两种方式:静态和动态 静态是在info.plist文件中申明,动态则是在代码中注册,系统支持两者同时存在. -系统限制每个app最多显示4个快捷图标,包括静态和动态 静态 ...
- C/C++ 跨平台交叉编译、静态库/动态库编译、MinGW、Cygwin、CodeBlocks使用原理及链接参数选项
目录 . 引言 . 交叉编译 . Cygwin简介 . 静态库编译及使用 . 动态库编译及使用 . MinGW简介 . CodeBlocks简介 0. 引言 UNIX是一个注册商标,是要满足一大堆条件 ...
- RT-Thread创建静态、动态线程
RT-Thread 实时操作系统核心是一个高效的硬实时核心,它具备非常优异的实时性.稳定性.可剪裁性,当进行最小配置时,内核体积可以到 3k ROM 占用. 1k RAM 占用. RT-Thread ...
- linux静态与动态库创建及使用实例
一,gcc基础语法: 基本语法结构:(由以下四部分组成) gcc -o 可执行文件名 依赖文件集(*.c/*.o) 依赖库文件及其头文件集(由-I或-L与-l指明) gcc 依赖文件集(*.c/*.o ...
- MYSQL学习笔记2--mysql 静态和动态plugin
mysql源码编译 .cmke 安装 yum install cmake .依赖的库下载机安装: yum -y install gcc* gcc-c++* autoconf* automake* zl ...
- Android SurfaceView实现静态于动态画图效果
本文是基于Android的SurfaceView的动态画图效果,实现静态和动态下的正弦波画图,可作为自己做图的简单参考,废话不多说,先上图, 静态效果: 动态效果: 比较简单,代码注释的也比较详细,易 ...
随机推荐
- linux-3.0内核移植到fl2440开发板(以s3c2410为模板)
1.新建kernel文件夹,用于存放内核文件 [weishusheng@localhost ~]$ mkdir kernel 2.进入kernel,上传压并解压压缩文件 [weishusheng@lo ...
- View Programming Guide for iOS ---- iOS 视图编程指南(二)---View and Window Architecture
View and Window Architecture 视图和窗口架构 Views and windows present your application’s user interface and ...
- k8s认证及ServiceAccount-十五
一.ServiceAccount (1)简介 https://www.kubernetes.org.cn/service-account Service account是为了方便Pod里面的进程调用K ...
- (水题)洛谷 - P2089 - 烤鸡
https://www.luogu.org/problemnew/show/P2089 非常暴力的dfs,不知道不剪枝会怎么样,但是其实最多也就 $3^{10}$ ,大不到哪里去.还有一个细节就是大于 ...
- SequoiaDB、SequoiaSQL、Cloudera Manager4.8.0、Cloudera CDH4.5 详细安装教程
1安装SequoaiDB集群 1.1配置信任关系 以root用户执行下面的操作 1 执行命令 ssh-keygen 然后一直回车确定即可 2 每台机器都打开id_rsa.pub文件 vi ~/.ssh ...
- 如何在VS2010的VC++ 基于对话框的MFC程序中添加菜单
方法1:亲测 成功 转载自https://social.msdn.microsoft.com/Forums/vstudio/zh-CN/48338f6b-e5d9-4c0c-8b17-05ca3ef ...
- c++ 头文件路径选择
单文件引用头文件./ 当前目录 ../ 父级目录 / 根目录 多文件引用头文件多文件引用头文件 定义单独放在cpp文件里面 ,声明放在().h)里面
- Django framework
1. Django 的内置web server是如何实现的 2. Django 的WSGI是如何实现的 3. Django middle ware是如何实现的 4. Django framework的 ...
- ui自动化测试的意义与理解
分层测试的思想 分层测试(有的也叫测试金字塔)是最近几年慢慢流行.火热起来的,也逐渐得到了大家的认可,大家应该已经比较熟悉分层测试的思想了,不太了解的可以自行找一些相应的渠道去补充一下上下文的知识. ...
- AJPFX总结java开发常用类(包装,数字处理集合等)(三)
4.Map是一种把键对象和值对象进行关联的容器,而一个值对象又可以是一个Map,依次类推,这样就可形成一个多级映射.对于键对象来说,像Set一样,一 个Map容器中的键对象不允许重复,这是为了保持查找 ...