PyTorch保存模型与加载模型+Finetune预训练模型使用
参数初始化参
数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:
def weight_init(m):
# 使用isinstance来判断m属于什么类型
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
m.weight.data.fill_(1)
m.bias.data.zero_()
Finetune
往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。
局部微调:有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)
# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
全局微调:有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:
ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
model.parameters())
optimizer = torch.optim.SGD([
{'params': base_params},
{'params': model.fc.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。
加载部分预训练模型:其实大多数时候我们需要根据我们的任务调节我们的模型,所以很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。
pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
model_dict = model.state_dict() # 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是一样的
我们可以通过下列算法进行读取模型
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
# 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
keys.append(k)
i = 0
for k,v in model_dict.items():
if v.size() == pretrained_dict[keys[i]].size():
print(k, ',', keys[i])
model_dict[k]=pretrained_dict[keys[i]]
i = i + 1
model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的
自己找对应关系,一个key对应一个key的赋值
PyTorch保存模型与加载模型+Finetune预训练模型使用的更多相关文章
- [Pytorch]Pytorch 保存模型与加载模型(转)
转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...
- 【4】TensorFlow光速入门-保存模型及加载模型并使用
本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...
- 如何使用 opencv 加载 darknet yolo 预训练模型?
如何使用 opencv 加载 darknet yolo 预训练模型? opencv 版本 > 3.4 以上 constexpr const char *image_path = "da ...
- 莫烦python教程学习笔记——保存模型、加载模型的两种方法
# View more python tutorials on my Youtube and Youku channel!!! # Youtube video tutorial: https://ww ...
- 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)
1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...
- keras模型的保存与重新加载
# 模型保存JSON文件 model_json = model.to_json() with open('model.json', 'w') as file: file.write(model_jso ...
- TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结
写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...
- MindSpore保存与加载模型
技术背景 近几年在机器学习和传统搜索算法的结合中,逐渐发展出了一种Search To Optimization的思维,旨在通过构造一个特定的机器学习模型,来替代传统算法中的搜索过程,进而加速经典图论等 ...
- NeHe OpenGL教程 第三十一课:加载模型
转自[翻译]NeHe OpenGL 教程 前言 声明,此 NeHe OpenGL教程系列文章由51博客yarin翻译(2010-08-19),本博客为转载并稍加整理与修改.对NeHe的OpenGL管线 ...
随机推荐
- php调用python脚本
主要参考两篇文章 PHP中的换行详解 利用PHP调试Python Python小窥 - 写给Python的入门者 这两篇文章结合起来进行测试,主要过程如下 cd /var/www/html mkdir ...
- @repository的含义,并且有时候却不用写,为什么?
//最后发现是这样的:@repository跟@Service,@Compent,@Controller这4种注解是没什么本质区别,都是声明作用,取不同的名字只是为了更好区分各自的功能.下图更多的作用 ...
- Android Intent Service
Android Intent Service 学习自 Android 官方文档 https://blog.csdn.net/iromkoear/article/details/63252665 Ove ...
- Activity-Flag标志位
Activity-Flag标志位 学习自 <Android开发艺术探索> 标志位漫谈 var intent: Intent = Intent(this, Test2Activity::cl ...
- nginx命令大全
sudo nginx #打开 nginx nginx -s reload|reopen|stop|quit #重新加载配置|重启|停止|退出 nginx nginx -t #测试配置是否有语法错误 n ...
- LOJ.2721.[NOI2018]屠龙勇士(扩展CRT 扩展欧几里得)
题目链接 LOJ 洛谷 rank前3无压力(话说rank1特判打表有意思么) \(x*atk[i] - y*p[i] = hp[i]\) 对于每条龙可以求一个满足条件的\(x_0\),然后得到其通解\ ...
- GitLab目录迁移方法
在生产环境上迁移GitLab的目录需要注意一下几点: 1.目录的权限必须为755或者775 2.目录的用户和用户组必须为git:git 3.如果在深一级的目录下,那么git用户必须添加到上一级目录的账 ...
- 使用CefSharp在.Net程序中嵌入Chrome浏览器(六)——调试
chrome强大的调试功能令许多开发者爱不释手,在使用cef的时候,我们也可以继承这强大的开发者工具. 集成调试: 我们可以使用如下函数直接使用集成在chrome里的开发者工具 _chrome.Sho ...
- JTAG TAP Controller
The TAP controller is a synchronous finite state machine that responds to changes at the TMS and TCK ...
- Go内置库模块 flag
import "flag" flag包实现了命令行参数的解析.每个参数认为一条记录,根据实际进行定义,到一个set集合.每条都有各自的状态参数. 在使用flag时正常流程: 1. ...