参数初始化参

数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
m.weight.data.fill_(1)
m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。
局部微调:有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)

# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调:有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
                     model.parameters())

optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.fc.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)
其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

加载部分预训练模型:其实大多数时候我们需要根据我们的任务调节我们的模型,所以很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

pretrained_dict = model_zoo.load_url(model_urls['resnet152']) 
model_dict = model.state_dict() # 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是一样的

我们可以通过下列算法进行读取模型

model_dict = model.state_dict()

pretrained_dict = torch.load(model_path)
 # 1. filter out unnecessary keys
    diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() == v.size()}
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
    pretrained_dict.update(diff)
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的

model_dict = model.state_dict()

pretrained_dict = torch.load(model_path)
    keys = []
    for k,v in pretrained_dict.items():
        keys.append(k)
    i = 0
    for k,v in model_dict.items():
        if v.size() == pretrained_dict[keys[i]].size():
            print(k, ',', keys[i])
            model_dict[k]=pretrained_dict[keys[i]]
        i = i + 1
    model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的

自己找对应关系,一个key对应一个key的赋值

PyTorch保存模型与加载模型+Finetune预训练模型使用的更多相关文章

  1. [Pytorch]Pytorch 保存模型与加载模型(转)

    转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...

  2. 【4】TensorFlow光速入门-保存模型及加载模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  3. 如何使用 opencv 加载 darknet yolo 预训练模型?

    如何使用 opencv 加载 darknet yolo 预训练模型? opencv 版本 > 3.4 以上 constexpr const char *image_path = "da ...

  4. 莫烦python教程学习笔记——保存模型、加载模型的两种方法

    # View more python tutorials on my Youtube and Youku channel!!! # Youtube video tutorial: https://ww ...

  5. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

  6. keras模型的保存与重新加载

    # 模型保存JSON文件 model_json = model.to_json() with open('model.json', 'w') as file: file.write(model_jso ...

  7. TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结

    写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...

  8. MindSpore保存与加载模型

    技术背景 近几年在机器学习和传统搜索算法的结合中,逐渐发展出了一种Search To Optimization的思维,旨在通过构造一个特定的机器学习模型,来替代传统算法中的搜索过程,进而加速经典图论等 ...

  9. NeHe OpenGL教程 第三十一课:加载模型

    转自[翻译]NeHe OpenGL 教程 前言 声明,此 NeHe OpenGL教程系列文章由51博客yarin翻译(2010-08-19),本博客为转载并稍加整理与修改.对NeHe的OpenGL管线 ...

随机推荐

  1. 3,EasyNetQ-发布/订阅

    一.发布 在发布/订阅模式中的角色是彼此陌生的. 一个发布者只是向世界说这个已经发生了,一位订阅者告诉世界“我在乎这个”. 在这个模型中,没有人关心特定的事件是很好的. 消息可能有一个订阅者,可能有2 ...

  2. yum与apt命令比较,yum安装出现No package vim available解决办法

    yum (Yellowdog Updater Modified)是一个集与查找,安装,更新和删除程序的Linux软件.它运行在RPM包兼容的Linux发行版本上,如:RedHat, Fedora, S ...

  3. 用户 'IIS APPPOOL\DefaultAppPool' 登录失败【收藏】

    转载:http://blog.csdn.net/wenjie315130552/article/details/7246143 问题是应用程序连接池的问题.网上有些朋友说是Temp文件夹的权限的问题. ...

  4. vs 2010 :类型化数据集DataSet应用

    1.启动服务器资源管理器,建立数据库连接 2.在项目中创建数据集 3.为数据集添加表对象 4.为表适配器tableAdapter添加参数化查询 5.修改表适配器的主查询,或添加其他查询 Update: ...

  5. ant design的一些坑

    1.在本地修改ant design的某些样式可以生效,但在线上就失效了.比如collapse组件里的箭头图标在本地和在线上的类名有变化,本地类名,线上类名:箭头图标的svg样式在线上会自动添加一个内联 ...

  6. C#如何直接调用非托管代码

    C#如何直接调用非托管代码,通常有2种方法: 1.  直接调用从 DLL 导出的函数. 2.  调用 COM 对象上的接口方法 我主要讨论从dll中导出函数,基本步骤如下: 1.使用 C# 关键字 s ...

  7. JTAG – A technical overview and Timing

    This document provides you with interesting background information about the technology that underpi ...

  8. Git 忽略某个目录中的文件,同时保留这个目录

    类似的一个问题是项目根目录下可能有 logs 一类的目录, 我们希望他人把仓库 clone 下来的时候能够已经携带了这个目录, 但又不希望让这个目录中的日志文件进版本库. 之前看到一些项目用了一种比较 ...

  9. 关于bootstrap的treeview不显示多选(复选框)的问题,以及联动选择的问题,外加多选后取值

    最近做项目用到了treeview.因为涉及到多选的问题,很是棘手,于是乎,我决定查看原生JS,探个究竟.需要引用官方的bootstrap-treeview.js都知道吧,对于所需要引用的,我就不多说了 ...

  10. 《Go学习笔记 . 雨痕》方法

    一.定义 方法 是与对象实例绑定的特殊函数. 方法 是面向对象编程的基本概念,用于维护和展示对象的自身状态.对象是内敛的,每个实例都有各自不同的独立特征,以 属性 和 方法 来暴露对外通信接口.普通函 ...