构建data_loader原理步骤

# engine/default.py
from detectron2.data import (
MetadataCatalog,
build_detection_test_loader,
build_detection_train_loader,
)
class DefaultTrainer(SimpleTrainer):
def __init__(self, cfg):
# Assume these objects must be constructed in this order.
data_loader = self.build_train_loader(cfg)
...
@classmethod
def build_train_loader(cls, cfg):
"""
Returns:
iterable
"""
return build_detection_train_loader(cfg)

函数调用关系如下图:

结合前面两篇文章的内容可以看到detectron2在构建model,optimizer和data_loader的时候都是在对应的build.py文件里实现的。我们看一下build_detection_train_loader是如何定义的(对应上图中紫色方框内的部分(自下往上的顺序)):


def build_detection_train_loader(cfg, mapper=None):
"""
A data loader is created by the following steps: 1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
2. Start workers to work on the dicts. Each worker will:
* Map each metadata dict into another format to be consumed by the model.
* Batch them by simply putting dicts into a list.
The batched ``list[mapped_dict]`` is what this dataloader will return. Args:
cfg (CfgNode): the config
mapper (callable): a callable which takes a sample (dict) from dataset and
returns the format to be consumed by the model.
By default it will be `DatasetMapper(cfg, True)`. Returns:
a torch DataLoader object
"""
# 获得dataset_dicts
dataset_dicts = get_detection_dataset_dicts(
cfg.DATASETS.TRAIN,
filter_empty=True,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
) # 将dataset_dicts转化成torch.utils.data.Dataset
dataset = DatasetFromList(dataset_dicts, copy=False) # 进一步转化成MapDataset,每次读取数据时都会调用mapper来对dict进行解析
if mapper is None:
mapper = DatasetMapper(cfg, True)
dataset = MapDataset(dataset, mapper) # 采样器
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
if sampler_name == "TrainingSampler":
sampler = samplers.TrainingSampler(len(dataset))
...
batch_sampler = build_batch_data_sampler(
sampler, images_per_worker, group_bin_edges, aspect_ratios
) # 数据迭代器 data_loader
data_loader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.DATALOADER.NUM_WORKERS,
batch_sampler=batch_sampler,
collate_fn=trivial_batch_collator,
worker_init_fn=worker_init_reset_seed,
)
return data_loader

由上面的源代码可以看出总共是五个步骤,我们只对前面三个部分进行详细介绍,后面的采样器和data_loader可以参阅一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

获得dataset_dicts

get_detection_dataset_dicts(dataset_names)函数需要传递的一个重要参数是dataset_names,这个参数其实就是一个字符串,用来指定数据集的名称。通过这个字符串,该函数会调用data/catalog.pyDatasetCatalog类来进行解析得到一个包含数据信息的字典。

解析的原理是:DatasetCatalog有一个字典_REGISTERED,默认已经注册好了例如coco,voc这些数据集的信息。如果你想要使用你自己的数据集,那么你需要在最开始前你需要定义你的数据集名字以及定义一个函数(这个函数不需要传参,而且最后会返回一个dict,该dict包含你的数据集信息),举个栗子:

from detectron2.data import DatasetCatalog
my_dataset_name = 'apple'
def get_dicts():
...
return dict DatasetCatalog.register(my_dataset_name, get_dicts)

当然,如果你的数据集已经是COCO的格式了,那么你也可以使用如下方法进行注册:

from detectron2.data.datasets import register_coco_instances
my_dataset_name = 'apple'
register_coco_instances(my_dataset_name, {}, "json_annotation.json", "path/to/image/dir")

另外需要注意的是一个数据集其实是可以由两个类来定义的,一个是前面介绍了的DatasetCatalog,另一个是MetadataCatalog

MetadataCatalog的作用是记录数据集的一些特征,这样我们就可以很方便的在整个代码中获取数据集的特征信息。在注册DatasetCatalog后,我们可以按如下栗子对MetadataCatalog进行注册并定义我们后面可能会用到的属性特征:

from detectron2.data import MetadataCatalog
MetadataCatalog.get("my_dataset").thing_classes = ["person", "dog"] # 也可以这样
MetadataCatalog.get("my_dataset").set("thing_classes",["person", "dog"])

注意:如果你的数据集名字未注册过,MetadataCatalog.get会自动进行注册,然后会自动设置你所设定的属性值。

其实MetadataCatalog还有其他的特征属性可以设置,如stuff_classes,stuff_colors等等。你可能会好奇thing_classesstuff_classes有什么区别,区别如下:

  • 抽象解释:thing_classes用于指定instance-level任务,stuff_classes用于semantic segmentation任务。
  • 具体解释:像椅子,书这种可数的东西,就可以理解成thing,所以用于instance-level;而雪、天空这种不可数的就理解成stuff,所以用于semantic segmentation。参考On Seeing Stuff: The Perception of Materials by Humans and Machines

最后,get_detection_dataset_dicts会返回一个包含若干个dict的list,之所以是list是因为参数dataset_names也是一个list,这样我们就可以制定多个names来同时对数据进行读取。

解析成DatasetFromList

DatasetFromList(dataset_dict)函数定义在detectron2/data/common.py中,它其实就是一个torch.utils.data.Dataset类,其源码如下

class DatasetFromList(data.Dataset):
"""
Wrap a list to a torch Dataset. It produces elements of the list as data.
""" def __init__(self, lst: list, copy: bool = True):
"""
Args:
lst (list): a list which contains elements to produce.
copy (bool): whether to deepcopy the element when producing it,
so that the result can be modified in place without affecting the
source in the list.
"""
self._lst = lst
self._copy = copy def __len__(self):
return len(self._lst) def __getitem__(self, idx):
if self._copy:
return copy.deepcopy(self._lst[idx])
else:
return self._lst[idx]

这个很简单就不加赘述了

DatsetFromList转化成MapDataset

其实DatsetFromListMapDataset都是torch.utils.data.Dataset的子类,那他们的区别是什么呢?很简单,区别就是后者使用了mapper

在解释mapper是什么之前我们首先要知道的是,在detectron2中,一张图片对应的是一个dict,那么整个数据集就是list[dict]。之后我们再看DatsetFromList,它的__getitem__函数非常简单,它只是简单粗暴地就返回了指定idx的元素。显然这样是不行的,因为在把数据扔给模型训练之前我们肯定还要对数据做一定的处理,而这个工作就是由mapper来做的,默认情况下使用的是detectron2/data/dataset_mapper.py中定义的DatasetMapper,如果你需要自定义一个mapper也可以参考这个写。

DatasetMapper(cfg, is_train=True)

我们继续了解一下DatasetMapper的实现原理,首先看一下官方给的定义:

A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by the model.

简单概括就是这个类是可调用的(callable),所以在下面的源码中可以看到定义了__call__方法。

该类主要做了这三件事:

The callable currently does the following:

  1. Read the image from "file_name"
  2. Applies cropping/geometric transforms to the image and annotations
  3. Prepare data and annotations to Tensor and :class:Instances

其源码如下(有删减):

class DatasetMapper:
def __init__(self, cfg, is_train=True):
# 读取cfg的参数
... def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below # 1. 读取图像数据
image = utils.read_image(dataset_dict["file_name"], format=self.img_format) # 2. 对image和box等做Transformation
if "annotations" not in dataset_dict:
image, transforms = T.apply_transform_gens(
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
)
else:
...
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
if self.crop_gen:
transforms = crop_tfm + transforms image_shape = image.shape[:2] # h, w # 3.将数据转化成tensor格式
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
... return dataset_dict

MapDataset

class MapDataset(data.Dataset):
def __init__(self, dataset, map_func):
self._dataset = dataset
self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work self._rng = random.Random(42)
self._fallback_candidates = set(range(len(dataset))) def __len__(self):
return len(self._dataset) def __getitem__(self, idx):
retry_count = 0
cur_idx = int(idx) while True:
data = self._map_func(self._dataset[cur_idx])
if data is not None:
self._fallback_candidates.add(cur_idx)
return data # _map_func fails for this idx, use a random new index from the pool
retry_count += 1
self._fallback_candidates.discard(cur_idx)
cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0] if retry_count >= 3:
logger = logging.getLogger(__name__)
logger.warning(
"Failed to apply `_map_func` for idx: {}, retry count: {}".format(
idx, retry_count
)
)
  • self._fallback_candidates是一个set,它的特点是其中的元素是独一无二的,定义这个的作用是记录可正常读取的数据索引,因为有的数据可能无法正常读取,所以这个时候我们就可以把这个坏数据的索引从_fallback_candidates中剔除,并随机采样一个索引来读取数据。
  • __getitem__中的逻辑就是首先读取指定索引的数据,如果正常读取就把该所索引值加入到_fallback_candidates中去;反之,如果数据无法读取,则将对应索引值删除,并随机采样一个数据,并且尝试3次,若3次后都无法正常读取数据,则报错,但是好像也没有退出程序,而是继续读数据,可能是以为总有能正常读取的数据吧hhh。

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com




2020-01-23 17:45:35

如有意合作,欢迎私戳

邮箱:marsggbo@foxmail.com



2019-10-23 13:37:13

Detectron2源码阅读笔记-(三)Dataset pipeline的更多相关文章

  1. Detectron2源码阅读笔记-(二)Registry&build_*方法

    ​ Trainer解析 我们继续Detectron2代码阅读笔记-(一)中的内容. 上图画出了detectron2文件夹中的三个子文件夹(tools,config,engine)之间的关系.那么剩下的 ...

  2. Detectron2源码阅读笔记-(一)Config&Trainer

    代码结构概览 核心部分 configs:储存各种网络的yaml配置文件 datasets:存放数据集的地方 detectron2:运行代码的核心组件 tools:提供了运行代码的入口以及一切可视化的代 ...

  3. Werkzeug源码阅读笔记(三)

    这次主要讲下werkzeug中的Local. 源码在werkzeug/local.py Thread Local 在Python中,状态是保存在对象中.Thread Local是一种特殊的对象,它是对 ...

  4. CI框架源码阅读笔记5 基准测试 BenchMark.php

    上一篇博客(CI框架源码阅读笔记4 引导文件CodeIgniter.php)中,我们已经看到:CI中核心流程的核心功能都是由不同的组件来完成的.这些组件类似于一个一个单独的模块,不同的模块完成不同的功 ...

  5. CI框架源码阅读笔记2 一切的入口 index.php

    上一节(CI框架源码阅读笔记1 - 环境准备.基本术语和框架流程)中,我们提到了CI框架的基本流程,这里再次贴出流程图,以备参考: 作为CI框架的入口文件,源码阅读,自然由此开始.在源码阅读的过程中, ...

  6. 源码阅读笔记 - 1 MSVC2015中的std::sort

    大约寒假开始的时候我就已经把std::sort的源码阅读完毕并理解其中的做法了,到了寒假结尾,姑且把它写出来 这是我的第一篇源码阅读笔记,以后会发更多的,包括算法和库实现,源码会按照我自己的代码风格格 ...

  7. libevent源码阅读笔记(一):libevent对epoll的封装

    title: libevent源码阅读笔记(一):libevent对epoll的封装 最近开始阅读网络库libevent的源码,阅读源码之前,大致看了张亮写的几篇博文(libevent源码深度剖析 h ...

  8. jdk源码阅读笔记-LinkedHashMap

    Map是Java collection framework 中重要的组成部分,特别是HashMap是在我们在日常的开发的过程中使用的最多的一个集合.但是遗憾的是,存放在HashMap中元素都是无序的, ...

  9. faster rcnn源码阅读笔记1

    自己保存的源码阅读笔记哈 faster rcnn 的主要识别过程(粗略) (开始填坑了): 一张3通道,1600*1600图像输入中,经过特征提取网络,得到100*100*512的feature ma ...

随机推荐

  1. Linux性能优化实战学习笔记:第二十八讲

    一.案例环境描述 1.环境准备 2CPU,4GB内存 预先安装docker sysstat工具 apt install docker.io sysstat nake git 案例总共由三个容器组成: ...

  2. Linux性能优化实战学习笔记:第三十六讲

    一.上节总结回顾 上一节,我们回顾了经典的 C10K 和 C1000K 问题.简单回顾一下,C10K 是指如何单机同时处理 1 万个请求(并发连接 1 万)的问题,而 C1000K 则是单机支持处理 ...

  3. CSP2019蒸馏记

    Day -\(\infty\) ~ Day -2 认真准备联赛. Day -1 复习模板,全真模拟比赛平衡树 下午进行了湖南大学 2 小时游. Day 0 上午睡过头了 下午日常训练,并没有什么开放日 ...

  4. Elasticsearch由浅入深(七)搜索引擎:_search含义、_multi-index搜索模式、分页搜索以及深分页性能问题、query string search语法以及_all metadata原理

    _search含义 _search查询返回结果数据含义分析 GET _search { , "timed_out": false, "_shards": { , ...

  5. OIDC-Open ID Connect

    OpenID Connect的简称,OIDC=(Identity, Authentication) + OAuth 2.0.它在OAuth2上构建了一个身份层,是一个基于OAuth2协议的身份认证标准 ...

  6. spring boot打包为war包,引入外部jar包

    1,在src/main/resource下新建目录jar,将外部jar包放在该目录下 2,在pom.xml中添加依赖 groupId,artifactId,version可随便写 <depend ...

  7. VUE的$refs和$el的使用

    ref 被用来给元素或子组件注册引用信息 ref 有三种用法: 1.ref 加在普通的元素上,用this.$refs.(ref值) 获取到的是dom元素 2.ref 加在子组件上,用this.$ref ...

  8. CentOS 安装libgdi的方法

    1. 安装必须的包 yum install glib2-devel cairo-devel libjpeg-turbo-devel-1.2.90-8.el7.x86_64 libtiff-devel- ...

  9. cmd命令和linux命令的区别

    cmd命令和linux命令看起来很相似,都是在一个控制台输入一些特定的指令去完成一些特定的操作.可是用过的朋友就会发现这些指令是有很多不同的,可是到底有什么不同,要说又说不上来,所以要了解一下. cm ...

  10. 关于eclipse SE版本不支持建立web工程的问题

    关于eclipse SE版本不支持建立web工程的问题 我们会发现 JAVA eclipse SE版本无法建立 Web 程序的问题...... 最好的解决方法就是下载一个myeclipse 或 Jav ...