PyTorch笔记之 scatter() 函数
scatter() 和 scatter_() 的作用是一样的,只不过 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会
PyTorch 中,一般函数加下划线代表直接在原来的 Tensor 上修改
scatter(dim, index, src) 的参数有 3 个
- dim:沿着哪个维度进行索引
- index:用来 scatter 的元素索引
- src:用来 scatter 的源元素,可以是一个标量或一个张量
这个 scatter 可以理解成放置元素或者修改元素
简单说就是通过一个张量 src 来修改另一个张量,哪个元素需要修改、用 src 中的哪个元素来修改由 dim 和 index 决定
官方文档给出了 3维张量 的具体操作说明,如下所示
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
exmaple:
x = torch.rand(2, 5) #tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
# [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]]) torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) #tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
# [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
# [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])
具体地说,我们的 index 是 torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]),一个二维张量,下面用图简单说明
我们是 2维 张量,一开始进行 $self[index[0][0]][0]$,其中 $index[0][0]$ 的值是0,所以执行 $self[0][0] = x[0][0] = 0.1940$
$self[index[i][j]][j] = src[i][j] $
再比如$self[index[1][0]][0]$,其中 $index[1][0]$ 的值是2,所以执行 $self[2][0] = x[1][0] = 0.2078$
src 除了可以是张量外,也可以是一个标量
example:
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7) #tensor([[7., 7., 7., 7., 7.],
# [0., 7., 0., 7., 0.],
# [7., 0., 7., 0., 7.]]
scatter() 一般可以用来对标签进行 one-hot 编码,这就是一个典型的用标量来修改张量的一个例子
example:
class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size, 1).random_() % class_num
#tensor([[6],
# [0],
# [3],
# [2]])
torch.zeros(batch_size, class_num).scatter_(1, label, 1)
#tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
# [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
# [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])
PyTorch笔记之 scatter() 函数的更多相关文章
- [Pytorch] pytorch笔记 <三>
pytorch笔记 optimizer.zero_grad() 将梯度变为0,用于每个batch最开始,因为梯度在不同batch之间不是累加的,所以必须在每个batch开始的时候初始化累计梯度,重置为 ...
- [Pytorch] pytorch笔记 <二>
pytorch笔记2 用到的关于plt的总结 plt.scatter scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, ...
- [Pytorch] pytorch笔记 <一>
pytorch笔记 - torchvision.utils.make_grid torchvision.utils.make_grid torchvision.utils.make_grid(tens ...
- matplotlib 知识点13:绘制散点图(scatter函数精讲)
散点图是指在回归分析中,数据点在直角坐标系平面上的分布图,散点图表示因变量随自变量而变化的大致趋势,据此可以选择合适的函数对数据点进行拟合. 用两组数据构成多个坐标点,考察坐标点的分布,判断两变量之间 ...
- IOS学习笔记07---C语言函数-printf函数
IOS学习笔记07---C语言函数-printf函数 0 7.C语言5-printf函数 ------------------------- ----------------------------- ...
- IOS学习笔记06---C语言函数
IOS学习笔记06---C语言函数 -------------------------------------------- qq交流群:创梦技术交流群:251572072 ...
- Typescript 学习笔记三:函数
中文网:https://www.tslang.cn/ 官网:http://www.typescriptlang.org/ 目录: Typescript 学习笔记一:介绍.安装.编译 Typescrip ...
- 交叉熵的数学原理及应用——pytorch中的CrossEntropyLoss()函数
分类问题中,交叉熵函数是比较常用也是比较基础的损失函数,原来就是了解,但一直搞不懂他是怎么来的?为什么交叉熵能够表征真实样本标签和预测概率之间的差值?趁着这次学习把这些概念系统学习了一下. 首先说起交 ...
- ES6学习笔记<三> 生成器函数与yield
为什么要把这个内容拿出来单独做一篇学习笔记? 生成器函数比较重要,相对不是很容易理解,单独做一篇笔记详细聊一聊生成器函数. 标题为什么是生成器函数与yield? 生成器函数类似其他服务器端语音中的接口 ...
随机推荐
- 理解长短期记忆网络(LSTM NetWorks)
转自:http://www.csdn.net/article/2015-11-25/2826323 原文链接:Understanding LSTM Networks(译者/刘翔宇 审校/赵屹华 责编/ ...
- Qt的QSettings类和.ini文件读写
Detailed Description QSettings类提供了持久的跨平台的应用程序设置.用户通常期望应用程序记住它的设置(窗口大小.位置等)所有会话.这些信息通常存储在Windows系统注册表 ...
- JavaWeb中的文件上传和下载功能的实现
导入相关支持jar包:commons-fileupload.jar,commons-io.jar 对于文件上传,浏览器在上传的过程中是将文件以流的形式提交到服务器端的,如果直接使用Servlet获取上 ...
- 双层for循环用java中的stream流来实现
//双重for循环for (int i = 0; i < fusRecomConfigDOList.size(); i++) { for (int j = 0; j < fusRecomC ...
- python如何调用c编译好可执行程序
python如何调用c编译好可执行程序 以下总结出几种在Python 中调用 C/C++ 代码的方法 ------------------------------------------- ...
- poj 3352 : Road Construction 【ebcc】
题目链接 题意:给出一个连通图,求最少加入多少条边可使图变成一个 边-双连通分量 模板题,熟悉一下边连通分量的定义.最后ans=(leaf+1)/2.leaf为原图中size为1的边-双连通分量 #i ...
- LeetCode--008--字符串转换整数 (atoi)(python)
示例 1: 输入: "42"输出: 42示例 2: 输入: " -42"输出: -42解释: 第一个非空白字符为 '-', 它是一个负号. 我们尽可能将负号与 ...
- 解决Intellij IDEA中项目不能识别yml配置文件
问题:能读取资源路径下的properties配置文件但是不能读yml配置文件 因为无法读取配置yml配置文件,所以不能配置bean,导致项目启动报错. 解决方法: 在VM options中设置虚拟机加 ...
- 简单说说JavaBean的使用
一:JavaBean定义 JavaBean是一种可重复使用.跨平台的软件组件.JavaBean可分为两种:一种是有用户界面(UI,User Interface)的JavaBean,例如中的那些可视化图 ...
- Linux内核设计与实现 总结笔记(第十五章)进程地址空间
一.地址空间 进程地址空间由进程可寻址的虚拟内存组成,内核允许进程使用这种虚拟内存中的地址. 每个进程都有一个32位或64位的平坦地址空间,空间的具体大小取决于体系结构.“平坦”指的是地址空间范围是一 ...