转自:https://blog.csdn.net/Vivianyzw/article/details/81061765

东风的地方

1. 直接加载预训练模型

在训练的时候可能需要中断一下,然后继续训练,也就是简单的从保存的模型中加载参数权重:


  1. net = SNet()
  2. net.load_state_dict(torch.load("model_1599.pkl"))

这种方式是针对于之前保存模型时以保存参数的格式使用的:

torch.save(net.state_dict(), "model/model_1599.pkl")

pytorch官网更推荐上述模型保存方法,也据说这种方式比下一种更快一点。

下面介绍第二种模型保存和加载的方式:


  1. net = SNet()
  2. torch.save(net, "model_1599.pkl")
  3. snet = torch.load("model_1599.pkl")

这种方式会将整个网络保存下来,数据量会更大,会消耗更多的时间,占用内存也更高。

2. 加载一部分预训练模型

模型可能是一些经典的模型改掉一部分,比如一般算法中提取特征的网络常见的会直接使用vgg16的features extraction部分,也就是在训练的时候可以直接加载已经在imagenet上训练好的预训练参数,这种方式实现如下:


  1. net = SNet()
  2. model_dict = net.state_dict()
  3. vgg16 = models.vgg16(pretrained=True)
  4. pretrained_dict = vgg16.state_dict()
  5. pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  6. model_dict.update(pretrained_dict)
  7. net.load_state_dict(model_dict)

也就是在网络中state_dict部分,属于vgg16的,替换成vgg16预训练模型里的参数(代码里的k:v for k,v in pretrained_dict.items() if k in model_dict),其他保持不变。

3. 微调经典网络

因为pytorch中的torchvision给出了很多经典常用模型,并附加了预训练模型。利用好这些训练好的基础网络可以加快不少自己的训练速度。

首先比如加载vgg16(带有预训练参数的形式):


  1. import torchvision.models as models
  2. vgg16 = models.vgg16(pretrained=True)

比如,网络第一层本来是Conv2d(3, 64, 3, 1, 1),想修改成Conv2d(4, 64, 3, 1 ,1),那直接赋值就可以了:


  1. import torch.nn as nn
  2. vgg16.features[0]=nn.Conv2d(4, 64, 3, 1, 1)

4. 修改经典网络

这个比上面微调修改的地方要多一些,但是想介绍一下这样的修改方式。

先简单介绍一下我需要需改的部分,在vgg16的基础模型下,每一个卷积都要加一个dropout层,并将ReLU激活函数换成PReLU,最后两层的Pooling层stride改成1。直接上代码:


  1. def feature_layer():
  2. layers = []
  3. pool1 = ['4', '9', '16']
  4. pool2 = ['23', '30']
  5. vgg16 = models.vgg16(pretrained=True).features
  6. for name, layer in vgg16._modules.items():
  7. if isinstance(layer, nn.Conv2d):
  8. layers += [layer, nn.Dropout2d(0.5), nn.PReLU()]
  9. elif name in pool1:
  10. layers += [layer]
  11. elif name == pool2[0]:
  12. layers += [nn.MaxPool2d(2, 1, 1)]
  13. elif name == pool2[1]:
  14. layers += [nn.MaxPool2d(2, 1, 0)]
  15. else:
  16. continue
  17. features = nn.Sequential(*layers)
  18. #feat3 = features[0:24]
  19. return features

大概的思路就是,创建一个新的网络(layers列表), 遍历vgg16里每一层,如果遇到卷积层(if isinstance(layer, nn.Conv2d)就先把该层(Conv2d)保持原样加进去,随后增加一个dropout层,再加一个PReLU层。然后如果遇到最后两层pool,就修改响应参数加进去,其他的pool正常加载。 最后将这个layers列表转成网络的nn.Sequential的形式,最后返回features。然后再你的新的网络层就可以用以下方式来加载:


  1. class SNet(nn.Module):
  2. def __init__(self):
  3. super(SNet, self).__init__()
  4. self.features = feature_layer()
  5. def forward(self, x):
  6. x = self.features(x)
  7. return x

[Pytorch]Pytorch加载预训练模型(转)的更多相关文章

  1. pytorch中修改后的模型如何加载预训练模型

    问题描述 简单来说,比如你要加载一个vgg16模型,但是你自己需要的网络结构并不是原本的vgg16网络,可能你删掉某些层,可能你改掉某些层,这时你去加载预训练模型,就会报错,错误原因就是你的模型和原本 ...

  2. 使用Huggingface在矩池云快速加载预训练模型和数据集

    作为NLP领域的著名框架,Huggingface(HF)为社区提供了众多好用的预训练模型和数据集.本文介绍了如何在矩池云使用Huggingface快速加载预训练模型和数据集. 1.环境 HF支持Pyt ...

  3. pytorch加载预训练模型参数的方式

    1.直接使用默认程序里的下载方式,往往比较慢: 2.通过修改源代码,使得模型加载已经下载好的参数,修改地方如下: 通过查找自己代码里所调用网络的类,使用pycharm自带的函数查找功能(ctrl+鼠标 ...

  4. Tensorflow加载预训练模型和保存模型(ckpt文件)以及迁移学习finetuning

    转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...

  5. Tensorflow加载预训练模型和保存模型

    转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...

  6. PyTorch模型加载与保存的最佳实践

    一般来说PyTorch有两种保存和读取模型参数的方法.但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题. 第一种方案是保存整个模型: 1 torch.save(model_object, ...

  7. PyTorch数据加载处理

    PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 scikit-image:用于图像的IO和变换 pandas:用于更容易地进行csv解 ...

  8. 【小白学PyTorch】5 torchvision预训练模型与数据集全览

    文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...

  9. pytorch数据加载器

    class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, ...

随机推荐

  1. ubuntu 创建桌面快捷方式

    $sudo apt-get install gnome-panel $gnome-desktop-item-edit  /home/xxx/桌面 --create-new 命令行:填入程序名称,如/u ...

  2. bootstrap+html5+css3

    一.栅格和块阴影 <!DOCTYPE html> <html> <head> <title>Bootstrap 实例 - 堆叠的水平</title ...

  3. [转帖]双剑合璧:CPU+GPU异构计算完全解析

    引用自:http://tech.sina.com.cn/mobile/n/2011-06-20/18371792199.shtml 这篇文章写的深入浅出,把异构计算的思想和行业趋势描述的非常清楚,难得 ...

  4. WCF(四) 深入契约

    服务契约中的请求-响应操作 1.请求-响应模式 [OperationContract]//1默认就是 请求-相应 Requst- Replay DateTime GetDateTime(); [Ope ...

  5. onethink后台登陆修改验证码!

    验证码: $config = array( 'fontSize' => 30, // 验证码字体大小 'length' => 3, // 验证码位数 'useNoise' => fa ...

  6. 浅谈Lambda表达式详解

    lambda简介 lambda运算符:所有的lambda表达式都是用新的lambda运算符 " => ",可以叫他,“转到”或者 “成为”.运算符将表达式分为两部分,左边指定 ...

  7. 修改mysql的字符集和默认存储引擎

    转自:http://blog.csdn.net/wyzxg/article/details/8779682 author:skatetime:2012/05/18 修改mysql的字符集和默认存储引擎 ...

  8. Network of Schools---poj1236(强连通分量)

    题目链接 题意:学校有一些单向网络,现在需要传一些文件 求:1,求最少需要向几个学校分发文件才能让每个学校都收到, 2,需要添加几条网络才能从任意一个学校分发都可以传遍所有学校. 解题思路(参考大神的 ...

  9. arc 和 非arc兼容

    1,选择项目中的Targets,选中你所要操作的Target, 2,选Build Phases,在其中Complie Sources中选择需要ARC的文件双击, 并在输入框中输入:-fobjc-arc ...

  10. Python之迭代器及生成器

    一. 迭代器 1.1 什么是可迭代对象 字符串.列表.元组.字典.集合 都可以被for循环,说明他们都是可迭代的. 我们怎么来证明这一点呢? from collections import Itera ...