Fairseq-快速可扩展的序列建模工具包
一种快速、可扩展的序列建模工具包,Pytorch的高级封装库,适用于机器翻译、语言模型和篇章总结等建模任务。
抽象
Dataset:数据加载
Fairseq中的Dataset
基本都是按功能逐层封装,按需组合起来。所有数据加载的实现均位于fairseq/data
下面。
两个比较常用的数据处理类:
IndexedDataset
直接处理/读取,bin/raw文件。LanguagepairDataset
包含src和tgt两个Dataset,用于处理成对的数据。比如在机器翻译的中翻英任务中,处理中文和英文文本。
Option:参数定义
Fairseq中的参数统一使用argsparser
库实现,模型通用参数被定义在fairseq/option.py
下。同时每个模型均有其特有参数,通过每个模型的add_args(parser)
函数定义。
在
fairseq/option.py
中定义了6类通用参数,对应的函数分别是get_preprocessing_parser()
,get_training_parser()
,get_generation_parser()
,get_interactive_generation_parser()
,get_eval_lm_parser()
和get_validation_parser
。这6类通用参数又通过add_***_args()
组装起来。在模型各自的实现中,通过继承接口中的
add_args()
添加模型特有的参数,比如fairseq\models\lstm.py
中通过add_args()
添加了LSTM模型的encoder-embed-dim
,encoder-layers
,encoder-bidirectional
等参数。
Model:网络模型的抽象
Fairseq中的Model
负责模型的定义,包括各个模型的总体结构,每个模型提供argeparser
供用户传入自定义参数。所有的模型定义均位于fairseq/models
下。
所有的模型均继承自类BaseFairseqModel
,而BaseFairseqModel
又继承自torch.nn.Module
,因此所有的Fairseq模型均可以作为其它Pytorch代码的模块。模型的具体结构,比如嵌入层的维度、隐藏层的个数由architectures
定义。特别地,
LanguageModel
和EncoderDecoderModel
均直接继承自BaseFairseqModel
,BaseFairseqModel
主要提供add_args()
、build_model()
等统一的接口,以及模型加载等功能。Encoder
和Decoder
均直接继承自torch.nn.Module
,LanguageModel
和EncoderDecoder
包含Encoder
和Decoder
。Decoder
包含一个output_layer
抽象接口,BERT这样的语言模型由于存在输出,因此继承的是Decoder
。
FairseqTask
Fairseq主要以FairseqTask
为核心,使用FairseqTask
将各个部分衔接起来。一个Task可以是TranslationTask
(比如使用Transformer做翻译),也可以是一个LanguageModelTask
。所有的任务定义均位于fairseq/tasks
下。一个FairseqTask
实例需要实现以下功能:
字典存储/加载。
提供加载、分切数据的帮助类,获得装载数据的
Dataloader
、iterator
等。创建模型。
创建
criterion
。循环训练、验证,直至收敛或达到指定训练轮数。
FairseqTask
实现的功能基本上包含了模型运行的全部要素,可以看到主函数的调用流程:
Criterion
所有的准则(criterion)定义在fairseq/criterions
内,准则对给定的模型和小批量数据计算损失(Loss)。也就是:
\]
在Fairseq中,实现所谓的“混合专家”(mixture-of-experts)模型,准则(criterion)实现EM风格(EM-style)的训练,以节约算力。
Optimizer
所有的优化器(optimizer)定义在fairseq/optim
中,优化器根据梯度,更新模型参数。
Scheduler
定义在fairseq/lr_scheduler
中。在训练过程中,调整学习率。
注册
注册机制
Fairseq中许多组件都是公共的,模块之间尽量解耦,需要一种方式指定应该跑哪一个Model,数据装载使用哪一个Dataset。注册机制在Fairseq中大量使用。
以FairseqTask的注册机制为例,FairseqTask包含了多个子类,如TranslationTask
、MaskedLMTask
、LanguageModeling
等。在fairseq/task/__init__.py
中会通过for循环import该目录下的所有文件,最后在TASK_REGISTRY
中可以得到key:cls
形式的模块存储器。其中,key
为字符串,cls
为模块的cls
对象。这种方式可以很方便的通过指定参数,导入想要的模块。在函数装饰器setup_task()
和register_task()
中,通过TASK_REGISTRY
载入和注册task。
举个例子,通过装饰器进行注册
,比如:
@register_task('language_modeling')
class LanguageModelingTask(FairseqTask):
...
在Model
、Criterion
部分都有该机制的身影。
在主函数train.py
中,通过setup_task()
,build_model()
,build_criterion()
中得到所需部分。
同样地,可以使用注册机制固化模型参数。一些模型仅仅有模型参数上的区别,本质并无区别,比如roberta_base
,roberta_large
。因此需要指定各个模型的具体默认参数,当然这些参数,用户可以通过fairseq的参数系统进行指定。这些模型的具体参数同样可以用注册的方式固定下来,在使用时可以更加方便。
- 对于模型,使用
@register_model
装饰器注册。
@register_model('roberta')
class RobertaModel(FairLanguageModel):
...
- 对于具体的模型结构,使用
@register_model_architecture
装饰器注册。
@register_model_architecture('robtera','roberta_large')
def roberta_large_architecture(args):
args.encoder_layers = getattr(args,'encoder_layers',24)
args.encoder_embed_dim = getattr(args,'encoder_embed_dim',1024)
...
base_architecture(args)
注册的函数对象会在ARCH_CONFIG_REGISTERY
中存储,并在option.py
中调用:
ARCH_CONFIG_REGISTRY[args.arch](args)
实现上的特点
Fairseq使用Pytorch实现,支持多机、多卡、混合精度训练。提升速度,降低显存占用。
分批次
Fairseq依据序列长度对源/目标序列进行分组,相似长度的序列作为一组,以减小对序列的补齐填充操作。每一个mini-batch内的样本在训练过程中不变,但每一轮训练时都会打乱mini-batch间的顺序。当在多卡、多机上运行时,每一个worker的mini-batches平均长度有所不同,以实现更有代表性(more representative)的迭代。
多GPU训练
使用NCCL2库和
torch.distributed
作为GPU间的通信。每个GPU上保留一个模型副本。
前向计算和反向传播异步。Fairseq中每一层的梯度计算完成后,都会把结果存放到缓存中,当缓存大小达到某一个阈值之后,在一个后台线程中同步梯度,反向传播照常进行。在每一个GPU上累加梯度,以减小worker上处理时间的方差,这样就不必等待计算比较慢的worker。
如图所示,图a在同步梯度时,等待最慢的worker,因此产生了大量的等待时间(白色所示,idle)。但Fairseq同时采用了图b和图c的技术,反向传播(back-propagation)和梯度同步(gradient synchronization)同时进行,并且累加梯度以减少worker上面处理时间的“抖动”,从而提升训练速度。
混合精度训练
Fairseq同时支持半精度浮点(half precision float point, FP16)和全精度浮点(full precision float point, FP32)的训练和推断。在前后向以及worker之间规约(all-reduce)时,使用FP16。但在参数更新时仍然采用FP32,以保证计算精度。由于FP16提供的精度有限,为了防止激活和梯度的下溢出,Fairseq实现了所谓的动态损失缩放(dynamic loss scaling)。当FP16的梯度在worker之间同步完成之后,将缩放到FP16的数字恢复为原来的FP32,并更新模型权重。
推断优化
Fairseq通过增量解码(incremental decoding)提供了更快的推理速度。所谓的增量解码,就是在解码时,将之前tokens处于激活beam状态下的模型状态(model states)缓存起来,以备后用,这样每一个新的token进来,只需要计算新的状态即可。也就是说,如果使用FairseqDecoder
接口实现普通
的解码器,对于每一个输出,都需要重新整个解码器隐状态,计算复杂度O(n^2)。而使用FairseqIncrementalDecoder
接口实现增量解码,就可以实现O(n)的解码速度。
在训练和推理阶段,通过用户指定的最大tokens数量,构建动态样本数量的batch。并且Fairseq在保证准确率的前提下,支持FP16精度的推断。相比于FP32,FP16推断将解码速度提高54%。注意:Fairseq中没有batch size的概念,用户通过指定max-tokens
,Fairseq会自动构建不定数量的batch送入模型训练。
Fairseq repo (Python): https://github.com/pytorch/fairseq
Paper: http://cn.arxiv.org/abs/1904.01038
Document: fairseq.readthedocs.io
https://zhuanlan.zhihu.com/p/100249351
https://zhuanlan.zhihu.com/p/100643955
Fairseq-快速可扩展的序列建模工具包的更多相关文章
- 基于机器学习的web异常检测——基于HMM的状态序列建模,将原始数据转化为状态机表示,然后求解概率判断异常与否
基于机器学习的web异常检测 from: https://jaq.alibaba.com/community/art/show?articleid=746 Web防火墙是信息安全的第一道防线.随着网络 ...
- BZOJ_2242_[SDOI2011]计算器_快速幂+扩展GCD+BSGS
BZOJ_2242_[SDOI2011]计算器_快速幂+扩展GCD+BSGS 题意: 你被要求设计一个计算器完成以下三项任务: 1.给定y,z,p,计算Y^Z Mod P 的值: 2.给定y,z,p, ...
- 阿里CTR预估:用户行为长序列建模
本文将介绍Alibaba发表在KDD'19 的论文<Practice on Long Sequential User Behavior Modeling for Click-Through Ra ...
- 【笔记】论文阅读:《Gorilla: 一个快速, 可扩展的, 内存式时序数据库》
英文:Gorilla: A fast, scalable, in-memory time series database 中文:Gorilla: 一个快速, 可扩展的, 内存式时序数据库
- Slickflow.Graph 开源工作流引擎快速入门之四: 图形编码建模工具使用手册
前言: 业务人员绘制流程时,通常使用图形GUI界面交互操作来完成,然而对于需要频繁操作或者管理较多流程的系统管理用户,就需要一款辅助工具,来帮助他们快速完成流程的创建和编辑更新.Slickflow.G ...
- “盛大游戏杯”第15届上海大学程序设计联赛夏季赛暨上海高校金马五校赛题解&&源码【A,水,B,水,C,水,D,快速幂,E,优先队列,F,暴力,G,贪心+排序,H,STL乱搞,I,尼姆博弈,J,差分dp,K,二分+排序,L,矩阵快速幂,M,线段树区间更新+Lazy思想,N,超级快速幂+扩展欧里几德,O,BFS】
黑白图像直方图 发布时间: 2017年7月9日 18:30 最后更新: 2017年7月10日 21:08 时间限制: 1000ms 内存限制: 128M 描述 在一个矩形的灰度图像上,每个 ...
- 【bzoj2242】: [SDOI2011]计算器 数论-快速幂-扩展欧几里得-BSGS
[bzoj2242]: [SDOI2011]计算器 1.快速幂 2.扩展欧几里得(费马小定理) 3.BSGS /* http://www.cnblogs.com/karl07/ */ #include ...
- bzoj 2242: [SDOI2011]计算器 BSGS+快速幂+扩展欧几里德
2242: [SDOI2011]计算器 Time Limit: 10 Sec Memory Limit: 512 MB[Submit][Status][Discuss] Description 你被 ...
- bzoj 2242 [SDOI2011]计算器 快速幂+扩展欧几里得+BSGS
1:快速幂 2:exgcd 3:exbsgs,题里说是素数,但我打的普通bsgs就wa,exbsgs就A了...... (map就是慢)..... #include<cstdio> # ...
随机推荐
- Python os.closerange() 方法
概述 os.closerange() 方法用于关闭所有文件描述符 fd,从 fd_low (包含) 到 fd_high (不包含), 错误会忽略.高佣联盟 www.cgewang.com 语法 clo ...
- PHP var_export() 函数
var_export() 函数用于输出或返回一个变量,以字符串形式表示.高佣联盟 www.cgewang.com高佣联盟 www.cgewang.com var_export() 函数返回关于传递给该 ...
- 华为手机内核代码的编译及刷入教程【通过魔改华为P9 Android Kernel 对抗反调试机制】
0x00 写在前面 攻防对立.程序调试与反调试之间的对抗是一个永恒的主题.在安卓逆向工程实践中,通过修改和编译安卓内核源码来对抗反调试是一种常见的方法.但网上关于此类的资料比较少,且都是基于AOSP ...
- CSMA/CD协议(载波侦听多路访问/碰撞检测) 最小帧长理解
以下的帧长有的是指帧的时间长度,帧的时间长度= 帧长/传输时延
- 找工作的你不容错过的45个PHP面试题附答案(下篇)
找工作的你不容错过的45个PHP面试题附答案(上篇) Q28:你将如何使用PHP创建Singleton类? /** * Singleton class * */ final class UserFac ...
- 使用Android Studio创建模拟器,安装配置Android SDK
Android Studio 一个写安卓APP应用的代码编辑器之类的?嗯,应该是... 这里只是需要用到里面的AVD Manager 创建安卓模拟器(也可以用mumu类的安卓模拟器):SDK Mana ...
- DCGAN实现
DCGAN实现 代码 dcgan.py #!/usr/bin/env python # -*- coding: utf-8 -*- import os import math import argpa ...
- Vue + ccropper.js裁切图片(vue-cropper)
按原比例裁剪图片并且不失真. 安装: cnpm install vue-cropper --save-dev 使用: <template> <div style="disp ...
- 【LeetCode/LintCode 题解】约瑟夫问题 · Joseph Problem
n个人按顺序围成一圈(编号为1~n),从第1个人从1开始报数,报到k的人出列,相邻的下个人重新从1开始报数,报到k的人出列,重复这个过程,直到队伍中只有1个人为止,这就是约瑟夫问题.现在给定n和k,你 ...
- Java—API/Obiect类的equals toString方法/String类/StringBuffer类/正则表达式
API Java 的API(API: Application(应用) Programming(程序) Interface(接口)) 就是JDK中提供给我们使用的类,这些类将底层的代码实现封装了起来 ...