End-to-End Test-Time Training for Long Context
Paper

End-to-End Test-Time Training for Long Context

Marcel RΓΈd
2026.01.31
Β·ArxivΒ·by 넀루
#LLM#Continual Learning#Test-Time Training#Transformer#Long Context

Key Points

  • 1This paper proposes TTT-E2E, an end-to-end Test-Time Training method that reframes long-context language modeling as a continual learning problem.
  • 2TTT-E2E operates by continually learning through next-token prediction at test time, and its initialization is meta-learned during training to optimize for this test-time adaptation.
  • 3This approach achieves performance scaling with context length comparable to full attention while maintaining constant inference latency, resulting in significant speed improvements for very long contexts.

The paper proposes "End-to-End Test-Time Training (TTT-E2E)" for long-context language modeling, framing it as a continual learning problem rather than an architectural one. The method utilizes a standard Transformer with sliding-window attention and continually updates its weights at test time via next-token prediction, effectively compressing context into its parameters. To enable this, the model's initialization is meta-learned during training.

Core Methodology: TTT-E2E

The core idea is to let the language model "learn on the job" by performing gradient updates on its weights using the incoming context during inference. This is achieved through a two-loop optimization process: an inner loop for test-time training and an outer loop for meta-learning the initial weights.

  1. Test-Time Training (Inner Loop):
At test time, given a sequence of tokens x0,x1,…,xT\mathbf{x}_0, \mathbf{x}_1, \ldots, \mathbf{x}_T, the model performs next-token prediction. Instead of simply predicting p^t+1=f(xt;W0)\hat{p}_{t+1} = f(\mathbf{x}_t; W_0) with static weights W0W_0, the model iteratively updates its weights.

  • Loss Function: For each token xt\mathbf{x}_t in the context, the model computes a standard next-token prediction loss:
β„“t(W)=CE(f(xtβˆ’1;W),xt)\ell_t(W) = \text{CE}(f(\mathbf{x}_{t-1}; W), \mathbf{x}_t)
where f(β‹…;W)f(\cdot; W) is the language model with weights WW, and CE denotes cross-entropy.

  • Mini-Batch Gradient Descent: To improve efficiency and stability, instead of online gradient descent (where b=1b=1), the paper employs mini-batch gradient descent. Given a batch size bb, weights are updated every bb tokens:
Wi=Wiβˆ’1βˆ’Ξ·1bβˆ‘k=(iβˆ’1)b+1ibβˆ‡β„“k(Wiβˆ’1)W_i = W_{i-1} - \eta \frac{1}{b} \sum_{k=(i-1)b+1}^{ib} \nabla \ell_k(W_{i-1})
where WiW_i are the weights after processing the ii-th mini-batch, and Ξ·\eta is the learning rate. The final prediction p^T+1\hat{p}_{T+1} for xT+1\mathbf{x}_{T+1} is made using WT/bW_{T/b}.

  • Integration with Sliding-Window Attention: To prevent the model from effectively becoming a bigram within each mini-batch (where only the first token benefits from prior context), the base architecture uses sliding-window attention with a window size kk. It's crucial that the window size kk is greater than or equal to the TTT mini-batch size bb (kβ‰₯bk \ge b), allowing the model to recall context within the mini-batch before weight updates.
  • Implementation Details for TTT:
    • Selective Parameter Update: Only the MLP (Multi-Layer Perceptron) layers are updated during TTT. Embedding layers, normalization layers, and attention layers are frozen, as updating them caused instability.
    • Partial Block Update: Only the MLP layers in the last 1/4 of the Transformer blocks are updated. This balances computational cost with the ability to compress context.
    • Static Second MLP: In the blocks that are updated, a second, static MLP layer is added to serve as a "safe" storage for pre-trained knowledge, preventing catastrophic forgetting. The total parameter count is kept consistent by reducing the hidden dimension of other MLPs.
    • Decoding Multiple Tokens: When generating multiple tokens, TTT updates are applied only after a full mini-batch of decoded tokens has been accumulated.
  1. Meta-Learning (Outer Loop):
The crucial aspect of TTT-E2E is that the model's initial weights W0W_0 (from training time) are optimized for the *final performance after* test-time training, not just for static pre-training loss. This resolves the mismatch between training and test-time behavior found in prior dynamic evaluation approaches (TTT-naive).

  • End-to-End Training Objective: The training objective for W0W_0 is the average test loss over sequences after TTT has been applied. For mini-batch TTT, this loss is:
L(W0;X)=1Tβˆ‘i=1T/bβˆ‘k=(iβˆ’1)b+1ibβ„“k(Wiβˆ’1)L(W_0; \mathbf{X}) = \frac{1}{T} \sum_{i=1}^{T/b} \sum_{k=(i-1)b+1}^{ib} \ell_k(W_{i-1})
  • Optimization: Optimizing W0W_0 with respect to L(W0;X)L(W_0; \mathbf{X}) requires computing gradients of gradients, as the inner loop updates themselves involve gradients. Modern automatic differentiation frameworks efficiently handle this. This outer loop optimization trains the model to "learn to learn" effectively at test time.

Alternative Derivation (Connection to TTT-KVB)

The paper also presents an alternative derivation, starting from prior work on Key-Value Binding (KVB) in TTT (TTT-KVB). TTT-KVB updates implicit key-value associations using a layer-wise reconstruction loss β„“(l)(Wtβˆ’1(l))=βˆ₯g(ΞΈK(l)xt(l);Wtβˆ’1(l))βˆ’ΞΈV(l)xt(l)βˆ₯2\ell^{(l)}(\mathbf{W}_{t-1}^{(l)}) = \| g(\theta_K^{(l)}\mathbf{x}_t^{(l)}; \mathbf{W}_{t-1}^{(l)}) - \theta_V^{(l)}\mathbf{x}_t^{(l)} \|^2.

The derivation proceeds as follows:

  1. Simplified Output Rule: The output generation rule is simplified by reusing the prediction of gg as the output embedding, rather than re-calling gg with updated weights.
  2. End-to-End at Test Time (Key Step): The critical transition is replacing the layer-wise KVB reconstruction loss with the global next-token prediction loss at the network's output. This removes the need for separate ΞΈK,ΞΈV\theta_K, \theta_V parameters per layer and makes the test-time training objective directly align with the task performance. This intermediate step is called "TTT-E2E all layers MH".
  3. Larger State with Less Compute: The final step involves realizing that for the E2E loss, it's more cost-effective to update fewer blocks (e.g., last 1/4) each containing larger hidden states (regular MLPs instead of multi-head LoRA MLPs used in TTT-KVB). This improves computational efficiency and allows for a larger effective state, which is beneficial for scaling with context length.

Results

TTT-E2E demonstrates superior scaling with context length compared to other methods like Mamba 2, Gated DeltaNet, and even Transformer with full attention (in terms of latency). For 3B models, TTT-E2E maintains its performance advantage across context lengths up to 128K, mirroring the scaling of full attention in terms of test loss difference. Crucially, TTT-E2E achieves constant inference latency regardless of context length (similar to RNNs and sliding-window attention), making it 2.7x faster than full attention for 128K context.