pytorch-API实现线性回归
示例:
- import torch
- import torch.nn as nn
- from torch import optim
- class MyModel(nn.Module):
- def __init__(self):
- super(MyModel,self).__init__()
- self.lr = nn.Linear(1,1)
- def forward(self,x):
- return self.lr(x)
- #准备数据
- x= torch.rand([500,1])
- y_true = 3*x+0.8
- #1.实例化模型
- model = MyModel()
- #2.实例化优化器
- optimizer = optim.Adam(model.parameters(),lr=0.1)
- #3.实例化损失函数
- loss_fn = nn.MSELoss()
- for i in range(500):
- #4.梯度置为0
- optimizer.zero_grad()
- #5.调用模型得到预测值
- y_predict = model(x)
- #6.通过损失函数,计算得到损失
- loss = loss_fn(y_predict,y_true)
- #7.反向传播,计算梯度
- loss.backward()
- #8.更新参数
- optimizer.step()
- #打印部分数据
- if i%10 ==0:
- print(i,loss.item())
- for param in model.parameters():
- print(param.item())
使用英伟达显卡CUDA模式加速计算:
- import torch
import torch.nn as nn
from torch import optim
import time
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")- class MyModel(nn.Module):
- def __init__(self):
super(MyModel,self).__init__()
self.lr = nn.Linear(1,1)- def forward(self,x):
return self.lr(x)- #准备数据
- x= torch.rand([500,1]).to(device=device)
y_true = 3*x+0.8
#1.实例化模型
model = MyModel().to(device)
#2.实例化优化器
optimizer = optim.Adam(model.parameters(),lr=0.1)
#3.实例化损失函数
loss_fn = nn.MSELoss()
start = time.time()
for i in range(500):
#4.梯度置为0
optimizer.zero_grad()
#5.调用模型得到预测值
y_predict = model(x)
#6.通过损失函数,计算得到损失
loss = loss_fn(y_predict,y_true)
#7.反向传播,计算梯度
loss.backward()
#8.更新参数
optimizer.step()- #打印部分数据
if i%10 ==0:
print(i,loss.item())- for param in model.parameters():
print(param.item())- end = time.time()
- print(end-start)
pytorch-API实现线性回归的更多相关文章
- Pytorch手写线性回归
pytorch手写线性回归 import torch import matplotlib.pyplot as plt from matplotlib.animation import FuncAnim ...
- Pytorch 实现简单线性回归
Pytorch 实现简单线性回归 问题描述: 使用 pytorch 实现一个简单的线性回归. 受教育年薪与收入数据集 单变量线性回归 单变量线性回归算法(比如,$x$ 代表学历,$f(x)$ 代表收入 ...
- Spark(十一) -- Mllib API编程 线性回归、KMeans、协同过滤演示
本文测试的Spark版本是1.3.1 在使用Spark的机器学习算法库之前,需要先了解Mllib中几个基础的概念和专门用于机器学习的数据类型 特征向量Vector: Vector的概念是和数学中的向量 ...
- pytorch API中sgd.py的学习记录
参考:PyTorch与caffe中SGD算法实现的一点小区别 其中公式(3)(4)的符号有问题 变量对应表 程序 参考文章 buf v momentum μ d_p Δf(θ) lr ξ p θ
- python pytorch numpy DNN 线性回归模型
1.直接奉献代码,后期有入门更新,之前一直在学的是TensorFlow, import torch from torch.autograd import Variable import torch.n ...
- pytorch实现手动线性回归
import torch import matplotlib.pyplot as plt learning_rate = 0.1 #准备数据 #y = 3x +0.8 x = torch.randn( ...
- 【深度学习 01】线性回归+PyTorch实现
1. 线性回归 1.1 线性模型 当输入包含d个特征,预测结果表示为: 记x为样本的特征向量,w为权重向量,上式可表示为: 对于含有n个样本的数据集,可用X来表示n个样本的特征集合,其中行代表样本,列 ...
- 分别基于TensorFlow、PyTorch、Keras的深度学习动手练习项目
×下面资源个人全都跑了一遍,不会出现仅是字符而无法运行的状况,运行环境: Geoffrey Hinton在多次访谈中讲到深度学习研究人员不要仅仅只停留在理论上,要多编程.个人在学习中也体会到单单的看理 ...
- Keras vs. PyTorch
We strongly recommend that you pick either Keras or PyTorch. These are powerful tools that are enjoy ...
- ELMo解读(论文 + PyTorch源码)
ELMo的概念也是很早就出了,应该是18年初的事情了.但我仍然是后知后觉,居然还是等BERT出来很久之后,才知道有这么个东西.这两天才仔细看了下论文和源码,在这里做一些记录,如果有不详实的地方,欢迎指 ...
随机推荐
- MFC之登录框的问题处理
1.在做登录框的时候,把登录框做成模态对话框,并且放在 主界面程序所在窗口打开之前.也就是放在主界面类的InitInstance()里.这样做就会在弹出主界面之前被登录框弹出模态框出来阻塞住. 1.但 ...
- Python IDE ——Anaconda+PyCharm的安装与配置
一 前言 最近莫名其妙地想学习一下Python,想着利用业余时间学习一下机器学习(或许仅仅是脑子一热吧).借着研究生期间对于PyCharm安装的印象,在自己的电脑上重新又安装了一遍.利用周末的一点时间 ...
- WeChat-SmallProgram:微信小程序中使用Async-await方法异步请求变为同步请求
微信小程序中有些 Api 是异步的,无法直接进行同步处理.例如:wx.request.wx.showToast.wx.showLoading 等.如果需要同步处理,可以使用如下方法: 提示:Async ...
- 非常诡异的IIS下由配置文件加上svg的mime头导致整个网站的静态文件访问报错误
调试了两天遇到一个非常诡异的问题 一个系统稳定运行了很多年,是用mvc5+WIN2008R2 + .NET 4.5 +IIS环境下运行,非常稳定,最近想迁移到一台新的服务器,为了少麻烦在阿里云上买了 ...
- Kubernets中获取客户端真实IP总结
1. 导言 绝大多数业务场景都是需要知道客户端IP的 在k8s中运行的业务项目,如何获取到客户端真实IP? 本文总结了通行的2种方式 要答案的直接看方式一.方式二和总结 SEO 关键字 nginx i ...
- Django-使用 include() 配置 URL
如果项目非常庞大,应用非常多,应用的 URL 都写在根 urls.py 配置文件中的话,会显的非常杂乱,还会出现名称冲突之类的问题,这样对开发整个项目是非常不利的. 可以这样解决,把每个应用的 URL ...
- MyBatis 学习笔记(1)
MyBatis 的基本构成 SqlSessionFactoryBuilder(构造器):它会根据配置信息或者代码来生成 SqlSessionFactory(工厂接口) SqlSessionFactor ...
- LeetCode48, 如何让矩阵原地旋转90度
本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是LeetCode第29篇,我们来看一道简单的矩阵旋转问题. 题意 题目的要求很简单,给定一个二维方形矩阵,要求返回矩阵旋转90度之后的 ...
- 如何连接到Oracle数据库?
如何连接到Oracle数据库? 使用SQL * Plus连接Oracle数据库服务器 SQL * Plus是交互式查询工具,我们在安装Oracle数据库服务器或客户端时会自动安装.SQL * Pl ...
- Hadoop(三)HDFS写数据的基本流程
HDFS写数据的流程 HDFS shell上传文件a.txt,300M 对文件分块,默认每块128M. shell向NameNode发送上传文件请求 NameNode检测文件系统目录树,看能否上传 N ...