[PyTorch] Enable fused FP8 block-scaling path in GroupedLinear module and fusible ops#3171
[PyTorch] Enable fused FP8 block-scaling path in GroupedLinear module and fusible ops#3171denera wants to merge 7 commits into
Conversation
…ng quantize The default Float8BlockScaling recipe constrains scales to powers of 2, so the fused grouped path must honor the flag to stay numerically consistent with the unfused path. Thread a runtime pow_2_scales argument through the grouped quantize kernels (the shared scale helper already implements the rounding) and drop the force_pow_2_scales rejections. Also add a quantization-config parameter to nvte_group_quantize_dbias, which previously had no way to receive force_pow_2_scales or amax_epsilon on the bgrad path. Signed-off-by: Alp Dener <adener@nvidia.com>
…r module Admit Float8BlockQuantizer in the fused GroupedTensor path on Hopper. The existing usage flags already match the Hopper TN-only mapping and the grouped GEMM selects transposed columnwise storage for NN/NT layouts, so only the path predicate changes. The fused path is an explicit opt-in via NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM, so raise on Blackwell (SM100/SM110) instead of silently falling back; the fused path has no MXFP8-broadcast emulation. Extend the fused dbias path (tex.bgrad_group_quantize) to FP8 block scaling when dgrad is required (dbias is computed in the rowwise pass). Add fp8_block_scaling to the fused-path tests with a Hopper-only gate, assert the fused path engages via a group_quantize spy, and add a Blackwell error-path test. Signed-off-by: Alp Dener <adener@nvidia.com>
Replace the blanket FP8 block-scaling rejection in BasicOperation.reset_recipe_state with a per-op supports_float8_block_scaling flag and opt in the GroupedLinear op. Mirror the module-path predicate and fused-bgrad changes; since the graph-safe flow is default-on here (no env-var opt-in), other architectures fall back to the split-quantize flow instead of raising. Force use_split_accumulator=True for FP8 block-scaling operands in general_grouped_gemm_for_grouped_tensor, matching non-grouped general_gemm: cuBLAS has no fast-accum FP8 block-scaling algorithm, so the ops-layer forward failed algo selection without it. Add fp8_block_scaling coverage to the ops GroupedLinear tests. The CUDA-graph-safe test skips it for now: the replayed wgrad for the last expert diverges between replays depending on process allocation history; under investigation. Graph capture remains covered by the module-path test. Signed-off-by: Alp Dener <adener@nvidia.com>
general_grouped_gemm_for_grouped_tensor allocated its setup workspace (the cuBLAS per-group pointer/dimension arrays) and its cuBLAS workspace with per-call torch.empty. Under make_graphed_callables the forward and backward graphs share one capture memory pool, and a per-call allocation's block returns to that pool as soon as the Python reference dies, so blocks alias across the two graphs and captured kernels from one graph overwrite the GEMM metadata the other graph reads at replay. Observed as allocation-history-dependent failures in the ops-layer GroupedLinear cuda-graph test: capture-time cublasLtMatmulAlgoGetHeuristic NOT_SUPPORTED errors and corrupted wgrad outputs. This is also the likely mechanism behind the FP8 block-scaling wgrad corruption under CUDA graphs previously observed on Hopper and attributed to cuBLAS. Cache the setup workspace per (device, group size) and reuse the cached per-device cuBLAS workspace from the non-grouped path; consecutive GEMMs reusing one workspace are ordered by the stream. Signed-off-by: Alp Dener <adener@nvidia.com>
…ole cuBLAS workspaces The grouped-tensor GEMM path shared one persistent cuBLAS workspace across all grouped matmuls. cuBLAS's grouped GEMM keeps a grid-synchronization flag in the first bytes of that workspace and zeros it (via a captured memset) before each matmul. When the dgrad and wgrad grouped matmuls of a GroupedLinear backward share one workspace inside a replayed CUDA graph, that flag is aliased between the two matmuls; on the second graph replay the second matmul's cooperative kernel deadlocks with cuBLAS 13.6 (and corrupts the last expert's wgrad on cuBLAS < 13.6). The two matmuls are strictly stream-ordered (single stream, all-DEFAULT graph edges, no programmatic dependent launch), so this is shared-workspace reuse, not concurrent co-scheduling. Give dgrad/forward (slot 0) and wgrad (slot 1) distinct persistent cuBLAS workspaces, dedicated to the grouped path. Each slot remains a single persistent allocation, so CUDA-graph capture safety is preserved. Also drop the cuBLAS-version gate that skipped the FP8 block-scaling GroupedLinear CUDA-graph test, so it now exercises the fix on all supported cuBLAS versions. Signed-off-by: Alp Dener <adener@nvidia.com>
bbfa1da to
b35924c
Compare
for more information, see https://pre-commit.ci
|
/te-ci core pytorch |
Greptile SummaryThis PR integrates FP8 block-scaling quantization into the fused
Confidence Score: 5/5Safe to merge. The cuBLAS workspace fix correctly isolates dgrad and wgrad into separate persistent buffers, the pow-2-scales threading is mechanically complete through all call sites, and the Blackwell error path is tested with a dedicated negative test. All changed kernels receive the new pow_2_scales argument at every launch site; the prior runtime rejection of force_pow_2_scales=True is removed in both C++ and CUDA layers consistently. The dual-workspace strategy is sound: fprop and dgrad share slot 0 safely because they run in separate CUDA-graph captures, while dgrad and wgrad—which are in the same backward graph and were the source of the deadlock—now get distinct 32 MiB buffers. The fuse_bgrad guard that gates FP8 block-scaling dbias fusion on input_requires_grad correctly maps to the kernel dependency on a rowwise quantization pass. Test coverage adds parametric cases across both Hopper-only SM-range tests, a new Blackwell error path test, and call-count instrumentation to confirm the fused path is actually exercised. No files require special attention. The most load-bearing changes are in gemm.py (workspace caching) and group_quantize_fp8_blockwise.cuh (pow_2_scales threading), both of which are mechanically complete and well-documented. Important Files Changed
Reviews (2): Last reviewed commit: "[PyTorch] Address review: document split..." | Re-trigger Greptile |
…ale dbias comment - general_grouped_gemm_for_grouped_tensor: expand the comment to state that the fused grouped FP8 block-scaling GEMM forces use_split_accumulator=True and intentionally overrides the caller-supplied value, consistent with the Float8BlockScaling recipe (which fixes it True for fprop/dgrad/wgrad). - Float8BlockScaling recipe docstring: document that FP8 block scaling always uses split accumulation and that the fused grouped GEMM path ignores any caller- or recipe-supplied use_split_accumulator value. - GroupedLinear ops backward: correct the stale "BF16/FP16 path" comment; that branch also handles quantized paths where bgrad fusion did not apply (e.g. FP8 block scaling without a dgrad pass). Signed-off-by: Alp Dener <adener@nvidia.com>
Description
This PR integrates the new grouped FP8 block-scaling quantization kernels into the fused GroupedTensor path in GroupedLinear.
Notes
Type of change
Checklist: