使用Pytorch在多GPU下保存和加载训练模型参数遇到的问题
最近使用Pytorch在学习一个深度学习项目,在模型保存和加载过程中遇到了问题,最终通过在网卡查找资料得已解决,故以此记之,以备忘却。
首先,是在使用多GPU进行模型训练的过程中,在保存模型参数时,应该使用类似如下代码进行保存:
对应的在加载模型参数时,使用如下代码进行加载是没有问题的:

请注意红圈的地方缺了“module”关键字,导致在保存模型参数时,参数保存成了这样(模型参数是以key-value的形式保存的),即stat_dict(key),对应的value每个值都多了一个module:
接下来在加载模型参数时,如果直接使用代码 model.load_state_dict(torch.load('模型参数文件存放路径')['state_dict'])就会出现问题。报错如下:
好了,既然知道了出问题的原因在哪里,那就来考虑下如何处理了,两种方案:
第一,修改保存模型的代码(加上"module")后,把模型重新训练一次,重新加载即可。但我们大家都知道,这样的深度模型训练,时间一般都是以小时或者天计的,我们等不了那么久。(如果时间允许,可以这么干。哈哈!)
第二,在加载模型参数之前,写代码将模型参数里的"module"关键字给去掉。比如可以这么写:

实话实说,这个代码并不是我的原创,网上给出这个解决方案的地方很多。但我这里有一点不同的时,我加了个“[state_dict]”,我看到的很多地方是没有这个的,直接就是ckpt.items()。因为我并不知道他们保存模型参数的代码是怎么写的,所以也并不好评论对错。但总之一句话,我们是要通过这段代码,去掉状态字典里的"module"关键字的所以大家可以通过debug,查看这里的k取到的是什么值,应该要是取到下图所示红色框里的值,然后通过“name=k[7:]”去掉前面的"module",然后再加载就可以了。
文中提到一个词“[state_dict]”,大家不用太在意,有的人在保存模型参数时,用的是“model”,只要在保存和读取的时候,保持一致就可以了。
欢迎大家对描述不清楚或者不准确的地方提出批评意见和建议!
使用Pytorch在多GPU下保存和加载训练模型参数遇到的问题的更多相关文章
- 从头学pytorch(十二):模型保存和加载
模型读取和存储 总结下来,就是几个函数 torch.load()/torch.save() 通过python的pickle完成序列化与反序列化.完成内存<-->磁盘转换. Module.s ...
- TensorFlow模型保存和加载方法
TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...
- 背水一战 Windows 10 (62) - 控件(媒体类): InkCanvas 保存和加载, 手写识别
[源码下载] 背水一战 Windows 10 (62) - 控件(媒体类): InkCanvas 保存和加载, 手写识别 作者:webabcd 介绍背水一战 Windows 10 之 控件(媒体类) ...
- keras中的模型保存和加载
tensorflow中的模型常常是protobuf格式,这种格式既可以是二进制也可以是文本.keras模型保存和加载与tensorflow不同,keras中的模型保存和加载往往是保存成hdf5格式. ...
- 完美实现保存和加载easyui datagrid自定义调整列宽位置隐藏属性功能
需求&场景 例表查询是业务系统中使用最多也是最基础功能,但也是调整最平凡,不同的用户对数据的要求也不一样,所以在系统正式使用后,做为开发恨不得坐在业务边上,根据他们的要求进行调整,需要调整最多 ...
- 超详细的Tensorflow模型的保存和加载(理论与实战详解)
1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...
- 三、TensorFlow模型的保存和加载
1.模型的保存: import tensorflow as tf v1 = tf.Variable(1.0,dtype=tf.float32) v2 = tf.Variable(2.0,dtype=t ...
- tensorflow模型持久化保存和加载
模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...
- tensorflow模型持久化保存和加载--深度学习-神经网络
模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...
随机推荐
- AlphaGo、人工智能、深度学习解读以及应用
经过比拼,AlphaGo最终还是胜出,创造了人机大战历史上的一个新的里程碑.几乎所有的人都在谈论这件事情,这使得把“人工智能”.“深度学习”的热潮推向了新的一个高潮.AlphaGo就像科幻电影里具有人 ...
- seo搜索优化教程09 - seo搜索优化外链优化
为了使大家更方便的了解及学习网络营销推广.seo搜索优化,星辉科技强势推出seo搜索优化教程.此为seo教程第九课 网络营销推广中有句行话,叫做"内容为王,外链为王",可见外链对于 ...
- 解决挖矿病毒【Xmrig miner 】CPU 100%服务器卡死问题
背景: 突然有一天,服务器访问很慢很慢,进程查看发现CPU是100%,而且没有任何降低的意思 收集: 打开任务管理器,进程查看中CPU排序,发现一个System的进程,第一想法以为是空闲利用,发现结束 ...
- 从当当客户端api抓取书评到词云生成
看了好几本大冰的书,感觉对自己的思维有不少的影响.想看看其他读者的评论.便想从当当下手抓取他们评论做个词云.想着网页版说不定有麻烦的反爬,干脆从手机客户端下手好了.果其不然,找到一个书评的api.发送 ...
- django 从零开始 13 返回文件
进行一些操作返回文件,flask和django差不多,基本都是在返回response 并且对其中的返回头部写入返回文件信息 # image def image(request): f = open(r ...
- DevOps - 持续集成
最近在担任公司部门的DevOps Champion的角色,一直觉得这个只是一个协调者的角色(而不是一个SME的角色),我的工作大概就是将每个项目的devops工具收集一下,然后用图表的形式去体现大家用 ...
- javaee作业
一.单选题(共5题,50.0分) 1 在SqlSession对象的openSession()方法中,不能作为参数executorType的可选值 的是( ). A. ExecutorTyp ...
- 深入理解JS引擎的执行机制
深入理解JS引擎的执行机制 1.灵魂三问 : JS为什么是单线程的? 为什么需要异步? 单线程又是如何实现异步的呢? 2.JS中的event loop(1) 3.JS中的event loop(2) 4 ...
- $props, $attrs,$listeners的具体使用例子
我在这使用属性重新render饿了么ui的tree: <el-tree ref="tree" icon-class="fa fa-caret-right" ...
- mysql那些事之索引篇
mysql那些事之索引篇 上一篇博客已经简单从广的方面介绍了一下mysql整体架构以及物理结构的内容. 本篇博客的内容是mysql的索引,索引无论是在面试还是我们日常工作中都是非常的重要一环. 索引是 ...