Sub-JEPA:面向稳定端到端隐变量世界模型的子空间高斯正则化
💡 轻量级实验友好:本项目全部实验在单张显存 24 GB 以上的 GPU(如 RTX 3090 / 4090)上完成,无需多卡或大规模集群。




摘要
我们提出 Sub-JEPA,一种将高斯正则化从全局嵌入空间迁移到低维随机正交子空间的隐变量世界模型。 控制任务的隐变量表示通常分布在低维流形上;在高维全局空间中施加各向同性高斯先验会引入与任务内在几何不匹配的过强偏置。 Sub-JEPA 将全局先验替换为独立的子空间级高斯约束,在放松偏置的同时保留了 LeWorldModel(LeWM) 的防坍缩保证。 在四个连续控制基准上,Sub-JEPA 持续超越 LeWM,且性能提升与有效秩的降低直接相关。
方法
隐变量世界模型由编码器

图 1. Sub-JEPA 通过共享编码器将连续观测编码为隐变量, 并以预测损失训练预测器。虚线以下为关键新增部分:子空间高斯正则化损失—— 隐变量被投影到 K 个冻结的随机正交子空间, 每个子空间独立地被正则化至标准高斯分布。
正交子空间投影。
LeWM 在全量
多子空间高斯正则化。
在每个子空间内,沿随机一维投影方向对子空间嵌入施加 Epps–Pulley 正态性检验,与 LeWM 保持一致。
正则化损失在所有
总目标函数将隐变量预测与子空间正则化结合:
冻结投影矩阵可防止编码器与正则化器之间的协同适应,确保训练过程中高斯约束的一致性。
实验结果
Sub-JEPA 在四个从原始 RGB 观测训练的连续控制基准上进行评估, 与 LeWM、PLDM 和 DINO-WM 进行比较(六个随机种子,均值 ± 标准差)。
| 方法 | Two-Room | Reacher | PushT | OGB-Cube |
|---|---|---|---|---|
| PLDM | 97.00 | 78.00 | 78.00 | 65.00 |
| DINO-WM (w/o proprio.) | 100.00 | 79.00 | 74.00 | 86.00 |
| LeWM | 84.33±4.23 | 82.67±4.42 | 84.67±6.53 | 67.33±5.01 |
| Sub-JEPA(本文) | 95.00±2.76 | 84.00±4.00 | 89.00±5.33 | 76.33±5.99 |

图 2. 四个环境中的有效秩(上)和规划成功率(下)。 Sub-JEPA 相对 LeWM 有效秩降低最大的环境,也呈现出最大的性能提升, 表明子空间正则化通过抑制隐空间中的冗余高秩变化来改善规划。
最大的性能提升出现在 Two-Room(+10.7%)和 OGB-Cube(+9.0%), 这两个环境同时也表现出最大的有效秩降低。 这一直接对应关系支持了我们的假设:当任务动态分布在低维流形上时, 全局高斯先验会迫使隐变量保持不必要的高秩。 子空间正则化使隐变量几何结构向任务的内在维度收缩。

图 3. Two-Room 上的成功率随子空间数 K 和子空间维度 ds 的变化曲面。 水平参考面为 LeWM 基线。Sub-JEPA 在大范围中间参数配置下均超越基线, 说明方法对超参数选择具有较强鲁棒性。
为验证 Sub-JEPA 所学表示更符合任务几何结构,我们在 Two-Room 上可视化隐变量轨迹。 该环境内在维度较低,使得全局高斯先验的不匹配最为显著。 连续观测被编码后,其 [CLS] 嵌入经 UMAP 投影至二维,并按归一化时间索引着色。 Sub-JEPA 在各个 episode 中均产生时序连贯的路径,而 LeWM 的时序结构则较为混乱—— 说明当任务动态本质上是低维时,全局先验会扭曲隐变量几何结构。

图 4. Two-Room 上隐变量轨迹的 UMAP 投影, 按归一化时间索引着色。Sub-JEPA 产生有序、时序连贯的路径; LeWM 的时序结构则较为混乱,表明全局高斯先验在任务动态本质上低维时会扭曲隐变量几何结构。
延续 LeWM 的分析思路,我们通过时序路径直线度来考察隐变量轨迹的几何特性—— 即隐空间中动态演化的线性程度,以相邻时序速度向量的平均余弦相似度衡量。 对于隐变量世界模型而言,更直的轨迹意味着更规则的 rollout,对规划质量尤为重要。 如图 6 所示,Sub-JEPA 在 PushT 和 OGB-Cube 上均产生比 LeWM 更直的轨迹,且无需任何显式优化目标。 这表明子空间正则化相比全局正则化能减少隐空间中的几何畸变, 为表 1 中的规划性能提升提供了几何层面的解释。

图 6. PushT 和 OGB-Cube 上的隐变量轨迹直线度 (相邻时序速度向量的平均余弦相似度)。Sub-JEPA 在无任何显式直线度优化目标的情况下 持续优于 LeWM,表明子空间正则化能减少隐空间中的几何畸变。
结论
Sub-JEPA 是对 LeWM 的最小化改动:将全局高斯正则化替换为