最近使用Pytorch在学习一个深度学习项目,在模型保存和加载过程中遇到了问题,最终通过在网卡查找资料得已解决,故以此记之,以备忘却。

  首先,是在使用多GPU进行模型训练的过程中,在保存模型参数时,应该使用类似如下代码进行保存:

  torch.save({
                'epoch': epoch,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict()
            }, 'results/checkpoint_net.pth')

  对应的在加载模型参数时,使用如下代码进行加载是没有问题的:

  checkpoint = torch.load('./results/checkpoint_net.pth')
       model.load_state_dict(checkpoint['model'])
  一般情况下,在保存模型时我们不会发现会有什么不对,而是在需要加载模型参数时,才发现加载报错了。比如:
 
  这时我们需要回头检查我们在保存模型参数时,是否有哪里不对。比如我这次就是这样的,写代码的时候并没有考虑到多GPU的情况,所以保存代码如下:
  

  torch.save({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }, 'results/checkpoint_net.pth')
  

  请注意红圈的地方缺了“module”关键字,导致在保存模型参数时,参数保存成了这样(模型参数是以key-value的形式保存的),即stat_dict(key),对应的value每个值都多了一个module:

  接下来在加载模型参数时,如果直接使用代码 model.load_state_dict(torch.load('模型参数文件存放路径')['state_dict'])就会出现问题。报错如下:

  好了,既然知道了出问题的原因在哪里,那就来考虑下如何处理了,两种方案:

  第一,修改保存模型的代码(加上"module")后,把模型重新训练一次,重新加载即可。但我们大家都知道,这样的深度模型训练,时间一般都是以小时或者天计的,我们等不了那么久。(如果时间允许,可以这么干。哈哈!)

  第二,在加载模型参数之前,写代码将模型参数里的"module"关键字给去掉。比如可以这么写:

  

 实话实说,这个代码并不是我的原创,网上给出这个解决方案的地方很多。但我这里有一点不同的时,我加了个“[state_dict]”,我看到的很多地方是没有这个的,直接就是ckpt.items()。因为我并不知道他们保存模型参数的代码是怎么写的,所以也并不好评论对错。但总之一句话,我们是要通过这段代码,去掉状态字典里的"module"关键字的所以大家可以通过debug,查看这里的k取到的是什么值,应该要是取到下图所示红色框里的值,然后通过“name=k[7:]”去掉前面的"module",然后再加载就可以了。

 文中提到一个词“[state_dict]”,大家不用太在意,有的人在保存模型参数时,用的是“model”,只要在保存和读取的时候,保持一致就可以了。

 欢迎大家对描述不清楚或者不准确的地方提出批评意见和建议!

使用Pytorch在多GPU下保存和加载训练模型参数遇到的问题的更多相关文章

  1. 从头学pytorch(十二):模型保存和加载

    模型读取和存储 总结下来,就是几个函数 torch.load()/torch.save() 通过python的pickle完成序列化与反序列化.完成内存<-->磁盘转换. Module.s ...

  2. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tf w1 = tf.Variable(tf.constant(2.0, shape=[1]), name= ...

  3. 背水一战 Windows 10 (62) - 控件(媒体类): InkCanvas 保存和加载, 手写识别

    [源码下载] 背水一战 Windows 10 (62) - 控件(媒体类): InkCanvas 保存和加载, 手写识别 作者:webabcd 介绍背水一战 Windows 10 之 控件(媒体类) ...

  4. keras中的模型保存和加载

    tensorflow中的模型常常是protobuf格式,这种格式既可以是二进制也可以是文本.keras模型保存和加载与tensorflow不同,keras中的模型保存和加载往往是保存成hdf5格式. ...

  5. 完美实现保存和加载easyui datagrid自定义调整列宽位置隐藏属性功能

    需求&场景 例表查询是业务系统中使用最多也是最基础功能,但也是调整最平凡,不同的用户对数据的要求也不一样,所以在系统正式使用后,做为开发恨不得坐在业务边上,根据他们的要求进行调整,需要调整最多 ...

  6. 超详细的Tensorflow模型的保存和加载(理论与实战详解)

    1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...

  7. 三、TensorFlow模型的保存和加载

    1.模型的保存: import tensorflow as tf v1 = tf.Variable(1.0,dtype=tf.float32) v2 = tf.Variable(2.0,dtype=t ...

  8. tensorflow模型持久化保存和加载

    模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...

  9. tensorflow模型持久化保存和加载--深度学习-神经网络

    模型文件的保存 tensorflow将模型保持到本地会生成4个文件: meta文件:保存了网络的图结构,包含变量.op.集合等信息 ckpt文件: 二进制文件,保存了网络中所有权重.偏置等变量数值,分 ...

随机推荐

  1. 峰哥说技术:10-Spring Boot静态资源处理

    Spring Boot深度课程系列 峰哥说技术—2020庚子年重磅推出.战胜病毒.我们在行动 10  峰哥说技术:Spring Boot静态资源处理 今天我们聊聊关于 Spring Boot 中关于静 ...

  2. 怎么查看linux文件夹下有多少个文件(mac同样)

    查看目录下有多少个文件及文件夹,在终端输入 ls | wc -w 查看目录下有多少个文件,在终端输入 ls | wc -c 查看文件夹下有多少个文件,多少个子目录,在终端输入 ls -l |wc -l ...

  3. Sublime text 3 运行python3

    要在Sublime text3编译器中成功运行 python3,需要在编译器设置中将python3添加至编译器中 新建编译系统 编辑弹出的文件,添加如下内容: { "cmd":[& ...

  4. Python数据科学手册(2) NumPy入门

    NumPy(Numerical Python 的简称)提供了高效存储和操作密集数据缓存的接口.在某些方面,NumPy 数组与 Python 内置的列表类型非常相似.但是随着数组在维度上变大,NumPy ...

  5. 痞子衡嵌入式:恩智浦SDK驱动代码风格检查工具预览版

    大家好,我是痞子衡,是正经搞技术的痞子. 接上文 <恩智浦SDK驱动代码风格.模板.检查工具> 继续聊,是的,过去的三天里我花了一些时间做了一个基于 PyQt5 的 GUI 工具,可以帮助 ...

  6. 【Python】2.12学习笔记 变量

    变量 关于变量我有一个不能理解的,关于全局变量作用域与地址的问题,学函数的时候我可能会搞懂它并且写下来 另外,其实昨天说的是有些不准确的,\(Python\)里的变量不是不用声明类型,只是声明方式特殊 ...

  7. What is the difference between shades and shadows?

    Shade is the darkness of an object not in direct light, while shadows are the silhouette of an objec ...

  8. MVC设计模式简介

    刚刚学习了MVC相关知识,在这里进行一下总结MVC设计模式提高了Java开发中的代码可读性,提高了开发效率,实乃开发利器 1在MVC中由客户端发送一个带参数的请求,经过servlet处理后做出相应的处 ...

  9. DS博客作业02--栈和队列

    0.PTA得分截图 1.本周学习总结 1.1总结栈和队列内容 栈的存储结构及操作 栈的顺序存储结构 typedef struct { ElemType data[MaxSize]: int top: ...

  10. 《Python学习手册 第五版》 -第15章 文档

    本章主要介绍Python中的文档,会通过多种方式来说明,如果查看Python自带文档和其他参考的资料 本章重点内容 1.#注释:源文件文档 2.dir函数:以列表显示对象中可用的属性 3.文档字符串 ...