목록으로
GitHub - deepseek-ai/DeepGEMM: DeepGEMM: clean and efficient FP8 GEMM kernels with fine-grained scaling
Service2025.03.08

GitHub - deepseek-ai/DeepGEMM: DeepGEMM: clean and efficient FP8 GEMM kernels with fine-grained scaling

요약

DeepGEMM은 FP8 및 BF16 GEMM 연산을 위한 효율적인 CUDA 기반 라이브러리로, 일반적인 행렬 곱셈과 MoE(Mix-of-Experts) 모델의 특정 시나리오를 지원합니다.
이 라이브러리는 단순한 디자인과 경량 JIT 컴파일 방식을 채택하여 설치 시 커널 컴파일이 필요 없으며, 높은 성능을 제공하면서 NVIDIA GPU 커널 최적화 학습에 유용한 리소스가 됩니다.
DeepGEMM은 SM90 및 SM100 아키텍처를 지원하며, MoE를 위한 특화된 그룹형 GEMM 및 최신 MQA 커널을 포함하여 다양한 고급 최적화 기능을 제공합니다.

상세 내용

DeepGEMM은 General Matrix Multiplications (GEMMs)를 위한 깨끗하고 효율적인 라이브러리로, FP8 및 BF16(개발 중) 데이터 타입을 지원하며, 일반적인 dense GEMM과 MoE(Mix-of-Experts) 그룹화된 시나리오를 모두 처리합니다. 이 라이브러리는 CUDA로 작성되었으며, 설치 시 커널 컴파일이 필요 없이 경량 JIT(Just-In-Time) 모듈을 사용하여 런타임에 모든 커널을 컴파일합니다.

핵심 방법론 및 설계 철학:

  • 간결한 설계 (Clean and Accessible): DeepGEMM은 CUTLASS 및 CuTe의 일부 개념을 활용하지만, 그들의 템플릿이나 대수에 크게 의존하지 않습니다. 대신, 소수의 핵심 커널 함수만을 사용하여 설계되어 간단하고 NVIDIA GPU 커널 최적화 기술 학습에 유용합니다. 이러한 경량 설계에도 불구하고, 다양한 행렬 형태에서 전문가가 튜닝한 라이브러리와 유사하거나 능가하는 성능을 보여줍니다.
  • JIT 컴파일 시스템:
  • * 런타임 컴파일: 설치 시 커널 컴파일 없이 런타임에 JIT 모듈을 통해 컴파일이 이루어집니다.
    * NVRTC 지원: NVCC 대신 NVRTC를 사용하여 컴파일 속도를 최대 10배 향상시킬 수 있으며, DGJITUSENVRTC=1DG_JIT_USE_NVRTC=1 환경 변수를 통해 활성화됩니다.
    * SM100 최적화: SM100 아키텍처에서는 NVCC 12.9 이상 버전이 FFMA interleaving을 자동으로 수행하므로, NVRTC 및 post-compilation SASS 최적화가 비활성화되어 CPU 오버헤드를 줄입니다.

  • 아키텍처 지원 및 메모리 레이아웃:
  • * SM90/SM100 지원: NVIDIA SM90 (Hopper) 및 SM100 (Blackwell) 아키텍처 GPU를 지원합니다.
    * 메모리 레이아웃:
    * SM90 구현은 NT (non-transposed A, transposed B) 메모리 레이아웃만 지원합니다. 즉, D=C+A@BTD = C + A @ B^T 연산을 수행합니다.
    * SM100 구현은 NT, TN, NN, TT를 포함한 모든 메모리 레이아웃을 지원합니다.
    * 스케일링 팩터 (LHS Scaling Factor): LHS 스케일링 팩터는 TMA(Tensor Memory Accelerator) 정렬된 전치(transposed) 레이아웃이어야 합니다.
    * SM90은 FP32 포맷의 스케일링 팩터를 요구합니다.
    * SM100은 4개의 UE8M0을 하나의 torch.int로 패킹하는 packed UE8M0 포맷을 요구합니다.
    * 입력 전치(transposition)나 FP8 캐스팅과 같은 작업은 사용자가 별도로 처리해야 합니다.

  • GEMM 타입:
  • * Normal Dense GEMMs (비-그룹화): fp8gemmnt,nn,tn,ttfp8_gemm_{nt, nn, tn, tt} 함수를 통해 기본적인 FP8 GEMM을 수행합니다.
    * Grouped GEMMs (contiguous layout): MoE 모델의 전문가들이 동일한 형태를 공유하는 시나리오에 특화되어 M-축만 그룹화하고 N과 K는 고정합니다. 토큰을 단일 텐서로 연결하여 처리하며, 각 전문가 세그먼트는 GEMM M 블록 크기에 정렬되어야 합니다 (get_mk_alignment_for_contiguous_layout() 참조). m_grouped_fp8_gemm_{nt, nn}_contiguous 함수를 제공하며, MoE weight backward를 위한 K-축 그룹화 API (k_grouped_fp8_gemm_tn_contiguous)도 있습니다.
    * Grouped GEMMs (masked layout): 추론 디코딩 단계에서 CUDA graph가 활성화되고 CPU가 각 전문가가 받는 토큰 수를 모르는 경우에 사용됩니다. 마스크 텐서를 제공하여 유효한 부분만 계산하며, m_grouped_fp8_gemm_nt_masked 함수를 사용합니다.

  • MQA (Multi-Query Attention) 커널:
  • * DeepSeek v3.2의 lightning indexer를 위한 scoring 커널을 제공합니다. 비-페이지드(prefilling용)와 페이지드(decoding용) 두 가지 버전이 있습니다.
    * 비-페이지드 fp8_mqa_logits 커널은 다음 계산을 수행합니다:
    * qq: E4M3 텐서, 형태는 [seq_len,num_heads,head_dim][seq\_len, num\_heads, head\_dim]
    * kvkv: E4M3 텐서 (형태: [seq_len_kv,head_dim][seq\_len\_kv, head\_dim])
    * SFkvSF_{kv}: float 스케일링 팩터 (형태: [seq_len_kv][seq\_len\_kv])
    * weightsweights: float 텐서 (형태: [seq_len,num_heads][seq\_len, num\_heads])
    * cu_seq_len_k_start,cu_seq_len_k_endcu\_seq\_len\_k\_start, cu\_seq\_len\_k\_end: int 텐서 (형태: [seq_len][seq\_len])
    * clean_logitsclean\_logits: -inf로 채워지지 않은 로짓을 정리할지 여부
    * 출력 텐서 outout의 형태는 [seq_len,seq_len_kv][seq\_len, seq\_len\_kv]이며, 토큰-대-토큰 로짓을 나타냅니다.
    * 각 쿼리 토큰 ii에 대해, cu_seq_len_k_start[i]cu\_seq\_len\_k\_start[i]부터 cu_seq_len_k_end[i])cu\_seq\_len\_k\_end[i])까지의 모든 키-값 토큰 jj에 대해 로짓 out[i,j]out[i, j]는 다음과 같이 계산됩니다:
    1. kv_j=kv[j,:]×SFkv[j]kv\_j = kv[j, :] \times SF_{kv}[j]
    2. outtemp=q[i,:,:]×kv_jout_{temp} = q[i, :, :] \times kv\_j (이는 헤드별 내적)
    3. outij=ReLU(outtemp)×weights[i,:]out_{ij} = \text{ReLU}(out_{temp}) \times weights[i, :]
    4. out[i,j]=outijout[i, j] = \sum out_{ij} (모든 헤드에 걸쳐 합산)
    여기서 ×\times는 element-wise 곱셈을, ReLU\text{ReLU}는 활성화 함수를 의미합니다.

    유틸리티 및 환경 변수:

    * deep_gemm.set_num_sms, get_num_sms: 사용할 SM(Streaming Multiprocessor) 최대 개수를 설정/조회합니다.
    * deep_gemm.set_tc_util, get_tc_util: Tensor Core 활용률을 설정/조회합니다.
    * deep_gemm.transform_sf_into_required_layout: 스케일링 팩터를 필요한 레이아웃으로 변환합니다.
    * deep_gemm.get_tma_aligned_size: TMA 정렬 요구 크기를 얻습니다.
    * deep_gemm.get_mk_alignment_for_contiguous_layout: 그룹화된 contiguous 레이아웃의 그룹 수준 정렬 요구사항을 얻습니다.
    * deep_gemm.get_mn_major_tma_aligned_tensor, get_mn_major_tma_aligned_packed_ue8m0_tensor: MN-major TMA 정렬된 텐서를 얻습니다.
    * 환경 변수: DG_JIT_DEBUG, DG_JIT_CACHE_DIR, DG_JIT_USE_NVRTC, DG_JIT_NVCC_COMPILER, DG_JIT_PTXAS_VERBOSE, DG_JIT_PRINT_COMPILER_COMMAND, DG_PRINT_CONFIGS 등을 제공하여 JIT 컴파일 및 디버깅을 제어합니다.

    DeepGEMM은 CUTLASS 프로젝트에서 영감을 받아 개발되었으며 MIT 라이선스 하에 배포됩니다.

    원본 보기
    GitHub
    Shared by Anonymous