Pytorch实战学习(一):用Pytorch实现线性回归
《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili
P5--用Pytorch实现线性回归
建立模型四大步骤
一、Prepare dataset
mini-batch:x、y必须是矩阵
- ## Prepare Dataset:mini-batch, X、Y是3X1的Tensor
- x_data = torch.Tensor([[1.0], [2.0], [3.0]])
- y_data = torch.Tensor([[2.0], [4.0], [6.0]])
二、Design model
1、重点是构造计算图
- ##Design Model
- ##构造类,继承torch.nn.Module类
- class LinearModel(torch.nn.Module):
- ## 构造函数,初始化对象
- def __init__(self):
- ##super调用父类
- super(LinearModel, self).__init__()
- ##构造对象,Linear Unite,包含两个Tensor:weight和bias,参数(1, 1)是w的维度
- self.linear = torch.nn.Linear(1, 1)
- ## 构造函数,前馈运算
- def forward(self, x):
- ## w*x+b
- y_pred = self.linear(x)
- return y_pred
- model = LinearModel()
2、设置w的维度,后一层的神经元数量 X 前一层神经元数量
三、Construct Loss and Optimizer
- ##Construct Loss and Optimizer
- ##损失函数,传入y和y_presd
- criterion = torch.nn.MSELoss(size_average = False)
- ##优化器,model.parameters()找出模型所有的参数,Lr--学习率
- optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
1、损失函数
2、优化器
可用不同的优化器进行测试对比
四、Training cycle
- ## Training cycle
- for epoch in range(100):
- y_pred = model(x_data)
- loss = criterion(y_pred, y_data)
- print(epoch, loss)
- ##梯度归零
- optimizer.zero_grad()
- ##反向传播
- loss.backward()
- ##更新
- optimizer.step()
完整代码
- import torch
- ## Prepare Dataset:mini-batch, X、Y是3X1的Tensor
- x_data = torch.Tensor([[1.0], [2.0], [3.0]])
- y_data = torch.Tensor([[2.0], [4.0], [6.0]])
- ##Design Model
- ##构造类,继承torch.nn.Module类
- class LinearModel(torch.nn.Module):
- ## 构造函数,初始化对象
- def __init__(self):
- ##super调用父类
- super(LinearModel, self).__init__()
- ##构造对象,Linear Unite,包含两个Tensor:weight和bias,参数(1, 1)是w的维度
- self.linear = torch.nn.Linear(1, 1)
- ## 构造函数,前馈运算
- def forward(self, x):
- ## w*x+b
- y_pred = self.linear(x)
- return y_pred
- model = LinearModel()
- ##Construct Loss and Optimizer
- ##损失函数,传入y和y_presd
- criterion = torch.nn.MSELoss(size_average = False)
- ##优化器,model.parameters()找出模型所有的参数,Lr--学习率
- optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
- ## Training cycle
- for epoch in range(100):
- y_pred = model(x_data)
- loss = criterion(y_pred, y_data)
- print(epoch, loss)
- ##梯度归零
- optimizer.zero_grad()
- ##反向传播
- loss.backward()
- ##更新
- optimizer.step()
- ## Outpue weigh and bias
- print('w = ', model.linear.weight.item())
- print('b = ', model.linear.bias.item())
- ## Test Model
- x_test = torch.Tensor([[4.0]])
- y_test = model(x_test)
- print('y_pred = ', y_test.data)
运行结果
训练100次后,得到的 weight and bias,还有预测的y
Pytorch实战学习(一):用Pytorch实现线性回归的更多相关文章
- 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码
PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
- 深度学习之PyTorch实战(1)——基础学习及搭建环境
最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...
- 参考《深度学习之PyTorch实战计算机视觉》PDF
计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...
- 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化
上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...
- PyTorch 实战:计算 Wasserstein 距离
PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...
- pytorch例子学习-DATA LOADING AND PROCESSING TUTORIAL
参考:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html DATA LOADING AND PROCESSING TUT ...
- pytorch的学习资源
安装:https://github.com/pytorch/pytorch 文档:http://pytorch.org/tutorials/beginner/blitz/tensor_tutorial ...
- 【pytorch】学习笔记(三)-激励函数
[pytorch]学习笔记-激励函数 学习自:莫烦python 什么是激励函数 一句话概括 Activation: 就是让神经网络可以描述非线性问题的步骤, 是神经网络变得更强大 1.激活函数是用来加 ...
- 【pytorch】学习笔记(二)- Variable
[pytorch]学习笔记(二)- Variable 学习链接自莫烦python 什么是Variable Variable就好像一个篮子,里面装着鸡蛋(Torch 的 Tensor),里面的鸡蛋数不断 ...
随机推荐
- spring cloud alibaba - Nacos 下载安装
1.关于名字 前四个字母分别为Naming和Configuration的前两个字母,最后的s为Service 2.是什么 一个更易于构建云原生应用的动态服务发现,配置管理和服务管理中心.是注册中心和配 ...
- Activiti02流程基本功能使用
主要分为一下几个步骤: 1.画图 2.部署流程-把图的信息转入到数据表格中 3.创建流程实例-开始一个流程-实际发起了一个流程 4.执行任务:获取任务+完成任务 1.画图 画了一个简单的流程图,图形文 ...
- vuluhub_jangow-01-1.0.1
前言 靶机:jangow-01-1.0.1 攻击机:kali linux2022.4 靶机描述 打靶ing 靶机探测 使用nmap扫描网段 点击查看代码 ┌──(root㉿kali)-[/home/k ...
- 【WinForm】窗体之间传值的几种方式
方法1:设置公共静态变量传值 eg: 1 public partial class mianForm 2 { 3 //声明i 为公共静态变量 4 public static string i = &q ...
- Spring Boot整合JSP --CRUD
Springboot整合JSP spring boot与视图层次的整合: JSP 效率低 Thymeleaf java Server page 是Java提供的一种动态的网页技术,低层是Servlet ...
- Spring框架-AOP核心
Spring AOP AOP (Aspect Oriented Programming) 面向切面编程 OOP(Object Oriented Programming)面向对象编程,用对象的思想来完善 ...
- Python核对遥感影像批量下载情况的方法
本文介绍批量下载遥感影像时,利用Python实现已下载影像文件的核对,并自动生成未下载影像的下载链接列表的方法. 批量下载大量遥感影像数据对于GIS学生与从业人员可谓十分常见.然而,对于动辄成 ...
- .net 中的几种事务
在一个MIS系统中,没有用事务那就绝对是有问题的,要么就只有一种情况:你的系统实在是太小了,业务业务逻辑有只要一步执行就可以完成了.因此掌握事务处理的方法是很重要,进我的归类在.net中大致有以下4种 ...
- Git远程提交的冲突解决
先本地直接提交代码:git push origin master 如果别人在自己之前提交了修改,git会提示push失败,需要先pull远程代码:git pull origin/master (拉取远 ...
- PostgresSql更改字段位置后,数据库异常
SQL server的studio有一个功能,可以随意拖拽表字段,更改其位置并使之重新排序,有同事问起,Postgres是否也可以.Postgres每个字段的顺序是在系统表pg_attribute里面 ...