Tensorflow RNN_LSTM实例
RNN的一种类型模型被称为长短期记忆网络(LSTM)。我觉得这是一个有趣的名字。它听起来也意味着:短期模式长期不会被遗忘。
LSTM的精确实现细节不在本文的范围之内。相信我,如果只学习LSTM模型会分散我们的注意力,因为它还没有确定的标准
所示。
:导入相关库
import numpy as np
import tensorflow as tf
from tensorflow.contrib import rnn
所示,构造函数里面设置模型超参数,权重和成本函数。
:定义一个类及其构造函数
class SeriesPredictor:
def __init__(self, input_dim, seq_size, hidden_dim=10):
self.input_dim = input_dim //#A
self.seq_size = seq_size //#A
self.hidden_dim = hidden_dim //#A
self.W_out = tf.Variable(tf.random_normal([hidden_dim, 1]),name='W_out') //#B
self.b_out = tf.Variable(tf.random_normal([1]), name='b_out') //#B
self.x = tf.placeholder(tf.float32, [None, seq_size, input_dim]) //#B
self.y = tf.placeholder(tf.float32, [None, seq_size]) //#B
self.cost = tf.reduce_mean(tf.square(self.model() - self.y)) //#C
self.train_op = tf.train.AdamOptimizer().minimize(self.cost) //#C
self.saver = tf.train.Saver() //#D
#A超参数。
#B权重变量和输入占位符。
#C成本优化器(cost optimizer)。
#D辅助操作
详细介绍了如何使用TensorFlow来实现使用LSTM的预测模型。
:定义RNN模型
def model(self):
"""
:param x: inputs of size [T, batch_size, input_size]
:param W: matrix of fully-connected output layer weights
:param b: vector of fully-connected output layer biases
"""
cell = rnn.BasicLSTMCell(self.hidden_dim) #A
outputs, states = tf.nn.dynamic_rnn(cell, self.x, dtype=tf.float32) #B
num_examples = tf.shape(self.x)[0]
W_repeated = tf.tile(tf.expand_dims(self.W_out, 0), [num_examples, 1, 1])#C
out = tf.matmul(outputs, W_repeated) + self.b_out
out = tf.squeeze(out)
return out
#A创建一个LSTM单元。
#B运行输入单元,获取输出和状态的张量。
#C将输出层计算为完全连接的线性函数。
所示,你打开会话并重复运行优化器。
另外,你可以使用交叉验证来确定训练模型的迭代次数。在这里我们假设固定数量的epocs。
训练后,将模型保存到文件中,以便稍后加载使用。
:在一个数据集上训练模型
def train(self, train_x, train_y):
with tf.Session() as sess:
tf.get_variable_scope().reuse_variables()
sess.run(tf.global_variables_initializer())
for i in range(1000): #A
mse = sess.run([self.train_op, self.cost], feed_dict={self.x: train_x, self.y: train_y})
if i % 100 == 0:
print(i, mse)
save_path = self.saver.save(sess, 'model.ckpt')
print('Model saved to {}'.format(save_path))
次
加载已保存的模型,并通过馈送一些测试数据以此来运行模型。如果学习的模型在测试数据上表现不佳,那么我们可以尝试调整LSTM单元格的隐藏维数
:测试学习的模型
def test(self, test_x):
with tf.Session() as sess:
tf.get_variable_scope().reuse_variables()
self.saver.restore(sess, './model.ckpt')
output = sess.run(self.model(), feed_dict={self.x: test_x})
print(output)
中,我们将创建输入序列,称为train_x,和相应的输出序列,称为train_y。
训练并测试一些虚拟数据
if __name__ == '__main__':
predictor = SeriesPredictor(input_dim=1, seq_size=4, hidden_dim=10)
train_x = [[[1], [2], [5], [6]],
[[5], [7], [7], [8]],
[[3], [4], [5], [7]]]
train_y = [[1, 3, 7, 11],
[5, 12, 14, 15],
[3, 7, 9, 12]]
predictor.train(train_x, train_y)
test_x = [[[1], [2], [3], [4]], #A
[[4], [5], [6], [7]]] #B
predictor.test(test_x)
,3,5,7。
,9,11,13。
你可以将此预测模型视为黑盒子,并用现实世界的时间数据进行测试。
Tensorflow RNN_LSTM实例的更多相关文章
- Mac tensorflow mnist实例
Mac tensorflow mnist实例 前期主要需要安装好tensorflow的环境,Mac 如果只涉及到CPU的版本,推荐使用pip3,傻瓜式安装,一行命令!代码使用python3. 在此附上 ...
- 深度学习之卷积神经网络CNN及tensorflow代码实例
深度学习之卷积神经网络CNN及tensorflow代码实例 什么是卷积? 卷积的定义 从数学上讲,卷积就是一种运算,是我们学习高等数学之后,新接触的一种运算,因为涉及到积分.级数,所以看起来觉得很复杂 ...
- Forward-backward梯度求导(tensorflow word2vec实例)
考虑不可分的例子 通过使用basis functions 使得不可分的线性模型变成可分的非线性模型 最常用的就是写出一个目标函数 并且使用梯度下降法 来计算 梯度的下降法的梯度 ...
- TensorFlow 简单实例
TF 手写体识别简单实例: TensorFlow很适合用来进行大规模的数值计算,其中也包括实现和训练深度神经网络模型.下面将介绍TensorFlow中模型的基本组成部分,同时将构建一个CNN模型来对M ...
- 条件随机场(crf)及tensorflow代码实例
对于条件随机场的学习,我觉得应该结合HMM模型一起进行对比学习.首先浏览HMM模型:https://www.cnblogs.com/pinking/p/8531405.html 一.定义 条件随机场( ...
- 关于深度学习之TensorFlow简单实例
1.对TensorFlow的基本操作 import tensorflow as tf import os os.environ[" a=tf.constant(2) b=tf.constan ...
- TensorFlow 基本使用
使用 TensorFlow, 你必须明白 TensorFlow: 使用图 (graph) 来表示计算任务. 在被称之为 会话 (Session) 的上下文 (context) 中执行图. 使用 ten ...
- [学习笔记] TensorFlow 入门之基本使用
整体介绍 使用 TensorFlow, 你必须明白 TensorFlow: 使用图 (graph) 来表示计算任务. 在被称之为 会话 (Session) 的上下文 (context) 中执行图. 使 ...
- 【Tensorflow】Tensorflow入门教程
基本使用 使用 TensorFlow, 你必须明白 TensorFlow: 使用图 (graph) 来表示计算任务. 在被称之为 会话 (Session) 的上下文 (context) 中执行图. 使 ...
随机推荐
- swift - 3D 视图,截图,关键字搜索
1.xib 上的 3D效果 按钮 2. import UIKit //1.导入框架 import MapKit class ViewController: UIViewController { @IB ...
- python基础易错题
1.以下代码输入什么: class Person: a = 1 def __init__(self): pass def getAge(self): print(__name__) p = Perso ...
- 如何在Windows下安装Tomcat服务器
Tomcat 服务器是一个免费的开放源代码的Web 应用服务器,属于轻量级应用服务器,在中小型系统和并发访问用户不是很多的场合下被普遍使用,是开发和调试JSP 程序的首选服务器.在Windows下安装 ...
- FoxMail提示:请求的名称有效,但是找不到请求的类型的数据
FoxMail发送或者接收邮件的时候,提示如下信息: <错误信息:请求的名称有效,但是找不到请求的类型的数据> 一,DNS解析不稳定 解决办法:修改本地电脑上面本地连接中的DNS地址< ...
- Laravel Relationship Events
Laravel Relationship Events is a package by Viacheslav Ostrovskiy that adds extra model relationship ...
- MySQLdb与sqlalchemy的简单封装
一:MySQLdb # !/usr/bin/python # -*- coding: UTF-8 -*- import MySQLdb import MySQLdb.cursors import co ...
- js 光标位置处理
/** * 获取选中文字 * 返回selection,toString可拿到结果,selection含有起始光标位置信息等 **/ function getSelectText() { var tex ...
- android源码下载/查看地址
源码下载: http://git.omapzoom.org/ 高通平台android源码下载地址: https://www.codeaurora.org/xwiki/bin/QAEP/WebHome ...
- spring学习 十二 AspectJ-based的通知入门 带参数的通知
第一步:编写通知类 package com.airplan.pojo; import org.aspectj.lang.ProceedingJoinPoint; public class Advice ...
- codeforces C. Functions again
题意:给定了一个公式,让你找到一对(l,r),求解出公式给定的F值. 当时没有想到,我把(-1)^(i-l)看成(-1)^i,然后思路就完全错了.其实这道题是个简单的dp+最长连续子序列. O(n)求 ...