一般的,默认的collate_fn函数是要求一个batch中的图片都具有相同size(因为要做stack操作),当一个batch中的图片大小都不同时,可以使用自定义的collate_fn函数,则一个batch中的图片不再被stack操作,可以全部存储在一个list中,当然还有对应的label,如下面这个例子:

  1. import torch
  2. from torch.utils.data import DataLoader
  3. from torchvision import transforms
  4. import torchvision.datasets as datasets
  5. import matplotlib.pyplot as plt
  6.  
  7. # a simple custom collate function, just to show the idea
  8. def my_collate(batch):
  9. data = [item[0] for item in batch]
  10. target = [item[1] for item in batch]
  11. target = torch.LongTensor(target)
  12. return [data, target]
  13.  
  14. def show_image_batch(img_list, title=None):
  15. num = len(img_list)
  16. fig = plt.figure()
  17. for i in range(num):
  18. ax = fig.add_subplot(1, num, i+1)
  19. ax.imshow(img_list[i].numpy().transpose([1,2,0]))
  20. ax.set_title(title[i])
  21.  
  22. plt.show()
  23.  
  24. # do not do randomCrop to show that the custom collate_fn can handle images of different size
  25. train_transforms = transforms.Compose([transforms.Scale(size = 224),
  26. transforms.ToTensor(),
  27. ])
  28.  
  29. # change root to valid dir in your system, see ImageFolder documentation for more info
  30. train_dataset = datasets.ImageFolder(root="/hd1/jdhao/toyset",
  31. transform=train_transforms)
  32.  
  33. trainset = DataLoader(dataset=train_dataset,
  34. batch_size=4,
  35. shuffle=True,
  36. collate_fn=my_collate, # use custom collate function here
  37. pin_memory=True)
  38.  
  39. trainiter = iter(trainset)
  40. imgs, labels = trainiter.next()
  41.  
  42. # print(type(imgs), type(labels))
  43. show_image_batch(imgs, title=[train_dataset.classes[x] for x in labels])

pytorch 中Dataloader中的collate_fn参数的更多相关文章

  1. pytorch中DataLoader, DataSet, Sampler之间的关系

    转自:https://mp.weixin.qq.com/s/RTv0cUWvc0kuXBeNoXVu_A 自上而下理解三者关系 首先我们看一下DataLoader.__next__的源代码长什么样,为 ...

  2. pytorch :: Dataloader中的迭代器和生成器应用

    在使用pytorch训练模型,经常需要加载大量图片数据,因此pytorch提供了好用的数据加载工具Dataloader. 为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭 ...

  3. ARTS-S pytorch中backward函数的gradient参数作用

    导数偏导数的数学定义 参考资料1和2中对导数偏导数的定义都非常明确.导数和偏导数都是函数对自变量而言.从数学定义上讲,求导或者求偏导只有函数对自变量,其余任何情况都是错的.但是很多机器学习的资料和开源 ...

  4. Eclipse中自动提示的方法参数都是arg0,arg1的解决方法

    Eclipse中自动提示的方法参数都是arg0,arg1,就不能根据参数名来推断参数的含义,非常不方便. 解决方法:Preferences->Java->Installed JREs,发现 ...

  5. C#调用SQL中的存储过程中有output参数,存储过程执行过程中返回信息

      C#调用SQL中的存储过程中有output参数,类型是字符型的时候一定要指定参数的长度.不然获取到的结果总是只有第一字符.本人就是由于这个原因,折腾了很久.在此记录一下,供大家以后参考! 例如: ...

  6. URL地址中使用中文作为的参数【转】

    原文:http://blog.csdn.net/blueheart20/article/details/43766713 引言: 在Restful类的服务设计中,经常会碰到需要在URL地址中使用中文作 ...

  7. SQL Server存储过程中使用表值作为输入参数示例

    这篇文章主要介绍了SQL Server存储过程中使用表值作为输入参数示例,使用表值参数,可以不必创建临时表或许多参数,即可向 Transact-SQL 语句或例程(如存储过程或函数)发送多行数据,这样 ...

  8. 在VS中向命令行添加参数的方法

    在VS中向命令行添加参数的方法 在VS中向命令行添加参数,即向main()函数传递参数的方法: 右键单击要 添加参数的工程-->属性-->配置属性-->调试,在右侧“命令参数”栏输入 ...

  9. R中的par()函数的参数

    把R中par()函数的主要参数整理了一下(另外本来还整理了每个参数的帮助文档中文解释,但是太长,就分类之后,整理为图表,excel不便放上来,就放了这些表的截图)

随机推荐

  1. 使用systemctl管理nginx

    [Unit] Description=nginx After=network.target [Service] Type=forking ExecStartPre=/data/apps/nginx/s ...

  2. 洛谷 P1082 同余方程 题解

    每日一题 day31 打卡 Analysis 题目问的是满足 ax mod b = 1 的最小正整数 x.(a,b是正整数) 但是不能暴力枚举 x,会超时. 把问题转化一下.观察 ax mod b = ...

  3. 014_Python3 循环语句

    1.while 循环 #!/usr/bin/env python3   n = 100   sum = 0 counter = 1 while counter <= n:     sum = s ...

  4. vue vue-cli中引入全局less变量的方式

    我们经常用less定义一些全局变量,比如主题的颜色,为了避免在每个组件中引用我首先尝试放在main.js中,发现并不起作用... 先看vue-cli2.x 版本如何解决 1.安装; npm insta ...

  5. [nginx]nginx的一个奇葩问题 500 Internal Server Error phpstudy2018 nginx虚拟主机配置 fastadmin常见问题处理

    [nginx]nginx的一个奇葩问题 500 Internal Server Error 解决方案 nginx 一直报500 Internal Server Error 错误,配置是通过phpstu ...

  6. golang orm

    package main import ( "fmt" "github.com/astaxie/beego/orm" _"github.com/go- ...

  7. 6、httpd2.4 编译安装LAMP

    www.itjc8.com 新特性: MPM支持运营DSO机制(动态共享对象),以模块形式按需加载 支持event MPM 支持异步读写 支持每模块及每个目录分别使用各自的日志级别 每请求配置 增强版 ...

  8. Java 基础:单例模式 Singleton Pattern

    1.简介 单例模式(Singleton Pattern)是 Java 中最简单的设计模式之一.这种类型的设计模式属于创建型模式,它提供了一种创建对象的最佳方式. 这种模式涉及到一个单一的类,该类负责创 ...

  9. C# 获取枚举值/获取名字和值

    枚举 int 转 枚举名称 public void Test() { //调用 string name1= ConvertEnumToString<ActionLogType>(1); s ...

  10. [WEB安全]SSRF中URL的伪协议

    当我们发现SSRF漏洞后,首先要做的事情就是测试所有可用的URL伪协议 0x01 类型 file:/// dict:// sftp:// ldap:// tftp:// gopher:// file: ...