最近使用Pytorch在学习一个深度学习项目,在模型保存和加载过程中遇到了问题,最终通过在网卡查找资料得已解决,故以此记之,以备忘却。

  首先,是在使用多GPU进行模型训练的过程中,在保存模型参数时,应该使用类似如下代码进行保存:

  torch.save({
                'epoch': epoch,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict()
            }, 'results/checkpoint_net.pth')

  对应的在加载模型参数时,使用如下代码进行加载是没有问题的:

  checkpoint = torch.load('./results/checkpoint_net.pth')
       model.load_state_dict(checkpoint['model'])
  一般情况下,在保存模型时我们不会发现会有什么不对,而是在需要加载模型参数时,才发现加载报错了。比如:
 
  这时我们需要回头检查我们在保存模型参数时,是否有哪里不对。比如我这次就是这样的,写代码的时候并没有考虑到多GPU的情况,所以保存代码如下:
  

  torch.save({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }, 'results/checkpoint_net.pth')
  

  请注意红圈的地方缺了“module”关键字,导致在保存模型参数时,参数保存成了这样(模型参数是以key-value的形式保存的),即stat_dict(key),对应的value每个值都多了一个module:

  接下来在加载模型参数时,如果直接使用代码 model.load_state_dict(torch.load('模型参数文件存放路径')['state_dict'])就会出现问题。报错如下:

  好了,既然知道了出问题的原因在哪里,那就来考虑下如何处理了,两种方案:

  第一,修改保存模型的代码(加上"module")后,把模型重新训练一次,重新加载即可。但我们大家都知道,这样的深度模型训练,时间一般都是以小时或者天计的,我们等不了那么久。(如果时间允许,可以这么干。哈哈!)

  第二,在加载模型参数之前,写代码将模型参数里的"module"关键字给去掉。比如可以这么写:

  

 实话实说,这个代码并不是我的原创,网上给出这个解决方案的地方很多。但我这里有一点不同的时,我加了个“[state_dict]”,我看到的很多地方是没有这个的,直接就是ckpt.items()。因为我并不知道他们保存模型参数的代码是怎么写的,所以也并不好评论对错。但总之一句话,我们是要通过这段代码,去掉状态字典里的"module"关键字的所以大家可以通过debug,查看这里的k取到的是什么值,应该要是取到下图所示红色框里的值,然后通过“name=k[7:]”去掉前面的"module",然后再加载就可以了。

 文中提到一个词“[state_dict]”,大家不用太在意,有的人在保存模型参数时,用的是“model”,只要在保存和读取的时候,保持一致就可以了。

 欢迎大家对描述不清楚或者不准确的地方提出批评意见和建议!

使用Pytorch在多GPU下保存和加载训练模型参数遇到的问题的更多相关文章

  1. 从头学pytorch(十二):模型保存和加载

    模型读取和存储 总结下来,就是几个函数 torch.load()/torch.save() 通过python的pickle完成序列化与反序列化.完成内存<-->磁盘转换. Module.s ...

  2. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...

  3. 背水一战 Windows 10 (62) - 控件(媒体类): InkCanvas 保存和加载, 手写识别

    [源码下载] 背水一战 Windows 10 (62) - 控件(媒体类): InkCanvas 保存和加载, 手写识别 作者:webabcd 介绍背水一战 Windows 10 之 控件(媒体类) ...

  4. keras中的模型保存和加载

    tensorflow中的模型常常是protobuf格式,这种格式既可以是二进制也可以是文本.keras模型保存和加载与tensorflow不同,keras中的模型保存和加载往往是保存成hdf5格式. ...

  5. 完美实现保存和加载easyui datagrid自定义调整列宽位置隐藏属性功能

    需求&场景 例表查询是业务系统中使用最多也是最基础功能,但也是调整最平凡,不同的用户对数据的要求也不一样,所以在系统正式使用后,做为开发恨不得坐在业务边上,根据他们的要求进行调整,需要调整最多 ...

  6. 超详细的Tensorflow模型的保存和加载(理论与实战详解)

    1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...

  7. 三、TensorFlow模型的保存和加载

    1.模型的保存: import tensorflow as tf v1 = tf.Variable(1.0,dtype=tf.float32) v2 = tf.Variable(2.0,dtype=t ...

  8. tensorflow模型持久化保存和加载

    模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...

  9. tensorflow模型持久化保存和加载--深度学习-神经网络

    模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...

随机推荐

  1. 关于javascript 的reduce方法

    作为一个前端菜鸟,觉得资料比较好,特地分享一下~~ reduce() 方法接收一个函数作为累加器(accumulator),数组中的每个值(从左到右)开始缩减,最终为一个值. 你一定也和我一样看的有点 ...

  2. java算法--普通队列

    数据结构队列 首先明确一下队列的概念. 队列是一种有序列表,使用数组的结构来存储队列的数据. 队列是一种先进先出的算法.由前端加入,由后端输出. 如下图: ​ 第一个图 第二个图 第三个图 这就是队列 ...

  3. 2019-2020-2 20175226 王鹏雲 网络对抗技术 Exp2 后门原理与实践

    2019-2020-2 20175226 王鹏雲 网络对抗技术 Exp2 后门原理与实践 实验内容 使用netcat获取主机操作Shell,cron启动: 使用socat获取主机操作Shell, 任务 ...

  4. Java继承中构造器的调用原理

    Java的继承是比较重要的特性,也是比较容易出错的地方,下面这个例子将展示如果父类构造器中调用被子类重写的方法时会出现的情况: 首先是父类: public class test { void fun( ...

  5. PHP8年开发经验原创开发文档教程

    订阅微信公众号: gzgwgas 每天为你分享PHP开发经验,坚决不踩坑,坚决不入坑. 微信扫码,关注公众号有惊喜!

  6. Java集合04——fail-fast&fail-safe 详解

    在前几个回合中,我们已经详细了解过了 Java 集合中的List.Set 和 Map,对这部分内容感兴趣的朋友可以关注我的公众号「Java面典」了解.今天我们将为各位介绍集合的失败机制--fail-f ...

  7. Natas13 Writeup(文件上传,绕过图片签名检测)

    Natas13: 与上一关页面类似,还是文件上传,只是多了提示“出于安全原因,我们现在仅接受图像文件!”.源码如下 function genRandomString() { $length = 10; ...

  8. ECharts的使用与总结

    ECharts的使用与总结 一,介绍与需求 1.1,介绍 ECharts商业级数据图表,一个纯Javascript的图表库,可以流畅的运行在PC和移动设备上,兼容当前绝大部分浏览器(IE6/7/8/9 ...

  9. Spring Boot框架——快速入门

    Spring Boot是Spring 全家桶非常重要的一个模块,通过 Spring Boot 可以快速搭建一个基于 Spring 的 Java 应用程序,Spring Boot 对常用的第三方库提供了 ...

  10. 【开源】使用Angular9和TypeScript开发RPG游戏

    RPG系统构造 通过对于斗罗大陆小说的游戏化过程,熟悉Angular的结构以及使用TypeScript的面向对象开发方法. 项目地址 人物 和其他RPG游戏类似,游戏里面的人物角色大致有这样的一些属性 ...