[pytorch修改]dataloader.py 实现darknet中的subdivision功能
dataloader.py
import random
import torch
import torch.multiprocessing as multiprocessing
from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
_remove_worker_pids, _error_if_any_worker_fails
from . import SequentialSampler, RandomSampler, BatchSampler
import signal
import functools
import collections
import re
import sys
import threading
import traceback
import os
import time
from torch._six import string_classes, int_classes, FileNotFoundError IS_WINDOWS = sys.platform == "win32"
if IS_WINDOWS:
import ctypes
from ctypes.wintypes import DWORD, BOOL, HANDLE if sys.version_info[0] == 2:
import Queue as queue
else:
import queue class ExceptionWrapper(object):
r"""Wraps an exception plus traceback to communicate across threads""" def __init__(self, exc_info):
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info)) _use_shared_memory = False
r"""Whether to use shared memory in default_collate""" MANAGER_STATUS_CHECK_INTERVAL = 5.0 if IS_WINDOWS:
# On Windows, the parent ID of the worker process remains unchanged when the manager process
# is gone, and the only way to check it through OS is to let the worker have a process handle
# of the manager and ask if the process status has changed.
class ManagerWatchdog(object):
def __init__(self):
self.manager_pid = os.getppid() self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
self.kernel32.OpenProcess.restype = HANDLE
self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
self.kernel32.WaitForSingleObject.restype = DWORD # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
SYNCHRONIZE = 0x00100000
self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) if not self.manager_handle:
raise ctypes.WinError(ctypes.get_last_error()) def is_alive(self):
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0
else:
class ManagerWatchdog(object):
def __init__(self):
self.manager_pid = os.getppid() def is_alive(self):
return os.getppid() == self.manager_pid def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
global _use_shared_memory
_use_shared_memory = True # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal happened again already.
# https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
_set_worker_signal_handlers() torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed) if init_fn is not None:
init_fn(worker_id) watchdog = ManagerWatchdog() while True:
try:
r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
except queue.Empty:
if watchdog.is_alive():
continue
else:
break
if r is None:
break
idx, batch_indices = r
try:
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
del samples def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
if pin_memory:
torch.cuda.set_device(device_id) while True:
try:
r = in_queue.get()
except Exception:
if done_event.is_set():
return
raise
if r is None:
break
if isinstance(r[1], ExceptionWrapper):
out_queue.put(r)
continue
idx, batch = r
try:
if pin_memory:
batch = pin_memory_batch(batch)
except Exception:
out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
out_queue.put((idx, batch)) numpy_type_map = {
'float64': torch.DoubleTensor,
'float32': torch.FloatTensor,
'float16': torch.HalfTensor,
'int64': torch.LongTensor,
'int32': torch.IntTensor,
'int16': torch.ShortTensor,
'int8': torch.CharTensor,
'uint8': torch.ByteTensor,
} def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size""" error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
out = None
if _use_shared_memory:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype)) return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int_classes):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], collections.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed] raise TypeError((error_msg.format(type(batch[0])))) def pin_memory_batch(batch):
if isinstance(batch, torch.Tensor):
return batch.pin_memory()
elif isinstance(batch, string_classes):
return batch
elif isinstance(batch, collections.Mapping):
return {k: pin_memory_batch(sample) for k, sample in batch.items()}
elif isinstance(batch, collections.Sequence):
return [pin_memory_batch(sample) for sample in batch]
else:
return batch _SIGCHLD_handler_set = False
r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
handler needs to be set for all DataLoaders in a process.""" def _set_SIGCHLD_handler():
# Windows doesn't support SIGCHLD handler
if sys.platform == 'win32':
return
# can't set signal in child threads
if not isinstance(threading.current_thread(), threading._MainThread):
return
global _SIGCHLD_handler_set
if _SIGCHLD_handler_set:
return
previous_handler = signal.getsignal(signal.SIGCHLD)
if not callable(previous_handler):
previous_handler = None def handler(signum, frame):
# This following call uses `waitid` with WNOHANG from C side. Therefore,
# Python can still get and update the process status successfully.
_error_if_any_worker_fails()
if previous_handler is not None:
previous_handler(signum, frame) signal.signal(signal.SIGCHLD, handler)
_SIGCHLD_handler_set = True class _DataLoaderIter(object):
r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" def __init__(self, loader):
self.dataset = loader.dataset
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.done_event = threading.Event() self.sample_iter = iter(self.batch_sampler) base_seed = torch.LongTensor(1).random_().item() if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
self.worker_queue_idx = 0
self.worker_result_queue = multiprocessing.SimpleQueue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {} self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queues[i],
self.worker_result_queue, self.collate_fn, base_seed + i,
self.worker_init_fn, i))
for i in range(self.num_workers)] if self.pin_memory or self.timeout > 0:
self.data_queue = queue.Queue()
if self.pin_memory:
maybe_device_id = torch.cuda.current_device()
else:
# do not initialize cuda context if not necessary
maybe_device_id = None
self.worker_manager_thread = threading.Thread(
target=_worker_manager_loop,
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
maybe_device_id))
self.worker_manager_thread.daemon = True
self.worker_manager_thread.start()
else:
self.data_queue = self.worker_result_queue for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start() _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
_set_SIGCHLD_handler()
self.worker_pids_set = True # prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices() def __len__(self):
return len(self.batch_sampler) def _get_batch(self):
if self.timeout > 0:
try:
return self.data_queue.get(timeout=self.timeout)
except queue.Empty:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
else:
return self.data_queue.get() def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch # check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch) if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration while True:
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self._get_batch()
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
return self._process_next_batch(batch) next = __next__ # Python 2 compatibility def __iter__(self):
return self def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
indices = next(self.sample_iter, None)
if indices is None:
return
self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
self.batches_outstanding += 1
self.send_idx += 1 def _process_next_batch(self, batch):
self.rcvd_idx += 1
self._put_indices()
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch def __getstate__(self):
# TODO: add limited pickling support for sharing an iterator
# across multiple threads for HOGWILD.
# Probably the best way to do this is by moving the sample pushing
# to a separate thread and then just sharing the data queue
# but signalling the end is tricky without a non-blocking API
raise NotImplementedError("_DataLoaderIter cannot be pickled") def _shutdown_workers(self):
try:
if not self.shutdown:
self.shutdown = True
self.done_event.set()
for q in self.index_queues:
q.put(None)
# if some workers are waiting to put, make place for them
try:
while not self.worker_result_queue.empty():
self.worker_result_queue.get()
except (FileNotFoundError, ImportError):
# Many weird errors can happen here due to Python
# shutting down. These are more like obscure Python bugs.
# FileNotFoundError can happen when we rebuild the fd
# fetched from the queue but the socket is already closed
# from the worker side.
# ImportError can happen when the unpickler loads the
# resource from `get`.
pass
# done_event should be sufficient to exit worker_manager_thread,
# but be safe here and put another None
self.worker_result_queue.put(None)
finally:
# removes pids no matter what
if self.worker_pids_set:
_remove_worker_pids(id(self))
self.worker_pids_set = False def __del__(self):
if self.num_workers > 0:
self._shutdown_workers() class DataLoader(object):
r"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset. Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: 1).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: False).
sampler (Sampler, optional): defines the strategy to draw samples from
the dataset. If specified, ``shuffle`` must be False.
batch_sampler (Sampler, optional): like sampler, but returns a batch of
indices at a time. Mutually exclusive with batch_size, shuffle,
sampler, and drop_last.
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means that the data will be loaded in the main process.
(default: 0)
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
into CUDA pinned memory before returning them.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: False)
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: 0)
worker_init_fn (callable, optional): If not None, this will be called on each
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
input, after seeding and before data loading. (default: None) .. note:: By default, each worker will have its PyTorch seed set to
``base_seed + worker_id``, where ``base_seed`` is a long generated
by main process using its RNG. However, seeds for other libraies
may be duplicated upon initializing workers (w.g., NumPy), causing
each worker to return identical random numbers. (See
:ref:`dataloader-workers-random-seed` section in FAQ.) You may
use ``torch.initial_seed()`` to access the PyTorch seed for each
worker in :attr:`worker_init_fn`, and use it to set other seeds
before data loading. .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
unpicklable object, e.g., a lambda function.
""" __initialized = False def __init__(self, dataset, batch_size=1, subdivison=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size_real = batch_size
self.subdivision = subdivison
self.batch_size = batch_size/subdivison
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
self.timeout = timeout
self.worker_init_fn = worker_init_fn if timeout < 0:
raise ValueError('timeout option should be non-negative') if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler option is mutually exclusive '
'with batch_size, shuffle, sampler, and '
'drop_last')
self.batch_size = None
self.drop_last = None if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with '
'shuffle') if self.num_workers < 0:
raise ValueError('num_workers option cannot be negative; '
'use num_workers=0 to disable multiprocessing.') if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler
self.batch_sampler = batch_sampler
self.__initialized = True def __setattr__(self, attr, val):
if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
raise ValueError('{} attribute should not be set after {} is '
'initialized'.format(attr, self.__class__.__name__)) super(DataLoader, self).__setattr__(attr, val) def __iter__(self):
return _DataLoaderIter(self) def __len__(self):
return len(self.batch_sampler)
[pytorch修改]dataloader.py 实现darknet中的subdivision功能的更多相关文章
- [pytorch修改]npyio.py 实现在标签中使用两种delimiter分割文件的行
from __future__ import division, absolute_import, print_function import io import sys import os impo ...
- Pytorch自定义dataloader以及在迭代过程中返回image的name
pytorch官方给的加载数据的方式是已经定义好的dataset以及loader,如何加载自己本地的图片以及label? 形如数据格式为 image1 label1 image2 label2 ... ...
- pytorch之dataloader深入剖析
PyTorch学习笔记(6)——DataLoader源代码剖析 - dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问: - 使用iter(dataloader) ...
- 记pytorch版faster rcnn配置运行中的一些坑
记pytorch版faster rcnn配置运行中的一些坑 项目地址 https://github.com/jwyang/faster-rcnn.pytorch 一般安装配置参考README.md文件 ...
- PyTorch 之 DataLoader
DataLoader DataLoader 是 PyTorch 中读取数据的一个重要接口,该接口定义在 dataloader.py 文件中,该接口的目的: 将自定义的 Dataset 根据 batch ...
- PyTorch之DataLoader杂谈
输入数据PipeLine pytorch 的数据加载到模型的操作顺序是这样的: ①创建一个 Dataset 对象②创建一个 DataLoader 对象③循环这个 DataLoader 对象,将img, ...
- YOLOV3中Darknet中cfg文件说明和理解
今天将要说明的是Darknet中的cfg文件,废话少说,直接干!(以cfg/yolov3.cfg为例,其它类似) [net] ★ [xxx]开始的行表示网 ...
- 如何修改myeclipse 内存,eclipse.ini中各个参数的作用。
修改MyEclipse/eclipse文件夹中配置文件eclipse.ini中的内存分配就哦了 =================================== 一般的ini文件设置主要包括以下 ...
- C# 操作 Word 修改word的高级属性中的自定义属性2
word的类库使用的是word2007版本的类库,类库信息见下面图片,折腾了半天,终于找到入口,网上 很多说的添加或者修改word的高级属性中的自定义属性都是错误的,感觉都是在copy网上的代码,自己 ...
随机推荐
- layui---表单验证
使用layui,使用它的表单验证也是比不可少的,下面就来总结下: <!-- 不用form 用div也可以 --> <form class="layui-form" ...
- Ubuntu16.04 本地提权漏洞复测过程
一.漏洞概述 Ubuntu 16.04 版本且unprivileged_bpf_disable 权限没有关闭的情况下就会存在 提权漏洞查看方式:1,cat /proc/version 查看系统版本 2 ...
- vue的单向数据流
父级向子组件传递的值, 子组件不能直接修改这个穿过来的值,否则会曝出警告,这就是单项数据流. 如果是引用值,传递的是引用值得地址,而不是值本身,也就是说,子组件里修改这个传过来的值,通常的做法是放到它 ...
- 多线程之批量插入小demo
多线程之批量插入 背景 昨天在测试mysql的两种批量更新时,由于需要入库大量测试数据,反复执行插入脚本,过程繁琐,档次很低,测试完后我就想着写个批量插入的小demo,然后又想写个多线程的批量插入的d ...
- ThinkPHP 缓存技术详解 使用大S方法
如果没有缓存的网站是百万级或者千万级的访问量,会给数据库或者服务器造成很大的压力,通过缓存,大幅减少服务器和数据库的负荷,假如我们把读取数据的过程分为三个层,第一个是访问层,第一个是缓存层,第三个是数 ...
- 团队作业记账本开发NABCD
N(Need)需求 现如今大学生的消费存在很大问题,很多情况下都是图一时之快,冲动消费,但是其实这些东西并不是特别需要.这样慢慢的堆积也就导致了大学生月月精光的局面.另外,现在基本上人手一部手机,许多 ...
- HTML链接式引入CSS和JS
<!-调用CSS-> <link href="./XXXXX.css" rel="stylesheet" type="text/cs ...
- 17.0-uC/OS-III消息管理
消息传递 有些情况下任务或ISR与另一个任务间进行通信,这种信息交换叫做作业间的通信. 可以有两种方法实现这种通信: 全局变量. 发送消息. 1.果使用全局变量,任务或ISR就须确保它独占该变量.如果 ...
- 对Java代码加密的两种方式,防止反编译
使用Virbox Protector对Java项目加密有两种方式,一种是对War包加密,一种是对Jar包加密.Virbox Protector支持这两种文件格式加密,可以加密用于解析class文件的j ...
- 69A
#include <stdio.h> int main() { int n; int sum1=0, sum2=0, sum3=0; int x, y, z; scanf("%d ...