运行代码:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # set random seed for comparing the two result calculations
tf.set_random_seed(1) # this is data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) # hyperparameters
lr = 0.001
training_iters = 100000
batch_size = 128 n_inputs = 28 # MNIST data input (img shape: 28*28)
n_steps = 28 # time steps
n_hidden_units = 128 # neurons in hidden layer
n_classes = 10 # MNIST classes (0-9 digits) # tf Graph input
x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_classes]) # Define weights
weights = {
# (28, 128)
'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),
# (128, 10)
'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))
}
biases = {
# (128, )
'in': tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),
# (10, )
'out': tf.Variable(tf.constant(0.1, shape=[n_classes, ]))
} def RNN(X, weights, biases):
# hidden layer for input to cell # transpose the inputs shape from
# X ==> (128 batch * 28 steps, 28 inputs)
X = tf.reshape(X, [-1, n_inputs]) # into hidden
# X_in = (128 batch * 28 steps, 128 hidden)
X_in = tf.matmul(X, weights['in']) + biases['in']
# X_in ==> (128 batch, 28 steps, 128 hidden)
X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units]) # cell
########################################## # basic LSTM Cell.
cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units)
# lstm cell is divided into two parts (c_state, h_state)
init_state = cell.zero_state(batch_size, dtype=tf.float32) outputs, final_state = tf.nn.dynamic_rnn(cell, X_in, initial_state=init_state, time_major=False) # unpack to list [(batch, outputs)..] * steps
outputs = tf.unstack(tf.transpose(outputs, [1,0,2]))
results = tf.matmul(outputs[-1], weights['out']) + biases['out'] # shape = (128, 10) return results pred = RNN(x, weights, biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train_op = tf.train.AdamOptimizer(lr).minimize(cost) correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
step = 0
while step * batch_size < training_iters:
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs])
sess.run([train_op], feed_dict={
x: batch_xs,
y: batch_ys,
})
if step % 20 == 0:
print(sess.run(accuracy, feed_dict={
x: batch_xs,
y: batch_ys,
}))
step += 1

运行结果:

TensorFlow从入门到理解(四):你的第一个循环神经网络RNN(分类例子)的更多相关文章

  1. TensorFlow从入门到理解(五):你的第一个循环神经网络RNN(回归例子)

    运行代码: import tensorflow as tf import numpy as np import matplotlib.pyplot as plt BATCH_START = 0 TIM ...

  2. TensorFlow从入门到理解

    一.<莫烦Python>学习笔记: TensorFlow从入门到理解(一):搭建开发环境[基于Ubuntu18.04] TensorFlow从入门到理解(二):你的第一个神经网络 Tens ...

  3. 通过keras例子理解LSTM 循环神经网络(RNN)

    博文的翻译和实践: Understanding Stateful LSTM Recurrent Neural Networks in Python with Keras 正文 一个强大而流行的循环神经 ...

  4. 基于TensorFlow的循环神经网络(RNN)

    RNN适用场景 循环神经网络(Recurrent Neural Network)适合处理和预测时序数据 RNN的特点 RNN的隐藏层之间的节点是有连接的,他的输入是输入层的输出向量.extend(上一 ...

  5. TensorFlow从入门到理解(六):可视化梯度下降

    运行代码: import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.m ...

  6. TensorFlow从入门到理解(三):你的第一个卷积神经网络(CNN)

    运行代码: from __future__ import print_function import tensorflow as tf from tensorflow.examples.tutoria ...

  7. TensorFlow从入门到理解(二):你的第一个神经网络

    运行代码: from __future__ import print_function import tensorflow as tf import numpy as np import matplo ...

  8. TensorFlow从入门到理解(一):搭建开发环境【基于Ubuntu18.04】

    *注:教程及本文章皆使用Python3+语言,执行.py文件都是用终端(如果使用Python2+和IDE都会和本文描述有点不符) 一.安装,测试,卸载 TensorFlow官网介绍得很全面,很完美了, ...

  9. 循环神经网络-RNN入门

    首先学习RNN需要一定的基础,即熟悉普通的前馈神经网络,特别是BP神经网络,最好能够手推. 所谓前馈,并不是说信号不能反向传递,而是网络在拓扑结构上不存在回路和环路. 而RNN最大的不同就是存在环路. ...

随机推荐

  1. wildfly access log 开启

    对于一个网站来说,访问日志,即access_log,是一项很重要的功能.利用它,我们可以统计出很多有用的信息,我们可以利用log来对整个网站的运行做有效的监控和分析,从而提升网站的性能. 在WildF ...

  2. A1139. First Contact

    Unlike in nowadays, the way that boys and girls expressing their feelings of love was quite subtle i ...

  3. POJ 1971 Parallelogram Counting (Hash)

          Parallelogram Counting Time Limit: 5000MS   Memory Limit: 65536K Total Submissions: 6895   Acc ...

  4. 【简单易用的傻瓜式图标设计工具】Logoist 3.1 for Mac

    [简介] Logoist 是一款Mac上强大易用的傻瓜式图标设计制作工具,通过使用内置模板和预设效果,您可以立即创建高质量的图形内容和艺术作品.通过使用该应用程序,可用于制作图标LOGO. 一款用于创 ...

  5. Word 测试下发布博客

    目录 语法.    3 NULL,TRUE,FALSE    3 大小端存储    4 类型转换    4 转义字符    5 运算符的优先级    5 表达式(a=b=c)    7 *pa++=* ...

  6. Reference-TMB

    Paper Name:Targeted Next Generation Sequencing Identifies Markers of Response to PD-1 Blockade Adress ...

  7. Xenserver之设置Xenserver和VM机开机自动启动

    一.设置Xenserver开机自动启动 [root@xenserver-DS-TestServer09 ~]# xe pool-list uuid ( RO) : b1c803a6-88cf-7b24 ...

  8. JDK1.5以后的版本特性

    一.JDK1.5新特性 1.泛型:泛型的本质是参数化类型,也就是说所操作的数据类型被指定为一个参数.这种参数类型可以用在类.接口和方法的创建中,分别称为泛型类.泛型接口.泛型方法.可以在编译的时候就能 ...

  9. spring XML配置参数替代properties文件

    xml中配置BEAN与参数 <bean id="beanXXX" class="com.benXXXX" init-method="initia ...

  10. SpringBoot笔记十三:引入webjar资源和国际化处理

    目录 什么是webjar 怎么使用webjar 国际化 新建国际化配置文件 配置配置文件 使用配置文件 我们先来看一个html,带有css的,我们就以这个为准来讲解. 资源可以去我网盘下载 链接:ht ...