
LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels
Key Points
- 1LeWorldModel (LeWM) introduces a stable, end-to-end Joint Embedding Predictive Architecture (JEPA) that learns world models from raw pixels using only two loss terms: next-embedding prediction and a Gaussian-distributed latent embedding regularizer.
- 2This approach simplifies training by reducing tunable hyperparameters from six to one, enabling efficient learning of compact models on a single GPU and significantly faster planning (up to 48x) compared to foundation-model-based alternatives.
- 3LeWM demonstrates strong competitive performance across diverse 2D and 3D control tasks, while its latent space encodes meaningful physical structure, as confirmed by probing physical quantities and detecting unphysical events.
LeWorldModel (LeWM) introduces a novel Joint Embedding Predictive Architecture (JEPA) designed to learn world models stably and end-to-end from raw pixel observations. Addressing the common challenges of representation collapse, reliance on complex multi-term losses, and auxiliary supervision in existing JEPA methods, LeWM proposes a streamlined approach with only two loss terms and a single primary tunable hyperparameter. This design facilitates stable training on a single GPU within hours, enabling significantly faster planning compared to foundation-model-based world models while maintaining competitive performance across various 2D and 3D control tasks.
The core methodology of LeWM revolves around jointly optimizing an encoder and a predictor using a simplified objective. The system operates entirely in a fully offline and reward-free setting, learning from unannotated trajectories of pixel observations and corresponding actions .
Model Architecture:
LeWM consists of two main components:
- Encoder (): This maps a raw pixel observation into a compact, low-dimensional latent representation . It is implemented as a Vision Transformer (ViT), specifically the tiny configuration (e.g., 5M parameters). The latent embedding is derived from the [CLS] token embedding of the ViT's last layer, followed by a 1-layer MLP with Batch Normalization. This projection step is crucial to ensure the latent representation space is amenable to the subsequent anti-collapse regularization.
- Predictor (): This models the environment dynamics within the latent space. It predicts the latent representation of the next frame given the current latent embedding and action . The predictor is a Transformer (e.g., 10M parameters) that incorporates actions via Adaptive Layer Normalization (AdaLN) at each layer. It takes a history of frame representations and predicts the next frame autoregressively with temporal causal masking. A projector network, similar to the encoder's, is applied after the predictor's output.
Training Objective:
The LeWM training objective is defined as the sum of a prediction loss and an anti-collapse regularization term:
- Prediction Loss (): This is a mean-squared error (MSE) between the predicted next latent embedding and the true next latent embedding . This term incentivizes the encoder to learn representations that are predictable by the predictor. The predictor operates in a teacher-forcing manner.
- Sketched-Isotropic-Gaussian Regularizer (SIGReg): This term is critical for preventing representation collapse, a common failure mode where models map all inputs to a trivial, constant representation. SIGReg encourages the distribution of latent embeddings (collected over history length , batch size , and embedding dimension ) to match an isotropic Gaussian target distribution. It leverages the CramΓ©rβWold theorem by enforcing normality along multiple random one-dimensional projections.
The value of (number of random projections) typically has a negligible impact on performance, making the regularization weight the only effective hyperparameter to tune. LeWM trains end-to-end without stop-gradient, exponential moving averages, or other heuristic stabilization tricks, allowing gradients to propagate through all components.
Latent Planning:
At inference time, LeWM enables trajectory optimization using Model Predictive Control (MPC) in its learned latent space. Given an initial observation and a goal observation :
- The initial and goal observations are encoded into latent embeddings: and .
- Candidate action sequences are iteratively optimized. For a given action sequence, the predictor autoregressively rolls out future latent states: , with .
- A terminal latent goal-matching objective is minimized:
- The optimal action sequence is found by solving:
This optimization is performed using the Cross-Entropy Method (CEM). To mitigate prediction error accumulation over long horizons, an MPC strategy is employed where only the first planned actions are executed before replanning from an updated observation.
Results and Contributions:
LeWM achieves strong control performance across diverse 2D and 3D tasks, outperforming existing end-to-end JEPA approaches like PLDM and remaining competitive with foundation-model-based methods such as DINO-WM at a substantially lower computational cost. Notably, LeWM demonstrates up to 48x faster planning times. The method's stability is highlighted by its smooth training curves and robustness to various architectural and hyperparameter choices, making it significantly easier to tune compared to prior methods. Beyond control, LeWM's latent space is shown to encode meaningful physical structure, as evidenced by successful probing of physical quantities and reliable detection of physically implausible events.