Skip to content

Add fused multi-tensor kernel for 1D blockwise FP8 quantization#3168

Draft
shangxiaokang wants to merge 3 commits into
NVIDIA:release_v2.8from
shangxiaokang:fused-multi-tensor
Draft

Add fused multi-tensor kernel for 1D blockwise FP8 quantization#3168
shangxiaokang wants to merge 3 commits into
NVIDIA:release_v2.8from
shangxiaokang:fused-multi-tensor

Conversation

@shangxiaokang

Copy link
Copy Markdown

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jul 2, 2026
@shangxiaokang shangxiaokang marked this pull request as draft July 2, 2026 10:07
@greptile-apps

greptile-apps Bot commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds a fused multi-tensor CUDA kernel for 1D block-scaled FP8 quantization, amortising kernel-launch overhead when multiple tensors share the same quantization parameters. The implementation extracts the existing per-block logic into a __device__ helper, introduces a MultiBlockwiseQuantizeArgs struct to pack up to 32 tensors into a single kernel launch, and wires a new nvte_multi_quantize_transpose_vector_blockwise C API into the PyTorch extension.

  • A non-atomic bool flag (g_printed_blockwise_fp8_multi_tensor_kernel) used for a once-only debug print creates a data race in multi-threaded runtimes and should be replaced with std::atomic<bool> (or the print removed entirely).
  • The fprintf(stderr, ...) diagnostic message fires on the first invocation and cannot be suppressed; it should be gated on an environment variable or removed before merge.

Confidence Score: 3/5

The fused kernel logic is architecturally sound, but the non-atomic global flag introduces a data race that is undefined behaviour in C++ on any multi-threaded runtime, and the unconditional stderr diagnostic print will surface unexpectedly in user environments on the first training step.

A plain bool read-modify-write across threads without synchronisation is undefined behaviour; in practice it can cause duplicate prints or torn reads. The stderr diagnostic is also an atypical side-effect for a library kernel dispatch. Both issues are in the same file and are straightforward to fix, but they affect runtime correctness and user experience enough to warrant a revision before merge.

transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu — the non-atomic flag and stderr print at the kernel launch site.

Important Files Changed

Filename Overview
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu Core change: extracts the single-tensor kernel body into a device helper and adds a new multi-tensor kernel with per-block tensor dispatch. Contains a data race on a non-atomic global flag used for a one-shot diagnostic print, and debug stderr output that should not reach production.
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds the fused blockwise multi-tensor path ahead of the per-tensor fallback. Adds an early return after nvte_multi_cast_transpose to avoid fall-through into the new code. Logic looks correct; noop handling is consistent with the individual quantize path.
transformer_engine/common/include/transformer_engine/transpose.h Adds public C API declaration for nvte_multi_quantize_transpose_vector_blockwise with complete doxygen documentation; no issues found.
transformer_engine/common/transpose/cast_transpose.h Adds internal declaration for multi_quantize_transpose_vector_blockwise; declaration matches the implementation signature.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[multi_tensor_quantize_impl] --> B{All FP8 delayed-scaling with both rowwise+columnwise?}
    B -- Yes --> C[nvte_multi_cast_transpose fused kernel]
    C --> Z[return]
    B -- No --> D{All FP8 1D blockwise with matching params?}
    D -- Yes --> E[Build MultiBlockwiseQuantizeArgs aligned + unaligned batches]
    E --> F[launch aligned kernel]
    E --> G[launch unaligned kernel]
    F --> H[multi_block_scaled_1d_cast_transpose_kernel linear scan for tensor_id]
    G --> H
    D -- No --> I[Individual quantize loop per-tensor fallback]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[multi_tensor_quantize_impl] --> B{All FP8 delayed-scaling with both rowwise+columnwise?}
    B -- Yes --> C[nvte_multi_cast_transpose fused kernel]
    C --> Z[return]
    B -- No --> D{All FP8 1D blockwise with matching params?}
    D -- Yes --> E[Build MultiBlockwiseQuantizeArgs aligned + unaligned batches]
    E --> F[launch aligned kernel]
    E --> G[launch unaligned kernel]
    F --> H[multi_block_scaled_1d_cast_transpose_kernel linear scan for tensor_id]
    G --> H
    D -- No --> I[Individual quantize loop per-tensor fallback]
Loading

Comments Outside Diff (1)

  1. transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu, line 232-236 (link)

    P2 Linear scan for tensor-ID lookup lacks an upper-bound guard

    The while (args.block_range[tensor_id + 1] <= bid) loop increments tensor_id without checking tensor_id < args.num_tensors. In production this is safe because n_blocks == block_range[num_tensors] so bid < block_range[num_tensors] always holds. However, the absence of an upper-bound means a single off-by-one in the caller (e.g. wrong n_blocks passed to the launch) would silently walk past the end of block_range and read uninitialised data. A tensor_id < args.num_tensors guard in the loop condition or an assertion would make the invariant explicit and aid future debugging.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +170 to +171
constexpr int kMaxTensorsPerBlockwiseKernel = 32; // Keep kernel args comfortably under 4 KB.
bool g_printed_blockwise_fp8_multi_tensor_kernel = false;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Non-atomic global flag causes a data race

g_printed_blockwise_fp8_multi_tensor_kernel is a plain bool, not std::atomic<bool>. In a multi-threaded PyTorch runtime (e.g. two threads dispatching this kernel path simultaneously for the first time), both threads can concurrently read false, concurrently write true, and both print — a classic read-modify-write data race that is undefined behaviour in C++. Beyond the thread-safety issue, writing diagnostic output to stderr from a production kernel dispatch path is atypical for a library: it pollutes the user's error stream on the very first model step and cannot be silenced. Consider replacing with std::atomic<bool> and std::memory_order_relaxed, or removing the print entirely if it was left over from development.

Suggested change
constexpr int kMaxTensorsPerBlockwiseKernel = 32; // Keep kernel args comfortably under 4 KB.
bool g_printed_blockwise_fp8_multi_tensor_kernel = false;
std::atomic<bool> g_printed_blockwise_fp8_multi_tensor_kernel{false};

Comment on lines +610 to +618
}
if (!g_printed_blockwise_fp8_multi_tensor_kernel) {
g_printed_blockwise_fp8_multi_tensor_kernel = true;
std::fprintf(stderr,
"[TransformerEngine] Using blockwise FP8 multi-tensor quantize kernel "
"(num_tensors=%d, num_blocks=%d, aligned=%d)\n",
kernel_args.num_tensors, n_blocks, kAligned ? 1 : 0);
std::fflush(stderr);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Diagnostic fprintf to stderr left in production kernel dispatch

This print fires (and flushes) on the first invocation of the fused multi-tensor kernel. For a library called inside large training loops this message will appear unexpectedly in users' stderr and cannot be disabled. If the intent is to aid debugging, it should either be guarded by an environment variable (e.g. NVTE_DEBUG) or removed before merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant