End-to-End Test-Time Training for Long Context
Paper

End-to-End Test-Time Training for Long Context

Marcel Rød
2026.01.31
·Arxiv·by 네루
#LLM#Continual Learning#Test-Time Training#Transformer#Long Context

핵심 포인트

  • 1본 논문은 long-context language modeling을 continual learning 문제로 재정의하고, standard Transformer와 Test-Time Training (TTT)을 결합하여 기존 아키텍처의 한계를 극복합니다.
  • 2제안하는 TTT-E2E는 next-token prediction을 통해 test-time에 모델이 계속 학습하게 하고, meta-learning으로 training-time에 TTT에 최적화된 초기화를 학습하여 End-to-End 방식을 구현합니다.
  • 3그 결과, TTT-E2E는 full attention Transformer와 유사하게 context 길이에 따라 성능이 확장되면서도, RNN처럼 context 길이에 무관하게 일정한 inference latency를 유지하여 128K context에서 2.7배 더 빠릅니다.

Long-context 언어 모델링은 Transformer의 self-attention이 컨텍스트 길이에 따라 계산 비용이 O(T2)O(T^2)로 증가하여 비효율적인 문제를 겪는다. 반면 Mamba 2나 Gated DeltaNet과 같은 RNN 계열 모델은 O(T)O(T)의 상수를 유지하지만 긴 컨텍스트에서 효과가 떨어진다. 본 논문은 이러한 long-context 언어 모델링을 아키텍처 설계 문제가 아닌 continual learning 문제로 재정의한다.

제안하는 방법은 Test-Time Training (TTT)의 한 형태로, 표준 Transformer 아키텍처와 sliding-window attention만을 사용하여 테스트 시간에 모델이 주어진 컨텍스트 상에서 다음 토큰 예측(next-token prediction)을 통해 학습을 지속한다. 이를 통해 모델은 읽는 컨텍스트를 가중치에 압축하여 저장한다. 또한, 테스트 시간 학습을 위한 모델의 초기화를 개선하기 위해 훈련 시간(training time)에 meta-learning을 적용한다. 이 전체적인 방법론을 End-to-End (E2E) TTT (TTT-E2E)라고 명명한다.

핵심 방법론 (TTT-E2E)

  1. TTT를 통한 컨텍스트 압축 (Next-Token Prediction):
테스트 시퀀스 X=(x1,,xT)X = (x_1, \dots, x_T)가 주어지면, 모델은 t=1,,Tt=1, \dots, T에 대해 xt1x_{t-1}으로부터 xtx_t를 예측하는 훈련 작업을 수행한다. 모델의 가중치 WW에 대해 시점 tt에서의 손실은 다음과 같다:
t(W)=CE(f(xt1;W),xt)\ell_t(W) = \text{CE}(f(x_{t-1}; W), x_t)
이 손실을 통해 가중치를 순차적으로 업데이트한다. 온라인 그라디언트 하강(online gradient descent)의 경우, 각 시점 tt마다 가중치 Wt1W_{t-1}를 사용하여 xt1x_{t-1}로부터 xtx_t를 예측하고, 손실 t(Wt1)\ell_t(W_{t-1})에 대한 그라디언트를 계산하여 가중치를 업데이트한다:
Wt=Wt1ηt(Wt1)W_t = W_{t-1} - \eta \nabla \ell_t(W_{t-1})
여기서 η\eta는 학습률이며, W0W_0는 테스트 시간 시작 시의 초기 가중치이다. 최종적으로 TT번째 토큰 xTx_T를 사용하여 다음 토큰을 예측한다: p^T+1=f(xT;WT)\hat{p}_{T+1} = f(x_T; W_T). 이 과정은 컨텍스트 정보를 모델의 가중치에 지속적으로 "압축"하여 저장하는 효과를 가져온다.

  1. Meta-Learning을 통한 E2E 훈련:
TTT의 효과를 극대화하기 위해, 훈련 시 모델의 초기 가중치 W0W_0가 테스트 시간 TTT 후의 성능에 최적화되도록 meta-learning을 사용한다. 기존의 'TTT-naive' 접근 방식은 W0W_0만을 사용하여 손실을 최소화했지만, TTT-E2E는 테스트 시 모델의 동작(가중치 업데이트)을 훈련 루프 내에서 모방한다. 훈련 시퀀스 XX에 대한 최종 손실은 TTT 과정에서 업데이트된 가중치로 계산된 토큰별 손실의 합이다:
L(W0;X)=1Tt=1Tt(Wt1)=1Tt=1TCE(f(xt1;Wt1),xt)L(W_0; X) = \frac{1}{T} \sum_{t=1}^{T} \ell_t(W_{t-1}) = \frac{1}{T} \sum_{t=1}^{T} \text{CE}(f(x_{t-1}; W_{t-1}), x_t)
L(W0;X)L(W_0; X)를 최소화하도록 W0W_0를 최적화한다. 이는 inner loop (테스트 시간 가중치 업데이트)와 outer loop (초기 가중치 W0W_0 최적화)로 구성된 meta-learning 프레임워크를 따른다. Outer loop에서는 inner loop의 그라디언트를 포함하는 '그라디언트의 그라디언트(gradients of gradients)'를 계산해야 한다.

  1. Mini-Batch TTT 및 Sliding Window Attention:
온라인 그라디언트 하강(b=1)은 효율성 및 안정성 문제를 야기할 수 있다. 이를 해결하기 위해 미니배치(mini-batch) 업데이트를 도입한다. 시퀀스를 크기 bb의 미니배치로 나누어 업데이트한다:
Wi=Wi1η1bt=(i1)b+1ibt(Wi1)W_i = W_{i-1} - \eta \frac{1}{b} \sum_{t=(i-1)b+1}^{ib} \nabla \ell_t(W_{i-1})
여기서 i=1,,T/bi=1, \dots, T/b 이다.
미니배치 TTT는 각 배치 내에서 모델이 이전 토큰에 대한 기억을 잃는 문제를 발생시킬 수 있다. 이를 보완하기 위해 모델 아키텍처에 sliding-window attention 레이어를 추가한다. 이는 TTT가 가중치를 업데이트하기 전에 모델이 각 미니배치 내의 컨텍스트를 기억할 수 있도록 보장한다. 예를 들어, kbk \ge b (window size \ge mini-batch size)로 설정된다.

  1. 구현 세부 사항 (Implementation Details):
    • MLP 레이어만 TTT: Transformer 블록 중 MLP 레이어만 TTT 중에 업데이트한다. 임베딩, 정규화, 어텐션 레이어는 고정된다. 이는 outer loop의 안정성을 높인다.
    • 일부 블록만 TTT: 전체 Transformer 블록의 1/4 (가장 마지막 블록들)만 TTT 중에 업데이트한다. 이는 계산 비용과 컨텍스트 스케일링 능력 사이의 균형을 맞춘다.
    • 두 개의 MLP 레이어: TTT 중 업데이트되는 블록에 정적(static)인 두 번째 MLP 레이어를 추가하여 사전 훈련된 지식의 손실을 방지한다. 비교를 위해 전체 네트워크의 MLP 히든 차원을 조정하여 총 파라미터 수는 동일하게 유지한다.

대안적 파생 (Alternative Derivation):
TTT-E2E는 이전 연구인 Key-Value Binding (KVB) 기반의 TTT (TTT-KVB)에서 파생될 수 있다. TTT-KVB는 각 레이어에서 key-value 쌍을 재구성하는 손실을 최적화하여 컨텍스트를 압축한다. TTT-E2E는 이 레이어별 재구성 손실을 전체 네트워크의 표준 다음 토큰 예측 손실로 대체하여 테스트 시 E2E 특성을 확보한다. 이 변경은 TTT-KVB의 추가적인 outer-loop 파라미터(θK,θV\theta_K, \theta_V)를 제거한다. 또한, TTT-E2E는 TTT-KVB의 다중 헤드(multi-head) MLP 업데이트 방식 대신 일반 MLP를 사용하며, 모든 레이어를 업데이트하는 대신 마지막 1/4 블록만 업데이트하여 더 적은 계산으로 더 큰 유효 상태(effective state)를 가진다.

결과:
TTT-E2E는 128K 컨텍스트 길이에서 Transformer with full attention과 유사한 성능을 유지하면서도, Mamba 2나 Gated DeltaNet과 달리 컨텍스트 길이에 따라 손실이 크게 증가하지 않음을 보여준다. 또한, RNN과 유사하게 컨텍스트 길이에 관계없이 일정한 추론 지연 시간(constant inference latency)을 가지며, 128K 컨텍스트에서 full attention보다 2.7배 빠르다. 3B 모델 기준으로, TTT-E2E는 SWA나 RNN 계열 모델보다 훨씬 나은 손실 성능을 보이며, full attention에 견줄 만한 성능을 보여준다.