这一节使用TF搭建一个简单的神经网络用于回归预测,首先随机生成一组数据

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
tf.set_random_seed(42)
np.random.seed(42)
x = np.linspace(-1,1,100)[:,np.newaxis] #<==>x=x.reshape(100,1)
noise = np.random.normal(0,0.1,size = x.shape)
y=np.power(x,2) + x +noise #y=x^2 + x+噪音
plt.scatter(x,y)
plt.show()

随机生成了一组数据,模型为\(y=x^2+x\),看一下数据的分布

接下来搭建一个含有一个隐藏层的神经网络,损失选择使用均方差误差

#模型部分
tf_X = tf.placeholder(tf.float32,x.shape) #=>X
tf_y = tf.placeholder(tf.float32,y.shape) #=>y output = tf.layers.dense(tf_X,10,tf.nn.relu,name="hidden")#隐藏层10个节点
output = tf.layers.dense(output,1,name='output') #1个输出层
#loss = tf.losses.mean_squared_error(tf_y,output)
loss = tf.reduce_mean(tf.sqrt(tf.pow(tf_y-output,2)))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.2)
train_op = optimizer.minimize(loss)

其中tf.losses中提供了常用的损失函数实现,也可以自己去实现,开始训练模型

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
plt.ion()
for step in range(100):
_,err,pred = sess.run([train_op,loss,output],feed_dict={tf_X:x,tf_y:y})
#cla() # Clear axis
#clf() # Clear figure
#close() # Close a figure window
plt.cla()#
plt.scatter(x,y)
plt.plot(x,pred,'r-',lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % err, fontdict={'size': 20, 'color': 'red'})
#plt.show()
plt.ioff()
plt.show()

看一看效果:

note:上面使用了plt.cla方法,这是由于方便看到变化过程,将plot过程写入到了for循环中,为了避免发生意外错误将对象从内存中清空。

使用TensorFlow实现回归预测的更多相关文章

  1. Tensorflow 线性回归预测房价实例

    在本节中将通过一个预测房屋价格的实例来讲解利用线性回归预测房屋价格,以及在tensorflow中如何实现 Tensorflow 线性回归预测房价实例 1.1. 准备工作 1.2. 归一化数据 1.3. ...

  2. TensorFlow笔记二:线性回归预测(Linear Regression)

    代码: import tensorflow as tf import numpy as np import xlrd import matplotlib.pyplot as plt DATA_FILE ...

  3. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  4. Python机器学习笔记:使用Keras进行回归预测

    Keras是一个深度学习库,包含高效的数字库Theano和TensorFlow.是一个高度模块化的神经网络库,支持CPU和GPU. 本文学习的目的是学习如何加载CSV文件并使其可供Keras使用,如何 ...

  5. TensorFlow入门(五)多层 LSTM 通俗易懂版

    欢迎转载,但请务必注明原文出处及作者信息. @author: huangyongye @creat_date: 2017-03-09 前言: 根据我本人学习 TensorFlow 实现 LSTM 的经 ...

  6. [Tensorflow] RNN - 03. MultiRNNCell for Digit Prediction

    Ref: http://blog.csdn.net/u014595019/article/details/52759104 Time: 2min Successfully downloaded tra ...

  7. Tensorflow实现LSTM识别MINIST

    import tensorflow as tf import numpy as np from tensorflow.contrib import rnn from tensorflow.exampl ...

  8. TensorFlow笔记四:从生成和保存模型 -> 调用使用模型

    TensorFlow常用的示例一般都是生成模型和测试模型写在一起,每次更换测试数据都要重新训练,过于麻烦, 以下采用先生成并保存本地模型,然后后续程序调用测试. 示例一:线性回归预测 make.py ...

  9. 深度学习笔记(十三)YOLO V3 (Tensorflow)

    [代码剖析]   推荐阅读! SSD 学习笔记 之前看了一遍 YOLO V3 的论文,写的挺有意思的,尴尬的是,我这鱼的记忆,看完就忘了  于是只能借助于代码,再看一遍细节了. 源码目录总览 tens ...

随机推荐

  1. Disruptor3.0的实现细节

    本文旨在介绍Disruptor3.0的实现细节,首先从整体上描述了Disruptor3.0的核心类图,Disruptor3.0 DSL(领域专用语言)的实现类图,并以Disruptor官方列举的几大特 ...

  2. MinnowBoard MAX 硬件开发板

    Minnowboard MAX MinnowBoard MAX是一款紧凑型,经济实惠,而且功能强大的开发板为专业人士和制造商.开放式的硬件设计使无尽的定制和集成的潜力.它采用64位英特尔®凌动™E38 ...

  3. 图解MBR分区无损转换GPT分区+UEFI引导安装WIN8.1

    确定你的主板支持UEFI引导.1,前期准备,WIN8.1原版系统一份(坛子里很多,自己下载个),U盘2个其中大于4G一个(最好 准备两个U盘)2,大家都知道WIN8系统只支持GPT分区,传统的MBR分 ...

  4. RAID卡技术简析

    经过一段时间的折腾,工作的事终于解决了,新工作一上来的第一件事就要熟悉RAID卡存储机制,先简单了解下RAID卡吧. 提到RAID卡就不得不提什么是RAID,RAID是英文Redundant Arra ...

  5. raid功能中spanning和striping模式有什么区别?

    RAID 0 又称为Stripe(条带化,串列)或Striping 它代表了所有RAID级别中最高的存储性能.RAID 0提高存储性能的原理是把连续的数据分散到多个磁盘上存取,这样,系统有数据请求就可 ...

  6. Java JVM使用哪种编码格式

    Java JVM使用哪种编码格式 A ASCII characters  B Unicode characters C Cp1252 D UTF-8 E GBK F GBK2312 答案:B   在J ...

  7. java.lang.ArrayIndexOutOfBoundsException

    1.错误描述 Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: 0 at com.you.m ...

  8. jquery.lazyload.js实现图片懒加载

    个人理解:将需要延迟加载的图片的src属性全部设置为一张相同尽可能小(目的是尽可能的少占宽带,节省流量,由于缓存机制,当浏览器加载了一张图片之后,相同的图片就会在缓存中拿,不会重新到服务器上拿)的图片 ...

  9. 双刃剑MongoDB的学习和避坑

    双刃剑MongoDB的学习和避坑 MongoDB 是一把双刃剑,它对数据结构的要求并不高.数据通过key-value的形式存储,而value的值可以是字符串,也可以是文档.所以我们在使用的过程中非常方 ...

  10. Vue01 Vue介绍、Vue使用、Vue实例的创建、数据绑定、Vue实例的生命周期、差值与表达式、指令与事件、语法糖

    1 Vue介绍 1.1 官方介绍 vue是一个简单小巧的渐进式的技术栈,它提供了Web开发中常用的高级功能:视图和数据的解耦.组件的服用.路由.状态管理.虚拟DOM 说明:简单小巧 -> 压缩后 ...