最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题。首先咱们先定义一个网络来进行后续的分析:

1、本文通用的网络模型

import torch
import torch.nn as nn
'''
定义网络中第一个网络模块 Net1
'''
class Net1(nn.Module):
def __init__(self):
super().__init__() # input size [B, 1, 3, 3] ==> [B, 1, 3, 3]
self.n = nn.Conv2d(1, 2, 3, padding=1)
def forward(self, x):
x = self.n(x)
return x
'''
定义网络中第二个网络模块 Net2
'''
class Net2(nn.Module):
def __init__(self):
super().__init__() self.n = nn.Sequential(
# input size [B, 1, 3, 3] ==> [B, 2, 3, 3]
nn.Conv2d(2, 2, 3, padding=1), # input size [B, 2, 3, 3] ==> [B, 1, 1, 1]
nn.Conv2d(2, 1, 3, padding=0),
)
def forward(self, x):
x = self.n(x)
return x
'''
定义网络中主网络模块 Network
'''
class Network(nn.Module):
def __init__(self):
super().__init__()
self.head = Net1()
self.tail = Net2()
def forward(self, x):
x = self.head(x)
x = self.tail(x)
return x

网络模块已经搭建好,我们先实例化一个模型然后打印看一下网络结构是否正确:

model = Network()	# 实例化网络模型
print(model) # 输出网络结构
Input = torch.randn(1,1,3,3) # 自定义数据输入
Output = model(Input) # 计算网络输出
print("Input 的维度为:{},Output 的维度为:{}".format(Input.shape, Output.shape))

则输出结果为:

Network(
(head): Net1(
(n): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(tail): Net2(
(n): Sequential(
(0): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1))
)
)
)
Input 的维度为:torch.Size([1, 1, 3, 3]),Output 的维度为:torch.Size([1, 1, 1, 1])

从输出结果看,网络包含两个子模块 headtail,这两个子模块分别是类 Net1Net2 的实例化对象。在 Net2 的定义中,使用了 nn.Sequential() 函数,它能够将包含在里面的网络按照输入顺序进行组合,封装成一个新的模块,适用于网络中大量重复的结构,比如 Conv-ReLU-Conv 等模块。

2、对模型进行训练得到权重

我们先对网络做一个简单的训练,训练代码如下:

model = Network()	# 实例化网络模型
print(model) # 输出网络结构 torch.manual_seed(0) # 固定随机种子,确保每次产生的随机输入一致,方便我们评估训练结果
Input = torch.randn(1,1,3,3) # 自定义数据输入 Iter_num = 10 # 定义最大的迭代次数
Label = torch.tensor(1.0) # 定义有监督训练的label,这里的label必须是float类型的Tensor,否则会出错
criterion = nn.MSELoss() # 定义损失函数,这里选用MSE import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr = 0.01) #定义优化器,这里采用随机梯度下降(SGD) for index in range(Iter_num):
Output = model(Input) # 计算网络输出
loss = criterion(Output, Label) # 计算loss
loss.backward() # 反向传播计算梯度
optimizer.step() # 梯度更新
print("Iter:{}/{}\tloss:{}\tOutput:{}".format(index, Iter_num, loss.data, Output.data))

训练过程如下:

Iter:0/10	loss:1.4089158773422241	Output:tensor([[[[-0.1870]]]])
Iter:1/10 loss:1.3796569108963013 Output:tensor([[[[-0.1746]]]])
Iter:2/10 loss:1.323099136352539 Output:tensor([[[[-0.1503]]]])
Iter:3/10 loss:1.2428957223892212 Output:tensor([[[[-0.1149]]]])
Iter:4/10 loss:1.143916130065918 Output:tensor([[[[-0.0695]]]])
Iter:5/10 loss:1.0316702127456665 Output:tensor([[[[-0.0157]]]])
Iter:6/10 loss:0.9117376208305359 Output:tensor([[[[0.0452]]]])
Iter:7/10 loss:0.7892979979515076 Output:tensor([[[[0.1116]]]])
Iter:8/10 loss:0.6688111424446106 Output:tensor([[[[0.1822]]]])
Iter:9/10 loss:0.5538586378097534 Output:tensor([[[[0.2558]]]])

3、模型存储

3.1 模型参数一起存储与加载
'''
这种方式存储模型的参数,而非整个模型
'''
torch.save(model.state_dict(), model_path) # 存储网络模型的参数
checkpoint = torch.load(model_path) # 先加载模型的参数
model.load_state_dict(checkpoint) # 再将加载的参数填入实例化的网络模型中
'''
这种方式存储整个模型
'''
torch.save(model,model_path) # 直接存储整个模型,包括模型结构和参数
model = torch.load(model_path) # 不用实例化,直接加载就可以用

存储整个模型与存储模型参数的区别:

  1. 整个模型:是保存整个网络结构和参数,使用时会加载结构和其中的参数,即边搭框架边填充参数;
  2. 仅参数:仅保存网络模型中的参数,在使用时需要先用训练时的模型实例化,再往里面填入参数,即需要先搭好框架再往框架里填参数。

下面我们就分别通过这两种方式进行模型存储与加载:

model_path_dict = './ckpt_dict.pth'	# 模型参数的存储路径
torch.save(model.state_dict(), model_path_dict) model_path_model = './ckpt_model.pth' # 整个模型的存储路径
torch.save(model, model_path_model) model_test = Network() # 重新实例化一个网络对象
test_out = model_test(Input) # 先看一下初始化输出
print("test_out: ", test_out.data) checkpoint = torch.load(model_path_dict) # 采用加载参数的方式加载与训练模型
model_test.load_state_dict(checkpoint)
print("test_out1: ", model_test(Input).data) # 查看预训练模型加载后的输出 model_test2 = torch.load(model_path_model) # 直接加载整个模型
print("test_out1: ", model_test2(Input).data) # 查看预训练模型加载后的输出

对应的输出结果如下:

test_out:   tensor([[[[0.1190]]]])  # 网络刚开始的输出结果
test_out1: tensor([[[[0.2558]]]]) # 加载参数后的网络输出
test_out2: tensor([[[[0.2558]]]]) # 加载整个模型后的网络输出

从结果中可以看出,这两种方式加载网络模型的效果是一样的,但是只存储参数的模型所占空间为 2731字节,整个模型所占的空间为4071字节,所以一般建议采取第一种方法。

3.2 模型参数分开存储
model_path_dict2 = './ckpt_dict2.pth'	# 模型的存储路径
torch.save({
'net1':model.head.state_dict(),
'net2':model.tail.state_dict(),
}, model_path_dict2) # 将模型的head和tail模块分开存储
model3 = Network() # 实例化一个新的网络
print("test_out: ", model3(Input).data) # 测试一下原始输出 checkpoint = torch.load(model_path_dict2)
model3.head.load_state_dict(checkpoint['net1']) # 给不同的模块分别加载不同的模型
model3.tail.load_state_dict(checkpoint['net2'])
print("test_out: ", model3(Input).data) #测试一下最后的输出
test_out:  tensor([[[[-0.1870]]]])
test_out: tensor([[[[0.2558]]]])

4、加载模型的部分参数

很多时候我们在训练过程中或多或少都会遇到如下问题:

  1. 已经有了与网络匹配的预训练模型,根据情况需要在网络中添加一个小模块,但是还想利用之前的与训练模型
  2. 虽然用的是同一个网络结构,但是由于定义的方法不一样,导致与训练模型的 key 对应不上

在这些情况下,上述加载模型的方式不能很好地解决这些问题,因此在加载模型时需要更精细的控制才能满足我们的要求。首先我们要先了解一下网络加载模型的实质,其实网络和模型都是按照字典的格式进行存储的,如下所示:

net_dic = model.state_dict()	# 加载网络的字典
for key, value in net_dic.items(): # 显示网络的 key value 值
print(key)
print(value)
for key, value in checkpoint.items(): # 显示模型的 key value 值
print(key)
print(value)

输出结果如下:

"""
这是网络的key-value
"""
head.n.weight
tensor([[[[-0.2744, 0.2048, -0.0635],
[-0.1417, 0.2827, -0.2909],
[ 0.0396, -0.0686, 0.2342]]],
...])
head.n.bias
tensor([-0.2389, 0.0188])
tail.n.0.weight
tensor([[[[-0.1658, -0.1408, -0.1394],
[ 0.1010, -0.1735, -0.0215],
[ 0.0153, 0.1298, -0.2054]]
...]])
tail.n.0.bias
tensor([0.0328, 0.1939])
tail.n.1.weight
tensor([[[[ 0.0598, 0.2197, 0.1340],
[-0.1290, 0.1500, -0.1595],
[-0.1066, 0.0536, 0.1065]],
...]])
tail.n.1.bias
tensor([0.0029])
"""
这是与训练模型的key-value
"""
head.n.weight
tensor([[[[-0.2744, 0.2048, -0.0635],
[-0.1417, 0.2827, -0.2909],
[ 0.0396, -0.0686, 0.2342]]],
...])
head.n.bias
tensor([-0.2389, 0.0188])
tail.n.0.weight
tensor([[[[-0.1658, -0.1408, -0.1394],
[ 0.1010, -0.1735, -0.0215],
[ 0.0153, 0.1298, -0.2054]],
...]])
tail.n.0.bias
tensor([0.0328, 0.1939])
tail.n.1.weight
tensor([[[[ 0.0598, 0.2197, 0.1340],
[-0.1290, 0.1500, -0.1595],
[-0.1066, 0.0536, 0.1065]],
...]])
tail.n.1.bias
tensor([0.0029])

因此模型加载的实质可以总结为:找到网络与模型相同的key,将模型对应的参数填入到网络中去。因此若要解决上述问题,只需要在加载模型参数时,进行 if-else 判断进行选择特定的网络层或者筛选特定的模型参数。所以 3.1节中加载模型参数可以写成:

checkpoint = torch.load(model_path_dict)	# 采用加载参数的方式加载与训练模型
model_stic = model.state_dict() # 提取网络的字典
state_dic = {k:v for k,v in checkpoint.items() if k in model_stic.keys()} # 找出待加载模型中与网络key一样的参数
model_stic.update(state_dic) # 更新网络参数
print("test_out1: ", model_test(Input).data) # 查看预训练模型加载后的输出

5、冻结模型的部分参数

在训练网络的时候,有的时候不一定需要网络的每个结构都按照同一个学习率更新,或者有的模块干脆不更新,因此这就需要冻结部分模型参数的梯度,但是又不能截断反向传播的梯度流,不然就会导致网络无法正常训练。

5.1 方法一:requires_grad = false
for name, para in model.named_parameters():
if 'tail' in name:
para.requires_grad = False # 将 tail 模块的梯度更新关闭,即冻结tail的参数 for para in model.parameters(): # 在训练前输出一下网络参数,与训练后进行对比
print(para) for index in range(Iter_num):
Output = model(Input)
loss = criterion(Output, Label)
loss.backward()
optimizer.step()
print("Iter:{}/{}\tloss:{}\tOutput:{}".format(index, Iter_num, loss.data, Output.data)) for para in model.parameters(): # 输出训练后的模型参数
print(para)

训练前的网络的部分参数:

Parameter containing:
tensor([[[[ 0.1211, 0.2768, -0.0686],
[ 0.2494, -0.0537, 0.0353],
[ 0.3018, -0.3092, -0.2098]]],
...], requires_grad=True)
Parameter containing:
tensor([0.1487, 0.1616], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0124, -0.1208, 0.0399],
[-0.2201, -0.1703, -0.1215],
[ 0.1487, 0.1382, -0.1045]],
...]])
Parameter containing:
tensor([ 0.0469, -0.2050])
Parameter containing:
tensor([[[[ 0.0217, -0.1475, -0.2197],
[ 0.2094, 0.1792, -0.2351],
[ 0.0441, -0.0397, -0.0388]],
...]])
Parameter containing:
tensor([0.1177])

训练后网络的参数:

Parameter containing:
tensor([[[[ 0.1256, 0.2754, -0.0720],
[ 0.2429, -0.0717, 0.0461],
[ 0.2887, -0.3248, -0.2124]]],
...], requires_grad=True)
Parameter containing:
tensor([0.1525, 0.1894], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0124, -0.1208, 0.0399],
[-0.2201, -0.1703, -0.1215],
[ 0.1487, 0.1382, -0.1045]],
...]])
Parameter containing:
tensor([ 0.0469, -0.2050])
Parameter containing:
tensor([[[[ 0.0217, -0.1475, -0.2197],
[ 0.2094, 0.1792, -0.2351],
[ 0.0441, -0.0397, -0.0388]],
...]])
Parameter containing:
tensor([0.1177])

通过对比可以发现,网络只更新了 head 层的参数,被冻结的 tail 层参数并没有更新。

5.2 从优化器中设置更新的网络层
import torch.optim as optim
optimizer = optim.SGD(model.head.parameters(), lr = 0.001) # 在优化器中只填入head层的参数
for para in model.parameters(): # 在训练前输出一下网络参数,与训练后进行对比
print(para) for index in range(Iter_num):
Output = model(Input)
loss = criterion(Output, Label)
loss.backward()
optimizer.step()
print("Iter:{}/{}\tloss:{}\tOutput:{}".format(index, Iter_num, loss.data, Output.data)) for para in model.parameters(): # 输出训练后的模型参数
print(para)

训练前的网络的部分参数:

Parameter containing:
tensor([[[[ 0.1211, 0.2768, -0.0686],
[ 0.2494, -0.0537, 0.0353],
[ 0.3018, -0.3092, -0.2098]]],
...], requires_grad=True)
Parameter containing:
tensor([0.1487, 0.1616], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0124, -0.1208, 0.0399],
[-0.2201, -0.1703, -0.1215],
[ 0.1487, 0.1382, -0.1045]],
...]], requires_grad=True)
Parameter containing:
tensor([ 0.0469, -0.2050], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0217, -0.1475, -0.2197],
[ 0.2094, 0.1792, -0.2351],
[ 0.0441, -0.0397, -0.0388]],
...]], requires_grad=True)
Parameter containing:
tensor([0.1177], requires_grad=True)

训练后的网络的部分参数:

Parameter containing:
tensor([[[[ 0.1256, 0.2754, -0.0720],
[ 0.2429, -0.0717, 0.0461],
[ 0.2887, -0.3248, -0.2124]]],
...], requires_grad=True)
Parameter containing:
tensor([0.1525, 0.1894], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0124, -0.1208, 0.0399],
[-0.2201, -0.1703, -0.1215],
[ 0.1487, 0.1382, -0.1045]],
...]], requires_grad=True)
Parameter containing:
tensor([ 0.0469, -0.2050], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0217, -0.1475, -0.2197],
[ 0.2094, 0.1792, -0.2351],
[ 0.0441, -0.0397, -0.0388]],
...]], requires_grad=True)
Parameter containing:
tensor([0.1177], requires_grad=True)

对比这两种方法都能够实现网络某一层参数的冻结而不影响其它层的梯度更新,但是仔细观察发现方法一中不更新参数的网络层的 requires_grad = False,而方法二中所有层的 requires_grad = True。由于个人知识水平有限,难免有错误的地方,还请不吝指正,相互学习,共同进步。

全面解析Pytorch框架下模型存储,加载以及冻结的更多相关文章

  1. 学习笔记TF049:TensorFlow 模型存储加载、队列线程、加载数据、自定义操作

    生成检查点文件(chekpoint file),扩展名.ckpt,tf.train.Saver对象调用Saver.save()生成.包含权重和其他程序定义变量,不包含图结构.另一程序使用,需要重新创建 ...

  2. 转载:通过扩大IE使用内存,解决skyline在IE下模型不能加载的方法

    转自:https://www.cnblogs.com/cannel/p/5261009.html 环境:skyline TerraExploere 6.6,win 10 sp1 64位,ie 11 情 ...

  3. 通过扩大IE使用内存,解决skyline在IE下模型不能加载的方法

    环境:skyline TerraExploere 6.6.1,win10 专业版 64位,ie 11 情况描述:在ie下浏览三维场景,ie占用内存不断增大并且内存占用固定在一个最高范围内,三维场景中部 ...

  4. PyTorch保存模型与加载模型+Finetune预训练模型使用

    Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...

  5. [Pytorch]Pytorch 保存模型与加载模型(转)

    转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...

  6. 浏览器环境下Javascript脚本加载与执行探析之DOMContentLoaded

    在”浏览器环境下Javascript脚本加载与执行探析“系列文章的前几篇,分别针对浏览器环境下JavaScript加载与执行相关的知识点或者属性进行了探究,感兴趣的同学可以先行阅读前几篇文章,了解相关 ...

  7. 重温.NET下Assembly的加载过程 ASP.NET Core Web API下事件驱动型架构的实现(三):基于RabbitMQ的事件总线

    重温.NET下Assembly的加载过程   最近在工作中牵涉到了.NET下的一个古老的问题:Assembly的加载过程.虽然网上有很多文章介绍这部分内容,很多文章也是很久以前就已经出现了,但阅读之后 ...

  8. 重温.NET下Assembly的加载过程

    最近在工作中牵涉到了.NET下的一个古老的问题:Assembly的加载过程.虽然网上有很多文章介绍这部分内容,很多文章也是很久以前就已经出现了,但阅读之后发现,并没能解决我的问题,有些点写的不是特别详 ...

  9. NET下Assembly的加载过程

    NET下Assembly的加载过程 最近在工作中牵涉到了.NET下的一个古老的问题:Assembly的加载过程.虽然网上有很多文章介绍这部分内容,很多文章也是很久以前就已经出现了,但阅读之后发现,并没 ...

随机推荐

  1. leetcode 1081

    开始的思路是遍历存储每个字符的所有位置,再进行扫描处理,但是实际操作并没有很熟练,于是在讨论区学习后,有了下面的解法! 首先需要知道不同的字符在字符串中的最后的位置(理论上的最优位置) 然后扫描字符串 ...

  2. Zoho Projects助力企业项目高效管理

    挑选项目管理工具,就和人买衣服.买鞋子是一样的,除了看外观,最重要的是合适.随着项目管理工具的不断发展,市面上有很多工具都非常优秀,也能解决企业.团队的实际需求. 对于项目管理来说,最重要的在于人员协 ...

  3. Visual Studio/VS中任务列表的妙用

    一.任务列表开启方法 首先说下开启的方法:视图-任务列表,即可打开任务列表. 快捷键Ctrl+'\'+T,熟练了可以快速开启.注意,'\'键是回车键上面的'',不要按成了'/' 二.任务列表标签设置 ...

  4. [DB] 数据库概述

    基本概念 关系模型:包括关系数据结构.关系操作集合.关系完整性约束三部分 关系型数据库:建立在关系模型基础上的数据库.由多张能互相联接的二维行列表格组成. 非关系型数据库(Nosql(Not Only ...

  5. [bug] Junit initializationError

    原因 导包错误 解决 先删除 import org.junit.Test; 再导入正确的包 参考 https://blog.csdn.net/javae100/article/details/7978 ...

  6. 006.Ansible自定义变量

    ansible支持变量,用于存储会在整个项目中重复使用到的一些值.以简化项目的创建与维护,降低出错的机率. 变量的定义: 变量名应该由字母.数字下划数组成 变量名必须以字母开头 ansible内置关键 ...

  7. mysql基础之忘掉密码解决办法及恢复root最高权限办法

    如果忘记了mysql的root用户的密码,可以使用如下的方法,重置root密码. 方法一: 1.停止当前mysql进程 systemctl stop mariadb 2.mysql进程停止后,使用如下 ...

  8. linux 解压总结

    tar解压 gz解压 bz2等各种解压文件使用方法 .tar 解包:tar xvf FileName.tar 打包:tar cvf FileName.tar DirName (注:tar是打包,不是压 ...

  9. android格式化日期

    import android.text.format.DateFormat import java.util.* dateTextView.text = DateFormat.format(" ...

  10. ASP.Net Core5.0 EF Core使用记录

    打算把之前开源的 基于ASP.Net Core开发一套通用后台框架 重新用ASP.Net Core 5写一遍,也算是巩固一下旧知识,学习下新知识.本文是项目搭建初期关于 EF Core 的使用记录 1 ...