[Common][PyTorch] Add strided batched GEMM in BF16/MXFP8#3160
Conversation
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
Greptile SummaryThis PR adds experimental cuBLASLt-backed strided batched GEMM support to Transformer Engine, covering BF16 and MXFP8 precision, along with a new fusible
Confidence Score: 3/5The new GEMM path is well-tested and mathematically correct, but the C++ GEMM function leaks cuBLASLt handles whenever an MXFP8 call fails due to a cuBLAS version mismatch — both at compile time and at runtime. In cublas_gemm_strided_batched, Adesc/Bdesc/Cdesc/Ddesc/operationDesc are allocated before the MXFP8 version guard. The compile-time #else NVTE_ERROR and the runtime cublas_version() NVTE_CHECK both throw after those handles exist but before the cleanup block, making it a definite leak on every failed MXFP8 strided batched GEMM call against an old library. transformer_engine/common/gemm/cublaslt_gemm.cu — the strided batched GEMM function needs RAII-style cleanup for the cuBLASLt descriptors it allocates before the version check. Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant PY as Python (BatchedGEMM / strided_batched_gemm)
participant CPP as C++ Extension (gemm.cpp)
participant SWIZ as Swizzle Kernel (swizzle.cu)
participant GEMM as GEMM Kernel (cublaslt_gemm.cu)
PY->>CPP: strided_batched_gemm(A, B, out, layout, ...)
alt MXFP8 inputs
CPP->>SWIZ: nvte_pack_mxfp8_scales_for_batched_gemm(compact packed, stream)
SWIZ-->>CPP: packed [batch0_scales][batch1_scales]...
CPP->>SWIZ: nvte_pack_mxfp8_columnwise_scales_for_batched_gemm(compact packed, stream)
SWIZ-->>CPP: packed column-wise scales
end
CPP->>GEMM: "nvte_cublas_gemm_strided_batched(A lda/stridea, B ldb/strideb, C=D, D ldd/strided, ...)"
GEMM->>GEMM: CheckStridedBatchBuffer (bounds check)
GEMM->>GEMM: CheckMXFP8BatchScaleBuffer (scale size check)
GEMM->>GEMM: cublasLtMatmulAlgoGetHeuristic + cublasLtMatmul
GEMM-->>CPP: D (output written in-place)
CPP-->>PY: out tensor
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant PY as Python (BatchedGEMM / strided_batched_gemm)
participant CPP as C++ Extension (gemm.cpp)
participant SWIZ as Swizzle Kernel (swizzle.cu)
participant GEMM as GEMM Kernel (cublaslt_gemm.cu)
PY->>CPP: strided_batched_gemm(A, B, out, layout, ...)
alt MXFP8 inputs
CPP->>SWIZ: nvte_pack_mxfp8_scales_for_batched_gemm(compact packed, stream)
SWIZ-->>CPP: packed [batch0_scales][batch1_scales]...
CPP->>SWIZ: nvte_pack_mxfp8_columnwise_scales_for_batched_gemm(compact packed, stream)
SWIZ-->>CPP: packed column-wise scales
end
CPP->>GEMM: "nvte_cublas_gemm_strided_batched(A lda/stridea, B ldb/strideb, C=D, D ldd/strided, ...)"
GEMM->>GEMM: CheckStridedBatchBuffer (bounds check)
GEMM->>GEMM: CheckMXFP8BatchScaleBuffer (scale size check)
GEMM->>GEMM: cublasLtMatmulAlgoGetHeuristic + cublasLtMatmul
GEMM-->>CPP: D (output written in-place)
CPP-->>PY: out tensor
|
| if (use_mxfp8) { | ||
| #if CUBLAS_VERSION >= 120800 | ||
| NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120800, | ||
| "MXFP8 strided batched GEMM requires cuBLAS 12.8+, but run-time cuBLAS version is ", | ||
| transformer_engine::cuda::cublas_version()); | ||
| NVTE_CHECK( | ||
| inputA->with_gemm_swizzled_scales, | ||
| "MXFP8 A scales for strided batched GEMM must be packed by batch and GEMM-swizzled."); | ||
| NVTE_CHECK( | ||
| inputB->with_gemm_swizzled_scales, | ||
| "MXFP8 B scales for strided batched GEMM must be packed by batch and GEMM-swizzled."); | ||
| fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv); | ||
| fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv); | ||
| NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, | ||
| CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, | ||
| &A_scale_inverse, sizeof(A_scale_inverse))); | ||
| NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, | ||
| CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, | ||
| &B_scale_inverse, sizeof(B_scale_inverse))); | ||
| const cublasLtMatmulMatrixScale_t scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; | ||
| const cublasLtMatmulMatrixScale_t scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; | ||
|
|
||
| NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( | ||
| operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a))); | ||
| NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( | ||
| operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b))); | ||
|
|
||
| // Workaround for the same heuristic-cache issue handled in the non-batched MXFP8 GEMM. | ||
| if (transformer_engine::cuda::cublas_version() <= 120803) { | ||
| const int64_t dummy_a_vec_stride = 1; | ||
| NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( | ||
| operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, | ||
| sizeof(dummy_a_vec_stride))); | ||
| } | ||
| #else | ||
| NVTE_ERROR( | ||
| "MXFP8 strided batched GEMM requires cuBLAS 12.8+, but compile-time cuBLAS version is ", | ||
| CUBLAS_VERSION); | ||
| #endif // CUBLAS_VERSION >= 120800 | ||
| } |
There was a problem hiding this comment.
cuBLASLt descriptor leak on every failed MXFP8 call
Adesc, Bdesc, Cdesc, Ddesc (lines 1010–1023) and operationDesc (line 1029) are all successfully created before this block. If use_mxfp8 is true and the binary was compiled with CUBLAS_VERSION < 120800, the NVTE_ERROR on line 1081 throws an exception. Equally, even when compiled against a new-enough cuBLAS, the runtime NVTE_CHECK at line 1048 (comparing cublas_version() >= 120800) will throw if the runtime library is older. In both cases the five already-created descriptors are never destroyed, leaking one handle per call attempt. The cleanup block at the end of the function is only reached on the happy path.
| def _validate_input(self, input_: torch.Tensor) -> int: | ||
| if is_quantized_tensor(input_): | ||
| raise ValueError("BatchedGEMM expects a high-precision input tensor") | ||
| if not isinstance(input_, torch.Tensor): | ||
| raise TypeError( | ||
| f"BatchedGEMM expects a torch.Tensor input (got {type(input_).__name__})" | ||
| ) |
There was a problem hiding this comment.
is_quantized_tensor(input_) is evaluated before isinstance(input_, torch.Tensor). For a completely non-tensor input (e.g., a plain Python list), is_quantized_tensor may raise an unexpected error before the more informative TypeError is reached. Checking isinstance first ensures the type guard fires before any attribute access.
| def _validate_input(self, input_: torch.Tensor) -> int: | |
| if is_quantized_tensor(input_): | |
| raise ValueError("BatchedGEMM expects a high-precision input tensor") | |
| if not isinstance(input_, torch.Tensor): | |
| raise TypeError( | |
| f"BatchedGEMM expects a torch.Tensor input (got {type(input_).__name__})" | |
| ) | |
| def _validate_input(self, input_: torch.Tensor) -> int: | |
| if not isinstance(input_, torch.Tensor): | |
| raise TypeError( | |
| f"BatchedGEMM expects a torch.Tensor input (got {type(input_).__name__})" | |
| ) | |
| if is_quantized_tensor(input_): | |
| raise ValueError("BatchedGEMM expects a high-precision input tensor") |
Description
This PR adds experimental cuBLASLt-backed strided batched GEMM support to Transformer Engine, together with a new fusible PyTorch
BatchedGEMMoperation.The main use case is applying one weight matrix per group without rearranging interleaved activations, such as
where
Gis the batch dim. Both contiguous[G, ..., D]and[..., G, D]layouts are supported through explicit leading dimensions and batch strides, avoiding data permutation or.contiguous()calls.For MXFP8, cuBLASLt expects each batch's swizzled scaling factors to be stored consecutively. MXFP8 tensors therefore retain compact scales, while the PyTorch extension packs the required row-wise or column-wise scales into temporary per-call buffers immediately before launching GEMM. No persistent batch-packed scale metadata is exposed through the Python API.
Type of change
Changes
[batch0_scales][batch1_scales]....[G, ..., D]and[..., G, D]activation layouts.Current limitations
[G, ..., D]or[..., G, D]layouts.Checklist: