Skip to content

Reverse MXFP8 quantization row raster#3170

Open
sraman-rgb wants to merge 2 commits into
NVIDIA:mainfrom
sraman-rgb:mxfp8-quantize-reverse-raster
Open

Reverse MXFP8 quantization row raster#3170
sraman-rgb wants to merge 2 commits into
NVIDIA:mainfrom
sraman-rgb:mxfp8-quantize-reverse-raster

Conversation

@sraman-rgb

Copy link
Copy Markdown
Contributor

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Siddhartha Raman Sundara Raman <sraman@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jul 2, 2026
@greptile-apps

greptile-apps Bot commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR reverses the Y-axis row raster order for MXFP8 quantization kernels, making them process row blocks from bottom-to-top rather than top-to-bottom. Within each block, stages are also processed in reverse, with all global offsets, scale indices, and bounds checks updated consistently.

  • quantize_mxfp8.cuh: Remaps blockIdx.yblock_id_Y = gridDim.y - 1 - blockIdx.y, and introduces logical_stage = STAGES - 1 - stage for all TMA loads, scale writes, and output stores.
  • grouped_layout.cuh: Adds linear_block_id_to_reverse_y_cta_coords to convert linear block IDs to reversed-Y CTA coordinates, and updates advance_to_next_job to take work_blocks_Y.
  • group_quantize_mxfp8.cuh: Uses the new helper for initial block assignment and persistent-kernel scheduling, with prefetch and processing loops updated to use logical_stage-based offsets.
  • New test: Validates that both the generic and grouped quantization paths produce bitwise-identical results after the raster reversal.

Confidence Score: 4/5

Safe to merge; the reverse-raster logic is applied consistently across block Y coordinates, stage offsets, scale indices, bounds checks, and dbias offsets in both the single-tensor and grouped kernel paths.

All offsets derived from the Y block index are consistently updated, and the advance_to_next_job signature change is complete across all call sites. The new test validates bitwise correctness for the generic TMA path and grouped persistent-kernel path, though the specialized cast-only path (column count divisible by 128) is not directly tested.

quantize_mxfp8.cuh and group_quantize_mxfp8.cuh carry the most complex multi-stage pipeline logic; the prefetch priming loop and the stage processing loop both had to be updated in sync, which they are.

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh Introduces reversed Y-raster via block_id_Y and logical_stage; all derived offsets are consistently updated including the initial TMA prefetch.
transformer_engine/common/cast/core/grouped_layout.cuh Adds linear_block_id_to_reverse_y_cta_coords helper and updates advance_to_next_job signature; logic is straightforward and correct.
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh Uses new reverse-Y helper for initial block assignment and advance_to_next_job; prefetch and processing loops apply logical_stage correctly.
tests/pytorch/mxfp8/test_mxfp8_quantize_raster.py Tests generic and grouped paths with N=96; padded-column reference approach is sound but the specialized cast-only path is not covered.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["Launch MXFP8 kernel"] --> B["Compute block_id_Y = gridDim.y - 1 - blockIdx.y"]
    B --> C["block_offset_Y = block_id_Y x CHUNK_DIM_Y"]
    C --> D["Initial prefetch: rows at (STAGES-1)*BUFF_DIM_Y"]
    D --> E["Stage loop: logical_stage = STAGES-1 - stage"]
    E --> F["stage_offset_Y = logical_stage x BUFF_DIM_Y"]
    F --> G["TMA store @ block_offset_Y + stage_offset_Y"]
    G --> J{More stages?}
    J -- Yes --> E
    J -- No --> K["advance_to_next_job via linear_block_id_to_reverse_y_cta_coords"]
    K --> L{More jobs?}
    L -- Yes --> B
    L -- No --> M["Done"]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["Launch MXFP8 kernel"] --> B["Compute block_id_Y = gridDim.y - 1 - blockIdx.y"]
    B --> C["block_offset_Y = block_id_Y x CHUNK_DIM_Y"]
    C --> D["Initial prefetch: rows at (STAGES-1)*BUFF_DIM_Y"]
    D --> E["Stage loop: logical_stage = STAGES-1 - stage"]
    E --> F["stage_offset_Y = logical_stage x BUFF_DIM_Y"]
    F --> G["TMA store @ block_offset_Y + stage_offset_Y"]
    G --> J{More stages?}
    J -- Yes --> E
    J -- No --> K["advance_to_next_job via linear_block_id_to_reverse_y_cta_coords"]
    K --> L{More jobs?}
    L -- Yes --> B
    L -- No --> M["Done"]
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +36 to +50
torch.manual_seed(0)
torch.cuda.manual_seed(0)

# N=96 avoids the specialized rowwise cast-only path, exercising the generic TMA path.
x = torch.randn((320, 96), dtype=torch.bfloat16, device="cuda")

q, s = _rowwise_quantize(x)
q_ref, s_ref = _rowwise_quantize_padded_reference(x)

torch.testing.assert_close(q, q_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(s[: x.size(0), : x.size(1) // 32], s_ref, atol=0.0, rtol=0.0)


@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
def test_mxfp8_grouped_quantize_reverse_raster_preserves_values() -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 No coverage for column-count that takes the specialized cast-only path

The comment at line 38 explicitly notes that N=96 is chosen to avoid the specialized rowwise cast-only path. However, the reverse-raster changes in quantize_mxfp8.cuh (via block_id_Y and logical_stage) also affect that specialized path since block_offset_Y and scales_block_offset_Y_rowwise/colwise are computed before any path divergence. A test with a column count that is a multiple of 128 (e.g., 128 or 256) would exercise the specialized path and strengthen confidence that both paths produce correct values after the raster change.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

import transformer_engine_torch as tex
from transformer_engine.pytorch import MXFP8Quantizer

recipe_available, reason_for_no_recipe = te.is_mxfp8_available(return_reason=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The module-level recipe_available / reason_for_no_recipe assignments run at import time; if the import is collected by pytest without a GPU, the te.is_mxfp8_available call could produce misleading error messages before any skipif guard fires. Assigning inside a helper called lazily is the safer pattern used in other test files in this repo.

Suggested change
recipe_available, reason_for_no_recipe = te.is_mxfp8_available(return_reason=True)
def _get_mxfp8_availability():
return te.is_mxfp8_available(return_reason=True)
recipe_available, reason_for_no_recipe = _get_mxfp8_availability()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant