加载数据集dataloader

from torch.utils.data import DataLoader
form 自己写的dataset import Dataset train_set = Dataset(train=True)
val_set = Dataset(train=False) image_datasets = {
'train': train_set, 'val': val_set
} batch_size = 4 dataloaders = {
'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2),
'val': DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
} dataset_sizes = {
x: len(image_datasets[x]) for x in image_datasets.keys()
}
print(dataset_sizes) for epoch in range(num_epochs):
for phase in ['train', 'val']:
if phase == 'train':
# for param_group in optimizer.param_groups:
# print("LR", param_group['lr'])
model.train()
else:
model.eval()

以上适用于train一遍test一遍的情况

或者分别加载训练和测试:

train_dataset = Dataset('train')
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True,
num_workers=2, collate_fn=collate_fn) test_dataset = Dataset('eval')
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False,
num_workers=2, collate_fn=collate_fn)

自己写Dataset

from torch.utils.data import Dataset
import os
import cv2
import torch
import numpy as np class Dataset(Dataset):
def __init__(self,train):
if train:
self.datapath = {'image': '/home/myy/code/Final_Project/data_train.txt', 'target':'/home/myy/code/Final_Project/gt_train.txt'}
else:
self.datapath = {'image': '/home/myy/code/Final_Project/data_test.txt', 'target':'/home/myy/code/Final_Project/gt_test.txt'}
# self.datapath = {'image': '/home/myy/code/Final_Project/test_small_data.txt', 'target':'/home/myy/code/Final_Project/test_small.txt'}
self.image_list, self.target_list = self.read_txt(self.datapath) # 此处可以依据需要自己定义一些函数
# 注意调用前要加上`self.`
# 比如以下两个读取数据的函数,read_txt、read_json就是自己定义的
def read_txt(self,datapath):
im =[]
target_image = []
print(datapath)
with open(datapath['image'], 'r') as f:
image_list = f.readlines()
with open(datapath['target'], 'r') as f:
target_list = f.readlines()
return image_list, target_list def read_json(save_path, encoding='utf8'):
jsondata = []
with open(save_path, 'r', encoding=encoding) as f:
content = f.read()
content = json.loads(content)
for key in content:
jsondata.append(content[key])
return jsondata def __getitem__(self, item):
# 最核心的部分,经过处理,要返回输入和gt return img, target def __len__(self):
# 这可以根据具体情况修改,不写也行
return len(self.data)

[深度学习]-Dataset数据集加载的更多相关文章

  1. 什么是pytorch(4.数据集加载和处理)(翻译)

    数据集加载和处理 这里主要涉及两个包:torchvision.datasets 和torch.utils.data.Dataset 和DataLoader torchvision.datasets是一 ...

  2. OFRecord 数据集加载

    OFRecord 数据集加载 在数据输入一文中知道了使用 DataLoader 及相关算子加载数据,往往效率更高,并且学习了如何使用 DataLoader 及相关算子. 在 OFrecord 数据格式 ...

  3. 深入java虚拟机学习 -- 类的加载机制(续)

    昨晚写 深入java虚拟机学习 -- 类的加载机制 都到1点半了,由于第二天还要工作,没有将上篇文章中的demo讲解写出来,今天抽时间补上昨晚的例子讲解. 这里我先把昨天的两份代码贴过来,重新看下: ...

  4. 【Java Web开发学习】Spring加载外部properties配置文件

    [Java Web开发学习]Spring加载外部properties配置文件 转载:https://www.cnblogs.com/yangchongxing/p/9136505.html 1.声明属 ...

  5. Python3读取深度学习CIFAR-10数据集出现的若干问题解决

    今天在看网上的视频学习深度学习的时候,用到了CIFAR-10数据集.当我兴高采烈的运行代码时,却发现了一些错误: # -*- coding: utf-8 -*- import pickle as p ...

  6. 深度学习常用数据集 API(包括 Fashion MNIST)

    基准数据集 深度学习中经常会使用一些基准数据集进行一些测试.其中 MNIST, Cifar 10, cifar100, Fashion-MNIST 数据集常常被人们拿来当作练手的数据集.为了方便,诸如 ...

  7. Recorder︱深度学习小数据集表现、优化(Active Learning)、标注集网络获取

    一.深度学习在小数据集的表现 深度学习在小数据集情况下获得好效果,可以从两个角度去解决: 1.降低偏差,图像平移等操作 2.降低方差,dropout.随机梯度下降 先来看看深度学习在小数据集上表现的具 ...

  8. PIE SDK 多数据源的复合数据集加载

    1. 功能简介 GIS遥感图像数据复合是将多种遥感图像数据融合成一种新的图像数据的技术,是目前遥感应用分析的前沿,PIESDK通过复合数据技术可以将多幅幅影像数据集(多光谱和全色数据)组合成一幅多波段 ...

  9. tensorflow数据集加载

    本篇涉及的内容主要有小型常用的经典数据集的加载步骤,tensorflow提供了如下接口:keras.datasets.tf.data.Dataset.from_tensor_slices(shuffl ...

随机推荐

  1. APISpace万券齐发,API采购大放价

    Eolink APISpace 是 Eolink 旗下专业的API 数据交易平台,上面拥有海量的API,开发者可以根据需求自由选择. 环境天气 全国天气预报,支持全国以及全球多个城市的天气查询,包含国 ...

  2. 作业二、安装CentOS7.9

    一.安装环境 1.VMware Workstation 16 Pro 2.CentOS7.9 二.部署系统 步骤1.进入VMware,点击创建新的虚拟机 步骤2.进入新建虚拟机向导,选择典型(推荐) ...

  3. 使用 Azure 静态 Web 应用服务免费部署 Hexo 博客

    一.前言 最近在折腾 Hexo 博客,试了一下 Azure 的静态 Web 应用服务,发现特别适合静态文档类型的网站,而且具有免费额度,支持绑定域名.本文只是以 Hexo 作为示例,其他类型的框架也是 ...

  4. SSH远程登录:两台或多台服务器之间免密登录设置

    有两台(或多台)同局域网的服务器A:192.168.2.21,B:192.168.2.25.让A,B这两台服务器之间能两两互相免密登录,并且每台服务器都可以自我免密登录(自我免密登录即:ssh loc ...

  5. 字符输出流_Writer类&FileWrite类介绍和字符输出流的基本使用_写出单个字符到文件

    字符输出流_Writer类&FileWrite类介绍 java.io.Writer:字符输出流,是所有字符输出流的最顶层的父类,是一个抽象类 共性抽象方法: void write(int c) ...

  6. 别再用 System.currentTimeMillis 统计耗时了,太 Low,试试 Spring Boot 源码在用的 StopWatch吧,够优雅!

    大家好,我是二哥呀! 昨天,一位球友问我能不能给他解释一下 @SpringBootApplication 注解是什么意思,还有 Spring Boot 的运行原理,于是我就带着他扒拉了一下这个注解的源 ...

  7. python操作ini文件

    简介 ini文件作为常见的配置文件,因此需要对ini文件做处理,此处使用configparser模块,本文介绍以下ini文件常用的处理方式. 需要读取的ini文件 如下文件,[ ]包含的称为secti ...

  8. mysql show操作

    SHOW CHARACTER SET 显示所有可用的字符集 SHOW CHARACTER SET; SHOW CHARACTER SET LIKE 'latin%'; SHOW COLLATION 输 ...

  9. React报错之Property 'X' does not exist on type 'HTMLElement'

    正文从这开始~ 总览 在React中,当我们试图访问类型为HTMLElement 的元素上不存在的属性时,就会发生Property 'X' does not exist on type 'HTMLEl ...

  10. 万字长文:从计算机本源深入探寻volatile和Java内存模型

    万字长文:从计算机本源深入探寻volatile和Java内存模型 前言 在本篇文章当中,主要给大家深入介绍Volatile关键字和Java内存模型.在文章当中首先先介绍volatile的作用和Java ...