DeepSeek MLA学习笔记
Published:
MLA(Multi-Head Latent Attention)是DeepSeek V2中提出的一种KV Cache压缩技术。不同于以往大多数方法致力于在token数量维度压缩KV Cache,MLA通过大幅缩减KV所需的通道数量来实现压缩效果,并通过一些算法和系统优化来削减额外的推理开销和兼容RoPE位置编码,从而使得模型decoding过程中针对KV Cache的访存压力大幅减小,有利于减少显存压力和提升推理速度。
压缩逻辑:
对于第$t$个token,假设其原始hidden state为$h_t\in\mathbb R^{d}$,希望将其维数压缩到$d_c\ll d$。
首先,使用$W^{DKV}\in\mathbb R^{d\times d_c}$来将hidden state $h_t$压缩为$c_t^{KV}\in\mathbb R^{d_c}$,它在推理中会被缓存起来作为历史context(代替kv cache):
\[c_t^{KV} = h_tW^{DKV}\]当需要计算attention时,再分别使用$W^{UK}\in\mathbb R^{d_c\times d}$和$W^{UV}\in\mathbb R^{d_c\times d}$将keys和values还原回原始维度(这里假设$d=d_hn_h$,其中$d_h$为head dim,$n_h$为num heads):
\[\begin{align} k_t^C &= c_t^{KV}W^{UK}\\ v_t^C &= c_t^{KV}W^{UV} \end{align}\]算子融合:
另外,在不考虑位置编码干扰的情况下,由于我们的最终目标是算出attention+output proj结果而不是得到$q,k,v$,因此算attention的时候可以进一步将$W^{UK}$吸收到$W^Q$中,将$W^{UV}$吸收到$W^O$中。设$u_{t_1}$为token $t_1,t_2$做注意力的最终结果:
\[u_{t_1}=[o_{t_1,1},\cdots, o_{t_1,h_h}]W^O\]则head $i$上的计算为:
\[\begin{align} o_{t,i}&=\left(\mathrm{Softmax}\left(\frac{q_{t_1,i}k_{t_2,i}^T}{\sqrt{d_h}}\right)v_{t_2,i}\right)W^O\\ &=\left(\mathrm{Softmax}\left(\frac{(h_{t_1} W^Q)(c_{t_2}^{KV}W^{UK})^T}{\sqrt{d_h}}\right)c_{t_2}^{KV}W^{UV}\right)W^O\\ &=\mathrm{Softmax}\left(\frac{h_{t_1} \textcolor{red}{(W^QW^{UK^T})}c_{t_2}^{KV^T}}{\sqrt{d_h}}\right)c_{t_2}^{KV}\textcolor{red}{(W^{UV}W^O)}\\ \end{align}\]对query的压缩
为了减小训练阶段的activation memory,同样对queries也做了压缩,虽然它不会有助于缩小kv cache:
\[\begin{align} c_t^Q&=h_t W^{DQ}\\ q_t^C&=c_t^Q W^{UQ} \end{align}\]
Decoupled RoPE
如果按照普通方法给还原后的key加RoPE的话,则会导致无法按照上式吸收掉$W^{UK}$,因为中间的RoPE矩阵$W^\Theta$是和位置有关的动态量,不能被静态权重$W^Q,W^{UK}$吸收:
\[(h_{t_1} W^QW^{\Theta_{t_1}})(c_{t_2}^{KV}W^{UK}W^{\Theta_{t_2}})^T=h_{t_1}\textcolor{red}{(W^QW^{\Theta_{t_1}}W^{\Theta_{t_2}} W^{UK^T}})c_t^{KV^T}\]因此,提出一种将RoPE解耦的策略,其使用额外的多头queries $q_t^R={q_{t,i}^R}_{i=1,\cdots,n_h}\in\mathbb R^{d_h^R}$和一个所有head共享的key $k_t^R\in\mathbb R^{d_h^R}$,来承载RoPE,其中$d_h^R$为解耦的queries和key的per-head维数。它们由RoPE单独算出:
\[[q_{t,1}^R,\cdots,q_{t,n_h}^R]=q_t^R=c_t^{Q}W^{QR}W^{\Theta_t}\]
其中,$W^{QR}\in\mathbb R^{d_c’\times d_h^R n_h}$,$W^{KR}\in\mathbb R^{d\times d_h^R}$,也即$q_t^R$是由压缩表征$c_t^Q$经过线性变换+RoPE得到的,$k_t^R$是由原始hidden dimension $h_t$经过线性变换+RoPE得到的。
在用于计算时,将原$q,k$和上述专门储存RoPE信息的$q,k$进行逐head的concat(沿feature维度):
\[q_{t,i}=[q_{t,i}^C;q_{t,i}^R]\] \[k_{t,i}=[k_{t,i}^C,k_{t}^R]\]也即,在预训练阶段,人为将$q,k$的每个head的维数由$d_h$扩展到了$d_h+d_h^R$。这也使得计算attention时除以的scaling factor由$\sqrt{d_h}$变成了$\sqrt{d_h+d_h^R}$。推理过程中,需要把$k_t^R$也存储起来作为历史context各个token的位置信息。
总的来看,MLA在预训练阶段将RoPE的作用方式由矩阵乘改为了concat,这就使得恢复原始维数的$k_t^C$可以和缓存的$k_t^R$通过高效的concat进行组合,而不需要每次都乘一遍RoPE矩阵。
综上来看,每个token $t$需要被缓存的量为:$c_t^{KV}$和$k_t^R$,则$l$个token需要占据显存为:$(d_c+d_h^R)l$。
在deepseek-v2中,设置$d_c=4d_h$,$d_h^R=\frac{d_h}{2}$,可见位置信息缓存占用的额外开销实际上非常小,甚至小于了一个head的维数。
