在为数据分类训练分类器的时候,比如猫狗分类时,我们经常会使用pytorch的ImageFolder:

CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)

使用可见pytorch torchvision.ImageFolder的使用

这里想实现的是如果想要覆写该函数,即能使用它的特性,又可以实现自己的功能

首先先分析下其源代码:

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']

class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: :: root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png root/cat/.png
root/cat/nsdf3.png
root/cat/asd932_.png Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path. Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
transform=transform,
target_transform=target_transform)
self.imgs = self.samples

ImageFolder的代码很简单,主要是继承了DatasetFolder:

def has_file_allowed_extension(filename, extensions):
"""查看文件是否是支持的可扩展类型 Args:
filename (string): 文件路径
extensions (iterable of strings): 可扩展类型列表,即能接受的图像文件类型 Returns:
bool: True if the filename ends with one of given extensions
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions) # 返回True或False列表 def make_dataset(dir, class_to_idx, extensions):
"""
返回形如[(图像路径, 该图像对应的类别索引值),(),...]
"""
images = []
dir = os.path.expanduser(dir)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue for root, _, fnames in sorted(os.walk(d)): #层层遍历文件夹,返回当前文件夹路径,存在的所有文件夹名,存在的所有文件名
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):查看文件是否是支持的可扩展类型,是则继续
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item) return images class DatasetFolder(data.Dataset):
"""A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext root/class_y/.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext Args:
root (string): 根目录路径
loader (callable): 根据给定的路径来加载样本的可调用函数
extensions (list[string]): 可扩展类型列表,即能接受的图像文件类型.
transform (callable, optional): 用于样本的transform函数,然后返回样本transform后的版本
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): 用于样本标签的transform函数 Attributes:
classes (list): 类别名列表
class_to_idx (dict): 项目(class_name, class_index)字典,如{'cat': , 'dog': }
samples (list): (sample path, class_index) 元组列表,即(样本路径, 类别索引)
targets (list): 在数据集中每张图片的类索引值,为列表
""" def __init__(self, root, loader, extensions, transform=None, target_transform=None):
classes, class_to_idx = self._find_classes(root) # 得到类名和类索引,如['cat', 'dog']和{'cat': , 'dog': }
# 返回形如[(图像路径, 该图像对应的类别索引值),(),...],即对每个图像进行标记
samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == :
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions))) self.root = root
self.loader = loader
self.extensions = extensions self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[] for s in samples] #所有图像的类索引值组成的列表 self.transform = transform
self.target_transform = target_transform def _find_classes(self, dir):
"""
在数据集中查找类文件夹。 Args:
dir (string): 根目录路径 Returns:
返回元组: (classes, class_to_idx)即(类名, 类索引),其中classes即相应的目录名,如['cat', 'dog'];class_to_idx为形如{类名:类索引}的字典,如{'cat': , 'dog': }. Ensures:
保证没有类名是另一个类目录的子目录
"""
if sys.version_info >= (, ):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()] #获得根目录dir的所有第一层子目录名
else:
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] #效果和上面的一样,只是版本不同方法不同
classes.sort() #然后对类名进行排序
class_to_idx = {classes[i]: i for i in range(len(classes))} #然后将类名和索引值一一对应的到相应字典,如{'cat': , 'dog': }
return classes, class_to_idx #然后返回类名和类索引 def __getitem__(self, index):
"""
Args:
index (int): Index Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path) # 加载图片
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target) return sample, target def __len__(self):
return len(self.samples) def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str

此时想要覆写ImageFolder,代码为:

class CustomImageFolder(ImageFolder):
"""
为了得到两张图(其中一张是随机选取的)的图像和索引值信息
"""
def __init__(self, root, transform=None):
super(CustomImageFolder, self).__init__(root, transform)
self.indices = range(len(self)) #该文件夹中的长度 def __getitem__(self, index1):
index2 = random.choice(self.indices) #从[,indices]中随机抽取一个数字,为了随机选取一张图 path1 = self.imgs[index1][] #此时的self.imgs等于self.samples,即内容为[(图像路径, 该图像对应的类别索引值),(),...]
label1 = self.imgs[index1][]
path2 = self.imgs[index2][]
label2 = self.imgs[index2][] img1 = self.loader(path1)
img2 = self.loader(path2)
if self.transform is not None:
img1 = self.transform(img1)
img2 = self.transform(img2) return img1, img2, label1, label2

pytorch ImageFolder的覆写的更多相关文章

  1. java重载与覆写

    很多同学对于overload和override傻傻分不清楚,建议不要死记硬背概念性的知识,要理解着去记忆. 先给出我的定义: overload(重载):在同一类或者有着继承关系的类中,一组名称相同,参 ...

  2. 在C#中该如何阻止虚方法的覆写

    在开发过程中,我们为了让一个类更有生命力,有时会用virtual来修饰一个方法好让子类来覆写它.但是如果有更新的子子类来覆写时,我们又不想让其影响到上一层的覆写,这时候就要用到new virtual来 ...

  3. JAVA继承与覆写

    实例:数组操作 首先是开发一个整型数组父类,要求从外部控制数组长度,并实现保存数据以及输出.然后子类中实现排序和反转. 基础父类代码如下: class Array { private int data ...

  4. C#类的继承,方法的重载和覆写

    在网易云课堂上看到唐大仕老师讲解的关于类的继承.方法的重载和覆写的一段代码,注释比较详细,在此记下以加深理解. 小总结: 1.类的继承:允许的实例化方式:Student t=new Student() ...

  5. C#使用基类的引用 and 虚方法和覆写方法

    结论:使用基类的引用,访问派生类对象时,得到的是基类的成员. 虚方法和覆写方法

  6. Java中方法的覆写

    所谓方法的覆写override就是子类定义了与父类中同名的方法,但是在方法覆写时必须考虑权限,即被子类覆写的方法不能拥有比父类方法更加严格的访问权限. 修饰符分别为public.protected.d ...

  7. 黑马程序员——JAVA基础之简述 类的继承、覆写

    ------- android培训.java培训.期待与您交流! ---------- 继承的概述: 多个类中存在相同属性和行为时,将这些内容抽取到单独一个类中,那么多个类无需再定义这些属性和行为,只 ...

  8. JAVA中继承时方法的重载(overload)与重写/覆写(override)

    JAVA继承时方法的重载(overload)与重写/覆写(override) 重载-Override 函数的方法参数个数或类型不一致,称为方法的重载. 从含义上说,只要求参数的个数或参数的类型不一致就 ...

  9. Android开发之Source无法覆写public void onClick(View v)

    初学Android开发,在为一个按钮[该按钮继承OnClickListener()]写监听时,发现无法在Source中引入public void onClick(View v),当时非常纳闷,平常情况 ...

随机推荐

  1. xshell连接linux使用vim无法正常使用小键盘

    解决方法 文件-->属性-->终端-->终端类型-->linux 之后重新连接即可

  2. X2E车载数据记录仪

            随着智能驾驶及网联技术深入应用,汽车中传输的数据量与日俱增,包括多种总线数据.视频数据.雷达数据.定位数据等等.据悉,高级别智能驾驶汽车中每秒传输的总线数据就达到G比特级别.而从产品开 ...

  3. Jquery无须浏览实现直接下载文件

    一.常用方式: 1.通常GET方式 后面带明文参数,不安全. window.location.href = 'http://localhost:1188/FileDownload.aspx?token ...

  4. Linux学习笔记——管道PIPE

    管道:当从一个进程连接数据流到另一个进程时,使用术语管道(pipe).# include <unistd.h> int pipe(int filedes[2]); //创建管道 pipe( ...

  5. 关于System.AccessViolationException异常

    什么是AccessViolationException 试图读写受保护内存时引发的异常. 继承 Object Exception SystemException AccessViolationExce ...

  6. .NET Core入门程序及命令行练习

    用命令行一步一步新建项目.添加Package.Restore.Build.Run 执行的实现方式,更让容易让我们了解.NET Core的运行机制. 准备工作 安装.NET Core 运行环境,下载地址 ...

  7. nexus 3.17.0 做为golang 的包管理工具

    nexus 3.17.0 新版本对于go 包管理的支持是基于go mod 的,同时我们也需要一个athens server 然后在nexus 中配置proxy 类型的repo 参考配置 来自官方的配置 ...

  8. JMX脚本在某些机器上报错,有的运行超时

    运行超时的 是因为在server端运行命令执行脚本,是server给agent下达的指定,但是server端到agent的10050端口没开,所以或一致堵死在那,知道执行超时, 解决:开通server ...

  9. 7kyu kata

    https://www.codewars.com/kata/isograms/train/java CW 大神 solution: public class isogram { public stat ...

  10. bootstrap导航条组件

    一.导航条模板(官方文档) <nav class="navbar navbar-default"> <div class="container-flui ...