莫烦theano学习自修第七天【回归结果可视化】
1.代码实现
from __future__ import print_function import theano import theano.tensor as T import numpy as np import matplotlib.pyplot as plt class Layer(object): def __init__(self, inputs, in_size, out_size, activation_function=None): self.W = theano.shared(np.random.normal(0, 1, (in_size, out_size))) self.b = theano.shared(np.zeros((out_size, )) + 0.1) self.Wx_plus_b = T.dot(inputs, self.W) + self.b self.activation_function = activation_function if activation_function is None: self.outputs = self.Wx_plus_b else: self.outputs = self.activation_function(self.Wx_plus_b) # Make up some fake data x_data = np.linspace(-1, 1, 300)[:, np.newaxis] noise = np.random.normal(0, 0.05, x_data.shape) y_data = np.square(x_data) - 0.5 + noise # y = x^2 - 0.5 # show the fake data #plt.scatter(x_data, y_data) plt.show() # determine the inputs dtype x = T.dmatrix("x") y = T.dmatrix("y") # add layers l1 = Layer(x, 1, 10, T.nnet.relu) l2 = Layer(l1.outputs, 10, 1, None) # compute the cost cost = T.mean(T.square(l2.outputs - y)) # compute the gradients gW1, gb1, gW2, gb2 = T.grad(cost, [l1.W, l1.b, l2.W, l2.b]) # apply gradient descent learning_rate = 0.05 train = theano.function( inputs=[x, y], outputs=[cost], updates=[(l1.W, l1.W - learning_rate * gW1), (l1.b, l1.b - learning_rate * gb1), (l2.W, l2.W - learning_rate * gW2), (l2.b, l2.b - learning_rate * gb2)]) # prediction predict = theano.function(inputs=[x], outputs=l2.outputs) # plot the real data fig = plt.figure() ax = fig.add_subplot(1,1,1) ax.scatter(x_data, y_data) plt.ion() plt.show() for i in range(1000): # training err = train(x_data, y_data) if i % 50 == 0: # to visualize the result and improvement try: ax.lines.remove(lines[0]) except Exception: pass prediction_value = predict(x_data) # plot the prediction lines = ax.plot(x_data, prediction_value, 'r-', lw=5) plt.pause(.5)
结果:
莫烦theano学习自修第七天【回归结果可视化】的更多相关文章
- 莫烦theano学习自修第九天【过拟合问题与正规化】
如下图所示(回归的过拟合问题):如果机器学习得到的回归为下图中的直线则是比较好的结果,但是如果进一步控制减少误差,导致机器学习到了下图中的曲线,则100%正确的学习了训练数据,看似较好,但是如果换成另 ...
- 莫烦theano学习自修第十天【保存神经网络及加载神经网络】
1. 为何保存神经网络 保存神经网络指的是保存神经网络的权重W及偏置b,权重W,和偏置b本身是一个列表,将这两个列表的值写到列表或者字典的数据结构中,使用pickle的数据结构将列表或者字典写入到文件 ...
- 莫烦theano学习自修第八天【分类问题】
1. 代码实现 from __future__ import print_function import numpy as np import theano import theano.tensor ...
- 莫烦theano学习自修第六天【回归】
1. 代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy a ...
- 莫烦theano学习自修第五天【定义神经层】
1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...
- 莫烦theano学习自修第三天【共享变量】
1. 代码实现 #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T i ...
- 莫烦theano学习自修第二天【激励函数】
1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...
- 莫烦theano学习自修第一天【常量和矩阵的运算】
1. 代码实现如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ # 导入numpy模块,因为numpy是常用的计算模块 import numpy as ...
- 莫烦sklearn学习自修第七天【交叉验证】
1. 什么是交叉验证 所谓交叉验证指的是将样本分为两组,一组为训练样本,一组为测试样本:对于哪些数据分为训练样本,哪些数据分为测试样本,进行多次拆分,每次将整个样本进行不同的拆分,对这些不同的拆分每个 ...
随机推荐
- 【转】编写微信聊天机器人4《聊天精灵WeChatGenius》:实时获取到微信聊天消息,hook数据库插入操作。
接上篇,使用Xposed来hook微信,找到微信进程:https://blog.csdn.net/weixin_42127613/article/details/81839537 既然已经找到了微信进 ...
- 使用python进行utf9编码和解码
在2005年4月1日(也就是愚人节),IEEE的rfc4042文件规定了utf9和utf18这2个所谓的Unicode的高效转换格式. 具体的格式说明,有兴趣的话点击上面的rfc4042链接去观看. ...
- Java中的并发工具类(CountDownLatch、CyclicBarrier、Semaphore、Exchanger)
在JDK的并发包里提供了很多有意思的并发工具类.CountDownLatch.CyclicBarrier和Semaphore 工具类提供了一种并发流程控制的手段,Exchanger 工具类则提供了在线 ...
- A2D JS框架 - Web API CSRF保护实现
这次自己实现了类似jQuery中ajax调用的方法,并且针对RESTFul进行了改造和集成,实现的A2D AJAX接口如下: $.ajax.RESTFulGetCollection("/ap ...
- 使用C#创建SQLite控制台应用程序
本文属于原创,转载请注明出处,谢谢! 一.开发环境 操作系统:Windows 10 X64 开发环境:VS2015 编程语言:C# .NET版本:.NET Framework 4.0 目标平台:X86 ...
- eclipse 常用配置
一.内置tomcat配置 解决eclipse 内置tomcat 与本地tomcat 端口冲突 传送门:http://www.cnblogs.com/tweet/p/7568979.html 二.字体设 ...
- js中布尔值为false的六种情况
下面6种值转化为布尔值时为false,其他转化都为true 1.undefined(未定义,找不到值时出现) 2.null(代表空值) 3.false(布尔值的false,字符串"false ...
- SNMP 获取交换机端口相关信息
原文地址:https://blog.csdn.net/ysdaniel/article/details/37927541 我们想用snmpwalk查看网络设备的端口,MIB库中相关定义的信息如下: [ ...
- i++ 相比 ++i 哪个更高效?为什么?
++i的效率高些,++i在运算过程中不产生临时对象,返回的就是i,是个左值,类似++i=1这样的表达式是合法的,而i++在运算的过程中会产生临时对象,返回的是零时对象的值,是个右值,像i++=1这样的 ...
- UVA - 12716 - 异或序列
求满足GCD(a,b) = a XOR b; 其中1<=b <=a<=n. 首先做这道题需要知道几个定理: 异或:a XOR b = c 那么 a XOR c = b; 那么我们令G ...