## 导入所需的包

import pandas as pd

import numpy as np

import matplotlib.pyplot as plt

import tensorflow as tf

tf.reset_default_graph()

plt.rcParams['font.sans-serif'] = 'SimHei' ##设置字体为SimHei显示中文

plt.rcParams['axes.unicode_minus'] = False ##设置正常显示符号

## 导入所需数据

df = pd.read_csv('日元-人民币.csv',encoding='gbk',engine='python')

df['时间'] = pd.to_datetime(df['时间'],format='%Y/%m/%d')

df = df.sort_values(by='时间')

df.head()

## 用折线图展示数据

plt.figure(figsize=(12,8))

plt.title('1999年1月1日到2018年8月21日最高价数据曲线')

plt.plot(df['time'],df['高'])

plt.show()

### 提取测试数据

data = df.loc[:,['time','高']]

## 标准化数据

data['高'] = (data['高']-np.mean(data['高']))/np.std(data['高'])

data['高(预)'] = data['高'].shift(-1)

data = data.iloc[:data.shape[0]-1]

data.columns = ['时间','x','y']

data.head()

#获取最高价序列

data=np.array(df['高'])

normalize_data=(data-np.mean(data))/np.std(data) #标准化

normalize_data=normalize_data[:,np.newaxis] #增加维度

#———————————————形成训练集——————————————————

#设置常量

time_step=20 #时间步

rnn_unit=10 #hidden layer units

batch_size=60 #每一批次训练多少个样例

input_size=1 #输入层维度

output_size=1 #输出层维度

lr=0.0006 #学习率

train_x,train_y=[],[] #训练集

for i in range(len(normalize_data)-time_step-1):

x=normalize_data[i:i+time_step]

y=normalize_data[i+1:i+time_step+1]

train_x.append(x.tolist())

train_y.append(y.tolist())

test_x = train_x[len(train_x)-31:len(train_x)-1]

test_y = train_y[len(train_y)-31:len(train_y)-1]

X=tf.placeholder(tf.float32, [None,time_step,input_size]) #每批次输入网络的tensor

Y=tf.placeholder(tf.float32, [None,time_step,output_size]) #每批次tensor对应的标签

#输入层、输出层权重、偏置

weights={

'in':tf.Variable(tf.random_normal([input_size,rnn_unit])),

'out':tf.Variable(tf.random_normal([rnn_unit,1]))

}

biases={

'in':tf.Variable(tf.constant(0.1,shape=[rnn_unit,])),

'out':tf.Variable(tf.constant(0.1,shape=[1,]))

}

def lstm(batch): #参数:输入网络批次数目

w_in=weights['in']

b_in=biases['in']

input=tf.reshape(X,[-1,input_size]) #需要将tensor转成2维进行计算,计算后的结果作为隐藏层的输入

input_rnn=tf.matmul(input,w_in)+b_in

input_rnn=tf.reshape(input_rnn,[-1,time_step,rnn_unit]) #将tensor转成3维,作为lstm cell的输入

cell=tf.nn.rnn_cell.BasicLSTMCell(rnn_unit)

init_state=cell.zero_state(batch,dtype=tf.float32)

output_rnn,final_states=tf.nn.dynamic_rnn(cell,input_rnn,initial_state=init_state, dtype=tf.float32) #output_rnn是记录lstm每个输出节点的结果,final_states是最后一个cell的结果

output=tf.reshape(output_rnn,[-1,rnn_unit]) #作为输出层的输入

w_out=weights['out']

b_out=biases['out']

pred=tf.matmul(output,w_out)+b_out

return pred,final_states

def train_lstm():

global batch_size

pred,_=lstm(batch_size)

#损失函数
loss=tf.reduce_mean(tf.square(tf.reshape(pred,[-1])-tf.reshape(Y, [-1])))

train_op=tf.train.AdamOptimizer(lr).minimize(loss)

saver=tf.train.Saver(tf.global_variables())

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

#重复训练100次

for i in range(100):

step=0

start=0

end=start+batch_size

while(end<len(train_x)): _,loss_=sess.run([train_op,loss],feed_dict={X:train_x[start:end],Y:train_y[start:end]})

start+=batch_size

end=start+batch_size

#每10步保存一次参数

if step%10==0:

print(i,step,loss_)

print("保存模型:",saver.save(sess,'.\stock.model'))

step+=1

def prediction():

pred,_=lstm(1) #预测时只输入[1,time_step,input_size]的测试数据

saver=tf.train.Saver(tf.global_variables())

with tf.Session() as sess:

#参数恢复

module_file = tf.train.latest_checkpoint('./')

saver.restore(sess, module_file)

#取训练集最后一行为测试样本。shape=[1,time_step,input_size]

prev_seq=train_x[-31]

predict=[]

#得到之后100个预测结果

for i in range(100):

next_seq=sess.run(pred,feed_dict={X:[prev_seq]})

predict.append(next_seq[-1])

#每次得到最后一个时间步的预测结果,与之前的数据加在一起,形成新的测试样本

prev_seq=np.vstack((prev_seq[1:],next_seq[-1]))

#以折线图表示结果

plt.figure()

plt.plot(list(range(len(normalize_data))), normalize_data, color='b')

plt.plot(list(range(len(normalize_data), len(normalize_data) + len(predict))), predict, color='r')

plt.show()

with tf.variable_scope('train'):

train_lstm()

with tf.variable_scope('train',reuse=True):

prediction()

基于python的机器学习实现日元币对人民币汇率预测的更多相关文章

  1. 基于Python的机器学习实战:KNN

    1.KNN原理: 存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一个数据与所属分类的对应关系.输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应 ...

  2. 基于python的机器学习开发环境安装(最简单的初步开发环境)

    一.安装Python 1.下载安装python3.6 https://www.python.org/getit/ 2.配置环境变量(2个) 略...... 二.安装Python算法库 安装顺序:Num ...

  3. 基于Python的机器学习实战:Apriori

    目录: 1.关联分析 2. Apriori 原理 3. 使用 Apriori 算法来发现频繁集 4.从频繁集中挖掘关联规则 5. 总结 1.关联分析  返回目录 关联分析是一种在大规模数据集中寻找有趣 ...

  4. 基于Python的机器学习实战:AadBoost

    目录: 1. Boosting方法的简介 2. AdaBoost算法 3.基于单层决策树构建弱分类器 4.完整的AdaBoost的算法实现 5.总结 1. Boosting方法的简介 返回目录 Boo ...

  5. 搭建基于python +opencv+Beautifulsoup+Neurolab机器学习平台

    搭建基于python +opencv+Beautifulsoup+Neurolab机器学习平台 By 子敬叔叔 最近在学习麦好的<机器学习实践指南案例应用解析第二版>,在安装学习环境的时候 ...

  6. 初识TPOT:一个基于Python的自动化机器学习开发工具

    1. TPOT介绍 一般来讲,创建一个机器学习模型需要经历以下几步: 数据预处理 特征工程 模型选择 超参数调整 模型保存 本文介绍一个基于遗传算法的快速模型选择及调参的方法,TPOT:一种基于Pyt ...

  7. 【Machine Learning】决策树案例:基于python的商品购买能力预测系统

    决策树在商品购买能力预测案例中的算法实现 作者:白宁超 2016年12月24日22:05:42 摘要:随着机器学习和深度学习的热潮,各种图书层出不穷.然而多数是基础理论知识介绍,缺乏实现的深入理解.本 ...

  8. 从Theano到Lasagne:基于Python的深度学习的框架和库

    从Theano到Lasagne:基于Python的深度学习的框架和库 摘要:最近,深度神经网络以“Deep Dreams”形式在网站中如雨后春笋般出现,或是像谷歌研究原创论文中描述的那样:Incept ...

  9. 基于Python使用SVM识别简单的字符验证码的完整代码开源分享

    关键字:Python,SVM,字符验证码,机器学习,验证码识别 1   概述 基于Python使用SVM识别简单的验证字符串的完整代码开源分享. 因为目前有了更厉害的新技术来解决这类问题了,但是本文作 ...

随机推荐

  1. Netty入门(五)ChanneHandler

    本节主要讨论了 Netty 的数据处理组件 ChannelHandler. 一.Channel 生命周期 Channel 有个简单但强大的状态模型,下面是 Channel 的四个状态: Channel ...

  2. CentOS7+Nginx设置Systemctl restart nginx.service服务

    centos 7上是用Systemd进行系统初始化的,Systemd 是 Linux 系统中最新的初始化系统(init),它主要的设计目标是克服 sysvinit 固有的缺点,提高系统的启动速度.关于 ...

  3. 【转】使用Chrome Frame,彻底解决浏览器兼容问题

    本文转自http://www.ryanbay.com/?p=269,感谢该作者的总结 X-UA-Compatible是自从IE8新加的一个设置,对于IE8以下的浏览器是不识别的. 通过在meta中设置 ...

  4. OpenCV——直方图计算、寻早最值位置和对比匹配(判断两幅图的相似程度)

  5. Kafka设计解析(二十三)关于Kafka监控方案的讨论

    转载自 huxihx,原文链接 关于Kafka监控方案的讨论 目前Kafka监控方案看似很多,然而并没有一个“大而全”的通用解决方案.各家框架也是各有千秋,以下是我了解到的一些内容: 一.Kafka ...

  6. 未能从程序集“System.Transactions, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089”中加载类型“System.Transactions.TransactionScopeAsyncFlowOption”

    项目发布到IIS以后,报以下错误 出现以上问题的原因是,我的项目是在Framework 4.5.2下开发的,而发布程序的服务器FM版本是4.5 .我解决办法是安装Framework 4.6.2 具体办 ...

  7. Scala--数组相关操作

    一.定长数组 Array定长数组,访问数组元素需要通过()  数组长度是固定的,但是内容可以修改 val nums = new Array[Int](10) //长度为10的int数组 初始化为0 v ...

  8. Postman无法正常启动解决办法

    问题描述: 应用程序窗口能够打开,但就是这样一直空白,什么都不显示.接下来,主窗口以纯白色加载,不显示任何其他内容. 接下来主窗口背景米色加载和菜单栏加载和工作.应用程序将永远保持这样, 有时界面会变 ...

  9. cli 开发记录

    最近要开发一个 cli,主要作用是方便同事生成前端项目,做了一天半,基本参考的是 vue-cli. cli 要实现的功能: 用 cnpm install zt-cli -g 全局安装,这个就要将你做的 ...

  10. 7、Class文件的格式

    Class文件的格式 1.magic(魔数) 身份标识,用来标记这是不是一个CLASS文件 CLASS的魔数比较有浪漫气息,是0xCAFEBABE(咖啡宝贝),也标识着将来JAVA咖啡商标: 2.之后 ...