Skip to content

[Common][PyTorch] Add strided batched GEMM in BF16/MXFP8#3160

Open
yaox12 wants to merge 3 commits into
NVIDIA:mainfrom
yaox12:xiny/batch_gemm
Open

[Common][PyTorch] Add strided batched GEMM in BF16/MXFP8#3160
yaox12 wants to merge 3 commits into
NVIDIA:mainfrom
yaox12:xiny/batch_gemm

Conversation

@yaox12

@yaox12 yaox12 commented Jul 1, 2026

Copy link
Copy Markdown
Member

Description

This PR adds experimental cuBLASLt-backed strided batched GEMM support to Transformer Engine, together with a new fusible PyTorch BatchedGEMM operation.

The main use case is applying one weight matrix per group without rearranging interleaved activations, such as

[S, B, G, D] @ [G, R, D] -> [S, B, G, R]

where G is the batch dim. Both contiguous [G, ..., D] and [..., G, D] layouts are supported through explicit leading dimensions and batch strides, avoiding data permutation or .contiguous() calls.
For MXFP8, cuBLASLt expects each batch's swizzled scaling factors to be stored consecutively. MXFP8 tensors therefore retain compact scales, while the PyTorch extension packs the required row-wise or column-wise scales into temporary per-call buffers immediately before launching GEMM. No persistent batch-packed scale metadata is exposed through the Python API.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add an experimental common API for cuBLASLt strided batched GEMM.
  • Support independent leading dimensions and batch strides for A, B, C, and D.
  • Support high-precision and MXFP8 input pairs with high-precision output.
  • Add kernels for packing compact MXFP8 row-wise and column-wise scales as:
    [batch0_scales][batch1_scales]....
  • Pack MXFP8 scales internally in the PyTorch C++ extension before each GEMM launch.
  • Add the fusible PyTorch BatchedGEMM operation with: [G, ..., D] and [..., G, D] activation layouts.
    • High-precision and MXFP8 forward and backward.
    • Cached row-wise and column-wise forward quantization for backward reuse.
    • Compact MXFP8 weights created through quantized_model_init.
    • Explicit rejection of backward overrides and unsupported recipes.
  • Add corresponding tests

Current limitations

  • A and B must both be high precision or both be MXFP8.
  • Output quantization and fused epilogues are not supported.
  • MXFP8 operands must use contiguous [G, ..., D] or [..., G, D] layouts.
  • Tensor parallelism is not supported by BatchedGEMM.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

yaox12 added 3 commits July 1, 2026 00:45
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
@greptile-apps

greptile-apps Bot commented Jul 1, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds experimental cuBLASLt-backed strided batched GEMM support to Transformer Engine, covering BF16 and MXFP8 precision, along with a new fusible BatchedGEMM PyTorch operation that handles both [G, ..., D] and [..., G, D] activation layouts without requiring data permutation.

  • New C API (nvte_cublas_gemm_strided_batched): accepts explicit leading dimensions and batch strides for A, B, C, and D; supports high-precision and MXFP8 pairs with high-precision output; MXFP8 scales must be pre-packed and swizzled before the call.
  • MXFP8 scale packing (nvte_pack_mxfp8_scales_for_batched_gemm / nvte_pack_mxfp8_columnwise_scales_for_batched_gemm): GPU kernels that gather compact per-batch scales into the consecutive [batch0_scales][batch1_scales]… layout cuBLASLt requires, handling both batch-major and interleaved layouts.
  • BatchedGEMM fusible op: full forward/backward in BF16 and MXFP8, with cached quantized operands for backward reuse; rejects unsupported recipes and backward overrides.

Confidence Score: 3/5

The new GEMM path is well-tested and mathematically correct, but the C++ GEMM function leaks cuBLASLt handles whenever an MXFP8 call fails due to a cuBLAS version mismatch — both at compile time and at runtime.

In cublas_gemm_strided_batched, Adesc/Bdesc/Cdesc/Ddesc/operationDesc are allocated before the MXFP8 version guard. The compile-time #else NVTE_ERROR and the runtime cublas_version() NVTE_CHECK both throw after those handles exist but before the cleanup block, making it a definite leak on every failed MXFP8 strided batched GEMM call against an old library.

transformer_engine/common/gemm/cublaslt_gemm.cu — the strided batched GEMM function needs RAII-style cleanup for the cuBLASLt descriptors it allocates before the version check.

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_gemm.cu Adds cublas_gemm_strided_batched and nvte_cublas_gemm_strided_batched. Resource leak: Adesc/Bdesc/Cdesc/Ddesc/operationDesc are allocated before the MXFP8 cuBLAS version guard; exceptions thrown there are never caught before cleanup.
transformer_engine/pytorch/ops/basic/batched_gemm.py New BatchedGEMM fusible op with full forward/backward for BF16 and MXFP8. Minor ordering issue in _validate_input. Scale packing, matrix stride logic, and quantizer configuration are correctly implemented.
transformer_engine/pytorch/csrc/extensions/gemm.cpp Adds strided_batched_gemm C++ extension with inline MXFP8 scale packing. C is always aliased to D (accumulate in-place), which is intentional but undocumented at the Python layer.
transformer_engine/common/swizzle/swizzle.cu Adds pack/columnwise-pack kernels for batched GEMM with correct tile-based swizzle for both batch-major and interleaved layouts.
transformer_engine/pytorch/cpp_extensions/gemm.py Python wrapper with correct MXFP8 compact-scale assertion and workspace size handling. C-matrix aliasing to D is undocumented.
tests/pytorch/test_strided_batched_gemm.py New test file with good boundary coverage including accumulate mode, scale rejection, out-of-bounds strides, and invalid scale shapes.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant PY as Python (BatchedGEMM / strided_batched_gemm)
    participant CPP as C++ Extension (gemm.cpp)
    participant SWIZ as Swizzle Kernel (swizzle.cu)
    participant GEMM as GEMM Kernel (cublaslt_gemm.cu)

    PY->>CPP: strided_batched_gemm(A, B, out, layout, ...)
    alt MXFP8 inputs
        CPP->>SWIZ: nvte_pack_mxfp8_scales_for_batched_gemm(compact packed, stream)
        SWIZ-->>CPP: packed [batch0_scales][batch1_scales]...
        CPP->>SWIZ: nvte_pack_mxfp8_columnwise_scales_for_batched_gemm(compact packed, stream)
        SWIZ-->>CPP: packed column-wise scales
    end
    CPP->>GEMM: "nvte_cublas_gemm_strided_batched(A lda/stridea, B ldb/strideb, C=D, D ldd/strided, ...)"
    GEMM->>GEMM: CheckStridedBatchBuffer (bounds check)
    GEMM->>GEMM: CheckMXFP8BatchScaleBuffer (scale size check)
    GEMM->>GEMM: cublasLtMatmulAlgoGetHeuristic + cublasLtMatmul
    GEMM-->>CPP: D (output written in-place)
    CPP-->>PY: out tensor
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant PY as Python (BatchedGEMM / strided_batched_gemm)
    participant CPP as C++ Extension (gemm.cpp)
    participant SWIZ as Swizzle Kernel (swizzle.cu)
    participant GEMM as GEMM Kernel (cublaslt_gemm.cu)

    PY->>CPP: strided_batched_gemm(A, B, out, layout, ...)
    alt MXFP8 inputs
        CPP->>SWIZ: nvte_pack_mxfp8_scales_for_batched_gemm(compact packed, stream)
        SWIZ-->>CPP: packed [batch0_scales][batch1_scales]...
        CPP->>SWIZ: nvte_pack_mxfp8_columnwise_scales_for_batched_gemm(compact packed, stream)
        SWIZ-->>CPP: packed column-wise scales
    end
    CPP->>GEMM: "nvte_cublas_gemm_strided_batched(A lda/stridea, B ldb/strideb, C=D, D ldd/strided, ...)"
    GEMM->>GEMM: CheckStridedBatchBuffer (bounds check)
    GEMM->>GEMM: CheckMXFP8BatchScaleBuffer (scale size check)
    GEMM->>GEMM: cublasLtMatmulAlgoGetHeuristic + cublasLtMatmul
    GEMM-->>CPP: D (output written in-place)
    CPP-->>PY: out tensor
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/cpp_extensions/gemm.py, line 1412-1491 (link)

    P2 C-matrix always aliased to D — undocumented design restriction

    The public C-level API nvte_cublas_gemm_strided_batched accepts independent C, ldc, and stridec parameters (useful for a fused D = A*B + beta*C with distinct accumulation buffers). This Python wrapper always passes out for both C and D, permanently losing that flexibility. A comment in the docstring noting that accumulation is always done in-place (beta is applied to out itself) would prevent surprises for future callers who discover the separate C parameter in the C++ header.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Reviews (1): Last reviewed commit: "add check" | Re-trigger Greptile

Comment on lines +1046 to +1085
if (use_mxfp8) {
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120800,
"MXFP8 strided batched GEMM requires cuBLAS 12.8+, but run-time cuBLAS version is ",
transformer_engine::cuda::cublas_version());
NVTE_CHECK(
inputA->with_gemm_swizzled_scales,
"MXFP8 A scales for strided batched GEMM must be packed by batch and GEMM-swizzled.");
NVTE_CHECK(
inputB->with_gemm_swizzled_scales,
"MXFP8 B scales for strided batched GEMM must be packed by batch and GEMM-swizzled.");
fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse, sizeof(A_scale_inverse)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse)));
const cublasLtMatmulMatrixScale_t scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
const cublasLtMatmulMatrixScale_t scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;

NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b)));

// Workaround for the same heuristic-cache issue handled in the non-batched MXFP8 GEMM.
if (transformer_engine::cuda::cublas_version() <= 120803) {
const int64_t dummy_a_vec_stride = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
sizeof(dummy_a_vec_stride)));
}
#else
NVTE_ERROR(
"MXFP8 strided batched GEMM requires cuBLAS 12.8+, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif // CUBLAS_VERSION >= 120800
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 cuBLASLt descriptor leak on every failed MXFP8 call

Adesc, Bdesc, Cdesc, Ddesc (lines 1010–1023) and operationDesc (line 1029) are all successfully created before this block. If use_mxfp8 is true and the binary was compiled with CUBLAS_VERSION < 120800, the NVTE_ERROR on line 1081 throws an exception. Equally, even when compiled against a new-enough cuBLAS, the runtime NVTE_CHECK at line 1048 (comparing cublas_version() >= 120800) will throw if the runtime library is older. In both cases the five already-created descriptors are never destroyed, leaking one handle per call attempt. The cleanup block at the end of the function is only reached on the happy path.

Comment on lines +236 to +242
def _validate_input(self, input_: torch.Tensor) -> int:
if is_quantized_tensor(input_):
raise ValueError("BatchedGEMM expects a high-precision input tensor")
if not isinstance(input_, torch.Tensor):
raise TypeError(
f"BatchedGEMM expects a torch.Tensor input (got {type(input_).__name__})"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 is_quantized_tensor(input_) is evaluated before isinstance(input_, torch.Tensor). For a completely non-tensor input (e.g., a plain Python list), is_quantized_tensor may raise an unexpected error before the more informative TypeError is reached. Checking isinstance first ensures the type guard fires before any attribute access.

Suggested change
def _validate_input(self, input_: torch.Tensor) -> int:
if is_quantized_tensor(input_):
raise ValueError("BatchedGEMM expects a high-precision input tensor")
if not isinstance(input_, torch.Tensor):
raise TypeError(
f"BatchedGEMM expects a torch.Tensor input (got {type(input_).__name__})"
)
def _validate_input(self, input_: torch.Tensor) -> int:
if not isinstance(input_, torch.Tensor):
raise TypeError(
f"BatchedGEMM expects a torch.Tensor input (got {type(input_).__name__})"
)
if is_quantized_tensor(input_):
raise ValueError("BatchedGEMM expects a high-precision input tensor")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant