大模型中的Attention Sink与Outliers的产生机制与消除
Published:
相比于传统的小型神经网络,大模型中存在Attention Sink和Outliers等独特的现象,它们的存在会给低比特量化等模型优化过程带来额外的挑战,并且也可能损害模型的算法性能。这里浅记一下对于它们产生机制的理解,以及消除它们的一些常见方法。
目录
1. Attention Sink和Outlier的产生
1.1 Attention Sink的产生机制
指的是注意力机制中所有token都会给start token一个非常大注意力分数的现象。
产生原因:
softmax强制要求每个token对所有历史tokens的注意力权重(softmax后的注意力分数)之和为1,但有时候当前token和所有历史token的关联性都比较低(也即它实际上并不想关注任何token,或者说它不需要历史tokens给它额外信息来更新它的representation(attention操作的本质就是在当前层通过交互其他的tokens的信息来更新当前token的表征向量)),则其会把“多余”的注意力分配到某些固定位置。而由于start token <bos>对于所有后续token都可见,因此它就变成了大家共用的“注意力垃圾桶”,会起到吸收多余注意力的作用。
本质上,是大模型在学习“不关注任何token”的能力
attention sink的存在使得输入序列的微小扰动(如替换单词)对其他token表示的影响较小,通过固定分配部分注意力到无关的token(如<bos>)来减少其他token间的过度信息混合,起到了稳定作用(尤其在长文本下)。
另外,如果从位置编码的角度考虑,则开头几个tokens可以作为绝对位置的“锚点”。相对位置编码原则上只能识别相对位置,但有些任务可能比较依赖绝对位置,则可以通过开头几个绝对位置约为0的token作为标杆,使得每个token在某种程度上也能识别出它们自己所在的绝对位置。
实践策略:在sparse attention(kvcache压缩)中显式地确保<bos>总能得到关注,如streamingllm中的做法。这样可以确保在超长文本上也不至于出现token间过度信息混合的问题。
1.2 Outliers的产生机制
指模型规模大到一定程度的情况下,某些权重或激活中存在少量数值非常大的元素的现象。例如ffn中的down_proj层的输出activation就特别容易产生outliers。
它们通常出现在某些固定通道中,activation中的outliers还会集中存在于某些tokens中(例如标点符号等语义不太明显的token)。
产生原因:
同样和softmax的数学特性有关。实验发现当activation的outliers集中于某些标点符号的representation中时,它们也会在注意力机制中吸收大量的注意力分数。因此,这些outlier的产生本质上还是因为很多tokens并不想和其他tokens有那么多的信息交互(或者说不想在attention层太剧烈地更新自己的representation),导致在标点符号等语义不明显的token上产生了巨大的注意力分数,进一步导致了这些标点符号对应的activation中产生了outliers。
2. 如何避免Attention Sink和Outliers的出现
若想从根源上解决outlier/sink的出现,可以采用如下方法。本质上就是显式地让attention具有“不关注任何token”的能力。实验证明它们一般能提升预训练模型的效果:
(1)可以在训练时采用off-by-one softmax:
在softmax的分母上+1:
\[\frac{\exp(x_i)}{\sum_j\exp(x_j)}\] \[\Downarrow\] \[\frac{\exp(x_i)}{ \textcolor{red}{1+} \sum_j\exp(x_j)}\]这样就使得总的注意力权重之和不一定为1,如果需要的话可以对任何token的注意力权重都接近0。
(2)可以通过在预训练阶段加入专门的占位符作为算softmax时吸收注意力的sink,来显式地使得attention具有“pay no attention”的能力。
例如gpt-oss中采用的方法:
class Attention: def __init__(self, config, ...): self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) # ... def forward(self, x): # get QKV, reshape ... # S = self.sinks.reshape(...) QK = Q @ K.T # QK.shape = [.., .., seq_len, seq_len], S.shape = [..., ..., seq_len, 1] # 也即,把sink作为一个“token”加入到key的末尾,但其“注意力分数”并不是由Q @ S.T算出来的,而是训出来的一个固定偏置值(可见它是直接加到Q @ K.T这个算出来的attention score上的) QK = torch.cat([QK, S], dim=-1) # 算softmax时会引入sink的影响 W = torch.softmax(QK, dim=-1) # 算完softmax后,再把sink的输出结果扔掉 W = W[..., :-1] attn = W @ V可见,在算softmax时,会直接加入一个学好的固定偏置值,这样就使得q对于“真实的tokens”的keys的注意力分布之和不必为1。
(3)可以通过gating机制来抑制attention sink的产生
在Qwen3-Next gated attention中提出,通过尝试在$W_Q,W_K,W_V,W_O$前后加入gating,从而带来input dependent sparsity,可以过滤掉不重要的信息。其能够抑制sink可能是因为筛掉了不重要的tokens,留下的都是重要的tokens,所以attention不太需要再把多余注意力扔到sink处,非常有利于训练稳定性、长文本等。
