将Pytorch模型从CPU转换成GPU
1. 如何进行迁移
对模型和相应的数据进行.cuda()处理。通过这种方式,我们就可以将内存中的数据复制到GPU的显存中去。从而可以通过GPU来进行运算了。
1.1 判定使用GPU
下载了对应的GPU版本的Pytorch之后,要确保GPU是可以进行使用的,通过torch.cuda.is_available()
的返回值来进行判断。
通过torch.cuda.device_count()
可以获得能够使用的GPU数量。其他就不多赘述了。
常常通过如下判定来写可以跑在GPU和CPU上的通用模型:
if torch.cuda.is_available():
ten1 = ten1.cuda()
MyModel = MyModel.cuda()
2. 对应数据的迁移
2.1 将Tensor迁移到显存中去
不论是什么类型的Tensor(FloatTensor或者是LongTensor等等),一律直接使用方法.cuda()即可。
例如:
ten1 = torch.FloatTensor(2) ten1_cuda = ten1.cuda()
如果要将显存中的数据复制到内存中,则对cuda数据类型使用.cpu()
方法即可。
2.2 将Variable迁移到显存中去
在模型中,我们最常使用的是Variable这个容器来装载使用数据。主要是由于Variable可以进行反向传播来进行自动求导。
同样地,要将Variable迁移到显存中,同样只需要使用.cuda()
即可实现。
这里有一个小疑问,对Variable直接使用.cuda
和对Tensor进行.cuda
然后再放置到Variable中结果是否一致呢。答案是肯定的。
ten1 = torch.FloatTensor(2)
>>> 6.1101e+24
4.5659e-41
[torch.FloatTensor of size 2]
ten1_cuda = ten1.cuda()
>>>> 6.1101e+24
4.5659e-41
[torch.cuda.FloatTensor of size 2 (GPU 0)]
V1_cpu = autograd.Variable(ten1)
>>>> Variable containing:
6.1101e+24
4.5659e-41
[torch.FloatTensor of size 2]
V2 = autograd.Variable(ten1_cuda)
>>>> Variable containing:
6.1101e+24
4.5659e-41
[torch.cuda.FloatTensor of size 2 (GPU 0)]
V1 = V1_cpu.cuda()
>>>> Variable containing:
6.1101e+24
4.5659e-41
[torch.cuda.FloatTensor of size 2 (GPU 0)]
最终我们能发现他们都能够达到相同的目的,但是他们完全一样了吗?我们使用V1 is V2
发现,结果是否定的。
对于V1,我们是直接对Variable进行操作的,这样子V1的.grad_fn
中会记录下创建的方式。因此这二者并不是完全相同的。
2.3 数据迁移小结
.cuda()
操作默认使用GPU 0也就是第一张显卡来进行操作。当我们想要存储在其他显卡中时可以使用.cuda(<显卡号数>)
来将数据存储在指定的显卡中。还有很多种方式,具体参考官方文档。
3. 模型迁移
模型的迁移这里指的是torch.nn下面的一些网络模型以及自己创建的模型迁移到GPU上去。
上面讲了使用.cuda()
即可将数据从内存中移植到显存中去。
对于模型来说,也是同样的方式,我们使用.cuda
来将网络放到显存上去。
3.1 torch.nn下的基本模型迁移
我们很惊奇地发现对于模型来说,不像数据那样使用了.cuda()
之后会改变其的数据类型。模型看起来没有任何的变化。
但是他真的没有改变吗。
我们将data1
投入linear_cuda
中去可以发现,系统会报错,而将.cuda
之后的data2投入linear_cuda
才能正常工作。并且输出的也是具有cuda的数据类型。
那是怎么一回事呢?
这是因为这些所谓的模型,其实也就是对输入参数做了一些基本的矩阵运算。所以我们对模型.cuda()
实际上也相当于将模型使用到的参数存储到了显存上去。
对于上面的例子,我们可以通过观察参数来发现区别所在。
linear.weight
>>>> Parameter containing:
-0.6847 0.2149
-0.5473 0.6863
[torch.FloatTensor of size 2x2]
linear_cuda.weight
>>>> Parameter containing:
-0.6847 0.2149
-0.5473 0.6863
[torch.cuda.FloatTensor of size 2x2 (GPU 0)]
3.2 自己模型的迁移
对于自己创建的模型类,由于继承了torch.nn.Module
,则可同样使用.cuda()
来将模型中用到的所有参数都存储到显存中去。
这里笔者曾经有一个疑问:当我们对模型存储到显存中去之后,那么这个模型中的方法后面所创建出来的Tensor是不是都会默认变成cuda的数据类型。答案是否定的。具体操作留给读者自己去实现。
3.3 模型小结
对于模型而言,我们可以将其看做是一种类似于Variable的容器。我们对它进行.cuda()
处理,是将其中的参数放到显存上去(因为实际使用的时候也是通过这些参数做运算)。
https://blog.csdn.net/qq_28444159/article/details/78781201
将Pytorch模型从CPU转换成GPU的更多相关文章
- h5模型文件转换成pb模型文件
本文主要记录Keras训练得到的.h5模型文件转换成TensorFlow的.pb文件 #*-coding:utf-8-* """ 将keras的.h5的模型文件,转换 ...
- 【tensorflow-v2.0】如何将模型转换成tflite模型
前言 TensorFlow Lite 提供了转换 TensorFlow 模型,并在移动端(mobile).嵌入式(embeded)和物联网(IoT)设备上运行 TensorFlow 模型所需的所有工具 ...
- [深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题
[深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存 ...
- DEX-6-caffe模型转成pytorch模型办法
在python2.7环境下 文件下载位置:https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ 1.可视化模型文件prototxt 1)在线可视化 ...
- 利用反射将Datatable、SqlDataReader转换成List模型
1. DataTable转IList public class DataTableToList<T>whereT :new() { ///<summary> ///利用反射将D ...
- Linux下ffmpeg添加Facebook/transform代码块实现将全景视频的球模型转换成立方体模型
Facebook事实上已开始在平台中支持360度全景视频的流播,但公司对此并不满足.其工程师更是基于锥体几何学设计出了一套全新的视频编码,号称最高能将全景视频的文件大小减少80%.(VR最新突破:全景 ...
- iOS swift HandyJSON组合Alamofire发起网络请求并转换成模型
在swift开发中,发起网络请求大部分开发者应该都是使用Alamofire发起的网络请求,至于请求完成后JSON解析这一块有很多解决方案,我们今天这里使用HandyJSON来解析请求返回的数据并转化成 ...
- 「新手必看」Python+Opencv实现摄像头调用RGB图像并转换成HSV模型
在ROS机器人的应用开发中,调用摄像头进行机器视觉处理是比较常见的方法,现在把利用opencv和python语言实现摄像头调用并转换成HSV模型的方法分享出来,希望能对学习ROS机器人的新手们一点帮助 ...
- .net 数据源DataSet 转换成模型
/// <summary> /// DataSet转换成model 自动赋值返回集合 /// </summary> /// <typeparam name="T ...
随机推荐
- mysql客户端工具
MySQL 数据库不仅提供了数据库的服务器端应用程序,同时还提供了大量的客户端工具程序,如 mysql,mysqladmin,mysqldump 等等,都是大家所熟悉的.虽然有些人对这些工具的功能都已 ...
- 用ChrootDirectory限制SFTP登录的用户只能访问指定目录且不能进行ssh登录
创建不能ssh登录的用户sftpuser1,密码用于sftp登录: sudo adduser sftpuser1 --home /sftp/sftpuser1 --shell /bin/false s ...
- POJ 1117 Pairs of Integers
Pairs of Integers Time Limit: 1000MS Memory Limit: 10000K Total Submissions: 4133 Accepted: 1062 Des ...
- SOA架构商城二 框架搭建
1.创建父工程 创建Maven工程pingyougou-parent,选择packaging类型为pom ,在pom.xml文件中添加锁定版本信息dependencyManagement与plugin ...
- 移动端mobiscroll无法滑动、无法划动选值的问题
mobiscroll配置 theme: 'ios',时.滑动取值无效: html的页面内容稍微长过手机屏幕,页面无法完全加载,允许稍微滑动,这时导致点击选择mobiscroll值时无法滑动取值.处理: ...
- vue--环境搭建(创建运行项目)
如何搭建vue环境: 1.安装之前必须要安装 node.js 2.搭建Vue环境,安装vue的脚手架工具 npm install --global vue-cli / cnpm install --g ...
- thinkCMF----使用自定义函数
thinkCMF使用自定义函数:app 下新建 common.php
- nginx配置虚拟主机之不同端口和不同IP地址
配置nginx虚拟主机不同端口和不同ip地址,和上编nginx基于域名配置虚拟主机博文类似,请先参考. zxl.com域名不同端口,配置文件内容如下: 1 2 3 4 5 6 7 8 9 10 11 ...
- Python 之反射和普通方式对比(模拟Web框架)
先模拟一个web页面的选择不同输出不同 vim day8-7.py #!/usr/bin/python # -*- coding:utf-8 -*- import home import accoun ...
- TOP100summit 2017:亚马逊Echo音箱能够语音识人,华人工程师揭秘设计原理
本文编辑:Cynthia 2017年,人工智能的消费产品落地聚焦在了智能音箱上,谷歌.亚马逊纷纷推出智能音箱产品,国内的阿里巴巴推出天猫精灵,小米推出小米AI音箱.智能音箱通过语音可以发出指令,未 ...