STEM: Scaling Transformers with Embedding Modules
Paper

STEM: Scaling Transformers with Embedding Modules

Attiano Purpura-Pontoniere
2026.01.22
·Arxiv·by 이호민
#Transformer#Sparsity#LLM#Embedding#Efficiency

핵심 포인트

  • 1STEM은 Transformer 모델에서 FFN의 up-projection을 토큰별 layer-local embedding lookup으로 대체하여, 고정된(static) 희소성을 도입하고 CPU offload를 가능하게 함으로써 per-token compute 및 cross-device communication 부담을 줄이는 방법론입니다.
  • 2이 접근 방식은 극심한 희소성에도 불구하고 안정적인 학습을 제공하며, MoE와 같은 동적 희소성 방식의 단점(training instability, load balancing, communication overhead)을 해결합니다.
  • 3STEM은 dense baseline 대비 최대 3~4%의 정확도 향상과 함께 FLOPs 및 parameter access 감소를 제공하며, 토큰 기반의 특성으로 인해 향상된 지식 저장 능력, Interpretability 및 knowledge editing 기능을 보여줍니다.

STEM (Scaling Transformers with Embedding Modules)은 Fine-grained sparsity가 높은 parametric capacity를 제공하지만, training instability, load balancing, communication overhead 문제에 직면한다는 점을 해결하기 위해 제안된 static, token-indexed 방식의 Transformer architecture이다. MoE(Mixture-of-Experts)와 같은 기존의 sparsity 접근 방식은 런타임 라우팅, load balancing 문제, 통신 오버헤드, 그리고 낮은 interpretability를 특징으로 한다. STEM은 이러한 문제들을 해결하며, 특히 FFN(Feed-Forward Network)의 up-projection을 layer-local embedding lookup으로 대체하고, gate 및 down-projection은 dense하게 유지하는 방식으로 작동한다.

Core Methodology (핵심 방법론)

Transformer의 FFN은 Key-Value memory 관점에서 이해될 수 있다. 여기서 up-projection matrix WuW_u는 입력 xx를 addressing weight를 생성하는 "key"로 매핑하고, down-projection matrix WdW_d의 열들은 "value"로 기능한다. Gated linear units (GLUs)의 경우, gate projection WgW_g는 각 memory slot의 참여를 조절하는 추가적인 "key"를 제공하여 context-adaptive retrieval을 가능하게 한다.

STEM은 이러한 FFN의 memory view에 기반하여, FFN layer의 up-projection인 WuxW_u^\ell x^\ell 부분을 token-indexed embedding lookup U[t]U^\ell[t]로 대체한다. 여기서 tt는 현재 token의 vocabulary ID를 나타내고, URV×dffU^\ell \in \mathbb{R}^{V \times d_{ff}}는 해당 layer의 embedding table이다. STEM layer의 출력 yy^\ell은 다음과 같이 정의된다:
y=Wd(SiLU(Wgx)U[t])y^\ell = W_d^\ell \left( \text{SiLU}(W_g^\ell x^\ell) \odot U^\ell[t] \right)
여기서 \odot는 element-wise multiplication을 나타낸다.

STEM은 기존 PLE(Per Layer Embedding)와 달리 FFN 블록을 보완하는 것이 아니라, FFN의 up-projection을 완전히 대체한다. 이러한 설계는 up-projection이 feature lookup을 위한 주소를 생성하는 반면, gate projection이 context-dependent modulation을 제공한다는 FFN의 역할 분석에 기반한다. 실험적으로, gate projection을 context-agnostic embedding으로 대체하는 것은 성능 저하를 가져오므로, up-projection만 대체하는 것이 더 효과적임이 입증되었다.

Benefits (장점)

  1. Better Training Stability (향상된 학습 안정성): MoE 모델과 달리, STEM은 extreme sparsity에도 불구하고 학습 불안정성이나 loss spikes를 보이지 않는다.
  2. Improved Performance with Larger Knowledge Capacity (더 큰 지식 용량으로 인한 성능 향상): STEM은 학습된 embedding 공간에서 더 큰 angular spread (낮은 pairwise cosine similarity)를 보인다. 이는 representational interference를 줄이고 parametric memory의 addressability를 향상시켜 효과적인 정보 저장 용량을 증가시킨다. 이는 ARC-Challenge 및 OpenBookQA와 같은 knowledge-intensive task에서 dense baseline 대비 9~10%의 큰 성능 향상으로 이어진다.
  3. Interpretability Features (해석 가능성 특징): 각 STEM embedding은 특정 token ID에 연결되어 있어, 개별 "micro-experts"가 명확한 token-level semantics를 가진다. 이는 모델의 지식 attribution을 더 투명하게 만들며, token-indexed nature 덕분에 입력 텍스트를 변경하지 않고도 STEM table index를 변경하여 모델의 output distribution을 제어(knowledge editing 및 injection)할 수 있다.
  4. Improved Long-context Inference (향상된 long-context 추론): context length가 증가함에 따라 더 많은 distinct parameter들이 활성화되어 test-time capacity scaling이 가능해진다. 이는 Needle-in-a-Haystack 벤치마크에서 long context (8k/16k/32k)에서 dense baseline 대비 8.4%에서 13%까지의 성능 향상으로 나타난다.
STEM에 의해 활성화되는 parameter의 수는 다음과 같다:
ParamsSTEMact(L)=SdffLuniqParams_{\text{STEM}}^{\text{act}}(L) = |S| d_{ff} L_{\text{uniq}}
여기서 S|S|는 STEM이 적용된 layer 수, dffd_{ff}는 FFN hidden size, LuniqL_{\text{uniq}}는 sequence 내의 unique token ID의 수이다.
  1. Training and Inference-time efficiency (학습 및 추론 시간 효율성): FFN layer의 up-projection 파라미터를 제거함으로써, STEM은 FFN 파라미터의 약 1/3을 절감한다.
    • FLOPs savings: Training 및 prefill 시, STEM의 per-layer FLOPs 감소는 ΔFtrain=BLddff\Delta F_{\text{train}} = \text{BLd} d_{ff} 이며, saving fraction은 dff4d+2L+3dff\frac{d_{ff}}{4d + 2L + 3d_{ff}}이다.
    • Parameter loading cost savings: Decoding 시, per-layer memory access cost 감소는 ΔMdec=Bddff\Delta M_{\text{dec}} = \text{Bd} d_{ff} 이며, saving fraction은 dff4d+2L+3dff\frac{d_{ff}}{4d + 2L + 3d_{ff}}이다.
  2. VRAM and Communication Savings (VRAM 및 통신 절감): STEM embedding table은 CPU 메모리로 offload될 수 있어 GPU VRAM을 절약하고, expert parallelism에서 발생하는 cross-node 통신 오버헤드를 피한다. Token-indexed 특성상 routing logic 없이 prefetch가 가능하며, 자주 사용되는 embedding을 캐싱하여 traffic을 더욱 줄일 수 있다.

Knowledge Editing with STEM (STEM을 이용한 지식 편집)

STEM의 지식 편집 능력은 특정 토큰의 STEM embedding을 수정함으로써, 입력 텍스트를 변경하지 않고도 모델의 출력을 조작할 수 있음을 보여준다. Source entity ("Spain")의 STEM embedding을 target entity ("Germany")의 embedding으로 교체하면, 모델이 "Berlin"과 같은 관련 정보를 생성하도록 유도할 수 있다. Source와 target entity의 tokenization 길이가 다를 경우, source span이 길면 padding (left padding이 선호됨) 또는 copying (target tokens를 반복하여 채움) 전략을 사용한다.