示例:

  1. import torch
  2. import torch.nn as nn
  3. from torch import optim
  4.  
  5. class MyModel(nn.Module):
  6.  
  7. def __init__(self):
  8. super(MyModel,self).__init__()
  9. self.lr = nn.Linear(1,1)
  10.  
  11. def forward(self,x):
  12. return self.lr(x)
  13.  
  14. #准备数据
  15.  
  16. x= torch.rand([500,1])
  17. y_true = 3*x+0.8
  18. #1.实例化模型
  19. model = MyModel()
  20. #2.实例化优化器
  21. optimizer = optim.Adam(model.parameters(),lr=0.1)
  22. #3.实例化损失函数
  23. loss_fn = nn.MSELoss()
  24.  
  25. for i in range(500):
  26. #4.梯度置为0
  27. optimizer.zero_grad()
  28. #5.调用模型得到预测值
  29. y_predict = model(x)
  30. #6.通过损失函数,计算得到损失
  31. loss = loss_fn(y_predict,y_true)
  32. #7.反向传播,计算梯度
  33. loss.backward()
  34. #8.更新参数
  35. optimizer.step()
  36.  
  37. #打印部分数据
  38. if i%10 ==0:
  39. print(i,loss.item())
  40.  
  41. for param in model.parameters():
  42. print(param.item())

  

使用英伟达显卡CUDA模式加速计算:

  1. import torch
    import torch.nn as nn
    from torch import optim
    import time
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  2.  
  3. class MyModel(nn.Module):
  4.  
  5. def __init__(self):
    super(MyModel,self).__init__()
    self.lr = nn.Linear(1,1)
  6.  
  7. def forward(self,x):
    return self.lr(x)
  8.  
  9. #准备数据
  10.  
  11. x= torch.rand([500,1]).to(device=device)
    y_true = 3*x+0.8
    #1.实例化模型
    model = MyModel().to(device)
    #2.实例化优化器
    optimizer = optim.Adam(model.parameters(),lr=0.1)
    #3.实例化损失函数
    loss_fn = nn.MSELoss()
    start = time.time()
    for i in range(500):
    #4.梯度置为0
    optimizer.zero_grad()
    #5.调用模型得到预测值
    y_predict = model(x)
    #6.通过损失函数,计算得到损失
    loss = loss_fn(y_predict,y_true)
    #7.反向传播,计算梯度
    loss.backward()
    #8.更新参数
    optimizer.step()
  12.  
  13. #打印部分数据
    if i%10 ==0:
    print(i,loss.item())
  14.  
  15. for param in model.parameters():
    print(param.item())
  16.  
  17. end = time.time()
  18.  
  19. print(end-start)

  

pytorch-API实现线性回归的更多相关文章

  1. Pytorch手写线性回归

    pytorch手写线性回归 import torch import matplotlib.pyplot as plt from matplotlib.animation import FuncAnim ...

  2. Pytorch 实现简单线性回归

    Pytorch 实现简单线性回归 问题描述: 使用 pytorch 实现一个简单的线性回归. 受教育年薪与收入数据集 单变量线性回归 单变量线性回归算法(比如,$x$ 代表学历,$f(x)$ 代表收入 ...

  3. Spark(十一) -- Mllib API编程 线性回归、KMeans、协同过滤演示

    本文测试的Spark版本是1.3.1 在使用Spark的机器学习算法库之前,需要先了解Mllib中几个基础的概念和专门用于机器学习的数据类型 特征向量Vector: Vector的概念是和数学中的向量 ...

  4. pytorch API中sgd.py的学习记录

    参考:PyTorch与caffe中SGD算法实现的一点小区别 其中公式(3)(4)的符号有问题 变量对应表 程序 参考文章 buf v momentum μ d_p Δf(θ) lr ξ p θ

  5. python pytorch numpy DNN 线性回归模型

    1.直接奉献代码,后期有入门更新,之前一直在学的是TensorFlow, import torch from torch.autograd import Variable import torch.n ...

  6. pytorch实现手动线性回归

    import torch import matplotlib.pyplot as plt learning_rate = 0.1 #准备数据 #y = 3x +0.8 x = torch.randn( ...

  7. 【深度学习 01】线性回归+PyTorch实现

    1. 线性回归 1.1 线性模型 当输入包含d个特征,预测结果表示为: 记x为样本的特征向量,w为权重向量,上式可表示为: 对于含有n个样本的数据集,可用X来表示n个样本的特征集合,其中行代表样本,列 ...

  8. 分别基于TensorFlow、PyTorch、Keras的深度学习动手练习项目

    ×下面资源个人全都跑了一遍,不会出现仅是字符而无法运行的状况,运行环境: Geoffrey Hinton在多次访谈中讲到深度学习研究人员不要仅仅只停留在理论上,要多编程.个人在学习中也体会到单单的看理 ...

  9. Keras vs. PyTorch

    We strongly recommend that you pick either Keras or PyTorch. These are powerful tools that are enjoy ...

  10. ELMo解读(论文 + PyTorch源码)

    ELMo的概念也是很早就出了,应该是18年初的事情了.但我仍然是后知后觉,居然还是等BERT出来很久之后,才知道有这么个东西.这两天才仔细看了下论文和源码,在这里做一些记录,如果有不详实的地方,欢迎指 ...

随机推荐

  1. MFC之登录框的问题处理

    1.在做登录框的时候,把登录框做成模态对话框,并且放在 主界面程序所在窗口打开之前.也就是放在主界面类的InitInstance()里.这样做就会在弹出主界面之前被登录框弹出模态框出来阻塞住. 1.但 ...

  2. Python IDE ——Anaconda+PyCharm的安装与配置

    一 前言 最近莫名其妙地想学习一下Python,想着利用业余时间学习一下机器学习(或许仅仅是脑子一热吧).借着研究生期间对于PyCharm安装的印象,在自己的电脑上重新又安装了一遍.利用周末的一点时间 ...

  3. WeChat-SmallProgram:微信小程序中使用Async-await方法异步请求变为同步请求

    微信小程序中有些 Api 是异步的,无法直接进行同步处理.例如:wx.request.wx.showToast.wx.showLoading 等.如果需要同步处理,可以使用如下方法: 提示:Async ...

  4. 非常诡异的IIS下由配置文件加上svg的mime头导致整个网站的静态文件访问报错误

    调试了两天遇到一个非常诡异的问题 一个系统稳定运行了很多年,是用mvc5+WIN2008R2  + .NET 4.5 +IIS环境下运行,非常稳定,最近想迁移到一台新的服务器,为了少麻烦在阿里云上买了 ...

  5. Kubernets中获取客户端真实IP总结

    1. 导言 绝大多数业务场景都是需要知道客户端IP的 在k8s中运行的业务项目,如何获取到客户端真实IP? 本文总结了通行的2种方式 要答案的直接看方式一.方式二和总结 SEO 关键字 nginx i ...

  6. Django-使用 include() 配置 URL

    如果项目非常庞大,应用非常多,应用的 URL 都写在根 urls.py 配置文件中的话,会显的非常杂乱,还会出现名称冲突之类的问题,这样对开发整个项目是非常不利的. 可以这样解决,把每个应用的 URL ...

  7. MyBatis 学习笔记(1)

    MyBatis 的基本构成 SqlSessionFactoryBuilder(构造器):它会根据配置信息或者代码来生成 SqlSessionFactory(工厂接口) SqlSessionFactor ...

  8. LeetCode48, 如何让矩阵原地旋转90度

    本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是LeetCode第29篇,我们来看一道简单的矩阵旋转问题. 题意 题目的要求很简单,给定一个二维方形矩阵,要求返回矩阵旋转90度之后的 ...

  9. 如何连接到Oracle数据库?

    如何连接到Oracle数据库?   使用SQL * Plus连接Oracle数据库服务器 SQL * Plus是交互式查询工具,我们在安装Oracle数据库服务器或客户端时会自动安装.SQL * Pl ...

  10. Hadoop(三)HDFS写数据的基本流程

    HDFS写数据的流程 HDFS shell上传文件a.txt,300M 对文件分块,默认每块128M. shell向NameNode发送上传文件请求 NameNode检测文件系统目录树,看能否上传 N ...