thumbnail: https://image.zhangxiann.com/jeison-higuita-W19AQY42rUk-unsplash.jpg

toc: true

date: 2020/2/19 20:17:25

disqusId: zhangxian

categories:

  • PyTorch

tags:

  • AI
  • Deep Learning

本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson2/rmb_classification/

人民币 二分类

实现 1 元人民币和 100 元人民币的图片二分类。前面讲过 PyTorch 的五大模块:数据、模型、损失函数、优化器和迭代训练。

数据模块又可以细分为 4 个部分:

  • 数据收集:样本和标签。
  • 数据划分:训练集、验证集和测试集
  • 数据读取:对应于PyTorch 的 DataLoader。其中 DataLoader 包括 Sampler 和 DataSet。Sampler 的功能是生成索引, DataSet 是根据生成的索引读取样本以及标签。
  • 数据预处理:对应于 PyTorch 的 transforms

# DataLoader 与 DataSet

torch.utils.data.DataLoader()

  1. torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

功能:构建可迭代的数据装载器

  • dataset: Dataset 类,决定数据从哪里读取以及如何读取
  • batchsize: 批大小
  • num_works:num_works: 是否多进程读取数据
  • sheuffle: 每个 epoch 是否乱序
  • drop_last: 当样本数不能被 batchsize 整除时,是否舍弃最后一批数据

Epoch, Iteration, Batchsize

  • Epoch: 所有训练样本都已经输入到模型中,称为一个 Epoch
  • Iteration: 一批样本输入到模型中,称为一个 Iteration
  • Batchsize: 批大小,决定一个 iteration 有多少样本,也决定了一个 Epoch 有多少个 Iteration

假设样本总数有 80,设置 Batchsize 为 8,则共有 $80 \div 8=10$ 个 Iteration。这里 $1 Epoch = 10 Iteration$。

假设样本总数有 86,设置 Batchsize 为 8。如果drop_last=True则共有 10 个 Iteration;如果drop_last=False则共有 11 个 Iteration。

torch.utils.data.Dataset

功能:Dataset 是抽象类,所有自定义的 Dataset 都需要继承该类,并且重写__getitem()__方法和__len__()方法 。__getitem()__方法的作用是接收一个索引,返回索引对应的样本和标签,这是我们自己需要实现的逻辑。__len__()方法是返回所有样本的数量。

数据读取包含 3 个方面

  • 读取哪些数据:每个 Iteration 读取一个 Batchsize 大小的数据,每个 Iteration 应该读取哪些数据。
  • 从哪里读取数据:如何找到硬盘中的数据,应该在哪里设置文件路径参数
  • 如何读取数据:不同的文件需要使用不同的读取方法和库。

这里的路径结构如下,有两类人民币图片:1 元和 100 元,每一类各有 100 张图片。

  • RMB_data

    • 1
    • 100

首先划分数据集为训练集、验证集和测试集,比例为 8:1:1。

数据划分好后的路径构造如下:

  • rmb_split

    • train

      • 1
      • 100
    • valid
      • 1
      • 100
    • test
      • 1
      • 100

实现读取数据的 Dataset,编写一个get_img_info()方法,读取每一个图片的路径和对应的标签,组成一个元组,再把所有的元组作为 list 存放到self.data_info变量中,这里需要注意的是标签需要映射到 0 开始的整数: rmb_label = {"1": 0, "100": 1}

  1. @staticmethod
  2. def get_img_info(data_dir):
  3. data_info = list()
  4. # data_dir 是训练集、验证集或者测试集的路径
  5. for root, dirs, _ in os.walk(data_dir):
  6. # 遍历类别
  7. # dirs ['1', '100']
  8. for sub_dir in dirs:
  9. # 文件列表
  10. img_names = os.listdir(os.path.join(root, sub_dir))
  11. # 取出 jpg 结尾的文件
  12. img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
  13. # 遍历图片
  14. for i in range(len(img_names)):
  15. img_name = img_names[i]
  16. # 图片的绝对路径
  17. path_img = os.path.join(root, sub_dir, img_name)
  18. # 标签,这里需要映射为 0、1 两个类别
  19. label = rmb_label[sub_dir]
  20. # 保存在 data_info 变量中
  21. data_info.append((path_img, int(label)))
  22. return data_info

然后在Dataset 的初始化函数中调用get_img_info()方法。

  1. def __init__(self, data_dir, transform=None):
  2. """
  3. rmb面额分类任务的Dataset
  4. :param data_dir: str, 数据集所在路径
  5. :param transform: torch.transform,数据预处理
  6. """
  7. # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
  8. self.data_info = self.get_img_info(data_dir)
  9. self.transform = transform

然后在__getitem__()方法中根据index 读取self.data_info中路径对应的数据,并在这里做 transform 操作,返回的是样本和标签。

  1. def __getitem__(self, index):
  2. # 通过 index 读取样本
  3. path_img, label = self.data_info[index]
  4. # 注意这里需要 convert('RGB')
  5. img = Image.open(path_img).convert('RGB') # 0~255
  6. if self.transform is not None:
  7. img = self.transform(img) # 在这里做transform,转为tensor等等
  8. # 返回是样本和标签
  9. return img, label

__len__()方法中返回self.data_info的长度,即为所有样本的数量。

  1. # 返回所有样本的数量
  2. def __len__(self):
  3. return len(self.data_info)

train_lenet.py中,分 5 步构建模型。

第 1 步设置数据。首先定义训练集、验证集、测试集的路径,定义训练集和测试集的transforms。然后构建训练集和验证集的RMBDataset对象,把对应的路径和transforms传进去。再构建DataLoder,设置 batch_size,其中训练集设置shuffle=True,表示每个 Epoch 都打乱样本。

  1. # 构建MyDataset实例train_data = RMBDataset(data_dir=train_dir, transform=train_transform)valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
  2. # 构建DataLoder
  3. # 其中训练集设置 shuffle=True,表示每个 Epoch 都打乱样本
  4. train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
  5. valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

第 2 步构建模型,这里采用经典的 Lenet 图片分类网络。

  1. net = LeNet(classes=2)
  2. net.initialize_weights()

第 3 步设置损失函数,这里使用交叉熵损失函数。

  1. criterion = nn.CrossEntropyLoss()

第 4 步设置优化器。这里采用 SGD 优化器。

  1. optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器
  2. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 设置学习率下降策略

第 5 步迭代训练模型,在每一个 epoch 里面,需要遍历 train_loader 取出数据,每次取得数据是一个 batchsize 大小。这里又分为 4 步。第 1 步进行前向传播,第 2 步进行反向传播求导,第 3 步使用optimizer更新权重,第 4 步统计训练情况。每一个 epoch 完成时都需要使用scheduler更新学习率,和计算验证集的准确率、loss。

  1. for epoch in range(MAX_EPOCH):
  2. loss_mean = 0.
  3. correct = 0.
  4. total = 0.
  5. net.train()
  6. # 遍历 train_loader 取数据
  7. for i, data in enumerate(train_loader):
  8. # forward
  9. inputs, labels = data
  10. outputs = net(inputs)
  11. # backward
  12. optimizer.zero_grad()
  13. loss = criterion(outputs, labels)
  14. loss.backward()
  15. # update weights
  16. optimizer.step()
  17. # 统计分类情况
  18. _, predicted = torch.max(outputs.data, 1)
  19. total += labels.size(0)
  20. correct += (predicted == labels).squeeze().sum().numpy()
  21. # 打印训练信息
  22. loss_mean += loss.item()
  23. train_curve.append(loss.item())
  24. if (i+1) % log_interval == 0:
  25. loss_mean = loss_mean / log_interval
  26. print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
  27. epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
  28. loss_mean = 0.
  29. scheduler.step() # 更新学习率
  30. # 每个 epoch 计算验证集得准确率和loss
  31. ...
  32. ...

我们可以看到每个 iteration,我们是从train_loader中取出数据的。

  1. def __iter__(self):
  2. if self.num_workers == 0:
  3. return _SingleProcessDataLoaderIter(self)
  4. else:
  5. return _MultiProcessingDataLoaderIter(self)

这里我们没有设置多进程,会执行_SingleProcessDataLoaderIter的方法。我们以_SingleProcessDataLoaderIter为例。在_SingleProcessDataLoaderIter里只有一个方法_next_data(),如下:

  1. def _next_data(self):
  2. index = self._next_index() # may raise StopIteration
  3. data = self._dataset_fetcher.fetch(index) # may raise StopIteration
  4. if self._pin_memory:
  5. data = _utils.pin_memory.pin_memory(data)
  6. return data

在该方法中,self._next_index()是获取一个 batchsize 大小的 index 列表,代码如下:

  1. def _next_index(self):
  2. return next(self._sampler_iter) # may raise StopIteration

其中调用的sampler类的__iter__()方法返回 batch_size 大小的随机 index 列表。

  1. def __iter__(self):
  2. batch = []
  3. for idx in self.sampler:
  4. batch.append(idx)
  5. if len(batch) == self.batch_size:
  6. yield batch
  7. batch = []
  8. if len(batch) > 0 and not self.drop_last:
  9. yield batch

然后再返回看 dataloader_next_data()方法:

  1. def _next_data(self):
  2. index = self._next_index() # may raise StopIteration
  3. data = self._dataset_fetcher.fetch(index) # may raise StopIteration
  4. if self._pin_memory:
  5. data = _utils.pin_memory.pin_memory(data)
  6. return data

在第二行中调用了self._dataset_fetcher.fetch(index)获取数据。这里会调用_MapDatasetFetcher中的fetch()函数:

  1. def fetch(self, possibly_batched_index):
  2. if self.auto_collation:
  3. data = [self.dataset[idx] for idx in possibly_batched_index]
  4. else:
  5. data = self.dataset[possibly_batched_index]
  6. return self.collate_fn(data)

这里调用了self.dataset[idx],这个函数会调用dataset.__getitem__()方法获取具体的数据,所以__getitem__()方法是我们必须实现的。我们拿到的data是一个 list,每个元素是一个 tunple,每个 tunple 包括样本和标签。所以最后要使用self.collate_fn(data)把 data 转换为两个 list,第一个 元素 是样本的batch 形式,形状为 [16, 3, 32, 32] (16 是 batch size,[3, 32, 32] 是图片像素);第二个元素是标签的 batch 形式,形状为 [16]。

所以在代码中,我们使用inputs, labels = data来接收数据。

PyTorch 数据读取流程图

首先在 for 循环中遍历`DataLoader`,然后根据是否采用多进程,决定使用单进程或者多进程的`DataLoaderIter`。在`DataLoaderIter`里调用`Sampler`生成`Index`的 list,再调用`DatasetFetcher`根据`index`获取数据。在`DatasetFetcher`里会调用`Dataset`的`__getitem__()`方法获取真正的数据。这里获取的数据是一个 list,其中每个元素是 (img, label) 的元组,再使用 `collate_fn()`函数整理成一个 list,里面包含两个元素,分别是 img 和 label 的`tenser`。

下图是我们的训练过程的 loss 曲线:

**参考资料**

如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。

我的文章会首发在公众号上,欢迎扫码关注我的公众号张贤同学

[PyTorch 学习笔记] 2.1 DataLoader 与 DataSet的更多相关文章

  1. 【深度学习】Pytorch 学习笔记

    目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...

  2. Pytorch学习笔记(一)——简介

    一.Tensor Tensor是Pytorch中重要的数据结构,可以认为是一个高维数组.Tensor可以是一个标量.一维数组(向量).二维数组(矩阵)或者高维数组等.Tensor和numpy的ndar ...

  3. 【pytorch】pytorch学习笔记(一)

    原文地址:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 什么是pytorch? pytorch是一个基于p ...

  4. [深度学习] pytorch学习笔记(3)(visdom可视化、正则化、动量、学习率衰减、BN)

    一.visdom可视化工具 安装:pip install visdom 启动:命令行直接运行visdom 打开WEB:在浏览器使用http://localhost:8097打开visdom界面 二.使 ...

  5. 莫烦PyTorch学习笔记(六)——批处理

    1.要点 Torch 中提供了一种帮你整理你的数据结构的好东西, 叫做 DataLoader, 我们能用它来包装自己的数据, 进行批训练. 而且批训练可以有很多种途径. 2.DataLoader Da ...

  6. [PyTorch 学习笔记] 2.2 图片预处理 transforms 模块机制

    PyTorch 的数据增强 我们在安装PyTorch时,还安装了torchvision,这是一个计算机视觉工具包.有 3 个主要的模块: torchvision.transforms: 里面包括常用的 ...

  7. [PyTorch 学习笔记] 5.1 TensorBoard 介绍

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/tensorboard_methods.py http ...

  8. PyTorch学习笔记6--案例2:PyTorch神经网络(MNIST CNN)

    上一节中,我们使用autograd的包来定义模型并求导.本节中,我们将使用torch.nn包来构建神经网络. 一个nn.Module包含各个层和一个forward(input)方法,该方法返回outp ...

  9. Pytorch学习笔记(二)---- 神经网络搭建

    记录如何用Pytorch搭建LeNet-5,大体步骤包括:网络的搭建->前向传播->定义Loss和Optimizer->训练 # -*- coding: utf-8 -*- # Al ...

随机推荐

  1. vue同时安装element ui跟 vant

    记一个卡了我比较久的问题,之前弄的心态爆炸各种问题. 现在来记录一下,首先我vant是已经安装成功了的. 然后引入element ui npm i element-ui -S 接着按需引入,安装插件 ...

  2. Django学习路32_创建管理员及内容补充+前面内容复习

    创建管理员 python manage.py createsuperuser   数据库属性命名限制 1.不能是python的保留关键字 2.不允许使用连续的下划线,这是由django的查询方式决定的 ...

  3. PHP array_diff_assoc() 函数

    实例 比较两个数组的键名和键值,并返回差集: <?php$a1=array("a"=>"red","b"=>"g ...

  4. CF1037H Security 线段树合并 SAM

    LINK:Security 求一个严格大于T的字符串 是原字符串S[L,R]的子串. 容易想到尽可能和T相同 然后再补一个尽可能小的字符即可. 出于这种思想 可以在SAM上先跑匹配 然后枚举加哪个字符 ...

  5. ElasticSearch学习中的坑

    elasticsearch 版本为 6.8.2 1 安装完启动报错:   解决,建立新用户执行 [root@localhost bin]# ./elasticsearch [2019-09-01T05 ...

  6. Latex—参考文献

    在写文章的最后最让我头疼的就是参考文献的问题了.网上的资料也有很多,这里整合了很多资料得出了一个用bib文件的方法. 1.  显示确定参考文献(一句没什么用的废话). 2.  利用谷歌学术(镜像),如 ...

  7. 利用Python操作MySQL数据库

    前言 在工作中,我们需要经常对数据库进行操作,比如 Oracle.MySQL.SQL Sever 等,今天我们就学习如何利用Python来操作 MySQL 数据库. 本人环境:Python 3.7.0 ...

  8. Node.js 和 Python之间如何进行选择?

    转载请注明出处:葡萄城官网,葡萄城为开发者提供专业的开发工具.解决方案和服务,赋能开发者. 原文出处:https://dzone.com/articles/nodejs-vs-python-which ...

  9. System.out.println()相关源码

    System.out.println是一个Java语句,一般情况下是将传递的参数,打印到控制台. System:是 java.lang包中的一个final类.根据javadoc,“java.lang. ...

  10. 解决CocoaPods could not find compatible versions for pod "React/Core"

    react-native框架中,在ios文件夹下执行pod install命令时出现的问题. 下面时完整的异常信息: [!] CocoaPods could not find compatible v ...