Reverse MXFP8 quantization row raster#3170
Conversation
Signed-off-by: Siddhartha Raman Sundara Raman <sraman@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR reverses the Y-axis row raster order for MXFP8 quantization kernels, making them process row blocks from bottom-to-top rather than top-to-bottom. Within each block, stages are also processed in reverse, with all global offsets, scale indices, and bounds checks updated consistently.
Confidence Score: 4/5Safe to merge; the reverse-raster logic is applied consistently across block Y coordinates, stage offsets, scale indices, bounds checks, and dbias offsets in both the single-tensor and grouped kernel paths. All offsets derived from the Y block index are consistently updated, and the advance_to_next_job signature change is complete across all call sites. The new test validates bitwise correctness for the generic TMA path and grouped persistent-kernel path, though the specialized cast-only path (column count divisible by 128) is not directly tested. quantize_mxfp8.cuh and group_quantize_mxfp8.cuh carry the most complex multi-stage pipeline logic; the prefetch priming loop and the stage processing loop both had to be updated in sync, which they are. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["Launch MXFP8 kernel"] --> B["Compute block_id_Y = gridDim.y - 1 - blockIdx.y"]
B --> C["block_offset_Y = block_id_Y x CHUNK_DIM_Y"]
C --> D["Initial prefetch: rows at (STAGES-1)*BUFF_DIM_Y"]
D --> E["Stage loop: logical_stage = STAGES-1 - stage"]
E --> F["stage_offset_Y = logical_stage x BUFF_DIM_Y"]
F --> G["TMA store @ block_offset_Y + stage_offset_Y"]
G --> J{More stages?}
J -- Yes --> E
J -- No --> K["advance_to_next_job via linear_block_id_to_reverse_y_cta_coords"]
K --> L{More jobs?}
L -- Yes --> B
L -- No --> M["Done"]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A["Launch MXFP8 kernel"] --> B["Compute block_id_Y = gridDim.y - 1 - blockIdx.y"]
B --> C["block_offset_Y = block_id_Y x CHUNK_DIM_Y"]
C --> D["Initial prefetch: rows at (STAGES-1)*BUFF_DIM_Y"]
D --> E["Stage loop: logical_stage = STAGES-1 - stage"]
E --> F["stage_offset_Y = logical_stage x BUFF_DIM_Y"]
F --> G["TMA store @ block_offset_Y + stage_offset_Y"]
G --> J{More stages?}
J -- Yes --> E
J -- No --> K["advance_to_next_job via linear_block_id_to_reverse_y_cta_coords"]
K --> L{More jobs?}
L -- Yes --> B
L -- No --> M["Done"]
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| torch.manual_seed(0) | ||
| torch.cuda.manual_seed(0) | ||
|
|
||
| # N=96 avoids the specialized rowwise cast-only path, exercising the generic TMA path. | ||
| x = torch.randn((320, 96), dtype=torch.bfloat16, device="cuda") | ||
|
|
||
| q, s = _rowwise_quantize(x) | ||
| q_ref, s_ref = _rowwise_quantize_padded_reference(x) | ||
|
|
||
| torch.testing.assert_close(q, q_ref, atol=0.0, rtol=0.0) | ||
| torch.testing.assert_close(s[: x.size(0), : x.size(1) // 32], s_ref, atol=0.0, rtol=0.0) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) | ||
| def test_mxfp8_grouped_quantize_reverse_raster_preserves_values() -> None: |
There was a problem hiding this comment.
No coverage for column-count that takes the specialized cast-only path
The comment at line 38 explicitly notes that N=96 is chosen to avoid the specialized rowwise cast-only path. However, the reverse-raster changes in quantize_mxfp8.cuh (via block_id_Y and logical_stage) also affect that specialized path since block_offset_Y and scales_block_offset_Y_rowwise/colwise are computed before any path divergence. A test with a column count that is a multiple of 128 (e.g., 128 or 256) would exercise the specialized path and strengthen confidence that both paths produce correct values after the raster change.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| import transformer_engine_torch as tex | ||
| from transformer_engine.pytorch import MXFP8Quantizer | ||
|
|
||
| recipe_available, reason_for_no_recipe = te.is_mxfp8_available(return_reason=True) |
There was a problem hiding this comment.
The module-level
recipe_available / reason_for_no_recipe assignments run at import time; if the import is collected by pytest without a GPU, the te.is_mxfp8_available call could produce misleading error messages before any skipif guard fires. Assigning inside a helper called lazily is the safer pattern used in other test files in this repo.
| recipe_available, reason_for_no_recipe = te.is_mxfp8_available(return_reason=True) | |
| def _get_mxfp8_availability(): | |
| return te.is_mxfp8_available(return_reason=True) | |
| recipe_available, reason_for_no_recipe = _get_mxfp8_availability() |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: