einsum函数说明

pytorch文档说明:\(torch.einsum(equation, **operands)\) 使用基于爱因斯坦求和约定的符号,将输入operands的元素沿指定的维数求和。einsum允许计算许多常见的多维线性代数阵列运算,方法是基于爱因斯坦求和约定以简写格式表示它们。主要是省略了求和号,总体思路是在箭头左边用一些下标标记输入operands的每个维度,并在箭头右边定义哪些下标是输出的一部分。通过将operands元素与下标不属于输出的维度的乘积求和来计算输出。其方便之处在于可以直接通过求和公式写出运算代码。

# 矩阵乘法例子引入
a = torch.rand(2,3)
b = torch.rand(3,4)
c = torch.einsum("ik,kj->ij", [a, b])
# 等价操作 torch.mm(a, b)

两个基本概念,自由索引/自由标(Free indices)和求和索引/哑标(Summation indices):

  • 自由索引,出现在箭头右边的索引
  • 求和索引,只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,

接着是介绍三条基本规则:

  • 规则一,equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,比如还是以上面矩阵乘法为例, "ik,kj->ij",k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作;
  • 规则二,只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引;
  • 规则三,equation 箭头右边的索引顺序可以是任意的,比如上面的 "ik,kj->ij" 如果写成 "ik,kj->ji",那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。

两条特殊规则:

  • equation 可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 "ik,kj->ij" 也可以简化为 "ik,kj",根据默认规则,输出就是 "ij" 与原来一样;
  • equation 中支持 "..." 省略号,用于表示用户并不关心的索引,详见下方转置例子

单操作数

获取对角线元素diagonal

einsum 可以不做求和。举个例子,获取二维方阵的对角线元素,结果放入一维向量。

\[A_i = B_{ii}
\]

上面,A 是一维向量,B 是二维方阵。使用 einsum 记法,可以写作 ii->i

torch.einsum('ii->i', torch.randn(4, 4))

# 以下操作互相等价
a = torch.randn(4,4)
c = torch.einsum('ii->i', a)
c = torch.diagonal(a, 0)

迹trace

求解矩阵的迹(trace),即对角线元素的和。

\[t = \Sigma_{i=1}^{n} A_{ii}
\]

t 是常量,A 是二维方阵。按照前面的做法,省略 ΣΣ,左右两边对调,省去矩阵和 t,剩下的就是ii->或省略箭头ii

torch.einsum('ii', torch.randn(4, 4))

矩阵转置

\[A_{ij} = B_{ji}
\]

A 和 B 都是二维方阵。einsum 可以表达为 ij->ji

torch.einsum('ij -> ji',a)

pytorch 中,还支持省略前面的维度。比如,只转置最后两个维度,可以表达为 ...ij->...ji。下面展示了一个含有四个二维矩阵的三维矩阵,转置三维矩阵中的每个二维矩阵。

A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# torch.Size([2, 3, 5, 4]) # 等价操作
A.permute(0,1,3,2)
A.transpose(2,3)

求和

\[b=\sum_{i} \sum_{j} A_{i j}=A_{i j}
\]
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)

列求和:

\[b_{j}=\sum_{i} A_{i j}=A_{i j}
\]
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3., 5., 7.]) # 等价操作
torch.sum(a, 0) # (dim参数0) means the dimension or dimensions to reduce.

双操作数

矩阵乘法

\[A_{ij} = \Sigma_{k=1}^{n} B_{ik} C_{kj}
\]

第一个学习的 einsum 表达式是,ik,kj->ij。前面提到过,爱因斯坦求和记法可以理解为懒人求和记法。将上述公式中的 ΣΣ 去掉,并且将左右两边对调一下,省去矩阵之后,剩下的就是 ik,kj->ij 了。

torch.einsum('ik,kj->ij', a, b) 

# 可用两个矩阵测试以下矩阵乘法操作互相等价
a = torch.randn(2,3)
b = torch.randn(3,4)
c = torch.matmul(a,b)
c = torch.einsum('ik,kj->ij', a, b)
c = a.mm(b)
c = torch.mm(a, b)
c = a @ b

矩阵-向量相乘

\[c_{i}=\sum_{k} A_{i k} b_{k}=A_{i k} b_{k}
\]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b]) tensor([ 5., 14.])

批量矩阵乘 batch matrix multiplication

\[C_{bik}=\sum_{k} A_{bij} B_{bjk}=A_{bij} B_{bjk}
\]
>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354],
[-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]]) # 等价操作
torch.bmm(As, Bs)

向量内积 dot

\[c=\sum_{i} a_{i} b_{i}=a_{i} b_{i}
\]
a = torch.arange(3)
b = torch.arange(3,6) # [3, 4, 5]
torch.einsum('i,i->', [a, b])
# tensor(14.) # 等价操作
torch.dot(a, b)

矩阵内积 dot

\[c=\sum_{i} \sum_{j} A_{i j} B_{i j}=A_{i j} B_{i j}
\]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)

哈达玛积

\[C_{i j}=A_{i j} B_{i j}
\]
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[ 0., 7., 16.],
[ 27., 40., 55.]])

外积 outer

\[C_{i j}=a_{i} b_{j}
\]
a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b]) tensor([[ 0., 0., 0., 0.],
[ 3., 4., 5., 6.],
[ 6., 8., 10., 12.]])

einsum其他规则和例子判断:

  • 输入中多次出现的字符,将被用作求和。例子,kj,ji 完整的表达式是 kj,ji->ik,矩阵乘法再相乘。
  • 输出可以指定,但是输出中的每个字符必须在输入中出现至少一次,输出的每个字符在输出中只能出现最多一次。例子,ab->aa 是非法的,ab->c 是非法的,ab->a 是合法的。
  • 省略符 ... 是用来跳过部分维度。例子,...ij,...jk 表示 batch 矩阵乘法。
  • 在输出没有指定的情况下,省略符优先级高于普通字符。例子,b...a 完整的表达式是 b...a->...ab,可以将一个形状为 (a,b,c) 的矩阵变为形状为 (b,c,a) 的矩阵。
  • 允许多个矩阵输入,表达式中使用逗号分开不同矩阵输入的下标。例子,i,i,i 表示将三个一维向量按位相乘,并相加。
  • 除了箭头,其他任何地方都可以加空格。例子,i j , j k -> ik 是合法的,ij,jk - > ik 是非法的。
  • 输入的表达式,维度需要和输入的矩阵对上,不能多也不能少。比如一个 shape 为 (4,3,3) 的矩阵,表达式 ab->a 是非法的,abc-> 是合法的。

实际使用

实现multi headed attention

https://nn.labml.ai/transformers/mha.html

如何优雅地实现多头自注意力

计算注意力score:

\[Q K^{\top} or S_{i j b h}=\sum_{d} Q_{i b h d} K_{j b h d}
\]
# q k v均为 [seq_len, batch_size, heads, d_k]
torch.einsum('ibhd,jbhd->ijbh', query, key) # 理解为ibhd,jbhd->ibhj->ijbh

计算attention输出:

\[\underset{\text { seq }}{\operatorname{softmax}}\left(\frac{Q K^{\top}}{\sqrt{d_{k}}}\right) V
\]
# attn [seq_len, seq_len, batch_size, heads]
# value [seq_len, batch_size, heads, d_k] x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# x [seq_len, batch_size, heads, d_k]

参考文献:

https://zhuanlan.zhihu.com/p/361209187

如何优雅地实现多头自注意力

https://rockt.github.io/2018/04/30/einsum **

einsum函数介绍-张量常用操作的更多相关文章

  1. Git介绍及常用操作演示(一)--技术流ken

    Git介绍 Git(读音为/gɪt/.)是一个开源的分布式版本控制系统,可以有效.高速的处理从很小到非常大的项目版本管理. Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发 ...

  2. CI 知识 :Git介绍及常用操作

    Git介绍 Git(读音为/gɪt/.)是一个开源的分布式版本控制系统,可以有效.高速的处理从很小到非常大的项目版本管理. Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发 ...

  3. python 文件操作: 文件操作的函数, 模式及常用操作.

    1.文件操作的函数: open("文件名(路径)", mode = '模式', encoding = "字符集") 2.模式: r , w , a , r+ , ...

  4. JavaScript基础DOM介绍和常用操作(5)

    day53 参考:https://www.cnblogs.com/liwenzhou/p/8011504.html JavaScript引入方式 location对象 window.location ...

  5. 简单的git入门介绍及常用操作

    集中式版本控制系统采用中央服务器上存储的所有文件和实现团队协作.但是CVCS主要缺点是中央服务器的单点故障,即故障.不幸的是,如果中央服务器宕机一小时,然后在该时段没有人可以合作.即使在最坏的情况下, ...

  6. Docker介绍及常用操作演示(一)--技术流ken

    Docker简介 Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的容器中,然后发布到任何流行的Linux机器上,也可以实现虚拟化.容器是完全使用沙箱机制,相互 ...

  7. Docker介绍及常用操作演示(一)

    Docker简介 Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的容器中,然后发布到任何流行的Linux机器上,也可以实现虚拟化.容器是完全使用沙箱机制,相互 ...

  8. Docker常用命令汇总,和常用操作举例

    Docker命令 docker 常用命令如下 管理命令: container 管理容器 image 管理镜像 network 管理网络 node 管理Swarm节点 plugin 管理插件 secre ...

  9. go语言之进阶篇字符串操作常用函数介绍

    下面这些函数来自于strings包,这里介绍一些我平常经常用到的函数,更详细的请参考官方的文档. 一.字符串操作常用函数介绍 1.Contains func Contains(s, substr st ...

随机推荐

  1. 报错需要选择一个空目录,或者选择的非空目录下存在 app.json 或者 project.config.json解决方案

    前言 小程序的第一个坑就是,创建了一个小程序项目,却在微信web开发者工具无法打开... 报了个错:需要选择一个空目录,或者选择的非空目录下存在 app.json 或者 project.config. ...

  2. 前端面试题整理——手写简易jquery

    class jQuery { constructor(selector) { const result = document.querySelectorAll(selector) console.lo ...

  3. vue和react给我的感受

    以下纯属个人使用两个框架的感想和体会: 不知道你们是否有这种感觉~ 我vue和react都用过一段时间,但是vue给我感觉就是经常会忘记语法,需要对照文档才知道怎么写( 难不成是我没喝六个核桃的原因吗 ...

  4. 关于页面中css某些情况下出现不知原因的隔断解决办法

    第一种方法:body{margin:0px;padding:0px position:absolute; top:0px;left:0px;} html{ width:100%; overflow-x ...

  5. Hyperledger Fabric无系统通道启动及通道的创建和删除

    前言 在Hyperledger Fabric组织的动态添加和删除中,我们已经完成了在运行着的网络中动态添加和删除组织,但目前为止,我们启动 orderer 节点的方式都是通过系统通道的方式,这样自带系 ...

  6. form表单请求

    form 表单的acton属性指向url:端口号/(服务器get,post的参数), meyhod='get'/'post'  请求方式,必须要加上name属性. <form action=&q ...

  7. 设置网站标题时找不到index.html问题解决

    都知道,修改网站标题在根目录index.html里修改.但是在vue3更新后,index.html就没有放这里了,放到了public中.去public中一眼就能看到.我也是去那里就找到了.

  8. JS将某个数组分割为N个对象一组(如,两两一组,三三一组等)

    方法一: var result = []; var data = [ {name:'chen',age:'25'}, {name:'chen',age:'25'}, {name:'chen',age: ...

  9. python基础练习题(题目 字母识词)

    day22 --------------------------------------------------------------- 实例031:字母识词 题目 请输入星期几的第一个字母来判断一 ...

  10. Unity实现A*寻路算法学习1.0

    一.A*寻路算法的原理 如果现在地图上存在两点A.B,这里设A为起点,B为目标点(终点) 这里为每一个地图节点定义了三个值 gCost:距离起点的Cost(距离) hCost:距离目标点的Cost(距 ...