
STEM: Scaling Transformers with Embedding Modules
Key Points
- 1STEM (Scaling Transformers with Embedding Modules) introduces a static, token-indexed method that replaces the FFN up-projection with a layer-local embedding lookup, aiming to address the instability and overheads of fine-grained sparsity.
- 2This approach enhances training stability, reduces per-token FLOPs and parameter access by eliminating about one-third of FFN parameters, and enables CPU offload by prefetching embeddings.
- 3STEM demonstrates significant accuracy improvements on knowledge-intensive tasks, provides unique interpretability for knowledge editing, and exhibits robust capacity scaling for long-context performance.
STEM (Scaling Transformers with Embedding Modules) is a novel approach designed to enhance the parametric capacity and efficiency of Transformer models, particularly by addressing challenges inherent in fine-grained sparsity methods like Mixture-of-Experts (MoE). It aims to provide higher capacity without proportional per-token compute increases, while mitigating issues such as training instability, load balancing, and communication overhead typically associated with MoE.
The core methodology of STEM involves a static, token-indexed modification to the Feed-Forward Network (FFN) architecture. Specifically, in a gated FFN (e.g., SwiGLU), the up-projection matrix () is replaced with a layer-local, token-indexed embedding lookup. The gate projection () and down-projection () weights remain dense and shared across tokens. For a given layer , input hidden state , and current token ID , the STEM layer computes its output as:
where is the per-layer embedding table (with being the vocabulary size and the FFN hidden dimension), is the row of corresponding to token , and denotes element-wise multiplication.
This design choice is motivated by the "key-value memory" view of FFNs, where the up-projection typically generates an address vector () that, after nonlinearity, modulates the retrieval from the down-projection (). The gate projection () provides context-dependent modulation. STEM replaces the context-dependent up-projection with a token-dependent embedding . This choice is critical; ablation studies indicated that replacing the gate projection harms performance because it needs context-adaptivity, whereas the up-projection, acting as an "address generator," benefits from the fixed, token-specific nature of the STEM embeddings.
Key insights and benefits of STEM include:
- Improved Training Stability: Unlike many MoE models which often suffer from loss spikes and training instability due to non-uniform expert routing, STEM exhibits stable training behavior even with extreme sparsity, as it avoids dynamic routing.
- Better Information Storage Capacity: By replacing the up-projection with token-specific embeddings, STEM's embedding space exhibits a significantly larger angular spread (lower pairwise cosine similarity) compared to the address vectors generated by standard FFNs. This reduced redundancy enables more precise and disentangled knowledge attribution, effectively increasing the model's capacity for storing and retrieving information.
- Knowledge Specificity & Interpretability: Each STEM embedding in every layer is tied to a specific token ID, granting "micro-experts" clear, token-level semantics. This direct knowledge attribution allows for interpretability and controllability. The model's output distribution can be systematically steered by surgically modifying the STEM embeddings for a given token ID, even while the input text remains unchanged. This demonstrates that factual knowledge is localized within these embeddings, making it modular and editable.
- Efficiency:
- FLOPs Reduction: During computation-intensive phases (training and prefill), STEM significantly reduces FLOPs. For a batch size , sequence length , model width , and FFN hidden size , the per-layer FLOPs saving is . The saving fraction is .
- Parameter Loading Cost Reduction: During memory-intensive decoding, the per-layer memory access cost reduction is , with the same saving fraction as FLOPs. This is because the matrix parameters are eliminated.
- VRAM and Communication Savings: STEM offloads its large embedding tables to CPU memory, fetching only the required token embeddings asynchronously to the GPU. This eliminates the need for expensive all-to-all communication typical in expert parallelism and frees up GPU VRAM (roughly one-third of FFN parameters).
- Context-length Adaptive Parameter Usage: Since STEM employs token-indexed sparsity, the number of distinct parameters activated in a forward pass scales with the number of unique tokens () in the context window. For STEM layers, the active STEM-specific parameters are . As typically grows sublinearly with sequence length (Heaps' law), longer contexts engage more parameters without increasing per-token FLOPs, leading to practical test-time capacity scaling and improved long-context performance.
Knowledge Editing Mechanism: STEM's token-indexed nature enables a unique knowledge editing capability. By replacing the STEM embeddings associated with a source entity's tokens with those of a target entity, the model's output can be steered. If source () and target () tokenization lengths differ:
- If : Strategies include left-padding the target token sequence with special tokens to match , or copying/repeating target tokens to fill the positions.
- If : Strategies involve selecting a representative subset of target tokens or averaging the embeddings across the entire target span.
Empirically, STEM demonstrates up to 3-4% accuracy improvements over dense baselines, especially on knowledge and reasoning-heavy benchmarks, while reducing per-token FLOPs and parameter accesses.