『PyTorch』矩阵乘法总结
1. 二维矩阵乘法 torch.mm()
torch.mm(mat1, mat2, out=None)
,其中mat1
(\(n\times m\)),mat2
(\(m\times d\)),输出out
的维度是(\(n\times d\))。
该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。
2. 三维带batch的矩阵乘法 torch.bmm()
由于神经网络训练一般采用mini-batch,经常输入的时三维带batch的矩阵,所以提供torch.bmm(bmat1, bmat2, out=None)
,其中bmat1
(\(b\times n \times m\)),bmat2
(\(b\times m \times d\)),输出out
的维度是(\(b \times n \times d\))。
该函数的两个输入必须是三维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。
3. 多维矩阵乘法 torch.matmul()
torch.matmul(input, other, out=None)
支持broadcast操作,使用起来比较复杂。
针对多维数据 matmul()
乘法,我们可以认为该matmul()
乘法使用使用两个参数的后两个维度来计算,其他的维度都可以认为是batch维度。假设两个输入的维度分别是input
(\(1000 \times 500 \times 99 \times 11\)), other
(\(500 \times 11 \times 99\))那么我们可以认为torch.matmul(input, other, out=None)
乘法首先是进行后两位矩阵乘法得到\((99 \times 11) \times (11 \times 99)\Rightarrow(99 \times 99)\) ,然后分析两个参数的batch size分别是 \(( 1000 \times 500)\) 和 \(500\) , 可以广播成为 \((1000 \times 500)\), 因此最终输出的维度是(\(1000 \times 500 \times 99 \times 99\))。
4. 矩阵逐元素(Element-wise)乘法 torch.mul()
torch.mul(mat1, other, out=None)
,其中other
乘数可以是标量,也可以是任意维度的矩阵,只要满足最终相乘是可以broadcast的即可
5. 两个运算符 @ 和 *
@
:矩阵乘法,自动执行适合的矩阵乘法函数*
:element-wise乘法
『PyTorch』矩阵乘法总结的更多相关文章
- 『PyTorch』第二弹重置_Tensor对象
『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...
- 『PyTorch』第十二弹_nn.Module和nn.functional
大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...
- 『PyTorch』第九弹_前馈网络简化写法
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- 『PyTorch』第三弹重置_Variable对象
『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...
- 『PyTorch』第十弹_循环神经网络
RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...
- 『PyTorch』第五弹_深入理解Tensor对象_下:从内存看Tensor
Tensor存储结构如下, 如图所示,实际上很可能多个信息区对应于同一个存储区,也就是上一节我们说到的,初始化或者普通索引时经常会有这种情况. 一.几种共享内存的情况 view a = t.arang ...
- 『PyTorch』第五弹_深入理解autograd_上:Variable属性方法
在PyTorch中计算图的特点可总结如下: autograd根据用户对variable的操作构建其计算图.对变量的操作抽象为Function. 对于那些不是任何函数(Function)的输出,由用户创 ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上
总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...
随机推荐
- C#串口通信SeriPort 电表DLT645 RS234/RS485
难受,三个多月前有一个电表电量监控的项目.做完了就没再管了.今天有需求需要改一些地方,但是....我想不起来干了啥,怎么干的啦.真的完全忘了.....项目名称叫啥都忘了.找了半天 不知道有没有和我一样 ...
- 【设计模式】java设计模式目录
1.创建型模式 JDK1.5枚举Singleton 单例模式 AbstractFactory 工厂方法模式 简单工厂模式 Builder Prototype 2.结构型 java设计模式 ...
- 03.SpringMVC之器
整体结构介绍 在Servlet的继承结构中一共有5个类,GenericServlet和HttpServlet在java中剩下的三个类HttpServletBean.FrameworkServlet和D ...
- BootstrapTable插件的使用 【转】
一.什么是Bootstrap-table? 在业务系统开发中,对表格记录的查询.分页.排序等处理是非常常见的,在Web开发中,可以采用很多功能强大的插件来满足要求,且能极大的提高开发效率,本随笔介绍这 ...
- Eclipse插件 -- 阿里巴巴扫描编码规插件
一.github地址: https://github.com/alibaba/p3c 二..eclipse插件的安装 此处示例采用eclipse,版本为 Neon.1 Release RC3 (4.6 ...
- JDBC基础篇(MYSQL)——通过JDBC连接数据库的三种方式
package day01_jdbc; import java.sql.Connection; import java.sql.Driver; import java.sql.DriverManage ...
- vue中的v-cloak指令
v-cloak不需要表达式,它会在vue实例结束编译时从绑定的html元素上移除,经常和display:none;配合使用: <div id="app" v-cloak> ...
- C# - 音乐小闹钟_BetaV2.0
时间:2017-11-21 作者:byzqy 介绍: 虽然上一版本基本实现了闹钟的功能,但是界面.功能.用户体验(简直谈不上体验^_^),以及众多的bug,所以升级,刻不容缓! 还是先看一下Beta ...
- VSCode中相对路径设置问题
使用的版本 对于import xxx操作,相对路径为sys.path 对于open("test.txt",'r')文件打开操作,相对路径为os.getcwd() 对于termina ...
- 一次PHP大马提权
记一次PHP提权 发现 PHP大马:指木马病毒:PHP大马,就是PHP写的提取站点权限的程序:因为带有提权或者修改站点功能,所以称为叫木马. 自从师哥那里听说过之后,一直感叹于PHP大马的神奇...但 ...