pytorch如何使用GPU
在本文中,我将介绍简单如何使用GPU
pytorch是一个非常优秀的深度学习的框架,具有速度快,代码简洁,可读性强的优点。
我们使用pytorch做一个简单的回归。
首先准备数据

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable
import torch.nn as nn
x = np.random.randn(1000, 1)*4
w = np.array([0.5,])
bias = -1.68

y_true = np.dot(x, w) + bias #真实数据
y = y_true + np.random.randn(x.shape[0])#加噪声的数据
#我们需要使用x和y,以及y_true回归出w和bias
1
2
3
4
5
6
7
8
9
10
11
12
定义回归网络的类

class LinearRression(nn.Module):
def __init__(self, input_size, out_size):
super(LinearRression, self).__init__()
self.x2o = nn.Linear(input_size, out_size)
#初始化
def forward(self, x):
return self.x2o(x)
#前向传递
1
2
3
4
5
6
7
8
接下来介绍将定义模型和优化器

batch_size = 10
model = LinearRression(1, 1)#回归模型
criterion = nn.MSELoss() #损失函数
#调用cuda
model.cuda()
criterion.cuda()

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
losses = []
1
2
3
4
5
6
7
8
9
下面就是训(练)练(丹)了

for i in range(epoches):
loss = 0
optimizer.zero_grad()#清空上一步的梯度
idx = np.random.randint(x.shape[0], size=batch_size)
batch_cpu = Variable(torch.from_numpy(x[idx])).float()
batch = batch_cpu.cuda()#很重要

target_cpu = Variable(torch.from_numpy(y[idx])).float()
target = target_cpu.cuda()#很重要
output = model.forward(batch)
loss += criterion(output, target)
loss.backward()
optimizer.step()

if (i +1)%10 == 0:
print('Loss at epoch[%s]: %.3f' % (i, loss.data[0]))
losses.append(loss.data[0])

plt.plot(losses, '-or')
plt.xlabel("Epoch")
plt.xlabel("Loss")

plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
下面是训练结果

Loss at epoch[9]: 5.407
Loss at epoch[19]: 3.795
Loss at epoch[29]: 2.352
Loss at epoch[39]: 1.725
Loss at epoch[49]: 1.722
Loss at epoch[59]: 1.044
Loss at epoch[69]: 1.044
Loss at epoch[79]: 0.771
Loss at epoch[89]: 1.248
Loss at epoch[99]: 1.862
1
2
3
4
5
6
7
8
9
10
总结一下。要调用cuda执行代码需要一下步骤

model.cuda()
criterion.cuda()
1
2
3
以及

batch_cpu = Variable(torch.from_numpy(x[idx])).float()
batch = batch_cpu.cuda()
target_cpu = Variable(torch.from_numpy(y[idx])).float()
target = target_cpu.cuda()
1
2
3
4
就是将模型和输入数据变为cuda执行的
,简直超级方便,良心推荐一波pytorch

---------------------
作者:小川爱学习
来源:CSDN
原文:https://blog.csdn.net/wuichuan/article/details/66969315
版权声明:本文为博主原创文章,转载请附上博文链接!

Pytorch使用GPU的更多相关文章

  1. Pytorch多GPU训练

    Pytorch多GPU训练 临近放假, 服务器上的GPU好多空闲, 博主顺便研究了一下如何用多卡同时训练 原理 多卡训练的基本过程 首先把模型加载到一个主设备 把模型只读复制到多个设备 把大的batc ...

  2. pytorch 多GPU训练总结(DataParallel的使用)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明.本文链接:https://blog.csdn.net/weixin_40087578/artic ...

  3. Pytorch指定GPU的方法总结

    Pytorch指定GPU的方法 改变系统变量 改变系统环境变量仅使目标显卡,编辑 .bashrc文件,添加系统变量 export CUDA_VISIBLE_DEVICES=0 #这里是要使用的GPU编 ...

  4. Ubuntu下安装pytorch(GPU版)

    我这里主要参考了:https://blog.csdn.net/yimingsilence/article/details/79631567 并根据自己在安装中遇到的情况做了一些改动. 先说明一下我的U ...

  5. [转] pytorch指定GPU

    查过好几次这个命令,总是忘,转一篇mark一下吧 转自:http://www.cnblogs.com/darkknightzh/p/6836568.html PyTorch默认使用从0开始的GPU,如 ...

  6. Pytorch多GPU并行处理

    可以参数2017coco detection 旷视冠军MegDet: MegDet 与 Synchronized BatchNorm PyTorch-Encoding官方文档对CGBN(cross g ...

  7. pytorch 多GPU处理过程

    多GPU的处理机制: 使用多GPU时,pytorch的处理逻辑是: 1.在各个GPU上初始化模型. 2.前向传播时,把batch分配到各个GPU上进行计算. 3.得到的输出在主GPU上进行汇总,计算l ...

  8. Pytorch 多 GPU 并行处理机制

    Pytorch 的多 GPU 处理接口是 torch.nn.DataParallel(module, device_ids),其中 module 参数是所要执行的模型,而 device_ids 则是指 ...

  9. 怎么用 pytorch 查看 GPU 信息

    如果你用的 Keras 或者 TensorFlow, 请移步 怎么查看keras 或者 tensorflow 正在使用的GPU In [1]: import torch In [2]: torch.c ...

随机推荐

  1. web前端学习(三)css学习笔记部分(7)-- 文字和字体相关样式、盒子相关样式、背景与边框相关样式

    12.  文字和字体相关样式 12.1  CSS3 给文字添加阴影 使用 text-shadow 属性给页面上的文字添加阴影效果,text-shadow 属性是在CSS2中定义的,在 CSS2.1 中 ...

  2. CSS-DOM的小知识(一)

    在DOM编程艺术中,CSS-DOM应用很广泛. 1.style属性 通过element.style.property可以获得元素的样式,但是style属性只能够返回内嵌样式,对于外部样式表的样式和he ...

  3. python 从数据库取回来的数据中文显示为乱码

    问题:从数据库取回来的数据,中文显示为乱码. 解决办法: 此处要指定charset为utf-8(一般数据库编码都是utf8),否则读取出的中文会乱码

  4. 2019-8-31-C#-对-byte-数组进行模式搜索

    title author date CreateTime categories C# 对 byte 数组进行模式搜索 lindexi 2019-08-31 16:55:58 +0800 2018-07 ...

  5. JQuery-- 获取元素的宽高、获取浏览器的宽高和垂直滚动距离

    * 能够使用jQuery设置尺寸 * .width() width * .innerWidth() width + padding * .outerWidth() width + padding + ...

  6. TP5动态路由配置好了但是报错was not found on this server的原因以及解决方法

    问题:The requested URL /xxxx.html was not found on this server 原因:apache的重写未开启,开启重写后,问题解决, 方法如下: apach ...

  7. OFBiz 16.11.03的直接部署、eclipse部署和IDEA部署

    一.在OFBiz官网下载最新的发行版本,也就是16.11.03版本. 下载地址:http://ofbiz.apache.org/download.html   点击页面Apache OFBiz 16. ...

  8. 某input元素值每隔三位添加逗号跟去掉逗号

    //每隔三位数字加一个逗号function moneyformat(s) {    var reg = /.*\..*/;    if (reg.test(s) == true) {        n ...

  9. hackerrank---Find a string

    题目链接 在字符串a中查找字符串b出现的次数...貌似不可以用a.count() 附上代码: a = raw_input().strip() b = raw_input().strip() cnt = ...

  10. 【Leetcode堆】数据流中的第K大元素(703)

    题目 设计一个找到数据流中第K大元素的类(class).注意是排序后的第K大元素,不是第K个不同的元素. 你的 KthLargest 类需要一个同时接收整数 k 和整数数组nums 的构造器,它包含数 ...