静态、多层RNN:

import numpy as np
import tensorflow as tf
# 导入 MINST 数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/data/", one_hot=True) #网络模型参数
n_input = 28 # MNIST data 输入 (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST 列别 (0-9 ,一共10类) #训练参数
batch_size = 128
learning_rate = 0.001
training_iters = 10000
display_step = 10 # tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes]) #构建网络
stacked_rnn = []
for _ in range(3):
stacked_rnn.append(tf.contrib.rnn.LSTMCell(n_hidden))
mcell = tf.contrib.rnn.MultiRNNCell(stacked_rnn) x1=tf.unstack(x,n_steps,1)#在axis=1进行解包分解。 outputs, states = tf.contrib.rnn.static_rnn(mcell, x1, dtype=tf.float32)#inputs must be a sequence

#最后一层全连接 outputs[-1]
pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None) # Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # 启动session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Reshape data to get 28 seq of 28 elements
batch_x = batch_x.reshape((batch_size, n_steps, n_input))
# Run optimization op (backprop)
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % display_step == 0:
# 计算批次数据的准确率
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print ("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \
"{:.6f}".format(loss) + ", Training Accuracy= " + \
"{:.5f}".format(acc))
step += 1
print (" Finished!")

# 计算准确率 for 128 mnist test images
test_len = 100
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
print ("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

在学习RNN这一章的时候,遇到static_rnn中输入数据 x 的格式:

[None, n_steps, n_input] 进行变换→ x1=tf.unstack(x,n_steps,1)

之后再传入:outputs, states = tf.contrib.rnn.static_rnn(mcell, x1, dtype=tf.float32)

很难理解,为什么要这样做,数据又进行了怎样的变换。


以下,为stack和unstack的详细举例:

  • tf.stack(values, axis=0, name=’stack’)
    以指定的轴axis,将一个维度为R的张量数组转变成一个维度为R+1的张量。即将一组张量以指定的轴,提高一个维度。

假设要转变的张量数组values的长度为N,其中的每个张量的形状为(A, B, C)。
如果轴axis=0,则转变后的张量的形状为(N, A, B, C)。
如果轴axis=1,则转变后的张量的形状为(A, N, B, C)。
如果轴axis=2,则转变后的张量的形状为(A, B, N, C)。其它情况依次类推。

举例如下:
‘x’ is [1, 4], 形状是(2),维度为1
‘y’ is [2, 5], 形状是(2),维度为1
‘z’ is [3, 6], 形状是(2),维度为1
stack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # axis的值默认为0。输出的形状为(3, 2)
stack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] # axis的值为1。输出的形状为(2, 3)

‘x’ is [[1,1,1,1],[2,2,2,2],[3,3,3,3]],形状是(3,4),维度为2
‘y’ is [[4,4,4,4],[5,5,5,5],[6,6,6,6]],形状是(3,4),维度为2
stack([x,y]) => [[[1,1,1,1],[2,2,2,2],[3,3,3,3]], [[4,4,4,4],[5,5,5,5],[6,6,6,6]]] # axis的值默认为0。输出的形状为(2, 3, 4)
stack([x,y],axis=1) => [[[1,1,1,1],[4,4,4,4]],[[2,2,2,2],[5,5,5,5]],[[3,3,3,3],[6,6,6,6]]] # axis的值为1。输出的形状为(3, 2, 4)
stack([x,y],axis=2) => [[[1,4],[1,4],[1,4],[1,4]],[[2,5],[2,5],[2,5],[2,5]],[[3,6],[3,6],[3,6],[3,6]]]# axis的值为2。输出的形状为(3, 4, 2)

axis可这样理解:stack就是要将一组相同形状的张量提高一个维度。axis就是这些张量里,将axis指定的维度用所有这些张量数组代替。如axis=2,表示指定在第2个维度,原来的元素用整个张量数组里的元素代替,即从(A, B, C)转变为(A, B, N, C)

参数:
values: 一个有相同形状与数据类型的张量数组。
axis: 以轴axis为中心来转变的整数。默认是第一个维度即axis=0。支持负数。取值范围为[-(R+1), R+1)
name: 这个操作的名字(可选)
返回:被提高一个维度后的张量
异常: ValueError: 如果轴axis超出范围[-(R+1), R+1).


  • tf.unstack()

tf.unstack(value, num=None, axis=0, name=’unstack’)
以指定的轴axis,将一个维度为R的张量数组转变成一个维度为R-1的张量。即将一组张量以指定的轴,减少一个维度。正好和stack()相反。

将张量value分割成num个张量数组。如果num没有指定,则是根据张量value的形状来指定。如果value.shape[axis]不存在,则抛出ValueError的异常。

假如一个张量的形状是(A, B, C, D)。
如果axis == 0,则输出的张量是value[i, :, :, :],i取值为[0,A),每个输出的张量的形状为(B,C,D)。
如果axis == 1,则输出的张量是value[:, i, :, :],i取值为[0,B),每个输出的张量的形状为(A,C,D)。
如果axis == 2,则输出的张量是value[:, :, i, :],i取值为[0,C),每个输出的张量的形状为(A,B,D)。依次类推。

举例如下:
‘x’ is [[1,1,1,1],[2,2,2,2],[3,3,3,3]] # 形状是(3,4),维度为2
unstack(x,axis=0) =>以指定的维度0为轴,转变成3个形状为(4)张量[1,1,1,1],[2,2,2,2],[3,3,3,3]
unstack(x,axis=1) =>以指定的维度1为轴,转变成4个形状为(3)张量[1,2,3],[1,2,3],[1,2,4],[1,2,3]

axis可这样理解:unstack就是要将一个张量降低为低一个维度的张量数组。axis就是将axis指定的维度,用所有这个张量里同维度的数据代替。

参数:
value: 一个将要被降维的维度大于0的张量。
num: 整数。指定的维度axis的长度。如果设置为None(默认值),将自动求值。
axis: 整数.以轴axis指定的维度来转变 默认是第一个维度即axis=0。支持负数。取值范围为[-R, R)
name: 这个操作的名字(可选)
返回:
从张量value降维后的张量数组。
异常:
ValueError: 如果num没有指定并且无法求出来。
ValueError: 如果axis超出范围 [-R, R)。


经过下面的例子理解后,上面的1对应axis=1, nsteps对应函数中的num参数,表示axis=1的长度。该操作将数据 x 按照序列数目切开。我们传入的 x 是个3维tensor,将其按照序列数切开,得到了n_steps个 二维的tensor, [batchsize, n_input]

RNN静态与动态的更多相关文章

  1. Android中BroadcastReceiver的两种注册方式(静态和动态)详解

    今天我们一起来探讨下安卓中BroadcastReceiver组件以及详细分析下它的两种注册方式. BroadcastReceiver也就是"广播接收者"的意思,顾名思义,它就是用来 ...

  2. 生成lua的静态库.动态库.lua.exe和luac.exe

    前些日子准备学习下关于lua coroutine更为强大的功能,然而发现根据lua 5.1.4版本来运行一段代码的话也会导致 "lua: attempt to yield across me ...

  3. Delphi DLL的创建、静态及动态调用

    转载:http://blog.csdn.net/welcome000yy/article/details/7905463 结合这篇博客:http://www.cnblogs.com/xumenger/ ...

  4. 3D touch 静态、动态设置及进入APP的跳转方式

    申明Quick Action有两种方式:静态和动态 静态是在info.plist文件中申明,动态则是在代码中注册,系统支持两者同时存在. -系统限制每个app最多显示4个快捷图标,包括静态和动态 静态 ...

  5. C/C++ 跨平台交叉编译、静态库/动态库编译、MinGW、Cygwin、CodeBlocks使用原理及链接参数选项

    目录 . 引言 . 交叉编译 . Cygwin简介 . 静态库编译及使用 . 动态库编译及使用 . MinGW简介 . CodeBlocks简介 0. 引言 UNIX是一个注册商标,是要满足一大堆条件 ...

  6. RT-Thread创建静态、动态线程

    RT-Thread 实时操作系统核心是一个高效的硬实时核心,它具备非常优异的实时性.稳定性.可剪裁性,当进行最小配置时,内核体积可以到 3k ROM 占用. 1k RAM 占用. RT-Thread ...

  7. linux静态与动态库创建及使用实例

    一,gcc基础语法: 基本语法结构:(由以下四部分组成) gcc -o 可执行文件名 依赖文件集(*.c/*.o) 依赖库文件及其头文件集(由-I或-L与-l指明) gcc 依赖文件集(*.c/*.o ...

  8. MYSQL学习笔记2--mysql 静态和动态plugin

    mysql源码编译 .cmke 安装 yum install cmake .依赖的库下载机安装: yum -y install gcc* gcc-c++* autoconf* automake* zl ...

  9. Android SurfaceView实现静态于动态画图效果

    本文是基于Android的SurfaceView的动态画图效果,实现静态和动态下的正弦波画图,可作为自己做图的简单参考,废话不多说,先上图, 静态效果: 动态效果: 比较简单,代码注释的也比较详细,易 ...

随机推荐

  1. MariaDB + Visual Studio 2017 环境下的 ODBC 入门开发

    参考: Easysoft公司提供的ODBC教程 微软提供的ODBC文档 环境: Windows 10 x64 1803 MariaDB TX 10.2.14 x64 MariaDB ODBC Conn ...

  2. HDU-5551 Huatuo's Medicine

    Time Limit: 3000/1000 MS (Java/Others)    Memory Limit: 65535/65535 K (Java/Others)Total Submission( ...

  3. In-App Purchase Programming Guide----(四) ----Requesting Payment

    Requesting Payment In the second part of the purchase process, after the user has chosen to purchase ...

  4. E20180407-hm

    queue   n. (人或车辆) 行列,长队; 辫子;   vi. (人.车等) 排队等候;   vt. (使) 排队,列队等待; compatible  adj. 兼容的,相容的; 和谐的,协调的 ...

  5. bzoj 3109: [cqoi2013]新数独【dfs】

    按3x3的小块dfs,填数的时候直接满足所有条件即可 #include<iostream> #include<cstdio> #include<cstring> u ...

  6. firewall-cmd 使用总结

    firewalld的简要说明: firewalld .firewall-cmd .firewall-offline-cmd它们Python脚本,通过定义的在/usr/lib/firewalld下面的x ...

  7. 继续(3n+1)猜想 (25)

    #include <algorithm> #include <iostream> using namespace std; int main(){ ] = { }; ], nu ...

  8. iOS UITextView 设置圆角边框线

    textView.layer.borderColor = UIColor.lightGray.cgColor textView.layer.cornerRadius = 4 textView.laye ...

  9. ssh公私密钥的生成

    ssh密钥的生成 root账号密钥的生成: 这里我们切换到root账号下,执行ssh-keygen命令: ssh-keygen -t dsa 然后一路回车即可 """ [ ...

  10. 18.3.1获得Class对象

    package d18_3_1; /** * Java中的java.lang.Class,简单理解就是为每个java对象的类型标识的类, * 虚拟机使用运行时类型信息选择正确的执行方法,用来保存这些运 ...