
LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels
핵심 포인트
- 1LeWorldModel (LeWM)은 raw pixel에서 stable하게 end-to-end로 훈련되는 최초의 JEPA로, 복잡한 multi-term loss 없이 next-embedding prediction loss와 SIGReg 두 가지 loss term만을 사용하여 collapse를 방지합니다.
- 2이 compact한 모델(15M 파라미터)은 single GPU에서 훈련 가능하며, DINO-WM보다 최대 48배 빠른 planning 속도를 보이면서도 diverse한 2D 및 3D control task에서 PLDM을 능가하거나 competitive한 성능을 달성합니다.
- 3LeWM의 latent space는 meaningful한 physical structure를 인코딩하며, physical quantity probing과 surprise evaluation을 통해 물리적으로 불가능한 이벤트를 신뢰성 있게 감지하는 능력을 보여줍니다.
LeWorldModel (LeWM)는 픽셀 기반 Joint Embedding Predictive Architecture (JEPA)에서 안정적인 End-to-End 학습을 가능하게 하는 새로운 프레임워크입니다. 기존 JEPA 방법론들은 표현 붕괴(representation collapse)를 피하기 위해 복잡한 다중 손실 항, Exponential Moving Average (EMA), 사전 학습된 인코더, 또는 보조 지도(auxiliary supervision)에 의존하여 불안정하거나 복잡했습니다. LeWM은 이러한 한계를 극복하며, 오직 두 가지 손실 항, 즉 다음 임베딩 예측 손실과 잠재 임베딩(latent embedding)이 가우시안 분포를 따르도록 강제하는 정규화(regularization)를 통해 안정적인 학습을 달성합니다. 이를 통해 조절해야 할 하이퍼파라미터의 수를 기존 End-to-End 대안 대비 6개에서 1개로 대폭 줄입니다.
핵심 방법론 (Core Methodology):
LeWM은 오프라인 및 보상 없는(reward-free) 환경에서 학습됩니다. 학습 데이터는 원시 픽셀 관측치 와 관련 행동 로 구성된 궤적(trajectories)입니다.
- 모델 아키텍처 (Model Architecture):
- 인코더 (Encoder): 주어진 프레임 관측치 를 압축된 저차원 잠재 표현 로 매핑합니다.
인코더는 Vision Transformer (ViT)로 구현됩니다. 최종 ViT 레이어의 [CLS] 토큰 임베딩을 사용하여 를 구성하며, 이후 Batch Normalization이 적용된 1-레이어 MLP를 통해 새로운 표현 공간으로 투영됩니다.
- 예측기 (Predictor): 잠재 공간에서 환경의 동역학(dynamics)을 모델링합니다. 현재 잠재 임베딩 와 행동 가 주어졌을 때 다음 프레임 관측치의 임베딩 을 예측합니다.
예측기는 Transformer 구조를 가지며, 행동은 Adaptive Layer Normalization (AdaLN)을 통해 각 레이어에 통합됩니다. 예측기는 개의 프레임 표현 이력을 입력받아 시간적 인과 마스킹(temporal causal masking)을 통해 다음 프레임 표현을 자기회귀적으로(auto-regressively) 예측합니다.
- 학습 목표 (Training Objective):
- 예측 손실 (Prediction Loss, ): Teacher-forcing 방식으로 구현되며, 연속적인 시간 단계의 예측된 임베딩과 실제 임베딩 간의 오차를 계산합니다. 이는 평균 제곱 오차(Mean Squared Error, MSE)를 사용합니다.
이 예측 손실은 인코더가 예측 가능한 표현을 학습하도록 유도합니다.
- Sketched-Isotropic-Gaussian Regularizer (SIGReg): 예측 손실만으로는 표현 붕괴(representation collapse)가 발생할 수 있습니다. 이를 방지하기 위해 SIGReg 정규화 항이 도입되어 임베딩 공간에서 특징 다양성(feature diversity)을 촉진합니다. SIGReg는 잠재 임베딩이 등방성 가우시안(isotropic Gaussian) 목표 분포와 일치하도록 장려합니다.
Cramér–Wold 정리 [39]에 따라, 모든 1차원 주변 분포(marginal distributions)를 일치시키는 것은 전체 결합 분포(full joint distribution)를 일치시키는 것과 동등합니다. SIGReg는 붕괴를 방지하며, 유일한 효과적인 하이퍼파라미터는 정규화 가중치 입니다 (일반적으로 는 고정).
LeWM은 stop-gradient, EMA 또는 추가적인 안정화 휴리스틱을 사용하지 않습니다. 모든 손실 구성 요소를 통해 그래디언트가 전파되며, 모든 파라미터는 End-to-End 방식으로 공동 최적화됩니다.
- 잠재 공간 계획 (Latent Planning):
는 예측기 를 사용하여 자기회귀적으로 계산됩니다: , .
이 유한-수평(finite-horizon) 최적 제어 문제는 Cross-Entropy Method (CEM) [40]를 사용하여 해결됩니다. Model Predictive Control (MPC) 전략이 적용되어 예측 오차 누적을 완화하고 실시간 제어에 근접한 빠른 계획 속도(DINO-WM 대비 48배)를 달성합니다.
결과 요약:
LeWM은 다양한 2D 및 3D 제어 태스크에서 강력한 성능을 보이며, 기존 End-to-End JEPA 접근 방식인 PLDM을 능가합니다. DINO-WM과 같은 파운데이션 모델 기반 World Model과도 경쟁력 있는 성능을 유지하면서 훨씬 낮은 계산 비용을 필요로 합니다. 또한, 잠재 공간에 대한 물리량 프로빙(probing)과 예외 탐지(violation-of-expectation test)를 통해 LeWM이 물리적 구조를 유의미하게 인코딩하고 물리적으로 불가능한 이벤트를 신뢰할 수 있게 감지함을 보여줍니다. SIGReg의 안정성과 단일 하이퍼파라미터 로의 간소화는 LeWM 학습 절차의 견고성을 크게 향상시킵니다.