1. 引言

  最近在刷开源的Pytorch版动手学深度学习,里面谈到几个高级选择函数,如index_select,masked_select,gather等。这些函数大多很容易理解,但是对于gather函数,确实有些难理解,官方文档开始也看得一脸懵,感觉不太直观。下面谈谈我对这几个函数的一些理解。

2. 维度的理解

  对于numpy和pytorch,其数组在做维度运算上刚开始可能会给人一种直观上的误解,以numpy求矩阵某个维度的最大值为例(pytorch的理解也是一样的)

import numpy as np
a = np.arange(1, 13).reshape(3, 4)
"""
result:
a = [[1, 2, 3, 4],
[5, 6, 7, 8,],
[9, 10, 11, 12]]
""" # 对a维度0求最大值
a.max(axis = 0)
"""
result:
[9, 10, 11, 12]
""" # 对a维度1求最大值
a.max(axis = 1)
"""
result:
[4, 8, 12]
"""

  如果对a矩阵在维度0上找最大值,根据我们直观上的经验应该是[4, 8, 12]。即从[1, 2, 3, 4]找到4,从[5, 6, 7, 8]找到8,从[9, 10, 11, 12]找到12。但是从上面结果来看,numpy运算却给了我们直观上认为是列最大值的结果[9, 10, 11, 12]。

  实际numpy(pytorch)运算应该理解为往给定的维度进行移动运算。还是以维度0为例,维度0上有3个向量,分别为[1, 2, 3, 4],[5, 6, 7, 8]和[9, 10, 11, 12]。往维度0移动,即[1, 2, 3, 4]和[5, 6, 7, 8]逐元素计算最大值,得到[5, 6, 7, 8],再和[9, 10, 11, 12]运算得到结果[9, 10, 11, 12]。

  另外,对于维度为3的数组,在numpy和pytorch中,应该把维度0理解为通道数,维度1和维度2才是对应高和宽。如果是3维数组对应着用于多输入通道和单输出通道的卷积核(维度为U x V x D),那么4维数组就对应着用于多输入通道和多输出通道的卷积核(维度为U x V x D x P),此时,维度0则为多通道卷积核数量的方向,维度1为通道数,维度2和3才是分别对应高和宽。

3. gather函数

pytorch和numpy中许多函数都涉及维度运算,gather也不例外,但是它相对于其他函数更难理解。依然先来看一个例子

import torch
a = torch.arange(1, 16).reshape(5, 3)
"""
result:
a = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12],
[13, 14, 15]]
""" # 定义两个index
b = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]])
c = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]]) # axis=0
output1 = a.gather(0, b)
"""
result:
[[1, 5, 9],
[7, 11, 15],
[1, 8, 15]]
""" # axis=1
output2 = a.gather(1, c)
"""
result:
[[2, 3, 1, 3, 2],
[5, 6, 5, 4, 4]]
"""

上面的例子看起来可能有点复杂,我们来一步步的分析它,先从gather维度为0开始讲起。

  1. a.gather(0, b)分为3个部分,a是需要被提取元素的矩阵,0代表的是提取的维度为0,b是提取元素的索引

    • 其中规定b和a是同维张量,即a是2维张量,b也必须是2维张量
  2. 0除了代表往维度0的方向提取元素外,还有一个特权---提取结果output可以在这个维度上的长度与a不同。打个比方,a现在的shape为(5, 3),那么提取结果output1的shape可以是(1,3),(2, 3),甚至(n, 3)。具体维度0的长度到底为多少由b来决定。
  3. 根据0的特权,导致了给定的b张量除了维度0外,其他的维度大小必须和a一样。其中张量b实际上包含以下两个信息
    • b可以利用除用于gather的维度(此处为维度0)外的维度来定位出唯一一个向量,也就是a[:, ?](三维度也是同理的,有a[:, ?1, ?2]),?的取值范围为a同维度的index。
    • 对于上述定位出的向量,通过b中的元素来定位提取向量中的哪一个元素。
    • 上面说得可能有点抽象,实际上b中的每个元素都能在a中提取出一个元素。举个具体点的例子,按照上面所说的,b[0, 0]可以提取a中的一个元素。对于b[0,0],除了维度0外,可以通过维度1来定位出唯一一个向量a[:, 0]。因为b[0, 0]的元素为0,即提取的是a[:, 0]的第0个元素---1,并将其作为output1[0, 0]的提取结果。

      下图给出了维度0和维度1,gather运算的图示

对于3维或者更高维度的张量gather的原理也是一样的

4. index_select函数

其他的高级选择函数都比较容易理解,这里简单的提一下。torch.index_select主要是根据传入的tensor来往给定的axis方向来选取张量

import torch
a = torch.arange(9).reshape(3, 3)
torch.index_select(a, 0, torch.tensor([0, 2]))
"""
result:
[[0, 1, 2],
[6, 7, 8]]
"""

5. masked_select函数

实际上就是通过掩码条件来选择元素,像torch.masked_select(x, x>0.5),实际上是和x[x>0.5]等价的,最后返回的是一维张量

import torch
a = torch.rand(5, 3) # 结果和a[a > 0.5]等价
torch.masked_select(a, a>0.5)

6. nonzero函数

找到非零元素的index

import torch
a = torch.eye(3)
torch.nonzero(a) """
result: 对应着非零元素的index
[[0, 0],
[1, 1],
[2, 2]]
"""

理解pytorch几个高级选择函数(如gather)的更多相关文章

  1. 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())

    学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...

  2. 关于Pytorch的二维tensor的gather和scatter_操作用法分析

    看得不明不白(我在下一篇中写了如何理解gather的用法) gather是一个比较复杂的操作,对一个2维tensor,输出的每个元素如下: out[i][j] = input[index[i][j]] ...

  3. 理解PyTorch的自动微分机制

    参考Getting Started with PyTorch Part 1: Understanding how Automatic Differentiation works 非常好的文章,讲解的非 ...

  4. 理解pytorch中的softmax中的dim参数

    import torch import torch.nn.functional as F x1= torch.Tensor( [ [1,2,3,4],[1,3,4,5],[3,4,5,6]]) y11 ...

  5. [深度学习] pytorch学习笔记(1)(数据类型、基础使用、自动求导、矩阵操作、维度变换、广播、拼接拆分、基本运算、范数、argmax、矩阵比较、where、gather)

    一.Pytorch安装 安装cuda和cudnn,例如cuda10,cudnn7.5 官网下载torch:https://pytorch.org/ 选择下载相应版本的torch 和torchvisio ...

  6. 《深入理解Java虚拟机:JVM高级特性与最佳实践》【PDF】下载

    <深入理解Java虚拟机:JVM高级特性与最佳实践>[PDF]下载链接: https://u253469.pipipan.com/fs/253469-230062566 内容简介 作为一位 ...

  7. 什么是pytorch(1开始)(翻译)

    Deep Learning with PyTorch: A 60 Minute Blitz 作者: Soumith Chintala 部分翻译:me 本内容包含: 在高级层面理解pytorch的ten ...

  8. 万字综述,核心开发者全面解读PyTorch内部机制

    斯坦福大学博士生与 Facebook 人工智能研究所研究工程师 Edward Z. Yang 是 PyTorch 开源项目的核心开发者之一.他在 5 月 14 日的 PyTorch 纽约聚会上做了一个 ...

  9. 【小白学PyTorch】15 TF2实现一个简单的服装分类任务

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

随机推荐

  1. 状态压缩动态规划(状压DP)详解

    0 引子 不要999,也不要888,只要288,只要288,状压DP带回家.你买不了上当,买不了欺骗.它可以当搜索,也可以卡常数,还可以装B,方式多样,随心搭配,自由多变,一定符合你的口味! 在计算机 ...

  2. laravel中elastisearch安装和测试运行是否成功(注意是windows下的操作)

    1.去elasticsearch官网下载,如果太慢可以在我上一个随笔看下载地址 2.下载完解压缩,在cmd中找到到elasticsearch的bin目录下执行.\elasticsearch.bat - ...

  3. 解决 SQL 注入和 XSS 攻击(Node.js 项目中)

    1.SQL 注入 SQL 注入,一般是通过把 SQL 命令插入到 Web 表单提交或输入域名或页面请求的查询字符串,最终达到欺骗服务器执行恶意的 SQL 命令. SQL 注入示例 在登录界面,后端会根 ...

  4. manualresetevent的用法学习

    ManualResetEvent 允许线程通过发信号互相通信. 通常,此通信涉及一个线程在其他线程进行之前必须完成的任务. 当一个线程开始一个活动(此活动必须完成后,其他线程才能开始)时,它调用 Re ...

  5. Unity2018.4.7导出Xcode工程报错解决方案

    1. unity导出xcode工程有两种模式,一种为模拟器运行的工程,一种为真机运行的工程,这里遇到的错误,都是导出模拟器运行工程时报的错误. 错误1: unity UnityMetalSupport ...

  6. Codeforces 1389 题解(A-E)

    AC代码 A. LCM Problem 若\(a < b\),则\(LCM(a,b)\)是\(a\)的整数倍且\(LCM(a,b) \ne a\),所以\(LCM(a,b) \ge 2a\),当 ...

  7. 跟着兄弟连系统学习Linux-【day08】

    day08-20200605 p27.软件包管理简 windows 和 linux 软件是不同的版本. Linux源码包,开源的.绝大部分都是C语言写的.源码包安装速度比较慢.需要先编译后再安装.脚本 ...

  8. jQuery提供的Ajax方法

    jQuery提供了4个ajax方法:$.get()  $.post()  $.ajax()  $.getJSON() 1.$.get() $.get(var1,var2,var3,var4): 参数1 ...

  9. 快速启动CMD窗口的办法

    在 文件管理器的 地址栏输入cmd回车,cmd会快速在此路径下打开. --END-- 2020-01-07

  10. C Primer Plus 学习笔记

    随笔: 1)C语言中%3d%2d什么意思? 格式化规定字符, 以"%"开始, 后跟一个或几个规定字符,用来确定输出内容格式.在"%"和字母之间插进数字表示最大场 ...