LLM注意力优化:稀疏注意力

1 minute read

Published:

在长文本情境下,对全部历史token做注意力会带来巨大的计算和访存开销,稀疏注意力(Sparse Attention)则通过让每个query有选择地仅和部分最相关的kv做注意力计算,来大幅减小实际计算量。从DeepSeek NSA,DSA到到DSV4中的CSA,HCA,稀疏注意力几乎已成为超长文本模型的标配,且其基本技术思路也存在诸多共性,包括KV块级压缩、基于相关度分数的top-k选择性注意力、粗细粒度结合的层次化注意力等。稀疏注意力也可以和全注意力、线性注意力等进行结合,通过interleave的方式在模型中交替放置,来实现混合注意力。

目录

Native Sparse Attention(DeepSeek NSA)

NSA(native sparse attention)类似于“局部+全局混合注意力”或“层次化注意力”

alt text

主要为了解决长文本下attention计算成本过高的问题。NSA通过压缩、选择、滑动窗口这三种策略来达到稀疏注意力,同时兼顾了全局粗粒度信息+重要部分的细粒度信息+邻近滑动窗口内的细粒度信息(类似人类阅读时先粗读,再找重点精读,同时看重点附近的内容)。

其是原生可训练的动态分层稀疏注意力。

  • 步骤1:压缩阶段

    先对历史KV进行分块,如上图将一句话分为4块,每块都包括了一部分连续的tokens;

    然后,将每一块都线性变换成一个向量(也即将这一块的所有tokens压缩成一个“总token”,变成了一个词向量),4块就压缩成了4个“总token”,然后再将它们4个拼接;

    然后,用$Q$和上述拼接后的KV向量(长度为4的“总token序列”)做attention,目的是进行“粗粒度的注意力”,也即先大概算一下当前query对于每一部分的关注度。这样就可以得到每一块的注意力分数(它既用于第二步的块筛选,同时也参与到最终的attn score中,作为对于整个context的粗粒度注意力);

  • 步骤2:选择阶段

    选择4块中注意力得分最高的2块,使得query和它们对应的tokens的原始kv cache做注意力(也即,只和比较重要的一部分块的tokens做真正的细粒度注意力)

  • 步骤3:滑动窗口

    为了保护局部信息,还要选择当前token最近的前一段的tokens,让query也和它们做attention(也即确保临近窗口内的tokens一定会被选中做细粒度attention,无论它们在前边步骤中是否属于被选中的块)

最终,将三种注意力(全局所有块的粗粒度+重要块的细粒度+临近窗口内的细粒度)进行合并,然后再通过一个gate机制(线性变化+sigmoid),作为最终的注意力输出

Mixture of Block Attention (Kimi MoBA)

和NSA同一天发布,也是著名的sparse attention工作,致力于减轻长文本下注意力复杂度高的问题。其结合了blockwise attention的分块思想和MoE的门控思想,首先计算token和各个历史block的分数,并使得token只选择top-k个分数最高的块进行关注

alt text

其将MoE原则应用于注意力机制,且可以无缝在稀疏注意力和完整注意力之间切换。

算法:

首先,将上下文划分为多个块,设每个块的token数量为$B$,$I_i$表示第$i$个块:

\[I_i=[(i-1)\times B+1,i\times B]\]

然后,对各个block采用gating机制。具体而言,gating函数$g$为当前query token选取最相关的那些blocks,其首先计算当前query token $q$和每个block的亲密度分数$s_i$,然后在所有blocks的分数中选出top-k个最大的,并被当前token关注。

亲密度分数$s_i$是由query token $q$和第$i$个block中所有token的key的mean pooling(均值)的内积得到的:

\[s_i=\langle q,\mathrm{mean\_pool}(K[I_i])\rangle\]

本质上,sliding window attention和attention sink等都是MoBA的特例(也即固定选取sink tokens/最近窗口中的tokens,相当于MoBA中固定选取这些块)

DeepSeek Sparse Attention (DSA)

在deepseek-v3.2-exp中发布,是一种细粒度的稀疏注意力机制,尤其为长文本情境设计。

简单来说,使用一个lightning indexer来计算query token和各个历史key token之间的index score(也即相关性分数),并选取其中得分最高的k个(训练时k=2048)来和当前token做注意力。

具体而言,输入token $t$和历史token $s$之间的index score $I_{t,s}$为:

\[I_{t,s}=\sum_{j=1}^{H^I}w_{t,j}^I\cdot ReLU(q_{t,j}^I\cdot k_s^I)\]

其中,$H^I$为indexer heads的数量,$q_{t,j}^I\in\mathbb R^{d^I}$和加权权重$w_{t,j}^I$都是由query token的hidden state $h_t$变换而来,$k_s^I$是由历史token的hidden state $h_s$变换而来。注意,这里用于计算index score的$q^I,k^I$是单独生成出来的,并不是计算注意力时真正使用的$q,k$。

由于indexer heads数量很少,且这些indexer向量的维数通常比正常QKV的维数小很多,而且可以用FP8实现,因此indexing开销相对很小。使用最简单的ReLU激活函数可以加快速度,且产生真0可以在稀疏注意力计算时直接skip。

当序列比较短的时候(<4K),DSA反而比较慢,因为indexer的开销占比比较大,此时可以直接使用MHA+mask来模拟稀疏pattern,避免启用高成本的indexer;当序列足够长时再使用真正的稀疏注意力。这种双模系统可以确保在长短序列下选择最优的计算路径。

和MLA集成如下:

DeepSeek CSA & HCA

DeepSeek-V4中采用了CSA(Compressed Sparse Attention)HCA(Heavily Compressed Attention)两种稀疏注意力机制,并将二者interleaved地混合来实现混合注意力,致力于减轻超长文本下做attention的开销压力。

简单来说,CSA首先将每$m$个token的KV Cache压缩提炼成一条(entry),使得KV Cache总量被压缩到原先的$\frac{1}{m}$,然后再使用DSA来让当前query选择性地仅和其中$k$条压缩后到KV Cache entries算attention;另外也会保留一个近邻滑动窗口内的原始KV Cache entries做注意力,来捕捉局部细粒度信息。HCA则更关注极致的KV Cache压缩,其将每$m’(\gg m)$个token的KV Cache压缩为一条,但并不启用稀疏注意力。也即,CSA对KV压缩的少但启用了稀疏注意力,HCA对KV压缩的多并保留全注意力。

CSA的具体实现如下:

alt text

设$H\in\mathbb R^{n\times d}$为输入隐含状态,$n$为序列长度,$d$为隐含维数。

首先分别计算得到两路KV entries $C^a,C^b\in\mathbb R^{n\times c}$:

\[C^a=H\cdot W^{aKV},~~ C^b=H\cdot W^{bKV}\]

以及它们各自对应的压缩权重张量$Z^a,Z^b\in\mathbb R^{n\times c}$:

\[Z^a=H\cdot W^{aZ},~~ Z^b=H\cdot W^{bZ}\]

其中$c$为每个注意力头的维数,$W^{aKV},W^{bKV},W^{aZ},W^{bZ}\in\mathbb R^{d\times C}$为可学习参数。

接下来,$C^a,C^b$中的每$m$条KV Cache entries都会根据压缩权重张量$Z^a,Z^b$以及可学习位置偏置$B^a,B^b\in\mathbb R^{m\times c}$,来被分别压缩成一条entry,然后再将二者错位融合成一条entry。考虑第$i$组tokens,其是由$C^a$中的第$i$组KV entries和$C^b$中的第$i-1$组KV entries融合而成的。

具体而言,首先将压缩权重张量进行归一化,取$Z^a$中对应第$i$组的$m$行$Z^{a}{mi:m(i+1)-1}\in\mathbb R^{m\times c}$与$Z^b$中对应第$i-1$组的$m$行$Z^b{m(i-1):mi-1}\in\mathbb R^{m\times c}$,分别加上位置偏置$B^a,B^b$后,沿序列长度方向堆叠成$2m$行,然后再沿序列长度方向(dim=0)一起做softmax归一化,得到归一化后的压缩权重张量:

\[[S^{a}_{mi:m(i+1)-1};S^b_{m(i-1):mi-1}]=\text{Softmax}([Z^a_{mi:m(i+1)-1}+B^a;Z^b_{m(i-1):mi-1}+B^b])\]

然后使用归一化后的压缩权重沿序列长度方向对$C^a,C^b$中对应的的KV entries做加权求和,最终再将两部分加权求和的结果加到一起,得到第$i$组tokens的最终压缩结果$C_i^{\text{Comp}}\in\mathbb R^c$:

\[C_i^{\text{Comp}}=\sum_{j=mi}^{m(i+1)-1}S_j^a\odot C_j^a+\sum_{j=m(i-1)}^{mi-1}S_j^b\odot C_j^b\]

全部tokens的压缩结果连起来即得到完整的KV压缩结果$C^{\text{Comp}}=[C_0^{\text{Comp}};\cdots;C_{n/m-1}^{\text{Comp}}]\in\mathbb R^{\frac{n}{m}\times c}$,可见其相比于原始KV Cache压缩了$m$倍。

另外,当$i=0$时,取$Z^b_{m(i-1):mi-1}=-\infty$,$C^b_{m(i-1):mi-1}=0$

在得到全部KV压缩结果$C^{\text{Comp}}$后,即可进一步采用DSA的策略来其中选取top-k条和当前query最相关的entries做attn计算,称为core attention。首先,采用和上述相同的流程得到压缩后KV的indexer keys $K^{\text{IComp}}\in\mathbb R^{\frac{n}{m}\times c^I}$,然后再为query token $t$生成$n_h^I$个维数为$c^I$的index queries:

\[c_t^Q=h_t\cdot W^{DQ}\] \[[q_{t,1}^I,\cdots,q_{t,n_h^l}^I]=q_t^I=c_t^Q\cdot W^{IUQ}\]

其中,$h_t\in\mathbb R^{d}$是query token $t$的隐含状态,$c_t^Q\in\mathbb R^{d_c}$是其压缩后的低秩latent向量,$W^DQ,W^{IUQ}$分别为生成indexer queries的down projection和up projection权重,生成的$q_t^I\in\mathbb R^{c^In_h^I}$被切成$n_h^I$个维数为$c^I$的indexer head。

然后即可计算该query token和第$s$个KV压缩块之间的相关度分数$I_{t,s}\in\mathbb R$。该token的各个indexer query head分别和该KV压缩块的indexer key做内积,得到每个head的相关度,然后再通过一套由query本身变换而来的加权权重$w_t^I$进行加权求和,得到最终的相关度分数:

\[[w^{I}_{t,1},\cdots,w^I_{t,n_h^I}]=w_t^I=h_t\cdot W^w\] \[I_{t,s}=\sum_{h=1}^{n_h^I}w_{t,h}^I\cdot\text{ReLU}(q_{t,h}^I\cdot K_s^{\text{IComp}})\]

对于所有压缩KV块都计算了相关度分数后,选出top-k个最高的并将它们对应的KV entry拿出来组成$C_t^{\text{SprsComp}}\in\mathbb R^{k\times c}$,来和query做core attention。

Core attention被实现为MQA形式,也即query为多头,kv为单头,且这里$C_t^{\text{SprsComp}}$被同时用作key和value。具体而言,使用另一个up projection权重$W^{UQ}$作用于上文得到的query latent向量$c_t^Q$并沿隐含维度方向切分,得到query的$n_h$个head:

\[[q_{t,1},\cdots,q_{t,n_h}]=q_t=c_t^Q\cdot W^{UQ}\]

然后即可做core attention,第$i$个head的计算结果为$o_{t,i}\in\mathbb R^c$:

\[o_{t,i}=\text{CoreAttn}(\text{query}=q_{t,i},\text{key}=C_t^{\text{SprsComp}},\text{value}=C_t^{\text{SprsComp}})\]

另外,为了捕捉近邻的细粒度信息,实际上还保留了最近的一个滑动窗口内的所有KV entries做注意力,其结果会和对压缩KV块的计算结果进行结合。

由于直接将$n_h$个head的计算结果连接起来得到的维数$cn_h$在ds-v4的配置下过大,会引入过大计算开销,因此这里采用分组的输出映射。先将$n_h$个head的计算结果分为$g$组,将每组内部的结果分别连接起来得到维数$cn_h/g$的向量后,再通过一个降维映射将其维数变为$d_g<cn_h/g$,然后再将各组的结果连成一个维数$d_gg$的向量$[o_{t,1}^{G’},\cdots,o_{t,g}^{G’}]\in\mathbb R^{d_gg}$,最后再映射为原始隐含维数的$\hat o_t\in\mathbb R^d$。

HCA的具体实现如下:

alt text

总的来说,HCA的压缩策略和CSA差不多,只不过压缩率$m’\gg m$,且没有启用两路错位加和,只有一套$C,Z$:

\[C=H\cdot W^{KV}\] \[Z=H\cdot W^Z\] \[S_{m'i:m'(i+1)-1}=\text{Softmax}(Z_{m'I:m'(i+1)-1}+B)\] \[C_i^{\text{Comp}}=\sum_{j=m'i}^{m'(i+1)-1}S_j\odot C_j\]

在执行Core Attention时,query会和所有压缩KV块的entry做MQA,不再进行top-k筛选。另外在输出时也采用了分组的输出映射