这是对莫凡python的学习笔记。

1.创建数据

import torch
import torch.utils.data as Data BATCH_SIZE = 8
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

可以看到创建了两个一维数据,x:1~10,y:10~1

2.构造数据集对象,及数据加载器对象

torch_dataset = Data.TensorDataset(x,y)
loader = Data.DataLoader(
dataset = torch_dataset,
batch_size = BATCH_SIZE,
shuffle = False,
num_workers = 2)

num_workers应该指的是多线程

3.输出数据集,这一步主要是看一下batch长什么样子

for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
print('Epoch:',epoch,'| Step:', step, '| batch x:',
batch_x.numpy(), '| batch y:', batch_y.numpy())

输出如下

('Epoch:', 0, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
('Epoch:', 0, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))
('Epoch:', 1, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6., 5., 4., 3.], dtype=float32))
('Epoch:', 1, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))
('Epoch:', 2, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6., 5., 4., 3.], dtype=float32))
('Epoch:', 2, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))

可以看到,batch_size等于8,则第二个bacth的数据只有两个。

将batch_size改为5,输出如下

('Epoch:', 0, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.], dtype=float32))
('Epoch:', 0, '| Step:', 1, '| batch x:', array([ 6., 7., 8., 9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))
('Epoch:', 1, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6.], dtype=float32))
('Epoch:', 1, '| Step:', 1, '| batch x:', array([ 6., 7., 8., 9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))
('Epoch:', 2, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6.], dtype=float32))
('Epoch:', 2, '| Step:', 1, '| batch x:', array([ 6., 7., 8., 9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))

pytorch批训练数据构造的更多相关文章

  1. pytorch:EDSR 生成训练数据的方法

    Pytorch:EDSR 生成训练数据的方法 引言 Winter is coming 正文 pytorch提供的DataLoader 是用来包装你的数据的工具. 所以你要将自己的 (numpy arr ...

  2. [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader

    [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 目录 [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 0x00 摘要 0x01 ...

  3. Pytorch之训练器设置

    Pytorch之训练器设置 引言 深度学习训练的时候有很多技巧, 但是实际用起来效果如何, 还是得亲自尝试. 这里记录了一些个人尝试不同技巧的代码. tensorboardX 说起tensorflow ...

  4. [NN] 随机VS批训练

    本文翻译节选自1998-Efficient BackProp, Yann LeCun et al.. 4.1 随机VS批训练 每一次迭代, 传统训练方式都需要遍历所有数据集来计算平均梯度. 批训练也同 ...

  5. [Pytorch]PyTorch Dataloader自定义数据读取

    整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...

  6. pytorch1.0批训练神经网络

    pytorch1.0批训练神经网络 import torch import torch.utils.data as Data # Torch 中提供了一种帮助整理数据结构的工具, 叫做 DataLoa ...

  7. libsvm的安装,数据格式,常见错误,grid.py参数选择,c-SVC过程,libsvm参数解释,svm训练数据,libsvm的使用详解,SVM核函数的选择

    直接conda install libsvm安装的不完整,缺几个.py文件. 第一种安装方法: 下载:http://www.csie.ntu.edu.tw/~cjlin/cgi-bin/libsvm. ...

  8. Alink漫谈(七) : 如何划分训练数据集和测试数据集

    Alink漫谈(七) : 如何划分训练数据集和测试数据集 目录 Alink漫谈(七) : 如何划分训练数据集和测试数据集 0x00 摘要 0x01 训练数据集和测试数据集 0x02 Alink示例代码 ...

  9. GitHub上YOLOv5开源代码的训练数据定义

    GitHub上YOLOv5开源代码的训练数据定义 代码地址:https://github.com/ultralytics/YOLOv5 训练数据定义地址:https://github.com/ultr ...

随机推荐

  1. log4j配置文件加载方式

    使用背景: apache的log4j是一个功能强大的日志文件,当我们使用eclipse等IDE在项目中配置log4j的时候,需要知道我们的配置文件的加载方式以及如何被加载的. 加载方式: (1).自动 ...

  2. java多线程编程——同步器Exchanger

    类java.util.concurrent.Exchanger提供了一个同步点,在这个同步点,一对线程可以交换数据.每个线程通过exchange()方法的入口提供数据给他的伙伴线程,并接收他的伙伴线程 ...

  3. css知多少(3)——样式来源与层叠规则(转)

    css知多少(3)——样式来源与层叠规则   上一节<css知多少(2)——学习css的思路>有几个人留言表示思路很好.继续期待,而且收到了9个赞,我还是比较欣慰的.没看过的朋友建议先去看 ...

  4. CentOS7下安装pip和pip3

    1.首先检查linux有没有安装python-pip包,直接执行 yum install python-pip 2.没有python-pip包就执行命令 yum -y install epel-rel ...

  5. ngx-bootstrap使用01 安装ngx-bootstrap和bootstrap及其使用、外部样式引入

    1 版本说明 2 新建一个angular项目 ng new 项目名 --stayle=scss 代码解释:创建一个样式文件格式为SCSS的angular项目 技巧01:由于我angular-cli的版 ...

  6. c语言实战: 计算时间差

    计算时间差有两种,一种是把时间都转化为分钟数,一种是把时间都转化为小时,后者是会用到除法所以不可避免产生浮点数,所以我们选择转化为分钟数来计算. //题目:给定两个时间点计算它们的时间差,比如,1:5 ...

  7. 算法Sedgewick第四版-第1章基础-2.1Elementary Sortss-008排序算法的复杂度(比较次数的上下限)

    一. 1. 2.

  8. centos 6.5安装 redis

    版本:redis-2.8.19.tar.gz 检查下面依赖是否安装,如果没有要先安装,不然会有异常. yum install gcc-c++ yum install -y tcl. .获取安装文件 r ...

  9. NSButton添加事件

    -(void)addButton { NSButton* pushButton = [[NSButton alloc] initWithFrame: NSMakeRect(, , , )]; push ...

  10. linux删除文件、创建文件

    1.删除文件 rm huahua.txt 2.创建文件 touch huahua.txt