LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels
Paper

LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels

Damien Scieur
2026.03.26
·Arxiv·by 이호민/AI
#Computer Vision#Deep Learning#JEPA#Representation Learning#World Model

Key Points

  • 1LeWorldModel (LeWM) introduces a stable, end-to-end Joint Embedding Predictive Architecture (JEPA) that learns world models from raw pixels using only two loss terms: next-embedding prediction and a Gaussian-distributed latent embedding regularizer.
  • 2This approach simplifies training by reducing tunable hyperparameters from six to one, enabling efficient learning of compact models on a single GPU and significantly faster planning (up to 48x) compared to foundation-model-based alternatives.
  • 3LeWM demonstrates strong competitive performance across diverse 2D and 3D control tasks, while its latent space encodes meaningful physical structure, as confirmed by probing physical quantities and detecting unphysical events.

LeWorldModel (LeWM) introduces a novel Joint Embedding Predictive Architecture (JEPA) designed to learn world models stably and end-to-end from raw pixel observations. Addressing the common challenges of representation collapse, reliance on complex multi-term losses, and auxiliary supervision in existing JEPA methods, LeWM proposes a streamlined approach with only two loss terms and a single primary tunable hyperparameter. This design facilitates stable training on a single GPU within hours, enabling significantly faster planning compared to foundation-model-based world models while maintaining competitive performance across various 2D and 3D control tasks.

The core methodology of LeWM revolves around jointly optimizing an encoder and a predictor using a simplified objective. The system operates entirely in a fully offline and reward-free setting, learning from unannotated trajectories of pixel observations o1:To_{1:T} and corresponding actions a1:Ta_{1:T}.

Model Architecture:
LeWM consists of two main components:

  1. Encoder (encθ\text{enc}_\theta): This maps a raw pixel observation oto_t into a compact, low-dimensional latent representation zt∈Rdz_t \in \mathbb{R}^d. It is implemented as a Vision Transformer (ViT), specifically the tiny configuration (e.g., 5M parameters). The latent embedding ztz_t is derived from the [CLS] token embedding of the ViT's last layer, followed by a 1-layer MLP with Batch Normalization. This projection step is crucial to ensure the latent representation space is amenable to the subsequent anti-collapse regularization.
  2. Predictor (predΟ•\text{pred}_\phi): This models the environment dynamics within the latent space. It predicts the latent representation of the next frame z^t+1\hat{z}_{t+1} given the current latent embedding ztz_t and action ata_t. The predictor is a Transformer (e.g., 10M parameters) that incorporates actions via Adaptive Layer Normalization (AdaLN) at each layer. It takes a history of NN frame representations and predicts the next frame autoregressively with temporal causal masking. A projector network, similar to the encoder's, is applied after the predictor's output.

Training Objective:
The LeWM training objective is defined as the sum of a prediction loss and an anti-collapse regularization term:
LLeWMβ‰œLpred+Ξ»SIGReg(Z)L_{\text{LeWM}} \triangleq L_{\text{pred}} + \lambda \text{SIGReg}(Z)

  1. Prediction Loss (LpredL_{\text{pred}}): This is a mean-squared error (MSE) between the predicted next latent embedding z^t+1\hat{z}_{t+1} and the true next latent embedding zt+1z_{t+1}. This term incentivizes the encoder to learn representations that are predictable by the predictor. The predictor operates in a teacher-forcing manner.
Lpredβ‰œβˆ₯z^t+1βˆ’zt+1βˆ₯22,whereΒ z^t+1=predΟ•(zt,at)L_{\text{pred}} \triangleq \left\| \hat{z}_{t+1} - z_{t+1} \right\|_2^2, \quad \text{where } \hat{z}_{t+1} = \text{pred}_\phi(z_t, a_t)
  1. Sketched-Isotropic-Gaussian Regularizer (SIGReg): This term is critical for preventing representation collapse, a common failure mode where models map all inputs to a trivial, constant representation. SIGReg encourages the distribution of latent embeddings ZZ (collected over history length NN, batch size BB, and embedding dimension dd) to match an isotropic Gaussian target distribution. It leverages the CramΓ©r–Wold theorem by enforcing normality along multiple random one-dimensional projections.
Specifically, it projects the latent embeddings Z∈RNΓ—BΓ—dZ \in \mathbb{R}^{N \times B \times d} onto MM random unit-norm directions u(m)∈Sdβˆ’1u^{(m)} \in S^{d-1}, yielding one-dimensional projections h(m)=Zu(m)h^{(m)} = Z u^{(m)}. An Epps–Pulley normality test statistic T(β‹…)T(\cdot) is then applied to each projection, and these statistics are aggregated:
SIGReg(Z)β‰œ1Mβˆ‘m=1MT(h(m))\text{SIGReg}(Z) \triangleq \frac{1}{M} \sum_{m=1}^{M} T(h^{(m)})
The value of MM (number of random projections) typically has a negligible impact on performance, making the regularization weight Ξ»\lambda the only effective hyperparameter to tune. LeWM trains end-to-end without stop-gradient, exponential moving averages, or other heuristic stabilization tricks, allowing gradients to propagate through all components.

Latent Planning:
At inference time, LeWM enables trajectory optimization using Model Predictive Control (MPC) in its learned latent space. Given an initial observation o1o_1 and a goal observation ogo_g:

  1. The initial and goal observations are encoded into latent embeddings: z1=encΞΈ(o1)z_1 = \text{enc}_\theta(o_1) and zg=encΞΈ(og)z_g = \text{enc}_\theta(o_g).
  2. Candidate action sequences a1:Ha_{1:H} are iteratively optimized. For a given action sequence, the predictor autoregressively rolls out future latent states: z^t+1=predΟ•(z^t,at)\hat{z}_{t+1} = \text{pred}_\phi(\hat{z}_t, a_t), with z^1=z1\hat{z}_1 = z_1.
  3. A terminal latent goal-matching objective is minimized:
C(z^H)=βˆ₯z^Hβˆ’zgβˆ₯22C(\hat{z}_H) = \left\| \hat{z}_H - z_g \right\|_2^2
  1. The optimal action sequence a1:Hβˆ—a^*_{1:H} is found by solving:
a1:Hβˆ—=arg⁑min⁑a1:HC(z^H)a^*_{1:H} = \arg \min_{a_{1:H}} C(\hat{z}_H)
This optimization is performed using the Cross-Entropy Method (CEM). To mitigate prediction error accumulation over long horizons, an MPC strategy is employed where only the first KK planned actions are executed before replanning from an updated observation.

Results and Contributions:
LeWM achieves strong control performance across diverse 2D and 3D tasks, outperforming existing end-to-end JEPA approaches like PLDM and remaining competitive with foundation-model-based methods such as DINO-WM at a substantially lower computational cost. Notably, LeWM demonstrates up to 48x faster planning times. The method's stability is highlighted by its smooth training curves and robustness to various architectural and hyperparameter choices, making it significantly easier to tune compared to prior methods. Beyond control, LeWM's latent space is shown to encode meaningful physical structure, as evidenced by successful probing of physical quantities and reliable detection of physically implausible events.