GitHub - EverMind-AI/MSA
Key Points
- 1Memory Sparse Attention (MSA) is an end-to-end trainable, scalable sparse latent-memory framework designed to overcome LLM context length limitations, enabling efficient processing of up to 100M tokens.
- 2MSA achieves this through innovations like scalable sparse attention with document-wise RoPE for extrapolation, KV cache compression via a Memory Parallel inference engine, and Memory Interleave for multi-round reasoning.
- 3Evaluated on long-context QA and Needle-in-a-Haystack benchmarks, MSA consistently outperforms RAG and other leading long-context models, exhibiting remarkable stability with less than 9% degradation from 16K to 100M tokens.
The paper introduces Memory Sparse Attention (MSA), a scalable, end-to-end trainable latent-memory framework designed to overcome the context length limitations of large language models (LLMs), enabling effective processing of 100M-token contexts. Existing approaches, such as linear attention, fixed-size state memory, and external storage (RAG/agents), often suffer from precision decay, latency growth, lack of end-to-end differentiability, or complex pipelines. MSA addresses these shortcomings by integrating scalable sparse attention, efficient KV cache management, and a multi-round reasoning mechanism.
Core Ideas and Contributions:
- Memory-Sparse Attention (MSA): An end-to-end trainable, scalable sparse attention layer utilizing document-wise RoPE, achieving near-linear complexity and exhibiting less than 9% accuracy degradation when scaling from 16K to 100M tokens.
- KV Cache Compression + Memory Parallel: A tiered storage system where GPU-resident routing keys facilitate distributed scoring, while content K/V pairs reside in host DRAM, enabling on-demand transfers. This architecture, combined with a Memory Parallel inference engine, delivers 100M-token throughput on 2รA800 GPUs.
- Memory Interleave: An adaptive process that cycles through "generative retrieval context expansion generation" to significantly boost multi-hop reasoning across scattered memory segments.
- Comprehensive Evaluation: MSA demonstrates superior stability and accuracy compared to same-backbone RAG, best-of-breed RAG pipelines, and leading long-context models across various long-context QA and Needle-in-a-Haystack (NIAH) benchmarks.
Overall Design Architecture and Methodology:
MSA integrates retrieval and generation into a single differentiable loop. The architecture processes document latent states by performing chunk-mean pooling on their Key (K), Value (V), and Router Key () representations to achieve compression.
A router projector computes relevance scores between the query's representation and the compressed document using cosine similarity, mean-pooled over attention heads, followed by a token-wise maximum. This relevance score is then used to select the Top-k most relevant documents. The compressed K/V pairs of these selected Top-k documents are concatenated with the query's local K/V for autoregressive decoding. Crucially, this routing mechanism is applied only to the upper layers of the model, while lower layers maintain independent document processing, ensuring hierarchical alignment and preserving fine-grained information.
Positional Encoding with RoPE:
MSA employs a hybrid RoPE (Rotary Positional Embedding) strategy:
- Parallel (document-wise) RoPE: Each retrieved document's positional encoding is reset from 0. This prevents positional drift between models trained on shorter contexts (e.g., 64K tokens) and those inferred on extremely long contexts (e.g., 100M tokens), enabling strong extrapolation.
- Global RoPE (active context): For the query and the concatenated retrieved blocks, the query's starting index is offset by (the number of retrieved blocks), preserving causal ordering: background information query generation.
Inference Pipeline:
The MSA inference pipeline consists of three stages:
- Global Memory Encoding (Offline): The entire corpus is pre-processed offline by forwarding it through the model to cache chunk-pooled representations: , , and . These compressed representations form the global memory bank.
- Online Routing & Context Assembly: Given a query, it is first projected to its router key representation, . This is then matched with the pre-computed from the global memory bank to identify and select the Top-k most relevant document chunks. Only the selected and content are loaded and concatenated with the local context (query and generated tokens so far). The Memory Parallel mechanism aids this process by sharding across GPUs, broadcasting the query, performing local scoring, and then a global reduction to determine the Top-k. Content remains in host DRAM and is asynchronously fetched only when selected, balancing VRAM usage and throughput for 100M-token deployment.
- Sparse Generation: Autoregressive decoding is performed over the assembled sparse context, which now includes the query, previous generated tokens, and the retrieved Top-k document chunks.
Memory Interleave:
To facilitate complex multi-hop reasoning, MSA utilizes Memory Interleave. This involves an adaptive, alternating pipeline: "generative retrieval context expansion generation". This iterative process allows the model to dynamically retrieve additional relevant information based on intermediate generation results, expanding the context and enhancing its reasoning capabilities across multiple, scattered memory segments.
Training Details:
MSA is trained with 158.95 billion tokens of continuous pretraining, incorporating an auxiliary routing loss to optimize the document selection process. This is followed by a two-stage Supervised Fine-Tuning (SFT) curriculum, scaling from 8K to 64K tokens, which helps the model learn to leverage longer contexts effectively. Ablation studies confirmed that curriculum extension, Memory Interleave, continuous pretraining, and injecting original text all contribute substantially to MSA's performance, with their removal causing 5%โ37% drops in accuracy depending on the task.
Results:
MSA consistently outperforms strong baselines:
- Long-Context QA: On 9 QA datasets (memory banks ranging from 277K to 10M tokens) using a Qwen3-4B-Instruct-2507 backbone, MSA achieved an average LLM judge score of 3.760, representing a +16.0% improvement over standard RAG, +11.5% over RAG+rerank, and +14.8% over HippoRAG2. MSA led on 8 out of 9 datasets within the same-backbone group.
- Best-of-Breed RAG Stacks Comparison: When compared against SOTA RAG stacks using larger backbones (e.g., KaLMv2 + Qwen3-235B, KaLMv2 + Llama-3.3-70B, with and without reranking), MSA achieved the highest score on 4 of 9 datasets and maintained an average score of 3.760, yielding relative gains of +7.2%, +5.0%, +10.7%, and +5.4% over the strongest configurations. The paper notes that performance gaps on some datasets (e.g., MuSiQue) are largely attributable to the differing parameter counts and intrinsic reasoning capacities of the backbone models.
- NIAH (RULER) Stability: MSA demonstrates remarkable stability in accuracy across extreme context lengths. On RULER, MSA maintained 94.84% accuracy at 1M tokens, while the unmodified backbone collapsed beyond 128K (down to 24.69% at 1M). Hybrid linear-attention models also degraded noticeably at 128K/256K, and external-memory agents, while stable, showed weaker absolute accuracy and steeper decay than MSA. Across an unprecedented 16K to 100M token range, MSA showed less than 9% degradation, suggesting a practical path to decouple memory capacity from reasoning.
MSA's design effectively decouples memory capacity from reasoning, offering a robust and practical solution for deploying LLMs in extremely long-context scenarios.