'''
A Bidirectional Recurrent Neural Network (LSTM) implementation example using TensorFlow library.
This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/)
Long Short Term Memory paper: http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf Author: Aymeric Damien
Project: https://github.com/aymericdamien/TensorFlow-Examples/
''' from __future__ import print_function import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np # Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) '''
To classify images using a bidirectional recurrent neural network, we consider
every image row as a sequence of pixels. Because MNIST image shape is 28*28px,
we will then handle 28 sequences of 28 steps for every sample.
''' # Parameters
learning_rate = 0.001 # 可以理解为,训练时总共用的样本数
training_iters = 100000 # 每次训练的样本大小
batch_size = 128 # 这个是用来显示的。
display_step = 10 # Network Parameters
# n_steps*n_input其实就是那张图 把每一行拆到每个time step上。
n_input = 28 # MNIST data input (img shape: 28*28)
n_steps = 28 # timesteps # 隐藏层大小
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST total classes (0-9 digits) # tf Graph input
# [None, n_steps, n_input]这个None表示这一维不确定大小
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes]) # Define weights
weights = {
# Hidden layer weights => 2*n_hidden because of forward + backward cells
'out': tf.Variable(tf.random_normal([2*n_hidden, n_classes]))
}
biases = {
'out': tf.Variable(tf.random_normal([n_classes]))
} def BiRNN(x, weights, biases): # Prepare data shape to match `bidirectional_rnn` function requirements
# Current data input shape: (batch_size, n_steps, n_input)
# Required shape: 'n_steps' tensors list of shape (batch_size, n_input) # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)
# 变成了n_steps*(batch_size, n_input)
x = tf.unstack(x, n_steps, 1) # Define lstm cells with tensorflow
# Forward direction cell
lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Backward direction cell
lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0) # Get lstm cell output
try:
outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
dtype=tf.float32)
except Exception: # Old TensorFlow version only returns outputs not states
outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
dtype=tf.float32) # Linear activation, using rnn inner loop last output
return tf.matmul(outputs[-1], weights['out']) + biases['out'] pred = BiRNN(x, weights, biases) # Define loss and optimizer
# softmax_cross_entropy_with_logits:Measures the probability error in discrete classification tasks in which the classes are mutually exclusive
# return a 1-D Tensor of length batch_size of the same type as logits with the softmax cross entropy loss.
# reduce_mean就是对所有数值(这里没有指定哪一维)求均值。
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # Initializing the variables
init = tf.global_variables_initializer() # Launch the graph
with tf.Session() as sess:
sess.run(init)
step = 1
# Keep training until reach max iterations
while step * batch_size < training_iters:
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Reshape data to get 28 seq of 28 elements
batch_x = batch_x.reshape((batch_size, n_steps, n_input))
# Run optimization op (backprop)
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % display_step == 0:
# Calculate batch accuracy
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \
"{:.6f}".format(loss) + ", Training Accuracy= " + \
"{:.5f}".format(acc))
step += 1
print("Optimization Finished!") # Calculate accuracy for 128 mnist test images
test_len = 128
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
print("Testing Accuracy:", \
sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

官方关于bilstm的例子写的很清楚了。因为是第一次看,还是要查许多东西。尤其是数据处理方面。

数据的处理(https://segmentfault.com/a/1190000008793389)

拼接

t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 1) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
tf.stack([t1, t2], 0) ==> [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
tf.stack([t1, t2], 1) ==> [[[1, 2, 3], [7, 8, 9]], [[4, 5, 6], [10, 11, 12]]]
tf.stack([t1, t2], 2) ==> [[[1, 7], [2, 8], [3, 9]], [[4, 10], [5, 11], [6, 12]]]

从shape的角度看:

t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0) # [2,3] + [2,3] ==> [4, 3]
tf.concat([t1, t2], 1) # [2,3] + [2,3] ==> [2, 6]
tf.stack([t1, t2], 0) # [2,3] + [2,3] ==> [2*,2,3]
tf.stack([t1, t2], 1) # [2,3] + [2,3] ==> [2,2*,3]
tf.stack([t1, t2], 2) # [2,3] + [2,3] ==> [2,3,2*]

抽取:

input = [[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]]
tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
[4, 4, 4]]]
tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
[[5, 5, 5]]] tf.gather(input, [0, 2]) ==> [[[1, 1, 1], [2, 2, 2]],
[[5, 5, 5], [6, 6, 6]]]

tensorflow bilstm官方示例的更多相关文章

  1. DotNetBar for Windows Forms 12.7.0.10_冰河之刃重打包版原创发布-带官方示例程序版

    关于 DotNetBar for Windows Forms 12.7.0.10_冰河之刃重打包版 --------------------11.8.0.8_冰河之刃重打包版------------- ...

  2. DotNetBar for Windows Forms 12.5.0.2_冰河之刃重打包版原创发布-带官方示例程序版

    关于 DotNetBar for Windows Forms 12.5.0.2_冰河之刃重打包版 --------------------11.8.0.8_冰河之刃重打包版-------------- ...

  3. DotNetBar for Windows Forms 12.2.0.7_冰河之刃重打包版原创发布-带官方示例程序版

    关于 DotNetBar for Windows Forms 12.2.0.7_冰河之刃重打包版 --------------------11.8.0.8_冰河之刃重打包版-------------- ...

  4. html5游戏引擎phaser官方示例学习

    首发:个人博客,更新&纠错&回复 phaser官方示例学习进行中,把官方示例调整为简明的目录结构,学习过程中加了点中文注释,代码在这里. 目前把官方的完整游戏示例看了一大半, brea ...

  5. 将百度坐标转换的javascript api官方示例改写成传统的回调函数形式

    改写前: 百度地图中坐标转换的JavaScript API示例官方示例如下: var points = [new BMap.Point(116.3786889372559,39.90762965106 ...

  6. ngRx 官方示例分析 - 3. reducers

    上一篇:ngRx 官方示例分析 - 2. Action 管理 这里我们讨论 reducer. 如果你注意的话,会看到在不同的 Action 定义文件中,导出的 Action 类型名称都是 Action ...

  7. ngRx 官方示例分析 - 2. Action 管理

    我们从 Action 名称开始. 解决 Action 名称冲突问题 在 ngRx 中,不同的 Action 需要一个 Action Type 进行区分,一般来说,这个 Action Type 是一个字 ...

  8. ngRx 官方示例分析 - 1. 介绍

    ngRx 的官方示例演示了在具体的场景中,如何使用 ngRx 管理应用的状态. 示例介绍 示例允许用户通过查询 google 的 book  API  来查询图书,并保存自己的精选书籍列表. 菜单有两 ...

  9. Ionic 2 官方示例程序 Super Starter

    原文发表于我的技术博客 本文分享了 Ionic 2 官方示例程序 Super Starter 的简要介绍与安装运行的方法,最好的学习示例代码,项目共包含了 14 个通用的页面设计,如:引导页.主页面详 ...

随机推荐

  1. 理解First Chance和Second Chance避免单步调试

    原文链接地址:http://blog.csdn.net/Donjuan/article/details/3859160 在现在C++.Java..Net代码大行其道的时候,很多代码错误(Bug)都是通 ...

  2. oracle的sql语句训练

    --查询工资最高的人的名字select ename ,sal from emp where sal=(select max(sal) from emp );--求出员工的工资在所有人的平均工资之上的人 ...

  3. 使用命令wsimport生成WebService客户端

    使用命令wsimport生成WebService客户端 wsimpost命令有几个重要的参数: -keep:是否生成java源文件    -d:指定输出目录    -s:指定源代码输出目录    -p ...

  4. 记录一发wm_concat()函数排序的问题

    需求:需要将列转行之后的工序按照待执行工序号排序,如果一样按工序号排 解决方法如下: select part_no, max(ywggx) ywggx from(select mt.part_no , ...

  5. Tomcat部署时war和war exploded区别及验证

    war和war exploded的区别 在使用IDEA开发项目的时候,部署Tomcat的时候通常会出现下边的情况: 是选择war还是war exploded 这里首先看一下他们两个的区别: war模式 ...

  6. [BZOJ2190&BZOJ2705]欧拉函数应用两例

    欧拉函数phi[n]是表示1~n中与n互质的数个数. 可以用公式phi[n]=n*(1-1/p1)*(1-1/p2)*(1-1/p3)...*(1-1/pk)来表示.(p为n的质因子) 求phi[p] ...

  7. bzoj 1044 贪心二分+DP

    原题传送门http://www.lydsy.com/JudgeOnline/problem.php?id=1044 首先对于第一问,我们可以轻易的用二分答案来搞定,对于每一个二分到的mid值 我们从l ...

  8. myeclipse打断点进入后无法查看变量的值的解决方法

    myeclipse打断点进入后无法查看变量的值,打开mycelipse菜单选项:“Window” - “Preferences” - “Java” - “Editor” - “Hovers" ...

  9. [ Python - 6 ] 正则表达式实现计算器功能

    要求:禁止使用eval函数.参考网上代码如下: #!_*_coding:utf-8_*_ """用户输入计算表达式,显示计算结果""" im ...

  10. Jxl、JxCell图表导出功能的实现

    最近接触过许多报表导出功能,也用过多种工具进行导出功能的实现,但对于图表的导出一直没有仔细的去展开研究和探讨,直到最近略微整理了下这方面的需求和技术攻克. 首先导出excel功能的实现主要有JXL.J ...