TensorFlow线性回归
目录
数据可视化
梯度下降
结果可视化
|
数据可视化 |
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt # 随机生成1000个点,围绕在y=0.1x+0.3的直线周围
num_points = 1000
vectors_set = []
for i in range(num_points):
x1 = np.random.normal(0.0, 0.55)
y1 = x1 * 0.1 + 0.3 + np.random.normal(0.0, 0.03)
vectors_set.append([x1, y1]) # 生成一些样本
x_data = [v[0] for v in vectors_set]
y_data = [v[1] for v in vectors_set] plt.scatter(x_data,y_data,c='r')
plt.show()

|
梯度下降 |
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt # 随机生成1000个点,围绕在y=0.1x+0.3的直线周围
num_points = 1000
vectors_set = []
for i in range(num_points):
x1 = np.random.normal(0.0, 0.55)
y1 = x1 * 0.1 + 0.3 + np.random.normal(0.0, 0.03)
vectors_set.append([x1, y1]) # 生成一些样本
x_data = [v[0] for v in vectors_set]
y_data = [v[1] for v in vectors_set] # 生成1维的W矩阵,取值是[-1,1]之间的随机数
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
# 生成1维的b矩阵,初始值是0
b = tf.Variable(tf.zeros([1]), name='b')
# 经过计算得出预估值y
y = W * x_data + b # 以预估值y和实际值y_data之间的均方误差作为损失
loss = tf.reduce_mean(tf.square(y - y_data), name='loss')
# 采用梯度下降法来优化参数
optimizer = tf.train.GradientDescentOptimizer(0.5) #参数是学习率
# 训练的过程就是最小化这个误差值
train = optimizer.minimize(loss, name='train') sess = tf.Session() init = tf.global_variables_initializer()
sess.run(init) # 初始化的W和b是多少
print ("W =", sess.run(W), "b =", sess.run(b), "loss =", sess.run(loss))
# 执行20次训练
for step in range(20):
sess.run(train)
# 输出训练好的W和b
print ("W =", sess.run(W), "b =", sess.run(b), "loss =", sess.run(loss))
'''
W = [ 0.72134733] b = [ 0.] loss = 0.204532
W = [ 0.54246926] b = [ 0.31014919] loss = 0.0552976
W = [ 0.41924465] b = [ 0.30693138] loss = 0.029155
W = [ 0.33045709] b = [ 0.30471471] loss = 0.0155833
W = [ 0.26648441] b = [ 0.30311754] loss = 0.00853772
W = [ 0.22039121] b = [ 0.30196676] loss = 0.00488007
W = [ 0.18718043] b = [ 0.3011376] loss = 0.00298124
W = [ 0.16325161] b = [ 0.30054021] loss = 0.00199547
W = [ 0.14601055] b = [ 0.30010974] loss = 0.00148373
W = [ 0.13358814] b = [ 0.29979959] loss = 0.00121806
W = [ 0.12463761] b = [ 0.29957613] loss = 0.00108014
W = [ 0.11818863] b = [ 0.29941514] loss = 0.00100854
W = [ 0.11354206] b = [ 0.29929912] loss = 0.000971367
W = [ 0.11019413] b = [ 0.29921553] loss = 0.00095207
W = [ 0.10778191] b = [ 0.29915532] loss = 0.000942053
W = [ 0.10604387] b = [ 0.29911193] loss = 0.000936852
W = [ 0.10479159] b = [ 0.29908064] loss = 0.000934153
W = [ 0.1038893] b = [ 0.29905814] loss = 0.000932751
W = [ 0.10323919] b = [ 0.2990419] loss = 0.000932023
W = [ 0.10277078] b = [ 0.29903021] loss = 0.000931646
W = [ 0.10243329] b = [ 0.29902178] loss = 0.00093145
'''
|
结果可视化 |
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt # 随机生成1000个点,围绕在y=0.1x+0.3的直线周围
num_points = 1000
vectors_set = []
for i in range(num_points):
x1 = np.random.normal(0.0, 0.55)
y1 = x1 * 0.1 + 0.3 + np.random.normal(0.0, 0.03)
vectors_set.append([x1, y1]) # 生成一些样本
x_data = [v[0] for v in vectors_set]
y_data = [v[1] for v in vectors_set] # 生成1维的W矩阵,取值是[-1,1]之间的随机数
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
# 生成1维的b矩阵,初始值是0
b = tf.Variable(tf.zeros([1]), name='b')
# 经过计算得出预估值y
y = W * x_data + b # 以预估值y和实际值y_data之间的均方误差作为损失
loss = tf.reduce_mean(tf.square(y - y_data), name='loss')
# 采用梯度下降法来优化参数
optimizer = tf.train.GradientDescentOptimizer(0.5) #参数是学习率
# 训练的过程就是最小化这个误差值
train = optimizer.minimize(loss, name='train') sess = tf.Session() init = tf.global_variables_initializer()
sess.run(init) # 初始化的W和b是多少
print ("W =", sess.run(W), "b =", sess.run(b), "loss =", sess.run(loss))
# 执行20次训练
for step in range(20):
sess.run(train)
# 输出训练好的W和b
print ("W =", sess.run(W), "b =", sess.run(b), "loss =", sess.run(loss)) plt.scatter(x_data,y_data,c='r')
plt.plot(x_data,sess.run(W)*x_data+sess.run(b))
plt.show()

TensorFlow线性回归的更多相关文章
- [tensorflow] 线性回归模型实现
在这一篇博客中大概讲一下用tensorflow如何实现一个简单的线性回归模型,其中就可能涉及到一些tensorflow的基本概念和操作,然后因为我只是入门了点tensorflow,所以我只能对部分代码 ...
- python,tensorflow线性回归Django网页显示Gif动态图
1.工程组成 2.urls.py """Django_machine_learning_linear_regression URL Configuration The ` ...
- tensorflow 线性回归解决 iris 2分类
# Combining Everything Together #---------------------------------- # This file will perform binary ...
- 1.tensorflow——线性回归
tensorflow 1.一切都要tf. 2.只有sess.run才能生效 import tensorflow as tf import numpy as np import matplotlib.p ...
- tensorflow 线性回归 iris
线性拟合
- TensorFlow简要教程及线性回归算法示例
TensorFlow是谷歌推出的深度学习平台,目前在各大深度学习平台中使用的最广泛. 一.安装命令 pip3 install -U tensorflow --default-timeout=1800 ...
- TensorFlow API 汉化
TensorFlow API 汉化 模块:tf 定义于tensorflow/__init__.py. 将所有公共TensorFlow接口引入此模块. 模块 app module:通用入口点脚本. ...
- tfboys——tensorflow模块学习(三)
tf.estimator模块 定义在:tensorflow/python/estimator/estimator_lib.py 估算器(Estimator): 用于处理模型的高级工具. 主要模块 ex ...
- TensorFlow — 相关 API
TensorFlow — 相关 API TensorFlow 相关函数理解 任务时间:时间未知 tf.truncated_normal truncated_normal( shape, mean=0. ...
随机推荐
- linux主机之间的SSH链接
一.什么是SSH连接 SSH为Secyre Shell的缩写,SSH 为建立在应用层基础上的安全协议.SSH 是目前较可靠,专为远程登录会话和其他网络服务提供安全性的协议.且SSH连接可以通过多种平台 ...
- linux 用户及文件权限管理
Linux 是一个可以实现多用户登陆的操作系统,比如“李雷”和“韩梅梅”都可以同时登陆同一台主机,他们共享一些主机的资源,但他们也分别有自己的用户空间,用于存放各自的文件.但实际上他们的文件都是放在同 ...
- cmd完成拷贝文件,并生成两个快捷脚本
@echo off@echo ------------------------------ @echo 正在创建目录 color 03if exist y:\00程序数据备份 ( md y:\00程序 ...
- button 文字图片上下/左右经常会用到,记录一下
上下: self.button.contentHorizontalAlignment = UIControlContentHorizontalAlignmentCenter;//使图片和文字水平 ...
- CentOS 7 安装 metasploit-framework
1 一键安装metasploit-framework apt-get install curl,wgetcurl https://raw.githubusercontent.com/rapid7/me ...
- 关于jq中input的value值clone的问题
如果想将input进行克隆,然后在后面显示出来并修改input里面的文字,这时就会发现一个问题,就是你克隆出来的value值始终是你克隆时的value,检查页面元素你就会发现,这时需要对克隆之后的in ...
- LVS+Heartbeat安装部署文档
LVS+Heartbeat安装部署文档 发表回复 所需软件: ipvsadm-1.24-10.x86_64.rpmheartbeat-2.1.3-3.el5.centos.x86_64.rpmhear ...
- mysql创建函数槽点
上机环境 mysql8.0 navicat for mysql 很有那么一批软件程序,要不做点脱了裤子放屁的事儿就觉得自己不够二进制似的,今儿写了一下午mysql函数,怎么都通过不了,上网一看 mys ...
- js/html 判断ie浏览器版本
1.html判断浏览器:<!--[if !IE]><!-->除ie外都可以识别<!--<![endif]--><!--[if IE]>所有ie可以 ...
- jsp九大内置对象及四个作用域【转】
1.Request对象 该对象封装了用户提交的信息,通过调用该对象相应的方法可以获取封装的信息,即使用该对象可以 获取用户提交的信息. 当Request对象获取客户提交的汉字字符时,会出现乱码问题,必 ...