Add MXFP8 support with cuBLASMp#3145
Conversation
|
MXFP8 comparison of cuBLASMp vs UB on DGX-B200:
|
c6517dd to
0c4c53a
Compare
Greptile SummaryThis PR enables MXFP8 (block scaling) support for the cuBLASMp comm+GEMM overlap path by extending
Confidence Score: 5/5Safe to merge; the MXFP8 canonicalization path follows the established FP8 pattern and is protected by both a compile-time version guard and runtime assertions on pointer validity. The new MXFP8 branch in canonicalize_input correctly keeps the transpose flag unchanged (MXFP8 columnwise data is not a transposed view), mirrors the cublaslt_gemm.cu reference path, and guards the entire path with CUBLASMP_VERSION checks. Validation is thorough: mixed scaling modes are rejected, swizzled-scale format is asserted, and per-direction null checks are in place. Test infrastructure change is mechanical and correct. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[cublasmp_gemm called] --> B{scaling_mode A & B}
B -->|both tensor_scaling| C[FP8 / BF16 path]
B -->|both mxfp8_scaling| D[MXFP8 path]
B -->|mixed / unknown| E[NVTE_CHECK fails]
D --> F{CUBLASMP_VERSION < 801?}
F -->|yes| G[NVTE_ERROR - runtime throw]
F -->|no| H[Check with_gemm_swizzled_scales]
H --> I[canonicalize_input A]
I --> J{transa?}
J -->|yes| K[use row-wise data, keep trans flag]
J -->|no| L[use columnwise data, keep trans flag]
H --> M[canonicalize_input B]
M --> N{transb?}
N -->|yes| O[use columnwise data, keep trans flag]
N -->|no| P[use row-wise data, keep trans flag]
K --> Q[Set A scale mode VEC32_UE8M0]
L --> Q
O --> R[Set B scale mode VEC32_UE8M0]
P --> R
Q --> U[cublasMpMatmul]
R --> U
C --> T[Set scale mode SCALAR_FP32]
T --> U
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[cublasmp_gemm called] --> B{scaling_mode A & B}
B -->|both tensor_scaling| C[FP8 / BF16 path]
B -->|both mxfp8_scaling| D[MXFP8 path]
B -->|mixed / unknown| E[NVTE_CHECK fails]
D --> F{CUBLASMP_VERSION < 801?}
F -->|yes| G[NVTE_ERROR - runtime throw]
F -->|no| H[Check with_gemm_swizzled_scales]
H --> I[canonicalize_input A]
I --> J{transa?}
J -->|yes| K[use row-wise data, keep trans flag]
J -->|no| L[use columnwise data, keep trans flag]
H --> M[canonicalize_input B]
M --> N{transb?}
N -->|yes| O[use columnwise data, keep trans flag]
N -->|no| P[use row-wise data, keep trans flag]
K --> Q[Set A scale mode VEC32_UE8M0]
L --> Q
O --> R[Set B scale mode VEC32_UE8M0]
P --> R
Q --> U[cublasMpMatmul]
R --> U
C --> T[Set scale mode SCALAR_FP32]
T --> U
Reviews (2): Last reviewed commit: "Enable cuBLASMp MXFP8 overlap tests" | Re-trigger Greptile |
85280d5 to
d0559ca
Compare
c0e28c9 to
b7ce126
Compare
denera
left a comment
There was a problem hiding this comment.
LGTM overall, but our build flow supports cuBLASMp 0.8.0 while the new MXFP8 path requires 0.8.1. This is not accounted for in the tests, so they're going to blanket fail on MXFP8 when cuBLASMp version is 0.8.0.
I'd recommend just bumping the minimum cuBLASMp version to 0.8.1 in the build flow (see here).
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: