Skip to content

[PyTorch] Add torch.compile custom-op path for Linear#9

Open
pggPL wants to merge 6 commits into
tensor_proto_mechanismfrom
linear_compile
Open

[PyTorch] Add torch.compile custom-op path for Linear#9
pggPL wants to merge 6 commits into
tensor_proto_mechanismfrom
linear_compile

Conversation

@pggPL

@pggPL pggPL commented Jun 7, 2026

Copy link
Copy Markdown
Owner

Build a torch.compile custom-op framework in dynamo.py that traces Linear forward+backward as single graph nodes (no graph break into the eager autograd.Function):

  • OpaqueSimpleMetadata bundle (with nested-dict support) and per-field buckets mapping the LinearFwd/BwdArgs dataclasses to op schema slots; quantizers ride through as value-opaque objects (own slot each).
  • Fakes stay TensorProto -> TensorProto; the framework converts tensor inputs to TensorProto at the boundary (via dataclasses.replace, which Dynamo can trace) and materializes output protos into fake tensors in register_fake. Saved-tensor aliases resolved by name from ctx_attrs.
  • Two-tier op + register_torch_dispatch flattens Float8Tensor weight inputs; quantized outputs are rebuilt via _ToSubclassFn (autograd-aware).
  • Mirror _linear_forward_impl set_usage / input+weight pipeline in the forward fake so register_fake layout matches eager.
  • Replace LinearBwdArgs.fp8_recipe with precomputed split-accumulator bools (recipe object is not compile-safe across the op boundary).
  • Dispatch through the op under torch.compiler.is_compiling(); drop @no_torch_dynamo from Linear.forward.

Tests: test_te_linear_compiles (bf16 + every recipe), quantized FP8 weight input. Backward through a Float8Tensor output is a strict xfail (AOTAutograd demands a subclass cotangent and the linear backward has no FP8-cotangent path).

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@pggPL pggPL force-pushed the tensor_proto_mechanism branch from 2cccc30 to 2e252f9 Compare June 16, 2026 15:31
@pggPL pggPL requested a review from cyanguwa as a code owner June 16, 2026 15:31
pggPL added a commit that referenced this pull request Jun 16, 2026
Squashed PR #9 (linear_compile) onto the rebased base.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pggPL added a commit that referenced this pull request Jun 16, 2026
Squashed PR #9 (linear_compile) onto the rebased base.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pggPL added a commit that referenced this pull request Jun 16, 2026
Squashed PR #9 (linear_compile) onto the rebased base.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the tensor_proto_mechanism branch from 2e252f9 to ba92f5b Compare June 16, 2026 16:05
pggPL added a commit that referenced this pull request Jun 16, 2026
Squashed PR #9 (linear_compile) onto the rebased base.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the linear_compile branch 2 times, most recently from 9a81f4a to dcf5b2f Compare June 16, 2026 16:06
@pggPL pggPL force-pushed the tensor_proto_mechanism branch from ba92f5b to b1273ea Compare June 16, 2026 16:12
pggPL added a commit that referenced this pull request Jun 16, 2026
Squashed PR #9 (linear_compile) onto the rebased base.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pggPL added a commit that referenced this pull request Jun 22, 2026
Squashed PR #9 (linear_compile) onto the rebased base.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

@kshitij12345 kshitij12345 left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Review the custom_op.py file.

torch._dynamo.reset()
# dynamic=False for now: a symbolic shape would land in an OpaqueValueBundle
# (value-opaque op arg) whose hash chokes on non-nested SymInt. Force static
# shapes (recompile per shape) until the bundle handles symbolic shapes.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dynamic=False for now: What happens with dynamic=True? Is it a hard error Or a graph break?

if isinstance(input_quantizer, NVFP4Quantizer):
rht_matrix = get_rht_matrix(
input_quantizer._with_random_sign_mask, inp.device.index
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious to know, are these still required if the warm-up allocated these tensors so that these are not allocated on CUDAGraph's private memory-pool and hence the CUDAGraph Trees doesn't error out due to allocated tensor not being returned (/silently ignored).


_TE_OP_NAMESPACE = "transformer_engine_compile"


Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: It would probably be good to have a comment explain the abstractions in this file, their purpose and how to fit together.

"""Whether ``value`` may be stored inside an instance (recursive)."""
if isinstance(value, cls.PRIMITIVE_TYPES):
return True
if isinstance(value, Enum):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should Enum be a primitive?


# Workaround for PyTorch issue: FxGraphCachePickler handles FakeScriptObject
# but not the real ProcessGroup that appears in example_inputs at inductor
# compile time. Register a copyreg reducer so the pickler can hash the key.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: Should this be supported by PyTorch?

_register_autograd_for_op(
fwd_op=inner_fwd_def, bwd_op_name=inner_bwd_name, **autograd_common
)
_register_autograd_for_op(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: does the outer op required grad rule?

"""
raise NotImplementedError

def pack(self, owner: Any) -> List[Tuple[str, Any]]:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be called flatten (or maybe unpack). pack seems like we are going to pack objects into one single object but this seems to do the opposite.

@pggPL pggPL force-pushed the tensor_proto_mechanism branch 9 times, most recently from 29e5245 to 99c1377 Compare June 29, 2026 13:10
@pggPL pggPL force-pushed the tensor_proto_mechanism branch 3 times, most recently from 9e78a6c to 50c11cd Compare June 29, 2026 13:46
pggPL and others added 6 commits June 30, 2026 11:38
…stants; fix SP memory leak; test suite hook-up

Wrap CommOverlapCore pybind11 methods that return compile-time constants
so torch.compile(fullgraph=True) can trace through them without graph
breaks:
- `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py;
  `_ub_is_fp8()` in gemm.py
- `with_cublasmp()` → `ub_is_cublasmp()` in base.py

All callers in linear.py, layernorm_linear.py, layernorm_mlp.py,
base.py, gemm.py, userbuffers_backward_linear.py and
userbuffers_forward_linear.py updated.

Fix quantized grad_output not being freed early for column-parallel SP
backward. Row-parallel SP already called clear_tensor_data(grad_output)
to release the gathered tensor; column-parallel SP quantizes grad_output
to Float8TensorStorage but never freed it before returning.  Under
torch.compile reduce-overhead this leaves 3 live pool tensors at
recording end and triggers "Detected 3 tensor(s) in the cudagraph pool
not tracked as outputs".  Extend the existing clear_tensor_data guard to
cover both parallel modes.

Fix custom-recipe quantizer state being re-initialised on every forward
call even when the recipe object has not changed. The existing early-exit
for CustomRecipeState was missing an identity check on the recipe object,
so any repeated call with the same recipe would bypass the early-return
and rebuild quantizers unnecessarily.  Add `if recipe_state.recipe is
recipe: return` to restore the intended caching behaviour.

Add test_torch_compile.py to L0_pytorch_unittest so the autocast and
existing compile tests run in CI.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
(cherry picked from commit bfce3a7)
for more information, see https://pre-commit.ci

(cherry picked from commit afe364b)
ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit
the no-roles warning, which graph-breaks under fullgraph=True. qfactory
dispatches on role.tensor_type instead of a pre-baked string key.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
(cherry picked from commit 22f80e4)
Squashed PR #9 (linear_compile) onto the rebased base.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
(cherry picked from commit 84dbc6b)
…tom_op

Replace the low-level torch.library.Library (_TE_LIB.define/.impl + functional register_fake/register_autograd/register_torch_dispatch with lib=) with the standard torch.library.custom_op API, passing the dynamically built schema explicitly via schema=. register_fake/register_autograd/register_torch_dispatch are now methods on the returned CustomOpDef. Drops the TOR901 Library usage and is robust to re-registration (get_library_allowing_overwrite).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
(cherry picked from commit d590560)
…e sentinel

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants