Skip to content

[PyTorch][torch.compile] Replace TensorProto with make_empty_traceable#13

Open
kshitij12345 wants to merge 3 commits into
pggPL:linear_compilefrom
kshitij12345:linear_compile_no_proto
Open

[PyTorch][torch.compile] Replace TensorProto with make_empty_traceable#13
kshitij12345 wants to merge 3 commits into
pggPL:linear_compilefrom
kshitij12345:linear_compile_no_proto

Conversation

@kshitij12345

@kshitij12345 kshitij12345 commented Jul 2, 2026

Copy link
Copy Markdown

Replace the 172-line TensorProto dataclass with a single function make_empty_traceable(quantizer, shape, dtype, device) that directly allocates traceable quantized tensors. The fake impls now return actual tensors (which become FakeTensors under register_fake) instead of intermediate descriptors.

The key insight: make_empty_traceable stashes _te_flat_names and _te_flat_ctx on the resulting tensor. Dynamo treats non-callable attributes on traceable wrapper subclasses as constant metadata, so forward_fn can read slot counts and reassembly info from these attributes without calling tensor_flatten (which would cause a graph break since it returns non-Tensor Python objects).

This eliminates:

  • TensorProto class and to_tensor_proto helper (tensor_proto.py deleted)
  • _proto_view (converted tensor fields to TensorProto before fake impls)
  • _tensor_field_names (identified fields for _proto_view)
  • _proto_slot_count / _proto_reassemble (operated on TensorProto objects)
  • TensorProto branch in _value_to_flat_tensors and _format_bwd_result

The fake impls in linear.py now use:

  • isinstance(inp, QuantizedTensorStorage) instead of inp.is_quantized
  • weight._quantizer instead of weight.quantizer (TensorProto field)
  • make_empty_traceable(...) instead of TensorProto(...)
  • Direct set_usage on quantizer instead of proto.update_usage()

Test Plan:

python -m pytest tests/pytorch/test_torch_compile.py -v

Authored with Claude.

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@kshitij12345 kshitij12345 marked this pull request as draft July 2, 2026 10:35
@kshitij12345 kshitij12345 force-pushed the linear_compile_no_proto branch 2 times, most recently from bda5d6c to 8877da0 Compare July 2, 2026 10:42
@kshitij12345 kshitij12345 marked this pull request as ready for review July 2, 2026 10:53
@kshitij12345 kshitij12345 force-pushed the linear_compile_no_proto branch from 8877da0 to 3bd5ab5 Compare July 2, 2026 13:35
Replace the 172-line TensorProto dataclass with a single function
make_empty_traceable(quantizer, shape, dtype, device) that directly allocates
traceable quantized tensors. The fake impls now return actual tensors (which
become FakeTensors under register_fake) instead of intermediate descriptors.

The key insight: make_empty_traceable stashes _te_flat_names and _te_flat_ctx
on the resulting tensor. Dynamo treats non-callable attributes on traceable
wrapper subclasses as constant metadata, so forward_fn can read slot counts
and reassembly info from these attributes without calling __tensor_flatten__
(which would cause a graph break since it returns non-Tensor Python objects).

This eliminates:
- TensorProto class and to_tensor_proto helper (tensor_proto.py deleted)
- _proto_view (converted tensor fields to TensorProto before fake impls)
- _tensor_field_names (identified fields for _proto_view)
- _proto_slot_count / _proto_reassemble (operated on TensorProto objects)
- TensorProto branch in _value_to_flat_tensors and _format_bwd_result

The fake impls in linear.py now use:
- isinstance(inp, QuantizedTensorStorage) instead of inp.is_quantized
- weight._quantizer instead of weight.quantizer (TensorProto field)
- make_empty_traceable(...) instead of TensorProto(...)
- Direct set_usage on quantizer instead of proto.update_usage()

Test Plan:

```
python -m pytest tests/pytorch/test_torch_compile.py -v -k 'not nvfp4'
```

Authored with Claude.
@kshitij12345 kshitij12345 force-pushed the linear_compile_no_proto branch from 3bd5ab5 to b24d259 Compare July 2, 2026 13:37
# TODO: understand why Dynamo does not recognize the quantizer retrieved via
# t._quantizer as the same value-opaque type it would if captured from a
# closure. If that is fixed upstream, the stashed attributes become
# unnecessary and we could compute slot counts directly from the quantizer.

@kshitij12345 kshitij12345 Jul 2, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch Issue: pytorch/pytorch#188796

quantizer.optimize_for_gemm = self.optimize_for_gemm
quantizer.rht_matrix = self.rht_matrix
quantizer.rht_matrix_random_sign_mask_t = self.rht_matrix_random_sign_mask_t
if not torch.compiler.is_compiling():

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Understand this better

@kshitij12345 kshitij12345 Jul 3, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens due to NVFPQuantizer being an OpaqueObject but then we try to attach a Tensor onto it.

…der is_compiling

Under Dynamo tracing, rht_matrix is a FakeTensor attached to an opaque
script object. Accessing it in copy() triggers SourcelessBuilder which
cannot wrap FakeTensor, causing an InternalTorchDynamoError.

The fake impl never runs real quantization, so rht_matrix is unnecessary
during tracing. Guard the tensor field copies with
torch.compiler.is_compiling() -- the matrix will be rebuilt lazily via
_rebuild_derived_state if the quantizer is later used outside tracing.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@kshitij12345 kshitij12345 force-pushed the linear_compile_no_proto branch from f506cb1 to c33cd00 Compare July 2, 2026 15:17
The C++ quantize kernel requires with_post_rht_amax=True when with_rht
is enabled. The test factory was creating an NVFP4Quantizer with
with_rht=True but with_post_rht_amax defaulting to False, causing
'Pre-RHT amax is not supported yet' at quantize time.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant