pytorch 6 batch_train 批训练
import torch
import torch.utils.data as Data
torch.manual_seed(1) # reproducible
# BATCH_SIZE = 5
BATCH_SIZE = 8 # 每次使用8个数据同时传入网路
x = torch.linspace(1, 10, 10) # this is x data (torch tensor)
y = torch.linspace(10, 1, 10) # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=False, # 设置不随机打乱数据 random shuffle for training
num_workers=2, # 使用两个进程提取数据,subprocesses for loading data
)
def show_batch():
for epoch in range(3): # 全部的数据使用3遍,train entire dataset 3 times
for step, (batch_x, batch_y) in enumerate(loader): # for each training step
# train your data...
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
show_batch()
BATCH_SIZE = 8 , 所有数据利用三次
Epoch: 0 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 0 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
Epoch: 1 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 1 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
Epoch: 2 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 2 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
END
pytorch 6 batch_train 批训练的更多相关文章
- pytorch:EDSR 生成训练数据的方法
Pytorch:EDSR 生成训练数据的方法 引言 Winter is coming 正文 pytorch提供的DataLoader 是用来包装你的数据的工具. 所以你要将自己的 (numpy arr ...
- MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...
- [NN] 随机VS批训练
本文翻译节选自1998-Efficient BackProp, Yann LeCun et al.. 4.1 随机VS批训练 每一次迭代, 传统训练方式都需要遍历所有数据集来计算平均梯度. 批训练也同 ...
- pytorch1.0批训练神经网络
pytorch1.0批训练神经网络 import torch import torch.utils.data as Data # Torch 中提供了一种帮助整理数据结构的工具, 叫做 DataLoa ...
- [源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路
[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路 目录 [源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路 0x00 摘要 0x01 痛点 0x02 难点 0 ...
- [源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程
[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程 目录 [源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程 0x00 摘要 0x01 ...
- [源码解析] PyTorch 分布式之弹性训练(3)---代理
[源码解析] PyTorch 分布式之弹性训练(3)---代理 目录 [源码解析] PyTorch 分布式之弹性训练(3)---代理 0x00 摘要 0x01 总体背景 1.1 功能分离 1.2 Re ...
- [源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑
[源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑 目录 [源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑 0x00 ...
- [源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎
[源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎 目录 [源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎 0x00 摘要 0x0 ...
随机推荐
- Myeclipse中将项目上传到码云
公司实习之后习惯是代码上传到svn上,最近想起来个人的一些代码上传的到码云上比较方便,根据网上分享的博客内容结合自己的整理记录 其中大多数是参考了https://blog.csdn.net/izzyl ...
- Centos如何安装 jdk 环境变量
一.编辑 profile 文件 vim /etc/profile 二.在 profile 文件下面最下面加上以下内容 export JAVA_HOME=/usr/local/java/jdk1.7.0 ...
- 模仿学习小游戏外星人入侵-Python学习,体会“函数”编程
游戏类如下: # !/usr/bin/python # -*- coding:utf-8 -*- """ Author :ZFH File :alien.py Softw ...
- UVa11183 - Teen Girl Squad(最小树形图-裸)
Problem I Teen Girl Squad Input: Standard Input Output: Standard Output -- 3 spring rolls please. - ...
- No unique bean of type [net.shougongfang.action.paymoney.AlipayPayMoneyReturnObj] is defined: Unsat
0 你把@Service放到实现类上吧.这个问题好像不止一个人在问啦 2013年10月25日 10:34 shidan66 30 0 1 1 加入评论 00 1,@service放到实现上 2. ...
- 技术总结--android篇(一)--MVC模式
先介绍下MVC模式:MVC全名是Model View Controller,是模型(model)-视图(view)-控制器(controller)的缩写,一种软件设计典范,用一种业务逻辑.数据.界面显 ...
- 【UML】UML世界的构成
UML概述 全名:Unified Modeling Language 中文名:统一建模语言 发展历程:"始于1997年一个OMG标准.它是一个支持模型化和软件系统开发的图形化语言,为软件开发 ...
- LeetCode——Valid Parentheses
Given a string containing just the characters '(', ')', '{', '}', '[' and ']', determine if the inpu ...
- Word技巧杂记(一)——去掉页眉上方的黑线
今天在调整文章的格式时,突然发现在页眉的上方有一条巨粗无比的黑线,不知从何处冒出来的(如下图) 经过长时间的研究,终于发现原来这是页面的边框.解决办法也很简单: 格式->边框与底纹->页面 ...
- 【BZOJ 2038】小Z的袜子
[题目链接] https://www.lydsy.com/JudgeOnline/problem.php?id=2038 [算法] 莫队算法 [代码] #include<bits/stdc++. ...