[PyTorch] Add joint forward-backward op fusion pass#3080
Conversation
Introduce a third operation fusion pass for joint forward-backward fusions, applied before the forward-only and backward-only passes. A joint fused op implements both fuser_forward and fuser_backward, so the two halves can cooperate (e.g. the forward saving reduced state that only its own backward knows how to recompute) and need not be individually interchangeable with the unfused ops. Add register_forward_backward_fusion, split the fusion application and basic-op reconciliation in OperationFuser so the forward/backward passes build on the joint grouping, and spell out the interchangeability contracts in the register_*_fusion docstrings. Add a custom joint fusion unit test and a user guide section. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Greptile SummaryThis PR adds infrastructure for joint forward-backward op fusion in TransformerEngine's PyTorch op fuser. Previously, forward and backward fusions were applied independently starting from the same list of basic ops; now a third "joint" pass runs first and produces a shared
Confidence Score: 5/5The change is safe to merge: it adds a new optional fusion pass that is a no-op when no joint fusion functions are registered, leaving all existing forward-only and backward-only fusion behavior unchanged. The refactoring is narrowly scoped — it introduces one new class-level registry list, splits a single classmethod into two well-tested statics, and adds the three-stage pipeline in No files require special attention. The core logic in Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["basic_ops (list of BasicOperation)"] --> B["_apply_fusions\n(forward_backward_fusion_functions)"]
B --> C["joint_ops\n(BasicOp | JointFusedOp)"]
C --> D["_apply_fusions\n(forward_fusion_functions)"]
C --> E["_apply_fusions\n(backward_fusion_functions)"]
D --> F["_map_to_basic_ops"]
E --> G["_map_to_basic_ops"]
F --> H["_forward_ops\n(op, basic_op_idxs)"]
G --> I["_backward_ops\n(op, basic_op_idxs)"]
H --> J["Forward pass"]
I --> K["Backward pass"]
J <-- "shared JointFusedOp instance" --> K
Reviews (2): Last reviewed commit: "Review suggestion from @vthumbe1503" | Re-trigger Greptile |
| dsilu = s * (1 + y * (1 - s)) | ||
| dy = dout * dsilu | ||
|
|
||
| # Linear backward | ||
| dx = torch.matmul(dy, w).to(dtype=dtype) | ||
| dw = torch.matmul(dy.T, x).to(dtype=dtype) | ||
|
|
||
| # grad_input, grad params per basic op, grad extra inputs per basic op | ||
| return dx, [(dw,), ()], [(), ()] | ||
|
|
There was a problem hiding this comment.
_enabled flag silently drops the joint fusion on re-fuse
CustomLinearSiLU._enabled is set to False after the first call to fuse_ops, so if maybe_fuse_ops is triggered a second time (e.g. recipe type changes, first_op_requiring_backward shifts, or amax-history length changes), the fusion function returns the ops list unchanged. On that re-run _forward_ops and _backward_ops would revert to the two unfused basic ops, causing the assert isinstance(forward_ops[0][0], CustomLinearSiLU) assertions below to fail silently or with a confusing error rather than a clear "joint fusion was not reapplied" message. The current test is safe because it only calls the model once, but the pattern is fragile: any future extension that adds a second forward call (e.g., to test different recipe configurations) will break without an obvious explanation. Consider either resetting _enabled at the start of each maybe_fuse_ops-triggering call, or restructuring the fuse function to be idempotent (fuse only if the first op is not already a CustomLinearSiLU).
There was a problem hiding this comment.
The test is quick and hacky. The right fix would be a way to unregister fusions, but that's outside the scope of this PR.
| def fuse_ops( | ||
| ops: list[FusibleOperation], | ||
| **unused, | ||
| ) -> list[FusibleOperation]: | ||
| """Apply fusion the first time this function is called""" | ||
| if CustomLinearSiLU._enabled: | ||
| CustomLinearSiLU._enabled = False | ||
| op = CustomLinearSiLU(linear=ops[0], silu=ops[1]) | ||
| return [op] + ops[2:] |
There was a problem hiding this comment.
fuse_ops indexes ops without bounds or type guards
When _enabled is True, the function unconditionally accesses ops[0] and ops[1] and constructs CustomLinearSiLU(linear=ops[0], silu=ops[1]) without checking len(ops) >= 2 or that the ops are the expected types. If a future pipeline change reduces the number of basic ops to fewer than two (or the joint-fusion pass is called earlier in a different context), this raises an uncaught IndexError with no diagnostic message. Compared to the documented sliding-window pattern in op_fuser.rst — which uses isinstance checks before fusing — the test's fuse_ops skips these guards entirely, making it a less reliable reference for users adapting this code.
There was a problem hiding this comment.
The test is quick and hacky. This function prioritizes simplicity over robustness.
|
/te-ci pytorch |
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
|
/te-ci pytorch |
Description
The op fuser assumes that forward fusions and backward fusions can be applied independently, so we enforce a contract that fused ops are interchangeable with the corresponding unfused ops. However, in the process of developing the grouped MLP fused ops, we have identified several optimizations that only make sense when the forward and backward fusions are performed together (e.g. recomputing FC2 input tensors instead of caching for backward).
This PR adds infrastructure to support joint forward-backward fusions, which relax the interchangeability contract so the forward and backward passes can have coupled implementations. Refactoring the grouped MLP ops as such a joint fused op will be deferred to a follow-up PR.
Type of change
Changes
Checklist: