PyTorch保存模型与加载模型+Finetune预训练模型使用
参数初始化参
数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:
def weight_init(m):
# 使用isinstance来判断m属于什么类型
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
m.weight.data.fill_(1)
m.bias.data.zero_()
Finetune
往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。
局部微调:有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)
# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
全局微调:有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:
ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
model.parameters())
optimizer = torch.optim.SGD([
{'params': base_params},
{'params': model.fc.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。
加载部分预训练模型:其实大多数时候我们需要根据我们的任务调节我们的模型,所以很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。
pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
model_dict = model.state_dict() # 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是一样的
我们可以通过下列算法进行读取模型
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
# 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
keys.append(k)
i = 0
for k,v in model_dict.items():
if v.size() == pretrained_dict[keys[i]].size():
print(k, ',', keys[i])
model_dict[k]=pretrained_dict[keys[i]]
i = i + 1
model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的
自己找对应关系,一个key对应一个key的赋值
PyTorch保存模型与加载模型+Finetune预训练模型使用的更多相关文章
- [Pytorch]Pytorch 保存模型与加载模型(转)
转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...
- 【4】TensorFlow光速入门-保存模型及加载模型并使用
本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...
- 如何使用 opencv 加载 darknet yolo 预训练模型?
如何使用 opencv 加载 darknet yolo 预训练模型? opencv 版本 > 3.4 以上 constexpr const char *image_path = "da ...
- 莫烦python教程学习笔记——保存模型、加载模型的两种方法
# View more python tutorials on my Youtube and Youku channel!!! # Youtube video tutorial: https://ww ...
- 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)
1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...
- keras模型的保存与重新加载
# 模型保存JSON文件 model_json = model.to_json() with open('model.json', 'w') as file: file.write(model_jso ...
- TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结
写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...
- MindSpore保存与加载模型
技术背景 近几年在机器学习和传统搜索算法的结合中,逐渐发展出了一种Search To Optimization的思维,旨在通过构造一个特定的机器学习模型,来替代传统算法中的搜索过程,进而加速经典图论等 ...
- NeHe OpenGL教程 第三十一课:加载模型
转自[翻译]NeHe OpenGL 教程 前言 声明,此 NeHe OpenGL教程系列文章由51博客yarin翻译(2010-08-19),本博客为转载并稍加整理与修改.对NeHe的OpenGL管线 ...
随机推荐
- 3,EasyNetQ-发布/订阅
一.发布 在发布/订阅模式中的角色是彼此陌生的. 一个发布者只是向世界说这个已经发生了,一位订阅者告诉世界“我在乎这个”. 在这个模型中,没有人关心特定的事件是很好的. 消息可能有一个订阅者,可能有2 ...
- yum与apt命令比较,yum安装出现No package vim available解决办法
yum (Yellowdog Updater Modified)是一个集与查找,安装,更新和删除程序的Linux软件.它运行在RPM包兼容的Linux发行版本上,如:RedHat, Fedora, S ...
- 用户 'IIS APPPOOL\DefaultAppPool' 登录失败【收藏】
转载:http://blog.csdn.net/wenjie315130552/article/details/7246143 问题是应用程序连接池的问题.网上有些朋友说是Temp文件夹的权限的问题. ...
- vs 2010 :类型化数据集DataSet应用
1.启动服务器资源管理器,建立数据库连接 2.在项目中创建数据集 3.为数据集添加表对象 4.为表适配器tableAdapter添加参数化查询 5.修改表适配器的主查询,或添加其他查询 Update: ...
- ant design的一些坑
1.在本地修改ant design的某些样式可以生效,但在线上就失效了.比如collapse组件里的箭头图标在本地和在线上的类名有变化,本地类名,线上类名:箭头图标的svg样式在线上会自动添加一个内联 ...
- C#如何直接调用非托管代码
C#如何直接调用非托管代码,通常有2种方法: 1. 直接调用从 DLL 导出的函数. 2. 调用 COM 对象上的接口方法 我主要讨论从dll中导出函数,基本步骤如下: 1.使用 C# 关键字 s ...
- JTAG – A technical overview and Timing
This document provides you with interesting background information about the technology that underpi ...
- Git 忽略某个目录中的文件,同时保留这个目录
类似的一个问题是项目根目录下可能有 logs 一类的目录, 我们希望他人把仓库 clone 下来的时候能够已经携带了这个目录, 但又不希望让这个目录中的日志文件进版本库. 之前看到一些项目用了一种比较 ...
- 关于bootstrap的treeview不显示多选(复选框)的问题,以及联动选择的问题,外加多选后取值
最近做项目用到了treeview.因为涉及到多选的问题,很是棘手,于是乎,我决定查看原生JS,探个究竟.需要引用官方的bootstrap-treeview.js都知道吧,对于所需要引用的,我就不多说了 ...
- 《Go学习笔记 . 雨痕》方法
一.定义 方法 是与对象实例绑定的特殊函数. 方法 是面向对象编程的基本概念,用于维护和展示对象的自身状态.对象是内敛的,每个实例都有各自不同的独立特征,以 属性 和 方法 来暴露对外通信接口.普通函 ...