在Tensorflow、Numpy和PyTorch中都提供了使用einsum的api,einsum是一种能够简洁表示点积、外积、转置、矩阵-向量乘法、矩阵-矩阵乘法等运算的领域特定语言。在Tensorflow等计算框架中使用einsum,操作矩阵运算时可以免于记忆和使用特定的函数,并且使得代码简洁,高效。

如对矩阵\(A\in \mathbb{R}^{I×K}​\)和矩阵\(B\in \mathbb{R}^{K×J}​\)做矩阵乘,然后对列求和,最终得到向量\(c\in \mathbb{R}^J​\),即:

\[\mathbb{R}^{I×K}\bigotimes \mathbb{R}^{K×J}\to \mathbb{R}^{I×J}\to \mathbb{R}^{J}
\]

使用爱因斯坦求和约定表示为:

\[c_j=\sum_i\sum_kA_{ik}B_{kj}=A_{ik}B_{kj}
\]

在Tensorflow、Numpy和PyTorch中对应的einsum字符串为:

ik,kj->j

在上面的字符串中,隐式地省略了重复的下标\(k\),表示在该维度矩阵乘;另外输出中未指明下标\(i\),表示在该维度累加。

Numpy、PyTorch和Tensorflow中的einsum

einsum在Numpy中的实现为np.einsum,在PyTorch中的实现为torch.einsum,在Tensorflow中的实现为tf.einsum,均使用同样的函数签名einsum(equation,operands),其中,equation传入爱因斯坦求和约定的字符串,而operands则是张量序列。在Numpy、Tensorflow中是变长参数列表,而在PyTorch中是列表。上述例子中,在Tensorflow中可写作:

tf.einsum('ik,kj->j',mat1,mat2)

其中,mat1、mat2为执行该运算的两个张量。注意:这里的(i,j,k)的命名是任意的,但在一个表达式中要一致。

PyTorch和Tensorflow像Numpy支持einsum的好处之一就是,einsum可以用于深度网络架构的任意计算图,并且可以反向传播。在Numpy和Tensorflow中的调用格式如下:

\[result=\mathop{einsum}('\square \square, \square \square \square,\square \square\to \square \square',arg1,arg2,arg3)
\]

其中,\(\square\)是占位符,表示张量维度;arg1,arg3是矩阵,arg2是三阶张量,运算结果是矩阵。注意:einsum处理可变数量的输入。上面例子中,einsum制定了三个参数的操作,但同样可以操作一个参数、两个参数和三个参数及以上的操作。

典型的einsum表达式

前置知识

  • 内积

    又称点积、点乘,对应位置数字相乘,结果是一个标量,有见向量内积和矩阵内积等。

    向量\(\vec a\)和向量\(\vec b\)的内积:

    \[\vec a=[a_1,a_2,...,a_n]\\
    \vec b=[b_1,b_2,...,b_n]\\
    \vec a\cdot \vec b^T=a_1b_1+a_2b_2+...+a_nb_n
    \]

    内积几何意义:

    \[\vec a \cdot \vec b^T=|\vec a||\vec b|\mathop{cos}\theta
    \]

  • 外积

    又称叉乘、叉积、向量积,行向量矩阵乘列向量,结果是二阶张量。注意到:张量的外积作为张量积的同义词。外积是一种特殊的克罗内克积。

    向量\(\vec a\)和向量\(\vec b\)的外积:

    \[\begin{bmatrix}
    b_1
    \\b_2
    \\ b_3
    \\ b_4
    \end{bmatrix}\bigotimes[a_1,a_2,a_3]=\begin{bmatrix}
    a_1b_1 & a_2b_1 & a_3b_1 \\
    a_1b_2 & a_2b_2 & a_3b_2 \\
    a_1b_3 & a_2b_3 & a_3b_3 \\
    a_1b_4 & a_2b_4 & a_3b_4 \\
    \end{bmatrix}
    \]

    外积的几何意义:

    \[\vec a=(x_1,y_1,z_1)\\
    \vec b=(x_2,y_2,z_2)\\
    \vec a\bigotimes\vec b=\begin{vmatrix}
    i & j & k\\
    x_1 & y_1 & z_1\\
    x_2 & y_2 & z_2
    \end{vmatrix}=(y_1z_2-y_2z_1)\vec i-(x_1z_2-x_2z_1)\vec j+(x_1y_2-x_2y_1)\vec k
    \]

    其中,

    \[\vec i=(1,0,0)\\
    \vec j=(0,1,0)\\
    \vec k=(0,0,1)
    \]

由于PyTorch可以实时输出运算结果,以PyTorch使用einsum表达式为例。

  • 矩阵转置

    \[B_{ji}=A_{ij}
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ij->ji',[a])
    >>>tensor([[0, 3],
    [1, 4],
    [2, 5]])
  • 求和

    \[b=\sum_{i}\sum_{j}A_{ij}
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ij->',[a])
    >>>tensor(15)
  • 列求和(列维度不变,行维度消失)

    \[b_j=\sum_iA_{ij}
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ij->j',[a])
    >>>tensor([ 3.,  5.,  7.])
  • 列求和(列维度不变,行维度消失)

    \[b_i=\sum_jA_{ij}
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ij->i', [a])
    >>>tensor([  3.,  12.])
  • 矩阵-向量相乘

    \[c_i=\sum_k A_{ik}b_k
    \]

    a=torch.arange(6).reshape(2,3)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]])
    torch.einsum('ik,k->i',[a,b])
    >>>tensor([  5.,  14.])
  • 矩阵-矩阵乘法

    \[C_{ij}=\sum_{k}A_{ik}B_{kj}
    \]

    a=torch.arange(6).reshape(2,3)
    b=torch.arange(15).reshape(3,5)
    >>>tensor([[0, 1, 2],
    [3, 4, 5]]) >>>tensor([[ 0, 1, 2, 3, 4],
    [ 5, 6, 7, 8, 9],
    [10, 11, 12, 13, 14]])
    torch.einsum('ik,kj->ij',[a,b])
    >>>tensor([[ 25,  28,  31,  34,  37],
    [ 70, 82, 94, 106, 118]])
  • 点积

    • 向量

      \[c=\sum_i a_i b_i
      \]

      a=torch.arange(3)
      b=torch.arange(3,6)
      >>>tensor([0, 1, 2])
      >>>tensor([3, 4, 5])
      torch.einsum('i,i->',[a,b])
      >>>tensor(14.)
    • 矩阵

      \[c=\sum_i\sum_j A_{ij}B_{ij}
      \]

      a=torch.arange(6).reshape(2,3)
      b=torch.arange(6,12).reshape(2,3)
      >>>tensor([[0, 1, 2],
      [3, 4, 5]]) >>>tensor([[ 6, 7, 8],
      [ 9, 10, 11]])
      torch.einsum('ij,ij->',[a,b])
      >>>tensor(145.)
  • 外积

    \[C_{ij}=a_i b_j
    \]

    a=torch.arange(3)
    b=torch.arange(3,7)
    >>>tensor([0, 1, 2])
    >>>tensor([3, 4, 5, 6])
    torch.einsum('i,j->ij',[a,b])
    >>>tensor([[  0.,   0.,   0.,   0.],
    [ 3., 4., 5., 6.],
    [ 6., 8., 10., 12.]])
  • batch矩阵乘

    \[C_{ijl}=\sum_{k}A_{ijk}B_{ikl}
    \]

    a=torch.randn(3,2,5)
    b=torch.randn(3,5,3)
    >>>tensor([[[-1.4131e+00,  3.8372e-02,  1.2436e+00,  5.4757e-01,  2.9478e-01],
    [ 1.3314e+00, 4.4003e-01, 2.3410e-01, -5.3948e-01, -9.9714e-01]], [[-4.6552e-01, 5.4318e-01, 2.1284e+00, 9.5029e-01, -8.2193e-01],
    [ 7.0617e-01, 9.8252e-01, -1.4406e+00, 1.0071e+00, 5.9477e-01]], [[-1.0482e+00, 4.7110e-02, 1.0014e+00, -6.0593e-01, -3.2076e-01],
    [ 6.6210e-01, 3.7603e-01, 1.0198e+00, 4.6591e-01, -7.0637e-04]]]) >>>tensor([[[-2.1797e-01, 3.1329e-04, 4.3139e-01],
    [-1.0621e+00, -6.0904e-01, -4.6225e-01],
    [ 8.5050e-01, -5.8867e-01, 4.8824e-01],
    [ 2.8561e-01, 2.6806e-01, 2.0534e+00],
    [-5.5719e-01, -3.3391e-01, 8.4069e-03]], [[ 5.2877e-01, 1.4361e+00, -6.4232e-01],
    [ 1.0813e+00, 8.5241e-01, -1.1759e+00],
    [ 4.9389e-01, -1.7523e-01, -9.5224e-01],
    [-1.3484e+00, -5.4685e-01, 8.5539e-01],
    [ 3.7036e-01, 3.4368e-01, -4.9617e-01]], [[-2.1564e+00, 3.0861e-01, 3.4261e-01],
    [-2.3679e+00, -2.5035e-01, 1.8104e-02],
    [ 1.1075e+00, 7.2465e-01, -2.0981e-01],
    [-6.5387e-01, -1.3914e-01, 1.5205e+00],
    [-1.6561e+00, -3.5294e-01, 1.9589e+00]]])
    torch.einsum('ijk,ikl->ijl',[a,b])
    >>>tensor([[[ 1.3170, -0.7075,  1.1067],
    [-0.1569, -0.2170, -0.6309]], [[-0.1935, -1.3806, -1.1458],
    [-0.4135, 1.7577, 0.3293]], [[ 4.1854, 0.5879, -2.1180],
    [-1.4922, 0.7846, 0.7267]]])
  • 张量缩约

    batch矩阵相乘是张量缩约的一个特例,比如有两个张量,一个n阶张量\(A\in \mathbb{R}^{I_1×l_2×...×I_n}​\),一个m阶张量\(B\in \mathbb{R}^{J_1×J_2×...×J_m}​\)。取n=4,m=5,假定维度\(I_2=J_3​\)且\(I_3=J_5​\),将这两个张量在这两个维度上(A张量的第2、3维度,B张量的第3、5维度)相乘,获得新张量\(C\in \mathbb{R}^{I_1×I_4×J_1×J_2×J_4}​\),如下所示:

    \[C_{I_1×I_4×J_1×J_2×J_4}=\sum_{I_2==J_3}\sum_{I_3==J_5}A_{I_1×I_2×I_3×I_4}B_{J_1×J_2×J_3×J_4×J_5}
    \]

    a=torch.randn(2,3,5,7)
    b=torch.randn(11,13,3,17,5) torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape
    >>>torch.Size([2, 7, 11, 13, 17])
  • 多张量计算

    如前所述,einsum可用于超过两个张量的计算,以双线性变换为例:

    \[D_ij=\sum_k\sum_lA_{ik}B_{jkl}C_{il}
    \]

    a=torch.randn(2,3)
    b=torch.randn(5,3,7)
    c=torch.randn(2,7) torch.einsum('ik,jkl,il->ij',[a,b,c]).shape
    >>>torch.Size([2,5])

kimiyoung/transformer-xl的tf部分大量使用了einsum表达式。

einsum满足你一切需要:深度学习中的爱因斯坦求和约定

向量点乘(内积)和叉乘(外积、向量积)概念及几何意义解读

矩阵外积与内积

外积-wiki

einsum:爱因斯坦求和约定的更多相关文章

  1. 爱因斯坦求和约定 (Einstein summation convention)

  2. MindSpore尝鲜之爱因斯坦求和

    技术背景 在前面的博客中,我们介绍过关于numpy中的张量网络的一些应用,同时利用相关的张量网络操作,我们可以实现一些分子动力学模拟中的约束算法,如LINCS等.在最新的nightly版本的MindS ...

  3. einsum函数介绍-张量常用操作

    einsum函数说明 pytorch文档说明:\(torch.einsum(equation, **operands)\) 使用基于爱因斯坦求和约定的符号,将输入operands的元素沿指定的维数求和 ...

  4. NumPy v1.15手册汉化

    NumPy参考 数组创建 零 和 一 empty(shape[, dtype, order]):返回给定形状和类型的新数组,而不初始化条目 empty_like(prototype[, dtype,  ...

  5. numpy函数查询手册

    写了个程序,对Numpy的绝大部分函数及其说明进行了中文翻译. 原网址:https://docs.scipy.org/doc/numpy/reference/routines.html#routine ...

  6. NumPy之:ndarray中的函数

    NumPy之:ndarray中的函数 目录 简介 简单函数 矢量化数组运算 条件逻辑表达式 统计方法 布尔数组 排序 文件 线性代数 随机数 简介 在NumPy中,多维数组除了基本的算数运算之外,还内 ...

  7. Differential Geometry之第四章标架与曲面论的基本定理

    第四章.标架与曲面论的基本定理 1.活动标架 2.自然标架的运动方程 爱因斯坦求和约定(Einstein summation convention) 3.曲面的结构方程 4.曲面的存在唯一性定理 5. ...

  8. 记号(notation)的学习

    数学的记号(notation) 记号具体代表什么含义,取决于你的定义: 比如这样的 d⃗  一个向量,每个分量 d(i) 表示的是从初始结点 v 到当前节点 vi 的最短路径:也即这样的一个向量的每一 ...

  9. 如何基于MindSpore实现万亿级参数模型算法?

    摘要:近来,增大模型规模成为了提升模型性能的主要手段.特别是NLP领域的自监督预训练语言模型,规模越来越大,从GPT3的1750亿参数,到Switch Transformer的16000亿参数,又是一 ...

随机推荐

  1. 新一代Xamarin

    新一代Xamarin竟然可以将.NET代码原生编译成:Jar包供Java原生调用.swift类库.obj-c类库.C++类库 供目标平台传统代码直接调用 之前和很多朋友聊到Xamarin觉得确实不错, ...

  2. Java与C#的语法区别

    1.作用域 在java中 { { int a=1; } int a=2;//以上a作用域外的以下,再声明同名的变量,是允许的: } 在C#中,以上是不允许的[只要在同一个作用域内,以上或以下的代码中 ...

  3. maven Java.lang.ClassNotFoundException: org.springframework.web.servlet.DispatcherServlet

    如果你可以确认你的maven Dependencies中已经导入了如下的jar包,那么你就要检查下Deployment Assembly 选中项目 alt+enter,然后查看maven依赖有没有被添 ...

  4. 64 位系统 vs2013 配置 OpenCV-3.1.0

    参考:64 位系统 vs2013 配置 opencv3.0 1. 环境准备 进入官网 http://opencv.org/,下载最新版本的 opencv(以本文 opencv-3.1.0 为例,.ex ...

  5. LR杂记 - Linux的系统监控工具vmstat详细说明

    一.前言 非常显然从名字中我们就能够知道vmstat是一个查看虚拟内存(Virtual Memory)使用状况的工具,可是如何通过vmstat来发现系统中的瓶颈呢?在回答这个问题前,还是让我们回想一下 ...

  6. 判断软件的闲置时间(使用Application.OnMessage过滤所有键盘鼠标消息)

    GetLastInputInfo是检测系统输入的,应用到某个程序中不合适! 此问题有二种解法来监控输入消息: 1.用线程级HOOK,钩上MOUSEHOOK与KEYBOARDHOOK 2.在Applic ...

  7. ADO.NET- 中批量添加数据的几种实现方法比较

    在.Net中经常会遇到批量添加数据,如将Excel中的数据导入数据库,直接在DataGridView控件中添加数据再保存到数据库等等. 方法一:一条一条循环添加 通常我们的第一反应是采用for或for ...

  8. 音频、视频等文件格式(.ts、.meta)及其认识

    MPEG:Moving Picture Experts Group,动态图像专家组, JPEG:Joint Photographic Experts Group,联合图像专家组 1. .ts .ts ...

  9. 读取数据变JSON传值!

    $(document).on("click",".btn_small",function(){                                v ...

  10. [Linux] ssh秘钥对免密码登陆

    准备两台linux服务器 a和b , 在a上使用ssh命令登陆b服务器 , 并且不用 输入密码 1.在a服务器上,比如是root用户 ,进去/root/.ssh目录 ,没有就创建, 就是进入家目录的. ...