【小白学PyTorch】9 tensor数据结构与存储结构
文章来自微信公众号【机器学习炼丹术】。
上一节课,讲解了MNIST图像分类的一个小实战,现在我们继续深入学习一下pytorch的一些有的没的的小知识来作为只是储备。
参考目录:
@
1 pytorch数据结构
1.1 默认整数与浮点数
【pytorch默认的整数是int64】
pytorch的默认整数是用64个比特存储,也就是8个字节(Byte)存储的。
【pytorch默认的浮点数是float32】
pytorch的默认浮点数是用32个比特存储,也就是4个字节(Byte)存储的。
import torch
import numpy as np
#----------------------
print('torch的浮点数与整数的默认数据类型')
a = torch.tensor([1,2,3])
b = torch.tensor([1.,2.,3.])
print(a,a.dtype)
print(b,b.dtype)
输出:
torch的浮点数与整数的默认数据类型
tensor([1, 2, 3]) torch.int64
tensor([1., 2., 3.]) torch.float32
1.2 dtype修改变量类型
print('torch的浮点数与整数的默认数据类型')
a = torch.tensor([1,2,3],dtype=torch.int8)
b = torch.tensor([1.,2.,3.],dtype = torch.float64)
print(a,a.dtype)
print(b,b.dtype)
输出结果:
torch的浮点数与整数的默认数据类型
tensor([1, 2, 3], dtype=torch.int8) torch.int8
tensor([1., 2., 3.], dtype=torch.float64) torch.float64
1.3 变量类型有哪些
张量的数据类型其实和numpy.array基本一一对应,除了不支持str
,主要有下面几种形式:
torch.float64 # 等同于(torch.double)
torch.float32 # 默认,FloatTensor
torch.float16
torch.int64 # 等同于torch.long
torch.int32 # 默认
torch.int16
torch.int8
torch.uint8 # 二进制码,表示0-255
torch.bool
在创建变量的时候,想要创建指定的变量类型,上文中提到了用dtype关键字来控制,但是我个人更喜欢使用特定的构造函数:
print('torch的构造函数')
a = torch.IntTensor([1,2,3])
b = torch.LongTensor([1,2,3])
c = torch.FloatTensor([1,2,3])
d = torch.DoubleTensor([1,2,3])
e = torch.tensor([1,2,3])
f = torch.tensor([1.,2.,3.])
print(a.dtype)
print(b.dtype)
print(c.dtype)
print(d.dtype)
print(e.dtype)
print(f.dtype)
输出结果:
torch的构造函数
torch.int32
torch.int64
torch.float32
torch.float64
torch.int64
torch.float32
因此我们可以得到结果:
torch.IntTensor
对应torch.int32
torch.LongTensor
对应torch.int64
,LongTensor常用在深度学习中的标签值 ,比方说分类任务中的类别标签0,1,2,3等,要求用ing64的数据类型;torch.FloatTensor
对应torch.float32
。FloatTensor常用做深度学习中可学习参数或者输入数据的类型torch.DoubleTensor
对应torch.float64
torch.tensor
则有一个推断的能力,加入输入的数据是整数,则默认int64,相当于LongTensor;假如输入数据是浮点数,则默认float32,相当于FLoatTensor。刚好对应深度学习中的标签和参数的数据类型,所以一般情况下,直接使用tensor就可以了,但是假如出现报错的时候,也要学会使用dtype或者构造函数来确保数据类型的匹配
1.4 数据类型转换
【使用torch.float()方法】
print('数据类型转换')
a = torch.tensor([1,2,3])
b = a.float()
c = a.double()
d = a.long()
print(b.dtype)
print(c.dtype)
print(d.dtype)
>>> 数据类型转换
>>> torch.float32
>>> torch.float64
>>> torch.int64
我个人比较习惯这个的方法。
【使用type方法】
b = a.type(torch.float32)
c = a.type(torch.float64)
d = a.type(torch.int64)
print(b.dtype) # torch.float32
print(c.dtype) # torch.float64
print(d.dtype) # torch.int64
2 torch vs numpy
PyTorch是一个python包,目的是加入深度学习应用, torch基本上是实现了numpy的大部分必要的功能,并且tensor是可以利用GPU进行加速训练的。
2.1 两者转换
转换时非常非常简单的:
import torch
import numpy as np
a = np.array([1.,2.,3.])
b = torch.tensor(a)
c = b.numpy()
print(a)
print(b)
print(c)
输出结果:
[1. 2. 3.]
tensor([1., 2., 3.], dtype=torch.float64)
[1. 2. 3.]
下面的内容就变得有点意思了,是内存复制相关的。假如a和b两个变量共享同一个内存,那么改变a的话,b也会跟着改变;如果a和b变量的内存复制了,那么两者是两个内存,所以改变a是不会改变b的。下面是讲解numpy和torch互相转换的时候,什么情况是共享内存,什么情况下是内存复制 (其实这个问题,也就是做个了解罢了,无用的小知识)
【Tensor()转换】
当numpy的数据类型和torch的数据类型相同时,共享内存;不同的时候,内存复制
print('numpy 和torch互相转换1')
a = np.array([1,2,3],dtype=np.float64)
b = torch.Tensor(a)
b[0] = 999
print('共享内存' if a[0]==b[0] else '不共享内存')
>>> 不共享内存
因为np.float64和torch.float32数据类型不同
print('numpy 和torch互相转换2')
a = np.array([1,2,3],dtype=np.float32)
b = torch.Tensor(a)
b[0] = 999
print('共享内存' if a[0]==b[0] else '不共享内存')
>>> 共享内存
因为np.float32和torch.float32数据类型相同
【from_numpy()转换】
print('from_numpy()')
a = np.array([1,2,3],dtype=np.float64)
b = torch.from_numpy(a)
b[0] = 999
print('共享内存' if a[0]==b[0] else '不共享内存')
>>> 共享内存
a = np.array([1,2,3],dtype=np.float32)
b = torch.from_numpy(a)
b[0] = 999
print('共享内存' if a[0]==b[0] else '不共享内存')
>>> 共享内存
如果你使用from_numpy()的时候,不管是什么类型,都是共享内存的。
【tensor()转换】
更常用的是这个tensor(),注意看T的大小写, 如果使用的是tensor方法,那么不管输入类型是什么,torch.tensor都会进行数据拷贝,不共享内存。
【.numpy()】
tensor转成numpy的时候,.numpy
方法是内存共享的哦。如果想改成内存拷贝的话,可以使用.numpy().copy()
就不共享内存了。或者使用.clone().numpy()
也可以实现同样的效果。clone是tensor的方法,copy是numpy的方法。
【总结】
记不清的话,就记住,tensor()数据拷贝了,.numpy()共享内存就行了。
2.2 两者区别
【命名】
虽然PyTorch实现了Numpy的很多功能,但是相同的功能却有着不同的命名方式,这让使用者迷惑。
例如创建随机张量的时候:
print('命名规则')
a = torch.rand(2,3,4)
b = np.random.rand(2,3,4)
【张量重塑】
这部分会放在下一章节详细说明~
3 张量
- 标量:数据是一个数字
- 向量:数据是一串数字,也是一维张量
- 矩阵:数据二维数组,也是二维张量
- 张量:数据的维度超过2的时候,就叫多维张量
3.1 张量修改尺寸
- pytorch常用reshape和view
- numpy用resize和reshape
- pytorch也有resize但是不常用
【reshape和view共享内存(常用)】
a = torch.arange(0,6)
b = a.reshape((2,3))
print(b)
c = a.view((2,3))
print(c)
a[0] = 999
print(b)
print(c)
输出结果:
tensor([[0, 1, 2],
[3, 4, 5]])
tensor([[0, 1, 2],
[3, 4, 5]])
tensor([[999, 1, 2],
[ 3, 4, 5]])
tensor([[999, 1, 2],
[ 3, 4, 5]])
上面的a,b,c三个变量其实是共享同一个内存,迁一而动全身。而且要求遵旨规则:原始数据有6个元素,所以可以修改成\(2\times 3\)的形式,但是无法修改成\(2\times 4\)的形式 ,我们来试试:
a = torch.arange(0,6)
b = a.reshape((2,4))
会抛出这样的错误:
【torch的resize_(不常用)】
但是pytorch有一个不常用的函数(对我来说用的不多),resize
,这个方法可以不遵守这个规则:
a = torch.arange(0,6)
a.resize_(2,4)
print(a)
输出结果为:
自动的补充了两个元素。虽然不知道这个函数有什么意义。。。。。。
这里可以看到函数resize后面有一个_,这个表示inplace=True的意思,当有这个_或者参数inplace的时候,就是表示所作的修改是在原来的数据变量上完成的,也就不需要赋值给新的变量了。
【numpy的resize与reshape(常用)】
import numpy as np
a = np.arange(0,6)
a.resize(2,3)
print(a)
import numpy as np
a = np.arange(0,6)
b = a.reshape(2,3)
print(b)
两个代码块的输出都是下面的,区别在于numpy的resize是没有返回值的,相当于inplace=True了,直接在原变量的进行修改,而reshape是有返回值的,不在原变量上修改(但是呢reshape是共享内存的):
[[0 1 2]
[3 4 5]]
3.2 张量内存存储结构
tensor
的数据结构包含两个部分:
- 头信息区Tensor:保存张量的形状size,步长stride,数据类型等信息
- 存储区Storage:保存真正的数据
头信息区Tensor的占用内存较小,主要的占用内存是Storate。
每一个tensor都有着对应的storage,一般不同的tensor的头信息可能不同,但是却可能使用相同的storage。(这里就是之前共享内存的view、reshape方法,虽然头信息的张量形状size发生了改变,但是其实存储的数据都是同一个storage)
3.3 存储区
我们来查看一个tensor的存储区:
import torch
a = torch.arange(0,6)
print(a.storage())
输出为:
0
1
2
3
4
5
[torch.LongStorage of size 6]
然后对tensor变量做一个view的变换:
b = a.view(2,3)
这个b.storage()
输出出来时和a.storate()
,相同的,这也是为什么view变换是内存共享的了。
# id()是获取对象的内存地址
print(id(a)==id(b)) # False
print(id(a.storage)==id(b.storage)) # True
可以发现,其实a和b虽然存储区是相同的,但是其实a和b整体式不同的。自然,这个不同就不同在头信息区,应该是尺寸size改变了。这也就是头信息区不同,但是存储区相同,从而节省大量内存
我们更进一步,假设对tensor切片了,那么切片后的数据是否共享内存,切片后的数据的storage是什么样子的呢?
print('研究tensor的切片')
a = torch.arange(0,6)
b = a[2]
print(id(a.storage)==id(b.storage))
输出结果为:
>>> True
没错,就算切片之后,两个tensor依然使用同一个存储区,所以相比也是共享内存的,修改一个另一个也会变化。
#.data_ptr(),返回tensor首个元素的内存地址。
print(a.data_ptr(),b.data_ptr())
print(b.data_ptr()-a.data_ptr())
输出为:
2080207827328 2080207827344
16
这是因为b的第一个元素和a的第一个元素内存地址相差了16个字节,因为默认的tesnor是int64,也就是8个字节一个元素,所以这里相差了2个整形元素
3.4 头信息区
依然是上面那两个tensor变量,a和b
a = torch.arange(0,6)
b = a.view(2,3)
print(a.stride(),b.stride())
输出为:
(1,) (3, 1)
变量a是一维数组,并且就是[0,1,2,3,4,5],所以步长stride是1;而b是二维数组,是[[0,1,2],[3,4,5]],所以就是先3个3个分成第一维度的,然后再1个1个的作为第二维度。
由此可见,绝大多数操作并不修改 tensor 的数据,只是修改了 tensor 的头信息,这种做法更节省内存,同时提升了处理速度。
【小白学PyTorch】9 tensor数据结构与存储结构的更多相关文章
- 【小白学PyTorch】20 TF2的eager模式与求导
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
- SQL Server 索引(一)数据结构和存储结构
本文关注以下方面(本文所有的讨论基于SQL Server数据库): 索引的分类: 索引的结构: 索引的存储 一.索引定义分类 让我们先来回答几个问题: 什么是索引? 索引是对数据库表中一列或多列的值进 ...
- 【小白学PyTorch】6 模型的构建访问遍历存储(附代码)
文章转载自微信公众号:机器学习炼丹术.欢迎大家关注,这是我的学习分享公众号,100+原创干货. 文章目录: 目录 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 ...
- 【小白学PyTorch】19 TF2模型的存储与载入
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
- 小白学PyTorch 动态图与静态图的浅显理解
文章来自公众号[机器学习炼丹术],回复"炼丹"即可获得海量学习资料哦! 目录 1 动态图的初步推导 2 动态图的叶子节点 3. grad_fn 4 静态图 本章节缕一缕PyTorc ...
- 【小白学PyTorch】17 TFrec文件的创建与读取
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
- 【小白学PyTorch】1 搭建一个超简单的网络
文章目录: 目录 1 任务 2 实现思路 3 实现过程 3.1 引入必要库 3.2 创建训练集 3.3 搭建网络 3.4 设置优化器 3.5 训练网络 3.6 测试 1 任务 首先说下我们要搭建的网络 ...
- 【小白学PyTorch】3 浅谈Dataset和Dataloader
文章目录: 目录 1 Dataset基类 2 构建Dataset子类 2.1 Init 2.2 getitem 3 dataloader 1 Dataset基类 PyTorch 读取其他的数据,主要是 ...
- 【小白学PyTorch】5 torchvision预训练模型与数据集全览
文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...
随机推荐
- java 网络通信协议、UDP与TCP
一 网络通信协议 通过计算机网络可以使多台计算机实现连接,位于同一个网络中的计算机在进行连接和通信时需要遵守一定 的规则,这就好比在道路中行驶的汽车一定要遵守交通规则一样.在计算机网络中,这些连接和通 ...
- c# Attribute会不会影响性能
Attribute很方便,标记一个类,设置这个类的额外信息,而不用另外设计存储这个信息. 那么频繁大量使用Attribute会不会影响类的性能. 为此,简单测试. 代码: 略............. ...
- Android TextView 字数过多,用跑马灯滚动形式实现
上代码: <TextView android:layout_width="120dp" android:layout_height="wrap_content&qu ...
- JavaScript 用七种方式教你判断一个变量是否为数组类型
JavaScript 如何判断一个变量是否为数组类型 引言 正文 方法一 方法二 方法三 方法四 方法五 方法六 方法七 结束语 引言 我们如何判断一个变量是否为数组类型呢? 今天来给大家介绍七种方式 ...
- linux root用户下没有.ssh目录
.ssh 是记录密码信息的文件夹,如果没有登录过root的话,就没有 .ssh 文件夹,因此登录 localhost ,并输入密码就会生成了 ssh localhost
- Fiddler显示指定host请求,以及过滤无用的css,js
第一步 右侧窗口点击filters 第二步 点击Use Fiters 第三步 第一个选项不动 no zone filter ,第二个选项选择 show only following hosts 第四步 ...
- TCL(事务控制语言)
#TCL/*Transaction Control Language 事务控制语言 事务:一个或一组sql语句组成一个执行单元,这个执行单元要么全部执行,要么全部不执行. 案例:转账 张三丰 1000 ...
- 存储池与存储卷,使用virt-install创建虚拟机
原文链接:https://www.cnblogs.com/zknublx/p/9199658.html 创建存储池 1.建立存储池的目录 mkdir /kvm/images 2.为了安全性,更改目录的 ...
- 从Vessel到二代裸金属容器,云原生的新一波技术浪潮涌向何处?
摘要:云原生大势,深度解读华为云四大容器解决方案如何加速技术产业融合. 云原生,可能是这两年云服务领域最火的词. 相较于传统的应用架构,云原生构建应用简便快捷,部署应用轻松自如.运行应用按需伸缩,是企 ...
- 精讲响应式WebClient第6篇-请求失败自动重试机制,强烈建议你看一看
本文是精讲响应式WebClient第6篇,前篇的blog访问地址如下: 精讲响应式webclient第1篇-响应式非阻塞IO与基础用法 精讲响应式WebClient第2篇-GET请求阻塞与非阻塞调用方 ...