import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf #——————————————————导入数据——————————————————————
f=open('./dataset/dataset_1.csv')
df=pd.read_csv(f) #读入股票数据
data=np.array(df['最高价']) #获取最高价序列
data=data[::-1] #反转,使数据按照日期先后顺序排列
#以折线图展示data
# plt.figure()
# plt.plot(data)
# plt.show()
normalize_data=(data-np.mean(data))/np.std(data) #标准化
normalize_data=normalize_data[:,np.newaxis] #增加维度 #生成训练集
#设置常量
time_step=20 #时间步
rnn_unit=10 #hidden layer units
batch_size=60 #每一批次训练多少个样例
input_size=1 #输入层维度
output_size=1 #输出层维度
lr=0.0006 #学习率
train_x,train_y=[],[] #训练集
for i in range(len(normalize_data)-time_step-1):
x=normalize_data[i:i+time_step]
y=normalize_data[i+1:i+time_step+1]
train_x.append(x.tolist())
train_y.append(y.tolist()) #——————————————————定义神经网络变量——————————————————
X=tf.placeholder(tf.float32, [None,time_step,input_size]) #每批次输入网络的tensor
Y=tf.placeholder(tf.float32, [None,time_step,output_size]) #每批次tensor对应的标签
#输入层、输出层权重、偏置
weights={
'in':tf.Variable(tf.random_normal([input_size,rnn_unit])),
'out':tf.Variable(tf.random_normal([rnn_unit,1]))
}
biases={
'in':tf.Variable(tf.constant(0.1,shape=[rnn_unit,])),
'out':tf.Variable(tf.constant(0.1,shape=[1,]))
} #——————————————————定义神经网络变量——————————————————
def lstm(batch): #参数:输入网络批次数目
w_in=weights['in']
b_in=biases['in']
input=tf.reshape(X,[-1,input_size]) #需要将tensor转成2维进行计算,计算后的结果作为隐藏层的输入
input_rnn=tf.matmul(input,w_in)+b_in
input_rnn=tf.reshape(input_rnn,[-1,time_step,rnn_unit]) #将tensor转成3维,作为lstm cell的输入
cell=tf.nn.rnn_cell.BasicLSTMCell(rnn_unit)
init_state=cell.zero_state(batch,dtype=tf.float32)
output_rnn,final_states=tf.nn.dynamic_rnn(cell, input_rnn,initial_state=init_state, dtype=tf.float32) #output_rnn是记录lstm每个输出节点的结果,final_states是最后一个cell的结果
output=tf.reshape(output_rnn,[-1,rnn_unit]) #作为输出层的输入
w_out=weights['out']
b_out=biases['out']
pred=tf.matmul(output,w_out)+b_out
return pred,final_states #——————————————————训练模型——————————————————
def train_lstm():
global batch_size
pred,_=lstm(batch_size)
#损失函数
loss=tf.reduce_mean(tf.square(tf.reshape(pred,[-1])-tf.reshape(Y, [-1])))
train_op=tf.train.AdamOptimizer(lr).minimize(loss)
saver=tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#重复训练10000次
for i in range(10000):
step=0
start=0
end=start+batch_size
while(end<len(train_x)):
_,loss_=sess.run([train_op,loss],feed_dict={X:train_x[start:end],Y:train_y[start:end]})
start+=batch_size
end=start+batch_size
#每10步保存一次参数
if step%10==0:
print(i,step,loss_)
print("保存模型:",saver.save(sess,'./module2/stock.model'))
step+=1 #————————————————预测模型————————————————————
def prediction():
pred,_=lstm(1) #预测时只输入[1,time_step,input_size]的测试数据
saver=tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
#参数恢复
module_file = tf.train.latest_checkpoint('./module2/')
saver.restore(sess, module_file) #取训练集最后一行为测试样本。shape=[1,time_step,input_size]
prev_seq=train_x[-1]
predict=[]
#得到之后100个预测结果
for i in range(100):
next_seq=sess.run(pred,feed_dict={X:[prev_seq]})
predict.append(next_seq[-1])
#每次得到最后一个时间步的预测结果,与之前的数据加在一起,形成新的测试样本
prev_seq=np.vstack((prev_seq[1:],next_seq[-1]))
#以折线图表示结果
plt.figure()
plt.plot(list(range(len(normalize_data))), normalize_data, color='b')
plt.plot(list(range(len(normalize_data), len(normalize_data) + len(predict))), predict, color='r')
plt.show() if __name__ == '__main__': # train_lstm()
prediction()
hello = tf.constant('Hello, TensorFlow!')
sess = tf.Session()
print(sess.run(hello))

数据集格式:

时间           最高价
2015/12/11 3455.55
2015/12/10 3503.65
2015/12/9 3495.7
2015/12/8 3518.65
2015/12/7 3543.95
2015/12/4 3568.97
2015/12/3 3591.73
2015/12/2 3538.85
2015/12/1 3483.41
2015/11/30 3470.37
2015/11/27 3621.9
2015/11/26 3668.38
2015/11/25 3648.37
2015/11/24 3616.48
2015/11/23 3654.75
2015/11/20 3640.53
2015/11/19 3618.21
2015/11/18 3617.07
2015/11/17 3678.27

Tensflow预测股票实例的更多相关文章

  1. 基于Spark Streaming预测股票走势的例子(一)

    最近学习Spark Streaming,不知道是不是我搜索的姿势不对,总找不到具体的.完整的例子,一怒之下就决定自己写一个出来.下面以预测股票走势为例,总结了用Spark Streaming开发的具体 ...

  2. 通过机器学习的线性回归算法预测股票走势(用Python实现)

    在本人的新书里,将通过股票案例讲述Python知识点,让大家在学习Python的同时还能掌握相关的股票知识,所谓一举两得.这里给出以线性回归算法预测股票的案例,以此讲述通过Python的sklearn ...

  3. Tensorflow实例:利用LSTM预测股票每日最高价(一)

    RNN与LSTM 这一部分主要涉及循环神经网络的理论,讲的可能会比较简略. 什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神 ...

  4. 20岁少年小伙利用Python_SVM预测股票趋势月入十万!

      在做数据预处理的时候,超额收益率是股票行业里的一个专有名词,指大于无风险投资的收益率,在我国无风险投资收益率即是银行定期存款. pycharm + anaconda3.6开发,涉及到的第三方库有p ...

  5. 基于Spark Streaming预测股票走势的例子(二)

    上一篇博客中,已经对股票预测的例子做了简单的讲解,下面对其中的几个关键的技术点再作一些总结. 1.updateStateByKey 由于在1.6版本中有一个替代函数,据说效率比较高,所以作者就顺便研究 ...

  6. AI金融:LSTM预测股票

    第一部分:从RNN到LSTM 1.什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神经网络模型中,从输入层到隐含层再到输出层, ...

  7. AI金融:利用LSTM预测股票每日最高价

    第一部分:从RNN到LSTM 1.什么是RNN RNN全称循环神经网络(Recurrent Neural Networks),是用来处理序列数据的.在传统的神经网络模型中,从输入层到隐含层再到输出层, ...

  8. 如何预测股票分析--先知(Prophet)

    在上一篇中,我们探讨了自动ARIMA,但是好像表现的还是不够完善,接下来看看先知的力量! 先知(Prophet) 有许多时间序列技术可以用在股票预测数据集上,但是大多数技术在拟合模型之前需要大量的数据 ...

  9. 《BI那点儿事》Microsoft 逻辑回归算法——预测股票的涨跌

    数据准备:一组股票历史成交数据(股票代码:601106 中国一重),起止日期:2011-01-04至今,其中变量有“开盘”.“最高”.“最低”.“收盘”.“总手”.“金额”.“涨跌”等 UPDATE ...

随机推荐

  1. Java_5.2 数组应用:*的打印

    1五行五列的* ************************* public static void main(String[] args) { for (int i = 1; i <= 5 ...

  2. spring boot 访问项目时加项目名称

    pringboot 项目一般直接地址加端口就可以访问了,不像放在tomcat里面还需要加上项目名. 现在,想访问的时候加上项目名用来区分,只要在配置文件里面加上 server.context-path ...

  3. node.js下载安装

    1.下载node.js在node中文网站,官方网站下载太慢 2.接着让我们点击下载链接,页面上呈现出你所需要下载的安装包,我们这里选择windows x64的安装包进行下载 3.安装node.js,一 ...

  4. Taxi

    /* After the lessons n groups of schoolchildren went outside and decided to visit Polycarpus to cele ...

  5. Android.Libraries

    1. Android Dependencies, Referenced Libraries, Android Private Libraries Android Private Libraries - ...

  6. (O)JS核心:call、apply和bind

    1. var func=function(a,b,c){ console.log([a,b,c]); }; func.apply(null,[1,2,3]); //[1,2,3] func.call( ...

  7. ubuntu12.04下Qt调试器的使用

    最近,我一直在用Qt编写C++程序,但在编写过程中遇到了问题,想用Qt Creator中的调试器调试一下,但调试时(在Qt Creator中已配置好相应的调试器)出现“ ptrace:Operatio ...

  8. linux学习第一天 (Linux就该这么学) 找到一本不错的Linux电子书,附《Linux就该这么学》章节目录

    本书是由全国多名红帽架构师(RHCA)基于最新Linux系统共同编写的高质量Linux技术自学教程,极其适合用于Linux技术入门教程或讲课辅助教材,目前是国内最值得去读的Linux教材,也是最有价值 ...

  9. default(T) 和 typeof 和 GetType()

    一.default(T) 在泛型编成中如果不限制T类型参数是值类型或引用类型的话 你程序内部可能会出现错误,因为值类型不允许NULL.所以default用来获取一个类型的默认值,对于值类型得到new ...

  10. 使用scrollTop返回顶部

    scrollTop属性表示被隐藏在内容区域上方的像素数.元素未滚动时,scrollTop的值为0,如果元素被垂直滚动了,scrollTop的值大于0,且表示元素上方不可见内容的像素宽度 由于scrol ...