Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易
近日,PyTorch 社区发布了一个深度学习工具包 PyTorchHub, 帮助机器学习工作者更快实现重要论文的复现工作。PyTorchHub 由一个预训练模型仓库组成,专门用于提高研究工作的复现性以及新的研究。同时它还内置了对Google Colab的支持,并与Papers With Code集成。目前 PyTorchHub 包括了一系列与图像分类、分割、生成以及转换相关的模型。
可复现性是许多研究领域的基本要求,这其中当然包括基于机器学习技术的研究领域。然而, 许多机器学习相关论文要么无法复现,要么难以重现。随着论文数量的持续增长,包括目前在 arXiv 上预印刷的数万份论文以及提交给会议的论文,研究工作的可复现性变得越来越重要。虽然其中许多论文都附有代码以及训练好的模型,但这种帮助显然非常有限,复现过程中仍有大量需要读者自己摸索的步骤。下面让我们来看一下如何通过 PyTorch Hub 这一利器完成快速的模型发布与工作复现。

如何快速发布模型
这部分主要介绍了对于模型发布者来说如何快速高效的将自己的模型加入 PyTorch Hub 库。PyTorch Hub 支持通过添加简单的 hubconf.py 文件将预先训练的模型(模型定义和预先训练重)发布到 GitHub 存储库。这提供了模型列表以及其依赖库列表。一些示例可以在torchvision,huggingface-bert和gan-model-zoo存储库中找到。
Pytoch 社区给出了 torchvision 的 hubconf.py 文件的示例:
|
# Optional list of dependencies required by the package
|
|
|
dependencies = ['torch']
|
|
|
from torchvision.models.alexnet import alexnet
|
|
|
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
|
|
|
from torchvision.models.inception import inception_v3
|
|
|
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d
|
|
|
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
|
|
|
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
|
|
|
from torchvision.models.segmentation import fcn_resnet101, deeplabv3_resnet101
|
|
|
from torchvision.models.googlenet import googlenet
|
|
|
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
|
|
|
from torchvision.models.mobilenet import mobilenet_v2
|
在 torchvision 中,模型有以下特性:
- 每个模型文件可以被独立执行或实现某个功能
- 不需要除了 PyTorch 之外的任何软件包(在 hubconf.py 中编码为 dependencies[‘torch’])
- 他们不需要单独的入口点,因为模型在创建时可以无缝地开箱即用。
PyTroch 社区认为最小化包依赖性可减少用户加载模型时遇到的困难。这里他们给出了一个更为复杂的例子——HuggingFace’s BERT 模型,它的 hubconf.py 如下:
|
dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex']
|
|
|
from hubconfs.bert_hubconf import (
|
|
|
bertTokenizer,
|
|
|
bertModel,
|
|
|
bertForNextSentencePrediction,
|
|
|
bertForPreTraining,
|
|
|
bertForMaskedLM,
|
|
|
bertForSequenceClassification,
|
|
|
bertForMultipleChoice,
|
|
|
bertForQuestionAnswering,
|
|
|
bertForTokenClassification
|
|
|
)
|
此外,对于每个模型,PyTorch 官方提到都需要为其创建一个入口点。下面是一个用于指定 bertForMaskedLM 模型的入口点的代码片段,这部分代码完成的功能是返回加载了预训练参数的模型。
|
def bertForMaskedLM(*args, **kwargs):
|
|
|
"""
|
|
|
BertForMaskedLM includes the BertModel Transformer followed by the
|
|
|
pre-trained masked language modeling head.
|
|
|
Example:
|
|
|
...
|
|
|
"""
|
|
|
model = BertForMaskedLM.from_pretrained(*args, **kwargs)
|
|
|
return model
|
这些入口点可以看成是复杂的模型结构的一种封装形式。它们可以在提供简洁高效的帮助文档的同时完成下载预训练权重的功能(例如,通过 pretrained = True),也可以集成其他特定功能,例如可视化。
通过 hubconf.py,模型发布者可以在 Github 上基于template提交他们的合并请求。PyTorch 社区希望通过 PyTorch Hub 创建一系列高质量、易复现且效果好的模型以提高研究工作的复现性。因此,PyTorch 会通过与模型发布者合作的方式以完善请求,并有可能会在某些情况下拒绝发布一些低质量的模型。一旦 PyTorch 社区接受了模型发布者的请求,这些新的模型将会很快出现在 PyTorch Hub 的网页上以供用户浏览。
用户工作流
对于想使用 PyTorch Hub 对别人的工作进行复现的用户,PyTorch Hub 提供了以下几个步骤:1)浏览可用的模型;2)加载模型;3)探索已加载的模型。下面让我们来浏览几个例子。
浏览可用的入口点
用户可以使用 torch.hub.list() API 列出仓库中的所有可用入口点。
|
>>> torch.hub.list('pytorch/vision')
|
|
|
>>>
|
|
|
['alexnet',
|
|
|
'deeplabv3_resnet101',
|
|
|
'densenet121',
|
|
|
...
|
|
|
'vgg16',
|
|
|
'vgg16_bn',
|
|
|
'vgg19',
|
|
|
'vgg19_bn']
|
注意,PyTorch Hub 还允许辅助入口点(除了预训练模型),例如,用于 BERT 模型预处理的 bertTokenizer,它可以使用户工作流程更加顺畅。
加载模型
对于 PyTroch Hub 中可用的模型,用户可以使用 torch.hub.load() API 加载模型入口点。此外,torch.hub.help() API 可以提供有关如何实例化模型的有用信息。
|
print(torch.hub.help('pytorch/vision', 'deeplabv3_resnet101'))
|
|
|
model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)
|
由于仓库的持有者会不断添加错误修复以及性能改进,PyTorch Hub 允许用户通过调用以下内容简单地获取最新更新:
|
model = torch.hub.load(..., force_reload=True)
|
这一举措可以有效地减轻仓库持有者重复发布模型的负担,从而使他们能够更专注于自己的研究工作。同时,也确保了用户可以获得最新版本的模型。
此外,对于用户来说,稳定性也是一个重要问题。因此,某些模型所有者会从特征的分支或标签为他们提供服务,以确保代码的稳定性。例如,pytorch_GAN_zoo 会从 hub 分支为他们提供服务:
|
model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=False)
|
这里,传递给 hub.load() 的 * args,** kwargs 用于实例化模型。在上面的示例中,pretrained = True 和 useGPU = False 被传递给模型的入口点。
探索已加载的模型
从 PyTorch Hub 加载模型后,用户可以使用以下工作流查看已加载模型的可用方法,并更好地了解运行它所需的参数。
其中,dir(model) 可以查看模型中可用的方法。下面是 bertForMaskedLM 的一些方法:
|
>>> dir(model)
|
|
|
>>>
|
|
|
['forward'
|
|
|
...
|
|
|
'to'
|
|
|
'state_dict',
|
|
|
]
|
help(model.forward)则会提供使已加载的模型运行时所需参数的视图:
|
>>> help(model.forward)
|
|
|
>>>
|
|
|
Help on method forward in module pytorch_pretrained_bert.modeling:
|
|
|
forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None)
|
|
|
...
|
- BERT:https://pytorch.org/hub/huggingface_pytorch-pretrained-bert_bert/
- DeepLabV3:https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/
其他探索方式与相关资源
PyTorch Hub 中提供的模型也支持 Colab,并且会直接链接在 Papers With Code 上,用户只需单击链接即可开始使用:

PyTorch 提供了一些相关资源帮助用户快速上手 PyTorch Hub:
- PyTorch Hub API 手册链接:https://pytorch.org/docs/stable/hub.html
- 模型提交地址:https://github.com/pytorch/hub
- 浏览 PyTorch Hub 网页以学习更多可用模型:https://pytorch.org/hub
- 在 Paper with Code 上浏览更多模型:https://paperswithcode.com/
FAQ
问:如果我们想贡献一个 Hub 中已经有了的模型,但也许我的模型具有更高的准确性,我还应该贡献吗?
答:是的,请提交您的模型,Hub 的下一步是开发投票系统以展示最佳模型。
问:谁负责保管 PyTorch Hub 的模型权重?
答:作为贡献者,您负责保管模型权重。您可以在您喜欢的云存储中托管您的模型,或者如果它符合限制,则可以在 GitHub 上托管您的模型。 如果您无法保管权重,请通过 Hub 仓库中提交问题的方式与我们联系。
问:如果我的模型使用了私有化数据进行训练怎么办?我还应该贡献这个模型吗?
答:请不要提交您的模型!PyTorch Hub 以开源研究为中心,并扩展到使用公开数据集来训练这些模型。如果提交了私有模型的合并请求,我们将恳请您重新提交使用公开数据进行训练后的模型。
问:我下载的模型保存在哪里?
答:我们遵循 XDG 基本目录规范,并遵循缓存文件和目录的通用标准。这些位置按以下顺序使用:
- 调用 hub.set_dir(<PATH_TO_HUB_DIR>)
- 如果环境变量了 TORCH_HOME,则为 $TORCH_HOME/hub。
- 如果设置了环境变量 XDG_CACHE_HOME,则为 $ XDG_CACHE_HOME / torch / hub。
- ~/.cache/torch/hub
相关推荐:
Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易的更多相关文章
- 深度学习之PyTorch实战(1)——基础学习及搭建环境
最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...
- Computational Network Toolkit (CNTK) 是微软出品的开源深度学习工具包
Computational Network Toolkit (CNTK) 是微软出品的开源深度学习工具包 用 CNTK 搞深度学习 (一) 入门 Computational Network Toolk ...
- 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码
PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...
- 《深度学习框架PyTorch:入门与实践》的Loss函数构建代码运行问题
在学习陈云的教程<深度学习框架PyTorch:入门与实践>的损失函数构建时代码如下: 可我运行如下代码: output = net(input) target = Variable(t.a ...
- 参考《深度学习之PyTorch实战计算机视觉》PDF
计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...
- 深度学习数据特征提取:ICCV2019论文解析
深度学习数据特征提取:ICCV2019论文解析 Goal-Driven Sequential Data Abstraction 论文链接: http://openaccess.thecvf.com/c ...
- 深度学习 目标检测算法 SSD 论文简介
深度学习 目标检测算法 SSD 论文简介 一.论文简介: ECCV-2016 Paper:https://arxiv.org/pdf/1512.02325v5.pdf Slides:http://w ...
- 神工鬼斧惟肖惟妙,M1 mac系统深度学习框架Pytorch的二次元动漫动画风格迁移滤镜AnimeGANv2+Ffmpeg(图片+视频)快速实践
原文转载自「刘悦的技术博客」https://v3u.cn/a_id_201 前段时间,业界鼎鼎有名的动漫风格转化滤镜库AnimeGAN发布了最新的v2版本,一时间街谈巷议,风头无两.提起二次元,目前国 ...
- 深度学习框架PyTorch一书的学习-第五章-常用工具模块
https://github.com/chenyuntc/pytorch-book/blob/v1.0/chapter5-常用工具/chapter5.ipynb 希望大家直接到上面的网址去查看代码,下 ...
随机推荐
- 2019.8.9 NOIP模拟测试15 反思总结
日常爆炸,考得一次比一次差XD 可能还是被身体拖慢了学习的进度吧,虽然按理来说没有影响.大家听的我也听过,大家学的我也没有缺勤多少次. 那么果然还是能力问题吗……? 虽然不愿意承认,但显然就是这样.对 ...
- Thinkphp 不足之处
1.报错机制 //控制器里面直接输出如下内容,代码不提示.TP报错机制已经开启 echo $aaaaaa; bbbbbbbbb; eco bbbbbbbb; 正常应该给出以下提示 Notice: Un ...
- 高效整洁CSS代码原则(上)
CSS学起来并不难,但在大型项目中,就变得难以管理,特别是不同的人在CSS书写风格上稍有不同,团队上就更加难以沟通,为此总结了一些如何实现高效整洁的CSS代码原则: 1. 使用Reset但并非全局Re ...
- Linux C/C++开发
首先就是要熟练在vim里面写代码,其实就是没有提示和自动补全了,这个问题并不大. 我服务器gcc版本是4.8.5,所以就按照这个来了 https://gcc.gnu.org/onlinedocs/gc ...
- Python数据分析与展示[第一周]
ipython 中的问号 获得相关的描述信息 %run 系统文件 执行某一个文件 ipython的模式命令 %magic 显示所有的魔术命令 %hist 命令历史输入信息 %pdb 异常发 ...
- 斯坦福CS课程列表
http://exploredegrees.stanford.edu/coursedescriptions/cs/ CS 101. Introduction to Computing Principl ...
- BMDP为常规的统计分析提供了大量的完备的函数系统,如:方差分析(ANOVA)、回归分析(Regression)、非参数分析(Nonparametric Analysis)、时间序列(Times Series)等等。此外,BMDP特别擅于进行出色的生存分析(Survival Analysis )。许多年来,一大批世界范围内顶级的统计学家都曾今参与过BMDP的开发工作。这不仅使得BMDP的权威性得到
BMDP是Bio Medical Data Processing的缩写,是世界级的统计工具软件,至今已经有40多年的历史.目前在国际上与SAS.SPSS被并称为三大统计软件包.BMDP是一个大 ...
- shops
#!/usr/bin/env python #coding:utf- import urllib2,sys,re,os,string reload(sys); sys.setdefaultencodi ...
- 【JZOJ3885】【长郡NOIP2014模拟10.22】搞笑的代码
ok 在OI界存在着一位传奇选手--QQ,他总是以风格迥异的搞笑代码受世人围观 某次某道题目的输入是一个排列,他使用了以下伪代码来生成数据 while 序列长度<n do { 随机生成一个整数属 ...
- 【python小随笔】单例模式设计(易懂版)
1:单例模式原理 大道理:希望在系统中某个对象只能存在一个,单例模式是最好的解决方案,单例模式是一种常见的软件设置模式,在它的核心结构中只包含一个被称为单例类的特殊类,通过单例模式可以保证系统中的一个 ...
