作者:韩信子@ShowMeAI

深度学习实战系列https://www.showmeai.tech/tutorials/42

TensorFlow 实战系列https://www.showmeai.tech/tutorials/43

本文地址https://www.showmeai.tech/article-detail/327

声明:版权所有,转载请联系平台与作者并注明出处

收藏ShowMeAI查看更多精彩内容

股票价格数据是一个时间序列形态的数据,诚然,股市的涨落和各种利好利空消息更相关,更多体现的是人们的信心状况,但是它的形态下,时序前后是有一定的相关性的,我们可以使用一种特殊类型的神经网络『循环神经网络 (RNN)』来对这种时序相关的数据进行建模和学习。

在本篇内容中,ShowMeAI将给大家演示,如何构建训练神经网络并将其应用在股票数据上进行预测。

对于循环神经网络的详细信息讲解,大家可以阅读ShowMeAI整理的系列教程和文章详细了解:

数据获取

在实际建模与训练之前,我们需要先获取股票数据。下面的代码使用 Ameritrade API 获取并生成数据,也可以使用其他来源。

  1. import matplotlib.pyplot as plt
  2. import mplfinance as mpl
  3. import pandas as pd
  4. td_consumer_key = 'YOUR-KEY-HERE'
  5. # 美国航空股票
  6. ticker = 'AAL'
  7. ##periodType - day, month, year, ytd
  8. ##period - number of periods to show
  9. ##frequencyTYpe - type of frequency for each candle - day, month, year, ytd
  10. ##frequency - the number of the frequency type in each candle - minute, daily, weekly
  11. endpoint = 'https://api.tdameritrade.com/v1/marketdata/{stock_ticker}/pricehistory?periodType={periodType}&period={period}&frequencyType={frequencyType}&frequency={frequency}'
  12. # 获取数据
  13. full_url = endpoint.format(stock_ticker=ticker,periodType='year',period=10,frequencyType='daily',frequency=1)
  14. page = requests.get(url=full_url,params={'apikey' : td_consumer_key})
  15. content = json.loads(page.content)
  16. # 转成pandas可处理格式
  17. df = pd.json_normalize(content['candles'])
  18. # 设置时间戳为索引
  19. df['timestamp'] = pd.to_datetime(df.datetime, unit='ms')
  20. df = df.set_index("timestamp")
  21. # 绘制数据
  22. plt.figure(figsize=(15, 6), dpi=80)
  23. plt.plot(df['close'])
  24. plt.legend(['Closing Price'])
  25. plt.show()
  26. # 存储前一天的数据
  27. df["previous_close"] = df["close"].shift(1)
  28. df = df.dropna() # 删除缺失值
  29. # 存储
  30. df.to_csv('../data/stock_'+ticker+'.csv', mode='w', index=True, header=True)

上面的代码查询 Ameritrade API 并返回 10 年的股价数据,例子中的股票为『美国航空公司』。 数据绘图结果如下所示:

数据处理

我们加载刚才下载的数据文件,并开始处理预测。

  1. # 读取数据
  2. ticker = 'AAL'
  3. df = pd.read_csv("../data/stock_"+ticker+".csv")
  4. # 设置索引
  5. df['DateIndex'] = pd.to_datetime(df['timestamp'], format="%Y/%m/%d")
  6. df = df.set_index('DateIndex')

下面我们对数据进幅度缩放,以便更好地送入神经网络和训练。(神经网络是一种对于输入数据幅度敏感的模型,不同字段较大的幅度差异,会影响网络的训练收敛速度和精度。)

  1. # 幅度缩放
  2. df2 = df
  3. cols = ['close', 'volume', 'previous_close']
  4. features = df2[cols]
  5. scaler = MinMaxScaler(feature_range=(0, 1)).fit(features.values)
  6. features = scaler.transform(features.values)
  7. df2[cols] = features

在这里,我们重点处理了收盘价成交量前几天收盘价列

数据切分

接下来我们将数据拆分为训练和测试数据集。

  1. # 收盘价设为目标字段
  2. X = df2.drop(['close','timestamp'], axis =1)
  3. y = df2['close']
  4. import math
  5. # 计算切分点(以80%的训练数据为例)
  6. train_percentage = 0.8
  7. split_point = math.floor(len(X) * train_percentage)
  8. # 时序切分
  9. train_x, train_y = X[:split_point], y[:split_point]
  10. test_x, test_y = X[split_point:], y[split_point:]

接下来,我们对数据进行处理,构建滑窗数据,沿时间序列创建数据样本。(因为我们需要基于历史信息对未来的数值进行预测)

  1. # 构建滑窗数据
  2. import numpy.lib
  3. from numpy.lib.stride_tricks import sliding_window_view
  4. def genWindows(X_in, y_in, window_size):
  5. X_out = []
  6. y_out = []
  7. length = X_in.shape[0]
  8. for i in range(window_size, length):
  9. X_out.append(X_in[i-window_size:i, 0:4])
  10. y_out.append(y_in[i-1])
  11. return np.array(X_out), np.array(y_out)
  12. # 窗口大小为5
  13. window_size = 5
  14. X_train_win, y_train_win = genWindows(np.array(train_x), np.array(train_y), window_size)
  15. X_test_win, y_test_win = genWindows(np.array(test_x), np.array(test_y), window_size)

模型构建&训练

构建完数据之后,我们就要构建 RNN 模型了,具体的代码如下所示。注意到下面使用了1个回调函数,模型会在验证集性能没有改善的情况下提前停止训练,防止模型过拟合影响泛化能力。

  1. from tensorflow.keras import callbacks
  2. # 早停止 回调函数
  3. callback_early_stopping = callbacks.EarlyStopping(
  4. monitor="loss",
  5. patience=10,#look at last 10 epochs
  6. min_delta=0.0001,#loss must improve by this amount
  7. restore_best_weights=True,
  8. )
  9. from tensorflow import keras
  10. from tensorflow.keras import layers
  11. from keras.models import Sequential
  12. # 构建RNN模型,结构为 输入-RNN-RNN-连续值输出
  13. input_shape=(X_train_win.shape[1],X_train_win.shape[2])
  14. print(input_shape)
  15. model = Sequential(
  16. [
  17. layers.Input(shape=input_shape),
  18. layers.SimpleRNN(units=128, return_sequences=True),
  19. layers.SimpleRNN(64, return_sequences=False),
  20. layers.Dense(1, activation="linear"),
  21. ]
  22. )
  23. # 优化器
  24. optimizer = keras.optimizers.Nadam(learning_rate=0.0001)
  25. model.compile(optimizer=optimizer, loss="mse")
  26. # 模型结构总结
  27. model.summary()
  28. # 模型训练
  29. batch_size = 20
  30. epochs = 50
  31. history = model.fit(X_train_win, y_train_win,
  32. batch_size=batch_size, epochs=epochs,
  33. callbacks=[
  34. callback_early_stopping
  35. ])

模型训练过程的损失函数(训练集上)的变化如下图所示。随着训练过程推进,模型损失不断优化,初期的优化和loss减小速度很快,后逐渐趋于平稳。

大约 10 个 epoch 后达到了最佳结果,训练好的模型就可以用于后续预测了,我们可以先对训练集进行预测,验证一下在训练集上学习的效果。

  1. # 训练集预测
  2. pred_train_y = model.predict(X_train_win)
  3. # 绘图
  4. plt.figure(figsize=(15, 6), dpi=80)
  5. plt.plot(np.array(train_y))
  6. plt.plot(pred_train_y)
  7. plt.legend(['Actual', 'Predictions'])
  8. plt.show()

模型在训练集上学习的效果还不错,大家可以看到预测结果和真实值对比绘图如下:

模型预测&应用

我们要评估模型的真实表现,需要在它没有见过的测试数据上评估,大家记得我们在数据切分的时候预留了 20% 的数据,下面我们用模型在这部分数据上预测并评估。

  1. # 测试集预测
  2. pred_test_y = model.predict(X_test_win)
  3. # 预测结果绘制
  4. plt.figure(figsize=(15, 6), dpi=80)
  5. plt.plot(np.array(test_y))
  6. plt.plot(pred_test_y)
  7. plt.legend(['Actual', 'Predictions'])
  8. plt.show()

相对训练集来说,大家看到测试集上的效果稍有偏差,但是总体趋势还是预测得不错。

我们要考察这个模型对于时间序列预测的泛化能力,可以进行更严格一点的建模预测,比如将训练得到的模型应用与另一支完全没见过的股票上进行预测。如下为我们训练得到的模型对 Microsoft/微软股票价格的预测:

我们从图上可以看到,模型表现良好(预测存在一定程度的噪音,但它对总体趋势的预测比较准确)。

参考资料

TensorFlow深度学习!构建神经网络预测股票价格!⛵的更多相关文章

  1. 没有博士学位,照样玩转TensorFlow深度学习

    教程 | 没有博士学位,照样玩转TensorFlow深度学习 机器之心2017-01-24 12:32:22 程序设计 谷歌 操作系统 阅读(362)评论(0) 选自Codelabs 机器之心编译 参 ...

  2. 针对深度学习(神经网络)的AI框架调研

    针对深度学习(神经网络)的AI框架调研 在我们的AI安全引擎中未来会使用深度学习(神经网络),后续将引入AI芯片,因此重点看了下业界AI芯片厂商和对应芯片的AI框架,包括Intel(MKL CPU). ...

  3. (转) TensorFlow深度学习,一篇文章就够了

    TensorFlow深度学习,一篇文章就够了 2016/09/22 · IT技术 · TensorFlow, 深度学习 分享到:6   原文出处: 我爱计算机 (@tobe迪豪 )    作者: 陈迪 ...

  4. TensorFlow深度学习,一篇文章就够了

    http://blog.jobbole.com/105602/ 作者: 陈迪豪,就职小米科技,深度学习工程师,TensorFlow代码提交者. TensorFlow深度学习框架 Google不仅是大数 ...

  5. 问题集录--TensorFlow深度学习

    TensorFlow深度学习框架 Google不仅是大数据和云计算的领导者,在机器学习和深度学习上也有很好的实践和积累,在2015年年底开源了内部使用的深度学习框架TensorFlow. 与Caffe ...

  6. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.2

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.2 http://blog.csdn.net/sunbow0 ...

  7. TensorFlow 深度学习中文第二版·翻译完成

    原文:Deep Learning with TensorFlow Second Edition 协议:CC BY-NC-SA 4.0 不要担心自己的形象,只关心如何实现目标.--<原则>, ...

  8. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1

    3.Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1 http://blog.csdn.net/sunbow0 ...

  9. windows下Anaconda3配置TensorFlow深度学习库

    Anaconda3(python3.6)安装tensorflow Anaconda3中安装tensorflow3是非常简单的,仅需通过 pip install tensorflow 测试代码: imp ...

随机推荐

  1. Qt 创建按钮动画

    1 封装自定义按钮 myPushBttton 2 构造函数 (默认图片,按下后显示图片) 3 测试开始按钮 4 开始制作特效 5 zoom1 向下弹跳 6 zoom2 向上弹跳 代码如下 main.h ...

  2. KingbaseES 的行列转换

    目录 背景 行转列 数据准备 分组聚合函数+CASE 根据压缩数据的格式,横向展开数据列选取不同方式 crosstab函数 PIVOT 操作符 PIVOT 操作符的限制 工具 ksql 的元命令 \c ...

  3. .NET 纯原生实现 Cron 定时任务执行,未依赖第三方组件 (Timer 优化版)

    在上个月写过一篇 .NET 纯原生实现 Cron 定时任务执行,未依赖第三方组件 的文章,当时 CronSchedule 的实现是使用了,每个服务都独立进入到一个 while 循环中,进行定期扫描是否 ...

  4. K8S_三种Port区别总结

    nodePort: 外部流量访问K8S集群中Service入口的一种方式 比如外部用户要访问k8s集群中的一个Web应用,那么我们可以配置对应service的type=NodePort,nodePor ...

  5. JDK 自带的服务发现框架 ServiceLoader 好用吗?

    请点赞关注,你的支持对我意义重大. Hi,我是小彭.本文已收录到 Github · AndroidFamily 中.这里有 Android 进阶成长知识体系,有志同道合的朋友,关注公众号 [彭旭锐] ...

  6. QT学习(三)

    首先整理一下编码的方法.对于一个待解决的问题,首先应该将大问题分解成小问题,将小问题划分为小小问题... 然后再进行类的抽象,将划分成的问题和类进行对应.然后再对划分的小..问题进行具体的处理分析,划 ...

  7. MySQL数据备份 mysqldump 详解

    MySQL数据备份流程 1 打开cmd窗口 通过命令进行数据备份与恢复: 需要在Windows的命令行窗口中进行: l 开始菜单,在运行中输入cmd回车: l 或者win+R,然后输入cmd回车,即可 ...

  8. 海康摄像机使用GB28181接入SRS服务器的搭建步骤---源码安装的方式

    下载代码 地址:https://github.com/ossrs/srs-gb28181 https://github.com/ossrs/srs-gb28181.git 注意:使用的是含有gb281 ...

  9. k8s实际操作中的小知识点

    1.批量执行yaml文件 # 把所有要执行的yaml文件放在同一个目录下,并且切换到这个目录下 kubectl apply -f . 2.利用pod的亲和和反亲和功能把pod调度到不同的node上 亲 ...

  10. 转载---Beats:如何使用Filebeat将MySQL日志发送到Elasticsearch

    在今天的文章中,我们来详细地描述如果使用Filebeat把MySQL的日志信息传输到Elasticsearch中.为了说明问题的方便,我们的测试系统的配置是这样的: 我有一台MacOS机器.在上面我安 ...