在 Pytorch 中一种模型保存和加载的方式如下:

# save
torch.save(model.state_dict(), PATH) # load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

model.state_dict()其实返回的是一个OrderDict,存储了网络结构的名字和对应的参数,下面看看源代码如何实现的。

state_dict

# torch.nn.modules.module.py
class Module(object):
def state_dict(self, destination=None, prefix='', keep_vars=False):
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.data
for name, buf in self._buffers.items():
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.data
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination

可以看到state_dict函数中遍历了4中元素,分别是_paramters,_buffers,_modules_state_dict_hooks,前面三者在之前的文章已经介绍区别,最后一种就是在读取state_dict时希望执行的操作,一般为空,所以不做考虑。另外有一点需要注意的是,在读取Module时采用的递归的读取方式,并且名字间使用.做分割,以方便后面load_state_dict读取参数。

class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.my_tensor = torch.randn(1) # 参数直接作为模型类成员变量
self.register_buffer('my_buffer', torch.randn(1)) # 参数注册为 buffer
self.my_param = nn.Parameter(torch.randn(1))
self.fc = nn.Linear(2,2,bias=False)
self.conv = nn.Conv2d(2,1,1)
self.fc2 = nn.Linear(2,2,bias=False)
self.f3 = self.fc
def forward(self, x):
return x model = MyModel()
print(model.state_dict())
>>>OrderedDict([('my_param', tensor([-0.3052])), ('my_buffer', tensor([0.5583])), ('fc.weight', tensor([[ 0.6322, -0.0255],
[-0.4747, -0.0530]])), ('conv.weight', tensor([[[[ 0.3346]], [[-0.2962]]]])), ('conv.bias', tensor([0.5205])), ('fc2.weight', tensor([[-0.4949, 0.2815],
[ 0.3006, 0.0768]])), ('f3.weight', tensor([[ 0.6322, -0.0255],
[-0.4747, -0.0530]]))])

可以看到最后的确输出了三种参数。

load_state_dict

下面的代码中我们可以分成两个部分看,

  1. load(self)

这个函数会递归地对模型进行参数恢复,其中的_load_from_state_dict的源码附在文末。

首先我们需要明确state_dict这个变量表示你之前保存的模型参数序列,而_load_from_state_dict函数中的local_state 表示你的代码中定义的模型的结构。

那么_load_from_state_dict的作用简单理解就是假如我们现在需要对一个名为conv.weight的子模块做参数恢复,那么就以递归的方式先判断conv是否在staet__dictlocal_state中,如果不在就把conv添加到unexpected_keys中去,否则递归的判断conv.weight是否存在,如果都存在就执行param.copy_(input_param),这样就完成了conv.weight的参数拷贝。

  1. if strict:

这个部分的作用是判断上面参数拷贝过程中是否有unexpected_keys或者missing_keys,如果有就报错,代码不能继续执行。当然,如果strict=False,则会忽略这些细节。

def load_state_dict(self, state_dict, strict=True):
missing_keys = []
unexpected_keys = []
error_msgs = [] # copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.') load(self) if strict:
error_msg = ''
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys))) if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
  • _load_from_state_dict
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
local_state = {k: v.data for k, v in local_name_params if v is not None} for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key] # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0] if input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
continue if isinstance(input_param, Parameter):
# backwards compatibility for serialized parameters
input_param = input_param.data
try:
param.copy_(input_param)
except Exception:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(key, param.size(), input_param.size()))
elif strict:
missing_keys.append(key) if strict:
for key, input_param in state_dict.items():
if key.startswith(prefix):
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com

如有意合作,欢迎私戳

邮箱:marsggbo@foxmail.com



2019-12-20 21:55:21

源码详解Pytorch的state_dict和load_state_dict的更多相关文章

  1. Spark Streaming揭秘 Day25 StreamingContext和JobScheduler启动源码详解

    Spark Streaming揭秘 Day25 StreamingContext和JobScheduler启动源码详解 今天主要理一下StreamingContext的启动过程,其中最为重要的就是Jo ...

  2. spring事务详解(三)源码详解

    系列目录 spring事务详解(一)初探事务 spring事务详解(二)简单样例 spring事务详解(三)源码详解 spring事务详解(四)测试验证 spring事务详解(五)总结提高 一.引子 ...

  3. 条件随机场之CRF++源码详解-预测

    这篇文章主要讲解CRF++实现预测的过程,预测的算法以及代码实现相对来说比较简单,所以这篇文章理解起来也会比上一篇条件随机场训练的内容要容易. 预测 上一篇条件随机场训练的源码详解中,有一个地方并没有 ...

  4. [转]Linux内核源码详解--iostat

    Linux内核源码详解——命令篇之iostat 转自:http://www.cnblogs.com/york-hust/p/4846497.html 本文主要分析了Linux的iostat命令的源码, ...

  5. saltstack源码详解一

    目录 初识源码流程 入口 1.grains.items 2.pillar.items 2/3: 是否可以用python脚本实现 总结pillar源码分析: @(python之路)[saltstack源 ...

  6. Shiro 登录认证源码详解

    Shiro 登录认证源码详解 Apache Shiro 是一个强大且灵活的 Java 开源安全框架,拥有登录认证.授权管理.企业级会话管理和加密等功能,相比 Spring Security 来说要更加 ...

  7. udhcp源码详解(五) 之DHCP包--options字段

    中间有很长一段时间没有更新udhcp源码详解的博客,主要是源码里的函数太多,不知道要不要一个一个讲下去,要知道讲DHCP的实现理论的话一篇博文也就可以大致的讲完,但实现的源码却要关心很多的问题,比如说 ...

  8. Activiti架构分析及源码详解

    目录 Activiti架构分析及源码详解 引言 一.Activiti设计解析-架构&领域模型 1.1 架构 1.2 领域模型 二.Activiti设计解析-PVM执行树 2.1 核心理念 2. ...

  9. 源码详解系列(六) ------ 全面讲解druid的使用和源码

    简介 druid是用于创建和管理连接,利用"池"的方式复用连接减少资源开销,和其他数据源一样,也具有连接数控制.连接可靠性测试.连接泄露控制.缓存语句等功能,另外,druid还扩展 ...

随机推荐

  1. Node版本管理器NVM常用命令

    NVM是什么?nvm (Node Version Manager) 是Nodejs版本管理器,可对不同的node版本快速进行切换. 为什么要用NVM?基于node的工具和项目越来越多,但是每个项目使用 ...

  2. 【洛谷P4251】[SCOI2015]小凸玩矩阵(二分+二分图匹配)

    洛谷 题意: 给出一个\(n*m\)的矩阵\(A\).现要从中选出\(n\)个数,任意两个数不能在同一行或者同一列. 现在问选出的\(n\)个数中第\(k\)大的数的最小值是多少. 思路: 显然二分一 ...

  3. python 实现 AES CBC模式加解密

    AES加密方式有五种:ECB, CBC, CTR, CFB, OFB 从安全性角度推荐CBC加密方法,本文介绍了CBC,ECB两种加密方法的python实现 python 在 Windows下使用AE ...

  4. /usr/lib/python2.7/subprocess.py", line 1239, in _execute_child

    Traceback (most recent call last):File "/home/eping/bin/repo", line 685, in main(sys.argv[ ...

  5. 201871010111-刘佳华《面向对象程序设计(java)》第十三周学习总结

    201871010111-刘佳华<面向对象程序设计(java)>第十三周学习总结 实验十一 图形界面事件处理技术 实验时间 2019-11-22 第一部分:理论知识总结 1.事件源:能够产 ...

  6. Django 连接数据库

    配置数据库 Django 默认连接的是SQLite,如果想要连接MySQL则需修改配置:在 setting.py 中找到数据库的默认配置: DATABASES = { 'default': { 'EN ...

  7. 优秀文章 Swagger

    原文:https://www.cnblogs.com/peterYong/p/9569453.html 原文:https://www.cnblogs.com/lhbshg/p/8711604.html

  8. JAVA 中加载属性文件的4种方法

    小总结 : 这个集合属性可以反序列化, 把持久化数据读出来, 输入流中放入要操作的文件! p.load加载这个输入流! p.getProperty( key) 根据这个键获得值! 补充 : web工程 ...

  9. appium--python启动appium服务

    前戏 前面我们都是在cmd下通过输入appium加端口号来启动服务的,在我们做自动化的时候,我们当然不希望我们手动启动appium服务,而是希望通过脚本自动启动appium服务. 我们可以使用subp ...

  10. 剑指offer:滑动窗口的最大值(栈和队列)

    1. 题目描述 /* 给定一个数组和滑动窗口的大小,找出所有滑动窗口里数值的最大值. 例如,如果输入数组{2,3,4,2,6,2,5,1}及滑动窗口的大小3,那么一共存在6个滑动窗口,他们的最大值分别 ...