pytorch学习:准备自己的图片数据
图片数据一般有两种情况:
1、所有图片放在一个文件夹内,另外有一个txt文件显示标签。
2、不同类别的图片放在不同的文件夹内,文件夹就是图片的类别。
针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理。下面分别进行说明:
一、所有图片放在一个文件夹内
这里以mnist数据集的10000个test为例, 我先把test集的10000个图片保存出来,并生着对应的txt标签文件。
先在当前目录创建一个空文件夹mnist_test, 用于保存10000张图片,接着运行代码:
import torch
import torchvision
import matplotlib.pyplot as plt
from skimage import io
mnist_test= torchvision.datasets.MNIST(
'./mnist', train=False, download=True
)
print('test set:', len(mnist_test)) f=open('mnist_test.txt','w')
for i,(img,label) in enumerate(mnist_test):
img_path="./mnist_test/"+str(i)+".jpg"
io.imsave(img_path,img)
f.write(img_path+' '+str(label)+'\n')
f.close()
经过上面的操作,10000张图片就保存在mnist_test文件夹里了,并在当前目录下生成了一个mnist_test.txt的文件,大致如下:
前期工作就装备好了,接着就进入正题了:
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image def default_loader(path):
return Image.open(path).convert('RGB') class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img,label def __len__(self):
return len(self.imgs) train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader)) def show_batch(imgs):
grid = utils.make_grid(imgs)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(),batch_y.size())
show_batch(batch_x)
plt.axis('off')
plt.show()
自定义了一个MyDataset, 继承自torch.utils.data.Dataset。然后利用torch.utils.data.DataLoader将整个数据集分成多个批次。
二、不同类别的图片放在不同的文件夹内
同样先准备数据,这里以flowers数据集为例,下载:
http://download.tensorflow.org/example_images/flower_photos.tgz
花总共有五类,分别放在5个文件夹下。大致如下图:
我的路径是d:/flowers/.
数据准备好了,就开始准备Dataset吧,这里直接调用torchvision里面的ImageFolder
import torch
import torchvision
from torchvision import transforms, utils
import matplotlib.pyplot as plt img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
transform=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
) print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader)) def show_batch(imgs):
grid = utils.make_grid(imgs,nrow=5)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader') for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(), batch_y.size()) show_batch(batch_x)
plt.axis('off')
plt.show()
就是这样。
pytorch学习:准备自己的图片数据的更多相关文章
- [PyTorch 学习笔记] 2.2 图片预处理 transforms 模块机制
PyTorch 的数据增强 我们在安装PyTorch时,还安装了torchvision,这是一个计算机视觉工具包.有 3 个主要的模块: torchvision.transforms: 里面包括常用的 ...
- pytorch: 准备、训练和测试自己的图片数据
大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试.如果我们是自己的图片数据,又该怎么做呢? 一.我的数据 我在学习的时候,使用的是fashion-mnist.这个 ...
- pytorch初步学习(一):数据读取
最近从tensorflow转向pytorch,感受到了动态调试的方便,也感受到了一些地方的不同. 所有实验都是基于uint16类型的单通道灰度图片. 一开始尝试用opencv中的cv.imread读取 ...
- [深度学习] pytorch利用Datasets和DataLoader读取数据
本文简单描述如果自定义dataset,代码并未经过测试(只是说明思路),为半伪代码.所有逻辑需按自己需求另外实现: 一.分析DataLoader train_loader = DataLoader( ...
- Python库 - Albumentations 图片数据增强库
Python图像处理库 - Albumentations,可用于深度学习中网络训练时的图片数据增强. Albumentations 图像数据增强库特点: 基于高度优化的 OpenCV 库实现图像快速数 ...
- 【深度学习】Pytorch学习基础
目录 pytorch学习 numpy & Torch Variable 激励函数 回归 区分类型 快速搭建法 模型的保存与提取 批训练 加速神经网络训练 Optimizer优化器 CNN MN ...
- tensorflow学习笔记三:实例数据下载与读取
一.mnist数据 深度学习的入门实例,一般就是mnist手写数字分类识别,因此我们应该先下载这个数据集. tensorflow提供一个input_data.py文件,专门用于下载mnist数据,我们 ...
- Caffe初试(三)使用caffe的cifar10网络模型训练自己的图片数据
由于我涉及一个车牌识别系统的项目,计划使用深度学习库caffe对车牌字符进行识别.刚开始接触caffe,打算先将示例中的每个网络模型都拿出来用用,当然这样暴力的使用是不会有好结果的- -||| ,所以 ...
- 纠错:基于FPGA串口发送彩色图片数据至VGA显示
今天这篇文章是要修改之前的一个错误,前面我写过一篇基于FPGA的串口发送图片数据至VGA显示的文章,最后是显示成功了,但是显示的效果图,看起来确实灰度图,当时我默认我使用的MATLAB代码将图片数据转 ...
随机推荐
- PTA_输入符号及符号个数打印沙漏(C++)
思路:想将所有沙漏所需符号数遍历一遍,然后根据输入的数判断需要输出多少多少层的沙漏,然后分两部分输出沙漏. #include<iostream> #include<cstring ...
- CF987B - High School: Become Human
Year 2118. Androids are in mass production for decades now, and they do all the work for humans. But ...
- 使用 Swoole 来加速 Laravel应用
Swoole 是为 PHP 开发的生产级异步编程框架. 他是一个纯 C 开发的扩展, 他允许 PHP 开发者在 PHP 中写 高性能,可扩展的并发 TCP, UDP, Unix socket, HTT ...
- Alpha冲刺(2/10)——2019.4.24
作业描述 课程 软件工程1916|W(福州大学) 团队名称 修!咻咻! 作业要求 项目Alpha冲刺(团队) 团队目标 切实可行的计算机协会维修预约平台 开发工具 Eclipse 团队信息 队员学号 ...
- dotnet core 3.0 linux 部署小贴士
dotnet core 3.0 目前还是测试版,在linux下安装 sdk 需要有一些注意事项 1.下载url https://dotnet.microsoft.com/download/thank- ...
- 通过源码理解HashMap的并发问题
最近在学习有关于Java的基础知识,在学习到HashMap的相关知识的时候,了解了HashMap的并发中会出现的问题,在此记录,加深理解(这篇文章是基于Java1.7的,主要是为了更加直观,更新版本的 ...
- python爬取网页内容demo
#html文本提取 from bs4 import BeautifulSoup html_sample = '\ <html> \ <body> \ <h1 id = & ...
- Hibernte
什么是CRM?(了解) CRM(customer relationship management)即客户关系管理,是指企业用CRM技术来管理与客户之间的关系.在不同场合下,CRM可能是一个管理学术语, ...
- python学习笔记(7)
第七章 文件和数据格式化 文件的使用 文件是数据的抽象和集合 文件是存储在辅助存储器上的数据序列 文件是数据存储的一种形式 文件展现形态:文本文件和二进制文件 文本文件 由单一特定编码组成的文件,如U ...
- window下如何使用文本编辑器(如记事本)创建、编译和执行Java程序
window下如何使用文本编辑器(如记事本)创建Java源代码文件,并编译执行 第一步:在一个英文目录下创建一个 .text 文件 第二步:编写代码 第三步:保存文件 方法一:选择 文件>另存为 ...