GitHub - deepseek-ai/FlashMLA: FlashMLA: Efficient Multi-head Latent Attention Kernels
핵심 포인트
- 1FlashMLA는 DeepSeek-V3 및 DeepSeek-V3.2-Exp 모델을 구동하는 DeepSeek의 최적화된 Attention Kernels 라이브러리로, Sparse Attention과 Dense Attention Kernels을 포함합니다.
- 2이 라이브러리는 FP8 KV cache를 사용하는 sparse decoding에서 최대 410 TFlops, dense decoding에서 최대 660 TFlops 등 높은 성능을 달성하며 NVIDIA SM90 및 SM100 GPU 아키텍처를 지원합니다.
- 3FlashMLA는 MLA Decoding과 Sparse/Dense MLA Prefill을 위한 Kernels 사용법을 제공하며, FlashAttention 및 cutlass 프로젝트에서 영감을 받아 다양한 하드웨어 플랫폼에 적용 가능합니다.
FlashMLA는 DeepSeek-AI에서 개발한 최적화된 어텐션 커널 라이브러리로, DeepSeek-V3 및 DeepSeek-V3.2-Exp 모델의 핵심 기술입니다. 이 라이브러리는 고성능의 희소(Sparse) 및 밀집(Dense) 어텐션 커널을 제공하여, 대규모 언어 모델의 추론(prefill 및 decoding) 단계에서 컴퓨팅 효율성을 극대화합니다.
1. 핵심 기능 및 구성:
- Sparse Attention Kernels (DeepSeek Sparse Attention - DSA):
- Prefill 단계의 토큰 단위 희소 어텐션: 입력 시퀀스 전체에 대한 어텐션 계산 대신,
indices텐서로 지정된topk개의 토큰에 대해서만 어텐션을 수행합니다. 이를 통해 불필요한 계산을 줄여 효율성을 높입니다. - Decoding 단계의 토큰 단위 희소 어텐션 (FP8 KV 캐시 사용): 디코딩 과정에서도 특정 토큰에만 집중하여 어텐션을 수행하며, Key-Value(KV) 캐시를 FP8 형식으로 저장하여 메모리 사용량을 줄입니다.
- Prefill 단계의 토큰 단위 희소 어텐션: 입력 시퀀스 전체에 대한 어텐션 계산 대신,
- Dense Attention Kernels:
- Prefill 단계의 밀집 어텐션: 일반적인 Multi-Head Attention (MHA)의 전방향(forward) 및 역방향(backward) 계산을 최적화합니다.
- Decoding 단계의 밀집 어텐션: 디코딩 단계에서 전체 KV 캐시에 대한 밀집 어텐션 계산을 최적화합니다.
2. 성능 하이라이트:
FlashMLA는 최신 NVIDIA GPU (SM90/H800, SM100/B200)에 최적화되어 있습니다.
- Sparse MLA Prefill: H800 SXM5에서 최대 640 TFlops, B200에서 최대 1450 TFlops (forward)의 성능을 달성합니다.
- Sparse MLA Decoding: H800 SXM5에서 최대 410 TFlops (FP8 KV 캐시 사용, bfloat16 연산)를 달성합니다.
- Dense MLA Decoding: H800 SXM5에서 메모리 바운드(memory-bound) 환경에서 최대 3000 GB/s, 컴퓨팅 바운드(computation-bound) 환경에서 최대 660 TFlops를 달성합니다.
- Dense MHA Prefill (SM100): B200에서 전방향 최대 1460 TFlops, 역방향 최대 1000 TFlops를 달성합니다.
- 전반적으로 기존 버전에 비해 컴퓨팅 바운드 워크로드에서 5% ~ 15%의 성능 향상을 제공합니다.
3. 핵심 방법론 (Core Methodology):
FlashMLA는 FlashAttention 2&3 및 CUTLASS 프로젝트에서 영감을 받아 GPU 하드웨어에 최적화된 커스텀 커널을 구현합니다. 특히 희소 어텐션 커널은 DeepSeek Sparse Attention (DSA)의 핵심 메커니즘을 효율적으로 실행합니다.
- MLA Decoding (
flash_mla_with_kvcache):- 이 커널은 디코딩 루프에서
flash_mla_with_kvcache함수를 호출하여 실행됩니다. - 입력 매개변수:
q_i: 현재 쿼리 텐서.[s_q, h_q, d_qk]형태.s_q는 쿼리 토큰 수 (일반적으로 1).kvcache_i: Key-Value 캐시.block_table: 블록 테이블 (희소 어텐션 시indices사용으로 대체 가능).cache_seqlens: 캐시된 시퀀스 길이.dv: 헤드 차원d_v.tile_scheduler_metadata,num_splits:get_mla_metadata함수를 통해 얻는 타이링 스케줄링 정보.is_causal: 인과성 여부 (causal masking).is_fp8_kvcache: FP8 KV 캐시 사용 여부.indices: 희소 어텐션을 위한 3D 인덱스 텐서(batch_size, seq_len_q, topk).- 각
indices[i][j][k]는 번째 배치, 번째 쿼리 시퀀스의 번째 토큰이 참조할 KV 캐시 내 토큰의 실제 주소를 인코딩합니다. 이 주소는(페이지 블록 인덱스) * 페이지 블록 크기 + (페이지 블록 내 오프셋)형태로 계산됩니다. 유효하지 않은 엔트리는 -1로 설정됩니다. 이indices텐스를 사용하면block_table매개변수가 필요 없습니다.
- 각
- FP8 KV 캐시 형식:
- 각 토큰의 KV 캐시는 총 656 Bytes로 구성됩니다.
512 bytes: 양자화된 NoPE(Non-Position Encoding) 부분으로, 512개의float8_e4m3값이 저장됩니다.16 bytes: 스케일 팩터로, 4개의float32값이 저장됩니다. 각float32는 128개의float8_e4m3값 그룹에 대한 스케일입니다.128 bytes: RoPE(Rotary Position Embedding) 부분으로, 64개의bfloat16값이 저장됩니다. 이 부분은 정확도 유지를 위해 양자화되지 않습니다.- 커널은 FP8 캐시를
bfloat16으로 역양자화(dequantize)한 후 어텐션 계산을 수행하며, 출력도bfloat16으로 반환합니다.
- 반환 값:
(out, lse).out은 어텐션 결과,lse는 각 쿼리 헤드의 어텐션 스코어의 log-sum-exp 값입니다.
- 이 커널은 디코딩 루프에서
- Sparse MLA Prefill (
flash_mla_sparse_fwd):- 이 커널은 희소 어텐션의 prefill 단계에 사용되며,
flash_mla_sparse_fwd함수를 직접 호출합니다. - 입력 매개변수:
q: 쿼리 텐서.[s_q, h_q, d_qk](bfloat16).kv: Key-Value 텐서.[s_kv, h_kv, d_qk](bfloat16).indices: 희소 어텐션을 위한 인덱스 텐서.[s_q, h_kv, topk](int32). 유효하지 않은 엔트리는 -1 또는s_kv이상의 값으로 설정됩니다.sm_scale: 스케일링 팩터.
- 핵심 연산 (PyTorch 등가 코드):
- 이 커널은 희소 어텐션의 prefill 단계에 사용되며,
kv및indices의 차원 조정:h_kv가 1인 경우,kv는[s_kv, d_qk]로,indices는[s_q, topk]로squeeze됩니다.focused_kv추출: . 이 단계에서indices에 명시된topk개의 KV 토큰만kv캐시에서 선택됩니다. 결과focused_kv텐서의 형태는[s_q, topk, d_qk]가 됩니다. 즉, 각 쿼리 토큰(s_q차원)은 자신과 어텐션할topk개의 키-값 쌍을 가집니다.- 어텐션 스코어
P계산:
여기서
Q의 형태는 [s_q, h_q, d_qk], focused_kv.transpose(-1, -2)의 형태는 [s_q, d_qk, topk]이므로, 행렬 곱셈 결과 P의 형태는 [s_q, h_q, topk]가 됩니다. 는 스케일링 팩터에 포함되어 로그 스페이스에서의 연산이 를 밑으로 하는 로그 스페이스에 해당함을 나타냅니다.
max_logits추출: . 이는 소프트맥스 연산의 수치적 안정성을 위해 일반적으로 사용되는 기술입니다.lse(Log-Sum-Exp) 계산:
이는 를 계산하는 함수입니다. 일반적인 log-sum-exp는 이지만, 여기서는 밑이 2입니다.
- 소프트맥스
S계산:
이는
P에 대해 를 적용하고 lse를 빼는 형태의 소프트맥스 연산입니다. 각 쿼리-키 쌍에 대한 어텐션 가중치를 의미합니다. S의 형태는 [s_q, h_q, topk]입니다.
- 최종 출력
out계산:
S [s_q, h_q, topk]와 focused_kv [s_q, topk, d_qk]의 행렬 곱셈을 통해 최종 어텐션 출력 out의 형태는 [s_q, h_q, d_qk]가 됩니다.
- 반환 값:
(out, max_logits, lse). - 배치 처리: 이 커널은 직접적인 배치 차원을 지원하지 않습니다. 여러 배치를 처리하려면 입력 텐서를 재구성하고
indices매개변수를 조정하여 배치 처리를 시뮬레이션해야 합니다.
- Dense MHA Prefill:
flash_attn_varlen_func,flash_attn_varlen_qkvpacked_func,flash_attn_varlen_kvpacked_func와 같은 표준 FlashAttention 인터페이스를 사용합니다.
4. 지원 환경 및 파트너십:
- GPU 아키텍처: NVIDIA SM90 (H100/H200/H800) 및 SM100 (B100/B200)을 지원합니다.
- CUDA: CUDA 12.8 이상 (SM100 커널의 경우 CUDA 12.9+ 필요).
- PyTorch: PyTorch 2.0 이상.
- MQA/MHA 모드: Multi-Query Attention (MQA) 및 Multi-Head Attention (MHA) 모드를 지원하며, DeepSeek V3.2 Paper의 부록에 설명된 특화된
head_dim_k및head_dim_v구성을 사용합니다. - 다양한 하드웨어 벤더(MetaX, Moore Threads, Hygon DCU, Intellifusion, Iluvatar Corex, AMD Instinct)에서 FlashMLA의 파생 버전 또는 포트를 제공하여 광범위한 하드웨어 플랫폼을 지원합니다.
FlashMLA는 DeepSeek 모델의 효율성을 높이는 데 핵심적인 역할을 하며, 고도로 최적화된 어텐션 커널을 통해 LLM 추론의 성능을 크게 향상시키는 데 기여합니다.