lstm的前向结构,不迭代

最基本的lstm结构。不涉及损失值和bp过程

import tensorflow as tf
import numpy as np inputs = tf.placeholder(np.float32, shape=(32,40,5)) # 32 是 batch_size
lstm_cell_1 = tf.nn.rnn_cell.LSTMCell(num_units=128) #实例话一个lstm单元,输出是128单元 print("output_size:",lstm_cell_1.output_size)
print("state_size:",lstm_cell_1.state_size)
print(lstm_cell_1.state_size.h)
print(lstm_cell_1.state_size.c) output,state=tf.nn.dynamic_rnn(
cell=lstm_cell_1,
inputs=inputs,
dtype=tf.float32
)
# 根据inputs输入的维度迭代rnn,并将输出和隐层态,push进output和state里面。
(inputs是三个维度,第一维,是batch_size,第二维:数据切片为面,第三维:切片面的具体数据) print("第一个输入的最后一个序列的预测输出:",output[1,-1,:])
print("output.shape:",output.shape)
print("len of state tuple",len(state))
print("state.h.shape:",state.h.shape)
print("state.c.shape:",state.c.shape) #>>>
output_size: 128
state_size: LSTMStateTuple(c=128, h=128)
128
128
第一个输入的最后一个序列的预测输出: Tensor("strided_slice:0", shape=(128,), dtype=float32)
output.shape: (32, 40, 128)
len of state tuple 2
state.h.shape: (32, 128)
state.c.shape: (32, 128)

用lstm对mnist数据分类

#引包和加载mnist数据

import tensorflow as tf
import input_data
import numpy as np
import matplotlib.pyplot as plt mnist = input_data.read_data_sets("data/", one_hot=True)
trainimgs, trainlabels, testimgs, testlabels \
= mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
ntrain, ntest, dim, nclasses \
= trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
print ("MNIST loaded")
diminput=28
dimhidden=128
dimoutput=nclasses
nsteps=28
weights={
'hidden':tf.Variable(tf.random_normal([diminput,dimhidden])),
'out':tf.Variable(tf.random_normal([dimhidden,dimoutput]))
}
biases={
'hidden':tf.Variable(tf.random_normal([dimhidden])),
'out':tf.Variable(tf.random_normal([dimoutput]))
}
def RNN(X,W,B,nsteps,name):
print(X.shape,'---')
X=tf.reshape(X,[-1,diminput])
X = tf.matmul(X, W['hidden']) + B['hidden']
X=tf.reshape(X,[-1,diminput,dimhidden])
print(X.shape)
with tf.variable_scope(name) as scope:
#scope.reuse_variables()
lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(dimhidden,forget_bias=1.0)
lstm_o,lstm_s=tf.nn.dynamic_rnn(cell=lstm_cell,inputs=X,dtype=tf.float32)
resultOut=tf.matmul(lstm_o[:,-1,:],W['out'])+B['out']
return {
'X':X,
'lstm_o':lstm_o,'lstm_s':lstm_s,'resultOut':resultOut
}
learning_rate=0.001
x=tf.placeholder('float',[None,nsteps,diminput]) y=tf.placeholder('float',[None,dimoutput]) myrnn=RNN(x,weights,biases,nsteps,'basic')
pred=myrnn['resultOut']
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optm=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
accr=tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1),tf.argmax(y,1)),tf.float32))
init=tf.global_variables_initializer()
training_epochs=33
batch_size=16
display_step=1
sess=tf.Session()
sess.run(init) for epoch in range(training_epochs):
avg_cost=100
total_batch=100
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))
feeds = {x: batch_xs, y: batch_ys}
sess.run(optm, feed_dict=feeds)
# Compute average loss
avg_cost += sess.run(cost, feed_dict=feeds)/total_batch
if epoch % display_step == 0:
print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
feeds = {x: batch_xs, y: batch_ys}
train_acc = sess.run(accr, feed_dict=feeds)
print (" Training accuracy: %.3f" % (train_acc))
testimgs = testimgs.reshape((ntest, nsteps, diminput))
feeds = {x: testimgs, y: testlabels}
test_acc = sess.run(accr, feed_dict=feeds)
print (" Test accuracy: %.3f" % (test_acc))
Epoch: 000/033 cost: 101.797383542
Training accuracy: 0.688
Test accuracy: 0.461
Epoch: 001/033 cost: 101.269138204
Training accuracy: 0.438
Test accuracy: 0.549
Epoch: 002/033 cost: 101.139203327
Training accuracy: 0.688
Test accuracy: 0.614
Epoch: 003/033 cost: 100.965362185
Training accuracy: 0.938
Test accuracy: 0.619
Epoch: 004/033 cost: 100.914383653
Training accuracy: 0.875
Test accuracy: 0.648
Epoch: 005/033 cost: 100.813317066
Training accuracy: 0.625
Test accuracy: 0.656
Epoch: 006/033 cost: 100.781623098
Training accuracy: 0.875
Test accuracy: 0.708
Epoch: 007/033 cost: 100.710710035
Training accuracy: 1.000
Test accuracy: 0.716
Epoch: 008/033 cost: 100.684573339
Training accuracy: 1.000
Test accuracy: 0.745
Epoch: 009/033 cost: 100.635698693
Training accuracy: 0.875
Test accuracy: 0.751
Epoch: 010/033 cost: 100.622099145
Training accuracy: 0.938
Test accuracy: 0.763
Epoch: 011/033 cost: 100.562925613
Training accuracy: 0.750
Test accuracy: 0.763
Epoch: 012/033 cost: 100.592214927
Training accuracy: 0.812
Test accuracy: 0.771
Epoch: 013/033 cost: 100.544024273
Training accuracy: 0.938
Test accuracy: 0.769
Epoch: 014/033 cost: 100.516522627
Training accuracy: 0.875
Test accuracy: 0.791
Epoch: 015/033 cost: 100.479632292
Training accuracy: 0.938
Test accuracy: 0.801
Epoch: 016/033 cost: 100.471150137
Training accuracy: 0.938
Test accuracy: 0.816
Epoch: 017/033 cost: 100.431061392
Training accuracy: 0.875
Test accuracy: 0.807
Epoch: 018/033 cost: 100.464853102
Training accuracy: 0.812
Test accuracy: 0.798
Epoch: 019/033 cost: 100.445183915
Training accuracy: 0.750
Test accuracy: 0.828
Epoch: 020/033 cost: 100.399013084
Training accuracy: 1.000
Test accuracy: 0.804
Epoch: 021/033 cost: 100.393008129
Training accuracy: 0.938
Test accuracy: 0.833
Epoch: 022/033 cost: 100.413909222
Training accuracy: 0.812
Test accuracy: 0.815

RNN(二)——基于tensorflow的LSTM的实现的更多相关文章

  1. 如何基于TensorFlow使用LSTM和CNN实现时序分类任务

    https://www.jiqizhixin.com/articles/2017-09-12-5 By 蒋思源2017年9月12日 09:54 时序数据经常出现在很多领域中,如金融.信号处理.语音识别 ...

  2. 学习Tensorflow的LSTM的RNN例子

    学习Tensorflow的LSTM的RNN例子 基于TensorFlow一次简单的RNN实现 极客学院-递归神经网络 如何使用TensorFlow构建.训练和改进循环神经网络

  3. TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人

    简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...

  4. 个基于TensorFlow的简单故事生成案例:带你了解LSTM

    https://medium.com/towards-data-science/lstm-by-example-using-tensorflow-feb0c1968537 在深度学习中,循环神经网络( ...

  5. TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人。

    简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...

  6. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  7. 两种开源聊天机器人的性能测试(二)——基于tensorflow的chatbot

    http://blog.csdn.net/hfutdog/article/details/78155676 开源项目链接:https://github.com/dennybritz/chatbot-r ...

  8. 基于TensorFlow Object Detection API进行迁移学习训练自己的人脸检测模型(二)

    前言 已完成数据预处理工作,具体参照: 基于TensorFlow Object Detection API进行迁移学习训练自己的人脸检测模型(一) 设置配置文件 新建目录face_faster_rcn ...

  9. 基于Tensorflow + Opencv 实现CNN自定义图像分类

    摘要:本篇文章主要通过Tensorflow+Opencv实现CNN自定义图像分类案例,它能解决我们现实论文或实践中的图像分类问题,并与机器学习的图像分类算法进行对比实验. 本文分享自华为云社区< ...

随机推荐

  1. AtCoder练习

    1. 3721 Smuggling Marbles 大意: 给定$n+1$节点树, $0$为根节点, 初始在一些节点放一个石子, 然后按顺序进行如下操作. 若$0$节点有石子, 则移入盒子 所有石子移 ...

  2. (三)Activiti之第一个程序以及Activiti插件的使用和Activiti表的解释

    一.案例 1.1 建立Activiti Diagram图 new -> activiti ->Activiti Diagram,创建一个HelloWorld文件,后缀自动为bpmn,如下图 ...

  3. Keras 训练 inceptionV3 并移植到OpenCV4.0 in C++

    1. 训练 # --coding:utf--- import os import sys import glob import argparse import matplotlib.pyplot as ...

  4. 【导出导入】IMPDP table_exists_action 参数的应用

    转自:https://yq.aliyun.com/articles/29337 当使用IMPDP完成数据库导入时,如遇到表已存在时,Oracle提供给我们如下四种处理方式:a.忽略(SKIP,默认行为 ...

  5. Go 方法使用

    方法的定义 在 Go 语言里,方法和函数只差了一个,那就是方法在 func 和标识符之间多了一个参数. type user struct { name string, email string, } ...

  6. 【Zookeeper】应用场景概述

    一.数据发布与订阅(配置中心) 二.负载均衡 三.命名服务(Naming Service) 四.分布式通知/协调 五.集群管理与Master选举 六.分布式锁 七.分布式事务 一.数据发布与订阅(配置 ...

  7. 怎么处理Win10系统更新提示代码0x80070057的错误?

    在使用好系统重装助手重装了Win10系统后,由于每个用户的电脑配置不同,有些用户会在更新时出现0x80070057的错误代码.下面就教大家Win10系统更新出现0x80070057错误该怎么解决. W ...

  8. Python内存数据序列化到硬盘上哪家强

    1. 闲扯一下:文件 磁盘上的数据,我们一般称为 “文件” ,一般不同的文件都有各自的后缀名,比如 .txt .docx .xlsx .jpg .mp3 .avi .这些不同类型的文件一般分为两大类: ...

  9. flash多进程写操作

    1 应用场景介绍   硬件条件:使用stm32 MCU   软件条件:协议栈应用   协议栈简单介绍如下:   类似于OSI七层模型,所涉及的协议栈包括应用层,网络层,链路层,物理层,如下图:   在 ...

  10. 【转】5种网络IO模型

    5种网络IO模型(有图,很清楚) IO多路复用—由Redis的IO多路复用yinch Linux中对文件描述符的操作(FD_ZERO.FD_SET.FD_CLR.FD_ISSET