准备数据

准备 COCO128 数据集,其是 COCO train2017 前 128 个数据。按 YOLOv5 组织的目录:

$ tree ~/datasets/coco128 -L 2
/home/john/datasets/coco128
├── images
│   └── train2017
│   ├── ...
│   └── 000000000650.jpg
├── labels
│   └── train2017
│   ├── ...
│   └── 000000000650.txt
├── LICENSE
└── README.txt

详见 Train Custom Data

定义 Dataset

torch.utils.data.Dataset 是一个数据集的抽象类。自定义数据集时,需继承 Dataset 并覆盖如下方法:

  • __len__: len(dataset) 获取数据集大小。
  • __getitem__: dataset[i] 访问第 i 个数据。

详见:

自定义实现 YOLOv5 数据集的例子:

import os
from pathlib import Path
from typing import Any, Callable, Optional, Tuple import numpy as np
import torch
import torchvision
from PIL import Image class YOLOv5(torchvision.datasets.vision.VisionDataset): def __init__(
self,
root: str,
name: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super(YOLOv5, self).__init__(root, transforms, transform, target_transform)
images_dir = Path(root) / 'images' / name
labels_dir = Path(root) / 'labels' / name
self.images = [n for n in images_dir.iterdir()]
self.labels = []
for image in self.images:
base, _ = os.path.splitext(os.path.basename(image))
label = labels_dir / f'{base}.txt'
self.labels.append(label if label.exists() else None) def __getitem__(self, idx: int) -> Tuple[Any, Any]:
img = Image.open(self.images[idx]).convert('RGB') label_file = self.labels[idx]
if label_file is not None: # found
with open(label_file, 'r') as f:
labels = [x.split() for x in f.read().strip().splitlines()]
labels = np.array(labels, dtype=np.float32)
else: # missing
labels = np.zeros((0, 5), dtype=np.float32) boxes = []
classes = []
for label in labels:
x, y, w, h = label[1:]
boxes.append([
(x - w/2) * img.width,
(y - h/2) * img.height,
(x + w/2) * img.width,
(y + h/2) * img.height])
classes.append(label[0]) target = {}
target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
target["labels"] = torch.as_tensor(classes, dtype=torch.int64) if self.transforms is not None:
img, target = self.transforms(img, target) return img, target def __len__(self) -> int:
return len(self.images)

以上实现,继承了 VisionDataset 子类。其 __getitem__ 返回了:

  • image: PIL Image, 大小为 (H, W)
  • target: dict, 含以下字段:
    • boxes (FloatTensor[N, 4]): 真实标注框 [x1, y1, x2, y2], x 范围 [0,W], y 范围 [0,H]
    • labels (Int64Tensor[N]): 上述标注框的类别标识

读取 Dataset

dataset = YOLOv5(Path.home() / 'datasets/coco128', 'train2017')
print(f'dataset: {len(dataset)}')
print(f'dataset[0]: {dataset[0]}')

输出:

dataset: 128
dataset[0]: (<PIL.Image.Image image mode=RGB size=640x480 at 0x7F6F9464ADF0>, {'boxes': tensor([[249.7296, 200.5402, 460.5399, 249.1901],
[448.1702, 363.7198, 471.1501, 406.2300],
...
[ 0.0000, 188.8901, 172.6400, 280.9003]]), 'labels': tensor([44, 51, 51, 51, 51, 44, 44, 44, 44, 44, 45, 45, 45, 45, 45, 45, 45, 45,
45, 50, 50, 50, 51, 51, 60, 42, 44, 45, 45, 45, 50, 51, 51, 51, 51, 51,
51, 44, 50, 50, 50, 45])})

预览:

使用 DataLoader

训练需要批量提取数据,可以使用 DataLoader :

dataset = YOLOv5(Path.home() / 'datasets/coco128', 'train2017',
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])) dataloader = DataLoader(dataset, batch_size=64, shuffle=True,
collate_fn=lambda batch: tuple(zip(*batch))) for batch_i, (images, targets) in enumerate(dataloader):
print(f'batch {batch_i}, images {len(images)}, targets {len(targets)}')
print(f' images[0]: shape={images[0].shape}')
print(f' targets[0]: {targets[0]}')

输出:

batch 0, images 64, targets 64
images[0]: shape=torch.Size([3, 480, 640])
targets[0]: {'boxes': tensor([[249.7296, 200.5402, 460.5399, 249.1901],
[448.1702, 363.7198, 471.1501, 406.2300],
...
[ 0.0000, 188.8901, 172.6400, 280.9003]]), 'labels': tensor([44, 51, 51, 51, 51, 44, 44, 44, 44, 44, 45, 45, 45, 45, 45, 45, 45, 45,
45, 50, 50, 50, 51, 51, 60, 42, 44, 45, 45, 45, 50, 51, 51, 51, 51, 51,
51, 44, 50, 50, 50, 45])}
batch 1, images 64, targets 64
images[0]: shape=torch.Size([3, 248, 640])
targets[0]: {'boxes': tensor([[337.9299, 167.8500, 378.6999, 191.3100],
[383.5398, 148.4501, 452.6598, 191.4701],
[467.9299, 149.9001, 540.8099, 193.2401],
[196.3898, 142.7200, 271.6896, 190.0999],
[134.3901, 154.5799, 193.9299, 189.1699],
[ 89.5299, 162.1901, 124.3798, 188.3301],
[ 1.6701, 154.9299, 56.8400, 188.3700]]), 'labels': tensor([20, 20, 20, 20, 20, 20, 20])}

源码

参考

APIs:

GoCoding 个人实践的经验分享,可关注公众号!

PyTorch 自定义数据集的更多相关文章

  1. [转载]pytorch自定义数据集

    为什么要定义Datasets: PyTorch提供了一个工具函数torch.utils.data.DataLoader.通过这个类,我们在准备mini-batch的时候可以多线程并行处理,这样可以加快 ...

  2. Pytorch划分数据集的方法

    之前用过sklearn提供的划分数据集的函数,觉得超级方便.但是在使用TensorFlow和Pytorch的时候一直找不到类似的功能,之前搜索的关键字都是"pytorch split dat ...

  3. pytorch加载语音类自定义数据集

    pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...

  4. MMDetection 快速开始,训练自定义数据集

    本文将快速引导使用 MMDetection ,记录了实践中需注意的一些问题. 环境准备 基础环境 Nvidia 显卡的主机 Ubuntu 18.04 系统安装,可见 制作 USB 启动盘,及系统安装 ...

  5. Scaled-YOLOv4 快速开始,训练自定义数据集

    代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...

  6. torch_13_自定义数据集实战

    1.将图片的路径和标签写入csv文件并实现读取 # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0 def load_csv(self,file ...

  7. Tensorflow2 自定义数据集图片完成图片分类任务

    对于自定义数据集的图片任务,通用流程一般分为以下几个步骤: Load data Train-Val-Test Build model Transfer Learning 其中大部分精力会花在数据的准备 ...

  8. Pytorch自定义数据库

    1)前言 虽然torchvision.datasets中已经封装了好多通用的数据集,但是我们在使用Pytorch做深度学习任务的时候,会面临着自定义数据库来满足自己的任务需要.如我们要训练一个人脸关键 ...

  9. [炼丹术]YOLOv5训练自定义数据集

    YOLOv5训练自定义数据 一.开始之前的准备工作 克隆 repo 并在Python>=3.6.0环境中安装requirements.txt,包括PyTorch>=1.7.模型和数据集会从 ...

随机推荐

  1. 2.centos 7清空文件和文件夹

    1.清空文件 测试文件:a.txt 1)方法一,[root@centos test]# > a.txt [root@centos test]# cat a.txt 1hjbfao hjkl23o ...

  2. python内置常量是什么?

    摘要:学习Python的过程中,我们会从变量常量开始学习,那么python内置的常量你知道吗? 一个新产品,想熟悉它,最好的办法就是查看说明书! 没错,Python也给我们准备了这样的说明书--Pyt ...

  3. 2019牛客暑期多校训练营(第一场)E ABBA (DP/卡特兰数)

    传送门 知识点:卡特兰数/动态规划 法一:动态规划 由题意易知字符串的任何一个前缀都满足\(cnt(A) - cnt(B) \le n , cnt(B)-cnt(A)\le m\) \(d[i][j] ...

  4. zjnu1735BOB (单调队列,单调栈)

    Description Little Bob is a famous builder. He bought land and wants to build a house. Unfortunately ...

  5. 51Nod - 1632

    B国拥有n个城市,其交通系统呈树状结构,即任意两个城市存在且仅存在一条交通线将其连接.A国是B国的敌国企图秘密发射导弹打击B国的交通线,现假设每条交通线都有50%的概率被炸毁,B国希望知道在被炸毁之后 ...

  6. Codeforces Round #660 (Div. 2) C. Uncle Bogdan and Country Happiness (DFS)

    题意:有\(n\)个人,每个人居住在某个节点,所有人都在节点\(1\)上班,下班后沿着最短路径回家,在回家途中心情可能会变差(心情只会变差不会变好),每个节点都有一个开心值,开心值等于所有经过时的好心 ...

  7. Gome 高性能撮合引擎微服务

    Gome 高性能撮合引擎微服务 使用 Golang 做计算,gRPC 做服务,ProtoBuf 做数据交换,RabbitMQ 做队列,Redis 做缓存实现的高性能撮合引擎微服务 依赖 具体依赖信息可 ...

  8. 手摸手带你学移动端WEB开发

    HTML常用标签总结 手摸手带你学CSS HTML5与CSS3知识点总结 手摸手带你学移动端WEB开发 好好学习,天天向上 本文已收录至我的Github仓库DayDayUP:github.com/Ro ...

  9. K8S(03)核心插件-Flannel网络插件

    系列文章说明 本系列文章,可以基本算是 老男孩2019年王硕的K8S周末班课程 笔记,根据视频来看本笔记最好,否则有些地方会看不明白 需要视频可以联系我 K8S核心网络插件Flannel 目录 系列文 ...

  10. Django实现文件上传

    一.HTML <!DOCTYPE html> <html lang="en"> <head> <meta charset="UT ...