《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili

P5--用Pytorch实现线性回归

建立模型四大步骤

一、Prepare dataset

mini-batch:x、y必须是矩阵

## Prepare Dataset:mini-batch, X、Y是3X1的Tensor
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

二、Design model

1、重点是构造计算图

##Design Model

##构造类,继承torch.nn.Module类
class LinearModel(torch.nn.Module):
## 构造函数,初始化对象
def __init__(self):
##super调用父类
super(LinearModel, self).__init__()
##构造对象,Linear Unite,包含两个Tensor:weight和bias,参数(1, 1)是w的维度
self.linear = torch.nn.Linear(1, 1) ## 构造函数,前馈运算
def forward(self, x):
## w*x+b
y_pred = self.linear(x)
return y_pred model = LinearModel()

2、设置w的维度,后一层的神经元数量 X 前一层神经元数量

三、Construct Loss and Optimizer

##Construct Loss and Optimizer

##损失函数,传入y和y_presd
criterion = torch.nn.MSELoss(size_average = False) ##优化器,model.parameters()找出模型所有的参数,Lr--学习率
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

1、损失函数

2、优化器

可用不同的优化器进行测试对比

四、Training cycle

## Training cycle

for epoch in range(100):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
print(epoch, loss) ##梯度归零
optimizer.zero_grad()
##反向传播
loss.backward()
##更新
optimizer.step()

完整代码

import torch

## Prepare Dataset:mini-batch, X、Y是3X1的Tensor
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]]) ##Design Model ##构造类,继承torch.nn.Module类
class LinearModel(torch.nn.Module):
## 构造函数,初始化对象
def __init__(self):
##super调用父类
super(LinearModel, self).__init__()
##构造对象,Linear Unite,包含两个Tensor:weight和bias,参数(1, 1)是w的维度
self.linear = torch.nn.Linear(1, 1) ## 构造函数,前馈运算
def forward(self, x):
## w*x+b
y_pred = self.linear(x)
return y_pred model = LinearModel() ##Construct Loss and Optimizer ##损失函数,传入y和y_presd
criterion = torch.nn.MSELoss(size_average = False) ##优化器,model.parameters()找出模型所有的参数,Lr--学习率
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) ## Training cycle for epoch in range(100):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
print(epoch, loss) ##梯度归零
optimizer.zero_grad()
##反向传播
loss.backward()
##更新
optimizer.step() ## Outpue weigh and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item()) ## Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

运行结果

训练100次后,得到的 weight and bias,还有预测的y

Pytorch实战学习(一):用Pytorch实现线性回归的更多相关文章

  1. 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码

    PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...

  2. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

  3. 深度学习之PyTorch实战(1)——基础学习及搭建环境

    最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...

  4. 参考《深度学习之PyTorch实战计算机视觉》PDF

    计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...

  5. 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化

    上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...

  6. PyTorch 实战:计算 Wasserstein 距离

    PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...

  7. pytorch例子学习-DATA LOADING AND PROCESSING TUTORIAL

    参考:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html DATA LOADING AND PROCESSING TUT ...

  8. pytorch的学习资源

    安装:https://github.com/pytorch/pytorch 文档:http://pytorch.org/tutorials/beginner/blitz/tensor_tutorial ...

  9. 【pytorch】学习笔记(三)-激励函数

    [pytorch]学习笔记-激励函数 学习自:莫烦python 什么是激励函数 一句话概括 Activation: 就是让神经网络可以描述非线性问题的步骤, 是神经网络变得更强大 1.激活函数是用来加 ...

  10. 【pytorch】学习笔记(二)- Variable

    [pytorch]学习笔记(二)- Variable 学习链接自莫烦python 什么是Variable Variable就好像一个篮子,里面装着鸡蛋(Torch 的 Tensor),里面的鸡蛋数不断 ...

随机推荐

  1. 深入解读.NET MAUI音乐播放器项目(一):概述与架构

    系列文章将分步解读音乐播放器核心业务及代码: 深入解读.NET MAUI音乐播放器项目(一):概述与架构 深入解读.NET MAUI音乐播放器项目(二):播放内核 深入解读.NET MAUI音乐播放器 ...

  2. 计算机网络基础07 DNS概述

    1 什么是DNS Domain Name System(域名系统),它是一个应用层的服务.它作为将域名和IP地址相互映射的一个分布式数据库,能够使人更方便地访问互联网.当前,对于每一级域名长度的限制是 ...

  3. 【ccc】为了ds的ccc2

    作业: #include <stdio.h> #include<string.h> int main(){ char s[100]; gets(s); int len; len ...

  4. Rpc-实现Zookeeper注册中心

    1.前言 本文章是笔主在声哥的手写RPC框架的学习下,对注册中心的一个拓展.因为声哥某些部分没有保留拓展性,所以本文章的项目与声哥的工程有部分区别,核心内容在Curator的注册发现与注销,思想看准即 ...

  5. Hugging Face 每周速递: 扩散模型课程完成中文翻译,有个据说可以教 ChatGPT 看图的模型开源了

    每一周,我们的同事都会向社区的成员们发布一些关于 Hugging Face 相关的更新,包括我们的产品和平台更新.社区活动.学习资源和内容更新.开源库和模型更新等,我们将其称之为「Hugging Ne ...

  6. RocketMQ - 生产者消息发送流程

    RocketMQ客户端的消息发送通常分为以下3层 业务层:通常指直接调用RocketMQ Client发送API的业务代码. 消息处理层:指RocketMQ Client获取业务发送的消息对象后,一系 ...

  7. 【NOIP2017提高组正式赛】列队

    题解 本题的解法是丰富多彩的! 线段树做法是极好的 代码非常之少 一个很显然的想法是维护 \(n+1\) 颗线段树 那要怎么维护才能不爆空间呢? 我们发现尽管 \(n \times m\) 那么大 但 ...

  8. Blender减面修改器

    推荐:使用 NSDT场景设计器 快速搭建 3D场景. 使用Decimate修改器的目的是减少雕刻或 3D 扫描模型的面数. 要使用抽取修改器,请转到对象模式并选择要减少面数的任何模型. 在对象模式中选 ...

  9. linux系统下,添加硬盘并挂载到操作系统的shell 脚本范例

    #!/bin/sh #新添加硬盘挂载到操作系统 pvcreate /dev/sdb   / / 一般新添加硬盘都是识别为sdb,当然,也不一定,要具体情况具体分析. vgcreate datavg / ...

  10. IntelliJ IDEA 程序运行的控制台乱码

    参考:https://blog.csdn.net/zp357252539/article/details/124614007 上方导航栏"Run→Edit Configurations-&q ...