Martins P., Marinho Z. and Martins A. \(\infty\)-former: Infinite Memory Transformer. arXiv preprint arXiv:2109.00301, 2021.

在transformer中引入一种长期记忆机制.

主要内容

假设\(X \in \mathbb{R}^{L \times d}\), 即每一行\(x_i\)代表一个token对应的特征.

Attention需要进行如下的步骤:

\[Q = XW^Q, K = X W^K, V = XW^V, \\
Z = \mathrm{softmax}(\frac{QK^T}{\sqrt{d}})V.
\]

为了符号简易起见, 我们不考虑multi-head的情形, 下面的思想可以直接应用之.

我们知道, 可以通过径向基函数来逼近任意的连续函数:

\[\sum_{k} b_k \psi_k (t) \rightarrow f(t).
\]

现在, 我们令\(t_i = \frac{i}{L}\), 即对\(L\)个tokens冠以时序, \(X\)的每一列都可以看成一个特殊的\(f_j(t)\)的位于\(t_i, i=0,1,\cdots, L-1\)处的值.

给定\(N\)个基函数\(\psi_k (t), k=0,1,\cdots, N-1\), 我们要通过求解系数\(\bm{b}_j = [b_{j0}, b_{j1},\cdots b_{j,N-1}]^T\)来逼近\(f_j\)(\(X\)的第\(j\)列).

设\(\Psi \in \mathbb{R}^{N \times L}, \Psi_{ki}=\psi_{k}(t_i)\), \(B \in \mathbb{R}^{d \times N}, B_{jk} = b_{jk}\).

作者通过岭回归来求解系数\(b\):

\[B = \arg \min_{B} \|B \Psi - X^T\|_F^2 + \lambda \|B\|_F^2,
\]

其显示表达式为:

\[B = X^T\Psi^T(\Psi\Psi^T + \lambda I)^{-1}.
\]

\[X^T \approx B\Psi \rightarrow x_i \approx B \psi (t_i).
\]

现在我们用\(\tilde{X} := \Psi^T B^T\)来代替\(X\), 则

\[K = \tilde{X} W^K = \Psi^TB^TW^K, \tilde{V} = \tilde{X}W^V = \Psi^TB^TW^V.
\]

注意, 我们并不对\(Q\)进行替换, 因为这个只是用作长期的记录用, Q每次重新计算.

对于每个\(q_i\), 我们构建一个其关于\(t\)的密度函数\(p_i(t)\), 文中假设其满足高斯分布:

\[\mathcal{N}(t; \mu_i; \sigma_i^2).
\]

\(\mu_i, \sigma_i^2\)分别通过如下估计:

\[\mu_i = \mathrm{sigmoid} (w_{\mu}^T K q_i)
=\mathrm{sigmoid} (w_{\mu}^T B^TW^K q_i), \\
\sigma^2_i = \mathrm{softplus} (w_{\sigma}^T K q_i)
=\mathrm{softplus} (w_{\sigma}^T B^TW^K q_i). \\
\]

注意最后令\(w^T\Psi^T = w^T\)既然\(\Psi\)是事先确定的.

我们知道

\[\mathrm{softmax}(\frac{Kq_i}{\sqrt{d}})
\]

实际上求解的是一个离散化的\(p_i(t)\), 即\(q_i\)和\(k_j\)的相合程度, 而

\[\mathrm{softmax}(\frac{Kq_i}{\sqrt{d}})^TV
\]

实际上就是求解期望

\[\mathbb{E}_{p_i}[v(t)].
\]

现在我们近似了一个连续的\(p_i(t)\), 也可以通过这种方式得到最后的\(z_i\):

\[\mathbb{E}_{p_i}[v(t)]
=\mathbb{E}_{p_i}[\psi^T(t)B^TW^V]
=\mathbb{E}_{p_i}[\psi^T(t)]B^TW^V.
\]

当我们取\(\psi\)为高斯径向基函数的时候, 上述是由显示解的.

现在来剖析一下, 好在哪里?

原本的\(K\)是\(L\times d\)的, 现在由于我们只需要计算\(B^TW\), 故实际上只有\(N \times d\), 我们可以选取很大的\(L\)但是选择较小的\(N\)来避免较高的复杂度.

如何扩展?

难不成每一次都要重新计算\(B\)? 倘若真的是这样就谈不上是长期记忆了.

作者采取了一种比较巧的方法, 实际上, 现在的\(B\psi(t)\)可以看成是一个\(d\)维的向量函数.

我们首先将其进行压缩至\([0, \tau], \tau \in (0, 1)\):

\[B\psi(t /\tau),
\]

如此一来, 整个函数的能量集中在\([0, \tau]\)中, 我们可以用剩下的\((\tau, 1]\)来放置新的\(X\).

我们首先从\([0, \tau]\)中采样\(M\)个点\(t_0, \cdots, t_{M-1}\), 并得到:

\[X_{past} = [x_0, \cdots, x_{M-1}]^T \in \mathbb{R}^{M \times d}, x_m=\psi^T(t_m/\tau)B^T.
\]

加上新的\(X_{new}\), 我们有

\[X = [X_{past}^T, X_{new}^T]^T \in \mathbb{R}^{(M + L) \times d},
\]

对\(X\)按照上面的逻辑重新估计\(B\)即可更新记忆.

关于如何采样这\(M\)个点, 作者提了一种sticky memories的方法, 将其与密度函数联系在一起, 便不细讲了.

实验细节

在看这篇论文的时候, 困扰我的就是这个径向基函数是怎么选的?

举一个作者在Language Modeling中的例子便可:

选取150个高斯径向基函数\(\mathcal{N}(t;\mu, \sigma^2)\), 其中

\(\mu\)从\([0, 1]\)中均匀采样, \(\sigma \in \{0.01, 0.05\}\).

还有用KL散度防止一般化就不讲了. 感觉本文有趣的点就是压缩这个地方, 还有对\(\Psi\)的处理.

随机推荐

  1. day13 装饰器与语法糖

    day13 装饰器与语法糖 一.装饰器 1.什么是装饰器 装饰器就是装饰别人的工具,具体是指为被装饰者添加新功能 装饰器->函数 被装饰者->函数 2.为何要用装饰器 装饰器的核心思想:( ...

  2. 【swift】CoreData Crash(崩溃)(Failed to call designated initializer on NSManagedObject class)

    感谢另一篇博客:https://blog.csdn.net/devday/article/details/6577985 里面的图片和介绍,发现问题如他描述的一样,没有bundle 我的Xcode版本 ...

  3. shell脚本采集系统cpu、内存、磁盘、网络信息

    有不少朋友不知道如何用shell脚本采集linux系统相关信息,包括cpu.内存.磁盘.网络等信息,这里脚本小编做下讲解,大家一起来看看吧. 一.cpu信息采集 1),采集cpu使用率采集算法:通过/ ...

  4. Linux服务加入systemctl|service管理

    一.加入systemctl 1.添加 vim /usr/lib/systemd/system/user_timejob.service # copy to /usr/lib/systemd/syste ...

  5. Samba 源码解析之内存管理

    由于工作需要想研究下Samba的源码,下载后发现目录结构还是很清晰的.一般大家可能会对source3和source4文件夹比较疑惑.这两个文件夹针对的是Samba主版本号,所以你可以暂时先看一个.这里 ...

  6. ciscn_2019_s_6

    例行检查 没有开启nx保护,考虑用shellcode来做这道题 程序放入ida查看 我们可以输入48个字符覆盖0使printf打印出bp的值 继续看这里,buf的大小实际上只有0x38的大小,但是re ...

  7. android 使用 perfetto 抓取atrace

    最近项目的原因需要抓自定义的一些atrace,发现使用google 自带的systrace python脚本抓出来的log使用chrome已经打不开了. 想着用用比较时髦的perfetto吧,发现无论 ...

  8. CF1427A Avoiding Zero 题解

    Content 请将一个长度为 \(n\) 的数列 \(A\) 重新排序,使得这个数列所有的前缀和 \(\neq 0\),或者证明没有这样的方案. 数据范围:\(t\) 组数据,\(1\leqslan ...

  9. ubuntu16.04 开启ipv6支持

     1)vim /etc/default/grub将GRUB_CMDLINE_LINUX中下面的这一项删除:ipv6.disable=12)执行 grub-mkconfig -o /boot/grub/ ...

  10. jackson-databind-2.2.3.jar,ackson-annotations-2.2.3.jar和jackson-core-2.2.3.jar下载

    jackson包开发下载,包括如下3个jar文件 jackson-databind-2.2.3.jar,还需要jackson-annotations-2.2.3.jar和jackson-core-2. ...