RNN的一种类型模型被称为长短期记忆网络(LSTM)。我觉得这是一个有趣的名字。它听起来也意味着:短期模式长期不会被遗忘。

LSTM的精确实现细节不在本文的范围之内。相信我,如果只学习LSTM模型会分散我们的注意力,因为它还没有确定的标准

 

所示。

:导入相关库

import numpy as np

import tensorflow as tf

from tensorflow.contrib import rnn

所示,构造函数里面设置模型超参数,权重和成本函数。

:定义一个类及其构造函数

class SeriesPredictor:

def __init__(self, input_dim, seq_size, hidden_dim=10):

self.input_dim = input_dim //#A

self.seq_size = seq_size //#A

self.hidden_dim = hidden_dim //#A

self.W_out = tf.Variable(tf.random_normal([hidden_dim, 1]),name='W_out') //#B

self.b_out = tf.Variable(tf.random_normal([1]), name='b_out') //#B

self.x = tf.placeholder(tf.float32, [None, seq_size, input_dim]) //#B

self.y = tf.placeholder(tf.float32, [None, seq_size]) //#B

self.cost = tf.reduce_mean(tf.square(self.model() - self.y)) //#C

self.train_op = tf.train.AdamOptimizer().minimize(self.cost) //#C

self.saver = tf.train.Saver() //#D

#A超参数。

#B权重变量和输入占位符。

#C成本优化器(cost optimizer)。

#D辅助操作

详细介绍了如何使用TensorFlow来实现使用LSTM的预测模型。

:定义RNN模型

def model(self):

"""

:param x: inputs of size [T, batch_size, input_size]

:param W: matrix of fully-connected output layer weights

:param b: vector of fully-connected output layer biases

"""

cell = rnn.BasicLSTMCell(self.hidden_dim) #A

outputs, states = tf.nn.dynamic_rnn(cell, self.x, dtype=tf.float32) #B

num_examples = tf.shape(self.x)[0]

W_repeated = tf.tile(tf.expand_dims(self.W_out, 0), [num_examples, 1, 1])#C

out = tf.matmul(outputs, W_repeated) + self.b_out

out = tf.squeeze(out)

return out

#A创建一个LSTM单元。

#B运行输入单元,获取输出和状态的张量。

#C将输出层计算为完全连接的线性函数。

所示,你打开会话并重复运行优化器。

另外,你可以使用交叉验证来确定训练模型的迭代次数。在这里我们假设固定数量的epocs。

训练后,将模型保存到文件中,以便稍后加载使用。

:在一个数据集上训练模型

def train(self, train_x, train_y):

with tf.Session() as sess:

tf.get_variable_scope().reuse_variables()

sess.run(tf.global_variables_initializer())

for i in range(1000): #A

mse = sess.run([self.train_op, self.cost], feed_dict={self.x: train_x, self.y: train_y})

if i % 100 == 0:

print(i, mse)

save_path = self.saver.save(sess, 'model.ckpt')

print('Model saved to {}'.format(save_path))

加载已保存的模型,并通过馈送一些测试数据以此来运行模型。如果学习的模型在测试数据上表现不佳,那么我们可以尝试调整LSTM单元格的隐藏维数

:测试学习的模型

def test(self, test_x):

with tf.Session() as sess:

tf.get_variable_scope().reuse_variables()

self.saver.restore(sess, './model.ckpt')

output = sess.run(self.model(), feed_dict={self.x: test_x})

print(output)

中,我们将创建输入序列,称为train_x,和相应的输出序列,称为train_y。

训练并测试一些虚拟数据

if __name__ == '__main__':

predictor = SeriesPredictor(input_dim=1, seq_size=4, hidden_dim=10)

train_x = [[[1], [2], [5], [6]],

[[5], [7], [7], [8]],

[[3], [4], [5], [7]]]

train_y = [[1, 3, 7, 11],

[5, 12, 14, 15],

[3, 7, 9, 12]]

predictor.train(train_x, train_y)

test_x = [[[1], [2], [3], [4]], #A

[[4], [5], [6], [7]]] #B

predictor.test(test_x)

,3,5,7。

,9,11,13。

你可以将此预测模型视为黑盒子,并用现实世界的时间数据进行测试。

Tensorflow RNN_LSTM实例的更多相关文章

  1. Mac tensorflow mnist实例

    Mac tensorflow mnist实例 前期主要需要安装好tensorflow的环境,Mac 如果只涉及到CPU的版本,推荐使用pip3,傻瓜式安装,一行命令!代码使用python3. 在此附上 ...

  2. 深度学习之卷积神经网络CNN及tensorflow代码实例

    深度学习之卷积神经网络CNN及tensorflow代码实例 什么是卷积? 卷积的定义 从数学上讲,卷积就是一种运算,是我们学习高等数学之后,新接触的一种运算,因为涉及到积分.级数,所以看起来觉得很复杂 ...

  3. Forward-backward梯度求导(tensorflow word2vec实例)

    考虑不可分的例子         通过使用basis functions 使得不可分的线性模型变成可分的非线性模型 最常用的就是写出一个目标函数 并且使用梯度下降法 来计算     梯度的下降法的梯度 ...

  4. TensorFlow 简单实例

    TF 手写体识别简单实例: TensorFlow很适合用来进行大规模的数值计算,其中也包括实现和训练深度神经网络模型.下面将介绍TensorFlow中模型的基本组成部分,同时将构建一个CNN模型来对M ...

  5. 条件随机场(crf)及tensorflow代码实例

    对于条件随机场的学习,我觉得应该结合HMM模型一起进行对比学习.首先浏览HMM模型:https://www.cnblogs.com/pinking/p/8531405.html 一.定义 条件随机场( ...

  6. 关于深度学习之TensorFlow简单实例

    1.对TensorFlow的基本操作 import tensorflow as tf import os os.environ[" a=tf.constant(2) b=tf.constan ...

  7. TensorFlow 基本使用

    使用 TensorFlow, 你必须明白 TensorFlow: 使用图 (graph) 来表示计算任务. 在被称之为 会话 (Session) 的上下文 (context) 中执行图. 使用 ten ...

  8. [学习笔记] TensorFlow 入门之基本使用

    整体介绍 使用 TensorFlow, 你必须明白 TensorFlow: 使用图 (graph) 来表示计算任务. 在被称之为 会话 (Session) 的上下文 (context) 中执行图. 使 ...

  9. 【Tensorflow】Tensorflow入门教程

    基本使用 使用 TensorFlow, 你必须明白 TensorFlow: 使用图 (graph) 来表示计算任务. 在被称之为 会话 (Session) 的上下文 (context) 中执行图. 使 ...

随机推荐

  1. linux命令学习之:sed

    sed:Stream Editor文本流编辑,sed是一个“非交互式的”面向字符流的编辑器.能同时处理多个文件多行的内容,可以不对原文件改动,把整个文件输入到屏幕,可以把只匹配到模式的内容输入到屏幕上 ...

  2. 编程,计算data段中的第一组数据的3次方,结果保存在后面一组dword单元中

    assume cs:code data segment dw ,,,,,,, dd ,,,,,,, data ends code segment start: mov ax,data mov ds,a ...

  3. Java 对象 引用,equal == string

    以前确实一直没注意这个概念,这次看了帖子才知道. 转载于:https://zwmf.iteye.com/blog/1738574 Java对象及其引用 关于对象与引用之间的一些基本概念. 初学Java ...

  4. 关于map::erase的使用说明

    C++ 中经常使用的容器类有vector,list,map.其中vector和list的erase都是返回迭代器,但是map就比较不一样. 当在循环体中使用map::erase语句时,为了能够在任何机 ...

  5. stark组件前戏之项目启动前加载指定文件

    1. django项目启动时, 自定制执行某个py文件 dajngo 启动时.会将所有 路由加载到内存中. 我的目的就是在 路由加载之前,执行某个py文件. 每个app中都有一个 apps.py fr ...

  6. JwtBearerAppBuilderExtensions.UseJwtBearerAuthentication(IApplicationBuilder

    netcore从1.1升级到2.0时,出的错,因为使用的时Jwt token参考https://github.com/aspnet/Security/issues/1310#issuecomment- ...

  7. python 面向对象编程 之 单例模式

    单例模式三种实现方式: 单例模式:单例模式是解决系统资源浪费的一种方案,是指一个类实例化后可以多次使用此对象. 单例模式应用场景:数据库操作.日志.后台打印 # settings.py# Host=' ...

  8. stl之容器、迭代器、算法几者之间的关系

    转自:https://blog.csdn.net/bobodem/article/details/49386131 stl包括容器.迭代器和算法: 容器 用于管理一些相关的数据类型.每种容器都有它的优 ...

  9. Invalid character found in the request target. The valid characters are defined in RFC 7230 and RFC 3986

    最近在Tomcat上配置一个项目,在点击一个按钮,下载一个文件的时候,老是会报上面的错误.试了很多方法,如对server.xml文件中,增加MaxHttpHeaderSize的大小,改端口,改Tomc ...

  10. 简单使用DESeq做差异分析

    简单使用DESeq做差异分析 Posted: 五月 06, 2017  Under: Transcriptomics  By Kai  no Comments DESeq这个R包主要针对count d ...