参考:

https://blog.csdn.net/u012735708/article/details/82769711

https://zybuluo.com/hanbingtao/note/581764

http://blog.sina.com.cn/s/blog_afc8730e0102xup1.html

https://blog.csdn.net/qq_30638831/article/details/80060045

执行代码:

  1. import pandas as pd
  2. from datetime import datetime
  3. from matplotlib import pyplot
  4. from sklearn.preprocessing import LabelEncoder,MinMaxScaler
  5. from sklearn.metrics import mean_squared_error
  6. from keras.models import Sequential
  7. from keras.layers import Dense
  8. from keras.layers import LSTM
  9. from numpy import concatenate
  10. from math import sqrt
  11.  
  12. # load data
  13. def parse(x):
  14. return datetime.strptime(x, '%Y %m %d %H')
  15.  
  16. def read_raw():
  17. dataset = pd.read_csv('C:/Users/cf_pc/Documents/jupyter/data/PRSA_data_2010.1.1-2014.12.31.csv', parse_dates = [['year', 'month', 'day', 'hour']], index_col=0, date_parser=parse)
  18. dataset.drop('No', axis=1, inplace=True)
  19. # manually specify column names
  20. dataset.columns = ['pollution', 'dew', 'temp', 'press', 'wnd_dir', 'wnd_spd', 'snow', 'rain']
  21. dataset.index.name = 'date'
  22. # mark all NA values with 0
  23. dataset['pollution'].fillna(0, inplace=True)
  24. # drop the first 24 hours
  25. dataset = dataset[24:]
  26. # summarize first 5 rows
  27. print(dataset.head(5))
  28. # save to file
  29. dataset.to_csv('C:/Users/cf_pc/Documents/jupyter/data/pollution.csv')
  30.  
  31. def drow_pollution():
  32. dataset = pd.read_csv('C:/Users/cf_pc/Documents/jupyter/data/pollution.csv', header=0, index_col=0)
  33. values = dataset.values
  34. # specify columns to plot
  35. groups = [0, 1, 2, 3, 5, 6, 7]
  36. i = 1
  37. # plot each column
  38. pyplot.figure(figsize=(10,10))
  39. for group in groups:
  40. pyplot.subplot(len(groups), 1, i)
  41. pyplot.plot(values[:, group])
  42. pyplot.title(dataset.columns[group], y=0.5, loc='right')
  43. i += 1
  44. pyplot.show()
  45.  
  46. def series_to_supervised(data, n_in=1, n_out=1, dropnan=True):
  47. # convert series to supervised learning
  48. n_vars = 1 if type(data) is list else data.shape[1]
  49. df = pd.DataFrame(data)
  50. cols, names = list(), list()
  51. # input sequence (t-n, ... t-1)
  52. for i in range(n_in, 0, -1):
  53. cols.append(df.shift(i))
  54. names += [('var%d(t-%d)' % (j+1, i)) for j in range(n_vars)]
  55. # forecast sequence (t, t+1, ... t+n)
  56. for i in range(0, n_out):
  57. cols.append(df.shift(-i))
  58. if i == 0:
  59. names += [('var%d(t)' % (j+1)) for j in range(n_vars)]
  60. else:
  61. names += [('var%d(t+%d)' % (j+1, i)) for j in range(n_vars)]
  62. # put it all together
  63. agg = pd.concat(cols, axis=1)
  64. agg.columns = names
  65. # drop rows with NaN values
  66. if dropnan:
  67. agg.dropna(inplace=True)
  68. return agg
  69.  
  70. def cs_to_sl():
  71. # load dataset
  72. dataset = pd.read_csv('C:/Users/cf_pc/Documents/jupyter/data/pollution.csv', header=0, index_col=0)
  73. values = dataset.values
  74. # integer encode direction
  75. encoder = LabelEncoder()
  76. values[:,4] = encoder.fit_transform(values[:,4])
  77. # ensure all data is float
  78. values = values.astype('float32')
  79. # normalize features
  80. scaler = MinMaxScaler(feature_range=(0, 1))
  81. scaled = scaler.fit_transform(values)
  82. # frame as supervised learning
  83. reframed = series_to_supervised(scaled, 1, 1)
  84. # drop columns we don't want to predict
  85. reframed.drop(reframed.columns[[9,10,11,12,13,14,15]], axis=1, inplace=True)
  86. print(reframed.head())
  87. return reframed,scaler
  88.  
  89. def train_test(reframed):
  90. # split into train and test sets
  91. values = reframed.values
  92. n_train_hours = 365 * 24
  93. train = values[:n_train_hours, :]
  94. test = values[n_train_hours:, :]
  95. # split into input and outputs
  96. train_X, train_y = train[:, :-1], train[:, -1]
  97. test_X, test_y = test[:, :-1], test[:, -1]
  98. # reshape input to be 3D [samples, timesteps, features]
  99. train_X = train_X.reshape((train_X.shape[0], 1, train_X.shape[1]))
  100. test_X = test_X.reshape((test_X.shape[0], 1, test_X.shape[1]))
  101. print(train_X.shape, train_y.shape, test_X.shape, test_y.shape)
  102. return train_X,train_y,test_X,test_y
  103.  
  104. def fit_network(train_X,train_y,test_X,test_y,scaler):
  105. model = Sequential()
  106. model.add(LSTM(50, input_shape=(train_X.shape[1], train_X.shape[2])))
  107. model.add(Dense(1))
  108. model.compile(loss='mae', optimizer='adam')
  109. # fit network
  110. history = model.fit(train_X, train_y, epochs=50, batch_size=72, validation_data=(test_X, test_y), verbose=2, shuffle=False)
  111. # plot history
  112. pyplot.plot(history.history['loss'], label='train')
  113. pyplot.plot(history.history['val_loss'], label='test')
  114. pyplot.legend()
  115. pyplot.show()
  116. # make a prediction
  117. yhat = model.predict(test_X)
  118. test_X = test_X.reshape((test_X.shape[0], test_X.shape[2]))
  119. # invert scaling for forecast
  120. inv_yhat = concatenate((yhat, test_X[:, 1:]), axis=1)
  121. inv_yhat = scaler.inverse_transform(inv_yhat)
  122. inv_yhat = inv_yhat[:,0]
  123. # invert scaling for actual
  124. inv_y = scaler.inverse_transform(test_X)
  125. inv_y = inv_y[:,0]
  126. # calculate RMSE
  127. rmse = sqrt(mean_squared_error(inv_y, inv_yhat))
  128. print('Test RMSE: %.3f' % rmse)
  129.  
  130. if __name__ == '__main__':
  131. drow_pollution()
  132. reframed,scaler = cs_to_sl()
  133. train_X,train_y,test_X,test_y = train_test(reframed)
  134. fit_network(train_X,train_y,test_X,test_y,scaler)

返回信息:

  1. var1(t-1) var2(t-1) var3(t-1) var4(t-1) var5(t-1) var6(t-1) \
  2. 1 0.129779 0.352941 0.245902 0.527273 0.666667 0.002290
  3. 2 0.148893 0.367647 0.245902 0.527273 0.666667 0.003811
  4. 3 0.159960 0.426471 0.229508 0.545454 0.666667 0.005332
  5. 4 0.182093 0.485294 0.229508 0.563637 0.666667 0.008391
  6. 5 0.138833 0.485294 0.229508 0.563637 0.666667 0.009912
  7.  
  8. var7(t-1) var8(t-1) var1(t)
  9. 1 0.000000 0.0 0.148893
  10. 2 0.000000 0.0 0.159960
  11. 3 0.000000 0.0 0.182093
  12. 4 0.037037 0.0 0.138833
  13. 5 0.074074 0.0 0.109658
  14. (8760, 1, 8) (8760,) (35039, 1, 8) (35039,)
  15. WARNING:tensorflow:From C:\3rd\Anaconda2\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
  16. Instructions for updating:
  17. Colocations handled automatically by placer.
  18. WARNING:tensorflow:From C:\3rd\Anaconda2\lib\site-packages\tensorflow\python\ops\math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
  19. Instructions for updating:
  20. Use tf.cast instead.
  21. Train on 8760 samples, validate on 35039 samples
  22. Epoch 1/50
  23. - 2s - loss: 0.0578 - val_loss: 0.0562
  24. Epoch 2/50
  25. - 1s - loss: 0.0413 - val_loss: 0.0563
  26. Epoch 3/50
  27. - 1s - loss: 0.0254 - val_loss: 0.0454
  28. Epoch 4/50
  29. - 1s - loss: 0.0179 - val_loss: 0.0388
  30. Epoch 5/50
  31. - 1s - loss: 0.0158 - val_loss: 0.0237
  32. Epoch 6/50
  33. - 1s - loss: 0.0149 - val_loss: 0.0175
  34. Epoch 7/50
  35. - 1s - loss: 0.0148 - val_loss: 0.0163
  36. Epoch 8/50
  37. - 1s - loss: 0.0147 - val_loss: 0.0160
  38. Epoch 9/50
  39. - 1s - loss: 0.0148 - val_loss: 0.0155
  40. Epoch 10/50
  41. - 1s - loss: 0.0147 - val_loss: 0.0151
  42. Epoch 11/50
  43. - 1s - loss: 0.0146 - val_loss: 0.0148
  44. Epoch 12/50
  45. - 1s - loss: 0.0147 - val_loss: 0.0145
  46. Epoch 13/50
  47. - 1s - loss: 0.0146 - val_loss: 0.0143
  48. Epoch 14/50
  49. - 1s - loss: 0.0146 - val_loss: 0.0143
  50. Epoch 15/50
  51. - 1s - loss: 0.0145 - val_loss: 0.0141
  52. Epoch 16/50
  53. - 1s - loss: 0.0145 - val_loss: 0.0144
  54. Epoch 17/50
  55. - 1s - loss: 0.0147 - val_loss: 0.0140
  56. Epoch 18/50
  57. - 1s - loss: 0.0145 - val_loss: 0.0140
  58. Epoch 19/50
  59. - 1s - loss: 0.0145 - val_loss: 0.0138
  60. Epoch 20/50
  61. - 1s - loss: 0.0145 - val_loss: 0.0138
  62. Epoch 21/50
  63. - 1s - loss: 0.0144 - val_loss: 0.0138
  64. Epoch 22/50
  65. - 1s - loss: 0.0145 - val_loss: 0.0138
  66. Epoch 23/50
  67. - 1s - loss: 0.0146 - val_loss: 0.0137
  68. Epoch 24/50
  69. - 1s - loss: 0.0144 - val_loss: 0.0137
  70. Epoch 25/50
  71. - 1s - loss: 0.0144 - val_loss: 0.0137
  72. Epoch 26/50
  73. - 1s - loss: 0.0144 - val_loss: 0.0136
  74. Epoch 27/50
  75. - 1s - loss: 0.0144 - val_loss: 0.0136
  76. Epoch 28/50
  77. - 1s - loss: 0.0144 - val_loss: 0.0136
  78. Epoch 29/50
  79. - 1s - loss: 0.0145 - val_loss: 0.0137
  80. Epoch 30/50
  81. - 1s - loss: 0.0145 - val_loss: 0.0136
  82. Epoch 31/50
  83. - 1s - loss: 0.0144 - val_loss: 0.0137
  84. Epoch 32/50
  85. - 1s - loss: 0.0144 - val_loss: 0.0136
  86. Epoch 33/50
  87. - 1s - loss: 0.0144 - val_loss: 0.0136
  88. Epoch 34/50
  89. - 1s - loss: 0.0145 - val_loss: 0.0136
  90. Epoch 35/50
  91. - 1s - loss: 0.0144 - val_loss: 0.0135
  92. Epoch 36/50
  93. - 1s - loss: 0.0144 - val_loss: 0.0135
  94. Epoch 37/50
  95. - 1s - loss: 0.0144 - val_loss: 0.0135
  96. Epoch 38/50
  97. - 1s - loss: 0.0144 - val_loss: 0.0135
  98. Epoch 39/50
  99. - 1s - loss: 0.0144 - val_loss: 0.0135
  100. Epoch 40/50
  101. - 1s - loss: 0.0144 - val_loss: 0.0135
  102. Epoch 41/50
  103. - 1s - loss: 0.0143 - val_loss: 0.0135
  104. Epoch 42/50
  105. - 1s - loss: 0.0144 - val_loss: 0.0135
  106. Epoch 43/50
  107. - 1s - loss: 0.0144 - val_loss: 0.0135
  108. Epoch 44/50
  109. - 1s - loss: 0.0144 - val_loss: 0.0135
  110. Epoch 45/50
  111. - 1s - loss: 0.0144 - val_loss: 0.0137
  112. Epoch 46/50
  113. - 1s - loss: 0.0144 - val_loss: 0.0136
  114. Epoch 47/50
  115. - 1s - loss: 0.0143 - val_loss: 0.0135
  116. Epoch 48/50
  117. - 1s - loss: 0.0144 - val_loss: 0.0136
  118. Epoch 49/50
  119. - 1s - loss: 0.0143 - val_loss: 0.0135
  120. Epoch 50/50
  121. - 1s - loss: 0.0144 - val_loss: 0.0134

  1. Test RMSE: 4.401

参考:

https://www.cnblogs.com/tianrunzhi/p/7825671.html

https://www.cnblogs.com/king-lps/p/7846414.html

https://www.cnblogs.com/datablog/p/6127000.html

https://www.cnblogs.com/charlotte77/p/5622325.html

https://www.cnblogs.com/bawu/p/7701810.html

Keras入门——(6)长短期记忆网络LSTM(三)的更多相关文章

  1. 如何预测股票分析--长短期记忆网络(LSTM)

    在上一篇中,我们回顾了先知的方法,但是在这个案例中表现也不是特别突出,今天介绍的是著名的l s t m算法,在时间序列中解决了传统r n n算法梯度消失问题的的它这一次还会有令人杰出的表现吗? 长短期 ...

  2. Keras入门——(7)长短期记忆网络LSTM(四)

    数据准备:http://www.manythings.org/anki/cmn-eng.zip 源代码:https://github.com/pjgao/seq2seq_keras 参考:https: ...

  3. Keras入门——(5)长短期记忆网络LSTM(二)

    参考: https://blog.csdn.net/zwqjoy/article/details/80493341 https://blog.csdn.net/u012735708/article/d ...

  4. Keras入门——(4)长短期记忆网络LSTM(一)

    参考: https://blog.csdn.net/zwqjoy/article/details/80493341 https://blog.csdn.net/u012735708/article/d ...

  5. LSTM - 长短期记忆网络

    循环神经网络(RNN) 人们不是每一秒都从头开始思考,就像你阅读本文时,不会从头去重新学习一个文字,人类的思维是有持续性的.传统的卷积神经网络没有记忆,不能解决这一个问题,循环神经网络(Recurre ...

  6. 递归神经网络之理解长短期记忆网络(LSTM NetWorks)(转载)

    递归神经网络 人类并不是每时每刻都从头开始思考.正如你阅读这篇文章的时候,你是在理解前面词语的基础上来理解每个词.你不会丢弃所有已知的信息而从头开始思考.你的思想具有持续性. 传统的神经网络不能做到这 ...

  7. 理解长短期记忆网络(LSTM NetWorks)

    转自:http://www.csdn.net/article/2015-11-25/2826323 原文链接:Understanding LSTM Networks(译者/刘翔宇 审校/赵屹华 责编/ ...

  8. LSTMs 长短期记忆网络系列

    RNN的长期依赖问题 什么是长期依赖? 长期依赖是指当前系统的状态,可能受很长时间之前系统状态的影响,是RNN中无法解决的一个问题. 如果从(1) “ 这块冰糖味道真?”来预测下一个词,是很容易得出“ ...

  9. LSTM(Long Short-Term Memory)长短期记忆网络

    1. 摘要 对于RNN解决了之前信息保存的问题,例如,对于阅读一篇文章,RNN网络可以借助前面提到的信息对当前的词进行判断和理解,这是传统的网络是不能做到的.但是,对于RNN网络存在长期依赖问题,比如 ...

随机推荐

  1. go语言快速入门教程

    go快速入门指南 by 小强,2019-06-13 go语言是目前非常火热的语言,广泛应用于服务器端,云计算,kubernetes容器编排等领域.它是一种开源的编译型程序设计语言,支持并发.垃圾回收机 ...

  2. Linux - 查看所有服务状态

    ubuntu: service --status-all 例如可查看ssh, apache2等服务是否开启

  3. .NET中的字符串(2):你真的了解.NET中的String吗?

    概述 String在任何语言中,都有它的特殊性,在.NET中也是如此.它属于基本数据类型,也是基本数据类型中唯一的引用类型.字符串可以声明为常量,但是它却放在了堆中.希望通过本文能够使大家对.NET中 ...

  4. Linux学习:进入与退出系统

    进入Linux系统:必须要输入用户的账号,在系统安装过程中可以创建以下两种帐号: 1.root--超级用户帐号(系统管理员),使用这个帐号可以在系统中做任何事情. 2.普通用户--这个帐号供普通用户使 ...

  5. pytest-conftest.py作用范围

    1.conftest.py解释 conftest.py是pytest框架里面一个很重要的东西,它可以在这个文件里面编写fixture,而这个fixture的作用就相当于我们unittest框架里面的s ...

  6. string和stringBuffer,stringBuilder的区别

    1,String类的内容一旦声明后是不可改变的,改变的只是其内存的指向,而StringBuffer类的对象内容是可以改变的. 2,对于StringBuffer,不能像String那样直接通过赋值的方式 ...

  7. 思科交换机配置单播MAC地址过滤

    1.其他厂商: 在华为,华三等设备上,我们都有“黑洞MAC地址表项” 的配置,其特点是手动配置.不会老化,且重启后也不会丢失.例如如下示例: 黑洞表项是特殊的静态MAC地址表项,丢弃含有特定源MAC地 ...

  8. ConcurrentHashMap 实现缓存类

    参考:https://blog.csdn.net/woshilijiuyi/article/details/81335497 在规定时间内,使用 hashMap 实现一个缓存工具类,需要考虑一下几点 ...

  9. 阅读build to win的个人感想

    一个程序员要向各个方面学习,向市场.向用户学习等,不能局限于一方面.除此以外还要有自己的想法,要懂得创新,也需要在各个方面都有所突破,有所超越,实力才是取得胜利的根关键.

  10. 「CTSC2008」网络管理

    「CTSC2008」网络管理 传送门 整体二分做法,应该和这题一样的吧. 就是把序列换成树,第 \(k\) 小换成第 \(k\) 大. 然后就切了... 参考代码: #include <algo ...