pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)
2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

-------------------------------------------------------------------------------------------------------------------------------

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

-----------

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

model = torch.load(PATH)

model.eval()

--------------------------------------------------------------------------------------------------------------------------------------

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

----------------------------------------------------------------------------------------------------------------------

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
--------------------------------------------------------------------------------------------

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
param.requires_grad = False
注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

---------------------------------------------------------------------------------------------

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)

def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

# initial model
model = TheModelClass()

#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor,'\t',model.state_dict()[param_tensor].size())

print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])

print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)

print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)

model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)
---------------------
作者:wzg2016
来源:CSDN
原文:https://blog.csdn.net/strive_for_future/article/details/83240081
版权声明:本文为博主原创文章,转载请附上博文链接!

pytorch 状态字典:state_dict 模型和参数保存的更多相关文章

  1. IOS第四天-新浪微博 -存储优化OAuth授权账号信息,下拉刷新,字典转模型

    *************application - (BOOL)application:(UIApplication *)application didFinishLaunchingWithOpti ...

  2. iOS开发——高级技术精选OC篇&Runtime之字典转模型实战

    Runtime之字典转模型实战 如果您还不知道什么是runtime,那么请先看看这几篇文章: http://www.cnblogs.com/iCocos/p/4734687.html http://w ...

  3. pytorch如何能够保证模型的可重复性

    问题背景是这样的: 我用了自己定义了pytorch中的模型,并且,在main函数中设置了随机种子用来保证模型初始化的参数是一致的,同时pytorch中的随机种子也能够影响dropout的作用,见链接 ...

  4. Runtime之字典转模型实战

    Runtime之字典转模型实战 先来看看怎么使用Runtime给模型类赋值 iOS开发中的Runtime可谓是功能强大,同时Runtime使用起来也是非常灵活的,今天博客的内容主要就是使用到一丁点的R ...

  5. ios开发runtime学习五:KVC以及KVO,利用runtime实现字典转模型

    一:KVC和KVO的学习 #import "StatusItem.h" /* 1:总结:KVC赋值:1:setValuesForKeysWithDictionary实现原理:遍历字 ...

  6. django----orm查询优化 MTV与MVC模型 choice参数 ajax serializers

    目录 orm查询优化 only defer select_related 与 prefetch_related MTV 与 MVC 模型 choice参数 Ajax 前端代码 后端代码 前后端传输数据 ...

  7. iOS开发——网络篇——JSON和XML,NSJSONSerialization ,NSXMLParser(XML解析器),NSXMLParserDelegate,MJExtension (字典转模型),GDataXML(三方框架解析XML)

    一.JSON 1.JSON简介什么是JSONJSON是一种轻量级的数据格式,一般用于数据交互服务器返回给客户端的数据,一般都是JSON格式或者XML格式(文件下载除外) JSON的格式很像OC中的字典 ...

  8. iOS开发——UI基础-懒加载,plist文件,字典转模型,自定义view

    一.懒加载 只有使用到了商品数组才会创建数组 保证数组只会被创建一次 只要能够保证数组在使用时才创建, 并且只会创建一次, 那么我们就称之为懒加载 lazy - (void)viewDidLoad 控 ...

  9. iOS开发UI篇—字典转模型

    iOS开发UI篇—字典转模型 一.能完成功能的“问题代码” 1.从plist中加载的数据 2.实现的代码 // // LFViewController.m // 03-应用管理 // // Creat ...

随机推荐

  1. 杨柳絮-Info:太原市多部门通力合作科学治理杨柳飞絮效果好

    ylbtech-杨柳絮-Info:太原市多部门通力合作科学治理杨柳飞絮效果好 1.返回顶部 1. 太原市多部门通力合作科学治理杨柳飞絮效果好 2016-04-21 07:16 4月10日,随着气温升高 ...

  2. 机器学习之决策树(ID3)算法与Python实现

    机器学习之决策树(ID3)算法与Python实现 机器学习中,决策树是一个预测模型:他代表的是对象属性与对象值之间的一种映射关系.树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每 ...

  3. 【《Objective-C基础教程 》笔记】(八)OC的基本事实和OC杂七杂八的疑问

    一.疑问 1.成员变量.实例变量.局部变量的差别和联系,在訪问.继承上怎样表现. 2.属性@property 和 {变量列表} 是否同样.有什么不同. 3.类方法.类成员.类属性:实例方法.实例变量. ...

  4. QLabel添加Click信号

    使用自定义label来实现此功能 其他控件可参照此例. #include "customerqlabel.h" CustomerQlabel::CustomerQlabel(QWi ...

  5. Apache Camel,Spring Boot 实现文件复制,转移 (转)

    基本框架 Apache Camel Spring Boot Maven 开发过程 1.新建一个POM(quickstart)项目,在POM文件中添加Camel和Spring Boot的依赖 <p ...

  6. Directx11教程(19) 画一个简单的地形

    原文:Directx11教程(19) 画一个简单的地形       通常我们在xz平面定义一个二维的网格,然后y的值根据一定的函数计算得到,比如正弦.余弦函数的组合等等,可以得到一个看似不错的地形或者 ...

  7. 微服务开源生态报告 No.4

    「微服务开源生态报告」,汇集各个开源项目近期的社区动态,帮助开发者们更高效的了解到各开源项目的最新进展. 社区动态包括,但不限于:版本发布.人员动态.项目动态和规划.培训和活动. 非常欢迎国内其他微服 ...

  8. 关于element-ui的弹框问题

    el-dialog获取数据. el-dialog加载到页面中的时候,其实已经加载好了.只是默认隐藏了. 第一次点击的时候弹出,为何拿不到数据?之后再次操作就一点问题都没有了.

  9. yum方式安装MySQL【转】

    在CentOS7中默认安装有MariaDB,这个是MySQL的分支,但为了需要,还是要在系统中安装MySQL,而且安装完成之后可以直接覆盖掉MariaDB. 另外至2919年5月4号, 默认安装的my ...

  10. Chef 安装

    http://www.tuicool.com/articles/RnAVn2 三个角色: chef server, chef workstation, chef nodes(chef clients) ...