静态、多层RNN:

import numpy as np
import tensorflow as tf
# 导入 MINST 数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/data/", one_hot=True) #网络模型参数
n_input = 28 # MNIST data 输入 (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST 列别 (0-9 ,一共10类) #训练参数
batch_size = 128
learning_rate = 0.001
training_iters = 10000
display_step = 10 # tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes]) #构建网络
stacked_rnn = []
for _ in range(3):
stacked_rnn.append(tf.contrib.rnn.LSTMCell(n_hidden))
mcell = tf.contrib.rnn.MultiRNNCell(stacked_rnn) x1=tf.unstack(x,n_steps,1)#在axis=1进行解包分解。 outputs, states = tf.contrib.rnn.static_rnn(mcell, x1, dtype=tf.float32)#inputs must be a sequence

#最后一层全连接 outputs[-1]
pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None) # Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # 启动session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Reshape data to get 28 seq of 28 elements
batch_x = batch_x.reshape((batch_size, n_steps, n_input))
# Run optimization op (backprop)
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % display_step == 0:
# 计算批次数据的准确率
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print ("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \
"{:.6f}".format(loss) + ", Training Accuracy= " + \
"{:.5f}".format(acc))
step += 1
print (" Finished!")

# 计算准确率 for 128 mnist test images
test_len = 100
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
print ("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

在学习RNN这一章的时候,遇到static_rnn中输入数据 x 的格式:

[None, n_steps, n_input] 进行变换→ x1=tf.unstack(x,n_steps,1)

之后再传入:outputs, states = tf.contrib.rnn.static_rnn(mcell, x1, dtype=tf.float32)

很难理解,为什么要这样做,数据又进行了怎样的变换。


以下,为stack和unstack的详细举例:

  • tf.stack(values, axis=0, name=’stack’)
    以指定的轴axis,将一个维度为R的张量数组转变成一个维度为R+1的张量。即将一组张量以指定的轴,提高一个维度。

假设要转变的张量数组values的长度为N,其中的每个张量的形状为(A, B, C)。
如果轴axis=0,则转变后的张量的形状为(N, A, B, C)。
如果轴axis=1,则转变后的张量的形状为(A, N, B, C)。
如果轴axis=2,则转变后的张量的形状为(A, B, N, C)。其它情况依次类推。

举例如下:
‘x’ is [1, 4], 形状是(2),维度为1
‘y’ is [2, 5], 形状是(2),维度为1
‘z’ is [3, 6], 形状是(2),维度为1
stack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # axis的值默认为0。输出的形状为(3, 2)
stack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] # axis的值为1。输出的形状为(2, 3)

‘x’ is [[1,1,1,1],[2,2,2,2],[3,3,3,3]],形状是(3,4),维度为2
‘y’ is [[4,4,4,4],[5,5,5,5],[6,6,6,6]],形状是(3,4),维度为2
stack([x,y]) => [[[1,1,1,1],[2,2,2,2],[3,3,3,3]], [[4,4,4,4],[5,5,5,5],[6,6,6,6]]] # axis的值默认为0。输出的形状为(2, 3, 4)
stack([x,y],axis=1) => [[[1,1,1,1],[4,4,4,4]],[[2,2,2,2],[5,5,5,5]],[[3,3,3,3],[6,6,6,6]]] # axis的值为1。输出的形状为(3, 2, 4)
stack([x,y],axis=2) => [[[1,4],[1,4],[1,4],[1,4]],[[2,5],[2,5],[2,5],[2,5]],[[3,6],[3,6],[3,6],[3,6]]]# axis的值为2。输出的形状为(3, 4, 2)

axis可这样理解:stack就是要将一组相同形状的张量提高一个维度。axis就是这些张量里,将axis指定的维度用所有这些张量数组代替。如axis=2,表示指定在第2个维度,原来的元素用整个张量数组里的元素代替,即从(A, B, C)转变为(A, B, N, C)

参数:
values: 一个有相同形状与数据类型的张量数组。
axis: 以轴axis为中心来转变的整数。默认是第一个维度即axis=0。支持负数。取值范围为[-(R+1), R+1)
name: 这个操作的名字(可选)
返回:被提高一个维度后的张量
异常: ValueError: 如果轴axis超出范围[-(R+1), R+1).


  • tf.unstack()

tf.unstack(value, num=None, axis=0, name=’unstack’)
以指定的轴axis,将一个维度为R的张量数组转变成一个维度为R-1的张量。即将一组张量以指定的轴,减少一个维度。正好和stack()相反。

将张量value分割成num个张量数组。如果num没有指定,则是根据张量value的形状来指定。如果value.shape[axis]不存在,则抛出ValueError的异常。

假如一个张量的形状是(A, B, C, D)。
如果axis == 0,则输出的张量是value[i, :, :, :],i取值为[0,A),每个输出的张量的形状为(B,C,D)。
如果axis == 1,则输出的张量是value[:, i, :, :],i取值为[0,B),每个输出的张量的形状为(A,C,D)。
如果axis == 2,则输出的张量是value[:, :, i, :],i取值为[0,C),每个输出的张量的形状为(A,B,D)。依次类推。

举例如下:
‘x’ is [[1,1,1,1],[2,2,2,2],[3,3,3,3]] # 形状是(3,4),维度为2
unstack(x,axis=0) =>以指定的维度0为轴,转变成3个形状为(4)张量[1,1,1,1],[2,2,2,2],[3,3,3,3]
unstack(x,axis=1) =>以指定的维度1为轴,转变成4个形状为(3)张量[1,2,3],[1,2,3],[1,2,4],[1,2,3]

axis可这样理解:unstack就是要将一个张量降低为低一个维度的张量数组。axis就是将axis指定的维度,用所有这个张量里同维度的数据代替。

参数:
value: 一个将要被降维的维度大于0的张量。
num: 整数。指定的维度axis的长度。如果设置为None(默认值),将自动求值。
axis: 整数.以轴axis指定的维度来转变 默认是第一个维度即axis=0。支持负数。取值范围为[-R, R)
name: 这个操作的名字(可选)
返回:
从张量value降维后的张量数组。
异常:
ValueError: 如果num没有指定并且无法求出来。
ValueError: 如果axis超出范围 [-R, R)。


经过下面的例子理解后,上面的1对应axis=1, nsteps对应函数中的num参数,表示axis=1的长度。该操作将数据 x 按照序列数目切开。我们传入的 x 是个3维tensor,将其按照序列数切开,得到了n_steps个 二维的tensor, [batchsize, n_input]

RNN静态与动态的更多相关文章

  1. Android中BroadcastReceiver的两种注册方式(静态和动态)详解

    今天我们一起来探讨下安卓中BroadcastReceiver组件以及详细分析下它的两种注册方式. BroadcastReceiver也就是"广播接收者"的意思,顾名思义,它就是用来 ...

  2. 生成lua的静态库.动态库.lua.exe和luac.exe

    前些日子准备学习下关于lua coroutine更为强大的功能,然而发现根据lua 5.1.4版本来运行一段代码的话也会导致 "lua: attempt to yield across me ...

  3. Delphi DLL的创建、静态及动态调用

    转载:http://blog.csdn.net/welcome000yy/article/details/7905463 结合这篇博客:http://www.cnblogs.com/xumenger/ ...

  4. 3D touch 静态、动态设置及进入APP的跳转方式

    申明Quick Action有两种方式:静态和动态 静态是在info.plist文件中申明,动态则是在代码中注册,系统支持两者同时存在. -系统限制每个app最多显示4个快捷图标,包括静态和动态 静态 ...

  5. C/C++ 跨平台交叉编译、静态库/动态库编译、MinGW、Cygwin、CodeBlocks使用原理及链接参数选项

    目录 . 引言 . 交叉编译 . Cygwin简介 . 静态库编译及使用 . 动态库编译及使用 . MinGW简介 . CodeBlocks简介 0. 引言 UNIX是一个注册商标,是要满足一大堆条件 ...

  6. RT-Thread创建静态、动态线程

    RT-Thread 实时操作系统核心是一个高效的硬实时核心,它具备非常优异的实时性.稳定性.可剪裁性,当进行最小配置时,内核体积可以到 3k ROM 占用. 1k RAM 占用. RT-Thread ...

  7. linux静态与动态库创建及使用实例

    一,gcc基础语法: 基本语法结构:(由以下四部分组成) gcc -o 可执行文件名 依赖文件集(*.c/*.o) 依赖库文件及其头文件集(由-I或-L与-l指明) gcc 依赖文件集(*.c/*.o ...

  8. MYSQL学习笔记2--mysql 静态和动态plugin

    mysql源码编译 .cmke 安装 yum install cmake .依赖的库下载机安装: yum -y install gcc* gcc-c++* autoconf* automake* zl ...

  9. Android SurfaceView实现静态于动态画图效果

    本文是基于Android的SurfaceView的动态画图效果,实现静态和动态下的正弦波画图,可作为自己做图的简单参考,废话不多说,先上图, 静态效果: 动态效果: 比较简单,代码注释的也比较详细,易 ...

随机推荐

  1. ASP.NET Core MVC 2.x 全面教程_ASP.NET Core MVC 15. 用户管理

    源码的github的地址 https://github.com/solenovex/ASP.NET-Core-MVC-Tutorial-Code 语雀上的人的地址: https://github.co ...

  2. 爱奇艺面试Python,竟然挂在第5轮……

    今天给大家分享我曾经在爱奇艺的面试,过程还是比较有意思的,可以给大家一些参考 聊骚阶段 嗲妹妹:你好,我是爱奇艺的HR,我们正在招聘运维开发岗位,请问您最近有在看工作机会吗? 我:(这声音也太酥了吧我 ...

  3. poj1724【最短路】

    题意: 给出n个城市,然后给出m条单向路,给出了每条路的距离和花费,问一个人有k coins,在不超过money的情况下从1到n最短路径路径. 思路: 我相信很多人在上面那道题的影响下,肯定会想想,在 ...

  4. bzoj 5496: [2019省队联测]字符串问题【SAM+拓扑】

    有一个想法就是暴力建图,把每个A向有和他相连的B前缀的A,然后拓扑一下,这样的图是n^2的: 考虑优化建图,因为大部分数据结构都是处理后缀的,所以把串反过来,题目中要求的前缀B就变成了后缀B 建立SA ...

  5. bzoj 3992: [SDOI2015]序列统计【原根+生成函数+NTT+快速幂】

    还是没有理解透原根--题目提示其实挺明显的,M是质数,然后1<=x<=M-1 这种计数就容易想到生成函数,但是生成函数是加法,而这里是乘法,所以要想办法变成加法 首先因为0和任何数乘都是0 ...

  6. linux之用户态和内核态

    一. Unix/Linux的体系架构 如上图所示,从宏观上来看,Linux操作系统的体系架构分为用户态和内核态(或者用户空间和内核).内核从本质上看是一种软件——控制计算机的硬件资源,并提供上层应用程 ...

  7. python之商品操作小程序

    要求:写一个添加商品的程序,商品信息写入txt文件中,以二维字典形式比如:{‘小米’:{‘价格’:‘1999元’,‘数量’:10}} 1.添加商品 #商品名称 #价格 #数量 2.查看商品 3.删除商 ...

  8. 无法获得VMCI 驱动程序的版本: 句柄无效的解决方法

    关闭虚拟机,找到安装路径,用记事本打开.vmx结尾的文件 将vmci0.present = "TRUE"改为vmci0.present = "FALSE"保存

  9. HDU6446(树上、排列的贡献计算)

    关键点在于:全排列中,任意两点u.v相邻的次数一定是(n - 1)! * 2次,即一个常数(可以由高中数学知识计算,将这两个点捏一起然后全排列然后乘二:或者用n! / C(2, n)). 这之后就好算 ...

  10. UvaLive6441(期望概率dp)

    1.涉及负数时同时维护最大和最小,互相转移. 2.考场上最大最小混搭转移WA,赛后发现如果是小的搭小的,大的搭大的就可过,类似这种: db a = (C[i] - W[i]) * dp1[i - ][ ...