『PyTorch』矩阵乘法总结
1. 二维矩阵乘法 torch.mm()
torch.mm(mat1, mat2, out=None)
,其中mat1
(\(n\times m\)),mat2
(\(m\times d\)),输出out
的维度是(\(n\times d\))。
该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。
2. 三维带batch的矩阵乘法 torch.bmm()
由于神经网络训练一般采用mini-batch,经常输入的时三维带batch的矩阵,所以提供torch.bmm(bmat1, bmat2, out=None)
,其中bmat1
(\(b\times n \times m\)),bmat2
(\(b\times m \times d\)),输出out
的维度是(\(b \times n \times d\))。
该函数的两个输入必须是三维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。
3. 多维矩阵乘法 torch.matmul()
torch.matmul(input, other, out=None)
支持broadcast操作,使用起来比较复杂。
针对多维数据 matmul()
乘法,我们可以认为该matmul()
乘法使用使用两个参数的后两个维度来计算,其他的维度都可以认为是batch维度。假设两个输入的维度分别是input
(\(1000 \times 500 \times 99 \times 11\)), other
(\(500 \times 11 \times 99\))那么我们可以认为torch.matmul(input, other, out=None)
乘法首先是进行后两位矩阵乘法得到\((99 \times 11) \times (11 \times 99)\Rightarrow(99 \times 99)\) ,然后分析两个参数的batch size分别是 \(( 1000 \times 500)\) 和 \(500\) , 可以广播成为 \((1000 \times 500)\), 因此最终输出的维度是(\(1000 \times 500 \times 99 \times 99\))。
4. 矩阵逐元素(Element-wise)乘法 torch.mul()
torch.mul(mat1, other, out=None)
,其中other
乘数可以是标量,也可以是任意维度的矩阵,只要满足最终相乘是可以broadcast的即可
5. 两个运算符 @ 和 *
@
:矩阵乘法,自动执行适合的矩阵乘法函数*
:element-wise乘法
『PyTorch』矩阵乘法总结的更多相关文章
- 『PyTorch』第二弹重置_Tensor对象
『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...
- 『PyTorch』第十二弹_nn.Module和nn.functional
大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...
- 『PyTorch』第九弹_前馈网络简化写法
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- 『PyTorch』第三弹重置_Variable对象
『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...
- 『PyTorch』第十弹_循环神经网络
RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...
- 『PyTorch』第五弹_深入理解Tensor对象_下:从内存看Tensor
Tensor存储结构如下, 如图所示,实际上很可能多个信息区对应于同一个存储区,也就是上一节我们说到的,初始化或者普通索引时经常会有这种情况. 一.几种共享内存的情况 view a = t.arang ...
- 『PyTorch』第五弹_深入理解autograd_上:Variable属性方法
在PyTorch中计算图的特点可总结如下: autograd根据用户对variable的操作构建其计算图.对变量的操作抽象为Function. 对于那些不是任何函数(Function)的输出,由用户创 ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上
总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...
随机推荐
- C#简单实现表达式目录树(Expression)
1.什么是表达式目录树 :简单的说是一种语法树,或者说是一种数据结构(Expression) 2.用Lambda声明表达式目录树: 1 2 3 4 5 Expression<Func<in ...
- protected访问权限
Java中protected方法访问权限的问题 protected 修饰的成员变量或方法,只能在同包或子类可访问; package 1 public class TestPackage { prote ...
- Go: LeetCode简单题,简单做(sort.Search)
前言 正值端午佳节,LeetCode也很懂.这两天都是简单题,早点做完去包粽子. 故事从一道简单题说起 第一个错误的版本 简单题 你是产品经理,目前正在带领一个团队开发新的产品.不幸的是,你的产品的最 ...
- Git修改历史commit的author信息
前言 "嘀嗒嘀嗒",抬头看向墙上的钟表,此时已是凌晨1点.小明终于把Go语言圣经第二章的笔记写完,保存commit,提交,然后睡觉. 额,等等,不对,小明发现他用的是公司的git账 ...
- nacos在nginx下集群以及数据库问题
持久化mysql时指定数据库编辑application.properties spring.datasource.platform=mysql db.num=1 db.url.0=jdbc:mysql ...
- jdbc操作mysql(四):利用反射封装
前言 有了前面利用注解拼接sql语句,下面来看一下利用反射获取类的属性和方法 不过好像有一个问题,数据库中的表名和字段中带有下划线该如何解决呢 实践操作 工具类:获取connection对象 publ ...
- git所遇到的问题
出现这种情况,或 ERROR: Repository not found. fatal: 无法读取远程仓库. 解决办法如下: 1.先输入$ git remote rm origin(删除关联的orig ...
- 树莓派4B切换国内源-亲测有效
参考:https://blog.csdn.net/qq_30290661/article/details/103386997 修改/etc/apt/sources.list,去掉自带的源,添加如下源: ...
- Servlet学习笔记(二)之Servlet路径映射配置、Servlet接口、ServletConfig、ServletContext
Servlet路径映射配置 要使Servlet对象正常的运行,需要进行适当的配置,以告诉Web容器哪个请求调用哪个Servlet对象处理,对Servlet起到一个注册的作用.Servlet的配置信息包 ...
- nios eclipse提示LED_PIO_BASE没有声明,怎么回事?
这是因为名字不一致引起的比如,在生成SOPC系统时,双击PIO(Parallel I/O)(在Avalon Modules -> Other 下),为系统添加输出接口,你没有把该组件改名成LED ...