pytorch ImageFolder的覆写
在为数据分类训练分类器的时候,比如猫狗分类时,我们经常会使用pytorch的ImageFolder:
CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)
使用可见pytorch torchvision.ImageFolder的使用
这里想实现的是如果想要覆写该函数,即能使用它的特性,又可以实现自己的功能
首先先分析下其源代码:
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp'] class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: :: root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png root/cat/.png
root/cat/nsdf3.png
root/cat/asd932_.png Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path. Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
transform=transform,
target_transform=target_transform)
self.imgs = self.samples
ImageFolder的代码很简单,主要是继承了DatasetFolder:
def has_file_allowed_extension(filename, extensions):
"""查看文件是否是支持的可扩展类型 Args:
filename (string): 文件路径
extensions (iterable of strings): 可扩展类型列表,即能接受的图像文件类型 Returns:
bool: True if the filename ends with one of given extensions
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions) # 返回True或False列表 def make_dataset(dir, class_to_idx, extensions):
"""
返回形如[(图像路径, 该图像对应的类别索引值),(),...]
"""
images = []
dir = os.path.expanduser(dir)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue for root, _, fnames in sorted(os.walk(d)): #层层遍历文件夹,返回当前文件夹路径,存在的所有文件夹名,存在的所有文件名
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):查看文件是否是支持的可扩展类型,是则继续
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item) return images class DatasetFolder(data.Dataset):
"""A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext root/class_y/.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext Args:
root (string): 根目录路径
loader (callable): 根据给定的路径来加载样本的可调用函数
extensions (list[string]): 可扩展类型列表,即能接受的图像文件类型.
transform (callable, optional): 用于样本的transform函数,然后返回样本transform后的版本
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): 用于样本标签的transform函数 Attributes:
classes (list): 类别名列表
class_to_idx (dict): 项目(class_name, class_index)字典,如{'cat': , 'dog': }
samples (list): (sample path, class_index) 元组列表,即(样本路径, 类别索引)
targets (list): 在数据集中每张图片的类索引值,为列表
""" def __init__(self, root, loader, extensions, transform=None, target_transform=None):
classes, class_to_idx = self._find_classes(root) # 得到类名和类索引,如['cat', 'dog']和{'cat': , 'dog': }
# 返回形如[(图像路径, 该图像对应的类别索引值),(),...],即对每个图像进行标记
samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == :
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions))) self.root = root
self.loader = loader
self.extensions = extensions self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[] for s in samples] #所有图像的类索引值组成的列表 self.transform = transform
self.target_transform = target_transform def _find_classes(self, dir):
"""
在数据集中查找类文件夹。 Args:
dir (string): 根目录路径 Returns:
返回元组: (classes, class_to_idx)即(类名, 类索引),其中classes即相应的目录名,如['cat', 'dog'];class_to_idx为形如{类名:类索引}的字典,如{'cat': , 'dog': }. Ensures:
保证没有类名是另一个类目录的子目录
"""
if sys.version_info >= (, ):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()] #获得根目录dir的所有第一层子目录名
else:
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] #效果和上面的一样,只是版本不同方法不同
classes.sort() #然后对类名进行排序
class_to_idx = {classes[i]: i for i in range(len(classes))} #然后将类名和索引值一一对应的到相应字典,如{'cat': , 'dog': }
return classes, class_to_idx #然后返回类名和类索引 def __getitem__(self, index):
"""
Args:
index (int): Index Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path) # 加载图片
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target) return sample, target def __len__(self):
return len(self.samples) def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
此时想要覆写ImageFolder,代码为:
class CustomImageFolder(ImageFolder):
"""
为了得到两张图(其中一张是随机选取的)的图像和索引值信息
"""
def __init__(self, root, transform=None):
super(CustomImageFolder, self).__init__(root, transform)
self.indices = range(len(self)) #该文件夹中的长度 def __getitem__(self, index1):
index2 = random.choice(self.indices) #从[,indices]中随机抽取一个数字,为了随机选取一张图 path1 = self.imgs[index1][] #此时的self.imgs等于self.samples,即内容为[(图像路径, 该图像对应的类别索引值),(),...]
label1 = self.imgs[index1][]
path2 = self.imgs[index2][]
label2 = self.imgs[index2][] img1 = self.loader(path1)
img2 = self.loader(path2)
if self.transform is not None:
img1 = self.transform(img1)
img2 = self.transform(img2) return img1, img2, label1, label2
pytorch ImageFolder的覆写的更多相关文章
- java重载与覆写
很多同学对于overload和override傻傻分不清楚,建议不要死记硬背概念性的知识,要理解着去记忆. 先给出我的定义: overload(重载):在同一类或者有着继承关系的类中,一组名称相同,参 ...
- 在C#中该如何阻止虚方法的覆写
在开发过程中,我们为了让一个类更有生命力,有时会用virtual来修饰一个方法好让子类来覆写它.但是如果有更新的子子类来覆写时,我们又不想让其影响到上一层的覆写,这时候就要用到new virtual来 ...
- JAVA继承与覆写
实例:数组操作 首先是开发一个整型数组父类,要求从外部控制数组长度,并实现保存数据以及输出.然后子类中实现排序和反转. 基础父类代码如下: class Array { private int data ...
- C#类的继承,方法的重载和覆写
在网易云课堂上看到唐大仕老师讲解的关于类的继承.方法的重载和覆写的一段代码,注释比较详细,在此记下以加深理解. 小总结: 1.类的继承:允许的实例化方式:Student t=new Student() ...
- C#使用基类的引用 and 虚方法和覆写方法
结论:使用基类的引用,访问派生类对象时,得到的是基类的成员. 虚方法和覆写方法
- Java中方法的覆写
所谓方法的覆写override就是子类定义了与父类中同名的方法,但是在方法覆写时必须考虑权限,即被子类覆写的方法不能拥有比父类方法更加严格的访问权限. 修饰符分别为public.protected.d ...
- 黑马程序员——JAVA基础之简述 类的继承、覆写
------- android培训.java培训.期待与您交流! ---------- 继承的概述: 多个类中存在相同属性和行为时,将这些内容抽取到单独一个类中,那么多个类无需再定义这些属性和行为,只 ...
- JAVA中继承时方法的重载(overload)与重写/覆写(override)
JAVA继承时方法的重载(overload)与重写/覆写(override) 重载-Override 函数的方法参数个数或类型不一致,称为方法的重载. 从含义上说,只要求参数的个数或参数的类型不一致就 ...
- Android开发之Source无法覆写public void onClick(View v)
初学Android开发,在为一个按钮[该按钮继承OnClickListener()]写监听时,发现无法在Source中引入public void onClick(View v),当时非常纳闷,平常情况 ...
随机推荐
- Docker 安装HDFS
网上拉取Docker模板,使用singlarities/hadoop镜像 [root@localhost /]# docker pull singularities/hadoop 查看: [root@ ...
- Docker镜像管理基础篇
Docker镜像管理基础篇 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.Docker Images Docker镜像还有启动容器所需要的文件系统及其内容,因此,其用于创建并启 ...
- 发布WS接口与实现WS接口[小列子]
webservice简介:Web Service技术, 能使得运行在不同机器上的不同应用无须借助附加的.专门的第三方软件或硬件, 就可相互交换数据或集成.依据Web Service规范实施的应用之间, ...
- javascript数据结构与算法——列表
前言: 1. 数据的存储结构顺序不重要,也不必对数据进行查找,列表就是一种很好的数据存储结构; 2.此列表采用仿原生数组的原型链上的方法来写,具体可以参考MDN数组介绍,并么有用prototype来构 ...
- jpa之No property buyerOpenId found for type OrderMaster! Did you mean 'buyerOpenid'?
java.lang.IllegalStateException: Failed to load ApplicationContext at org.springframework.test.conte ...
- 08 c++中运算符重载(未完成)
参考:轻松搞定c++语言 定义:赋予已有运算符多重含义,实现一名多用(比较函数重载) 运算符重载的本质是函数重载 重载函数的格式: 函数类型 operator 运算符名称(形参表列) { 重载实体 ...
- 洛谷 P1250 种树 题解
差分约束系统,维护前缀和,根据式子d[ b ] < = d[ e + 1 ] - t,可以看出要连e和b - 1,但占用了超级源点0,所以要把区间向后移,这样就可以用超级源点0来保持图的连通性( ...
- (尚026)Vue_案例_动态初始化显示(尚025)
(1).当前页面需要变化什么样的数据? 答:列表;应该有个todos:[]数组;数组中包含每个元素均为一个对象;有数据titles:'xxx';(勾不勾选)complete:'布尔类型' (2).数组 ...
- 自用 goodsdetail
JSON.parse(data.parameter) 存的字符串 <select id="getGoodsBaseInfoById" resultType="co ...
- Windbg命令的语法规则系列(三)
五.源文件行语法 可以将源文件行号指定为MASM表达式的全部或部分.这些数字计算出与该源代码行对应的可执行代码的偏移量.不能使用源代码行作为C++表达式的一部分.必须用重音符(`)将源文件和行号表达式 ...