Trainer解析

我们继续Detectron2代码阅读笔记-(一)中的内容。

上图画出了detectron2文件夹中的三个子文件夹(tools,config,engine)之间的关系。那么剩下的文件夹又是如何起作用的呢?


def main(args):
cfg = setup(args) if args.eval_only:
...
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
if cfg.TEST.AUG.ENABLED:
trainer.register_hooks(
[hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
)
return trainer.train()

build_*方法

我们从trainer = Trainer(cfg)开始进一步了解。

Detectron2代码阅读笔记-(一)中已经提到过一连串的Trainer的继承关系如下:

tools.train_net.Trainer->detectron2.engine.default.DefaultTrainer->detectron2.engine.train_loop.SimpleTrainer->detectron2.engine.train_loop.TrainerBase,而detectron2.engine.default.DefaultTrainer在其__init__(self, cfg)函数中定义了解析cfg。如下面代码所示,cfg会作为参数倍若干个build_*方法解析,得到解析后的model,optimizer,data_loader等。

from detectron2.modeling import build_model
class DefaultTrainer(SimpleTrainer):
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
"""
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg) ... self.register_hooks(self.build_hooks()) @classmethod
def build_model(cls, cfg):
"""
Returns:
torch.nn.Module:
"""
model = build_model(cfg)
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
return model

下面我们以DefaultTrainer.build_model为例来介绍注册机制,该方法调用了detectron2/modeling/meta_arch/build_model.pybuild_model函数,其源代码如下:

from detectron2.utils.registry import Registry

META_ARCH_REGISTRY = Registry("META_ARCH")
META_ARCH_REGISTRY.__doc__ = """
def build_model(cfg):
"""
Built the whole model, defined by `cfg.MODEL.META_ARCHITECTURE`.
"""
meta_arch = cfg.MODEL.META_ARCHITECTURE
return META_ARCH_REGISTRY.get(meta_arch)(cfg)
  • meta_arch = cfg.MODEL.META_ARCHITECTURE: 根据超参数获得网络结构的名字
  • return META_ARCH_REGISTRY.get(meta_arch)(cfg):META_ARCH_REGISTRY是一个Registry类(这个在后面会详细介绍),可以将这一行代码拆成如下几个步骤:
model = META_ARCH_REGISTRY.get(meta_arch)
return model(cfg)

注册机制Registry

那么Registry到底是什么呢?在分析源代码之前我们先了解一下如何使用它,假如你想自己实现一个新的backbone网络,那么你可以这样做:

首先在detectron2中定义好如下(实际上已经定义了):

# detectron2/modeling/backbone/build.py
BACKBONE_REGISTRY = Registry('BACKBONE')

之后在你创建的新的文件下按如下方式创建你的backbone

# detectron2/modeling/backbone/your_backbone.py
from .build import BACKBONE_REGISTRY # 方式1
@BACKBONE_REGISTRY.register()
class MyBackbone():
... # 方式2
class MyBackbone():
...
BACKBONE_REGISTRY.register(MyBackbone)

Registry源代码如下(有删减):

class Registry(object):
def __init__(self, name):
self._name = name
self._obj_map = {} def _do_register(self, name, obj):
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(name, self._name)
self._obj_map[name] = obj def register(self, obj=None):
if obj is None:
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class)
return func_or_class return deco # used as a function call
name = obj.__name__
self._do_register(name, obj) def get(self, name):
ret = self._obj_map.get(name)
if ret is None:
raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name))
return ret
  • 首先是__init__部分:

    • self._name则是你要注册的名字,例如对于完整的模型而言,name一般取META_ARCH。当然如果你需要自定义backbone网络,你也可以定义一个Registry('BACKBONE')
    • self._obj_map:其实就是一个字典。以模型为例,key就是你的模型名字,而value就是对应的模型类。这样你在传参时只需要修改一下模型名字就能使用不同的模型了。具体实现方法就是后面这几个函数。
  • register: 可以看到该方法定义了注册的两种方式,一种是当obj==None的时候,使用装饰器的方式注册,另外一种就是直接将obj作为参数调用_do_register进行注册。
  • _do_register:真正注册的函数,可以看到它首先会判断name是否已经存在于self._obj_map了。什么意思呢?还是以backbone为例,我们定义了一个BACKBONE_REGISTRY = Registry('BACKBONE'),然后又定义了很多种backbone,而这些backbone都使用@BACKBONE_REGISTRY.register()的方式注册到了BACKBONE_REGISTRY._obj_map中了,所以才取名为Registry,还是蛮形象的吼。
  • get: 这个其实就是根据key值对字典进行取值。

Detectron2 整体代码架构

虽然Detectron2还有很多部分没有介绍到,但是源代码分析到这应该对整体架构有了一定的理解了,具体的一些细节会在后续的文章中进行分析。现对Detectron2 整体代码架构总结一下:

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com






2019-10-15 13:16:32

Detectron2源码阅读笔记-(二)Registry&build_*方法的更多相关文章

  1. Detectron2源码阅读笔记-(一)Config&Trainer

    代码结构概览 核心部分 configs:储存各种网络的yaml配置文件 datasets:存放数据集的地方 detectron2:运行代码的核心组件 tools:提供了运行代码的入口以及一切可视化的代 ...

  2. werkzeug源码阅读笔记(二) 下

    wsgi.py----第二部分 pop_path_info()函数 先测试一下这个函数的作用: >>> from werkzeug.wsgi import pop_path_info ...

  3. Detectron2源码阅读笔记-(三)Dataset pipeline

    构建data_loader原理步骤 # engine/default.py from detectron2.data import ( MetadataCatalog, build_detection ...

  4. werkzeug源码阅读笔记(二) 上

    因为第一部分是关于初始化的部分的,我就没有发布出来~ wsgi.py----第一部分 在分析这个模块之前, 需要了解一下WSGI, 大致了解了之后再继续~ get_current_url()函数 很明 ...

  5. Android源码阅读笔记二 消息处理机制

    消息处理机制: .MessageQueue: 用来描述消息队列2.Looper:用来创建消息队列3.Handler:用来发送消息队列 初始化: .通过Looper.prepare()创建一个Loope ...

  6. Apollo源码阅读笔记(二)

    Apollo源码阅读笔记(二) 前面 分析了apollo配置设置到Spring的environment的过程,此文继续PropertySourcesProcessor.postProcessBeanF ...

  7. 【原】FMDB源码阅读(二)

    [原]FMDB源码阅读(二) 本文转载请注明出处 -- polobymulberry-博客园 1. 前言 上一篇只是简单地过了一下FMDB一个简单例子的基本流程,并没有涉及到FMDB的所有方方面面,比 ...

  8. Three.js源码阅读笔记-5

    Core::Ray 该类用来表示空间中的“射线”,主要用来进行碰撞检测. THREE.Ray = function ( origin, direction ) { this.origin = ( or ...

  9. jdk源码阅读笔记-LinkedHashMap

    Map是Java collection framework 中重要的组成部分,特别是HashMap是在我们在日常的开发的过程中使用的最多的一个集合.但是遗憾的是,存放在HashMap中元素都是无序的, ...

随机推荐

  1. cad 一个小技巧--复制视口带冻结信息

    cad使用 ctrl+c 和 ctrl+v 进行跨文件复制视口的时候,会出现复制视口冻结信息丢失,因为你只选择了视口进行复制, 如果要实现带上冻结信息,那么要把含有相关图层的图元一起 ctrl+c/v ...

  2. Alpha冲刺(11/10)——2019.5.3

    作业描述 课程 软件工程1916|W(福州大学) 团队名称 修!咻咻! 作业要求 项目Alpha冲刺(团队) 团队目标 切实可行的计算机协会维修预约平台 开发工具 Eclipse 团队信息 队员学号 ...

  3. .net core 在 View 中使用 Jquery 无效问题

    问题描述: 在 View 视图中使用模板 _Layout.cshtml,其中模板已经调用了 Jquery.js ,但是在 View 视图下写 js 无效.后来通过浏览器查看自己写的 js 压根没加载出 ...

  4. Python3+Robot Framework+RIDE安装使用教程

    一.说明 Python3----网上很多文章都是用Python2,Robot Framework的部分文档没更新也直接写着不支持Python3(如RIDE does not yet support P ...

  5. 【C++】C++中基类的析构函数为什么要用virtual虚析构函数?

    正面回答: 当基类的析构函数不是虚函数,并且基类指针指向一个派生类对象,然后通过基类指针来删除这个派生类对象时,如果基类的析构函数不是虚析构函数,那么派生类的析构函数就不会被调用,从而产生内存泄漏 # ...

  6. [转帖]centos7 使用kubeadm 快速部署 kubernetes 国内源

    centos7 使用kubeadm 快速部署 kubernetes 国内源 https://www.cnblogs.com/qingfeng2010/p/10540832.html 前言 搭建kube ...

  7. Linux命令中service的用法

    用途说明 service命令用于对系统服务进行管理,比如启动(start).停止(stop).重启(restart).查看状态(status)等.相关的命令还包括chkconfig.ntsysv等,c ...

  8. PowerBuilder学习笔记之1开发环境

    Powerbuilder Classic 12.5开发环境(PB经典 12.5) 教材链接:https://wenku.baidu.com/view/5e087d6ab9f67c1cfad6195f3 ...

  9. 『Blocks 区间dp』

    Blocks Description Some of you may have played a game called 'Blocks'. There are n blocks in a row, ...

  10. go ---switch语句

    package main import ( "fmt" ) func main() { var ar = [...]string{"A", "B&qu ...