图片数据一般有两种情况:

1、所有图片放在一个文件夹内,另外有一个txt文件显示标签。

2、不同类别的图片放在不同的文件夹内,文件夹就是图片的类别。

针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理。下面分别进行说明:

一、所有图片放在一个文件夹内

这里以mnist数据集的10000个test为例, 我先把test集的10000个图片保存出来,并生着对应的txt标签文件。

先在当前目录创建一个空文件夹mnist_test, 用于保存10000张图片,接着运行代码:

import torch
import torchvision
import matplotlib.pyplot as plt
from skimage import io
mnist_test= torchvision.datasets.MNIST(
'./mnist', train=False, download=True
)
print('test set:', len(mnist_test)) f=open('mnist_test.txt','w')
for i,(img,label) in enumerate(mnist_test):
img_path="./mnist_test/"+str(i)+".jpg"
io.imsave(img_path,img)
f.write(img_path+' '+str(label)+'\n')
f.close()

经过上面的操作,10000张图片就保存在mnist_test文件夹里了,并在当前目录下生成了一个mnist_test.txt的文件,大致如下:

前期工作就装备好了,接着就进入正题了:

from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image def default_loader(path):
return Image.open(path).convert('RGB') class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img,label def __len__(self):
return len(self.imgs) train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader)) def show_batch(imgs):
grid = utils.make_grid(imgs)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(),batch_y.size())
show_batch(batch_x)
plt.axis('off')
plt.show()

自定义了一个MyDataset, 继承自torch.utils.data.Dataset。然后利用torch.utils.data.DataLoader将整个数据集分成多个批次。

二、不同类别的图片放在不同的文件夹内

同样先准备数据,这里以flowers数据集为例,下载:

http://download.tensorflow.org/example_images/flower_photos.tgz

花总共有五类,分别放在5个文件夹下。大致如下图:

我的路径是d:/flowers/.

数据准备好了,就开始准备Dataset吧,这里直接调用torchvision里面的ImageFolder

import torch
import torchvision
from torchvision import transforms, utils
import matplotlib.pyplot as plt img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
transform=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
) print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader)) def show_batch(imgs):
grid = utils.make_grid(imgs,nrow=5)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(), batch_y.size()) show_batch(batch_x)
plt.axis('off')
plt.show()

就是这样。

pytorch学习:准备自己的图片数据的更多相关文章

  1. [PyTorch 学习笔记] 2.2 图片预处理 transforms 模块机制

    PyTorch 的数据增强 我们在安装PyTorch时,还安装了torchvision,这是一个计算机视觉工具包.有 3 个主要的模块: torchvision.transforms: 里面包括常用的 ...

  2. pytorch: 准备、训练和测试自己的图片数据

    大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试.如果我们是自己的图片数据,又该怎么做呢? 一.我的数据 我在学习的时候,使用的是fashion-mnist.这个 ...

  3. pytorch初步学习(一):数据读取

    最近从tensorflow转向pytorch,感受到了动态调试的方便,也感受到了一些地方的不同. 所有实验都是基于uint16类型的单通道灰度图片. 一开始尝试用opencv中的cv.imread读取 ...

  4. [深度学习] pytorch利用Datasets和DataLoader读取数据

    本文简单描述如果自定义dataset,代码并未经过测试(只是说明思路),为半伪代码.所有逻辑需按自己需求另外实现: 一.分析DataLoader train_loader = DataLoader( ...

  5. Python库 - Albumentations 图片数据增强库

    Python图像处理库 - Albumentations,可用于深度学习中网络训练时的图片数据增强. Albumentations 图像数据增强库特点: 基于高度优化的 OpenCV 库实现图像快速数 ...

  6. 【深度学习】Pytorch学习基础

    目录 pytorch学习 numpy & Torch Variable 激励函数 回归 区分类型 快速搭建法 模型的保存与提取 批训练 加速神经网络训练 Optimizer优化器 CNN MN ...

  7. tensorflow学习笔记三:实例数据下载与读取

    一.mnist数据 深度学习的入门实例,一般就是mnist手写数字分类识别,因此我们应该先下载这个数据集. tensorflow提供一个input_data.py文件,专门用于下载mnist数据,我们 ...

  8. Caffe初试(三)使用caffe的cifar10网络模型训练自己的图片数据

    由于我涉及一个车牌识别系统的项目,计划使用深度学习库caffe对车牌字符进行识别.刚开始接触caffe,打算先将示例中的每个网络模型都拿出来用用,当然这样暴力的使用是不会有好结果的- -||| ,所以 ...

  9. 纠错:基于FPGA串口发送彩色图片数据至VGA显示

    今天这篇文章是要修改之前的一个错误,前面我写过一篇基于FPGA的串口发送图片数据至VGA显示的文章,最后是显示成功了,但是显示的效果图,看起来确实灰度图,当时我默认我使用的MATLAB代码将图片数据转 ...

随机推荐

  1. PTA_输入符号及符号个数打印沙漏(C++)

    思路:想将所有沙漏所需符号数遍历一遍,然后根据输入的数判断需要输出多少多少层的沙漏,然后分两部分输出沙漏.   #include<iostream> #include<cstring ...

  2. CF987B - High School: Become Human

    Year 2118. Androids are in mass production for decades now, and they do all the work for humans. But ...

  3. 使用 Swoole 来加速 Laravel应用

    Swoole 是为 PHP 开发的生产级异步编程框架. 他是一个纯 C 开发的扩展, 他允许 PHP 开发者在 PHP 中写 高性能,可扩展的并发 TCP, UDP, Unix socket, HTT ...

  4. Alpha冲刺(2/10)——2019.4.24

    作业描述 课程 软件工程1916|W(福州大学) 团队名称 修!咻咻! 作业要求 项目Alpha冲刺(团队) 团队目标 切实可行的计算机协会维修预约平台 开发工具 Eclipse 团队信息 队员学号 ...

  5. dotnet core 3.0 linux 部署小贴士

    dotnet core 3.0 目前还是测试版,在linux下安装 sdk 需要有一些注意事项 1.下载url https://dotnet.microsoft.com/download/thank- ...

  6. 通过源码理解HashMap的并发问题

    最近在学习有关于Java的基础知识,在学习到HashMap的相关知识的时候,了解了HashMap的并发中会出现的问题,在此记录,加深理解(这篇文章是基于Java1.7的,主要是为了更加直观,更新版本的 ...

  7. python爬取网页内容demo

    #html文本提取 from bs4 import BeautifulSoup html_sample = '\ <html> \ <body> \ <h1 id = & ...

  8. Hibernte

    什么是CRM?(了解) CRM(customer relationship management)即客户关系管理,是指企业用CRM技术来管理与客户之间的关系.在不同场合下,CRM可能是一个管理学术语, ...

  9. python学习笔记(7)

    第七章 文件和数据格式化 文件的使用 文件是数据的抽象和集合 文件是存储在辅助存储器上的数据序列 文件是数据存储的一种形式 文件展现形态:文本文件和二进制文件 文本文件 由单一特定编码组成的文件,如U ...

  10. window下如何使用文本编辑器(如记事本)创建、编译和执行Java程序

    window下如何使用文本编辑器(如记事本)创建Java源代码文件,并编译执行 第一步:在一个英文目录下创建一个 .text 文件 第二步:编写代码 第三步:保存文件 方法一:选择 文件>另存为 ...