【pytorch学习笔记】-搭建神经网络进行关系拟合

学习自莫烦python

目标

1.创建一些围绕y=x^2+噪声这个函数的散点

2.用神经网络模型来建立一个可以代表他们关系的线条

建立数据集

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)#一维变二维,x从-1到1,切分为100份
y=x.pow(2)+0.2*torch.rand(x.size())#创建一些围绕着这y=x^2的随机点的散点 # plt.scatter(x.data.numpy(),y.data.numpy())#画图
# plt.show() x,y=Variable(x),Variable(y)#构造神经网络要使用Variable类型

建立神经网络

1.继承torch.nn.Module模块

2.定义__init__函数,在初始化函数中定义输入层到隐藏层,从隐藏层再到输出层各个层的神经元个数

3.再一层层搭建(forward(x))层于层的关系链接

class Net(torch.nn.Module):
def __init__(self,n_feature,n_hidden,n_ouput):#初始化信息
super(Net, self).__init__()
self.hidden=torch.nn.Linear(n_feature,n_hidden,n_ouput)#隐藏层线性输出
self.predict=torch.nn.Linear(n_hidden,n_ouput)#输出层线性输出 def forward(self,x):#前向传递的过程
#正向传播输入值,神经网络输出预测值
x=F.relu(self.hidden(x))#激励函数加工一下
x=self.predict(x)#输出值预测值
return x

训练神经网络

1.定义训练工具optimizer,输入神经网络参数和学习效率

2.定义误差函数,使用均方差来计算实际值y和训练输出值之间的误差

3.每次训练向神经网络输入x,得到预测值,计算误差

4.注意要清空上一步的残余更新参数值

5.误差反向传播, 计算参数更新值

6.将参数更新值施加到 net 的 parameters 上

for t in range(200):#训练200次
prediction=net(x)#输入输入值
loss=loss_func(prediction,y)#计算误差预测值和真实值之间的误差,注意位置
optimizer.zero_grad()#梯度清零
loss.backward()#反向传递
optimizer.step()#优化梯度

可视化训练过程

for t in range(200):#训练200次
prediction=net(x)#输入输入值
loss=loss_func(prediction,y)#计算误差预测值和真实值之间的误差,注意位置
optimizer.zero_grad()#梯度清零
loss.backward()#反向传递
optimizer.step()#优化梯度
# 接着上面来
if t % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1)

完整代码

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)#一维变二维
y=x.pow(2)+0.2*torch.rand(x.size()) # plt.scatter(x.data.numpy(),y.data.numpy())
# plt.show() x,y=Variable(x),Variable(y)#构造神经网络的是琥珀要使用Variable类型的 class Net(torch.nn.Module):
def __init__(self,n_feature,n_hidden,n_ouput):#初始化信息
super(Net, self).__init__()
self.hidden=torch.nn.Linear(n_feature,n_hidden,n_ouput)#隐藏层线性输出
self.predict=torch.nn.Linear(n_hidden,n_ouput)#输出层线性输出 def forward(self,x):#前向传递的过程
#正向传播输入值,神经网络输出预测值
x=F.relu(self.hidden(x))#激励函数加工一下
x=self.predict(x)#输出值预测值
return x net=Net(n_feature=1,n_hidden=10,n_ouput=1)#输入值是一个,隐藏层有10个神经元,输出值为y值
print(net) optimizer=torch.optim.SGD(net.parameters(),lr=0.5)#输入神经网络的所有参数,学习效率,这个是训练工具
loss_func=torch.nn.MSELoss()#误差处理均方差 plt.ion() # 画图
plt.show() for t in range(200):#训练200次
prediction=net(x)#输入输入值
loss=loss_func(prediction,y)#计算误差预测值和真实值之间的误差,注意位置
optimizer.zero_grad()#梯度清零
loss.backward()#反向传递
optimizer.step()#优化梯度
# 接着上面来
if t % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1)

过程结果

中间过程省略一部分...

【pytorch】学习笔记(四)-搭建神经网络进行关系拟合的更多相关文章

  1. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  2. ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试

    http://www.cnblogs.com/denny402/p/5852983.html ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试   刚开始学习tf时,我们从 ...

  3. Go语言学习笔记四: 运算符

    Go语言学习笔记四: 运算符 这章知识好无聊呀,本来想跨过去,但没准有初学者要学,还是写写吧. 运算符种类 与你预期的一样,Go的特点就是啥都有,爱用哪个用哪个,所以市面上的运算符基本都有. 算术运算 ...

  4. kvm虚拟化学习笔记(四)之kvm虚拟机日常管理与配置

    KVM虚拟化学习笔记系列文章列表----------------------------------------kvm虚拟化学习笔记(一)之kvm虚拟化环境安装http://koumm.blog.51 ...

  5. MySql学习笔记四

    MySql学习笔记四 5.3.数据类型 数值型 整型 小数 定点数 浮点数 字符型 较短的文本:char, varchar 较长的文本:text, blob(较长的二进制数据) 日期型 原则:所选择类 ...

  6. 官网实例详解-目录和实例简介-keras学习笔记四

    官网实例详解-目录和实例简介-keras学习笔记四 2018-06-11 10:36:18 wyx100 阅读数 4193更多 分类专栏: 人工智能 python 深度学习 keras   版权声明: ...

  7. ZooKeeper学习笔记四:使用ZooKeeper实现一个简单的分布式锁

    作者:Grey 原文地址: ZooKeeper学习笔记四:使用ZooKeeper实现一个简单的分布式锁 前置知识 完成ZooKeeper集群搭建以及熟悉ZooKeeperAPI基本使用 需求 当多个进 ...

  8. C#可扩展编程之MEF学习笔记(四):见证奇迹的时刻

    前面三篇讲了MEF的基础和基本到导入导出方法,下面就是见证MEF真正魅力所在的时刻.如果没有看过前面的文章,请到我的博客首页查看. 前面我们都是在一个项目中写了一个类来测试的,但实际开发中,我们往往要 ...

  9. IOS学习笔记(四)之UITextField和UITextView控件学习

    IOS学习笔记(四)之UITextField和UITextView控件学习(博客地址:http://blog.csdn.net/developer_jiangqq) Author:hmjiangqq ...

随机推荐

  1. 关于项目在网页中运行部分jsp出现乱码(由request.getRequestDispatcher("XXX.jsp").forward(request, response)造成)的解决方法

    在写jsp的时候发现部分的jsp在浏览器预览时出现乱码,为一堆问号,如图: 当时问了同学,只有部分jsp会出现乱码,因为重新建一个jsp在运行就没有错误,可以显示出来,所以发现是jsp头部的错误,当新 ...

  2. mysql gis基本使用

    # 插入空间数据 INSERT INTO `t_pot` VALUES ('1', '北京', POINT(116.401394,39.916042)); INSERT INTO `t_pot` VA ...

  3. Inter IPP 绘图 ippi/ipps

    IPP的资料网上比较少,主要还是参考Inter官网和文档 官方文档ipps.pdf主要是对数据做处理,包括加减乘除.FFT.DFT等 文档ippi.pdf只要是对图像做处理,包括通道转换.图片处理等 ...

  4. CSS效果篇--这里有你想要的CSS3漂亮的自定义Checkbox各种复选框

    在原来有一篇文章写到了<CSS效果篇--纯CSS+HTML实现checkbox的思路与实例>.这篇文章主要写各种自定义的checkbox复选框,实现如图所示的复选框: 大致的html代码都 ...

  5. java编程出现的错误对应的解决方法

    error: could not open D:\java\jre1.8\lib\amd64\jvm.cfg 解决方法:把java的环境变量%JAVA_HOME%/bin上移到最上面 优化 查看网页源 ...

  6. better-scroll 介绍

    碰到一个项目,应该遵守两大规则: 1. 不要让项目产生过多的第三方依赖 2. 增强组件的应用率 尽可能的将东西写在组件里面,尽可能的将数据写活,通过组件通信来进行数据转换,用到的依赖处理,我们可以通过 ...

  7. Python全栈开发第5天作业

    作业一:1) 将用户信息数据库文件和组信息数据库文件纵向合并为一个文件/1.txt(覆盖) 2) 将用户信息数据库文件和用户密码数据库文件纵向合并为一个文件/2.txt(追加) 3) 将/1.txt. ...

  8. SDK location not found. Define location with sdk.dir in the local.properties file or with an ANDROID_HOME environment variable.

    问题描述: 已经安装了android-sdk 和gradle环境,并配置了环境变量,如下所示: android环境 root@wangju-HP--G4:/home/wangju/Desktop/5i ...

  9. flask_sqlalchemy的session线程安全源码解读

    flask_sqlalchemy是如何在多线程中对数据库操作不相互影响 数据库操作隔离 结论:使用scoped_session实现数据库操作隔离 flask的api.route()接收一个请求,就会创 ...

  10. 转: Android 设备的远程调试入门

    从 Windows.Mac 或 Linux 计算机远程调试 Android 设备上的实时内容. 本教程将向您展示如何: 设置您的 Android 设备进行远程调试,并从开发计算机上发现设备. 从您的开 ...