Pytorch:使用GPU训练
1.模型转为cuda
gpus = [0] #使用哪几个GPU进行训练,这里选择0号GPU
cuda_gpu = torch.cuda.is_available() #判断GPU是否存在可用
net = Net(12288, 25, 16, 6)
if(cuda_gpu):
net = torch.nn.DataParallel(net, device_ids=gpus).cuda() #将模型转为cuda类型
2.数据转为cuda
(minibatchX, minibatchY) = minibatch
minibatchX = minibatchX.astype(np.float32).T
minibatchY = minibatchY.astype(np.float32).T
if(cuda_gpu):
b_x = Variable(torch.from_numpy(minibatchX).cuda()) #将数据转为cuda类型
b_y = Variable(torch.from_numpy(minibatchY).cuda())
else:
b_x = Variable(torch.from_numpy(minibatchX))
b_y = Variable(torch.from_numpy(minibatchY))
3.输出数据去cuda,转为numpy
correct_prediction = sum(torch.max(output, 1)[1].data.squeeze() == torch.max(b_y, 1)[1].data.squeeze())
if(cuda_gpu):
correct_prediction = correct_prediction.cpu().numpy() #.cpu将cuda转为tensor类型,.numpy将tensor转为numpy类型
else:
correct_prediction = correct_prediction.numpy()
linux输入nvidia-smi,可以看到调用GPU成功!
Pytorch:使用GPU训练的更多相关文章
- Pytorch多GPU训练
Pytorch多GPU训练 临近放假, 服务器上的GPU好多空闲, 博主顺便研究了一下如何用多卡同时训练 原理 多卡训练的基本过程 首先把模型加载到一个主设备 把模型只读复制到多个设备 把大的batc ...
- pytorch 多GPU训练总结(DataParallel的使用)
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明.本文链接:https://blog.csdn.net/weixin_40087578/artic ...
- pytorch 指定GPU训练
# 1: torch.cuda.set_device(1) # 2: device = torch.device("cuda:1") # 3:(官方推荐)import os os. ...
- pytorch 多GPU训练过程中出现ap=0情况
原因可能是pytorch 自带的BN bug:安装nvidia apex 可以解决: $ git clone https://github.com/NVIDIA/apex $ cd apex $ pi ...
- Pytorch中多GPU训练指北
前言 在数据越来越多的时代,随着模型规模参数的增多,以及数据量的不断提升,使用多GPU去训练是不可避免的事情.Pytorch在0.4.0及以后的版本中已经提供了多GPU训练的方式,本文简单讲解下使用P ...
- PyTorch Tutorials 4 训练一个分类器
%matplotlib inline 训练一个分类器 上一讲中已经看到如何去定义一个神经网络,计算损失值和更新网络的权重. 你现在可能在想下一步. 关于数据? 一般情况下处理图像.文本.音频和视频数据 ...
- Pytorch使用分布式训练,单机多卡
pytorch的并行分为模型并行.数据并行 左侧模型并行:是网络太大,一张卡存不了,那么拆分,然后进行模型并行训练. 右侧数据并行:多个显卡同时采用数据训练网络的副本. 一.模型并行 二.数据并行 数 ...
- MinkowskiEngine多GPU训练
MinkowskiEngine多GPU训练 目前,MinkowskiEngine通过数据并行化支持Multi-GPU训练.在数据并行化中,有一组微型批处理,这些微型批处理将被送到到网络的一组副本中. ...
- 使用Deeplearning4j进行GPU训练时,出错的解决方法
一.问题 使用deeplearning4j进行GPU训练时,可能会出现java.lang.UnsatisfiedLinkError: no jnicudnn in java.library.path错 ...
- tensorflow使用多个gpu训练
关于多gpu训练,tf并没有给太多的学习资料,比较官方的只有:tensorflow-models/tutorials/image/cifar10/cifar10_multi_gpu_train.py ...
随机推荐
- python 列表反转
反转: 将原列表反转,返回None: li = [1, 2, 3]li.reverse()print(li)# [3, 2, 1]1234不改变原列表,返回反转后的新列表: li = [1, 2, 3 ...
- SSD训练网络参数计算
一个预测层的网络结构如下所示: 可以看到,是由三个分支组成的,分别是"PriorBox"层,以及conf.loc的预测层,其中,conf与loc的预测层的参数是由PriorBox的 ...
- LeetCode:182.查找重复的电子邮箱
题目链接:https://leetcode-cn.com/problems/duplicate-emails/ 题目 编写一个 SQL 查询,查找 Person 表中所有重复的电子邮箱. 示例: +- ...
- 409 Conflict - PUT https://registry.npm.taobao.org/-/user/org.couchdb.user:zphtown - [conflict] User xxx already exists
解决方法cmd执行 npm config set registry https://registry.npmjs.org/ 为什么,参考此文档:https://blog.csdn.net/adc_go ...
- CSS选择器(通配符选择器、标签选择器、类选择器、id选择器、群组选择器、后代选择器、子元素选择器和相邻元素选择器)
通配符选择器 * 与任何元素匹配 派生选择器: 后代选择器(包含选择器):后代选择器可以选择作为元素后代的元素 A B 对A元素中的B元素应用样式 后代选择器中两个元素间的层次间隔可以是无 ...
- 采购订单保存生成PO号后增强点。
EXIT_SAPMM06E_013 这个增强可用于生成的PO后,调用外部接口把变更或生成的PO信息下发出去. 这里面的参数 I_EKKO 是新的抬头 I_EKKO_OLD 是更改前的抬头 XEKPO ...
- Java注解【三、注解的分类】
按运行机制分 源码注解 只在源码中存在 编译时注解 在class中依然存在,如@Deprecated 运行时注解 运行阶段起作用,如@Autowired 按来源分 JDK自带注解 三方注解 最常见 自 ...
- 常见排序&查询算法Java代码实现
1. 排序算法代码实现 /** * ascending sort * 外层循环边界条件:总共需要冒泡的轮数--每一轮都将最大或最小的数冒泡到最后 * 内层循环边界条件:冒泡数字移动的边界--最终数字需 ...
- Mysql(四)-1:单表查询
一 单表查询的语法 SELECT 字段1,字段2... FROM 表名 WHERE 条件 GROUP BY field HAVING 筛选 ORDER BY field LIMIT 限制条数 二 关键 ...
- STM32WB AHB总线、APB总线与外设
方框图: 如图所示: 1)APB1外设 2)APB2外设 3)AHB1外设 4)AHB2外设 5)AHB3外设 6)AHB4外设(ABH共享总线外设) 内存映射关系图: