文章来自微信公众号【机器学习炼丹术】。

上一节课,讲解了MNIST图像分类的一个小实战,现在我们继续深入学习一下pytorch的一些有的没的的小知识来作为只是储备。

参考目录:

@

1 pytorch数据结构

1.1 默认整数与浮点数

【pytorch默认的整数是int64】

pytorch的默认整数是用64个比特存储,也就是8个字节(Byte)存储的。

【pytorch默认的浮点数是float32】

pytorch的默认浮点数是用32个比特存储,也就是4个字节(Byte)存储的。

import torch
import numpy as np
#----------------------
print('torch的浮点数与整数的默认数据类型')
a = torch.tensor([1,2,3])
b = torch.tensor([1.,2.,3.])
print(a,a.dtype)
print(b,b.dtype)

输出:

torch的浮点数与整数的默认数据类型
tensor([1, 2, 3]) torch.int64
tensor([1., 2., 3.]) torch.float32

1.2 dtype修改变量类型

print('torch的浮点数与整数的默认数据类型')
a = torch.tensor([1,2,3],dtype=torch.int8)
b = torch.tensor([1.,2.,3.],dtype = torch.float64)
print(a,a.dtype)
print(b,b.dtype)

输出结果:

torch的浮点数与整数的默认数据类型
tensor([1, 2, 3], dtype=torch.int8) torch.int8
tensor([1., 2., 3.], dtype=torch.float64) torch.float64

1.3 变量类型有哪些

张量的数据类型其实和numpy.array基本一一对应,除了不支持str,主要有下面几种形式:

torch.float64 # 等同于(torch.double)
torch.float32 # 默认,FloatTensor
torch.float16
torch.int64 # 等同于torch.long
torch.int32 # 默认
torch.int16
torch.int8
torch.uint8 # 二进制码,表示0-255
torch.bool

在创建变量的时候,想要创建指定的变量类型,上文中提到了用dtype关键字来控制,但是我个人更喜欢使用特定的构造函数:

print('torch的构造函数')
a = torch.IntTensor([1,2,3])
b = torch.LongTensor([1,2,3])
c = torch.FloatTensor([1,2,3])
d = torch.DoubleTensor([1,2,3])
e = torch.tensor([1,2,3])
f = torch.tensor([1.,2.,3.])
print(a.dtype)
print(b.dtype)
print(c.dtype)
print(d.dtype)
print(e.dtype)
print(f.dtype)

输出结果:

torch的构造函数
torch.int32
torch.int64
torch.float32
torch.float64
torch.int64
torch.float32

因此我们可以得到结果:

  • torch.IntTensor对应torch.int32
  • torch.LongTensor对应torch.int64,LongTensor常用在深度学习中的标签值 ,比方说分类任务中的类别标签0,1,2,3等,要求用ing64的数据类型;
  • torch.FloatTensor对应torch.float32FloatTensor常用做深度学习中可学习参数或者输入数据的类型
  • torch.DoubleTensor对应torch.float64
  • torch.tensor则有一个推断的能力,加入输入的数据是整数,则默认int64,相当于LongTensor;假如输入数据是浮点数,则默认float32,相当于FLoatTensor。刚好对应深度学习中的标签和参数的数据类型,所以一般情况下,直接使用tensor就可以了,但是假如出现报错的时候,也要学会使用dtype或者构造函数来确保数据类型的匹配

1.4 数据类型转换

【使用torch.float()方法】

print('数据类型转换')
a = torch.tensor([1,2,3])
b = a.float()
c = a.double()
d = a.long()
print(b.dtype)
print(c.dtype)
print(d.dtype)
>>> 数据类型转换
>>> torch.float32
>>> torch.float64
>>> torch.int64

我个人比较习惯这个的方法。

【使用type方法】

b = a.type(torch.float32)
c = a.type(torch.float64)
d = a.type(torch.int64)
print(b.dtype) # torch.float32
print(c.dtype) # torch.float64
print(d.dtype) # torch.int64

2 torch vs numpy

PyTorch是一个python包,目的是加入深度学习应用, torch基本上是实现了numpy的大部分必要的功能,并且tensor是可以利用GPU进行加速训练的。

2.1 两者转换

转换时非常非常简单的:

import torch
import numpy as np a = np.array([1.,2.,3.])
b = torch.tensor(a)
c = b.numpy()
print(a)
print(b)
print(c)

输出结果:

[1. 2. 3.]
tensor([1., 2., 3.], dtype=torch.float64)
[1. 2. 3.]

下面的内容就变得有点意思了,是内存复制相关的。假如a和b两个变量共享同一个内存,那么改变a的话,b也会跟着改变;如果a和b变量的内存复制了,那么两者是两个内存,所以改变a是不会改变b的。下面是讲解numpy和torch互相转换的时候,什么情况是共享内存,什么情况下是内存复制 (其实这个问题,也就是做个了解罢了,无用的小知识)

【Tensor()转换】

当numpy的数据类型和torch的数据类型相同时,共享内存;不同的时候,内存复制

print('numpy 和torch互相转换1')
a = np.array([1,2,3],dtype=np.float64)
b = torch.Tensor(a)
b[0] = 999
print('共享内存' if a[0]==b[0] else '不共享内存')
>>> 不共享内存

因为np.float64和torch.float32数据类型不同

print('numpy 和torch互相转换2')
a = np.array([1,2,3],dtype=np.float32)
b = torch.Tensor(a)
b[0] = 999
print('共享内存' if a[0]==b[0] else '不共享内存')
>>> 共享内存

因为np.float32和torch.float32数据类型相同

【from_numpy()转换】

print('from_numpy()')
a = np.array([1,2,3],dtype=np.float64)
b = torch.from_numpy(a)
b[0] = 999
print('共享内存' if a[0]==b[0] else '不共享内存')
>>> 共享内存
a = np.array([1,2,3],dtype=np.float32)
b = torch.from_numpy(a)
b[0] = 999
print('共享内存' if a[0]==b[0] else '不共享内存')
>>> 共享内存

如果你使用from_numpy()的时候,不管是什么类型,都是共享内存的。

【tensor()转换】

更常用的是这个tensor(),注意看T的大小写, 如果使用的是tensor方法,那么不管输入类型是什么,torch.tensor都会进行数据拷贝,不共享内存。

【.numpy()】

tensor转成numpy的时候,.numpy方法是内存共享的哦。如果想改成内存拷贝的话,可以使用.numpy().copy()就不共享内存了。或者使用.clone().numpy()也可以实现同样的效果。clone是tensor的方法,copy是numpy的方法。

【总结】

记不清的话,就记住,tensor()数据拷贝了,.numpy()共享内存就行了。

2.2 两者区别

【命名】

虽然PyTorch实现了Numpy的很多功能,但是相同的功能却有着不同的命名方式,这让使用者迷惑。

例如创建随机张量的时候:

print('命名规则')
a = torch.rand(2,3,4)
b = np.random.rand(2,3,4)

【张量重塑】

这部分会放在下一章节详细说明~

3 张量

  • 标量:数据是一个数字
  • 向量:数据是一串数字,也是一维张量
  • 矩阵:数据二维数组,也是二维张量
  • 张量:数据的维度超过2的时候,就叫多维张量

3.1 张量修改尺寸

  • pytorch常用reshape和view
  • numpy用resize和reshape
  • pytorch也有resize但是不常用

【reshape和view共享内存(常用)】

a = torch.arange(0,6)
b = a.reshape((2,3))
print(b)
c = a.view((2,3))
print(c)
a[0] = 999
print(b)
print(c)

输出结果:

tensor([[0, 1, 2],
[3, 4, 5]])
tensor([[0, 1, 2],
[3, 4, 5]])
tensor([[999, 1, 2],
[ 3, 4, 5]])
tensor([[999, 1, 2],
[ 3, 4, 5]])

上面的a,b,c三个变量其实是共享同一个内存,迁一而动全身。而且要求遵旨规则:原始数据有6个元素,所以可以修改成\(2\times 3\)的形式,但是无法修改成\(2\times 4\)的形式 ,我们来试试:

a = torch.arange(0,6)
b = a.reshape((2,4))

会抛出这样的错误:

【torch的resize_(不常用)】

但是pytorch有一个不常用的函数(对我来说用的不多),resize,这个方法可以不遵守这个规则:

a = torch.arange(0,6)
a.resize_(2,4)
print(a)

输出结果为:



自动的补充了两个元素。虽然不知道这个函数有什么意义。。。。。。

这里可以看到函数resize后面有一个_,这个表示inplace=True的意思,当有这个_或者参数inplace的时候,就是表示所作的修改是在原来的数据变量上完成的,也就不需要赋值给新的变量了。

【numpy的resize与reshape(常用)】

import numpy as np
a = np.arange(0,6)
a.resize(2,3)
print(a)
import numpy as np
a = np.arange(0,6)
b = a.reshape(2,3)
print(b)

两个代码块的输出都是下面的,区别在于numpy的resize是没有返回值的,相当于inplace=True了,直接在原变量的进行修改,而reshape是有返回值的,不在原变量上修改(但是呢reshape是共享内存的):

[[0 1 2]
[3 4 5]]

3.2 张量内存存储结构

tensor的数据结构包含两个部分:

  • 头信息区Tensor:保存张量的形状size,步长stride,数据类型等信息
  • 存储区Storage:保存真正的数据

头信息区Tensor的占用内存较小,主要的占用内存是Storate。

每一个tensor都有着对应的storage,一般不同的tensor的头信息可能不同,但是却可能使用相同的storage。(这里就是之前共享内存的view、reshape方法,虽然头信息的张量形状size发生了改变,但是其实存储的数据都是同一个storage)

3.3 存储区

我们来查看一个tensor的存储区:

import torch
a = torch.arange(0,6)
print(a.storage())

输出为:

 0
1
2
3
4
5
[torch.LongStorage of size 6]

然后对tensor变量做一个view的变换:

b = a.view(2,3)

这个b.storage()输出出来时和a.storate(),相同的,这也是为什么view变换是内存共享的了。

# id()是获取对象的内存地址
print(id(a)==id(b)) # False
print(id(a.storage)==id(b.storage)) # True

可以发现,其实a和b虽然存储区是相同的,但是其实a和b整体式不同的。自然,这个不同就不同在头信息区,应该是尺寸size改变了。这也就是头信息区不同,但是存储区相同,从而节省大量内存

我们更进一步,假设对tensor切片了,那么切片后的数据是否共享内存,切片后的数据的storage是什么样子的呢?

print('研究tensor的切片')
a = torch.arange(0,6)
b = a[2]
print(id(a.storage)==id(b.storage))

输出结果为:

>>> True

没错,就算切片之后,两个tensor依然使用同一个存储区,所以相比也是共享内存的,修改一个另一个也会变化。

#.data_ptr(),返回tensor首个元素的内存地址。
print(a.data_ptr(),b.data_ptr())
print(b.data_ptr()-a.data_ptr())

输出为:

2080207827328 2080207827344
16

这是因为b的第一个元素和a的第一个元素内存地址相差了16个字节,因为默认的tesnor是int64,也就是8个字节一个元素,所以这里相差了2个整形元素

3.4 头信息区

依然是上面那两个tensor变量,a和b

a = torch.arange(0,6)
b = a.view(2,3)
print(a.stride(),b.stride())

输出为:

(1,) (3, 1)

变量a是一维数组,并且就是[0,1,2,3,4,5],所以步长stride是1;而b是二维数组,是[[0,1,2],[3,4,5]],所以就是先3个3个分成第一维度的,然后再1个1个的作为第二维度。

由此可见,绝大多数操作并不修改 tensor 的数据,只是修改了 tensor 的头信息,这种做法更节省内存,同时提升了处理速度。

【小白学PyTorch】9 tensor数据结构与存储结构的更多相关文章

  1. 【小白学PyTorch】20 TF2的eager模式与求导

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  2. SQL Server 索引(一)数据结构和存储结构

    本文关注以下方面(本文所有的讨论基于SQL Server数据库): 索引的分类: 索引的结构: 索引的存储 一.索引定义分类 让我们先来回答几个问题: 什么是索引? 索引是对数据库表中一列或多列的值进 ...

  3. 【小白学PyTorch】6 模型的构建访问遍历存储(附代码)

    文章转载自微信公众号:机器学习炼丹术.欢迎大家关注,这是我的学习分享公众号,100+原创干货. 文章目录: 目录 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 ...

  4. 【小白学PyTorch】19 TF2模型的存储与载入

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  5. 小白学PyTorch 动态图与静态图的浅显理解

    文章来自公众号[机器学习炼丹术],回复"炼丹"即可获得海量学习资料哦! 目录 1 动态图的初步推导 2 动态图的叶子节点 3. grad_fn 4 静态图 本章节缕一缕PyTorc ...

  6. 【小白学PyTorch】17 TFrec文件的创建与读取

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  7. 【小白学PyTorch】1 搭建一个超简单的网络

    文章目录: 目录 1 任务 2 实现思路 3 实现过程 3.1 引入必要库 3.2 创建训练集 3.3 搭建网络 3.4 设置优化器 3.5 训练网络 3.6 测试 1 任务 首先说下我们要搭建的网络 ...

  8. 【小白学PyTorch】3 浅谈Dataset和Dataloader

    文章目录: 目录 1 Dataset基类 2 构建Dataset子类 2.1 Init 2.2 getitem 3 dataloader 1 Dataset基类 PyTorch 读取其他的数据,主要是 ...

  9. 【小白学PyTorch】5 torchvision预训练模型与数据集全览

    文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...

随机推荐

  1. java 网络通信协议、UDP与TCP

    一 网络通信协议 通过计算机网络可以使多台计算机实现连接,位于同一个网络中的计算机在进行连接和通信时需要遵守一定 的规则,这就好比在道路中行驶的汽车一定要遵守交通规则一样.在计算机网络中,这些连接和通 ...

  2. c# Attribute会不会影响性能

    Attribute很方便,标记一个类,设置这个类的额外信息,而不用另外设计存储这个信息. 那么频繁大量使用Attribute会不会影响类的性能. 为此,简单测试. 代码: 略............. ...

  3. Android TextView 字数过多,用跑马灯滚动形式实现

    上代码: <TextView android:layout_width="120dp" android:layout_height="wrap_content&qu ...

  4. JavaScript 用七种方式教你判断一个变量是否为数组类型

    JavaScript 如何判断一个变量是否为数组类型 引言 正文 方法一 方法二 方法三 方法四 方法五 方法六 方法七 结束语 引言 我们如何判断一个变量是否为数组类型呢? 今天来给大家介绍七种方式 ...

  5. linux root用户下没有.ssh目录

    .ssh 是记录密码信息的文件夹,如果没有登录过root的话,就没有 .ssh 文件夹,因此登录 localhost ,并输入密码就会生成了 ssh localhost

  6. Fiddler显示指定host请求,以及过滤无用的css,js

    第一步 右侧窗口点击filters 第二步 点击Use Fiters 第三步 第一个选项不动 no zone filter ,第二个选项选择 show only following hosts 第四步 ...

  7. TCL(事务控制语言)

    #TCL/*Transaction Control Language 事务控制语言 事务:一个或一组sql语句组成一个执行单元,这个执行单元要么全部执行,要么全部不执行. 案例:转账 张三丰 1000 ...

  8. 存储池与存储卷,使用virt-install创建虚拟机

    原文链接:https://www.cnblogs.com/zknublx/p/9199658.html 创建存储池 1.建立存储池的目录 mkdir /kvm/images 2.为了安全性,更改目录的 ...

  9. 从Vessel到二代裸金属容器,云原生的新一波技术浪潮涌向何处?

    摘要:云原生大势,深度解读华为云四大容器解决方案如何加速技术产业融合. 云原生,可能是这两年云服务领域最火的词. 相较于传统的应用架构,云原生构建应用简便快捷,部署应用轻松自如.运行应用按需伸缩,是企 ...

  10. 精讲响应式WebClient第6篇-请求失败自动重试机制,强烈建议你看一看

    本文是精讲响应式WebClient第6篇,前篇的blog访问地址如下: 精讲响应式webclient第1篇-响应式非阻塞IO与基础用法 精讲响应式WebClient第2篇-GET请求阻塞与非阻塞调用方 ...