使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口。在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等。

模型保存/加载

1.所有模型参数

训练过程中,有时候会由于各种原因停止训练,这时候我们训练过程中就需要注意将每一轮epoch的模型保存(一般保存最好模型与当前轮模型)。一般使用pytorch里面推荐的保存方法。该方法保存的是模型的参数。

  1. #保存模型到checkpoint.pth.tar
  2. torch.save(model.module.state_dict(), checkpoint.pth.tar’)

对应的加载模型方法为(这种方法需要先反序列化模型获取参数字典,因此必须先load模型,再load_state_dict):

  1. mymodel.load_state_dict(torch.load(‘checkpoint.pth.tar’))

有了上面的保存后,现以一个例子说明如何在inference AND/OR resume train使用。

  1. #保存模型的状态,可以设置一些参数,后续可以使用
  2. state = {'epoch': epoch + 1,#保存的当前轮数
  3. 'state_dict': mymodel.state_dict(),#训练好的参数
  4. 'optimizer': optimizer.state_dict(),#优化器参数,为了后续的resume
  5. 'best_pred': best_pred#当前最好的精度
  6. ,....,...}
  7. #保存模型到checkpoint.pth.tar
  8. torch.save(state, checkpoint.pth.tar’)
  9. #如果是best,则复制过去
  10. if is_best:
  11. shutil.copyfile(filename, directory + 'model_best.pth.tar')
  12. checkpoint = torch.load('model_best.pth.tar')
  13. model.load_state_dict(checkpoint['state_dict'])#模型参数
  14. optimizer.load_state_dict(checkpoint['optimizer'])#优化参数
  15. epoch = checkpoint['epoch']#epoch,可以用于更新学习率等
  16. #有了以上的东西,就可以继续重新训练了,也就不需要担心停止程序重新训练。
  17. train/eval
  18. ....
  19. ....

上面是pytorch建议使用的方法,当然还有第二种方法。这种方法灵活性不高,不推荐。

  1. #保存
  2. torch.save(mymodel,‘checkpoint.pth.tar’)
  3. #加载
  4. mymodel = torch.load(‘checkpoint.pth.tar’)

2.部分模型参数

在很多时候,我们加载的是已经训练好的模型,而训练好的模型可能与我们定义的模型不完全一样,而我们只想使用一样的那些层的参数。

有几种解决方法:

(1)直接在训练好的模型开始搭建自己的模型,就是先加载训练好的模型,然后再它基础上定义自己的模型;

  1. model_ft = models.resnet18(pretrained=use_pretrained)
  2. self.conv1 = model_ft.conv1
  3. self.bn = model_ft.bn
  4. ... ...

(2) 自己定义好模型,直接加载模型

  1. #第一种方法:
  2. mymodelB = TheModelBClass(*args, **kwargs)
  3. # strict=False,设置为false,只保留键值相同的参数
  4. mymodelB.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
  5. #第二种方法:
  6. # 加载模型
  7. model_pretrained = models.resnet18(pretrained=use_pretrained)
  8. # mymodel's state_dict,
  9. # 如: conv1.weight
  10. # conv1.bias
  11. mymodelB_dict = mymodelB.state_dict()
  12. # 将model_pretrained的建与自定义模型的建进行比较,剔除不同的
  13. pretrained_dict = {k: v for k, v in model_pretrained.items() if k in mymodelB_dict}
  14. # 更新现有的model_dict
  15. mymodelB_dict.update(pretrained_dict)
  16. # 加载我们真正需要的state_dict
  17. mymodelB.load_state_dict(mymodelB_dict)
  18. # 方法2可能更直观一些

参数初始化

第二个问题是参数初始化问题,在很多代码里面都会使用到,毕竟不是所有的都是有预训练参数。这时就需要对不是与预训练参数进行初始化。pytorch里面的每个Tensor其实是对Variabl的封装,其包含data、grad等接口,因此可以用这些接口直接赋值。这里也提供了怎样把其他框架(caffe/tensorflow/mxnet/gluonCV等)训练好的模型参数直接赋值给pytorch.其实就是对data直接赋值。

pytorch提供了初始化参数的方法:

  1. def weight_init(m):
  2. if isinstance(m,nn.Conv2d):
  3. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  4. m.weight.data.normal_(0,math.sqrt(2./n))
  5. elif isinstance(m,nn.BatchNorm2d):
  6. m.weight.data.fill_(1)
  7. m.bias.data.zero_()

但一般如果没有很大需求初始化参数,也没有问题(不确定性能是否有影响的情况下),pytorch内部是有默认初始化参数的。

Fintune

最后就是精调了,我们平时做实验,至少backbone是用预训练的模型,将其用作特征提取器,或者在它上面做精调。

用于特征提取的时候,要求特征提取部分参数不进行学习,而pytorch提供了requires_grad参数用于确定是否进去梯度计算,也即是否更新参数。以下以minist为例,用resnet18作特征提取:

  1. #加载预训练模型
  2. model = torchvision.models.resnet18(pretrained=True)
  3. #遍历每一个参数,将其设置为不更新参数,即不学习
  4. for param in model.parameters():
  5. param.requires_grad = False
  6. # 将全连接层改为mnist所需的10类,注意:这样更改后requires_grad默认为True
  7. model.fc = nn.Linear(512, 10)
  8. # 优化
  9. optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

用于全局精调时,我们一般对不同的层需要设置不同的学习率,预训练的层学习率小一点,其他层大一点。这要怎么做呢?

  1. # 加载预训练模型
  2. model = torchvision.models.resnet18(pretrained=True)
  3. model.fc = nn.Linear(512, 10)
  4. # 参考:https://blog.csdn.net/u012759136/article/details/65634477
  5. ignored_params = list(map(id, model.fc.parameters()))
  6. base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
  7. # 对不同参数设置不同的学习率
  8. params_list = [{'params': base_params, 'lr': 0.001},]
  9. params_list.append({'params': model.fc.parameters(), 'lr': 0.01})
  10. optimizer = torch.optim.SGD(params_list,
  11. 0.001
  12. momentum=args.momentum,
  13. weight_decay=args.weight_decay)

最后整理一下目前,pytorch预训练的基础模型:

(1)torchvision

torchvision里面已经提供了不同的预训练模型,一般也够用了。

pytorch/visiongithub.com

包含了alexnet/densenet各种版本(densenet121/densenet169/densenet201/densenet161)/inception_v3/resnet各种版本(resnet18', 'resnet34', 'resnet50', 'resnet101','resnet152')/SqueezeNet各种版本( 'squeezenet1_0', 'squeezenet1_1')/VGG各种版本( 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn','vgg19_bn', 'vgg19')

(2)其他预训练好的模型,如,SENet/NASNet等。

Cadene/pretrained-models.pytorchgithub.com

(3)gluonCV转pytorch的模型,包括,分类网络,分割网络等,这里的精度均比其他框架高几个百分点。

zhanghang1989/gluoncv-torchgithub.com

PyTorch模型读写、参数初始化、Finetune的更多相关文章

  1. [PyTorch]PyTorch中模型的参数初始化的几种方法(转)

    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 本文目录 1. xavier初始化 2. kaiming初始化 3. 实际使用中看到的初始化 3.1 ResNeXt,de ...

  2. Pytorch基础(6)----参数初始化

    一.使用Numpy初始化:[直接对Tensor操作] 对Sequential模型的参数进行修改: import numpy as np import torch from torch import n ...

  3. pytorch对模型参数初始化

    1.使用apply() 举例说明: Encoder :设计的编码其模型 weights_init(): 用来初始化模型 model.apply():实现初始化 # coding:utf- from t ...

  4. pytorch和tensorflow的爱恨情仇之参数初始化

    pytorch和tensorflow的爱恨情仇之基本数据类型 pytorch和tensorflow的爱恨情仇之张量 pytorch和tensorflow的爱恨情仇之定义可训练的参数 pytorch版本 ...

  5. PyTorch常用参数初始化方法详解

    1. 均匀分布 torch.nn.init.uniform_(tensor, a=0, b=1) 从均匀分布U(a, b)中采样,初始化张量. 参数: tensor - 需要填充的张量 a - 均匀分 ...

  6. 从零搭建Pytorch模型教程(四)编写训练过程--参数解析

    ​  前言 训练过程主要是指编写train.py文件,其中包括参数的解析.训练日志的配置.设置随机数种子.classdataset的初始化.网络的初始化.学习率的设置.损失函数的设置.优化方式的设置. ...

  7. 【转载】 pytorch自定义网络结构不进行参数初始化会怎样?

    原文地址: https://blog.csdn.net/u011668104/article/details/81670544 ------------------------------------ ...

  8. ubuntu之路——day15.1 只用python的numpy在底层检验参数初始化对模型的影响

    首先感谢这位博主整理的Andrew Ng的deeplearning.ai的相关作业:https://blog.csdn.net/u013733326/article/details/79827273 ...

  9. DEX-6-caffe模型转成pytorch模型办法

    在python2.7环境下 文件下载位置:https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ 1.可视化模型文件prototxt 1)在线可视化 ...

随机推荐

  1. WPF 饼状图,柱形图,折线图 (2 折线图)

    折线图在柱形图的基础上,做了一些修改.大概效果和用法如下. X轴和Y轴的刻度,使用用了Path的Figures属性,绘制多条Figure+LineSegment完成. 同时,由于折线图很可能会画多条线 ...

  2. 如何在github上传本地项目代码

    首先你要在github上申请一个账号 网址:https://github.com/ 然后你要下载一个git工具 网址:https://gitforwindows.org/ 进入官网直接下载就行,下载完 ...

  3. 根据当前设备的宽度,动态计算出rem的换算比例,实现页面中元素的等比缩放

    ~function anonymous(window){ //根据当前设备的宽度,动态计算出rem的换算比例,实现页面中元素的等比缩放 let computedREM = function compu ...

  4. tar.gz 文件解压

    tar.gz 文件解压 解压缩 file.tar.gz 的过程中出现如下所示问题: tar: 它似乎不像是一个 tar 归档文件 tar: 跳转到下一个头 tar: 由于前次错误,将以上次的错误状态退 ...

  5. Zookeeper 序列化机制

    一.到底在哪些地方需要使用序列化技术呢? 二.Zookeeper(分布式协调服务组件+存储系统) Java 序列化机制 Hadoop序列化机制 Zookeeper序列化机制 一.到底在哪些地方需要使用 ...

  6. Springboot 日志、配置文件、接口数据如何脱敏?老鸟们都是这样玩的!

    一.前言 核心隐私数据无论对于企业还是用户来说尤其重要,因此要想办法杜绝各种隐私数据的泄漏.下面陈某带大家从以下三个方面讲解一下隐私数据如何脱敏,也是日常开发中需要注意的: 配置文件数据脱敏 接口返回 ...

  7. Python实现Thrift Server

    近期在项目中存在跨编程语言协作的需求,使用到了Thrift.本文将记录用python实现Thrift服务端的方法. 环境准备 根据自身实际情况下载对应的Thrift编译器,比如我在Windows系统上 ...

  8. openwrt开发笔记二:树莓派刷openwrt

    前言及准备 本笔记适用于第一次给树莓派刷openwrt系统的玩家,对刷机过程及注意事项进行了记录,刷机之后对openwrt进行一些简单配置. 使用openwrt源码制作固件需要花费一点时间. 平台环境 ...

  9. Python - 面向对象编程 - __init__() 构造方法

    什么是构造方法 在创建类时, 可手动添加一个   __init__() 方法,称为构造方法,这是一个实例方法 构造方法用于创建实例对象时使用,每当创建一个类的实例对象时,Python 解释器都会自动调 ...

  10. Selenium系列(十九) - Web UI 自动化基础实战(6)

    如果你还想从头学起Selenium,可以看看这个系列的文章哦! https://www.cnblogs.com/poloyy/category/1680176.html 其次,如果你不懂前端基础知识, ...