GitHub - deepseek-ai/FlashMLA: FlashMLA: Efficient Multi-head Latent Attention Kernels
Service

GitHub - deepseek-ai/FlashMLA: FlashMLA: Efficient Multi-head Latent Attention Kernels

deepseek-ai
2025.03.08
Β·GitHubΒ·by Anonymous
#Attention Kernels#LLM#Sparse Attention#Deep Learning#GPU Computing

Key Points

  • 1FlashMLA is DeepSeek's optimized library of attention kernels, powering DeepSeek-V3 and DeepSeek-V3.2-Exp models with highly efficient dense and sparse attention implementations.
  • 2It delivers substantial performance gains for prefill and decoding stages, achieving up to 660 TFlops on NVIDIA H800 GPUs and featuring token-level sparse attention with FP8 KV cache support.
  • 3The library provides kernels for SM90/SM100 architectures, including MQA and MHA modes, and is supported across various other GPU platforms through community collaborations.

FlashMLA is DeepSeek's optimized library of attention kernels, designed to power large language models such as DeepSeek-V3 and DeepSeek-V3.2-Exp. It comprises implementations for both sparse and dense attention mechanisms, covering both prefill and decoding stages, and is optimized for NVIDIA GPU architectures (SM90/SM100).

The core methodology of FlashMLA revolves around highly optimized CUDA kernels that aim to maximize throughput (TFLOPS) for compute-bound workloads and memory bandwidth (GB/s) for memory-bound configurations.

1. Sparse Attention Kernels:
These kernels implement DeepSeek Sparse Attention (DSA), which is token-level sparse.

  • Decoding Stage: Utilizes flash_mla_with_kvcache.
    • FP8 KV Cache: For decoding, the kernel supports an FP8 KV cache, where matrix multiplication is performed in bfloat16 after dequantization. The FP8 KV cache format per token (656 Bytes) consists of:
      • First 512 bytes: "quantized NoPE" (512 float8_e4m3 values).
      • Next 16 bytes: Scale factors (4 float32 values, each scaling 128 float8_e4m3 values).
      • Last 128 bytes: "RoPE" (64 bfloat16 values), not quantized for accuracy.
    • Sparsity (indices): Token-level sparsity is enabled by an indices tensor of shape (batch_size, seq_len_q, topk). indices[i][j][k] specifies the physical memory location (block index * page_block_size + offset within block) of the kk-th relevant token for the jj-th query in the ii-th batch. Invalid entries are marked with -1.
    • Input: q_i (query), kvcache_i (KV cache), block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, is_causal, is_fp8_kvcache, indices.
    • Output: (out, lse), where out is the attention result and lse is the log-sum-exp of attention scores for each query head.
    • Performance: Achieves 410 TFLOPS (compute-bound) on H800 SXM5 for FP8 sparse decoding.
  • Prefill Stage: Uses flash_mla_sparse_fwd.
    • Inputs: q (Query, shape [Sq,Hq,Dqk][S_q, H_q, D_{qk}]), kv (Key-Value, shape [Skv,Hkv,Dqk][S_{kv}, H_{kv}, D_{qk}]), indices (shape [Sq,Hkv,topk][S_q, H_{kv}, \text{topk}]), sm_scale. Note: This kernel does not support a batch dimension directly; multi-batch inference requires reshaping inputs and indices. Invalid indices are -1 or β‰₯Skv\ge S_{kv}.
    • Output: (out, max_logits, lse).
    • Equivalent PyTorch Operations: Given query Q∈RSqΓ—HqΓ—DqkQ \in \mathbb{R}^{S_q \times H_q \times D_{qk}}, key-value KV∈RSkvΓ—HkvΓ—DqkKV \in \mathbb{R}^{S_{kv} \times H_{kv} \times D_{qk}}, and indices I∈ZSqΓ—HkvΓ—topkI \in \mathbb{Z}^{S_q \times H_{kv} \times \text{topk}} (assuming Hkv=1H_{kv}=1 for simplicity):
      1. Select focused KV tokens: KVfocused∈RSqΓ—topkΓ—DqkKV_{focused} \in \mathbb{R}^{S_q \times \text{topk} \times D_{qk}}, where KVfocused[i,k,:]=KV[I[i,k],:]KV_{focused}[i, k, :] = KV[I[i, k], :].
      2. Compute logits P∈RSqΓ—HqΓ—topkP \in \mathbb{R}^{S_q \times H_q \times \text{topk}}:
P = (Q @ KV_{focused}^T) \cdot \text{sm_scale} \cdot \log_2(e)
  1. Compute maximum logits M∈RSqΓ—HqM \in \mathbb{R}^{S_q \times H_q}:
M=max⁑dim=βˆ’1(P)M = \max_{\text{dim}=-1}(P)
  1. Compute log-sum-exp L∈RSqΓ—HqL \in \mathbb{R}^{S_q \times H_q} (base 2):
L=log⁑2(βˆ‘k=1topk2Pijk)L = \log_2 \left( \sum_{k=1}^{\text{topk}} 2^{P_{ijk}} \right)
  1. Compute sparse attention scores S∈RSqΓ—HqΓ—topkS \in \mathbb{R}^{S_q \times H_q \times \text{topk}}:
S=2(Pβˆ’L)S = 2^{(P - L)}
  1. Compute output O∈RSqΓ—HqΓ—DqkO \in \mathbb{R}^{S_q \times H_q \times D_{qk}}:
O=S@KVfocusedO = S @ KV_{focused}
The kernel returns (O,M,L)(O, M, L).
  • Performance: Achieves up to 640 TFLOPS on H800 SXM5 and 1450 TFLOPS on B200.

2. Dense Attention Kernels:
These implement standard Multi-Head Attention (MHA).

  • Prefill Stage: Leverages functions like flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func, and flash_attn_varlen_kvpacked_func, similar in usage to the flash_attn library.
    • Performance: Achieves up to 1460 TFLOPS (forward) and 1000 TFLOPS (backward) on B200 (reported by NVIDIA).
  • Decoding Stage: Achieves up to 3000 GB/s (memory-bound) and 660 TFLOPS (computation-bound) on H800 SXM5.

Performance Summary & Requirements:
FlashMLA achieves significant performance gains, including 5%~15% improvement for compute-bound workloads, reaching up to 660 TFLOPS on NVIDIA H800 SXM5 GPUs for dense MLA decoding.

  • Hardware: SM90 (Hopper) and SM100 (Blackwell) architectures.
  • Software: CUDA 12.8+ (12.9+ for SM100) and PyTorch 2.0+.
  • MLA Mode: "MLA Mode" refers to the attention configuration.
    • MQA (Multi-Query Attention): Typically uses headdimk=576head_dim_k = 576 with headdimv=512head_dim_v = 512. Supported for Dense Decoding (SM90, BF16 KV cache), Sparse Decoding (SM90 & SM100, FP8 KV cache), and Sparse Prefill (SM90 & SM100).
    • MHA (Multi-Head Attention): Uses headdimk=192/128head_dim_k = 192/128 with headdimv=128head_dim_v = 128. Supported for Dense Prefill (SM100).

FlashMLA is inspired by FlashAttention 2&3 and CUTLASS projects, providing a robust and performant attention kernel library.