看得不明不白(我在下一篇中写了如何理解gather的用法)

gather是一个比较复杂的操作,对一个2维tensor,输出的每个元素如下:

out[i][j] = input[index[i][j]][j]  # dim=0
out[i][j] = input[i][index[i][j]] # dim=1

二维tensor的gather操作

针对0轴

注意index此时的值

输入

index = t.LongTensor([[0,1,2,3]])
print("index = \n", index) #index是2维
print("index的形状: ",index.shape) #index形状是(1,4)

输出

index =
tensor([[0, 1, 2, 3]])
index的形状: torch.Size([1, 4])

分割线============

针对1轴

注意index此时的值

输入

index = t.LongTensor([[0,1,2,3]]).t()  #index是2维
print("index = \n", index) #index形状是(4,1)
print("index的形状: ",index.shape)

输出

index =
tensor([[0],
[1],
[2],
[3]])
index的形状: torch.Size([4, 1])

分割线===========

再来看看几个例子

注意index在以0轴和1轴为标准时的表达式是不一样的。

b.gather()中取0维时,输出的结果是行形式,取1维时,输出的结果是列形式。

  • b是一个 $ 3\times4 $ 型的
>>> import torch as t
>>> b = t.arange(0,12).view(3,4)
>>> b
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> index = t.LongTensor([[0,1,2]]) >>> index
tensor([[0, 1, 2]]) >>> b.gather(0,index) #运行失败了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 3], src [3 x 4] and index [1 x 3] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620 >>> index2 = t.LongTensor([[0,1,2]]).t() >>> b.gather(1,index2) #运行成功了
tensor([[ 0],
[ 5],
[10]]) >>> index3 = t.LongTensor([[0,1,2,3]]).t() >>> b.gather(1,index3) #运行失败了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [4 x 1], src [3 x 4] and index [4 x 1] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620
  • b是一个 $ 6\times6 $ 型的
>>> import torch as t
>>> b = t.arange(0,36).view(6,6)
>>> b
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]]) >>> index = t.LongTensor([[0,1,2,3,4,5,6]])
>>> b.gather(0,index) #运行失败了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 7], src [6 x 6] and index [1 x 7] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620 >>> index = t.LongTensor([[0,1,2,3,4,5]])
>>> b.gather(0,index) #运行成功了
tensor([[ 0, 7, 14, 21, 28, 35]])
>>> b.gather(1,index)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 6], src [6 x 6] and index [1 x 6] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620 >>> index2 = t.LongTensor([[0,1,2,3,4,5]]).t()
>>> b.gather(1,index2) #运行成功了
tensor([[ 0],
[ 7],
[14],
[21],
[28],
[35]]) >>> index3 = t.LongTensor([[0,1,2,3,4]]).t()
>>> b.gather(1,index3)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [5 x 1], src [6 x 6] and index [5 x 1] to have the same size apart from dimension 1 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620 >>> index4 = t.LongTensor([[0,1,2,3,4]])
>>> b.gather(0,index4)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected tensor [1 x 5], src [6 x 6] and index [1 x 5] to have the same size apart from dimension 0 at c:\new-builder_3\win-wheel\pytorch\aten\src\th\generic/THTensorMath.cpp:620

与gather相对应的逆操作是scatter_,gather把数据从input中按index取出,而scatter_是把取出的数据再放回去。注意scatter_函数是inplace操作。





与gather相对应的逆操作是scatter_,gather把数据从input中按index取出,而scatter_是把取出的数据再放回去。注意scatter_函数是inplace操作。

out = input.gather(dim, index)
-->近似逆操作
out = Tensor()
out.scatter_(dim, index)

根据StackOverflow上的问题修改代码如下:

输入

# 把两个对角线元素放回去到指定位置
c = t.zeros(4,4)
c.scatter_(1, index, b.float())

输出

tensor([[ 0.,  0.,  0.,  3.],
[ 0., 5., 6., 0.],
[ 0., 9., 10., 0.],
[12., 0., 0., 15.]])

关于Pytorch的二维tensor的gather和scatter_操作用法分析的更多相关文章

  1. 2017.11.17 C++系列---用malloc动态给c++二维数组的申请与释放操作

    方法一:利用二级指针申请一个二维数组. #include<stdio.h> #include<stdlib.h> int main() { int **a; //用二级指针动态 ...

  2. Pytorch学习笔记(二)——Tensor

    一.对Tensor的操作 从接口的角度讲,对Tensor的操作可以分为两类: (1)torch.function (2)tensor.function 比如torch.sum(a, b)实际上和a.s ...

  3. Javascript生成二维码(QR)

    网络上已经有非常多的二维码编码和解码工具和代码,很多都是服务器端的,也就是说需要一台服务器才能提供二维码的生成.本着对服务器性能的考虑,这种小事情都让服务器去做,感觉对不住服务器,尤其是对于大流量的网 ...

  4. 二维指针*(void **)的研究(uC/OS-II案例) 《转载》

    uC/OS-II内存管理函数内最难理解的部分就是二维指针,本文以图文并茂的方式对二维指针进行了详细分析与讲解.看完本文,相信对C里面指针的概念又会有进一步的认识. 一.OSMemCreate( ) 函 ...

  5. Stars(二维树状数组)

    Stars Time Limit: 5000/2000 MS (Java/Others) Memory Limit: 32768/65536 K (Java/Others) Total Submiss ...

  6. iOS原生实现二维码拉近放大

    http://www.cocoachina.com/ios/20180416/23033.html 2018-04-16 15:34 编辑: yyuuzhu 分类:iOS开发 来源:程序鹅 8 300 ...

  7. Java打印M图形(二维数组)——(九)

    对于平面图形输出集合图形与数字组合的,用二维数组.先在Excel表格中分析一下,找到简单的规律.二维数组的行数为行高,列数为最后一个数大小. 对于减小再增大再减小再增大的,可以用一个boolean标志 ...

  8. 混沌数学之二维logistic模型

    上一节讲了logistic混沌模型,这一节对其扩充一下讲二维 Logistic映射.它起着从一维到高维的衔接作用,对二维映射中混沌现象的研究有助于认识和预测更复杂的高维动力系统的性态.通过构造一次藕合 ...

  9. UVa 11297 Census (二维线段树)

    题意:给定上一个二维矩阵,有两种操作 第一种是修改 c x y val 把(x, y) 改成 val 第二种是查询 q x1 y1 x2 y2 查询这个矩形内的最大值和最小值. 析:二维线段树裸板. ...

随机推荐

  1. apache2+svn Cannot load modules/mod_dav_svn.so into server: \xd5\xd2\xb2\xbb\xb5\xbd\xd6\xb8\xb6\xa8\xb5\xc4\xc4\xa3\xbf\xe9\xa1\xa3

    按照svn里的readme文件安装配置apache2与svn后, 启动apache2服务的时候 出现下面的问题 Cannot load C:/Program Files/Apache Software ...

  2. Hadoop科普文—常见的45个问题解答 &#183; Hadoop

    个模式 · 单机(本地)模式 · 伪分布式模式 · 全分布式模式 2.  单机(本地)模式中的注意点? 在单机模式(standalone)中不会存在守护进程,全部东西都执行在一个JVM上. 这里相同没 ...

  3. ios -逆向-代码混淆

    该方法只能针对有.m.h的类进行混淆,静态库等只有.h文件的没法进行混淆 代码混淆,刚刚看到是不是有点懵逼,反正我是最近才接触到这么个东西,因为之前对于代码和APP,只需要实现功能就好了,根本没有考虑 ...

  4. ios -WKWebView 高度 准确,留有空白的解决方案

    #import "ViewController.h" #import <WebKit/WebKit.h> @interface ViewController ()< ...

  5. ThinkPHP部分内置函数

    D.F.S.C.L.A.I 他们都在functions.php这个文件家下面我分别说明一下他们的功能 D() 加载Model类M() 加载Model类 A() 加载Action类L() 获取语言定义C ...

  6. C语言结构体数组内带字符数组初始化和赋值

    1.首先定义结构体数组: typedef struct BleAndTspRmtCmd{ char terminal[3]; char note[3]; char rmtCmd[10]; char c ...

  7. AngularCli项目中添加字体图标(Font)详解

    本文主要讲如何在AngularCli生成的项目中使用字体图标. 一 SVG图标准备 将需要转换为字体图标的图片转换为SVG格式. 这个让项目视觉设计师搞定即可. 二 SVG图标转Font 可以通过Ic ...

  8. 数据库之MySQL(一)

    概述 1.什么是数据库 ?   数据的仓库,如:在ATM的示例中我们创建了一个 db 目录,称其为数据库 2.什么是 MySQL.Oracle.SQLite.Access.MS SQL Server等 ...

  9. xlwt 模块 操作excel

    1.xlwt 基本用法 import xlwt #1 新建文件 new_file = open('test.xls', 'w') new_file.close() #2 创建工作簿 wookbook ...

  10. 测试开发面试的Linux面试题:常用命令

    Hello,大家好上次给大家介绍了vim使用方法,今天来给大家讲一讲linux系统文件命令 (1)Linux的文件系统目录配置要遵循FHS规范,规范定义的两级目录规范如下:        /home  ...