GitHub - deepseek-ai/DeepGEMM: DeepGEMM: clean and efficient FP8 GEMM kernels with fine-grained scaling
요약
상세 내용
핵심 방법론 및 설계 철학:
* NVRTC 지원: NVCC 대신 NVRTC를 사용하여 컴파일 속도를 최대 10배 향상시킬 수 있으며, 환경 변수를 통해 활성화됩니다.
* SM100 최적화: SM100 아키텍처에서는 NVCC 12.9 이상 버전이 FFMA interleaving을 자동으로 수행하므로, NVRTC 및 post-compilation SASS 최적화가 비활성화되어 CPU 오버헤드를 줄입니다.
* 메모리 레이아웃:
* SM90 구현은 NT (non-transposed A, transposed B) 메모리 레이아웃만 지원합니다. 즉, 연산을 수행합니다.
* SM100 구현은 NT, TN, NN, TT를 포함한 모든 메모리 레이아웃을 지원합니다.
* 스케일링 팩터 (LHS Scaling Factor): LHS 스케일링 팩터는 TMA(Tensor Memory Accelerator) 정렬된 전치(transposed) 레이아웃이어야 합니다.
* SM90은 FP32 포맷의 스케일링 팩터를 요구합니다.
* SM100은 4개의 UE8M0을 하나의
torch.int로 패킹하는 packed UE8M0 포맷을 요구합니다.* 입력 전치(transposition)나 FP8 캐스팅과 같은 작업은 사용자가 별도로 처리해야 합니다.
* Grouped GEMMs (contiguous layout): MoE 모델의 전문가들이 동일한 형태를 공유하는 시나리오에 특화되어 M-축만 그룹화하고 N과 K는 고정합니다. 토큰을 단일 텐서로 연결하여 처리하며, 각 전문가 세그먼트는 GEMM M 블록 크기에 정렬되어야 합니다 (
get_mk_alignment_for_contiguous_layout() 참조). m_grouped_fp8_gemm_{nt, nn}_contiguous 함수를 제공하며, MoE weight backward를 위한 K-축 그룹화 API (k_grouped_fp8_gemm_tn_contiguous)도 있습니다.* Grouped GEMMs (masked layout): 추론 디코딩 단계에서 CUDA graph가 활성화되고 CPU가 각 전문가가 받는 토큰 수를 모르는 경우에 사용됩니다. 마스크 텐서를 제공하여 유효한 부분만 계산하며,
m_grouped_fp8_gemm_nt_masked 함수를 사용합니다.* 비-페이지드
fp8_mqa_logits 커널은 다음 계산을 수행합니다:* : E4M3 텐서, 형태는
* : E4M3 텐서 (형태: )
* : float 스케일링 팩터 (형태: )
* : float 텐서 (형태: )
* : int 텐서 (형태: )
* : -inf로 채워지지 않은 로짓을 정리할지 여부
* 출력 텐서 의 형태는 이며, 토큰-대-토큰 로짓을 나타냅니다.
* 각 쿼리 토큰 에 대해, 부터 까지의 모든 키-값 토큰 에 대해 로짓 는 다음과 같이 계산됩니다:
1.
2. (이는 헤드별 내적)
3.
4. (모든 헤드에 걸쳐 합산)
여기서 는 element-wise 곱셈을, 는 활성화 함수를 의미합니다.
유틸리티 및 환경 변수:
* deep_gemm.set_num_sms, get_num_sms: 사용할 SM(Streaming Multiprocessor) 최대 개수를 설정/조회합니다.
* deep_gemm.set_tc_util, get_tc_util: Tensor Core 활용률을 설정/조회합니다.
* deep_gemm.transform_sf_into_required_layout: 스케일링 팩터를 필요한 레이아웃으로 변환합니다.
* deep_gemm.get_tma_aligned_size: TMA 정렬 요구 크기를 얻습니다.
* deep_gemm.get_mk_alignment_for_contiguous_layout: 그룹화된 contiguous 레이아웃의 그룹 수준 정렬 요구사항을 얻습니다.
* deep_gemm.get_mn_major_tma_aligned_tensor, get_mn_major_tma_aligned_packed_ue8m0_tensor: MN-major TMA 정렬된 텐서를 얻습니다.
* 환경 변수: DG_JIT_DEBUG, DG_JIT_CACHE_DIR, DG_JIT_USE_NVRTC, DG_JIT_NVCC_COMPILER, DG_JIT_PTXAS_VERBOSE, DG_JIT_PRINT_COMPILER_COMMAND, DG_PRINT_CONFIGS 등을 제공하여 JIT 컴파일 및 디버깅을 제어합니다.
DeepGEMM은 CUTLASS 프로젝트에서 영감을 받아 개발되었으며 MIT 라이선스 하에 배포됩니다.