构建data_loader原理步骤

# engine/default.py
from detectron2.data import (
MetadataCatalog,
build_detection_test_loader,
build_detection_train_loader,
)
class DefaultTrainer(SimpleTrainer):
def __init__(self, cfg):
# Assume these objects must be constructed in this order.
data_loader = self.build_train_loader(cfg)
...
@classmethod
def build_train_loader(cls, cfg):
"""
Returns:
iterable
"""
return build_detection_train_loader(cfg)

函数调用关系如下图:

结合前面两篇文章的内容可以看到detectron2在构建model,optimizer和data_loader的时候都是在对应的build.py文件里实现的。我们看一下build_detection_train_loader是如何定义的(对应上图中紫色方框内的部分(自下往上的顺序)):


def build_detection_train_loader(cfg, mapper=None):
"""
A data loader is created by the following steps: 1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
2. Start workers to work on the dicts. Each worker will:
* Map each metadata dict into another format to be consumed by the model.
* Batch them by simply putting dicts into a list.
The batched ``list[mapped_dict]`` is what this dataloader will return. Args:
cfg (CfgNode): the config
mapper (callable): a callable which takes a sample (dict) from dataset and
returns the format to be consumed by the model.
By default it will be `DatasetMapper(cfg, True)`. Returns:
a torch DataLoader object
"""
# 获得dataset_dicts
dataset_dicts = get_detection_dataset_dicts(
cfg.DATASETS.TRAIN,
filter_empty=True,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
) # 将dataset_dicts转化成torch.utils.data.Dataset
dataset = DatasetFromList(dataset_dicts, copy=False) # 进一步转化成MapDataset,每次读取数据时都会调用mapper来对dict进行解析
if mapper is None:
mapper = DatasetMapper(cfg, True)
dataset = MapDataset(dataset, mapper) # 采样器
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
if sampler_name == "TrainingSampler":
sampler = samplers.TrainingSampler(len(dataset))
...
batch_sampler = build_batch_data_sampler(
sampler, images_per_worker, group_bin_edges, aspect_ratios
) # 数据迭代器 data_loader
data_loader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.DATALOADER.NUM_WORKERS,
batch_sampler=batch_sampler,
collate_fn=trivial_batch_collator,
worker_init_fn=worker_init_reset_seed,
)
return data_loader

由上面的源代码可以看出总共是五个步骤,我们只对前面三个部分进行详细介绍,后面的采样器和data_loader可以参阅一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

获得dataset_dicts

get_detection_dataset_dicts(dataset_names)函数需要传递的一个重要参数是dataset_names,这个参数其实就是一个字符串,用来指定数据集的名称。通过这个字符串,该函数会调用data/catalog.pyDatasetCatalog类来进行解析得到一个包含数据信息的字典。

解析的原理是:DatasetCatalog有一个字典_REGISTERED,默认已经注册好了例如coco,voc这些数据集的信息。如果你想要使用你自己的数据集,那么你需要在最开始前你需要定义你的数据集名字以及定义一个函数(这个函数不需要传参,而且最后会返回一个dict,该dict包含你的数据集信息),举个栗子:

from detectron2.data import DatasetCatalog
my_dataset_name = 'apple'
def get_dicts():
...
return dict DatasetCatalog.register(my_dataset_name, get_dicts)

当然,如果你的数据集已经是COCO的格式了,那么你也可以使用如下方法进行注册:

from detectron2.data.datasets import register_coco_instances
my_dataset_name = 'apple'
register_coco_instances(my_dataset_name, {}, "json_annotation.json", "path/to/image/dir")

另外需要注意的是一个数据集其实是可以由两个类来定义的,一个是前面介绍了的DatasetCatalog,另一个是MetadataCatalog

MetadataCatalog的作用是记录数据集的一些特征,这样我们就可以很方便的在整个代码中获取数据集的特征信息。在注册DatasetCatalog后,我们可以按如下栗子对MetadataCatalog进行注册并定义我们后面可能会用到的属性特征:

from detectron2.data import MetadataCatalog
MetadataCatalog.get("my_dataset").thing_classes = ["person", "dog"] # 也可以这样
MetadataCatalog.get("my_dataset").set("thing_classes",["person", "dog"])

注意:如果你的数据集名字未注册过,MetadataCatalog.get会自动进行注册,然后会自动设置你所设定的属性值。

其实MetadataCatalog还有其他的特征属性可以设置,如stuff_classes,stuff_colors等等。你可能会好奇thing_classesstuff_classes有什么区别,区别如下:

  • 抽象解释:thing_classes用于指定instance-level任务,stuff_classes用于semantic segmentation任务。
  • 具体解释:像椅子,书这种可数的东西,就可以理解成thing,所以用于instance-level;而雪、天空这种不可数的就理解成stuff,所以用于semantic segmentation。参考On Seeing Stuff: The Perception of Materials by Humans and Machines

最后,get_detection_dataset_dicts会返回一个包含若干个dict的list,之所以是list是因为参数dataset_names也是一个list,这样我们就可以制定多个names来同时对数据进行读取。

解析成DatasetFromList

DatasetFromList(dataset_dict)函数定义在detectron2/data/common.py中,它其实就是一个torch.utils.data.Dataset类,其源码如下

class DatasetFromList(data.Dataset):
"""
Wrap a list to a torch Dataset. It produces elements of the list as data.
""" def __init__(self, lst: list, copy: bool = True):
"""
Args:
lst (list): a list which contains elements to produce.
copy (bool): whether to deepcopy the element when producing it,
so that the result can be modified in place without affecting the
source in the list.
"""
self._lst = lst
self._copy = copy def __len__(self):
return len(self._lst) def __getitem__(self, idx):
if self._copy:
return copy.deepcopy(self._lst[idx])
else:
return self._lst[idx]

这个很简单就不加赘述了

DatsetFromList转化成MapDataset

其实DatsetFromListMapDataset都是torch.utils.data.Dataset的子类,那他们的区别是什么呢?很简单,区别就是后者使用了mapper

在解释mapper是什么之前我们首先要知道的是,在detectron2中,一张图片对应的是一个dict,那么整个数据集就是list[dict]。之后我们再看DatsetFromList,它的__getitem__函数非常简单,它只是简单粗暴地就返回了指定idx的元素。显然这样是不行的,因为在把数据扔给模型训练之前我们肯定还要对数据做一定的处理,而这个工作就是由mapper来做的,默认情况下使用的是detectron2/data/dataset_mapper.py中定义的DatasetMapper,如果你需要自定义一个mapper也可以参考这个写。

DatasetMapper(cfg, is_train=True)

我们继续了解一下DatasetMapper的实现原理,首先看一下官方给的定义:

A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by the model.

简单概括就是这个类是可调用的(callable),所以在下面的源码中可以看到定义了__call__方法。

该类主要做了这三件事:

The callable currently does the following:

  1. Read the image from "file_name"
  2. Applies cropping/geometric transforms to the image and annotations
  3. Prepare data and annotations to Tensor and :class:Instances

其源码如下(有删减):

class DatasetMapper:
def __init__(self, cfg, is_train=True):
# 读取cfg的参数
... def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below # 1. 读取图像数据
image = utils.read_image(dataset_dict["file_name"], format=self.img_format) # 2. 对image和box等做Transformation
if "annotations" not in dataset_dict:
image, transforms = T.apply_transform_gens(
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
)
else:
...
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
if self.crop_gen:
transforms = crop_tfm + transforms image_shape = image.shape[:2] # h, w # 3.将数据转化成tensor格式
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
... return dataset_dict

MapDataset

class MapDataset(data.Dataset):
def __init__(self, dataset, map_func):
self._dataset = dataset
self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work self._rng = random.Random(42)
self._fallback_candidates = set(range(len(dataset))) def __len__(self):
return len(self._dataset) def __getitem__(self, idx):
retry_count = 0
cur_idx = int(idx) while True:
data = self._map_func(self._dataset[cur_idx])
if data is not None:
self._fallback_candidates.add(cur_idx)
return data # _map_func fails for this idx, use a random new index from the pool
retry_count += 1
self._fallback_candidates.discard(cur_idx)
cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0] if retry_count >= 3:
logger = logging.getLogger(__name__)
logger.warning(
"Failed to apply `_map_func` for idx: {}, retry count: {}".format(
idx, retry_count
)
)
  • self._fallback_candidates是一个set,它的特点是其中的元素是独一无二的,定义这个的作用是记录可正常读取的数据索引,因为有的数据可能无法正常读取,所以这个时候我们就可以把这个坏数据的索引从_fallback_candidates中剔除,并随机采样一个索引来读取数据。
  • __getitem__中的逻辑就是首先读取指定索引的数据,如果正常读取就把该所索引值加入到_fallback_candidates中去;反之,如果数据无法读取,则将对应索引值删除,并随机采样一个数据,并且尝试3次,若3次后都无法正常读取数据,则报错,但是好像也没有退出程序,而是继续读数据,可能是以为总有能正常读取的数据吧hhh。

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com




2020-01-23 17:45:35

如有意合作,欢迎私戳

邮箱:marsggbo@foxmail.com



2019-10-23 13:37:13

Detectron2源码阅读笔记-(三)Dataset pipeline的更多相关文章

  1. Detectron2源码阅读笔记-(二)Registry&build_*方法

    ​ Trainer解析 我们继续Detectron2代码阅读笔记-(一)中的内容. 上图画出了detectron2文件夹中的三个子文件夹(tools,config,engine)之间的关系.那么剩下的 ...

  2. Detectron2源码阅读笔记-(一)Config&Trainer

    代码结构概览 核心部分 configs:储存各种网络的yaml配置文件 datasets:存放数据集的地方 detectron2:运行代码的核心组件 tools:提供了运行代码的入口以及一切可视化的代 ...

  3. Werkzeug源码阅读笔记(三)

    这次主要讲下werkzeug中的Local. 源码在werkzeug/local.py Thread Local 在Python中,状态是保存在对象中.Thread Local是一种特殊的对象,它是对 ...

  4. CI框架源码阅读笔记5 基准测试 BenchMark.php

    上一篇博客(CI框架源码阅读笔记4 引导文件CodeIgniter.php)中,我们已经看到:CI中核心流程的核心功能都是由不同的组件来完成的.这些组件类似于一个一个单独的模块,不同的模块完成不同的功 ...

  5. CI框架源码阅读笔记2 一切的入口 index.php

    上一节(CI框架源码阅读笔记1 - 环境准备.基本术语和框架流程)中,我们提到了CI框架的基本流程,这里再次贴出流程图,以备参考: 作为CI框架的入口文件,源码阅读,自然由此开始.在源码阅读的过程中, ...

  6. 源码阅读笔记 - 1 MSVC2015中的std::sort

    大约寒假开始的时候我就已经把std::sort的源码阅读完毕并理解其中的做法了,到了寒假结尾,姑且把它写出来 这是我的第一篇源码阅读笔记,以后会发更多的,包括算法和库实现,源码会按照我自己的代码风格格 ...

  7. libevent源码阅读笔记(一):libevent对epoll的封装

    title: libevent源码阅读笔记(一):libevent对epoll的封装 最近开始阅读网络库libevent的源码,阅读源码之前,大致看了张亮写的几篇博文(libevent源码深度剖析 h ...

  8. jdk源码阅读笔记-LinkedHashMap

    Map是Java collection framework 中重要的组成部分,特别是HashMap是在我们在日常的开发的过程中使用的最多的一个集合.但是遗憾的是,存放在HashMap中元素都是无序的, ...

  9. faster rcnn源码阅读笔记1

    自己保存的源码阅读笔记哈 faster rcnn 的主要识别过程(粗略) (开始填坑了): 一张3通道,1600*1600图像输入中,经过特征提取网络,得到100*100*512的feature ma ...

随机推荐

  1. C++对象布局

    <C++应用程序性能优化><深度探索C++对象模型>笔记 #include<iostream> using namespace std; class student ...

  2. c++基础第一篇

    前言:我是从c和c++对比的角度来讲解c++的基础知识. (1)c++格式如下: #include <iostream> //标准输入输出头文件 using namespace std; ...

  3. [LeetCode] 893. Groups of Special-Equivalent Strings 特殊字符串的群组

    You are given an array A of strings. Two strings S and T are special-equivalent if after any number ...

  4. 分布式共识算法 (三) Raft算法

    系列目录 分布式共识算法 (一) 背景 分布式共识算法 (二) Paxos算法 分布式共识算法 (三) Raft算法 分布式共识算法 (四) BTF算法 一.引子 1.1 介绍 Raft 是一种为了管 ...

  5. 团队作业第五次—项目冲刺-Day7

    Day7 part1-SCRUM: 项目相关 作业相关 具体描述 所属班级 2019秋福大软件工程实践Z班 作业要求 团队作业第五次-项目冲刺 作业正文 hunter--冲刺集合 团队名称 hunte ...

  6. 第26课 std::async异步任务

    一. std::async函数模板 (一)std::async和std::thread的区别 1. 两者最明显的区别在于async采用默认启动策略时并不一定创建新的线程.如果系统资源紧张,那么std: ...

  7. 如何写APA格式的论文

    一.一般准则 FONT :   TIMES NEW ROMAN SIZE                    :   12 DOUBLE-SPACING INDENT               : ...

  8. python字符串格式化方法%s和format函数

    1.%s方法 一个例子 print("my name is %s and i am %d years old" %("xiaoming",18) 输出结果:my ...

  9. Mysql load data infile 命令导入含中文csv源数据文件 【错误代码 1300】

    [1]Load data infile 命令导入含中文csv源数据文件 报错:Invalid utf8 character string: '??֧' (1)问题现象 csv格式文件源数据: 导入SQ ...

  10. 插件油泼猴+脚本 for chrome 安装 - https://greasyfork.org/zh-CN

    http://chromecj.com/utilities/2018-09/1525.html 一.将 *.crx 改名为 *.zip 二.访问 chrome://flags/#extensions- ...