Pytorch入门随手记
Pytorch入门随手记
什么是Pytorch?
Pytorch是Torch到Python上的移植(Torch原本是用Lua语言编写的)
是一个动态的过程,数据和图是一起建立的。
tensor.dot(tensor1,tensor2)//tensor各个对应位置相乘再相加
print(net)可以输出网络结构
Pytorch的动态性:网络参数可以有多个不固定的,例如:
来源:https://morvanzhou.github.io/tutorials/machine-learning/torch/5-01-dynamic/
最典型的例子就是 RNN, 有时候 RNN 的 time step 不会一样, 或者在 training 和 testing 的时候,
batch_size和time_step也不一样, 这时, Tensorflow 就头疼了, Tensorflow 的人也头疼了. 哈哈, 如果用一个动态计算图的 Torch, 我们就好理解多了, 写起来也简单多了.激活函数使用层和function,在效果上没什么区别
使用torch.nn.Sequential快速搭建模型
torch.nn.Sequential(
#eg
torch.nn.linear(2,10),
torch.nn.ReLU(),
torch.nn.linear(10,2),
)
这里使用的是匿名对象,所以print出来之后是没有类型名称(即self.hidden和self.predict之类的,输出的时候会显示hidden和predict).
保存和提取神经网络
保存
torch.save(net,"net.pkl")#保存整个神经网络模型,类型名为pkl
torch.save(net.state_dict(),"net_params.pkl")#只保存参数而不保存整个网络
提取
net=torch.load("net.pkl")#提取网络
net2=torch.nn.Sequential(
这里只是举了用Sequential来创建网络的例子,如果不用这种匿名方法的话也是一样的,就是在提取参数之前要搭建一个和原网络完全一样的网络结构
)
net2.load_state_dict(torch.load("net_params.pkl"))#只提取参数
批训练(Mini Batch Training)
BATCH_SIZE=5
x=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)
torch_dataset=Data.TensorDataset(data_tensor=x,target_tensor=y)
loader=Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,#shuffle如果设置为true,则每次batch都是选择的不一样的数据,设置为False,则每次batch的数据都一样。
num_workers=2,#设置提取数据时候的线程数量
)
for epoch in range(3):
for step,(batch_x,batch_y)in enumerate(loader):#enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
#例如本例中,那个step就是提取的index超参数:在机器学习的上下文中,超参数是在开始学习过程之前设置值的参数,而不是通过训练得到的参数数据。通常情况下,需要对超参数进行优化,给学习机选择一组最优超参数,以提高学习的性能和效果。
由此可见,超参数一般是人为指定的、定义在模型之前的一些全局变量,它对模型和训练的过程进行控制。习惯上,用大写来表示。

我觉得它第三个for循环和zip合起来还挺灵性的。
len(train_loader)和len(train_loader.dataset)的区别
这里举个例子:
train_loader = torch.utils.data.DataLoader(
dataset=torch_train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=WORKERS)
len(train_loader.dataset)=len(torch_train_dataset),也就是数据集的大小,和batch_size无关
而len(train_loader)=len(train_loader.dataset)/batch_size并向上取整
Pytorch入门随手记的更多相关文章
- [pytorch] Pytorch入门
Pytorch入门 简单容易上手,感觉比keras好理解多了,和mxnet很像(似乎mxnet有点借鉴pytorch),记一记. 直接从例子开始学,基础知识咱已经看了很多论文了... import t ...
- pytorch 入门指南
两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ...
- 超简单!pytorch入门教程(五):训练和测试CNN
我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...
- pytorch入门2.2构建回归模型初体验(开始训练)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.0构建回归模型初体验(数据生成)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.1构建回归模型初体验(模型构建)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- Pytorch入门——手把手教你MNIST手写数字识别
MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...
- Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader
本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...
- Pytorch入门中 —— 搭建网络模型
本节内容参照小土堆的pytorch入门视频教程,主要通过查询文档的方式讲解如何搭建卷积神经网络.学习时要学会查询文档,这样会比直接搜索良莠不齐的博客更快.更可靠.讲解的内容主要是pytorch核心包中 ...
随机推荐
- 安装mysql后必须要做的一件事
Step 1. 检查默认账户和密码 $cat /etc/mysql/debian.cnf # 在ubuntu下查看默认账户名和密码 会看到 [client] host = localhost user ...
- ELK的安全解决方案 X-Pack(1)
安装 X-Pack 前必须安装 elasticsearch. Kibana.logstash,因为之前安装ELK选择的版本都是5.4.1,所以这次选择X-Pack的版本也要是5.4.1的 第一步:下载 ...
- 002-jdk-数据结构-工具类Collections、Arrays、System.arraycopy
常用备注 一.LIst to Array List<String> list = new ArrayList<String>(); Object[] array=list.to ...
- setInterval、clearInterval的回调函数,实现函数间调用的先后顺序
定义: var waitUnitil=function (untillCallBack, nextStepCallBack, count) { if (count == null) { count = ...
- Ubuntu 16.04安装MySQL设置远程访问出现问题的完美解决方案(error:10061)
一.安装mysql 安装mysql过程中,需要设置mysql的root账号的密码,不要忽略了. sudo apt-get install mysql-server apt isntall mysql- ...
- ASP.NET 拼多多用户登录授权后使用code去换取access_token
一.拼多多开放平台 由于本人刚毕业进公司实习 遇到一些问题然后想通过博客来记录和分享给大家一起学习. 第一次写博客没什么经验不是写的很好 请大家多多关照 嘴下留情哈哈 谢谢! 好了 话不多说直接进入主 ...
- 攻防世界WEB新手练习
0x01 view_source 0x02 get_post 这道题就是最基础的get和post请求的发送 flag:cyberpeace{b1e763710ff23f2acf16c2358d3132 ...
- 【VS开发】图像颜色
版权声明:本文为博主原创文章,转载请注明出处http://blog.csdn.net/lg1259156776/. 最近被图像颜色整的乱七八糟的,一会儿YUV422,一会儿RGB,一会儿gray... ...
- npm EPERM: operation not permitted
缓存问题导致 需要删除npmrc文件. 强调:不是nodejs安装目录npm模块下的那个npmrc文件 而是在C:\Users\{账户}\下的.npmrc文件..
- 最新 竞技世界java校招面经 (含整理过的面试题大全)
从6月到10月,经过4个月努力和坚持,自己有幸拿到了网易雷火.京东.去哪儿.竞技世界等10家互联网公司的校招Offer,因为某些自身原因最终选择了竞技世界.6.7月主要是做系统复习.项目复盘.Leet ...