Add fused multi-tensor kernel for 1D blockwise FP8 quantization#3168
Add fused multi-tensor kernel for 1D blockwise FP8 quantization#3168shangxiaokang wants to merge 3 commits into
Conversation
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a fused multi-tensor CUDA kernel for 1D block-scaled FP8 quantization, amortising kernel-launch overhead when multiple tensors share the same quantization parameters. The implementation extracts the existing per-block logic into a
Confidence Score: 3/5The fused kernel logic is architecturally sound, but the non-atomic global flag introduces a data race that is undefined behaviour in C++ on any multi-threaded runtime, and the unconditional stderr diagnostic print will surface unexpectedly in user environments on the first training step. A plain bool read-modify-write across threads without synchronisation is undefined behaviour; in practice it can cause duplicate prints or torn reads. The stderr diagnostic is also an atypical side-effect for a library kernel dispatch. Both issues are in the same file and are straightforward to fix, but they affect runtime correctness and user experience enough to warrant a revision before merge. transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu — the non-atomic flag and stderr print at the kernel launch site. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[multi_tensor_quantize_impl] --> B{All FP8 delayed-scaling with both rowwise+columnwise?}
B -- Yes --> C[nvte_multi_cast_transpose fused kernel]
C --> Z[return]
B -- No --> D{All FP8 1D blockwise with matching params?}
D -- Yes --> E[Build MultiBlockwiseQuantizeArgs aligned + unaligned batches]
E --> F[launch aligned kernel]
E --> G[launch unaligned kernel]
F --> H[multi_block_scaled_1d_cast_transpose_kernel linear scan for tensor_id]
G --> H
D -- No --> I[Individual quantize loop per-tensor fallback]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[multi_tensor_quantize_impl] --> B{All FP8 delayed-scaling with both rowwise+columnwise?}
B -- Yes --> C[nvte_multi_cast_transpose fused kernel]
C --> Z[return]
B -- No --> D{All FP8 1D blockwise with matching params?}
D -- Yes --> E[Build MultiBlockwiseQuantizeArgs aligned + unaligned batches]
E --> F[launch aligned kernel]
E --> G[launch unaligned kernel]
F --> H[multi_block_scaled_1d_cast_transpose_kernel linear scan for tensor_id]
G --> H
D -- No --> I[Individual quantize loop per-tensor fallback]
|
| constexpr int kMaxTensorsPerBlockwiseKernel = 32; // Keep kernel args comfortably under 4 KB. | ||
| bool g_printed_blockwise_fp8_multi_tensor_kernel = false; |
There was a problem hiding this comment.
Non-atomic global flag causes a data race
g_printed_blockwise_fp8_multi_tensor_kernel is a plain bool, not std::atomic<bool>. In a multi-threaded PyTorch runtime (e.g. two threads dispatching this kernel path simultaneously for the first time), both threads can concurrently read false, concurrently write true, and both print — a classic read-modify-write data race that is undefined behaviour in C++. Beyond the thread-safety issue, writing diagnostic output to stderr from a production kernel dispatch path is atypical for a library: it pollutes the user's error stream on the very first model step and cannot be silenced. Consider replacing with std::atomic<bool> and std::memory_order_relaxed, or removing the print entirely if it was left over from development.
| constexpr int kMaxTensorsPerBlockwiseKernel = 32; // Keep kernel args comfortably under 4 KB. | |
| bool g_printed_blockwise_fp8_multi_tensor_kernel = false; | |
| std::atomic<bool> g_printed_blockwise_fp8_multi_tensor_kernel{false}; |
| } | ||
| if (!g_printed_blockwise_fp8_multi_tensor_kernel) { | ||
| g_printed_blockwise_fp8_multi_tensor_kernel = true; | ||
| std::fprintf(stderr, | ||
| "[TransformerEngine] Using blockwise FP8 multi-tensor quantize kernel " | ||
| "(num_tensors=%d, num_blocks=%d, aligned=%d)\n", | ||
| kernel_args.num_tensors, n_blocks, kAligned ? 1 : 0); | ||
| std::fflush(stderr); | ||
| } |
There was a problem hiding this comment.
Diagnostic
fprintf to stderr left in production kernel dispatch
This print fires (and flushes) on the first invocation of the fused multi-tensor kernel. For a library called inside large training loops this message will appear unexpectedly in users' stderr and cannot be disabled. If the intent is to aid debugging, it should either be guarded by an environment variable (e.g. NVTE_DEBUG) or removed before merge.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: