1.基本数据绘制成图

数据有15天股票的开盘价格和收盘价格,可以通过比较当天开盘价格和收盘价格的大小来判断当天股票价格的涨跌情况,红色表示涨,绿色表示跌,测试代码如下:

 # encoding:utf-8

 import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
date = np.linspace(1, 15, 15)
# 当天的收盘价格
endPrice = np.array([2511.90,2538.26,2510.68,2591.66,2732.98,2701.69,2701.29,2678.67,2726.50,2681.50,2739.17,2715.07,2823.58,2864.90,2919.08]
)
# 当天的开盘价格
beginPrice = np.array([2438.71,2500.88,2534.95,2512.52,2594.04,2743.26,2697.47,2695.24,2678.23,2722.13,2674.93,2744.13,2717.46,2832.73,2877.40])
print(date) # 打印日期
plt.figure()
for i in range(0,15):
# 通过循环遍历数据画出柱状图
dateOne = np.zeros([2])
dateOne[0] = i
dateOne[1] = i
print(dateOne)
priceOne = np.zeros([2])
priceOne[0] = beginPrice[i]
priceOne[1] = endPrice[i]
if endPrice[i] > beginPrice[i]:
# 如果收盘价格大于开盘价格说明股票上涨 用红色表示 lw为线条粗细
plt.plot(dateOne, priceOne,'r',lw=8)
else:
# 如果收盘价格小于开盘价格说明股票下跌 用绿色表示 lw为线条粗细
plt.plot(dateOne, priceOne,'g',lw=5)
plt.show()

运行后的图如下:

2.人工神经网络进行预测

建立一个简单的三层人工神经网络。

循环的终止条件可以为预先设定的循环次数或者与真实值的差异百分比

功能实现,完整的测试代码如下:

 # encoding:utf-8

 import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
date = np.linspace(1, 15, 15)
# 当天的收盘价格
endPrice = np.array([2511.90,2538.26,2510.68,2591.66,2732.98,2701.69,2701.29,2678.67,2726.50,2681.50,2739.17,2715.07,2823.58,2864.90,2919.08]
)
# 当天的开盘价格
beginPrice = np.array([2438.71,2500.88,2534.95,2512.52,2594.04,2743.26,2697.47,2695.24,2678.23,2722.13,2674.93,2744.13,2717.46,2832.73,2877.40])
print(date) # 打印日期
plt.figure()
for i in range(0,15):
# 通过循环遍历数据画出柱状图
dateOne = np.zeros([2])
dateOne[0] = i
dateOne[1] = i
print(dateOne)
priceOne = np.zeros([2])
priceOne[0] = beginPrice[i]
priceOne[1] = endPrice[i]
if endPrice[i] > beginPrice[i]:
# 如果收盘价格大于开盘价格说明股票上涨 用红色表示 lw为线条粗细
plt.plot(dateOne, priceOne,'r',lw=8)
else:
# 如果收盘价格小于开盘价格说明股票下跌 用绿色表示 lw为线条粗细
plt.plot(dateOne, priceOne,'g',lw=5)
# plt.show()
# A(15x1)*w1(1x10)+b1(1*10) = B(15x10)
# B(15x10)*w2(10x1)+b2(15x1) = C(15x1)
# 1 A B C
dateNormal = np.zeros([15,1])
priceNormal = np.zeros([15,1])
# 日期和价格进行归一化处理
for i in range(0, 15):
dateNormal[i, 0] = i/14.0
priceNormal[i, 0] = endPrice[i]/3000.0
print(dateNormal)
print(priceNormal) x = tf.placeholder(tf.float32, [None, 1]) # 表明是N行1列的
y = tf.placeholder(tf.float32, [None, 1]) # 表明是N行1列的 # B
w1 = tf.Variable(tf.random_uniform([1, 10], 0, 1)) # 可变值 可以通过误差修改值 范围0-1
b1 = tf.Variable(tf.zeros([1, 10])) # 可变值 可以通过误差修改值
wb1 = tf.matmul(x, w1)+b1
layer1 = tf.nn.relu(wb1) # 激励函数 映射成另一个值
# 第一二层完毕 # C
w2 = tf.Variable(tf.random_uniform([10, 1], 0, 1)) # 可变值 可以通过误差修改值 范围0-1
b2 = tf.Variable(tf.zeros([15, 1]))
wb2 = tf.matmul(layer1, w2)+b2
layer2 = tf.nn.relu(wb2) # 激励函数 映射成另一个值
# 第二三层完毕 # 误差用loss表示 实际是一个标准差
loss = tf.reduce_mean(tf.square(y-layer2)) # y 真实 layer2 计算
# 每次调整的步长 梯度下降0.1 目的是缩小loss减小真实值与误差值的差异
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) # 初始化
for i in range(0, 10000): # 训练次数为10000
sess.run(train_step, feed_dict={x: dateNormal, y: priceNormal})
# w1w2 b1b2 A + wb -->layer2
pred = sess.run(layer2, feed_dict={x: dateNormal})
predPrice = np.zeros([15, 1]) # 预测结果
for i in range(0, 15): # 还原数据需要*3000
predPrice[i, 0] = (pred*3000)[i, 0]
plt.plot(date, predPrice, 'b', lw=1)
plt.show()

运行结果如下:(图中蓝色的线表示股票的预测值)

Python实现人工神经网络逼近股票价格的更多相关文章

  1. 吴裕雄 python 机器学习——人工神经网络感知机学习算法的应用

    import numpy as np from matplotlib import pyplot as plt from sklearn import neighbors, datasets from ...

  2. 吴裕雄 python 机器学习——人工神经网络与原始感知机模型

    import numpy as np from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D from ...

  3. python大战机器学习——人工神经网络

    人工神经网络是有一系列简单的单元相互紧密联系构成的,每个单元有一定数量的实数输入和唯一的实数输出.神经网络的一个重要的用途就是接受和处理传感器产生的复杂的输入并进行自适应性的学习,是一种模式匹配算法, ...

  4. 人工神经网络,支持任意数量隐藏层,多层隐藏层,python代码分享

    http://www.cnblogs.com/bambipai/p/7922981.html------误差逆传播算法讲解 人工神经网络包含多种不同的神经网络,此处的代码建立的是多层感知器网络,代码以 ...

  5. 用BP人工神经网络识别手写数字

    http://wenku.baidu.com/link?url=HQ-5tZCXBQ3uwPZQECHkMCtursKIpglboBHq416N-q2WZupkNNH3Gv4vtEHyPULezDb5 ...

  6. [DL学习笔记]从人工神经网络到卷积神经网络_1_神经网络和BP算法

    前言:这只是我的一个学习笔记,里边肯定有不少错误,还希望有大神能帮帮找找,由于是从小白的视角来看问题的,所以对于初学者或多或少会有点帮助吧. 1:人工全连接神经网络和BP算法 <1>:人工 ...

  7. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  8. 开源的c语言人工神经网络计算库 FANN

    这年头机器学习非常的火,神经网络算是机器学习算法中的比较重要的一种.这段时间我也花了些功夫,学了点皮毛,顺便做点学习笔记. 介绍人工神经网络的基本理论的教科书很多.我正在看的是蒋宗礼教授写的<人 ...

  9. 人工神经网络(Artificial Neural Networks)

    人工神经网络的产生一定程度上受生物学的启发,因为生物的学习系统是由相互连接的神经元相互连接的神经元组成的复杂网络.而人工神经网络跟这个差不多,它是一系列简单的单元相互密集连接而成的.其中每个单元有一定 ...

随机推荐

  1. 献给即将35岁的初学者,焦虑 or 出路?

    导言:“对抗职场“35 岁焦虑”,也许唯一的方法是比这个瞬息万变的商业社会跑得更快!” 一直以来,都有许多人说“程序员或测试员是个吃青春饭的职业”,甚至还有说“35 岁混不到管理就等于失业”的言论. ...

  2. git系列之---工作中项目的常用git操作

    0.本地git的安装 官网下载 1.git 配置 git config user.name  查看 用户名 git config user.email   查看 邮箱 git config --glo ...

  3. Linux 简介、目录结构

    Linux是类 Unix 操作系统. 根据原生程度可分为: 内核版本 发行版本:一些公司.组织在内核版的基础上进行二次开发 根据市场需求可分为: 服务器版:没有好看的界面,在终端操作,类似于dos 桌 ...

  4. github无法访问的解决实践

    无废话版: ----------------------------- 1.复制下面内容,添加到hosts文件里(C:\Windows\System32\drivers\etc)不能修改的话,则把文件 ...

  5. Redis实现访问控制频率

    为什么限制访问频率 做服务接口时通常需要用到请求频率限制 Rate limiting,例如限制一个用户1分钟内最多可以范围100次 主要用来保证服务性能和保护数据安全 因为如果不进行限制,服务调用者可 ...

  6. linux cpp (接口与实现的分离)

    以下是 .h 文件,是接口. 以下是函数的实现 以下是主函数 首先是以上两个文件编译,不用编译头文件 g++ -c gradeBook.cpp g++ -c gradeBook.main.cpp 之后 ...

  7. MySQL 8 用户定义函数

    MySQL Server可以通过创建或者加载UDFs(User-Defined Functions)来扩展服务器功能. 通过CREATE FUNCTION语句加载 UDF,比如: CREATE FUN ...

  8. AE神奇插件TypeMonkey—抖音点赞100W+的文字视频特效是如何做出来的?

    现在最火的东西,短视频必须要拥有姓名啦,抖音这些短视频平台风头正盛,我们也常常在上面看到一些文字动画Vlog,看着并不复杂,但是有些却有上百万的点击量,今天介绍的一款神奇插件——TypeMonkey, ...

  9. MariaDB(MySQL)创建、删除、选择及数据类型使用详解

    一.MariaDB简介(MySQL简介略过) MariaDB数据库管理系统是MySQL的一个分支,主要由开源社区在维护,采用GPL授权许可 MariaDB的目的是完全兼容MySQL,包括API和命令行 ...

  10. sbt package报错:a bytes-like object is required, not 'str'

    Traceback (most recent call last): File , in <module> s.sendall(content) TypeError: a bytes-li ...