GitHub - deepseek-ai/FlashMLA: FlashMLA: Efficient Multi-head Latent Attention Kernels
요약
상세 내용
1. 핵심 기능 및 구성:
* Sparse Attention Kernels (DeepSeek Sparse Attention - DSA):
* Prefill 단계의 토큰 단위 희소 어텐션: 입력 시퀀스 전체에 대한 어텐션 계산 대신, indices 텐서로 지정된 topk 개의 토큰에 대해서만 어텐션을 수행합니다. 이를 통해 불필요한 계산을 줄여 효율성을 높입니다.
* Decoding 단계의 토큰 단위 희소 어텐션 (FP8 KV 캐시 사용): 디코딩 과정에서도 특정 토큰에만 집중하여 어텐션을 수행하며, Key-Value(KV) 캐시를 FP8 형식으로 저장하여 메모리 사용량을 줄입니다.
* Dense Attention Kernels:
* Prefill 단계의 밀집 어텐션: 일반적인 Multi-Head Attention (MHA)의 전방향(forward) 및 역방향(backward) 계산을 최적화합니다.
* Decoding 단계의 밀집 어텐션: 디코딩 단계에서 전체 KV 캐시에 대한 밀집 어텐션 계산을 최적화합니다.
2. 성능 하이라이트:
FlashMLA는 최신 NVIDIA GPU (SM90/H800, SM100/B200)에 최적화되어 있습니다.
* Sparse MLA Prefill: H800 SXM5에서 최대 640 TFlops, B200에서 최대 1450 TFlops (forward)의 성능을 달성합니다.
* Sparse MLA Decoding: H800 SXM5에서 최대 410 TFlops (FP8 KV 캐시 사용, bfloat16 연산)를 달성합니다.
* Dense MLA Decoding: H800 SXM5에서 메모리 바운드(memory-bound) 환경에서 최대 3000 GB/s, 컴퓨팅 바운드(computation-bound) 환경에서 최대 660 TFlops를 달성합니다.
* Dense MHA Prefill (SM100): B200에서 전방향 최대 1460 TFlops, 역방향 최대 1000 TFlops를 달성합니다.
* 전반적으로 기존 버전에 비해 컴퓨팅 바운드 워크로드에서 5% ~ 15%의 성능 향상을 제공합니다.
3. 핵심 방법론 (Core Methodology):
FlashMLA는 FlashAttention 2&3 및 CUTLASS 프로젝트에서 영감을 받아 GPU 하드웨어에 최적화된 커스텀 커널을 구현합니다. 특히 희소 어텐션 커널은 DeepSeek Sparse Attention (DSA)의 핵심 메커니즘을 효율적으로 실행합니다.
* MLA Decoding (flash_mla_with_kvcache):
* 이 커널은 디코딩 루프에서 flash_mla_with_kvcache 함수를 호출하여 실행됩니다.
* 입력 매개변수:
* q_i: 현재 쿼리 텐서. [s_q, h_q, d_qk] 형태. s_q는 쿼리 토큰 수 (일반적으로 1).
* kvcache_i: Key-Value 캐시.
* block_table: 블록 테이블 (희소 어텐션 시 indices 사용으로 대체 가능).
* cache_seqlens: 캐시된 시퀀스 길이.
* dv: 헤드 차원 d_v.
* tile_scheduler_metadata, num_splits: get_mla_metadata 함수를 통해 얻는 타이링 스케줄링 정보.
* is_causal: 인과성 여부 (causal masking).
* is_fp8_kvcache: FP8 KV 캐시 사용 여부.
* indices: 희소 어텐션을 위한 3D 인덱스 텐서 (batch_size, seq_len_q, topk).
* 각 indices[i][j][k]는 번째 배치, 번째 쿼리 시퀀스의 번째 토큰이 참조할 KV 캐시 내 토큰의 실제 주소를 인코딩합니다. 이 주소는 (페이지 블록 인덱스) * 페이지 블록 크기 + (페이지 블록 내 오프셋) 형태로 계산됩니다. 유효하지 않은 엔트리는 -1로 설정됩니다. 이 indices 텐스를 사용하면 block_table 매개변수가 필요 없습니다.
* FP8 KV 캐시 형식:
* 각 토큰의 KV 캐시는 총 656 Bytes로 구성됩니다.
* 512 bytes: 양자화된 NoPE(Non-Position Encoding) 부분으로, 512개의 float8_e4m3 값이 저장됩니다.
* 16 bytes: 스케일 팩터로, 4개의 float32 값이 저장됩니다. 각 float32는 128개의 float8_e4m3 값 그룹에 대한 스케일입니다.
* 128 bytes: RoPE(Rotary Position Embedding) 부분으로, 64개의 bfloat16 값이 저장됩니다. 이 부분은 정확도 유지를 위해 양자화되지 않습니다.
* 커널은 FP8 캐시를 bfloat16으로 역양자화(dequantize)한 후 어텐션 계산을 수행하며, 출력도 bfloat16으로 반환합니다.
* 반환 값: (out, lse). out은 어텐션 결과, lse는 각 쿼리 헤드의 어텐션 스코어의 log-sum-exp 값입니다.
* Sparse MLA Prefill (flash_mla_sparse_fwd):
* 이 커널은 희소 어텐션의 prefill 단계에 사용되며, flash_mla_sparse_fwd 함수를 직접 호출합니다.
* 입력 매개변수:
* q: 쿼리 텐서. [s_q, h_q, d_qk] (bfloat16).
* kv: Key-Value 텐서. [s_kv, h_kv, d_qk] (bfloat16).
* indices: 희소 어텐션을 위한 인덱스 텐서. [s_q, h_kv, topk] (int32). 유효하지 않은 엔트리는 -1 또는 s_kv 이상의 값으로 설정됩니다.
* sm_scale: 스케일링 팩터.
* 핵심 연산 (PyTorch 등가 코드):
이 커널은 다음 PyTorch 연산과 등가적인 계산을 수행합니다.
1. kv 및 indices의 차원 조정: h_kv가 1인 경우, kv는 [s_kv, d_qk]로, indices는 [s_q, topk]로 squeeze됩니다.
2. focused_kv 추출: . 이 단계에서 indices에 명시된 topk개의 KV 토큰만 kv 캐시에서 선택됩니다. 결과 focused_kv 텐서의 형태는 [s_q, topk, d_qk]가 됩니다. 즉, 각 쿼리 토큰(s_q 차원)은 자신과 어텐션할 topk개의 키-값 쌍을 가집니다.
3. 어텐션 스코어 P 계산:
여기서 Q의 형태는 [s_q, h_q, d_qk], focused_kv.transpose(-1, -2)의 형태는 [s_q, d_qk, topk]이므로, 행렬 곱셈 결과 P의 형태는 [s_q, h_q, topk]가 됩니다. 는 스케일링 팩터에 포함되어 로그 스페이스에서의 연산이 를 밑으로 하는 로그 스페이스에 해당함을 나타냅니다.
4. max_logits 추출: . 이는 소프트맥스 연산의 수치적 안정성을 위해 일반적으로 사용되는 기술입니다.
5. lse (Log-Sum-Exp) 계산:
이는 를 계산하는 함수입니다. 일반적인 log-sum-exp는 이지만, 여기서는 밑이 2입니다.
6. 소프트맥스 S 계산:
이는 P에 대해 를 적용하고 lse를 빼는 형태의 소프트맥스 연산입니다. 각 쿼리-키 쌍에 대한 어텐션 가중치를 의미합니다. S의 형태는 [s_q, h_q, topk]입니다.
7. 최종 출력 out 계산:
S [s_q, h_q, topk]와 focused_kv [s_q, topk, d_qk]의 행렬 곱셈을 통해 최종 어텐션 출력 out의 형태는 [s_q, h_q, d_qk]가 됩니다.
* 반환 값: (out, max_logits, lse).
* 배치 처리: 이 커널은 직접적인 배치 차원을 지원하지 않습니다. 여러 배치를 처리하려면 입력 텐서를 재구성하고 indices 매개변수를 조정하여 배치 처리를 시뮬레이션해야 합니다.
* Dense MHA Prefill:
* flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func와 같은 표준 FlashAttention 인터페이스를 사용합니다.
4. 지원 환경 및 파트너십:
* GPU 아키텍처: NVIDIA SM90 (H100/H200/H800) 및 SM100 (B100/B200)을 지원합니다.
* CUDA: CUDA 12.8 이상 (SM100 커널의 경우 CUDA 12.9+ 필요).
* PyTorch: PyTorch 2.0 이상.
* MQA/MHA 모드: Multi-Query Attention (MQA) 및 Multi-Head Attention (MHA) 모드를 지원하며, DeepSeek V3.2 Paper의 부록에 설명된 특화된 head_dim_k 및 head_dim_v 구성을 사용합니다.
* 다양한 하드웨어 벤더(MetaX, Moore Threads, Hygon DCU, Intellifusion, Iluvatar Corex, AMD Instinct)에서 FlashMLA의 파생 버전 또는 포트를 제공하여 광범위한 하드웨어 플랫폼을 지원합니다.
FlashMLA는 DeepSeek 모델의 효율성을 높이는 데 핵심적인 역할을 하며, 고도로 최적화된 어텐션 커널을 통해 LLM 추론의 성능을 크게 향상시키는 데 기여합니다.