Skip to content

[PyTorch] Add joint forward-backward op fusion pass#3080

Merged
timmoon10 merged 3 commits into
NVIDIA:mainfrom
timmoon10:tmoon/joint-forward-backward-fusions
Jun 9, 2026
Merged

[PyTorch] Add joint forward-backward op fusion pass#3080
timmoon10 merged 3 commits into
NVIDIA:mainfrom
timmoon10:tmoon/joint-forward-backward-fusions

Conversation

@timmoon10

Copy link
Copy Markdown
Member

Description

The op fuser assumes that forward fusions and backward fusions can be applied independently, so we enforce a contract that fused ops are interchangeable with the corresponding unfused ops. However, in the process of developing the grouped MLP fused ops, we have identified several optimizations that only make sense when the forward and backward fusions are performed together (e.g. recomputing FC2 input tensors instead of caching for backward).

This PR adds infrastructure to support joint forward-backward fusions, which relax the interchangeability contract so the forward and backward passes can have coupled implementations. Refactoring the grouped MLP ops as such a joint fused op will be deferred to a follow-up PR.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add joint forward-backward op fusion pass

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Introduce a third operation fusion pass for joint forward-backward
fusions, applied before the forward-only and backward-only passes. A
joint fused op implements both fuser_forward and fuser_backward, so the
two halves can cooperate (e.g. the forward saving reduced state that
only its own backward knows how to recompute) and need not be
individually interchangeable with the unfused ops.

Add register_forward_backward_fusion, split the fusion application and
basic-op reconciliation in OperationFuser so the forward/backward passes
build on the joint grouping, and spell out the interchangeability
contracts in the register_*_fusion docstrings. Add a custom joint
fusion unit test and a user guide section.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 requested a review from vthumbe1503 June 4, 2026 00:27
@timmoon10 timmoon10 added the enhancement New feature or request label Jun 4, 2026
@greptile-apps

greptile-apps Bot commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds infrastructure for joint forward-backward op fusion in TransformerEngine's PyTorch op fuser. Previously, forward and backward fusions were applied independently starting from the same list of basic ops; now a third "joint" pass runs first and produces a shared joint_ops list from which both the forward-only and backward-only fusion passes are seeded.

  • Splits the old _fuse_ops classmethod into two static helpers — _apply_fusions (runs a sequence of fusion functions) and _map_to_basic_ops (validates and annotates the result) — and adds a new three-stage pipeline in maybe_fuse_ops.
  • Adds register_forward_backward_fusion, a new public API that appends a fusion function to the new OperationFuser.forward_backward_fusion_functions class-level registry; joint fusions are applied before the existing forward-only and backward-only passes so the resulting FusedOperation is referenced by both _forward_ops and _backward_ops.
  • Strengthens bounds checks in _map_to_basic_ops (idx >= len(basic_ops) guards) to prevent silent IndexError on malformed fusion output, and adds a test, documentation, and API export.

Confidence Score: 5/5

The change is safe to merge: it adds a new optional fusion pass that is a no-op when no joint fusion functions are registered, leaving all existing forward-only and backward-only fusion behavior unchanged.

The refactoring is narrowly scoped — it introduces one new class-level registry list, splits a single classmethod into two well-tested statics, and adds the three-stage pipeline in maybe_fuse_ops. When forward_backward_fusion_functions is empty (the default), joint_ops is identical to basic_ops, so existing paths are unaffected. The test verifies both numerical correctness and the critical invariant that the same FusedOperation object is referenced by both _forward_ops and _backward_ops.

No files require special attention. The core logic in fuser.py has been carefully refactored with no regressions to existing behavior.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fuser.py Core change: _fuse_ops classmethod split into _apply_fusions + _map_to_basic_ops statics; maybe_fuse_ops now runs joint fusions first. New register_forward_backward_fusion public API added. Bounds guards in _map_to_basic_ops are a good defensive improvement.
transformer_engine/pytorch/ops/op.py Updated FusedOperation docstring to explain the new equivalence-contract distinction between forward-only/backward-only fusions and joint fusions. No logic changes; doc update is accurate.
tests/pytorch/test_fusible_ops.py Adds test_custom_forward_backward_fused_op demonstrating a joint LinearSiLU op that recomputes the SiLU input in its backward instead of caching it. Verifies both numerical correctness and that the same FusedOperation object appears in both forward and backward op lists.
docs/examples/op_fuser/op_fuser.rst New "Joint forward-backward fusions" section with a worked LinearSiLU example. Docstring update to the existing interchangeability warning is accurate.
transformer_engine/pytorch/ops/init.py Adds register_forward_backward_fusion to the public exports in alphabetical order.
docs/api/pytorch.rst One-line autoapifunction directive for register_forward_backward_fusion inserted alongside the other register functions.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["basic_ops (list of BasicOperation)"] --> B["_apply_fusions\n(forward_backward_fusion_functions)"]
    B --> C["joint_ops\n(BasicOp | JointFusedOp)"]
    C --> D["_apply_fusions\n(forward_fusion_functions)"]
    C --> E["_apply_fusions\n(backward_fusion_functions)"]
    D --> F["_map_to_basic_ops"]
    E --> G["_map_to_basic_ops"]
    F --> H["_forward_ops\n(op, basic_op_idxs)"]
    G --> I["_backward_ops\n(op, basic_op_idxs)"]
    H --> J["Forward pass"]
    I --> K["Backward pass"]
    J <-- "shared JointFusedOp instance" --> K
Loading

Reviews (2): Last reviewed commit: "Review suggestion from @vthumbe1503" | Re-trigger Greptile

Comment on lines +5185 to +5194
dsilu = s * (1 + y * (1 - s))
dy = dout * dsilu

# Linear backward
dx = torch.matmul(dy, w).to(dtype=dtype)
dw = torch.matmul(dy.T, x).to(dtype=dtype)

# grad_input, grad params per basic op, grad extra inputs per basic op
return dx, [(dw,), ()], [(), ()]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _enabled flag silently drops the joint fusion on re-fuse

CustomLinearSiLU._enabled is set to False after the first call to fuse_ops, so if maybe_fuse_ops is triggered a second time (e.g. recipe type changes, first_op_requiring_backward shifts, or amax-history length changes), the fusion function returns the ops list unchanged. On that re-run _forward_ops and _backward_ops would revert to the two unfused basic ops, causing the assert isinstance(forward_ops[0][0], CustomLinearSiLU) assertions below to fail silently or with a confusing error rather than a clear "joint fusion was not reapplied" message. The current test is safe because it only calls the model once, but the pattern is fragile: any future extension that adds a second forward call (e.g., to test different recipe configurations) will break without an obvious explanation. Consider either resetting _enabled at the start of each maybe_fuse_ops-triggering call, or restructuring the fuse function to be idempotent (fuse only if the first op is not already a CustomLinearSiLU).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is quick and hacky. The right fix would be a way to unregister fusions, but that's outside the scope of this PR.

Comment on lines +5196 to +5204
def fuse_ops(
ops: list[FusibleOperation],
**unused,
) -> list[FusibleOperation]:
"""Apply fusion the first time this function is called"""
if CustomLinearSiLU._enabled:
CustomLinearSiLU._enabled = False
op = CustomLinearSiLU(linear=ops[0], silu=ops[1])
return [op] + ops[2:]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 fuse_ops indexes ops without bounds or type guards

When _enabled is True, the function unconditionally accesses ops[0] and ops[1] and constructs CustomLinearSiLU(linear=ops[0], silu=ops[1]) without checking len(ops) >= 2 or that the ops are the expected types. If a future pipeline change reduces the number of basic ops to fewer than two (or the joint-fusion pass is called earlier in a different context), this raises an uncaught IndexError with no diagnostic message. Compared to the documented sliding-window pattern in op_fuser.rst — which uses isinstance checks before fusing — the test's fuse_ops skips these guards entirely, making it a less reliable reference for users adapting this code.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is quick and hacky. This function prioritizes simplicity over robustness.

@timmoon10

Copy link
Copy Markdown
Member Author

/te-ci pytorch

vthumbe1503
vthumbe1503 previously approved these changes Jun 8, 2026

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread transformer_engine/pytorch/ops/fuser.py Outdated
Comment thread transformer_engine/pytorch/ops/fuser.py Outdated
Comment thread transformer_engine/pytorch/ops/fuser.py Outdated
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10

Copy link
Copy Markdown
Member Author

/te-ci pytorch

@timmoon10 timmoon10 merged commit 2a30d03 into NVIDIA:main Jun 9, 2026
19 of 25 checks passed
@timmoon10 timmoon10 deleted the tmoon/joint-forward-backward-fusions branch June 9, 2026 01:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants