Skip to content

Add MXFP8 support with cuBLASMp#3145

Open
almogsegal wants to merge 2 commits into
NVIDIA:mainfrom
almogsegal:add-mxfp8-and-nvfp4-with-cublasmp
Open

Add MXFP8 support with cuBLASMp#3145
almogsegal wants to merge 2 commits into
NVIDIA:mainfrom
almogsegal:add-mxfp8-and-nvfp4-with-cublasmp

Conversation

@almogsegal

@almogsegal almogsegal commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add MXFP8 support in comm_gemm.cpp (cuBLASMp path).
  • Add cuBLASMp + MXFP8 tests in test_comm_gemm_overlap.py.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 25, 2026
@almogsegal

almogsegal commented Jun 25, 2026

Copy link
Copy Markdown
Contributor Author

MXFP8 comparison of cuBLASMp vs UB on DGX-B200:

GPUs Op UB time cuBLASMp time Speedup Faster
2 AG 0.0969 ms 0.0802 ms 1.21x cuBLASMp
2 RS 0.0867 ms 0.0730 ms 1.19x cuBLASMp
4 AG 0.1215 ms 0.1255 ms 1.03x UB
4 RS 0.1357 ms 0.1049 ms 1.29x cuBLASMp
8 AG 0.2380 ms 0.2146 ms 1.11x cuBLASMp
8 RS 0.2191 ms 0.2141 ms 1.02x cuBLASMp

@almogsegal almogsegal force-pushed the add-mxfp8-and-nvfp4-with-cublasmp branch from c6517dd to 0c4c53a Compare June 25, 2026 10:19
@greptile-apps

greptile-apps Bot commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR enables MXFP8 (block scaling) support for the cuBLASMp comm+GEMM overlap path by extending cublasmp_gemm to canonicalize MXFP8 inputs and configure the appropriate VEC32_UE8M0 scale mode. The corresponding test skip conditions are removed so the feature is now exercised in CI.

  • comm_gemm.cpp: Adds an MXFP8 branch inside canonicalize_input that selects rowwise vs. columnwise data/scales based on the transpose direction (keeping the transpose flag unchanged, since MXFP8 columnwise data preserves logical shape unlike tensor-FP8). Adds a compile-time guard (CUBLASMP_VERSION < 801) that injects a runtime error for unsupported library versions, and uses CUBLASMP_MATMUL_MATRIX_SCALE_VEC32_UE8M0 for scale mode setup.
  • test_comm_gemm_overlap.py: Removes the two pytest.skip calls that blocked cuBLASMp+MXFP8 combinations and introduces an explicit COMM_GEMM_QUANTIZATION_PARAMS list with human-readable test IDs to replace the two separate parametrize decorators.

Confidence Score: 5/5

Safe to merge; the MXFP8 canonicalization path follows the established FP8 pattern and is protected by both a compile-time version guard and runtime assertions on pointer validity.

The new MXFP8 branch in canonicalize_input correctly keeps the transpose flag unchanged (MXFP8 columnwise data is not a transposed view), mirrors the cublaslt_gemm.cu reference path, and guards the entire path with CUBLASMP_VERSION checks. Validation is thorough: mixed scaling modes are rejected, swizzled-scale format is asserted, and per-direction null checks are in place. Test infrastructure change is mechanical and correct.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/common/comm_gemm/comm_gemm.cpp Extends cublasmp_gemm with MXFP8 block-scaling support: adds validation, a new canonicalize_input MXFP8 branch that picks rowwise/columnwise data without flipping the transpose flag, and a VEC32_UE8M0 scale mode with a compile-time CUBLASMP_VERSION guard. Logic is consistent with the cublaslt_gemm.cu reference path.
tests/pytorch/distributed/test_comm_gemm_overlap.py Removes two pytest.skip guards that blocked cuBLASMp+MXFP8 test cases, and consolidates the separate use_cublasmp / quantization parametrize decorators into a single explicit COMM_GEMM_QUANTIZATION_PARAMS list with readable IDs.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[cublasmp_gemm called] --> B{scaling_mode A & B}
    B -->|both tensor_scaling| C[FP8 / BF16 path]
    B -->|both mxfp8_scaling| D[MXFP8 path]
    B -->|mixed / unknown| E[NVTE_CHECK fails]

    D --> F{CUBLASMP_VERSION < 801?}
    F -->|yes| G[NVTE_ERROR - runtime throw]
    F -->|no| H[Check with_gemm_swizzled_scales]

    H --> I[canonicalize_input A]
    I --> J{transa?}
    J -->|yes| K[use row-wise data, keep trans flag]
    J -->|no| L[use columnwise data, keep trans flag]

    H --> M[canonicalize_input B]
    M --> N{transb?}
    N -->|yes| O[use columnwise data, keep trans flag]
    N -->|no| P[use row-wise data, keep trans flag]

    K --> Q[Set A scale mode VEC32_UE8M0]
    L --> Q
    O --> R[Set B scale mode VEC32_UE8M0]
    P --> R

    Q --> U[cublasMpMatmul]
    R --> U
    C --> T[Set scale mode SCALAR_FP32]
    T --> U
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[cublasmp_gemm called] --> B{scaling_mode A & B}
    B -->|both tensor_scaling| C[FP8 / BF16 path]
    B -->|both mxfp8_scaling| D[MXFP8 path]
    B -->|mixed / unknown| E[NVTE_CHECK fails]

    D --> F{CUBLASMP_VERSION < 801?}
    F -->|yes| G[NVTE_ERROR - runtime throw]
    F -->|no| H[Check with_gemm_swizzled_scales]

    H --> I[canonicalize_input A]
    I --> J{transa?}
    J -->|yes| K[use row-wise data, keep trans flag]
    J -->|no| L[use columnwise data, keep trans flag]

    H --> M[canonicalize_input B]
    M --> N{transb?}
    N -->|yes| O[use columnwise data, keep trans flag]
    N -->|no| P[use row-wise data, keep trans flag]

    K --> Q[Set A scale mode VEC32_UE8M0]
    L --> Q
    O --> R[Set B scale mode VEC32_UE8M0]
    P --> R

    Q --> U[cublasMpMatmul]
    R --> U
    C --> T[Set scale mode SCALAR_FP32]
    T --> U
Loading

Reviews (2): Last reviewed commit: "Enable cuBLASMp MXFP8 overlap tests" | Re-trigger Greptile

Comment thread tests/pytorch/distributed/run_gemm_with_overlap.py Outdated
Comment thread transformer_engine/common/comm_gemm/comm_gemm.cpp Outdated
Comment thread transformer_engine/common/comm_gemm/comm_gemm.cpp Outdated
@almogsegal almogsegal force-pushed the add-mxfp8-and-nvfp4-with-cublasmp branch 2 times, most recently from 85280d5 to d0559ca Compare June 25, 2026 10:54
@almogsegal almogsegal changed the title Add MXFP8 and NVFP4 support with cuBLASMp Add MXFP8 support with cuBLASMp Jun 25, 2026
@almogsegal almogsegal force-pushed the add-mxfp8-and-nvfp4-with-cublasmp branch 2 times, most recently from c0e28c9 to b7ce126 Compare June 25, 2026 11:04

@denera denera left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, but our build flow supports cuBLASMp 0.8.0 while the new MXFP8 path requires 0.8.1. This is not accounted for in the tests, so they're going to blanket fail on MXFP8 when cuBLASMp version is 0.8.0.

I'd recommend just bumping the minimum cuBLASMp version to 0.8.1 in the build flow (see here).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants