运行代码:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D LR = 0.1
REAL_PARAMS = [1.2, 2.5]
INIT_PARAMS = [[5, 4],
[5, 1],
[2, 4.5]][2] x = np.linspace(-1, 1, 200, dtype=np.float32) # x data y_fun = lambda a, b: np.sin(b*np.cos(a*x))
tf_y_fun = lambda a, b: tf.sin(b*tf.cos(a*x)) noise = np.random.randn(200)/10
y = y_fun(*REAL_PARAMS) + noise # target # tensorflow graph
a, b = [tf.Variable(initial_value=p, dtype=tf.float32) for p in INIT_PARAMS]
pred = tf_y_fun(a, b)
mse = tf.reduce_mean(tf.square(y-pred))
train_op = tf.train.GradientDescentOptimizer(LR).minimize(mse) a_list, b_list, cost_list = [], [], []
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for t in range(400):
a_, b_, mse_ = sess.run([a, b, mse])
a_list.append(a_); b_list.append(b_); cost_list.append(mse_) # record parameter changes
result, _ = sess.run([pred, train_op]) # training # visualization codes:
print('a=', a_, 'b=', b_)
plt.figure(1)
plt.scatter(x, y, c='b') # plot data
plt.plot(x, result, 'r-', lw=2) # plot line fitting
# 3D cost figure
fig = plt.figure(2); ax = Axes3D(fig)
a3D, b3D = np.meshgrid(np.linspace(-2, 7, 30), np.linspace(-2, 7, 30)) # parameter space
cost3D = np.array([np.mean(np.square(y_fun(a_, b_) - y)) for a_, b_ in zip(a3D.flatten(), b3D.flatten())]).reshape(a3D.shape)
ax.plot_surface(a3D, b3D, cost3D, rstride=1, cstride=1, cmap=plt.get_cmap('rainbow'), alpha=0.5)
ax.scatter(a_list[0], b_list[0], zs=cost_list[0], s=300, c='r') # initial parameter place
ax.set_xlabel('a'); ax.set_ylabel('b')
ax.plot(a_list, b_list, zs=cost_list, zdir='z', c='r', lw=3) # plot 3D gradient descent
plt.show()

运行结果:

TensorFlow从入门到理解(六):可视化梯度下降的更多相关文章

  1. TensorFlow从入门到理解

    一.<莫烦Python>学习笔记: TensorFlow从入门到理解(一):搭建开发环境[基于Ubuntu18.04] TensorFlow从入门到理解(二):你的第一个神经网络 Tens ...

  2. TensorFlow从入门到理解(二):你的第一个神经网络

    运行代码: from __future__ import print_function import tensorflow as tf import numpy as np import matplo ...

  3. TensorFlow从入门到理解(五):你的第一个循环神经网络RNN(回归例子)

    运行代码: import tensorflow as tf import numpy as np import matplotlib.pyplot as plt BATCH_START = 0 TIM ...

  4. TensorFlow从入门到理解(四):你的第一个循环神经网络RNN(分类例子)

    运行代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # set rando ...

  5. TensorFlow从入门到理解(三):你的第一个卷积神经网络(CNN)

    运行代码: from __future__ import print_function import tensorflow as tf from tensorflow.examples.tutoria ...

  6. TensorFlow从入门到理解(一):搭建开发环境【基于Ubuntu18.04】

    *注:教程及本文章皆使用Python3+语言,执行.py文件都是用终端(如果使用Python2+和IDE都会和本文描述有点不符) 一.安装,测试,卸载 TensorFlow官网介绍得很全面,很完美了, ...

  7. NN优化方法对照:梯度下降、随机梯度下降和批量梯度下降

    1.前言 这几种方法呢都是在求最优解中常常出现的方法,主要是应用迭代的思想来逼近.在梯度下降算法中.都是环绕下面这个式子展开: 当中在上面的式子中hθ(x)代表.输入为x的时候的其当时θ參数下的输出值 ...

  8. 超简单tensorflow入门优化程序&&tensorboard可视化

    程序1 任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解. 使用tensorflow编程实现: #-*- coding: utf-8 -*-) im ...

  9. TensorFlow学习——入门篇

    本文主要通过一个简单的 Demo 介绍 TensorFlow 初级 API 的使用方法,因为自己也是初学者,因此本文的目的主要是引导刚接触 TensorFlow 或者 机器学习的同学,能够从第一步开始 ...

随机推荐

  1. spring boot下MultipartHttpServletRequest如何提高上传文件大小的默认值

    前言: 上传下载功能算是一个非常常见的功能,如果使用MultipartHttpServletRequest来做上传功能. 不配置上传大小的话,默认是2M.在有些场景,这个肯定不能满足条件. 上传代码: ...

  2. Gym 101911E "Painting the Fence"(线段树区间更新+双端队列)

    传送门 题意: 庭院中有 n 个围栏,每个围栏上都被涂上了不同的颜色(数字表示): 有 m 条指令,每条指令给出一个整数 x ,你要做的就是将区间[ x第一次出现的位置 , x最后出现的位置 ]中的围 ...

  3. SCU-4527 NightMare2(Dijkstra+BFS) !!!错误题解!!!

    错解警告!!! 描述 可怜的RunningPhoton又做噩梦了..但是这次跟上次不大一样,虽然他又被困在迷宫里,又被装上了一个定时炸弹,但是值得高兴的是,他发现他身边有数不清的财宝,所以他如果能带着 ...

  4. css换行

    1. word-break:break-all;只对英文起作用,以字母作为换行依据 2. word-wrap:break-word; 只对英文起作用,以单词作为换行依据 3. white-space: ...

  5. 【3D动画建模设计工具】Maxon Cinema 4D Studio for Mac 20.0

    图标 Icon   软件介绍 Description Maxon Cinema 4D Studio R20 ,是由德国公司Maxon Computer一款适用于macOS系统的3D动画建模设计工具,是 ...

  6. nginx配置模板问题404

    nginx配置模板问题 一.nginx主配置文件如下 cat /etc/nginx/nginx.conf user nginx; worker_processes ; #error_log logs/ ...

  7. 读写锁ReadWriteLock

    为了提高性能,Java提供了读写锁,在读的地方使用读锁,在写的地方使用写锁,灵活控制,如果没有写锁的情况下,读是无阻塞的,在一定程度上提高了程序的执行效率. Java中读写锁有个接口java.util ...

  8. Redis 高可用分布式集群

    一,高可用 高可用(High Availability),是当一台服务器停止服务后,对于业务及用户毫无影响. 停止服务的原因可能由于网卡.路由器.机房.CPU负载过高.内存溢出.自然灾害等不可预期的原 ...

  9. mysql关联模糊查询他表字段

    如下:订单表关联了用户的id(多个),要根据用户名模糊查询订单信息,但是订单表只有id.创建视图用不着,咱也没权限.于是如下 SELECT * FROM ( SELECT cu.id AS 'id', ...

  10. 阿里云部署Web项目

    1.首先最基本的购买服务器和域名(学生党可以享受每月9块钱的优惠,不知道为什么,pc端不能购买,只能下载阿里云APP购买)  下载APP后打开:学生专区-学生特权-购买(我选择的是ubuntu,这个随 ...