LLM注意力优化:线性注意力

1 minute read

Published:

标准注意力计算关于序列长度呈二次复杂度,在长文本情境下会给模型推理带来巨大开销。而二次复杂度的无法破除的本质原因在于,注意力计算中的softmax迫使必须先算QK再和V运算,而不能先算KV再和Q运算。线性注意力(Linear Attention)通过对标准注意力公式进行近似改写来使得后者可行,从而使得计算开销变成随序列长度线性增长。从另一个角度看,线性注意力本质上也是类似RNN或SSM模型,以recurrent的形式通过有限大小的记忆张量来压缩无限的历史信息,而不像标准注意力下KV Cache随上下文长度线性增长。因此,线性注意力可以从根本上解决注意力二次复杂度的问题,但由于其本质上是用有限空间压缩无限记忆,因此算法性能不一定能赶上标准注意力。目前Qwen3-Next等模型通常以混合注意力的形式来穿插线性注意力层和标准注意力层,从而在算法和系统性能之间达到一个平衡。”

目录

线性注意力基础

使用kernel-based的方法,设计一个kernel $\phi$来对$Q,K$做变换,从而近似非线性的softmax操作:将其近似为kernalize后的线性点积操作:

\[Softmax(QK^T)V\approx\phi(Q)(\phi(K)^TV)\]

Kernelize过程可以引入非线性,例如可以采用$\phi(x)=elu(x)+1$

在attention的原始形式下,由于softmax操作是非线性的,所以必须先计算$O(n^2d)$的内积$Q_{n\times d}K^T_{d\times n}$再和$V$计算,而不能使用结合律来修改计算顺序。

而破除掉softmax后,attention操作变成了线性操作,利用线性操作的结合律,可以转而先计算$O(d^2n)$的外积$K^T_{d\times n}V_{n\times d}$再和$Q_{n\times d}$做计算,这样就把关于序列长度$n$的复杂度降到了线性。

进一步,可以证明线性注意力本质上就是RNN(或者说SSM)的recurrent形式。简单起见先忽略kernalize函数,则对于token $t$来说,其关注之前所有token($token_1\sim token_t$)的attention计算结果$y_t$为:

\[\begin{align} y_t&=\sum_{i=1}^t(q_t^Tk_i)v_i\\ &=\sum_{i=1}^t v_i(k_i^T q_t)\\ &=\left(\sum_{i=1}^t v_i k_i^T\right)q_t\\ &=S_tq_t \end{align}\]

上述表达式通过线性操作的交换律和结合律,通过计算历史k,v的外积并得到一个历史状态矩阵$S$,从而将$token_t$与历史语境做注意力的结果写成了recurrent的形式:

\[S_t=\sum_{i=1}^t v_i k_i^T\] \[S_t=S_{t-1}+v_tk_t^T\] \[y_t=S_tq_t\]

可以理解为“当前token $q_t$从历史语境$S_t$中检索到信息并输出$o_t$”。其和RNN的形式是一样的,只不过RNN使用一个$1\times d$的hidden vector来记录历史信息,而线性注意力使用一个$d\times d$的矩阵$S$来记录历史信息。

这样一来,不仅每一步推理关于序列长度的复杂度变为了$O(n)$,同时也可以只使用固定大小的状态矩阵$S$来保存历史信息,而不需要线性增长的kv cache,在这一点上也继承了RNN对于长文本的优点。因此,线性注意力很适合长文本情境。

训练时,可以将其改写为卷积形式,从而可以:1)在系统层面采用并行的prefix scan来完成计算;2)在数学层面使用FFT变换到频域后将卷积用频域乘法代替,从而避免$O(N^2)$的复杂度

线性注意力的问题:

其更新规则对于所有tokens都一视同仁,但事实上attention机制中不同tokens并不同等重要。例如,总体来看通常越近的token越重要,越远的token越不重要,因此RoPE等位置编码都带有远程衰减属性。

因此,可以尝试给状态矩阵$S$加上各种距离衰减/gating(decay)。例如最基本的data independent decay使得历史信息以指数速度衰减,其中$0<\gamma<1$是数据无关的常数:

\[S_t=\gamma S_{t-1}+v_tk_t^T\]

也可以使用data dependent decay,使得历史信息的衰减率$0<\gamma_t<1$依据当前token $t$决定:

\[S_t=\gamma_t S_{t-1}+v_tk_t^T\]

当前主流linear attention工作的更新公式:

alt text

DeltaNet

https://arxiv.org/pdf/2406.06484

其本质上也是一种data dependent decay,并且在更新状态时更直观地主动抹除了旧信息并写入了新信息。

alt text

设token $t$的输入表征为$h_t$,则有:

\[q_t=W_Qx_t\] \[k_t=W_Kx_t\] \[v_t=W_Vx_t\]

另外设置一个权重$W_\beta$,其可以让当前token $t$决定本步的旧信息和当前信息的混合比例$\beta_t$:

\[\beta_t=\sigma(W_\beta x_t)\]

使用当前token的key $k_t$来从历史语境$S_{t-1}$中检索出old value,代表旧信息:

\[v_t^{old}=S_{t-1}k_t\]

然后,即可使用混合比例$\beta_t$对于旧value信息$v_t^{old}$和当前token的value信息$v_t$进行加权平均,组合得到“当前value信息”:

\[v_t^{new}=\beta_t v_t+(1-\beta_t)v_{t}^{old}\]

最终得到DeltaNet的状态更新规则,其直观上为:在状态矩阵中擦除旧信息$v_t^{old}k_t^T$,并写入当前信息$v_t^{new}k_t^T$,从而得到新状态:

\[S_t=S_{t-1}-v_t^{old}k_t^T+v_t^{new}k_t^T\]

当前token的输出为:

\[y_t=S_tq_t\]

整理上述公式,DeltaNet的状态更新公式(也即Delta Rule)可以写成:

\[S_t=S_{t-1}(I-\beta_t k_tk_t^T)+\beta_t v_tk_t^T\]

DeltaNet的网络结构:

事实上,近期的linear attention工作通常在$q,k,v$上进一步加一个短卷积层(“depthwise-separable short convolution”)。其通常被加到q/k/v projection之后,对于$q_t,k_t,v_t$进行施加。它是一个optional的选项,2024之后的linear attention工作通常会采用,很轻量级,例如DeltaNet中采用kernel_size=4。实验表明加上它以后算法效果会更好些。

其作用是“token-mixing”,也即将当前token和临近历史token的信息进行混合,增强临近token之间的局部交互,从而使得输入特征更丰富。纯线性投影生成的$q_t,k_t,v_t=W_qh_t,W_kh_t,W_vh_t$只考虑当前输入token $x_t$,而在SSM中模型往往缺乏精确的局部token偏移和比较能力,也缺乏位置信息。因此,引入一个短卷积可以隐式地捕捉$x_{t-1},x_{t-2}$等临近token的信息,在计算成本低廉的前提下提升模型(在时间维度上的)表达能力。

具体而言,设卷积核大小为$N$,则需要维护一个动态滑动窗口来维护最近的$N$个tokens的$q,k,v$:${q_{t-N+1},\cdots,q_{t-1},q_t},{k\cdots},{v\cdots}$。以$q$为例,对于当前token线性投影生成的$q_t=W_qh_t$做完卷积后,得到的新query为:

\[\tilde q_t=\sum_{\tau=0}^{K-1}w_\tau\cdot q_{t-\tau}\]

其中$w_\tau$为卷积核权重。例如$K=3$时:$\tilde q_t=w_0q_t+w_1q_{t-1}+w_2q_{t-2}$。

然后,再用做完token mixing后的$\tilde q,\tilde k,\tilde v$进一步进行Delta rule中的状态更新等。

Gated DeltaNet

https://arxiv.org/pdf/2412.06464

Gated DeltaNet在DeltaNet的基础上,借鉴Mamba2直接针对状态矩阵$S_{t-1}$再加一个data dependent的decay(或者说叫gating):

\[\tilde S_{t-1}=\alpha_tS_{t-1}\]

其中$0<\alpha_t<1$是和当前token $t$有关的衰减因子(也是由$h_t$线性变换而来),其可以灵活控制希望保留的记忆量。然后再用这个衰减后的$\tilde S_{t-1}$来做DeltaNet中的操作即可:

将公式整理后即可得到Gated DeltaNet的$S_t$更新公式:

\[S_t=S_{t-1}(\textcolor{red}{\alpha_t}(I-\beta_t k_tk_t^T)+\beta_t v_tk_t^T\]

可见,其融合了Mamba2的更新公式:

\[S_t=\alpha_t S_{t-1}+v_tk_t^T\]

和DeltaNet的更新公式:

\[S_t=S_{t-1}(I-\beta_t k_tk_t^T)+\beta_t v_tk_t^T\]

Qwen3-Next Hybrid Attention

Qwen3-Next为Gated DeltaNet和Gated Attention以3:1混合的hybrid attention架构,也即有几层是前者,有几层是后者,交替出现,从而支持高效的超长文本处理。

  • Gated DeltaNet

    是一种线性注意力(或者说SSM),其cache(状态矩阵)大小一直不变,适合处理超长文本,是长文本高效性的主要来源。详见线性注意力部分。

  • Gated Attention

    尝试在普通的Softmax Attention模块的各个部分加入gating机制。具体而言,gating机制的通式为:

    \[Y'=Y\odot\sigma(XW_\theta)\]

    其中$Y$是需要被gating调制的输入,$X$是用于计算gating值的另一个输入(一般来说$Y$其实也是$X$变换而来的,在这个工作中$X$取的是pre-norm后进入attention前的那个hidden state),$W_\theta$表示可学习的gating权重,$\sigma$为某种激活函数

    例如,LLaMA MLP的SwiGLU中:$Y=\text{up_proj}(X)$,$XW_\theta=\text{gate_proj}(X)$,则$Y’=\text{up_proj}(X)\odot Act(\text{gate_proj}(X))$

    对比实验结论证明,在$W_V$和$W_O$之间加一个gate的训练效果最好。

    在attention中加入gating效果很好的可能原因是:

    • 原始的$W_V,W_O$之间不存在非线性,二者本质上可以融合成一个矩阵,因此 $W_V,W_O$之间加一个非线性操作有利于提高模型表达能力

    • gating可以带来input dependent sparsity,可以过滤掉不重要的信息

    • gating可以避免attention sink(可能是因为筛掉了不重要的tokens,留下的都是重要的tokens,所以attention不太需要再把多余注意力扔到sink处),非常有利于训练稳定性、长文本等

    • gating可以加强长文本外推能力。实验发现使用gated attention+YaRN做长文本外推后,模型性能下降更少。

Kimi Delta Attention

https://arxiv.org/pdf/2510.26692

通过更细粒度的gating来进一步提升Gated DeltaNet。在Gated DeltaNet Attention(GDA)中,每个时刻的衰减因子$\alpha_t\in(0,1)$是一个标量,也即$S_{t-1}\in\mathbb R^{d\times d}$中的所有元素都乘以同一个$\alpha_t$:

\[S_t=(I-\beta_t k_tk_t^T)\textcolor{red}{\alpha_t}S_{t-1}+\beta_t v_tk_t^T\]

但Kimi Delta Attention(KDA)中进一步把$\alpha_t$扩展为一个对角阵$\mathrm{Diag}(\mathbf{\alpha}_t)$,使得能够更细粒度地控制memory decay和positional awareness。更新公式变为:

\[S_t=(I-\beta_t k_tk_t^T)\textcolor{red}{\mathrm{Diag}(\alpha_t)}S_{t-1}+\beta_t v_tk_t^T\]

alt text

完整的Kimi linear是linear attention(KDA)和full attention(MLA)比例为3:1的hybrid attention结构,且FFN部分采用MoE: