第五章——Pytorch中常用的工具
1. 数据处理
数据加载
在Pytorch 中,数据加载可以通过自己定义的数据集对象来实现。数据集对象被抽象为Dataset类,实现自己定义的数据集需要继承Dataset,并实现两个Python魔法方法。
__getitem__
: 返回一条数据或一个样本。obj[index]
等价于obj.__getitem__(index)
.__len__
: 返回样本的数量。len(obj)
等价于obj.__len__()
.
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np
class DogCat(data.Dataset):
def __init__(self,root):
imgs=os.listdir(root)
#所有图片的绝对路径
#这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
self.imgs=[os.path.join(root, img) for img in imgs]
def __getitem__(self, index):
img_path=self.imgs[index]
#dog->1, cat->0
label=1 if 'dog' in img_path.split("/")[-1] else 0
pil_img=Image.open(img_path)
array=np.asarray(pil_img)
data=t.from_numpy(array)
return data,label
def __len__(self):
return len(self.image)
dataset=DogCat('N:/百度网盘/kaggle/DogCat')
img,label=dataset[0]#相当于调用dataset.__getitem__(0)
for img,label in dataset:
print(img.size(),img.float().mean(),label)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
结果:
torch.Size([280, 300, 3]) tensor(71.6653) 0
torch.Size([396, 312, 3]) tensor(131.8400) 0
torch.Size([414, 500, 3]) tensor(156.6921) 0
torch.Size([375, 499, 3]) tensor(96.8243) 0
torch.Size([445, 431, 3]) tensor(103.8582) 1
torch.Size([373, 302, 3]) tensor(160.0512) 1
torch.Size([240, 288, 3]) tensor(95.1983) 1
torch.Size([499, 375, 3]) tensor(90.5196) 1
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
问题:结果大小不一,这对于batch训练的神经网络来说很不友好。
返回的样本数值交大,未归一化至【-1,1】
针对上述问题,pytorch提供了torchvision。它是一个视觉工具包,提供了很多视觉图像处理的工具。
其中transforms模块提供了对PIL Image对象和Tensor对象的常用操作。
对PIL Image的常见操作如下:
- Scale/Resize: 调整尺寸,长宽比保持不变; #Resize
- CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片;
- Pad: 填充;
- ToTensor: 将PIL Image对象转换成Tensor,会自动将【0,255】归一化至【0,1】。
对Tensor的常见操作如下:
- Normalize: 标准化,即减均值,除以标准差;
ToPILImage:将Tensor转为PIL Image.
如果要对图片进行多个操作,可通过Compose将这些操作拼接起来,类似于nn.Sequential.
这些操作定义之后是以对象的形式存在,真正使用时需要调用它的__call__
方法,类似于nn.Mudule.
例如:要将图片调整为224*224,首先应构建操作trans=Scale((224,224))
,然后调用trans(img)
.
import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
transforms=T.Compose([
T.Resize(224), #缩放图片(Image),保持长宽比不变,最短边为224像素
T.CenterCrop(224), #从图片中间裁剪出224*224的图片
T.ToTensor(), #将图片Image转换成Tensor,归一化至【0,1】
T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) #标准化至【-1,1】,规定均值和方差
])
class DogCat(data.Dataset):
def __init__(self,root, transforms=None):
imgs=os.listdir(root)
self.imgs=[os.path.join(root, img) for img in imgs]
self.transforms=transforms
def __getitem__(self, index):
img_path=self.imgs[index]
#dog->1, cat->0
label=1 if 'dog' in img_path.split("/")[-1] else 0
data=Image.open(img_path)
if self.transforms:
data=self.transforms(data)
return data,label
def __len__(self):
return len(self.imgs)
dataset=DogCat('N:/百度网盘/kaggle/DogCat/', transforms=transforms)
img,label=dataset[0]#相当于调用dataset.__getitem__(0)
for img,label in dataset:
print(img.size(),label)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
结果:
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
除了上述操作外,transforms还可以通过Lambda
封装自定义的转换策略.
例如相对PIL Image进行随机旋转,则可以写成trans=T.Lambda(lambda img: img.rotate(random()*360))
.
ImageFolder
下面介绍一个会经常使用到的Dataset——ImageFolder,它的实现和上述DogCat很相似。
ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:
ImageFolder(root, transform=None, target_transform = None, loader = default_loader)
- 1
- 2
它主要有四个参数:
- root :在root指定的路径下寻找图片
- transform: 对PIL Image进行转换操作, transform的输入是使用loader读取图片返回的对象;
- target_transform :对label的转换;
- loader: 指定加载图片的函数,默认操作是读取为PIL Image对象。
label是按照文件夹名顺序排序后存成字典的,即{类名:类序号(从0开始)}
,一般来说最好直接将文件夹命名为从0开始的数字,这样会和ImageFolder实际的label一直,如果不是这种命名规则,建议通过self.class_to_idx
属性了解label和文件夹名的映射关系。
from torchvision.datasets import ImageFolder
dataset=ImageFolder('N:\\data\\')
dataset.class_to_idx
- 1
- 2
- 3
运行结果:
{'cat': 0, 'dog': 1}
- 1
输入:
#所有图片的路径和对应的label
dataset.imgs
- 1
- 2
输出:
[('N:\\data\\cat\\cat.1.jpg', 0),
('N:\\data\\cat\\cat.2.jpg', 0),
('N:\\data\\cat\\cat.3.jpg', 0),
('N:\\data\\cat\\cat.4.jpg', 0),
('N:\\data\\dog\\dog.9131.jpg', 1),
('N:\\data\\dog\\dog.9132.jpg', 1),
('N:\\data\\dog\\dog.9133.jpg', 1),
('N:\\data\\dog\\dog.9134.jpg', 1)]
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
#没有任何的transform,所以返回的还是PIL Image对象
dataset[0][1] #第一维是第几张图,第二维为1返回label
- 1
- 2
输出:0
dataset[0][0] #第一维是第几张图,第二维为0返回图片数据,返回的Image对象如图所示:
- 1
输出:
加上transform:
normilize=T.Normalize(mean=[0.4,0.4,0.4],std=[0.2,0.2,0.2])
transform=T.Compose([
T.RandomResizedCrop (224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normilize,
])
- 1
- 2
- 3
- 4
- 5
- 6
- 7
dataset=ImageFolder('N:\\data\\',transform=transform)
#深度学习中图片数据一般保存为CxHxWx,即通道数x图片高x图片宽
dataset[0][0].size()
- 1
- 2
- 3
输出:
torch.Size([3, 224, 224])
- 1
to_img=T.ToPILImage()
#0.2和0.4是标准差和均值的近似
to_img(dataset[0][0]*0.2+0.4)
- 1
- 2
- 3
输出:
DataLoader加载数据
Dateset只负责数据的抽象,一次调用__getitem__
只返回一个样本。
在训练神经网络时,是对一个batch的数据进行操作,同时还要进行shuffle和并行加速等。
对此,pytorch
提供了DataLoader
帮助我们实现这些功能。
DataLoader的函数定义如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
- 1
- 2
- dataset: 加载的数据集)Dataset对象;
- batch_size: 批大小;
- shuffle:是否将数据打乱;
- sampler:样本抽样
- num_workers:使用多进程加载的进程数,0代表不使用多进程;
- collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可;
- pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些;
- drop_last:dataset 中的数据个数可能不是 batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃。
from torch.utils.data import DataLoader
dataloader=DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
dataiter=iter(dataloader)
imgs,labels=next(dataiter)
imgs.size()
- 1
- 2
- 3
- 4
- 5
- 6
输出:
torch.Size([3, 3, 224, 224])
- 1
dataloader是一个可迭代的对象,我们可以像使用迭代器一样使用它,例如:
for batch_datas,batch_labels in dataloader:
train()
- 1
- 2
或
dataiter=iter(dataloader)
batch_datas,batch_labels =next(dataiter)
- 1
- 2
sampler:采样模块
Pytorch 中还提供了一个sampler
模块,用来对数据进行采样。
常用的有随机采样器RandonSampler
,当dataloader
的shuffle
参数为True
时,系统会自动调用这个采样器 ,实现打乱数据。
默认的采样器是SequentialSampler, 它会按顺序一个一个进行采样。
这里介绍另外一个很有用的采样方法:它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它进行重采样。
构建WeightedRandomSampler
时需提供两个参数:每个样本的权重weights
、共选取的样本总数num_samples
,以及一个可选参数replacement
。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部样本数目。
replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。
如果设为False,则当某一类样本被全部选取完,但样本数目仍为达到num_samples时,sampler将不会再从该类中选取数据,此时可能导致weights参数失效。
下面举例说明:
1)
#dataset=DogCat('N:/百度网盘/kaggle/DogCat/',transforms=transforms)
dataset=DogCat('N:/百度网盘/kaggle/DogCat/', transforms=transforms)
#img,label=dataset[0]#相当于调用dataset.__getitem__(0)
#狗的图片取出的概率是猫的概率的两倍
#两类取出的概率与weights的绝对值大小无关,之和比值有关
weights=[2 if label==1 else 1 for data ,label in dataset]
weights
- 1
- 2
- 3
- 4
- 5
- 6
- 7
输出:
[1, 1, 1, 1, 2, 2, 2, 2]
- 1
2)
from torch.utils.data.sampler import WeightedRandomSampler
sampler=WeightedRandomSampler(weights,
num_samples=9,
replacement=True)
dataloader=DataLoader(dataset,
batch_size=3,
sampler=sampler)
for datas,labels in dataloader:
print(labels.tolist())
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
输出:
[1, 0, 1]
[1, 0, 1]
[1, 0, 1]
- 1
- 2
- 3
可见猫狗样本比例约为1:2,另外一共有8个样本,却返回了9个样本,说明样本有被重复返回的,这就是replacement参数的作用。
下面我们将replacement设置为False.
from torch.utils.data.sampler import WeightedRandomSampler
sampler=WeightedRandomSampler(weights,num_samples=8,replacement=False)
dataloader=DataLoader(dataset,batch_size=4,sampler=sampler)
for datas,labels in dataloader:
print(labels.tolist())
- 1
- 2
- 3
- 4
- 5
输出:
[0, 0, 1, 0]
[1, 0, 1, 1]
- 1
- 2
在这种情况下,num_samples等于dataset的样本总数,为了 不重复选取,sampler会将每个样本都返回,这样就失去了weight的意义。
从上面的例子可见sampler
在采样中的作用:如果指定了sampler
,shuffle
将不再生效,并且sampler.num_smples
会覆盖dataset
的实际大小,即一个epoch返回的图片总数取决于sampler.num_samples
.
总结:
完整代码:
import os
from PIL import Image
from torch.utils import data
#import numpy as np
from torchvision import transforms as T
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
transforms=T.Compose([
T.Resize(224), #缩放图片(Image),保持长宽比不变,最短边为224像素
T.CenterCrop(224), #从图片中间裁剪出224*224的图片
T.ToTensor(), #将图片Image转换成Tensor,归一化至【0,1】
T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) #标准化至【-1,1】,规定均值和方差
])
class DogCat(data.Dataset):
def __init__(self,root, transforms=None):
imgs=os.listdir(root)
self.imgs=[os.path.join(root, img) for img in imgs]
self.transforms=transforms
def __getitem__(self, index):
img_path=self.imgs[index]
#dog->1, cat->0
label=1 if 'dog' in img_path.split("/")[-1] else 0
data=Image.open(img_path)
if self.transforms:
data=self.transforms(data)
return data,label
def __len__(self):
return len(self.imgs)
dataset=DogCat('N:/百度网盘/kaggle/DogCat/', transforms=transforms)
img,label=dataset[0]#相当于调用dataset.__getitem__(0)
print("******dataset*************")
print("dataset")
for img,label in dataset:
print(img.size(),label)
dataset=DogCat('N:/百度网盘/kaggle/DogCat/', transforms=transforms)
#狗的图片取出的概率是猫的概率的两倍
#两类取出的概率与weights的绝对值大小无关,之和比值有关
weights=[2 if label==1 else 1 for data ,label in dataset]
print("******weights**************")
print("weight:{}".format(weights))
print("******sampler**************")
sampler=WeightedRandomSampler(weights,num_samples=8,replacement=False)
dataloader=DataLoader(dataset,batch_size=4,sampler=sampler)
for datas,labels in dataloader:
print(labels.tolist())
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
输出:
******dataset*************
dataset
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
******weights**************
weight:[1, 1, 1, 1, 2, 2, 2, 2]
******sampler**************
[0, 0, 1, 1]
[1, 0, 1, 0]
第五章——Pytorch中常用的工具的更多相关文章
- Java基础学习(五)-- Java中常用的工具类、枚举、Java中的单例模式之详解
Java中的常用类 1.Math : 位于java.lang包中 (1)Math.PI:返回一个最接近圆周率的 (2)Math.abs(-10):返回一个数的绝对值 (3)Math.cbrt(27): ...
- 第五章、 Linux 常用網路指令
http://linux.vbird.org/linux_server/0140networkcommand.php 第五章. Linux 常用網路指令 切換解析度為 800x600 最近更新 ...
- 计算机图形学 opengl版本 第三版------胡事民 第四章 图形学中的向量工具
计算机图形学 opengl版本 第三版------胡事民 第四章 图形学中的向量工具 一 基础 1:向量分析和变换 两个工具 可以设计出各种几何对象 点和向量基于坐标系定义 拇指指向z轴正 ...
- 【全面解禁!真正的Expression Blend实战开发技巧】第五章 从最常用ButtonStyle开始 - ImageButton
原文:[全面解禁!真正的Expression Blend实战开发技巧]第五章 从最常用ButtonStyle开始 - ImageButton 本章围绕ImageButton深入讨论,为什么是Image ...
- java中常用的工具类(一)
我们java程序员在开发项目的是常常会用到一些工具类.今天我汇总了一下java中常用的工具方法.大家可以在项目中使用.可以收藏!加入IT江湖官方群:383126909 我们一起成长 一.String工 ...
- shell编程系列7--shell中常用的工具find、locate、which、whereis
shell编程系列7--shell中常用的工具find.locate.which.whereis .文件查找之find命令 语法格式:find [路径] [选项] [操作] 选项 -name 根据文件 ...
- java中常用的工具类(三)
继续分享java中常用的一些工具类.前两篇的文章中有人评论使用Apache 的lang包和IO包,或者Google的Guava库.后续的我会加上的!谢谢支持IT江湖 一.连接数据库的综合类 ...
- java中常用的工具类(二)
下面继续分享java中常用的一些工具类,希望给大家带来帮助! 1.FtpUtil Java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ...
- SDK中常用的工具
Android SDK包含了各种各样的定制工具,简介如下: 一.Android模拟器(Android Emulator )它是在你的计算机上运行的一个虚拟移动设备.你可以使用模拟器来在一个实际的And ...
随机推荐
- 吴裕雄--天生自然C++语言学习笔记:C++ STL 教程
C++ STL(标准模板库)是一套功能强大的 C++ 模板类,提供了通用的模板类和函数,这些模板类和函数可以实现多种流行和常用的算法和数据结构,如向量.链表.队列.栈. C++ 标准模板库的核心包括以 ...
- express连接数据库 读取表
connection 连接数据库 connection.query 查询表 1.依赖 const mysql = require('mysql'); 连接数据库代码 var connecti ...
- POJ 3368:Frequent values
Frequent values Time Limit: 2000MS Memory Limit: 65536K Total Submissions: 14764 Accepted: 5361 ...
- 并发 ping
参考 [root@RS2 ~]# cat .sh #!/bin/bash # --, by wwy #------------------------------------------------- ...
- 1. react 简书 项目初始化
1. 创建 react 项目 npx create-react-app my-app 2. src 目录下删除 除了 index.js index.css app.js 的文件 3. 引入 style ...
- filter滤镜兼容ie的rgba属性
要在一个页面中设置一个半透明的白色div.这个貌似不是难题,只需要给这个div设置如下的属性即可: background: rgba(255,255,255,0.1); 但是要兼容到ie8.这个就有点 ...
- oracle 查询char类型的数据
曾经遇到一个坑. ';//使用PLSQL工具 能查出结果 偏偏在java代码里面查询不出结果. select taskdate from taskinfo where taskdate='201808 ...
- BZOJ:2186: [Sdoi2008]沙拉公主的困惑
问题:可能逆元不存在吗? 题解: Gcd(a,b)==Gcd(b,a-b); 从数据范围可以看出应该求M!的欧拉函数: 然后通过Gcd转化过去 一开始没想到 #include<iostream& ...
- VLOOKUP返回#N/A结果
VLOOKUP返回#N/A结果 1.无目标值 使用control+f查找是否存在所要搜索的值. 2.位置错误 所要搜索区域,被搜索值必须在首列. 3.格式错误 搜索值和被搜索区域格式需一致. 4.特殊 ...
- Redis: Reducing Memory Usage
High Level Tips for Redis Most of Stream-Framework's users start out with Redis and eventually move ...