RNN教程之-2 LSTM实战
前言
说出来你们不敢相信,刚才码了半天的字,一个侧滑妈的全没了,都怪这Mac的触摸板太敏感沃日。好吧,不浪费时间了,前言一般都是废话,这个教程要解决的是一个LSTM的实战问题,很多人问我RNN是啥,有什么卵用,你可以看看我之前写的博客可以入门,但是如果你想实际操作代码,那么慢慢看这篇文章。本文章所有代码和数据集在我的Github Repository下载。
问题
给你一个数据集,只有一列数据,这是一个关于时间序列的数据,从这个时间序列中预测未来一年某航空公司的客运流量。
首先我们数据预览一下,用pandas读取数据,这里我们只需要使用后一列真实数据,如果你下载了数据,数据大概长这样:
time passengers
0 1949-01 112
1 1949-02 118
2 1949-03 132
3 1949-04 129
4 1949-05 121
5 1949-06 135
6 1949-07 148
7 1949-08 148
8 1949-09 136
9 1949-10 119
... ... ....
第一列是时间,第二列是客流量,为了看出这个我们要预测的客流量随时间的变化趋势,本大神教大家如何把趋势图画出来,接下来就非常牛逼了。用下面的代码来画图:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
df = pd.read_csv('international-airline-passengers.csv', sep=',')
df = df.set_index('time')
df['passengers'].plot()
plt.show()
这时候我们可以看到如下的趋势图:

可以看出,我们的数据存在一定的周期性,这个周期性并不是一个重复出现某个值,而是趋势的增长过程有一定的规律性,这个我们人肉眼就能看得出来,但是实际上计算机要识别这种规律就有一定的难度了,这时候就需要使用我们的LSTM大法。
好的,数据已经预览完了,接下来我们得思考一下怎么预测,怎么把数据处理为LSTM网络需要的格式。
LSTM数据预处理
这个过程非常重要,这也是很多水平不高的博客或者文章中没有具体阐述而导致普通读者不知道毛意思的过程,其实我可以这样简单的叙述,LSTM你不要以为各种时间序列搞的晕头转向,其实本质它还是神经网络,与普通的神经网络没有任何区别。我们接下来就用几行小代码把数据处理为我们需要的类似于神经网络输入的二维数据。
首先我们确确实实需要的只是一列数据:
df = pd.read_csv(file_name, sep=',', usecols=[1])
data_all = np.array(df).astype(float)
print(data_all)
输出是:
[[ 112.]
[ 118.]
[ 132.]
[ 129.]
[ 121.]
[ 135.]
[ 148.]
[ 148.]
[ 136.]
[ 119.]
[ 104.]
[ 118.]
[ 115.]
....
]
非常好,现在我们已经把我们需要的数据抠出来了,继续上面处理:
data = []
for i in range(len(data_all) - sequence_length - 1):
data.append(data_all[i: i + sequence_length + 1])
reshaped_data = np.array(data).astype('float64')
print(reshaped_data)
这时候你会发现好像结果看不懂,不知道是什么数据,如果你data_all处理时加ravel()(用来把数据最里面的中括号去掉),即:
df = pd.read_csv(file_name, sep=',', usecols=[1])
data_all = np.array(df).ravel().astype(float)
print(data_all)
那么数据输出一目了然:
[[ 112. 118. 132. ..., 136. 119. 104.]
[ 118. 132. 129. ..., 119. 104. 118.]
[ 132. 129. 121. ..., 104. 118. 115.]
...,
[ 362. 405. 417. ..., 622. 606. 508.]
[ 405. 417. 391. ..., 606. 508. 461.]
[ 417. 391. 419. ..., 508. 461. 390.]]
是的,没有错!一列数据经过我们这样不处理就可以作为LSTM网络的输入数据了,而且和神经网络没有什么两样!!牛逼吧?牛逼快去哥的Github Repo给个star,喊你们寝室的菜市场的大爷大妈都来赞!越多越好,快,哥的大牛之路就靠你们了!
然而这还是只是开始。。接下来要做的就是把数据切分为训练集和测试集:
split = 0.8
np.random.shuffle(reshaped_data)
x = reshaped_data[:, :-1]
y = reshaped_data[:, -1]
split_boundary = int(reshaped_data.shape[0] * split)
train_x = x[: split_boundary]
test_x = x[split_boundary:]
train_y = y[: split_boundary]
test_y = y[split_boundary:]
这些步骤相信聪明的你一点看得懂,我就不多废话了,我要说明的几点是,你运行时直接运行Github上的脚本代码,如果报错请私信我微信jintianiloveu
,我在代码中把过程包装成了函数所以文章中的代码可能不太一样。在实际代码中数据是需要归一化的,这个你应该知道,如何归一化代码中也有。
搭建LSTM模型
好,接下来是最牛逼的部分,也是本文章的核心内容(但实际内容并不多),数据有了,我们就得研究研究LSTM这个东东,不管理论上吹得多么牛逼,我只看它能不能解决问题,不管黑猫白猫,能抓到老鼠的就是好猫,像我们这样不搞伪学术注重经济效益的商人来说,这点尤为重要。搭建LSTM模型,我比较推荐使用keras,快速简单高效,分分钟,但是牺牲的是灵活性,不过话又说回来,真正的灵活性也是可以发挥的,只是要修改底层的东西那就有点麻烦了,我们反正是用它来解决问题的,更基础的部分我们就不研究了,以后有时间再慢慢深入。
在keras 的官方文档中,说了LSTM是整个Recurrent层实现的一个具体类,它需要的输入数据维度是:
形如
(samples,timesteps,input_dim)
的3D张量
发现没有,我们上面处理完了数据的格式就是(samples,timesteps)
这个time_step是我们采用的时间窗口,把一个时间序列当成一条长链,我们固定一个一定长度的窗口对这个长链进行采用,最终就得到上面的那个二维数据,那么我们缺少的是input_dim这个维度,实际上这个input_dim就是我们的那一列数据的数据,我们现在处理的是一列也有可能是很多列,一系列与时间有关的数据需要我们去预测,或者文本处理中会遇到。我们先不管那么多,先把数据处理为LSTM需要的格式:
train_x = np.reshape(train_x, (train_x.shape[0], train_x.shape[1], 1))
test_x = np.reshape(test_x, (test_x.shape[0], test_x.shape[1], 1))
好的,这时候数据就是我们需要的啦。接下来搭建模型:
# input_dim是输入的train_x的最后一个维度,train_x的维度为(n_samples, time_steps, input_dim)
model = Sequential()
model.add(LSTM(input_dim=1, output_dim=50, return_sequences=True))
model.add(LSTM(100, return_sequences=False))
model.add(Dense(output_dim=1))
model.add(Activation('linear'))
model.compile(loss='mse', optimizer='rmsprop')
看到没,这个LSTM非常简单!!甚至跟输入的数据格式没有任何关系,只要输入数据的维度是1,就不需要修改模型的任何参数就可以把数据输入进去进行训练!
我们这里使用了两个LSTM进行叠加,第二个LSTM第一个参数指的是输入的维度,这和第一个LSTM的输出维度并不一样,这也是LSTM比较“随意”的地方。最后一层采用了线性层。
结果
预测的结果如下图所示:

这个结果还是非常牛逼啊,要知道我们的数据是打乱过得噢,也就是说泛化能力非常不错,厉害了word LSTM!
筒子们,本系列教程到此结束,欢迎再次登录老司机的飞船。。。。如果有不懂的私信我,想引起我的注意快去Github上给我star!!!
转:http://www.jianshu.com/p/5d6d5aac4dbd
RNN教程之-2 LSTM实战的更多相关文章
- NLP教程(5) - 语言模型、RNN、GRU与LSTM
作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/36 本文地址:http://www.showmeai.tech/article-det ...
- 【ASP.NET实战教程】ASP.NET实战教程大集合,各种项目实战集合
[ASP.NET实战教程]ASP.NET实战教程大集合,各种项目实战集合,希望大家可以好好学习教程中,有的比较老了,但是一直很经典!!!!论坛中很多小伙伴说.net没有实战教程学习,所以小编连夜搜集整 ...
- Docker最全教程——从理论到实战(八)
在本系列教程中,笔者希望将必要的知识点围绕理论.流程(工作流程).方法.实践来进行讲解,而不是单纯的为讲解知识点而进行讲解.也就是说,笔者希望能够让大家将理论.知识.思想和指导应用到工作的实际场景和实 ...
- Docker最全教程——从理论到实战(七)
在本系列教程中,笔者希望将必要的知识点围绕理论.流程(工作流程).方法.实践来进行讲解,而不是单纯的为讲解知识点而进行讲解.也就是说,笔者希望能够让大家将理论.知识.思想和指导应用到工作的实际场景和实 ...
- Docker最全教程——从理论到实战(六)
托管到腾讯云容器服务 托管到腾讯云容器服务,我们的公众号“magiccodes”已经发布了相关的录屏教程,大家可以结合本篇教程一起查阅. 自建还是托管? 在开始之前,我们先来讨论一个问题——是自建 ...
- Docker最全教程——从理论到实战(五)
往期内容链接 Docker最全教程——从理论到实战(一) Docker最全教程——从理论到实战(二) Docker最全教程——从理论到实战(三) Docker最全教程——从理论到实战(四) 本篇教程持 ...
- Docker最全教程——从理论到实战
Docker最全教程——从理论到实战(一) Docker最全教程——从理论到实战(二) Docker最全教程——从理论到实战(三) Docker最全教程——从理论到实战(四) Docker最全教程—— ...
- 深度学习原理与框架-递归神经网络-RNN网络基本框架(代码?) 1.rnn.LSTMCell(生成单层LSTM) 2.rnn.DropoutWrapper(对rnn进行dropout操作) 3.tf.contrib.rnn.MultiRNNCell(堆叠多层LSTM) 4.mlstm_cell.zero_state(state初始化) 5.mlstm_cell(进行LSTM求解)
问题:LSTM的输出值output和state是否是一样的 1. rnn.LSTMCell(num_hidden, reuse=tf.get_variable_scope().reuse) # 构建 ...
- 深度学习之循环神经网络RNN概述,双向LSTM实现字符识别
深度学习之循环神经网络RNN概述,双向LSTM实现字符识别 2. RNN概述 Recurrent Neural Network - 循环神经网络,最早出现在20世纪80年代,主要是用于时序数据的预测和 ...
随机推荐
- 【Android开发笔记】Android Splash Screen 启动界面
public class SplashActivity extends Activity { @Override protected void onCreate(Bundle savedInstanc ...
- Python学习笔记-day1(if流程控制)
在python中,流程控制语句为强制缩进(4空格) if username=='lmc' and password=='123456': print('Welcome User {name} logi ...
- [:space:]的用法(转)
转自:http://blog.itpub.net/27181165/viewspace-1061688/ 在linux中通常会使用shell结合正则表达式来过滤字符,本文将以一个简单的例子来说明+,* ...
- 本人常用的Phpstorm快捷键
我设置的是eclipse的按键风格(按键习惯),不是phpstorm的风格 1.添加TODO(这个不是快捷键)://TODO 后面是说明,换行写实现代码 2.选择相同单词做一次性修改:Alt+J+鼠标 ...
- mif文件生成方法
mif文件就是存储器初始化文件,即memory initialization file,用来配置RAM或ROM中的数据.常见生成方法: Quartus自带的mif编辑器生成 mif软件生成 高级编程语 ...
- tp3.2.3自定义全局函数的使用
全局函数的定义,好处就是我们可以跨文件使用,而且调用方式可以直接调用,十分方便,在这里做个小记录 1.在Application/Home/Common目录下面新建一个名为function.php的文件 ...
- 关于ffmpeg(libav)解码视频最后丢帧的问题
其实最初不是为了解决这个问题而来的,是Peter兄给我的提示解决另一个问题却让我误打误撞解决了另外一个问题之后也把这个隐藏了很久的bug找到(之前总是有一些特别短的视频产生不知所措还以为是视频素材本身 ...
- java—三大框架详解,其发展过程及掌握的Java技术慨括
Struts.Hibernate和Spring是我们Java开发中的常用关键,他们分别针对不同的应用场景给出最合适的解决方案.但你是否知道,这些知名框架最初是怎样产生的? 我们知道,传统的Java W ...
- Hbase各种查询总结
运用hbase好长时间了,今天利用闲暇时间把Hbase的各种查询总结下,以后有时间把协处理器和自定义File总结下. 查询条件分为: 1.统计表数据 2,hbase 简单分页 3,like 查询 4 ...
- python_27_多级字典嵌套及操作
#key-value 字典无下标 所以乱序,key值尽量不要取中文 person_log={ '大二':{ 'Ya Nan':['free','cute','soso'], 'Sha sha':['微 ...