LLM知识记忆:DeepSeek Engram解析
Published:
虽然MoE通过条件计算扩大了模型的容量,但Transformer本身并不具备原生的知识查表能力,其不得不通过计算来低效地模拟信息检索的过程。而语言建模任务本质上包含组合推理与知识检索两种任务,其中前者需要动态的计算,后者则是静态和刻板的(如命名实体和公式化的模式等),其可以使用N-gram这种无需昂贵计算就能有效捕捉局部依赖模式的方法进行知识检索,免去低级重复的推理。因此,DeepSeek提出一种称为Engram的条件记忆机制,作为和MoE互补的另一个稀疏维度。这种设计使得静态的知识可以通过Engram检索获得,无需引入昂贵计算,只需$O(1)$的查表操作,而动态推理则走MoE计算。
Engram
https://arxiv.org/pdf/2601.07372
MoE通过条件计算来提升模型容量,也即根据当前上下文内容来稀疏地激活特定的参数,来处理动态的逻辑;Engram则通过条件记忆来基于当前上下文内容稀疏地检索静态的embeddings,来获取固定的知识。具体而言,传统的N-gram将当前的局部上下文作为key,用来以$O(1)$的复杂度在一个巨大的embedding表中进行查找,Engram即为基于它设计的,可以和MoE形成互补。
Engram模块可以被插入到模型的某些层中,其可以检索静态的N-gram memory,并通过一个上下文感知的gating将其与本层隐含状态进行融合。
设输入序列为$X=(x_1,\cdots,x_T)$,在第$l$层的隐含状态为$H^{(l)}\in\mathbb R^{T\times d}$。对于每个输入位置$t$的token $x_t$来说,Engram模块会执行检索和融合两个操作:首先提取以其为后缀的N-gram信息(也即当前输入序列中该token所在的局部上下文),并通过hash操作将其映射为一个index后,从一个可学习的embedding表(memory table)中以$O(1)$复杂度检索回对应的embedding向量(其在训练后已经蕴含和这个局部上下文相关的知识信息),作为检索来的记忆embedding;然后,再以当前全局上下文(当前输入)的hidden states作为key来对检索来的memory做个attention来计算相关度分数,用来判断它和全局上下文真正的相关度,并以此作为加权系数将其融入到当前hidden states中,从而完成记忆的检索和融合。
查表检索来的记忆embedding本质上是作为一种bias来指导生成下一个token,它可以被看成一种“方向向量”,用来使得模型在早期层就“相当确信地知道”接下来要生成的“候选token”,后边的一系列深层则可以在此基础上将计算用在更深度的推理上,比如反复思考该“候选token”的合理性等,从而使得最终输出结果更准确、是更复杂推理后到结果。而如果没有Engram等话,则模型从浅到深绝大部分层的计算都浪费在了识别简单模式上,好不容易推出下一个候选token是什么的时候已经快到模型输出层了,模型已经不剩几层能用来对这个答案进行进一步深度校准思考了,导致输出的候选答案相对来说是“欠考虑的”,准确性上可能就打了折扣,而这是由于模型把过多的计算都浪费在了简单低级推理上所致。
本质上,Engram可以跳过本来需要的多层计算(模式识别等低级推理),在浅层直接让模型感知到最可能的下一个候选token是什么,并以这个状态作为“起始点”,把后边的大量深层计算都用在深度思考来refine这个结果上,而没有Engram的话则模型的大部分层计算都浪费在了识别简单模式上,在很深层时才能到达那个“起始点”,而此时留给复杂推理refine的计算层已经不多了,模型有限的计算容量很多都已浪费在了之前的低级推理上。
示例(简化版):
输入的prompt为"The capital of France",在模型等某一层中,对于"France"这个token来说,提取以其为后缀的2-gram和3-gram:"of France"、"capital of France",然后对每个n-gram短语做hash后,各自映射得到一个index,并根据index检索到memory table中对应的embedding向量,然后再将2-gram和3-gram的embedding向量拼接后,作为"France"局部上下文检索到的静态知识memory向量。
然后,以prompt在当前层的hidden state为query,以这个memory向量为key,计算一下这个memory向量和当前全局上下文的关联度,并得到一个相关度分数,用来确知该memory到底是否强适用于当前上下文,并以此作为融合权重来将memory向量融合到hidden state中。
这个memory向量可以看成一个语义模板,在训练过程中,它学习了“当见到
"capital of France"这个模式时接下来要强烈关联"is"、"Paris等token”的知识,在推理时浅层就查表得到并把它加到hidden state中,模型就会在浅层感知到下一个token很可能应该生成"is"、"Paris"等。而知道这个信息后,以它为起始点,接下来的那些层的计算就可以用于进行更深入的校验和推理,如”"is"的时态变化是否正确“、“"Paris"是否符合当前上下文的时代信息限制”等,从而使得最后一层输出的结果是深度考虑后的。Engram通过在早期层就把模型带到了一个很接近答案的状态,来使得模型尽可能把有限的计算容量都用在深度思考上。
而如果没有它的话,Transformer还要一层层推理来理解“首都”的含义、“首都”、“法国”这些东西的关联、有关“法国”的信息等,才能最终推理出“
"Paris"”这个答案,但当模型感知到这一点时已经快到输出层了,没什么层剩下用于进一步验证了,最终只能仓促输出这个结果,准确率就可能打折扣。而这就是模型浪费了太多计算在识别模式等低级推理上,导致后期没有足够计算被剩下用于校验推理所致。
具体实现细节如下:
首先,将局部上下文(n-gram窗口)通过hash映射为一个index,并用于在memory table中检索对应的memory embeddings:
Tokenizer Compression(分词压缩):
为了尽可能提升语义密度,将语义相同但形式不同的token进行合并,得到一个canonical id(来表示这一组语义相同的token统一后得到的单个token),如
"Apple"和"apple",这样能够将128K的词表压缩23%,减小了N-gram的组合空间。在数学上,定义一个词表映射层$\mathcal P$,其可以将每个位置$t$的token $x_t$映射为它对应的canonical id:
\[x_t'=\mathcal P(x_t)\]将以其为后缀的n-gram窗口内的所有tokens都映射为它们对应的canonical id,从而得到$x_t$的分词压缩后的n-gram窗口$g_{t,n}$:
\[g_{t,n}=(x_{t-n+1}',\cdots,x_t')\]为了充分捕捉不同粒度的局部信息,后面会结合多种$n$取值下的n-gram并融合它们检索到的记忆信息,本文默认取2-gram和3-gram。
Multi-head hashing(多头哈希):
为了尽可能避免哈希碰撞,这里对每个n-gram窗口都使用多个hash head进行映射,从而得到多个index。设总共使用$K$个哈希头,其中第$k$个头对于n-gram窗口的哈希映射函数为$\varphi_{n,k}$,其映射窗口$g_{t,n}$得到一个index值$z_{t,n,k}$:
\[z_{t,n,k}=\varphi_{n,k}(g_{t,n})\]每个$n$取值下的每个哈希头$k$都拥有一个memory table $\mathbf E_{n,k}$,根据index查表可得其对应的memory embedding向量$\mathbf e_{t,n,k}$:
\[\mathbf e_{t,n,k}=\mathbf E_{n,k}[z_{t,n,k}]\]
最终,将所有$n$取值下的所有哈希头的结果全都连接起来,得到为$x_t$检索到的总memory向量:
\[\mathbf e_t=\\|_{n=2}^N\\|_{k=1}^K \mathbf e_{t,n,k}\]接下来进行Context-aware Gating,来计算memory向量和当前全局上下文的相关度分数。具体而言,使用位置$t$的隐含状态向量$\mathbf h_t$作为动态query(它经过前面层的attention后已经聚合了全局信息),使用memory向量$\mathbf e_t$变换得到key,然后计算得到相关度分数$\alpha_t$:
\[\mathbf k_t=W_K\mathbf e_t\] \[\alpha_t=\sigma\left(\frac{\text{RMSNorm}(\mathbf h_t)^T\text{RMSNorm}(\mathbf{k}_t)}{\sqrt d}\right)\]然后用该相关度分数来缩放memory向量,从而调控其融入隐含状态后带来的影响大小,缩放后的memory向量表示为$\mathbf{\tilde{ v}}_t$:
\[\mathbf {\tilde v}_t=\alpha_t\cdot \mathbf e_t\]如此缩放后,可以使得当获取的记忆$\mathbf e_t$和当前上下文$\mathbf h_t$相悖时相关度分数变得很小,从而抑制memory向量带来的噪声的影响。当前序列所有缩放后的memory向量堆叠组成$\tilde{\mathbf V}=[\mathbf{\tilde{v}}_1;\cdots;\mathbf{\tilde{v}}_T]\in\mathbb R^{T\times d}$,即为整个序列获取的记忆张量。
最后,为了增大感受野(混合邻近的局部信息)并提升模型非线性,引入一个大小为$w=4$的1D卷积和一个SiLU激活函数,来将整个序列的记忆张量沿序列方向进行局部混合与非线性激活,最终得到记忆张量$\mathbf Y\in\mathbb R^{T\times d}$:
\[\mathbf Y=\text{SiLU}(\text{Conv1D}(\text{RMSNorm}(\tilde{\mathbf V})))+\tilde{\mathbf V}\]记忆张量最终被加到当前隐含状态上,从而得到融合了记忆后的隐含状态:
\[\mathbf H^{(l)}\leftarrow \mathbf H^{(l)}+\mathbf Y\]为了定量化描述MoE和Engram二者的协同作用,建模一个稀疏度分配问题:给定一个固定的总参数量限制,该如何最优地将容量分配给MoE专家和Engram记忆。实验表明,以MoE部分的分配比例为横轴,val loss为纵轴,呈现一个U形曲线。而若给Engram无限记忆空间,则loss会随其embedding数量增加而以对数趋势减小。
系统设计、与mHC结合等:留坑
