Skip to content

[PyTorch] Enable fused FP8 block-scaling path in GroupedLinear module and fusible ops#3171

Open
denera wants to merge 7 commits into
NVIDIA:mainfrom
denera:pytorch/grouped-linear-fused-fp8bs
Open

[PyTorch] Enable fused FP8 block-scaling path in GroupedLinear module and fusible ops#3171
denera wants to merge 7 commits into
NVIDIA:mainfrom
denera:pytorch/grouped-linear-fused-fp8bs

Conversation

@denera

@denera denera commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR integrates the new grouped FP8 block-scaling quantization kernels into the fused GroupedTensor path in GroupedLinear.

Notes

  • DGRAD and WGRAD grouped GEMMs executing serially on the same stream appear to violate the workspace serialization under some specific graph replay conditions, leading to a deadlock in WGRAD. PR resolves the deadlock by allocating and caching TWO 32MiB cuBLAS workspaces in order to isolate DGRAD and WGRAD workspaces.
  • PR enables pow_2 constrained scales for grouped FP8 block-scaling at the kernel level for numerical parity with the split-quantize path on SM90, but the FP8BS->MXFP8 scale-broadcast mechanism remains disabled on SM100+.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

denera added 5 commits July 2, 2026 20:35
…ng quantize

The default Float8BlockScaling recipe constrains scales to powers of 2,
so the fused grouped path must honor the flag to stay numerically
consistent with the unfused path. Thread a runtime pow_2_scales argument
through the grouped quantize kernels (the shared scale helper already
implements the rounding) and drop the force_pow_2_scales rejections.

Also add a quantization-config parameter to nvte_group_quantize_dbias,
which previously had no way to receive force_pow_2_scales or
amax_epsilon on the bgrad path.

Signed-off-by: Alp Dener <adener@nvidia.com>
…r module

Admit Float8BlockQuantizer in the fused GroupedTensor path on Hopper.
The existing usage flags already match the Hopper TN-only mapping and
the grouped GEMM selects transposed columnwise storage for NN/NT
layouts, so only the path predicate changes.

The fused path is an explicit opt-in via
NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM, so raise on Blackwell
(SM100/SM110) instead of silently falling back; the fused path has no
MXFP8-broadcast emulation.

Extend the fused dbias path (tex.bgrad_group_quantize) to FP8 block
scaling when dgrad is required (dbias is computed in the rowwise pass).

Add fp8_block_scaling to the fused-path tests with a Hopper-only gate,
assert the fused path engages via a group_quantize spy, and add a
Blackwell error-path test.

Signed-off-by: Alp Dener <adener@nvidia.com>
Replace the blanket FP8 block-scaling rejection in
BasicOperation.reset_recipe_state with a per-op
supports_float8_block_scaling flag and opt in the GroupedLinear op.
Mirror the module-path predicate and fused-bgrad changes; since the
graph-safe flow is default-on here (no env-var opt-in), other
architectures fall back to the split-quantize flow instead of raising.

Force use_split_accumulator=True for FP8 block-scaling operands in
general_grouped_gemm_for_grouped_tensor, matching non-grouped
general_gemm: cuBLAS has no fast-accum FP8 block-scaling algorithm, so
the ops-layer forward failed algo selection without it.

Add fp8_block_scaling coverage to the ops GroupedLinear tests. The
CUDA-graph-safe test skips it for now: the replayed wgrad for the last
expert diverges between replays depending on process allocation
history; under investigation. Graph capture remains covered by the
module-path test.

Signed-off-by: Alp Dener <adener@nvidia.com>
general_grouped_gemm_for_grouped_tensor allocated its setup workspace
(the cuBLAS per-group pointer/dimension arrays) and its cuBLAS
workspace with per-call torch.empty. Under make_graphed_callables the
forward and backward graphs share one capture memory pool, and a
per-call allocation's block returns to that pool as soon as the Python
reference dies, so blocks alias across the two graphs and captured
kernels from one graph overwrite the GEMM metadata the other graph
reads at replay. Observed as allocation-history-dependent failures in
the ops-layer GroupedLinear cuda-graph test: capture-time
cublasLtMatmulAlgoGetHeuristic NOT_SUPPORTED errors and corrupted
wgrad outputs. This is also the likely mechanism behind the FP8
block-scaling wgrad corruption under CUDA graphs previously observed
on Hopper and attributed to cuBLAS.

Cache the setup workspace per (device, group size) and reuse the
cached per-device cuBLAS workspace from the non-grouped path;
consecutive GEMMs reusing one workspace are ordered by the stream.

Signed-off-by: Alp Dener <adener@nvidia.com>
…ole cuBLAS workspaces

The grouped-tensor GEMM path shared one persistent cuBLAS workspace across all
grouped matmuls. cuBLAS's grouped GEMM keeps a grid-synchronization flag in the
first bytes of that workspace and zeros it (via a captured memset) before each
matmul. When the dgrad and wgrad grouped matmuls of a GroupedLinear backward share
one workspace inside a replayed CUDA graph, that flag is aliased between the two
matmuls; on the second graph replay the second matmul's cooperative kernel
deadlocks with cuBLAS 13.6 (and corrupts the last expert's wgrad on cuBLAS < 13.6).
The two matmuls are strictly stream-ordered (single stream, all-DEFAULT graph
edges, no programmatic dependent launch), so this is shared-workspace reuse, not
concurrent co-scheduling.

Give dgrad/forward (slot 0) and wgrad (slot 1) distinct persistent cuBLAS
workspaces, dedicated to the grouped path. Each slot remains a single persistent
allocation, so CUDA-graph capture safety is preserved.

Also drop the cuBLAS-version gate that skipped the FP8 block-scaling GroupedLinear
CUDA-graph test, so it now exercises the fix on all supported cuBLAS versions.

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the pytorch/grouped-linear-fused-fp8bs branch from bbfa1da to b35924c Compare July 2, 2026 20:35
@denera

denera commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci core pytorch

@denera denera linked an issue Jul 2, 2026 that may be closed by this pull request
@greptile-apps

greptile-apps Bot commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR integrates FP8 block-scaling quantization into the fused GroupedTensor path of GroupedLinear, enabling grouped FP8 block-scaling GEMMs on Hopper (SM90) while explicitly raising on Blackwell when the user opts into the fused path via env var. A separate fix eliminates a cuBLAS-workspace-aliasing deadlock under CUDA-graph replay by caching two distinct 32 MiB cuBLAS workspaces (one for fprop/dgrad, one for wgrad) and a persistent setup workspace.

  • FP8 block-scaling fused path (Hopper-only): The force_pow_2_scales restriction is lifted from the grouped quantize kernels by threading the bool pow_2_scales argument down to the device kernels; all callers from dispatch/quantize.cuh and cast_grouped_dbias.cu are updated, including the new NVTEQuantizationConfig argument on nvte_group_quantize_dbias.
  • Dual cuBLAS workspace fix: general_grouped_gemm_for_grouped_tensor now fetches one of two @lru_cache-backed 32 MiB workspaces keyed by (device, slot) where slot 1 is wgrad-exclusive, preventing the cooperative-kernel grid-sync flag from being aliased across dgrad and wgrad in the same replayed CUDA graph.
  • supports_float8_block_scaling opt-in flag: Added to BasicOperation (default False) so the existing NotImplementedError becomes op-specific instead of blanket-blocking all fusible ops.

Confidence Score: 5/5

Safe to merge. The cuBLAS workspace fix correctly isolates dgrad and wgrad into separate persistent buffers, the pow-2-scales threading is mechanically complete through all call sites, and the Blackwell error path is tested with a dedicated negative test.

All changed kernels receive the new pow_2_scales argument at every launch site; the prior runtime rejection of force_pow_2_scales=True is removed in both C++ and CUDA layers consistently. The dual-workspace strategy is sound: fprop and dgrad share slot 0 safely because they run in separate CUDA-graph captures, while dgrad and wgrad—which are in the same backward graph and were the source of the deadlock—now get distinct 32 MiB buffers. The fuse_bgrad guard that gates FP8 block-scaling dbias fusion on input_requires_grad correctly maps to the kernel dependency on a rowwise quantization pass. Test coverage adds parametric cases across both Hopper-only SM-range tests, a new Blackwell error path test, and call-count instrumentation to confirm the fused path is actually exercised.

No files require special attention. The most load-bearing changes are in gemm.py (workspace caching) and group_quantize_fp8_blockwise.cuh (pow_2_scales threading), both of which are mechanically complete and well-documented.

Important Files Changed

Filename Overview
transformer_engine/pytorch/cpp_extensions/gemm.py Replaces per-call workspace allocations with two lru_cache-backed persistent tensors (setup + two cuBLAS slots); adds _is_fp8_blockwise helper to force use_split_accumulator for FP8 block-scaling operands. Logic is correct and well-commented.
transformer_engine/pytorch/module/grouped_linear.py Adds Float8BlockQuantizer detection in _is_grouped_tensor_path_supported (raises on Blackwell with env-var opt-in, returns True on Hopper). The fuse_bgrad guard now gates FP8 block-scaling on ctx.requires_dgrad, matching the kernel requirement for a rowwise pass.
transformer_engine/pytorch/ops/basic/grouped_linear.py Mirrors the module-layer path: FP8 block-scaling returns True on Hopper, falls back on other arches. fuse_bgrad mirrors the module's gating on ctx.input_requires_grad. Updated comment for the fallback dbias path is accurate.
transformer_engine/pytorch/csrc/extensions/cast.cpp bgrad_group_quantize now constructs a QuantizationConfigWrapper from the Float8BlockQuantizer's knobs before the two-call workspace sizing protocol, forwarding force_pow_2_scales and amax_epsilon. MXFP8 stays with the default zero-initialized config (correct).
transformer_engine/pytorch/csrc/quantizer.cpp Removes the runtime check that rejected force_pow_2_scales in create_grouped_tensor, enabling the fused path to honor pow-2 scale constraints. The restriction is now lifted at the kernel level instead.
transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Threads bool pow_2_scales through all four kernel launch paths and passes it to compute_scale_from_types. Kernel signatures for group_quantize_blockwise_1d/2d updated consistently.
transformer_engine/common/cast/dispatch/quantize.cuh Removes the NVTE_CHECK assertions blocking force_pow_2_scales=true in both fwd and bwd helpers; passes the config field through to the blockwise launcher.
transformer_engine/common/cast/cast_grouped_dbias.cu nvte_group_quantize_dbias gains a NVTEQuantizationConfig parameter and forwards it to the dispatch helper instead of nullptr. Public header updated in lockstep.
transformer_engine/pytorch/ops/op.py Adds supports_float8_block_scaling class attribute (default False) so FP8 block-scaling NotImplementedError is op-specific; GroupedLinear opts in with True.
tests/pytorch/test_grouped_linear.py Adds fp8_block_scaling param to grouped-tensor-path and cuda-graph-safe tests, a monkeypatch group_quantize call counter to verify the fused path is taken, and a new Blackwell error test. Skip conditions are correctly scoped to Hopper-only.
transformer_engine/common/recipe/init.py Adds a NOTE docstring explaining why use_split_accumulator is always forced True for FP8 block scaling, addressing the previously flagged silent override concern.

Reviews (2): Last reviewed commit: "[PyTorch] Address review: document split..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
…ale dbias comment

- general_grouped_gemm_for_grouped_tensor: expand the comment to state that the fused
  grouped FP8 block-scaling GEMM forces use_split_accumulator=True and intentionally
  overrides the caller-supplied value, consistent with the Float8BlockScaling recipe
  (which fixes it True for fprop/dgrad/wgrad).
- Float8BlockScaling recipe docstring: document that FP8 block scaling always uses
  split accumulation and that the fused grouped GEMM path ignores any caller- or
  recipe-supplied use_split_accumulator value.
- GroupedLinear ops backward: correct the stale "BF16/FP16 path" comment; that branch
  also handles quantized paths where bgrad fusion did not apply (e.g. FP8 block
  scaling without a dgrad pass).

Signed-off-by: Alp Dener <adener@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Blockwise (1x128 and 128x128) FP8 grouped quantization

1 participant