Skip to content

[PyTorch] Add optional caller-provided output/grad-input buffers to GroupedLinear and fused grouped MLP#3161

Open
phu0ngng wants to merge 6 commits into
NVIDIA:mainfrom
phu0ngng:pyt-gg-w-symm
Open

[PyTorch] Add optional caller-provided output/grad-input buffers to GroupedLinear and fused grouped MLP#3161
phu0ngng wants to merge 6 commits into
NVIDIA:mainfrom
phu0ngng:pyt-gg-w-symm

Conversation

@phu0ngng

@phu0ngng phu0ngng commented Jul 1, 2026

Copy link
Copy Markdown
Collaborator

Description

Let callers have an option to supply preallocated buffers for the forward output and backward input gradient of GroupedLinear and the fused grouped MLP, instead of the module/op allocating them internally. The caller can pass a symmetric-memory buffer that is fed directly into TE EP's combine, eliminating the D2D copy from the token buffer into TE EP's internal staging buffer.

Type of change

  • Documentation change
  • Bug fix
  • New feature (non-breaking change which adds functionality)
  • Breaking change
  • Infra/Build change
  • Code refactoring

Changes

  • GroupedLinear.forward: optional out / dgrad_out buffers, wired through the legacy and GroupedTensor paths (incl. FP8). Validated via a shared validate_or_alloc_output helper (shape/dtype/device/contiguous/no-grad).
  • Fusible ops: Sequential.forward takes an op_kwargs mapping (keyed by module or index) that routes per-op kwargs to the target op, used to pass output/grad-input buffers to grouped linear / grouped MLP.
  • Fused grouped MLP: the GEMMs write the output and dgrad directly into the caller buffers for both nvfp4 and MXFP8, with no copy. The GEMM writes only the packed valid rows (per m_splits); any padded tail is left untouched. Buffers may be supplied independently.

Dependency

The in-place MXFP8 path requires a cuDNN frontend change that exposes an optional output tensor on the grouped-GEMM quant wrapper: NVIDIA/cudnn-frontend#338.

Testing

  • test_grouped_linear_caller_output_buffers: covers out/dgrad_out/both across the legacy and GroupedTensor paths _ buffer aliasing, bit-exact match vs. internal allocation, untouched padded tail, and ValueError on shape mismatch.
  • test_grouped_linear_caller_buffers (ops) and test_grouped_mlp_caller_buffers: verify the fused op runs, the output aliases the buffer, and grads match internal allocation.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@phu0ngng phu0ngng requested a review from ksivaman as a code owner July 1, 2026 11:35
@phu0ngng phu0ngng requested a review from ptrendx July 1, 2026 11:37
@greptile-apps

greptile-apps Bot commented Jul 1, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds optional caller-provided preallocated buffers for the forward output (out) and backward grad-input (dgrad_out) in GroupedLinear and the fused grouped MLP, allowing symmetric-memory buffers to be fed directly into TE EP's combine without an intermediate D2D copy.

  • A shared validate_or_alloc_output helper is added in _common.py for shape/dtype/device/contiguity/grad validation, used consistently across both the module-level _GroupedLinear and the fused op paths.
  • Sequential.forward gains an op_kwargs parameter routing per-op keyword arguments to the correct basic-op slot inside an OperationFuser, backed by a _resolve_op_kwargs helper that walks module groups and aligns slots.
  • Both NVFP4 and MXFP8 fused-MLP forward/backward paths write directly into the caller buffer (MXFP8 via as_strided to match the cuDNN kernel's packed 3D layout); any padded tail rows beyond sum(m_splits) are intentionally left unchanged.

Confidence Score: 5/5

Safe to merge; the buffer aliasing, validation, and routing logic are correct across all code paths examined.

The core validate_or_alloc_output helper is consistent and well-guarded. The _resolve_op_kwargs group-index tracking aligns correctly with _module_groups across every module-layout combination. Both NVFP4 and MXFP8 forward/backward paths correctly alias or write into the caller buffer, and the OperationFuser safely handles basic_op_kwargs=None by expanding it to empty dicts. Tests cover bit-exact match, alias verification, padded-tail invariance, and shape-mismatch rejection. The only items flagged are a stale pylint suppression comment and missing NVFP4 fused-MLP test coverage — neither affects correctness.

No files require special attention; the minor items in ops/basic/grouped_linear.py and tests/pytorch/test_grouped_mlp.py are non-blocking.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/_common.py Adds OUTPUT_BUFFER_KEY / GRAD_INPUT_BUFFER_KEY constants and validate_or_alloc_output helper; straightforward and correct.
transformer_engine/pytorch/ops/sequential.py Adds op_kwargs to forward and _resolve_op_kwargs to route per-op kwargs into fuser basic-op slots; the group-index tracking logic is correct for all module-layout combinations.
transformer_engine/pytorch/module/grouped_linear.py Adds out/dgrad_out parameters to GroupedLinear.forward, a module-local _validate_or_alloc_output, and threads the buffers through both grouped-tensor and legacy paths correctly.
transformer_engine/pytorch/ops/basic/grouped_linear.py Reads out_buffer and dgrad_out from basic_op_kwargs in fuser_forward and save_ctx; the pylint disable=unused-argument on basic_op_kwargs in save_ctx is now stale since the parameter is actively used.
transformer_engine/pytorch/ops/fused/grouped_mlp.py Wires output_buffer and dgrad_out through both NVFP4 and MXFP8 forward/backward paths; MXFP8 uses as_strided to provide the caller buffer in the kernel's 3D packed format. NVFP4 caller-buffer path has no dedicated test in the fused-MLP test suite.
tests/pytorch/test_grouped_linear.py Adds thorough parametrized tests covering out/dgrad_out/both on legacy and grouped-tensor paths, bit-exact match, padded-tail invariance, and shape-mismatch ValueError.
tests/pytorch/test_fusible_ops.py Adds test_grouped_linear_caller_buffers covering two-op Sequential routing, alias verification, bit-exact match, and shape-mismatch rejection.
tests/pytorch/test_grouped_mlp.py Adds test_grouped_mlp_caller_buffers for the MXFP8 fused path; the NVFP4 fused-MLP path with caller buffers is not covered.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Caller
    participant Sequential
    participant OperationFuser
    participant GroupedLinear_fc1
    participant GroupedLinear_fc2
    participant GroupedMLP_CuTeGEMM

    Caller->>Sequential: forward(x, split_sizes, op_kwargs)
    Sequential->>Sequential: _resolve_op_kwargs → group_kwargs per fuser
    Sequential->>OperationFuser: "forward(*xs, basic_op_kwargs)"
    OperationFuser->>OperationFuser: maybe_fuse_ops → GroupedMLP_CuTeGEMMGLU

    alt Fused MLP path
        OperationFuser->>GroupedMLP_CuTeGEMM: fuser_forward(basic_op_kwargs)
        GroupedMLP_CuTeGEMM->>GroupedMLP_CuTeGEMM: "output_buffer = basic_op_kwargs[-1].get(output)"
        GroupedMLP_CuTeGEMM->>GroupedMLP_CuTeGEMM: "fc1_ctx.dgrad_out = basic_op_kwargs[0].get(grad_input)"
        GroupedMLP_CuTeGEMM->>GroupedMLP_CuTeGEMM: GEMM writes into output_buffer (no copy)
        GroupedMLP_CuTeGEMM-->>Caller: fc2_out aliasing output_buffer
    else Unfused sequential path
        OperationFuser->>GroupedLinear_fc1: fuser_forward(basic_op_kwargs[0])
        GroupedLinear_fc1->>GroupedLinear_fc1: save dgrad_out on ctx
        OperationFuser->>GroupedLinear_fc2: fuser_forward(basic_op_kwargs[1])
        GroupedLinear_fc2->>GroupedLinear_fc2: validate_or_alloc_output(out_buffer)
        GroupedLinear_fc2-->>Caller: out aliasing out_buffer
    end

    Caller->>Sequential: backward(dy)
    Sequential->>GroupedMLP_CuTeGEMM: fuser_backward
    GroupedMLP_CuTeGEMM->>GroupedMLP_CuTeGEMM: fc1_ctx.dgrad_out → validate_or_alloc_output
    GroupedMLP_CuTeGEMM->>GroupedMLP_CuTeGEMM: GEMM dgrad into grad_input_buffer
    GroupedMLP_CuTeGEMM-->>Caller: grad_input aliasing grad_input_buffer
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Caller
    participant Sequential
    participant OperationFuser
    participant GroupedLinear_fc1
    participant GroupedLinear_fc2
    participant GroupedMLP_CuTeGEMM

    Caller->>Sequential: forward(x, split_sizes, op_kwargs)
    Sequential->>Sequential: _resolve_op_kwargs → group_kwargs per fuser
    Sequential->>OperationFuser: "forward(*xs, basic_op_kwargs)"
    OperationFuser->>OperationFuser: maybe_fuse_ops → GroupedMLP_CuTeGEMMGLU

    alt Fused MLP path
        OperationFuser->>GroupedMLP_CuTeGEMM: fuser_forward(basic_op_kwargs)
        GroupedMLP_CuTeGEMM->>GroupedMLP_CuTeGEMM: "output_buffer = basic_op_kwargs[-1].get(output)"
        GroupedMLP_CuTeGEMM->>GroupedMLP_CuTeGEMM: "fc1_ctx.dgrad_out = basic_op_kwargs[0].get(grad_input)"
        GroupedMLP_CuTeGEMM->>GroupedMLP_CuTeGEMM: GEMM writes into output_buffer (no copy)
        GroupedMLP_CuTeGEMM-->>Caller: fc2_out aliasing output_buffer
    else Unfused sequential path
        OperationFuser->>GroupedLinear_fc1: fuser_forward(basic_op_kwargs[0])
        GroupedLinear_fc1->>GroupedLinear_fc1: save dgrad_out on ctx
        OperationFuser->>GroupedLinear_fc2: fuser_forward(basic_op_kwargs[1])
        GroupedLinear_fc2->>GroupedLinear_fc2: validate_or_alloc_output(out_buffer)
        GroupedLinear_fc2-->>Caller: out aliasing out_buffer
    end

    Caller->>Sequential: backward(dy)
    Sequential->>GroupedMLP_CuTeGEMM: fuser_backward
    GroupedMLP_CuTeGEMM->>GroupedMLP_CuTeGEMM: fc1_ctx.dgrad_out → validate_or_alloc_output
    GroupedMLP_CuTeGEMM->>GroupedMLP_CuTeGEMM: GEMM dgrad into grad_input_buffer
    GroupedMLP_CuTeGEMM-->>Caller: grad_input aliasing grad_input_buffer
Loading

Reviews (5): Last reviewed commit: "Use 256-aligned splits in caller-buffer ..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/grouped_linear.py
Comment thread tests/pytorch/test_grouped_linear.py
@pggPL pggPL self-requested a review July 1, 2026 12:59
@phu0ngng phu0ngng requested a review from timmoon10 as a code owner July 1, 2026 16:07
@phu0ngng phu0ngng marked this pull request as draft July 1, 2026 16:08
Comment on lines +175 to +176
output: Optional[torch.Tensor] = None,
grad_input: Optional[torch.Tensor] = None,

@timmoon10 timmoon10 Jul 1, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm skeptical of this API because it looks general, but is actually only supported with grouped linear and grouped MLP. How about something like the following:

def forward(
    self,
    input: torch.Tensor,
    *extra_inputs: torch.Tensor,
    extra_kwargs: Optional[Sequence[Optional[dict[str, Any]]]] = None,
):
    ...

    if extra_kwargs is None:
        extra_kwargs = [None] * len(self)

    # Figure out what extra_kwargs go to each module group, and fail if a non-None extra_kwargs goes to a non-basic op

    for module_group in self._module_groups:
            xs = module_group(*xs, basic_op_kwargs=module_group_basic_op_kwargs)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Tim. Good call. The output and grad_input kwargs overpromised generality. Building on your suggestion, I went with a generic per op kwargs pass through, but keyed by module or index instead of a positional list:

seq(x, split_sizes, probs, split_sizes,
    op_kwargs={fc1: {GRAD_INPUT_BUFFER_KEY: g}, fc2: {OUTPUT_BUFFER_KEY: o}})

Sequential resolves each key to its group and basic op slot and forwards it as basic_op_kwargs, with no new dispatch path, and raises if a key targets a non-fusible op. This keeps the container kwarg-agnostic, so the buffers are now just op specific kwargs, while avoiding None padding and positional coupling.

Let me know what you think about this new design.

# when the value above was materialized into it; otherwise it is copied.
if output_buffer is not None:
output_buffer = validate_or_alloc_output(output_buffer, fc2_out_shape, dtype, device)
if output_buffer.data_ptr() != fc2_out.data_ptr():

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under what case this condition can be false? It seems to be it is always true, I do not see we feed the user provided buffer to the kernel?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You were right that it was always true for MXFP8, which uses CuTe GEMM.

I pushed the latest commit in which I reworked it so that it does not trigger extra copies: the caller buffer is passed straight into the CuTe GEMM via d_tensor, so the kernel writes output/dgrad in place for both nvfp4 and MXFP8. This needs a cuDNN FE change I made a PR for: NVIDIA/cudnn-frontend#338 (until it lands and the bundled FE is bumped, the in-place MXFP8 path can't run). So for testing, we will need to check out this cuDNN FE manually.

@phu0ngng phu0ngng changed the title [PyTorch] Add optional caller-provided out/dgrad_out buffers to GroupedLinear [PyTorch] Add optional caller-provided output/grad-input buffers to GroupedLinear and fused grouped MLP Jul 2, 2026
@phu0ngng phu0ngng marked this pull request as ready for review July 2, 2026 14:16
phu0ngng and others added 4 commits July 2, 2026 07:33
…ar module and fusible ops

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ping

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…uffers, eliminating the D2D copy + cleanup

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
pre-commit-ci Bot and others added 2 commits July 2, 2026 14:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants