世界模型浅析:JEPA架构的基本思路
Published:
无论是LLM还是Diffusion,这些生成模型本质上都是在基于统计概率来拟合世界中的每一个token/pixel,但它们可能只是在统计意义上进行“续写”,而不是真正理解了世界运行的物理规律,例如“杯子被摔了会碎”、“有人碰杯子会碰倒”等。LeCun认为真正的智能并不是能预测下一个token/pixel,而是能理解世界如何演化(“如果我这样做,世界会发生什么”)。其提出的JEPA架构通过预测“抽象后的未来状态”而不是预测token/pixel本身,来使得模型学习高层次的世界演化规律,相当于一个“脑内物理引擎”。
另外,现有的生成模型试图直接拟合原始的token/pixel,但由于现实世界的复杂性,生成对象存在了大量的细节特征,例如“每根头发”、“光照噪声”等都存在了很大的不确定性,但事实上这些细节可能根本就不重要,例如想知道“推了杯子”的结果时无需预测每个水花像素的轨迹,只需预测“杯子会倾倒”、“杯子会摔碎”这些结果就好。
JEPA(Joint-Embedding Prediction Architecture,联合嵌入预测架构)是LeCun提出的世界模型表征学习架构。其核心思路在于让模型预测世界状态的抽象表示(latent embedding)。例如,一张图片映射为抽象表示向量后,其中包含的信息可能表示“有人”、“有车”、“桌上有杯子”等“世界状态”,而并不是包含了所有的原始像素。“Joint Embedding”指的是模型会同时学习如何表示世界和如何预测未来世界状态。
下面以I-JEPA为例来分析JEPA的基本思路。
其训练任务类似MAE,mask掉图片的一部分,希望模型基于剩下的其他部分作为上下文来恢复出mask的部分。但MAE在像素级进行的重建会过度关注纹理、局部细节、高频噪声等,而不是物体、语义、空间关系这种高层次的信息,因此I-JEPA预测的不是像素,而是mask部分的抽象表征。
设$x$为context信息(图片中剩余部分),$y$为target信息(图片中被mask的部分),$z$为其他辅助信息(例如target的位置坐标等),则:
生成式架构如上图中所示,其将$x$编码到隐空间后,结合$z$的辅助信息,通过decoder网络直接预测目标$\hat y$,并和原始target信息$y$算loss:
\[s_x=\text{x-encoder}(x)\]
JEPA架构如上图右所示,其将$x,y$均编码到隐空间,得到它们的抽象表示$s_x,s_y$,然后再由$s_x$结合$z$通过predictor网络预测目标的抽象表征$\hat s_y$,并让其和target信息的抽象表征$s_y$做loss:
\[s_x=\text{x-encoder}(x)\]
可见,JEPA中会同时学习“如何表征世界”(也即encoder网络)和“如何预测未来世界状态”(也即predictor网络),这就是“联合预测架构”。
以一张“狗在草地上”的图像为例,一些实现细节如下:
首先使用ViT将图片切成patch,然后随机选择一些mask的区域作为target。I-JEPA会遮住足够大的区域(语义级别的大区域),因为如果target太小的话模型可以靠纹理或局部连续性等预测出来,只有当其足够大时才能迫使模型必须理解了语义才能预测出来。在上图中,红黄蓝三个方框里的“耳朵”、“尾巴”、“草地”就是target区域,
然后构建context区域。I-JEPA并没有选取全部剩余区域,而是经过采样来得到空间分散的一些区域作为上下文,这样可以避免模型仅通过像素级的连续性就能完成补全,而是必须要理解这个物体。上图中“狗头”就是context区域,模型看到存在“狗头”就应该能推断出来还会出现“尾巴”、“耳朵”等,而不是在预测“狗毛的每个像素”。
这里还使用了额外信息$z$作为预测的辅助信息,它可以是target的位置、mask的位置等信息,让模型知道它看到的是哪部分区域,要预测的又是哪部分区域。在上图中对应的是predictor前的红黄蓝三个实心块,表示了三个目标区域的位置信息。
另外,模型中对于“如何表征世界”的编码器实际上包含了两个模块,分别是编码context和target的encoder。如果二者直接独立地一起训练的话可能导致表征坍塌,例如它们全都输出0来拉低loss,而学不到有用的信息。因此,I-JEPA通过EMA(指数移动平均)来让二者存在时间延迟,使得target encoder是context encoder的慢速移动平均:
\[\theta_{\text{target}}\leftarrow m\theta_{\text{target}}+(1-m)\theta_{\text{context}}\]这样可以一定程度上缓解collapse问题,避免两边很轻易地一起坍塌到输出0并获得极小loss,而是让二者存在一定的时间延迟:就算其中一个collapse了,另一个也不会立刻collapse到一样的状态,而是仍保持某个复杂状态,此时loss不会一下子被拉的很小,能够产生一个较大的梯度来推动collapse的一边跳出该区域。
总的来看,JEPA提供了一种学习高层次的世界知识的基础架构,训练后的模型可以产出理解了世界深层信息的抽象表征,可以进一步轻松地被用于图像理解、深度估计等下游任务。
