使用TensorFlow实现回归预测
这一节使用TF搭建一个简单的神经网络用于回归预测,首先随机生成一组数据
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
tf.set_random_seed(42)
np.random.seed(42)
x = np.linspace(-1,1,100)[:,np.newaxis] #<==>x=x.reshape(100,1)
noise = np.random.normal(0,0.1,size = x.shape)
y=np.power(x,2) + x +noise #y=x^2 + x+噪音
plt.scatter(x,y)
plt.show()
随机生成了一组数据,模型为\(y=x^2+x\),看一下数据的分布
接下来搭建一个含有一个隐藏层的神经网络,损失选择使用均方差
误差
#模型部分
tf_X = tf.placeholder(tf.float32,x.shape) #=>X
tf_y = tf.placeholder(tf.float32,y.shape) #=>y
output = tf.layers.dense(tf_X,10,tf.nn.relu,name="hidden")#隐藏层10个节点
output = tf.layers.dense(output,1,name='output') #1个输出层
#loss = tf.losses.mean_squared_error(tf_y,output)
loss = tf.reduce_mean(tf.sqrt(tf.pow(tf_y-output,2)))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.2)
train_op = optimizer.minimize(loss)
其中tf.losses
中提供了常用的损失函数实现,也可以自己去实现,开始训练模型
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
plt.ion()
for step in range(100):
_,err,pred = sess.run([train_op,loss,output],feed_dict={tf_X:x,tf_y:y})
#cla() # Clear axis
#clf() # Clear figure
#close() # Close a figure window
plt.cla()#
plt.scatter(x,y)
plt.plot(x,pred,'r-',lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % err, fontdict={'size': 20, 'color': 'red'})
#plt.show()
plt.ioff()
plt.show()
看一看效果:
note:上面使用了plt.cla
方法,这是由于方便看到变化过程,将plot过程写入到了for循环中,为了避免发生意外错误将对象从内存中清空。
使用TensorFlow实现回归预测的更多相关文章
- Tensorflow 线性回归预测房价实例
在本节中将通过一个预测房屋价格的实例来讲解利用线性回归预测房屋价格,以及在tensorflow中如何实现 Tensorflow 线性回归预测房价实例 1.1. 准备工作 1.2. 归一化数据 1.3. ...
- TensorFlow笔记二:线性回归预测(Linear Regression)
代码: import tensorflow as tf import numpy as np import xlrd import matplotlib.pyplot as plt DATA_FILE ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- Python机器学习笔记:使用Keras进行回归预测
Keras是一个深度学习库,包含高效的数字库Theano和TensorFlow.是一个高度模块化的神经网络库,支持CPU和GPU. 本文学习的目的是学习如何加载CSV文件并使其可供Keras使用,如何 ...
- TensorFlow入门(五)多层 LSTM 通俗易懂版
欢迎转载,但请务必注明原文出处及作者信息. @author: huangyongye @creat_date: 2017-03-09 前言: 根据我本人学习 TensorFlow 实现 LSTM 的经 ...
- [Tensorflow] RNN - 03. MultiRNNCell for Digit Prediction
Ref: http://blog.csdn.net/u014595019/article/details/52759104 Time: 2min Successfully downloaded tra ...
- Tensorflow实现LSTM识别MINIST
import tensorflow as tf import numpy as np from tensorflow.contrib import rnn from tensorflow.exampl ...
- TensorFlow笔记四:从生成和保存模型 -> 调用使用模型
TensorFlow常用的示例一般都是生成模型和测试模型写在一起,每次更换测试数据都要重新训练,过于麻烦, 以下采用先生成并保存本地模型,然后后续程序调用测试. 示例一:线性回归预测 make.py ...
- 深度学习笔记(十三)YOLO V3 (Tensorflow)
[代码剖析] 推荐阅读! SSD 学习笔记 之前看了一遍 YOLO V3 的论文,写的挺有意思的,尴尬的是,我这鱼的记忆,看完就忘了 于是只能借助于代码,再看一遍细节了. 源码目录总览 tens ...
随机推荐
- Hadoop序列化与Java序列化
序列化就是把内存中的对象的状态信息转换成字节序列,以便于存储(持久化)和网络传输 反序列化就是就将收到的字节序列或者是硬盘的持久化数据,转换成内存中的对象. 1.JDK的序列化 只要实现了serial ...
- RVDS4.0 + JLINK 调试 cortex-A9
1.RVDS4.0的安装与破解 参看http://blog.csdn.net/cp1300/article/details/7772645这位大神的帖子吧,写的很详细. 2.JLINK驱动的安装 这里 ...
- Linux 系统裁剪笔记 4 (内核配置选项及删改)
CDROM filesystem support(CONFIG_ISO9660_FS)[Y/m/n/?]有标准光驱的系统应该选Y.Minix fs support(CONFIG_MINIX_FS)[ ...
- Linux开机启动图片修改
Linux启动时会在屏幕上显示一个默认的开机图片,我们可以修改成为自己的图片,需要做以下工作 软件gimp下载地址:http://www.rayfile.com/zh-cn/files/0bb556b ...
- JavaScript去除日期中的“-”
JavaScript去除日期中的"-" 1.说明 经常会出现这样的情况,页面的日期格式是:YYYY-MM-DD,而数据库中的日期格式是:YYYYMMDD,两者之间需要转换一下,方能 ...
- 简单bfs(hdu2612)
#include<stdio.h>#include<string.h>#include<queue>#define INF 0x3f3f3f3fusing name ...
- hive查询结果输出到hdfs上
insert overwrite directory "/mapredOutput/UserYesterdayInterest/${hiveconf:day}"row format ...
- 使用PHPword中文乱码并且下载的方法
如果你的编码格式是utf-8的话就用这个 1.找到 Section.php 的 addText 函数 $givenText = utf8_encode($text); 改成 $givenText = ...
- 洛谷P3203 [HNOI2010]弹飞绵羊(LCT,Splay)
洛谷题目传送门 关于LCT的问题详见我的LCT总结 思路分析 首先分析一下题意.对于每个弹力装置,有且仅有一个位置可以弹到.把这样的一种关系可以视作边. 然后,每个装置一定会往后弹,这不就代表不存在环 ...
- CDQ分治嵌套模板:多维偏序问题
CDQ分治2 CDQ套CDQ:四维偏序问题 题目来源:COGS 2479 偏序 #define LEFT 0 #define RIGHT 1 struct Node{int a,b,c,d,bg;}; ...