数据导入与处理

来自这里

在解决任何机器学习问题时,都需要在处理数据上花费大量的努力。PyTorch提供了很多工具来简化数据加载,希望使代码更具可读性。在本教程中,我们将学习如何从繁琐的数据中加载、预处理数据或增强数据。

开始本教程之前,请确认你已安装如下Python包:

  • scikit-image:图像IO操作和格式转换
  • pandas:更方便解析CSV

我们接下来要处理的数据集是人脸姿态。这意味着人脸的注释如下:

总之,每个面部都有68个不同标记点。

可以从这里下载数据集,并将其解压后存放到目录‘data/faces/’。

数据集来自带有面部注释的CSV文件,文件内容类似以下格式:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

接下来我们快速读取CSV文件,并从(N,2)数组中获取注释,N表示标记数量。

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2) print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

输出:

Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
[33. 76.]
[34. 86.]
[34. 97.]]

现在我们写一个简单的帮助函数:展示图片和它的标记,用它来展示样本。

def show_landmarks(image,landmarks):
'''
展示带标记点的图像
'''
plt.imshow(image)
plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker='.',c='r')
plt.pause(10) plt.figure()
show_landmarks(io.imread(os.path.join('data/faces',img_name)),landmarks)
plt.show()

数据集类(Dataset class)

torch.utils.data.Dataset是一个表示数据集的抽象类。你自定义的数据集应该继承Dataset并重写以下方法:

  • len 这样len(dataset)是可以返回数据集的大小
  • getitem 支持索引操作,比如dataset[i]来获取第i个样本。

现在我们来实现我们的面部标记数据集类。我们将在__init__中读取CSV,然后再__getitem__中读取图像。这样可以高效利用内存,因为所有的图像并不是都存在在内存中,而是按需读取。

我们数据集的样本是字典格式的:{'image':image,'landmarks':landmarks}。我们的数据集将采用可选参数transform,以便任何必要的处理都可以被应用在样本上。在下一节中我们会看到transform的用途。

class FaceLandmarksDataset(Dataset):
'''
Face Landmarks Dataset
''' def __init__(self,csv_file,root_dir,transform=None):
'''
param csv_file(string): 带注释的CSV文件路径
param root_dit(string): 存储图像的路径
param transform(callable,optional): 被应用到样本的可选transform操作
'''
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform def __len__(self):
return len(self.landmarks_frame) def __getitem__(self,idx):
img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx,0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx,1:]
landmarks = np.array([landmarks])
landmarks = landmarks.astype('float').reshape(-1,2)
sample = {'image':image,'landmarks':landmarks} if self.transform:
sample = self.transform(sample) return sample

现在我们实例化这个类,并且迭代输出部分样本。我们打印输出前4个样本并展示它们的标记。

face_dataset = FaceLandmarksDataset(
csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/') fig = plt.figure() for i in range(len(face_dataset)):
sample = face_dataset[i] print(i, sample['image'].shape, sample['landmarks'].shape) ax = plt.subplot(1, 4, i+1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample) if i == 3:
plt.show()
break

输出:

0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

Transforms(转换)

从上面的例子可以看出这些样本的尺寸并不一致。大多数神经网络都期望图像的尺寸是固定的。这样的话,我们就需要一些处理代码。接下来我们创建三个变换函数:

  • Rescale:缩放图像
  • RandomCrop:随机裁剪图像。这是数据扩充。
  • ToTensor:将numpy图像转为torch图像(我们需要交换轴)。

我们将以类而不是简单的函数的方式来实现它们,这样就不需要在每次调用时都传递转换需要的参数。这样我们只需要实现__call__方法,需要的话还可以实现__init__方法。然后我们可以按如下的方式使用:

tsfm = Transform(params)
transformed_sample = tsfm(sample)

下面展示如何将这些转换同时应用在图像和标记点。

class Rescale(object):
'''
按给定的尺寸缩放图像 param output_size (tuple or int): 目标输出尺寸。如果是tuple,输出为匹配的输出尺寸;如果是int,则匹配较小的图像边缘,保证相同的长宽比例。
''' def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size*h/w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size*w/h
else:
new_h, new_w = self.output_size new_h, new_w = int(new_h), int(new_w) img = transform.resize(image, (new_h, new_w)) landmarks = landmarks*[new_w/w, new_h/h]
return {'image': img, 'landmarks': landmarks} class RandomCrop(object):
'''
随机裁剪图像 param output_size (tuple or int): 目标输出尺寸。如果是int,正方形裁剪
''' def __init__(self, output_size):
assert isinstance(output_size, (int, tuple)) if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2]
new_h, new_w = self.output_size top = np.random.randint(0, h-new_h)
left = np.random.randint(0, w-new_w) image = image[top:top+new_h, left:left+new_w] landmarks = landmarks - [left, top] return {'image': image, 'landmarks': landmarks} class ToTensor(object):
'''
将ndarrays格式样本转换为Tensors
''' def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks'] image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}

组合变换

现在,我们在样本上应用转换。

比如我们想将图片的短边缩放为256然后在随机裁剪出一个224的正方形,那么我们将用到RescaleRandomCroptorchvision.transforms.Compost可以帮助我们完成上述组合操作。

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256), RandomCrop(224)]) fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
transformed_sample = tsfrm(sample) ax = plt.subplot(1, 3, i+1)
plt.tight_layout()
ax.set_title(type(tsfrm).__name__)
show_landmarks(**transformed_sample)
plt.show()

遍历数据集

接下来我们将上面的代码整合起来,创建一个带有组合变换的数据集。综上所述,每次采样该数据集时:

  • 从文件中动态读取图像
  • 转换应用到读取的图像上
  • 由于其中一种转换是随机的,因此数据在抽样时得到了扩充

我们可以是像之前一样用for i in range循环遍历创建的数据集:

transformed_dataset = FaceLandmarksDataset(
csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/',
transform=transforms.Compose([Rescale(256), RandomCrop(224), ToTensor()])
) for i in range(len(transformed_dataset)):
sample = transformed_dataset[i] print(i, sample['image'].size(), sample['landmarks'].size()) if i == 3:
break

输出:

0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])

然而,只是使用建档的for训练遍历数据,我们将丢失很多特征。尤其是我们丢失了:

  • 批量处理数据
  • 移动数据
  • 使用multiprocessing并行加载数据

torch.utils.data.DataLoader是一个提供了所有这些功能的迭代器。接下来使用的参数是明朗的。一个有趣的参数是collate_fn。你可以使用collate_fn指定需要如何对样本进行批量处理。然而,默认的collate足够胜任大多数使用场景。

dataloader = DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=4) def show_landmarks_batch(sample_batched):
'''
批量展示样本
'''
images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']
batch_size = len(sample_batched)
im_size = images_batch.size(2)
grid_border_size = 2 grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0))) for i in range(batch_size):
plt.scatter(
landmarks_batch[i, :, 0].numpy() + i * im_size +
(i+1)*grid_border_size,
landmarks_batch[i, :, 1].numpy() + grid_border_size,
s=10, marker='.', c='r'
)
plt.title('Batch from dataloader') for i_batch,sample_batched in enumerate(dataloader):
print(i_batch, sample_batched['image'].size(),
sample_batched['landmarks'].size()) if i_batch == 3:
plt.figure()
show_landmarks_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break

输出:

0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

后续:torchvision

在本教程中,我们了解了如何实现并使用数据集、转换和数据导入。torchvision包提供了一些常用的数据集和转换。你甚至可能不需要编写自定义的类。在torchvision中最常用的数据集是ImageFolder。它假设图像的组织方式如下所示:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

‘ants’、‘bees’等等都是类的标签。在PIL.Image上操作的类似常用的转化,如RandomHorizontalFlipScale,都是可用的。你可以使用它们来编写想下面的数据导入:

import torch
from torchvision import transforms, datasets data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)

[PyTorch入门]之数据导入与处理的更多相关文章

  1. Kafka Connect使用入门-Mysql数据导入到ElasticSearch

    1.Kafka Connect Connect是Kafka的一部分,它为在Kafka和外部存储系统之间移动数据提供了一种可靠且伸缩的方式,它为连接器插件提供了一组API和一个运行时-Connect负责 ...

  2. oracle数据库数据导入导出步骤(入门)

    oracle数据库数据导入导出步骤(入门) 说明: 1.数据库数据导入导出方法有多种,可以通过exp/imp命令导入导出,也可以用第三方工具导出,如:PLSQL 2.如果熟悉命令,建议用exp/imp ...

  3. R语言基础入门之二:数据导入和描述统计

    by 写长城的诗 • October 30, 2011 • Comments Off This post was kindly contributed by 数据科学与R语言 - go there t ...

  4. pytorch入门2.0构建回归模型初体验(数据生成)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  5. go语言入门教程百度网盘 mysql图形化操作与数据导入

    mysql图形化操作与数据导入 @author:Davie 版权所有:北京千锋互联科技有限公司 数据库存储技术 数据库(Database)是按照数据结构来组织.存储和管理数据的仓库.每个数据库都有一个 ...

  6. 三分钟教会你Python数据分析—数据导入,小白基础入门必看内容

    前言 文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,版权归原作者所有,如有问题请及时联系我们以作处理. 作者:小白 PS:如有需要Python学习资料的小伙伴可以加点击下方链接自行 ...

  7. [Pytorch]PyTorch Dataloader自定义数据读取

    整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...

  8. Pytorch入门——手把手教你MNIST手写数字识别

    MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...

  9. Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader

    本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...

随机推荐

  1. Python 进阶 - 面向对象

    Python 面向对象 面向过程 把完成某个需求的所有步骤,从头到尾逐步实现 根据开发需求,将某些功能独立的代码封装成一个又一个函数 最后完成的代码,就是顺序地调用不同的函数 面向过程特点: 注重步骤 ...

  2. 项目部署篇之——下载安装Xftp6,Xshell6

    俗话说工欲善其事必先利其器,想要在服务器上部署环境就得先安装操作工具. 我用的是xshell6,和xftp6.下面是下载连接,都是免费版的,不需要破解 xftp6链接:https://pan.baid ...

  3. 支付宝H5支付demo

    支付宝H5支付 首先我们必须注册一个支付宝应用(本案例就直接用支付宝的沙箱环境,这个沙箱也就是支付宝提供给开发者的一个测试环境) 登录地址:https://open.alipay.com/platfo ...

  4. 好看的UI组合,为以后自己写组件库做准备

    1. 黑色格子背景 { color: rgb(255, 255, 255); text-shadow: 1px 1px 0 rgba(0,0,0,.3); rgb(62, 64, 74); backg ...

  5. Docker Compose文件详解 V2

    Compose file reference 语法: web:      build: ./web      ports:      - "5000:5000"      volu ...

  6. oracle_(第一课) 安装oracle数据库

    首先去官网下载两个架包链接如下:官网链接 第一步:将两个架包解压到同一个database目录下.如截图所示: 第二步:打开setup应用程序 打开后就到了下面这个页面 第三步:配置安全更新 环境变量配 ...

  7. Educational Codeforces Round 48 (Rated for Div. 2)异或思维

    题:https://codeforces.com/contest/1016/problem/D 题意:有一个 n * m 的矩阵, 现在给你 n 个数, 第 i 个数 a[ i ] 代表 i 这一行所 ...

  8. vue点击复制文本粘贴

    <template>  <ul>      <li> <input type="text" class="inpNone&quo ...

  9. Qt 项目中main主函数及其作用

    main.cpp 是实现 main() 函数的文件,下面是 main.cpp 文件的内容. #include "widget.h" #include <QApplicatio ...

  10. “大屏,您好!” SONIQ声光揭新品“U•F•O”神秘面纱

    作为全球第一批做互联网智能电视的传媒企业,SONIQ声光于4月22日在中国大饭店举行了盛大的新品发布会.其中的重头戏就是当天发布会上作为先锋部队入驻中国电视市场的"UFO".笔者作 ...