Pytorch collate_fn用法
By default, Dataloader
use collate_fn
method to pack a series of images and target as tensors (first dimension of tensor is batch size). The default collate_fn
expects all the images in a batch to have the same size because it uses torch.stack()
to pack the images. If the images provided by Dataset
have variable size, you have to provide your custom collate_fn
. A simple example is shown below:
# a simple custom collate function, just to show the idea # `batch` is a list of tuple where first element is image tensor and # second element is corresponding label def my_collate(batch):
data = [item[0] for item in batch] # just form a list of tensor target = [item[1] for item in batch]
target = torch.LongTensor(target)
return [data, target]
Reference: Writing Your Own Custom Dataset for Classification in PyTorch
By default, torch stacks the input image to from a tensor of size N*C*H*W
, so every image in the batch must have the same height and width. In order to load a batch with variable size input image, we have to use our own collate_fn
which is used to pack a batch of images.
For image classification, the input to collate_fn
is a list of with size batch_size
. Each element is a tuple where the first element is the input image(a torch.FloatTensor
) and the second element is the image label which is simply an int
. Because the samples in a batch have different size, we can store these samples in a list ans store the corresponding labels in torch.LongTensor
. Then we put the image list and the label tensor into a list and return the result.
here is a very simple snippet to demonstrate how to write a custom collate_fn
:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt # a simple custom collate function, just to show the idea
def my_collate(batch):
data = [item[0] for item in batch]
target = [item[1] for item in batch]
target = torch.LongTensor(target)
return [data, target] def show_image_batch(img_list, title=None):
num = len(img_list)
fig = plt.figure()
for i in range(num):
ax = fig.add_subplot(1, num, i+1)
ax.imshow(img_list[i].numpy().transpose([1,2,0]))
ax.set_title(title[i]) plt.show() # do not do randomCrop to show that the custom collate_fn can handle images of different size
train_transforms = transforms.Compose([transforms.Scale(size = 224),
transforms.ToTensor(),
]) # change root to valid dir in your system, see ImageFolder documentation for more info
train_dataset = datasets.ImageFolder(root="/hd1/jdhao/toyset",
transform=train_transforms) trainset = DataLoader(dataset=train_dataset,
batch_size=4,
shuffle=True,
collate_fn=my_collate, # use custom collate function here
pin_memory=True) trainiter = iter(trainset)
imgs, labels = trainiter.next() # print(type(imgs), type(labels))
show_image_batch(imgs, title=[train_dataset.classes[x] for x in labels])
Reference: How to create a dataloader with variable-size input
Dataloader的测试用例:
import torch
import torch.utils.data as Data
import numpy as np test = np.array([0,1,2,3,4,5,6,7,8,9,10,11]) inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)])) torch_dataset = Data.TensorDataset(inputing,target)
batch = 3 loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=batch, # 批大小
# 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
collate_fn=lambda x:(
torch.cat(
[x[i][j].unsqueeze(0) for i in range(len(x))], 0
).unsqueeze(0) for j in range(len(x[0]))
)
) for (i,j) in loader:
print(i)
print(j)
Reference: DataLoader的collate_fn参数
pytorch 读取变长数据
https://zhuanlan.zhihu.com/p/60129684
Pytorch collate_fn用法的更多相关文章
- pytorch faster_rcnn
代码地址:https://github.com/jwyang/faster-rcnn.pytorch 1.fasterRCNN.train():这个不是让网络进行训练,而是让module in tra ...
- Transformers 简介(下)
作者|huggingface 编译|VK 来源|Github Transformers是TensorFlow 2.0和PyTorch的最新自然语言处理库 Transformers(以前称为pytorc ...
- 深度学习与CV教程(8) | 常见深度学习框架介绍
作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/37 本文地址:http://www.showmeai.tech/article-det ...
- Pytorch 一些函数用法
PyTorch中view的用法:https://blog.csdn.net/york1996/article/details/81949843 max用法 import torch d=torch.T ...
- 关于Pytorch的二维tensor的gather和scatter_操作用法分析
看得不明不白(我在下一篇中写了如何理解gather的用法) gather是一个比较复杂的操作,对一个2维tensor,输出的每个元素如下: out[i][j] = input[index[i][j]] ...
- [转载]PyTorch中permute的用法
[转载]PyTorch中permute的用法 来源:https://blog.csdn.net/york1996/article/details/81876886 permute(dims) 将ten ...
- Pytorch中randn和rand函数的用法
Pytorch中randn和rand函数的用法 randn torch.randn(*sizes, out=None) → Tensor 返回一个包含了从标准正态分布中抽取的一组随机数的张量 size ...
- Pytorch中nn.Conv2d的用法
Pytorch中nn.Conv2d的用法 nn.Conv2d是二维卷积方法,相对应的还有一维卷积方法nn.Conv1d,常用于文本数据的处理,而nn.Conv2d一般用于二维图像. 先看一下接口定义: ...
- PyTorch中view的用法
相当于numpy中resize()的功能,但是用法可能不太一样. 我的理解是: 把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其 ...
随机推荐
- 计蒜客 置换的玩笑(DFS)
传送门 题目大意: 小蒜头又调皮了.这一次,姐姐的实验报告惨遭毒手. 姐姐的实验报告上原本记录着从 1 到 n 的序列,任意两个数字间用空格间隔.但是“坑姐”的蒜头居然把数字间的空格都给删掉了,整个数 ...
- 题解 P4942 【小凯的数字】
题目 为什么看到很多题解区的 dalao 都用逆元?是我太菜了吧 [分析] 首先,根据弃九验算法的原理,显然可以得到:一个 \(n\) 位数 \(a_1a_2a_3\dots a_n\equiv a_ ...
- 17.3.10--->关于数值溢出问题
取值范围: short.int.long 占用的字节数不同,所能表示的数值范围也不同.以32位平台为例,下面是它们的取值范围: 数据类型 所占字 ...
- Java机器学习软件介绍
Java机器学习软件介绍 编写程序是最好的学习机器学习的方法.你可以从头开始编写算法,但是如果你要取得更多的进展,建议你采用现有的开源库.在这篇文章中你会发现有关Java中机器学习的主要平台和开放源码 ...
- ZJNU 2136 - 会长的正方形
对于n*m网格 取min(n,m)作为最大的正方形边长 则答案可以表示成 s=1~min(n,m) 对于一个s*s的正方形 用oblq数组储存有多少四个角都在这个正方形边上的正方形 以4*4为例 除了 ...
- 一、Cookie和Session介绍
会话跟踪 1. 什么是会话 * 用户拨打10086,从服务台接通后会话开始: * 用户发出话费查询请求,服务台响应.这是该会话中的一个请求: * 用户发出套餐变更请求,服务台响应.这是该会话中的 ...
- python与mysql部分函数和控制流语法对比
条件语句 python语法 a=int(input("输入一个数[0,100]成绩:")) if 100>=a>=90: print("优") el ...
- debian8.8安装sougou输入法
传送门:http://www.cnblogs.com/ligongzi/p/6137601.html 亲测可用
- SaltStack事件驱动 – event reactor
Event是SaltStack里面的对每个事件的一个记录,它相比job更加底层,Event能记录更加详细的SaltStack事件,比如Minion服务启动后请求Master签发证书或者证书校验的过程, ...
- mybatis 自定义类型转换器 (后台日期类型插入数据库)
后台日期类型插入数据库 有以下几种发法: 1 调用数据库 日期字符串转日期函数 str_to_date("日期","yyyy-MM-dd HH:mm:ss") ...