1. 如何进行迁移

对模型和相应的数据进行.cuda()处理。通过这种方式,我们就可以将内存中的数据复制到GPU的显存中去。从而可以通过GPU来进行运算了。

1.1 判定使用GPU

下载了对应的GPU版本的Pytorch之后,要确保GPU是可以进行使用的,通过torch.cuda.is_available()的返回值来进行判断。
通过torch.cuda.device_count()可以获得能够使用的GPU数量。其他就不多赘述了。 
常常通过如下判定来写可以跑在GPU和CPU上的通用模型:

  1. if torch.cuda.is_available():
  2. ten1 = ten1.cuda()
  3. MyModel = MyModel.cuda()

2. 对应数据的迁移

2.1 将Tensor迁移到显存中去

不论是什么类型的Tensor(FloatTensor或者是LongTensor等等),一律直接使用方法.cuda()即可。 
例如:

  1. ten1 = torch.FloatTensor(2)
  2.  
  3. ten1_cuda = ten1.cuda()

如果要将显存中的数据复制到内存中,则对cuda数据类型使用.cpu()方法即可。

2.2 将Variable迁移到显存中去

在模型中,我们最常使用的是Variable这个容器来装载使用数据。主要是由于Variable可以进行反向传播来进行自动求导。 
同样地,要将Variable迁移到显存中,同样只需要使用.cuda()即可实现。

这里有一个小疑问,对Variable直接使用.cuda和对Tensor进行.cuda然后再放置到Variable中结果是否一致呢。答案是肯定的。

  1. ten1 = torch.FloatTensor(2)
  2. >>> 6.1101e+24
  3. 4.5659e-41
  4. [torch.FloatTensor of size 2]
  5. ten1_cuda = ten1.cuda()
  6. >>>> 6.1101e+24
  7. 4.5659e-41
  8. [torch.cuda.FloatTensor of size 2 (GPU 0)]
  9. V1_cpu = autograd.Variable(ten1)
  10. >>>> Variable containing:
  11. 6.1101e+24
  12. 4.5659e-41
  13. [torch.FloatTensor of size 2]
  14. V2 = autograd.Variable(ten1_cuda)
  15. >>>> Variable containing:
  16. 6.1101e+24
  17. 4.5659e-41
  18. [torch.cuda.FloatTensor of size 2 (GPU 0)]
  19. V1 = V1_cpu.cuda()
  20. >>>> Variable containing:
  21. 6.1101e+24
  22. 4.5659e-41
  23. [torch.cuda.FloatTensor of size 2 (GPU 0)]

最终我们能发现他们都能够达到相同的目的,但是他们完全一样了吗?我们使用V1 is V2发现,结果是否定的。

对于V1,我们是直接对Variable进行操作的,这样子V1的.grad_fn中会记录下创建的方式。因此这二者并不是完全相同的。

2.3 数据迁移小结

.cuda()操作默认使用GPU 0也就是第一张显卡来进行操作。当我们想要存储在其他显卡中时可以使用.cuda(<显卡号数>)来将数据存储在指定的显卡中。还有很多种方式,具体参考官方文档。

3. 模型迁移

模型的迁移这里指的是torch.nn下面的一些网络模型以及自己创建的模型迁移到GPU上去。

上面讲了使用.cuda()即可将数据从内存中移植到显存中去。 
对于模型来说,也是同样的方式,我们使用.cuda来将网络放到显存上去。

3.1 torch.nn下的基本模型迁移

我们很惊奇地发现对于模型来说,不像数据那样使用了.cuda()之后会改变其的数据类型。模型看起来没有任何的变化。 
但是他真的没有改变吗。 
我们将data1投入linear_cuda中去可以发现,系统会报错,而将.cuda之后的data2投入linear_cuda才能正常工作。并且输出的也是具有cuda的数据类型。

那是怎么一回事呢? 
这是因为这些所谓的模型,其实也就是对输入参数做了一些基本的矩阵运算。所以我们对模型.cuda()实际上也相当于将模型使用到的参数存储到了显存上去。

对于上面的例子,我们可以通过观察参数来发现区别所在。

  1. linear.weight
  2. >>>> Parameter containing:
  3. -0.6847 0.2149
  4. -0.5473 0.6863
  5. [torch.FloatTensor of size 2x2]
  6. linear_cuda.weight
  7. >>>> Parameter containing:
  8. -0.6847 0.2149
  9. -0.5473 0.6863
  10. [torch.cuda.FloatTensor of size 2x2 (GPU 0)]

3.2 自己模型的迁移

对于自己创建的模型类,由于继承了torch.nn.Module,则可同样使用.cuda()来将模型中用到的所有参数都存储到显存中去。

这里笔者曾经有一个疑问:当我们对模型存储到显存中去之后,那么这个模型中的方法后面所创建出来的Tensor是不是都会默认变成cuda的数据类型。答案是否定的。具体操作留给读者自己去实现。

3.3 模型小结

对于模型而言,我们可以将其看做是一种类似于Variable的容器。我们对它进行.cuda()处理,是将其中的参数放到显存上去(因为实际使用的时候也是通过这些参数做运算)。

https://blog.csdn.net/qq_28444159/article/details/78781201

将Pytorch模型从CPU转换成GPU的更多相关文章

  1. h5模型文件转换成pb模型文件

      本文主要记录Keras训练得到的.h5模型文件转换成TensorFlow的.pb文件 #*-coding:utf-8-* """ 将keras的.h5的模型文件,转换 ...

  2. 【tensorflow-v2.0】如何将模型转换成tflite模型

    前言 TensorFlow Lite 提供了转换 TensorFlow 模型,并在移动端(mobile).嵌入式(embeded)和物联网(IoT)设备上运行 TensorFlow 模型所需的所有工具 ...

  3. [深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题

    [深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存 ...

  4. DEX-6-caffe模型转成pytorch模型办法

    在python2.7环境下 文件下载位置:https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ 1.可视化模型文件prototxt 1)在线可视化 ...

  5. 利用反射将Datatable、SqlDataReader转换成List模型

    1. DataTable转IList public class DataTableToList<T>whereT :new() { ///<summary> ///利用反射将D ...

  6. Linux下ffmpeg添加Facebook/transform代码块实现将全景视频的球模型转换成立方体模型

    Facebook事实上已开始在平台中支持360度全景视频的流播,但公司对此并不满足.其工程师更是基于锥体几何学设计出了一套全新的视频编码,号称最高能将全景视频的文件大小减少80%.(VR最新突破:全景 ...

  7. iOS swift HandyJSON组合Alamofire发起网络请求并转换成模型

    在swift开发中,发起网络请求大部分开发者应该都是使用Alamofire发起的网络请求,至于请求完成后JSON解析这一块有很多解决方案,我们今天这里使用HandyJSON来解析请求返回的数据并转化成 ...

  8. 「新手必看」Python+Opencv实现摄像头调用RGB图像并转换成HSV模型

    在ROS机器人的应用开发中,调用摄像头进行机器视觉处理是比较常见的方法,现在把利用opencv和python语言实现摄像头调用并转换成HSV模型的方法分享出来,希望能对学习ROS机器人的新手们一点帮助 ...

  9. .net 数据源DataSet 转换成模型

    /// <summary> /// DataSet转换成model 自动赋值返回集合 /// </summary> /// <typeparam name="T ...

随机推荐

  1. HTTP协议详解(真的很经典)(转载)

    HTTP协议详解(真的很经典):http://www.cnblogs.com/li0803/archive/2008/11/03/1324746.html 引言 HTTP是一个属于应用层的面向对象的协 ...

  2. CF 217A Ice Skating

    A. Ice Skating time limit per test 2 seconds memory limit per test 256 megabytes input standard inpu ...

  3. Node复制文件

    本人开发过程中,经常遇到,要去拷贝模板到当前文件夹,经常要去托文件,为了省事,解决这个问题,写了一个node复制文件. // js/app.js:指定确切的文件名.// js/*.js:某个目录所有后 ...

  4. mysql判断一个字符串是否包含某几个字符

    使用locate(substr,str)函数,如果包含,返回>0的数,否则返回0

  5. Windows Server 2008 R2之一活动目录服务部署

    测试环境: 服务器:计算机名Win2008R2CNDC,已安装Windows Server 2008 R2.IPV4:192.168.1.13,255.255.255.0,网关地址192.168.1. ...

  6. CAT偶现NPE的问题

    1.背景 我们公司的调用链系统是基于大众点评的CAT客户端改造的,服务端完全有自己设计开发的.在是用CAT客户端收集dubbo调用信息的时候,我们发现了一个CAT偶现NPE的bug,该bug隐藏的很深 ...

  7. FFT【快速傅里叶变换】FWT【快速沃尔什变换】

    实在是 美丽的数学啊 关于傅里叶变换的博客 讲的很细致 图片非常易于理解http://blog.jobbole.com/70549/ 大概能明白傅里叶变换是干吗的了 但是还是不能明白为什么用傅里叶变换 ...

  8. Spring Cloud微服务开发笔记5——Ribbon负载均衡策略规则定制

    上一篇文章单独介绍了Ribbon框架的使用,及其如何实现客户端对服务访问的负载均衡,但只是单独从Ribbon框架实现,没有涉及spring cloud.本文着力介绍Ribbon的负载均衡机制,下一篇文 ...

  9. ArcGIS API for javascript开发笔记(六)——REST详解及如何使用REST API调用GP服务

    感谢一路走来默默支持和陪伴的你~~~ -------------------欢迎来访,拒绝转载-------------------- 一.Rest API基础 ArcGIS 平台提供了丰富的REST ...

  10. Indexes (also called “keys” in MySQL)

    High Performance MySQL, Third Edition by Baron Schwartz, Peter Zaitsev, and Vadim Tkachenko   Is an ...