思路:在数据上选择一条直线y=Wx+b,在这条直线上附件随机生成一些数据点如下图,让TensorFlow建立回归模型,去学习什么样的W和b能更好去拟合这些数据点。

1)随机生成1000个数据点,围绕在y=0.1x+0.3 周围,设置W=0.1,b=0.3,届时看构建的模型是否能学习到w和b的值。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
num_points=1000
vectors_set=[]
for i in range(num_points):
x1=np.random.normal(0.0,0.55) #横坐标,进行随机高斯处理化,以0为均值,以0.55为标准差
y1=x1*0.1+0.3+np.random.normal(0.0,0.03) #纵坐标,数据点在y1=x1*0.1+0.3上小范围浮动
vectors_set.append([x1,y1])
x_data=[v[0] for v in vectors_set]
y_data=[v[1] for v in vectors_set]
plt.scatter(x_data,y_data,c='r')
plt.show()

构造数据如下图

2)构造线性回归模型,学习上面数据图是符合一个怎么样的W和b

    W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')  # 生成1维的W矩阵,取值是[-1,1]之间的随机数
b = tf.Variable(tf.zeros([1]), name='b') # 生成1维的b矩阵,初始值是0
y = W * x_data + b # 经过计算得出预估值y
loss = tf.reduce_mean(tf.square(y - y_data), name='loss') # 以预估值y和实际值y_data之间的均方误差作为损失
optimizer = tf.train.GradientDescentOptimizer(0.5) # 采用梯度下降法来优化参数 学习率为0.5
train = optimizer.minimize(loss, name='train') # 训练的过程就是最小化这个误差值
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
print ("W =", sess.run(W), "b =", sess.run(b), "loss =", sess.run(loss)) # 初始化的W和b是多少
for step in range(20): # 执行20次训练
sess.run(train)
print ("W =", sess.run(W), "b =", sess.run(b), "loss =", sess.run(loss)) # 输出训练好的W和b

打印每一次结果,如下图,随着迭代进行,训练的W、b越来越接近0.1、0.3,说明构建的回归模型确实学习到了之间建立的数据的规则。loss一开始很大,后来慢慢变小,说明模型表达效果随着迭代越来越好。

W = [-0.9676645] b = [0.] loss = 0.45196822

W = [-0.6281831] b = [0.29385352] loss = 0.17074569

W = [-0.39535886] b = [0.29584622] loss = 0.07962803

W = [-0.23685378] b = [0.2972129] loss = 0.03739688

W = [-0.12894464] b = [0.2981433] loss = 0.017823622

W = [-0.05548081] b = [0.29877672] loss = 0.008751821

W = [-0.00546716] b = [0.29920793] loss = 0.0045472304

W = [0.02858179] b = [0.2995015] loss = 0.0025984894

W = [0.05176209] b = [0.29970136] loss = 0.0016952885

W = [0.06754307] b = [0.29983744] loss = 0.0012766734

W = [0.07828666] b = [0.29993007] loss = 0.001082654

W = [0.08560082] b = [0.29999313] loss = 0.0009927301

W = [0.09058025] b = [0.30003607] loss = 0.0009510521

W = [0.09397022] b = [0.30006528] loss = 0.00093173544

W = [0.09627808] b = [0.3000852] loss = 0.00092278246

W = [0.09784925] b = [0.30009875] loss = 0.000918633

W = [0.09891889] b = [0.30010796] loss = 0.00091670983

W = [0.0996471] b = [0.30011424] loss = 0.0009158184

W = [0.10014286] b = [0.3001185] loss = 0.00091540517

W = [0.10048037] b = [0.30012143] loss = 0.0009152137

W = [0.10071015] b = [0.3001234] loss = 0.0009151251

注:以上内容为我学习唐宇迪老师的Tensorflow课程所做的笔记

用Tensorflow完成简单的线性回归模型的更多相关文章

  1. tensorflow入门(1):构造线性回归模型

    今天让我们一起来学习如何用TF实现线性回归模型.所谓线性回归模型就是y = W * x + b的形式的表达式拟合的模型. 我们先假设一条直线为 y = 0.1x + 0.3,即W = 0.1,b = ...

  2. [tensorflow] 线性回归模型实现

    在这一篇博客中大概讲一下用tensorflow如何实现一个简单的线性回归模型,其中就可能涉及到一些tensorflow的基本概念和操作,然后因为我只是入门了点tensorflow,所以我只能对部分代码 ...

  3. PRML读书笔记——线性回归模型(上)

    本章开始学习第一个有监督学习模型--线性回归模型."线性"在这里的含义仅限定了模型必须是参数的线性函数.而正如我们接下来要看到的,线性回归模型可以是输入变量\(x\)的非线性函数. ...

  4. TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化

    线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...

  5. 【学习笔记】tensorflow实现一个简单的线性回归

    目录 准备知识 Tensorflow运算API 梯度下降API 简单的线性回归的实现 建立事件文件 变量作用域 增加变量显示 模型的保存与加载 自定义命令行参数 准备知识 Tensorflow运算AP ...

  6. 机器学习与Tensorflow(1)——机器学习基本概念、tensorflow实现简单线性回归

    一.机器学习基本概念 1.训练集和测试集 训练集(training set/data)/训练样例(training examples): 用来进行训练,也就是产生模型或者算法的数据集 测试集(test ...

  7. TensorFlow从0到1之TensorFlow实现简单线性回归(15)

    本节将针对波士顿房价数据集的房间数量(RM)采用简单线性回归,目标是预测在最后一列(MEDV)给出的房价. 波士顿房价数据集可从http://lib.stat.cmu.edu/datasets/bos ...

  8. 线性回归模型的 MXNet 与 TensorFlow 实现

    本文主要探索如何使用深度学习框架 MXNet 或 TensorFlow 实现线性回归模型?并且以 Kaggle 上数据集 USA_Housing 做线性回归任务来预测房价. 回归任务,scikit-l ...

  9. TensorFlow简要教程及线性回归算法示例

    TensorFlow是谷歌推出的深度学习平台,目前在各大深度学习平台中使用的最广泛. 一.安装命令 pip3 install -U tensorflow --default-timeout=1800 ...

随机推荐

  1. HTML基础代码

    <!--注释内容,在浏览时不会显示--><!DOCTYPE HTML> <!--声明文档类型--><html> <!--头部内容:--> & ...

  2. orcal 数据库 maven架构 ssh框架 的全xml环境模版 及常见异常解决

    创建maven项目后,毫不犹豫,超简单傻瓜式搞定dependencies(pom.xml 就是maven的依赖管理),这样你就有了所有你要的包 <project xmlns="http ...

  3. C++ C# VC VC.net以及VC++有什么区别和联系?

    C/C++是编程语言,C是C++的爸爸,也就是说C++从C发展而来,而C++完全兼容C的语法.国际上有一个专门管理C++的机构,它们负责C++的标准制定. VC++是微软公司的C++编译环境,使用它可 ...

  4. 校内胡策 T9270 mjt树

    题目背景 从前森林里有一棵很大的mjt树,树上有很多小动物. 题目描述 mjt树上有 n 个房间,第 i 个房间住着 ai 只第bi 种小动物. 这n个房间用n-1条路连接起来,其中房间1位mjt树的 ...

  5. 双硬盘双系统win10+manjaro-kde搭建

    电脑sdd+hdd双硬盘,默认win10装在了sdd分区,uefi+gpt引导.现在想要在hdd中划分出一个分区安装manjaro,并在开机多重引导. 1. 制作安装盘 先去下载最新的镜像,最好在国内 ...

  6. 【Linux】文件、目录权限及归属

    访问权限: 可读(read):允许查看文件内容.显示目录列表 可写(write):允许修改文件内容,允许在目录中新建.移动.删除文件或子目录 可执行(execute):允许运行程序.切换目录 归属: ...

  7. PTA(BasicLevel)-1012 数字分类

    一 题目描述    给定一系列正整数,请按要求对数字进行分类,并输出以下 5 个数字: ​​ = 能被 5 整除的数字中所有偶数的和: ​​ = 将被 5 除后余 1 的数字按给出顺序进行交错求和,即 ...

  8. Centos7 安装 Python 的笔记

    Centos7 安装 Python 的笔记 注意:系统自带的Python2.7不要改动,最好也不要出错,不然yum之类的工具可能会出错. 安装Python3.7.0 TensorFlow对Python ...

  9. hadoop errors

    1.taskTracker和jobTracker 启动失败 2011-01-05 12:44:42,144 ERROR org.apache.hadoop.mapred.TaskTracker: Ca ...

  10. 20155202 实验四 Android开发基础

    20155202 实验四 Android开发基础 实验内容 1.基于Android Studio开发简单的Android应用并部署测试; 2.了解Android.组件.布局管理器的使用: 3.掌握An ...