GitHub - deepseek-ai/DeepGEMM: DeepGEMM: clean and efficient FP8 GEMM kernels with fine-grained scaling
Key Points
- 1DeepGEMM is a CUDA-based library designed for clean and efficient General Matrix Multiplications (GEMMs), supporting FP8 and BF16 data types across dense and MoE grouped scenarios.
- 2It features a lightweight JIT module for runtime kernel compilation, simplifying its design by focusing on a limited set of core functions without heavy reliance on complex templates from other libraries.
- 3Despite its simplicity, DeepGEMM achieves competitive performance, matching or exceeding expert-tuned libraries, with reported peak performance of up to 1550 TFLOPS on H800, and includes specialized kernels for tasks like MoE weight gradients and MQA logits.
DeepGEMM is a CUDA-based library for General Matrix Multiplications (GEMMs), emphasizing efficiency and clarity. It supports FP8 and BF16 (work-in-progress) data types for both standard dense GEMMs and grouped scenarios relevant to Mix-of-Experts (MoE) models. The library is designed to be lightweight, avoiding heavy reliance on complex template metaprogramming from frameworks like CUTLASS, instead focusing on a limited set of core kernel functions to serve as an accessible resource for NVIDIA GPU kernel optimization. Despite its simplified design, DeepGEMM claims to match or exceed the performance of expert-tuned libraries across various matrix shapes.
The core methodology of DeepGEMM revolves around a Just-In-Time (JIT) compilation strategy using a lightweight C++ module, eliminating the need for pre-installation kernel compilation. Initially, it eschewed NVRTC and post-compilation SASS optimizations, relying on NVCC 12.9+ for automatic FFMA interleaving, though NVRTC support was later introduced as an option for faster compilation.
DeepGEMM provides optimized GEMM kernels following the convention , offering various transposition options. For dense GEMMs, functions like are available.
A key distinction lies in memory layout support:
- SM90 (Hopper) architecture: Supports only the NT memory layout (non-transposed A, transposed B), i.e., .
- SM100 (Blackwell) architecture: Supports all memory layouts (NT, TN, NN, TT).
For all architectures, the Left-Hand Side (LHS) scaling factor is required to be TMA-aligned and transposed. The format of these scaling factors differs by architecture: SM90 requires FP32, while SM100 requires packed UE8M0 format, where four UE8M0 values are packed into a single 32-bit integer. The library expects users to handle input transpositions and FP8 casting independently.
DeepGEMM specifically caters to grouped GEMMs crucial for MoE models, offering two primary grouping schemes:
- Contiguous Layout Grouping: Designed for scenarios like MoE training forward passes or inference prefilling where experts share the same shape. It groups along the M-axis, while N and K dimensions remain fixed. Tokens processed by different experts are concatenated into a single "contiguous" tensor. A critical requirement is that each expert segment within this contiguous tensor must be aligned to the GEMM M block size, which can be retrieved via
get_mk_alignment_for_contiguous_layout(). A K-axis grouped API (k_grouped_fp8_gemm_tn_contiguous) is also provided for MoE weight gradient computations. - Masked Layout Grouping: Applicable during inference decoding, particularly when CUDA graphs are enabled and the CPU is unaware of the exact number of tokens each expert receives. This method employs a mask tensor, allowing the kernel to compute only the valid portions, exemplified by
m_grouped_fp8_gemm_nt_masked.
Additionally, DeepGEMM includes specialized MQA (Multi-Query Attention) kernels, fp8_mqa_logits (non-paged for prefilling) and fp8_paged_mqa_logits (for paged decoding). For the non-paged version, fp8_mqa_logits, given inputs q (E4M3, ), kv (E4M3, ) with float scaling factor (shaped as ), weights (float, ), and cu_seq_len_k_start/end (int, ), the output tensor (shaped as ) is computed for each token in and token in the range as follows:
- (element-wise matrix-vector product, resulting in )
- (element-wise product, resulting in )
- (summation over heads, resulting in a scalar logit).
The library also provides utility functions for managing SM counts, Tensor Core utilization, scaling factor transformations, and TMA alignment queries. Environment variables allow fine-tuning JIT behavior, cache directories, compiler selection (NVRTC vs. NVCC), and debugging output.