GitHub - deepseek-ai/FlashMLA: FlashMLA: Efficient Multi-head Latent Attention Kernels
Key Points
- 1FlashMLA is DeepSeek's optimized library of attention kernels, powering DeepSeek-V3 and DeepSeek-V3.2-Exp models with highly efficient dense and sparse attention implementations.
- 2It delivers substantial performance gains for prefill and decoding stages, achieving up to 660 TFlops on NVIDIA H800 GPUs and featuring token-level sparse attention with FP8 KV cache support.
- 3The library provides kernels for SM90/SM100 architectures, including MQA and MHA modes, and is supported across various other GPU platforms through community collaborations.
FlashMLA is DeepSeek's optimized library of attention kernels, designed to power large language models such as DeepSeek-V3 and DeepSeek-V3.2-Exp. It comprises implementations for both sparse and dense attention mechanisms, covering both prefill and decoding stages, and is optimized for NVIDIA GPU architectures (SM90/SM100).
The core methodology of FlashMLA revolves around highly optimized CUDA kernels that aim to maximize throughput (TFLOPS) for compute-bound workloads and memory bandwidth (GB/s) for memory-bound configurations.
1. Sparse Attention Kernels:
These kernels implement DeepSeek Sparse Attention (DSA), which is token-level sparse.
- Decoding Stage: Utilizes
flash_mla_with_kvcache.- FP8 KV Cache: For decoding, the kernel supports an FP8 KV cache, where matrix multiplication is performed in bfloat16 after dequantization. The FP8 KV cache format per token (656 Bytes) consists of:
- First 512 bytes: "quantized NoPE" (512
float8_e4m3values). - Next 16 bytes: Scale factors (4
float32values, each scaling 128float8_e4m3values). - Last 128 bytes: "RoPE" (64
bfloat16values), not quantized for accuracy.
- First 512 bytes: "quantized NoPE" (512
- Sparsity (
indices): Token-level sparsity is enabled by anindicestensor of shape(batch_size, seq_len_q, topk).indices[i][j][k]specifies the physical memory location (block index * page_block_size + offset within block) of the -th relevant token for the -th query in the -th batch. Invalid entries are marked with -1. - Input:
q_i(query),kvcache_i(KV cache),block_table,cache_seqlens,dv,tile_scheduler_metadata,num_splits,is_causal,is_fp8_kvcache,indices. - Output:
(out, lse), whereoutis the attention result andlseis the log-sum-exp of attention scores for each query head. - Performance: Achieves 410 TFLOPS (compute-bound) on H800 SXM5 for FP8 sparse decoding.
- FP8 KV Cache: For decoding, the kernel supports an FP8 KV cache, where matrix multiplication is performed in bfloat16 after dequantization. The FP8 KV cache format per token (656 Bytes) consists of:
- Prefill Stage: Uses
flash_mla_sparse_fwd.- Inputs:
q(Query, shape ),kv(Key-Value, shape ),indices(shape ),sm_scale. Note: This kernel does not support a batch dimension directly; multi-batch inference requires reshaping inputs andindices. Invalid indices are -1 or . - Output:
(out, max_logits, lse). - Equivalent PyTorch Operations: Given query , key-value , and indices (assuming for simplicity):
- Select focused KV tokens: , where .
- Compute logits :
- Inputs:
- Compute maximum logits :
- Compute log-sum-exp (base 2):
- Compute sparse attention scores :
- Compute output :
The kernel returns .
- Performance: Achieves up to 640 TFLOPS on H800 SXM5 and 1450 TFLOPS on B200.
2. Dense Attention Kernels:
These implement standard Multi-Head Attention (MHA).
- Prefill Stage: Leverages functions like
flash_attn_varlen_func,flash_attn_varlen_qkvpacked_func, andflash_attn_varlen_kvpacked_func, similar in usage to theflash_attnlibrary.- Performance: Achieves up to 1460 TFLOPS (forward) and 1000 TFLOPS (backward) on B200 (reported by NVIDIA).
- Decoding Stage: Achieves up to 3000 GB/s (memory-bound) and 660 TFLOPS (computation-bound) on H800 SXM5.
Performance Summary & Requirements:
FlashMLA achieves significant performance gains, including 5%~15% improvement for compute-bound workloads, reaching up to 660 TFLOPS on NVIDIA H800 SXM5 GPUs for dense MLA decoding.
- Hardware: SM90 (Hopper) and SM100 (Blackwell) architectures.
- Software: CUDA 12.8+ (12.9+ for SM100) and PyTorch 2.0+.
- MLA Mode: "MLA Mode" refers to the attention configuration.
- MQA (Multi-Query Attention): Typically uses with . Supported for Dense Decoding (SM90, BF16 KV cache), Sparse Decoding (SM90 & SM100, FP8 KV cache), and Sparse Prefill (SM90 & SM100).
- MHA (Multi-Head Attention): Uses with . Supported for Dense Prefill (SM100).
FlashMLA is inspired by FlashAttention 2&3 and CUTLASS projects, providing a robust and performant attention kernel library.