scatter()scatter_() 的作用是一样的,只不过 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会

PyTorch 中,一般函数加下划线代表直接在原来的 Tensor 上修改

scatter(dim, index, src) 的参数有 3 个

  • dim:沿着哪个维度进行索引
  • index:用来 scatter 的元素索引
  • src:用来 scatter 的源元素,可以是一个标量或一个张量

这个 scatter  可以理解成放置元素或者修改元素

简单说就是通过一个张量 src  来修改另一个张量,哪个元素需要修改、用 src 中的哪个元素来修改由 dim 和 index 决定

官方文档给出了 3维张量 的具体操作说明,如下所示

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2

exmaple:

x = torch.rand(2, 5)

#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
# [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]]) torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) #tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
# [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
# [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

具体地说,我们的 index 是 torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]),一个二维张量,下面用图简单说明

我们是 2维 张量,一开始进行 $self[index[0][0]][0]$,其中 $index[0][0]$ 的值是0,所以执行 $self[0][0] = x[0][0] = 0.1940$

$self[index[i][j]][j] = src[i][j] $

再比如$self[index[1][0]][0]$,其中 $index[1][0]$ 的值是2,所以执行 $self[2][0] = x[1][0] = 0.2078$

src 除了可以是张量外,也可以是一个标量

example:

torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7)

#tensor([[7., 7., 7., 7., 7.],
# [0., 7., 0., 7., 0.],
# [7., 0., 7., 0., 7.]]

scatter() 一般可以用来对标签进行 one-hot 编码,这就是一个典型的用标量来修改张量的一个例子

example:

class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size, 1).random_() % class_num
#tensor([[6],
# [0],
# [3],
# [2]])
torch.zeros(batch_size, class_num).scatter_(1, label, 1)
#tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
# [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
# [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])

PyTorch笔记之 scatter() 函数的更多相关文章

  1. [Pytorch] pytorch笔记 <三>

    pytorch笔记 optimizer.zero_grad() 将梯度变为0,用于每个batch最开始,因为梯度在不同batch之间不是累加的,所以必须在每个batch开始的时候初始化累计梯度,重置为 ...

  2. [Pytorch] pytorch笔记 <二>

    pytorch笔记2 用到的关于plt的总结 plt.scatter scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, ...

  3. [Pytorch] pytorch笔记 <一>

    pytorch笔记 - torchvision.utils.make_grid torchvision.utils.make_grid torchvision.utils.make_grid(tens ...

  4. matplotlib 知识点13:绘制散点图(scatter函数精讲)

    散点图是指在回归分析中,数据点在直角坐标系平面上的分布图,散点图表示因变量随自变量而变化的大致趋势,据此可以选择合适的函数对数据点进行拟合. 用两组数据构成多个坐标点,考察坐标点的分布,判断两变量之间 ...

  5. IOS学习笔记07---C语言函数-printf函数

    IOS学习笔记07---C语言函数-printf函数 0 7.C语言5-printf函数 ------------------------- ----------------------------- ...

  6. IOS学习笔记06---C语言函数

    IOS学习笔记06---C语言函数 --------------------------------------------  qq交流群:创梦技术交流群:251572072              ...

  7. Typescript 学习笔记三:函数

    中文网:https://www.tslang.cn/ 官网:http://www.typescriptlang.org/ 目录: Typescript 学习笔记一:介绍.安装.编译 Typescrip ...

  8. 交叉熵的数学原理及应用——pytorch中的CrossEntropyLoss()函数

    分类问题中,交叉熵函数是比较常用也是比较基础的损失函数,原来就是了解,但一直搞不懂他是怎么来的?为什么交叉熵能够表征真实样本标签和预测概率之间的差值?趁着这次学习把这些概念系统学习了一下. 首先说起交 ...

  9. ES6学习笔记<三> 生成器函数与yield

    为什么要把这个内容拿出来单独做一篇学习笔记? 生成器函数比较重要,相对不是很容易理解,单独做一篇笔记详细聊一聊生成器函数. 标题为什么是生成器函数与yield? 生成器函数类似其他服务器端语音中的接口 ...

随机推荐

  1. 常用sql---表记录数和占用空间统计

    1.每张表的记录数和占用空间 select owner as 用户名, table_name as 表名, num_rows as 记录数, ROUND(t.NUM_ROWS * t.AVG_ROW_ ...

  2. 【学习】019 SpringBoot

    一.SpringBoot介绍 1.1.SpringBoot简介 在您第1次接触和学习Spring框架的时候,是否因为其繁杂的配置而退却了?在你第n次使用Spring框架的时候,是否觉得一堆反复黏贴的配 ...

  3. DevOps书单:调研了101名专家,推荐这39本必读书籍

    任何一个领域都遵循从新人到熟手,从熟手到专家的路径.在成长过程中,DevOps人经常会陷入没人带,没人管,找不到职业方向的迷茫. DevOps是在商业演进与企业协作的进化过程中诞生的一个全新职业,被很 ...

  4. kettle imestamp : Unable to get timestamp from resultset at index 22

    在做ETL的时候,连接MySQL读取含有timestamp类型的表,出现如下错误: 经Google,据说是MySQL自身的问题.解决方法也很简单,在Spoon的数据库连接中,打开选项,加入一行命令参数 ...

  5. LeetCode--022--括号生成(python)

    给出 n 代表生成括号的对数,请你写出一个函数,使其能够生成所有可能的并且有效的括号组合. 例如,给出 n = 3,生成结果为: class Solution: def generateParenth ...

  6. 02-webpack的作用

    webpack的作用,将不同静态资源的类型打包成一个JS文件,在html页面应用该JS文件的时候,JS文件里的html就可以正常的运行,去执行操作. 也可以加载前端页面的CSS样式.Img图片

  7. github上.md的编写

    # algs4 一:大标题 =========== 二:中标题 ------------ 三:1~6级标题 # 一级标题 ## 二级标题 ### 三级标题 #### 四级标题 ##### 五级标题 # ...

  8. 【PowerOJ1746&网络流24题】航空路线问题(费用流)

    题意: 思路: [问题分析] 求最长两条不相交路径,用最大费用最大流解决. [建模方法] 把第i个城市拆分成两个顶点<i.a>,<i.b>. 1.对于每个城市i,连接(< ...

  9. qbzt day3 晚上 平衡树的一些思想

    pks大佬的blog 二叉查找树 任何一个节点左子树的所有元素都小于这个节点,右子树的所有元素都大于这个节点 查找一个节点:从根节点开始,比他小就向左走,比他大就向右走 平衡树:解决二叉查找树的一些痛 ...

  10. (转)sql server 事务与try catch

    本文转载自:http://www.cnblogs.com/sky_Great/archive/2013/01/09/2852417.html sql普通事务 begin transaction tr ...