diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index a0501cca8e..87dfd171f6 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -32,7 +32,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch.quantized_tensor import QuantizedTensor, _STORAGE_REGISTRY -from transformer_engine.pytorch.dynamo import TensorProto, to_tensor_proto +from transformer_engine.pytorch.dynamo.traceable_utils import make_empty_traceable, _contiguous_stride from transformer_engine.pytorch import ( is_fp8_available, is_mxfp8_available, @@ -499,6 +499,7 @@ def _nvfp4(with_rht=True): rowwise=True, columnwise=True, with_rht=with_rht, + with_post_rht_amax=with_rht, ) @@ -640,7 +641,7 @@ def fn(inp): # --------------------------------------------------------------------------- -# torch.compile-traceable allocation primitives + TensorProto +# torch.compile-traceable allocation primitives # --------------------------------------------------------------------------- @@ -665,9 +666,7 @@ def fn(inp): def _build_from_primitives(quantizer, shape, dtype, device="cpu"): """Assemble a quantized tensor straight from the quantizer primitives: ``alloc_tensors`` (buffers) + ``create_metadata`` (ctx) + the storage's - ``__tensor_unflatten__`` -- i.e. exactly what ``TensorProto.create_tensor`` - does, but without going through :class:`TensorProto`. - """ + ``__tensor_unflatten__``.""" names = tuple(quantizer._describe_buffers(shape)) # pylint: disable=protected-access ctx = quantizer.create_metadata(shape, dtype=dtype) buffers = quantizer.alloc_tensors(shape, device=device) @@ -713,7 +712,7 @@ def _skip_if_dequantize_unsupported(q): @pytest.mark.parametrize("factory, shape", _PROTO_QUANTIZERS) def test_primitives_unflatten_compiles(factory, shape): """create_metadata + alloc_tensors + __tensor_unflatten__ compose and trace - under ``fullgraph=True`` (CPU), without TensorProto.""" + under ``fullgraph=True`` (CPU).""" q = factory() names = tuple(q._describe_buffers(shape)) # pylint: disable=protected-access @@ -779,78 +778,64 @@ def test_storage_flatten_unflatten_roundtrip(factory, shape): torch.testing.assert_close(rebuilt.dequantize(), expected, atol=0, rtol=0, equal_nan=True) -# ----- TensorProto ----- +# ----- make_empty_traceable ----- @pytest.mark.parametrize("factory, shape", _PROTO_QUANTIZERS) -def test_tensor_proto_matches_primitives(factory, shape): - """TensorProto is a thin wrapper: its ``create_metadata`` / - ``create_inner_tensors`` / ``create_tensor`` match building everything - directly from the quantizer primitives.""" +def test_make_empty_traceable_matches_primitives(factory, shape): + """make_empty_traceable is equivalent to building from quantizer primitives + directly (alloc_tensors + create_metadata + __tensor_unflatten__).""" q = factory() - proto = TensorProto(shape=shape, dtype=torch.bfloat16, quantizer=q, device=torch.device("cpu")) - assert proto.is_quantized - # Metadata matches the quantizer's. - assert proto.create_metadata() == q.create_metadata(shape, dtype=torch.bfloat16) + # Build via make_empty_traceable. + tensor = make_empty_traceable(q, shape, dtype=torch.bfloat16, device="cpu") - # inner_names follows the storage's canonical __tensor_flatten__ order (the - # order the real op flattens its outputs to), while create_inner_tensors - # matches the _describe_buffers geometry (a name->shape/dtype mapping). - bufs = q._describe_buffers(shape) # pylint: disable=protected-access + # Build directly from primitives. direct = _build_from_primitives(q, shape, torch.bfloat16) - names = tuple(direct.__tensor_flatten__()[0]) - assert set(names) == set(bufs) - assert proto.inner_names() == names - inner = proto.create_inner_tensors() - assert len(inner) == len(names) - for name, buf in zip(names, inner): - exp_shape, exp_dtype = bufs[name] - assert tuple(buf.shape) == tuple(exp_shape) - assert buf.dtype == exp_dtype - # The assembled tensor matches one built directly from the primitives. - assert _signature(proto.create_tensor(), names) == _signature(direct, names) + # Both should produce the same buffer layout. + names = tuple(q._describe_buffers(shape)) # pylint: disable=protected-access + assert _signature(tensor, names) == _signature(direct, names) @pytest.mark.parametrize("factory, shape", _PROTO_QUANTIZERS) -def test_tensor_proto_create_tensor_eager(factory, shape): - """``create_tensor`` (no fake) yields a real quantized tensor.""" +def test_make_empty_traceable_eager(factory, shape): + """make_empty_traceable (no fake) yields a real quantized tensor.""" q = factory() - proto = TensorProto(shape=shape, dtype=torch.bfloat16, quantizer=q, device=torch.device("cpu")) - out = proto.create_tensor() + out = make_empty_traceable(q, shape, dtype=torch.bfloat16, device="cpu") assert isinstance(out, QuantizedTensor) assert tuple(out.shape) == tuple(shape) assert out.dtype == torch.bfloat16 - for name in proto.inner_names(): + names = tuple(q._describe_buffers(shape)) # pylint: disable=protected-access + for name in names: assert not isinstance(getattr(out, name), FakeTensor) @pytest.mark.parametrize("factory, shape", _PROTO_QUANTIZERS) -def test_tensor_proto_create_tensor_fake(factory, shape): - """``create_tensor`` under ``FakeTensorMode`` yields a fake-backed quantized +def test_make_empty_traceable_fake(factory, shape): + """make_empty_traceable under FakeTensorMode yields a fake-backed quantized tensor with the right shape/dtype and fake inner buffers.""" q = factory() - proto = TensorProto(shape=shape, dtype=torch.bfloat16, quantizer=q, device=torch.device("cpu")) with FakeTensorMode(): - out = proto.create_tensor() + out = make_empty_traceable(q, shape, dtype=torch.bfloat16, device="cpu") assert isinstance(out, QuantizedTensor) assert tuple(out.shape) == tuple(shape) assert out.dtype == torch.bfloat16 - for name in proto.inner_names(): + names = tuple(q._describe_buffers(shape)) # pylint: disable=protected-access + for name in names: assert isinstance(getattr(out, name), FakeTensor) @pytest.mark.parametrize("factory, shape", _PROTO_QUANTIZERS) -def test_tensor_proto_create_tensor_compiles(factory, shape): - """``TensorProto.create_tensor`` traces under ``fullgraph=True`` (CPU).""" +def test_make_empty_traceable_compiles(factory, shape): + """make_empty_traceable traces under torch.compile(fullgraph=True) (CPU).""" q = factory() + names = tuple(q._describe_buffers(shape)) # pylint: disable=protected-access def fn(x): - proto = TensorProto(shape=tuple(x.shape), dtype=x.dtype, quantizer=q, device=x.device) - t = proto.create_tensor() + t = make_empty_traceable(q, tuple(x.shape), dtype=x.dtype, device=x.device) acc = x.new_zeros(()) - for name in proto.inner_names(): + for name in names: acc = acc + getattr(t, name).float().sum() return acc @@ -860,36 +845,32 @@ def fn(x): assert out.shape == () -def test_to_tensor_proto_plain(): - """``to_tensor_proto`` describes a plain tensor.""" - t = torch.empty(2, 3, dtype=torch.float32) - proto = to_tensor_proto(t) - assert not proto.is_quantized - assert proto.shape == (2, 3) - assert proto.dtype == torch.float32 - assert proto.inner_names() == ("data",) +def test_make_empty_traceable_plain_tensor(): + """For non-quantized tensors, make_empty_traceable produces a plain tensor.""" + t = make_empty_traceable(None, (2, 3), dtype=torch.float32, device="cpu") + assert not isinstance(t, QuantizedTensor) + assert t.shape == (2, 3) + assert t.dtype == torch.float32 @pytest.mark.parametrize("factory, shape", _PROTO_QUANTIZERS) -def test_to_tensor_proto_quantized(factory, shape): - """``to_tensor_proto`` round-trips a quantized tensor back into a proto.""" +def test_make_empty_traceable_roundtrip(factory, shape): + """A tensor built via make_empty_traceable can be flattened and unflattened.""" q = factory() - tensor = TensorProto( - shape=shape, dtype=torch.bfloat16, quantizer=q, device=torch.device("cpu") - ).create_tensor() - - proto = to_tensor_proto(tensor) - assert proto.is_quantized - assert proto.shape == tuple(shape) - assert proto.dtype == torch.bfloat16 - # Same buffer layout as the original tensor. - assert proto.inner_names() == tuple( - q._describe_buffers(shape) - ) # pylint: disable=protected-access - # Rebuilding from the derived proto matches the original tensor's structure. - assert _signature(proto.create_tensor(), proto.inner_names()) == _signature( - tensor, proto.inner_names() + tensor = make_empty_traceable(q, shape, dtype=torch.bfloat16, device="cpu") + names = tuple(q._describe_buffers(shape)) # pylint: disable=protected-access + + # Flatten. + flat_names, flat_ctx = tensor.__tensor_flatten__() + assert set(flat_names) == set(names) + inner = {name: getattr(tensor, name) for name in flat_names} + + # Unflatten. + rebuilt = type(tensor).__tensor_unflatten__( + inner, flat_ctx, tuple(tensor.shape), tensor.stride() ) + assert isinstance(rebuilt, QuantizedTensor) + assert _signature(rebuilt, flat_names) == _signature(tensor, names) # --------------------------------------------------------------------------- diff --git a/transformer_engine/pytorch/dynamo/__init__.py b/transformer_engine/pytorch/dynamo/__init__.py index 66e0525a9e..f4864cda5b 100644 --- a/transformer_engine/pytorch/dynamo/__init__.py +++ b/transformer_engine/pytorch/dynamo/__init__.py @@ -5,13 +5,12 @@ """torch.compile glue for Transformer Engine.""" from .quantizer_opaque import register_value_opaque_quantizer, is_value_opaque_quantizer -from .tensor_proto import TensorProto, to_tensor_proto +from .traceable_utils import make_empty_traceable from .custom_op import register_custom_op __all__ = [ "register_value_opaque_quantizer", "is_value_opaque_quantizer", - "TensorProto", - "to_tensor_proto", + "make_empty_traceable", "register_custom_op", ] diff --git a/transformer_engine/pytorch/dynamo/custom_op.py b/transformer_engine/pytorch/dynamo/custom_op.py index 68b2aae59e..9aca5b866f 100644 --- a/transformer_engine/pytorch/dynamo/custom_op.py +++ b/transformer_engine/pytorch/dynamo/custom_op.py @@ -14,7 +14,6 @@ Dict, List, Optional, - Sequence, Tuple, Union, get_args, @@ -24,7 +23,7 @@ import torch -from .tensor_proto import TensorProto, to_tensor_proto, _contiguous_stride +from .traceable_utils import _contiguous_stride, _slot_count, _maybe_reassemble_tensor_subclass from ..quantized_tensor import ( QuantizedTensor, QuantizedTensorStorage, @@ -723,11 +722,6 @@ def _get_buckets(cls: type) -> List[_Bucket]: return buckets -def _tensor_field_names(buckets: List[_Bucket]) -> List[str]: - """Names of fields carrying tensors (for building the proto view).""" - return [b.name for b in buckets if isinstance(b, (_TensorBucket, _UniversalTensorBucket))] - - def _build_schema(buckets: List[_Bucket]) -> Tuple[str, List[str]]: """Return ``(schema_arg_str, slot_names)`` for a bucket list.""" spec = [slot for b in buckets for slot in b.schema_slots()] @@ -761,76 +755,25 @@ def _unpack(cls: type, args: Dict[str, Any], buckets: List[_Bucket]) -> Any: object.__setattr__(obj, k, v) return obj - -def _proto_view(obj: Any, tensor_field_names: Sequence[str]) -> Any: - """Copy of dataclass ``obj`` with each tensor field replaced by a :class:`TensorProto`. - - Only tensor fields have a ``TensorProto`` equivalent, so quantizer / scalar - fields are simply carried over unchanged; the fake impl works purely on - geometry. Built with :func:`dataclasses.replace` (the only such construction - Dynamo can trace). - """ - overrides: Dict[str, Any] = {} - for name in tensor_field_names: - value = getattr(obj, name, None) - if value is not None and not isinstance(value, TensorProto): - overrides[name] = to_tensor_proto(value) - if not overrides: - return obj - return dataclasses.replace(obj, **overrides) - - # --------------------------------------------------------------------------- # # Op outputs <-> flat ``Tensor[]`` payload: this is how an op returns / saves # quantized tensors (and wrapper subclasses). Outputs are flattened to their # inner buffers on the way out and rebuilt via ``__tensor_unflatten__`` on the -# way back; on the fake side a TensorProto supplies the geometry. +# way back; the fake impl returns actual tensors whose __tensor_flatten__ +# provides the template for reassembly. # --------------------------------------------------------------------------- # -def _proto_slot_count(proto: Optional[TensorProto]) -> int: - """Flat ``Tensor[]`` slots the value for ``proto`` occupies.""" - if proto is None: - return 1 - return len(proto.inner_names()) - - -def _proto_reassemble( - proto: Optional[TensorProto], - chunk: List[Optional[torch.Tensor]], -) -> Optional[Union[torch.Tensor, QuantizedTensorStorage]]: - """Rebuild the value described by ``proto`` from its flat tensors ``chunk``. - - ``proto`` describes one output: ``None`` (-> ``None``), a plain tensor - (``chunk`` is the single tensor, returned as-is), or a quantized tensor - (``chunk`` are its inner tensors, reassembled into the wrapper subclass via - ``__tensor_unflatten__``). - """ - if proto is None: - return None - if proto.quantizer is None: - return chunk[0] - inner_names = proto.inner_names() - meta = proto.create_metadata() - shape = tuple(proto.shape) - stride = _contiguous_stride(shape) - storage_cls = _STORAGE_REGISTRY[meta["cls"]] - inner_dict = dict(zip(inner_names, chunk)) - return storage_cls.__tensor_unflatten__(inner_dict, meta, shape, stride) - - def _value_to_flat_tensors( - value: Optional[Union[torch.Tensor, QuantizedTensorStorage, TensorProto]], + value: Optional[Union[torch.Tensor, QuantizedTensorStorage]], ) -> List[torch.Tensor]: """Return the flat ``Tensor[]`` slots that represent one op output ``value``. - Inverse of :func:`_proto_reassemble`; the slot count matches - :func:`_proto_slot_count`. + Inverse of :func:`_maybe_reassemble_tensor_subclass`; the slot count matches + :func:`_slot_count`. """ if value is None: return [_encode_none(None)] - if isinstance(value, TensorProto): - return [_encode_none(t) for t in value.create_inner_tensors()] if isinstance(value, torch.Tensor): if type(value) is not torch.Tensor and hasattr( # pylint: disable=unidiomatic-typecheck value, "__tensor_flatten__" @@ -843,7 +786,7 @@ def _value_to_flat_tensors( return [_encode_none(getattr(value, n)) for n in inner_names] raise TypeError( f"unsupported value type {type(value).__name__}; expected None / " - "torch.Tensor / tensor subclass / bare storage / TensorProto." + "torch.Tensor / tensor subclass / bare storage." ) @@ -870,8 +813,7 @@ def _format_fwd_result(result: Any) -> List[torch.Tensor]: def _format_bwd_result(grads: Any, num_grad_inputs: int, op_qualname: str) -> List[torch.Tensor]: """Pack a backward-impl return tuple into the op's ``Tensor[]`` payload. - Each grad occupies exactly one slot (validated against ``num_grad_inputs``); - a :class:`TensorProto` grad is materialized into a single tensor. + Each grad occupies exactly one slot (validated against ``num_grad_inputs``). """ grads = list(grads) if len(grads) != num_grad_inputs: @@ -879,13 +821,7 @@ def _format_bwd_result(grads: Any, num_grad_inputs: int, op_qualname: str) -> Li f"{op_qualname} expected backward_impl to return {num_grad_inputs} grads " f"(one per input_tensors_for_grad entry), got {len(grads)}" ) - out: List[torch.Tensor] = [] - for g in grads: - if isinstance(g, TensorProto): - out.append(_encode_none(g.create_tensor())) - else: - out.append(_encode_none(g)) - return out + return [_encode_none(g) for g in grads] def _split_fwd_fake_result( @@ -946,16 +882,15 @@ def _register_kernel( arg_type: type, arg_names: List[str], buckets: List[_Bucket], - tensor_field_names: List[str], impl: Callable[[Any], Any], fake_impl: Callable[[Any], Any], format_result: Callable[[Any], List[torch.Tensor]], ) -> Any: """Define the op via ``torch.library.custom_op`` with the real ``impl`` + the - ``fake_impl`` (proto), returning the ``CustomOpDef``. + ``fake_impl``, returning the ``CustomOpDef``. The real kernel rebuilds the dataclass and runs ``impl``; the fake kernel - runs the proto fake impl on the :func:`_proto_view`. Both go through + runs the fake impl directly on the unpacked object. Both go through ``format_result``. """ @@ -967,8 +902,7 @@ def _impl(*flat: Any) -> List[torch.Tensor]: def _fake(*flat: Any) -> List[torch.Tensor]: kwargs = dict(zip(arg_names, flat)) obj = _unpack(arg_type, kwargs, buckets) - proto_obj = _proto_view(obj, tensor_field_names) - return format_result(fake_impl(proto_obj)) + return format_result(fake_impl(obj)) op = torch.library.custom_op( f"{_TE_OP_NAMESPACE}::{op_name}", _impl, mutates_args=(), schema=schema_str @@ -984,7 +918,6 @@ def _register_autograd_for_op( fwd_arg_type: type, fwd_arg_names: List[str], fwd_buckets: List[_Bucket], - fwd_tensor_field_names: List[str], bwd_arg_names: List[str], bwd_buckets: List[_Bucket], fwd_slot_defaults: List[Any], @@ -995,7 +928,7 @@ def _register_autograd_for_op( ) -> None: """Wire ``register_autograd`` on a forward op so its backward calls ``bwd_op_name``. - ``setup_context`` re-runs the proto fwd fake impl to recover output / saved + ``setup_context`` re-runs the fwd fake impl to recover output / saved templates, reassembles each flat output chunk, and hands the saved tuple + ``ctx_attrs`` to the module's ``setup_context``. """ @@ -1006,24 +939,23 @@ def _setup_context(ctx, inputs, output): } kwargs = dict(zip(fwd_arg_names, inputs)) fwd_obj = _unpack(fwd_arg_type, kwargs, fwd_buckets) - proto_obj = _proto_view(fwd_obj, fwd_tensor_field_names) - user_fakes, saved_fakes, ctx_attrs = _split_fwd_fake_result(fwd_fake_impl(proto_obj)) + user_fakes, saved_fakes, ctx_attrs = _split_fwd_fake_result(fwd_fake_impl(fwd_obj)) cursor = 0 user_outputs: List[Any] = [] - for proto in user_fakes: - n = _proto_slot_count(proto) + for template in user_fakes: + n = _slot_count(template) chunk = [_decode_none(t) for t in output[cursor : cursor + n]] cursor += n - user_outputs.append(_proto_reassemble(proto, chunk)) + user_outputs.append(_maybe_reassemble_tensor_subclass(template, chunk)) saved_list: List[Any] = [] - for proto in saved_fakes: - n = _proto_slot_count(proto) + for template in saved_fakes: + n = _slot_count(template) chunk = [_decode_none(t) for t in output[cursor : cursor + n]] cursor += n - saved_list.append(_proto_reassemble(proto, chunk)) + saved_list.append(_maybe_reassemble_tensor_subclass(template, chunk)) bwd_obj = backward_obj_type() tensors_to_save_from_setup = setup_context_user( @@ -1203,8 +1135,6 @@ def _register_custom_op_impl( fwd_buckets = _get_buckets(fwd_arg_type) bwd_buckets = _get_buckets(backward_arg_type) - fwd_tensor_field_names = _tensor_field_names(fwd_buckets) - bwd_tensor_field_names = _tensor_field_names(bwd_buckets) fwd_schema_args, fwd_arg_names = _build_schema(fwd_buckets) bwd_schema_args, bwd_arg_names = _build_schema(bwd_buckets) @@ -1225,7 +1155,6 @@ def _register_custom_op_impl( arg_type=fwd_arg_type, arg_names=fwd_arg_names, buckets=fwd_buckets, - tensor_field_names=fwd_tensor_field_names, impl=fwd_impl, fake_impl=fwd_fake_impl, format_result=_format_fwd_result, @@ -1236,7 +1165,6 @@ def _register_custom_op_impl( arg_type=backward_arg_type, arg_names=bwd_arg_names, buckets=bwd_buckets, - tensor_field_names=bwd_tensor_field_names, impl=backward_impl, fake_impl=bwd_fake_impl, format_result=lambda g: _format_bwd_result(g, num_grad_inputs, inner_bwd_qualname), @@ -1257,7 +1185,6 @@ def _register_custom_op_impl( "fwd_arg_type": fwd_arg_type, "fwd_arg_names": fwd_arg_names, "fwd_buckets": fwd_buckets, - "fwd_tensor_field_names": fwd_tensor_field_names, "bwd_arg_names": bwd_arg_names, "bwd_buckets": bwd_buckets, "fwd_slot_defaults": fwd_slot_defaults, @@ -1305,19 +1232,18 @@ def _bwd_rule(mode, func, types, args, kwargs): _quantized_tensor_passthrough_ops.add(inner_bwd_op.default) def forward_fn(fwd_args): - proto_obj = _proto_view(fwd_args, fwd_tensor_field_names) - user_fakes, _saved_fakes, _ctx_attrs = _split_fwd_fake_result(fwd_fake_impl(proto_obj)) + user_fakes, _saved_fakes, _ctx_attrs = _split_fwd_fake_result(fwd_fake_impl(fwd_args)) kwargs = _pack(fwd_args, fwd_buckets) flat_in = [kwargs[name] for name in fwd_arg_names] result = outer_fwd_op(*flat_in) cursor = 0 outputs: List[Any] = [] - for proto in user_fakes: - n = _proto_slot_count(proto) + for template in user_fakes: + n = _slot_count(template) chunk = [_decode_none(t) for t in result[cursor : cursor + n]] cursor += n - outputs.append(_proto_reassemble(proto, chunk)) + outputs.append(_maybe_reassemble_tensor_subclass(template, chunk)) if len(outputs) == 1: return outputs[0] diff --git a/transformer_engine/pytorch/dynamo/tensor_proto.py b/transformer_engine/pytorch/dynamo/tensor_proto.py deleted file mode 100644 index 911b4151bf..0000000000 --- a/transformer_engine/pytorch/dynamo/tensor_proto.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""TensorProto: a data-free description of a tensor / quantized tensor.""" - -from __future__ import annotations -import copy as _copy -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple - -import torch - - -def _contiguous_stride(shape: Tuple[int, ...]) -> Tuple[int, ...]: - """Row-major (contiguous) stride for ``shape``.""" - stride: list = [] - acc = 1 - for dim in reversed(shape): - stride.append(acc) - acc *= dim - return tuple(reversed(stride)) - - -@dataclass -class TensorProto: - """A data-free *prototype* of a tensor or quantized tensor. - - Captures ``shape`` / ``dtype`` and, for quantized tensors, the - (value-opaque) ``quantizer`` -- enough to rebuild a tensor without holding - storage. The common abstraction over plain ``torch.Tensor``, - ``QuantizedTensorStorage`` and ``QuantizedTensor``, used for custom-op fake - impls and for reassembling a quantized tensor from bare buffers. - """ - - shape: Tuple[int, ...] - dtype: torch.dtype - quantizer: Optional[Any] = None - requires_grad: bool = False - device: Optional[torch.device] = field(default=None) - - def __post_init__(self) -> None: - # Own a private copy of the quantizer so usage changes (update_usage) - # never touch the shared, value-opaque quantizer. The copy inherits the - # quantizer's current row-/column-wise usage as this proto's layout. - if self.quantizer is not None: - q = self.quantizer - self.quantizer = q.copy() if hasattr(q, "copy") else _copy.copy(q) - - @property - def is_quantized(self) -> bool: - """Whether this proto describes a quantized tensor.""" - return self.quantizer is not None - - def update_usage( - self, - *, - rowwise_usage: Optional[bool] = None, - columnwise_usage: Optional[bool] = None, - ) -> None: - """Mirror ``QuantizedTensor.update_usage`` on the proto's buffer layout. - - Applied to the proto's own quantizer copy, so the shared (value-opaque) - quantizer is never mutated. No-op for plain (non-quantized) protos. - """ - if self.quantizer is None: - return - self.quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) - - def inner_names(self) -> Tuple[str, ...]: - """Names of the flat tensor buffers backing this proto, in order. - - The real op flattens a quantized output via the storage's - ``__tensor_flatten__`` -- i.e. ``_FLATTEN_TENSOR_BUFFERS`` order, keeping - only the present buffers. ``_describe_buffers`` may emit the same buffers - in a different (per-usage) order (e.g. NVFP4 groups each amax right after - its scale), so reorder to the canonical flatten order here to keep the - fake layout aligned with the real one slot-for-slot. - """ - if self.quantizer is None: - return ("data",) - # pylint: disable=protected-access - described = list(self.quantizer._describe_buffers(tuple(self.shape)).keys()) - storage_cls = self.quantizer._storage_metadata(self.dtype)["cls"] - flatten_order = [attr for attr, _ in storage_cls._FLATTEN_TENSOR_BUFFERS] - extra = [name for name in described if name not in flatten_order] - if extra: - raise RuntimeError( - f"{storage_cls.__name__} describes buffer(s) {extra} absent from its " - f"_FLATTEN_TENSOR_BUFFERS {flatten_order}; the fake layout cannot be " - "aligned with the real one slot-for-slot." - ) - return tuple(name for name in flatten_order if name in described) - - def create_metadata(self) -> Dict[str, Any]: - """Data-free ``__tensor_unflatten__`` context describing this tensor.""" - if self.quantizer is None: - return { - "is_tensor": True, - "is_quantized": False, - "dtype": self.dtype, - "requires_grad": self.requires_grad, - } - return self.quantizer.create_metadata( - tuple(self.shape), dtype=self.dtype, requires_grad=self.requires_grad - ) - - def create_inner_tensors(self) -> List[torch.Tensor]: - """Materialize the flat inner buffers (in :meth:`inner_names` order). - - Under ``register_fake`` the ``torch.empty`` calls produce ``FakeTensor``s; - ``requires_grad`` is left default (managed by ``register_autograd``). - """ - device = self.device if self.device is not None else torch.device("cuda") - if self.quantizer is None: - return [torch.empty(tuple(self.shape), dtype=self.dtype, device=device)] - inner = self.quantizer.alloc_tensors(tuple(self.shape), device=device) - return [inner[name] for name in self.inner_names()] - - def create_tensor(self) -> torch.Tensor: - """Materialize an (uninitialized) tensor matching this proto (traceable). - - Quantized protos reassemble the :meth:`create_inner_tensors` buffers via - the storage's ``__tensor_unflatten__``. - """ - if self.quantizer is None: - device = self.device if self.device is not None else torch.device("cuda") - return torch.empty( - tuple(self.shape), - dtype=self.dtype, - device=device, - requires_grad=self.requires_grad, - ) - from ..quantized_tensor import ( # pylint: disable=import-outside-toplevel - _STORAGE_REGISTRY, - ) - - shape = tuple(self.shape) - ctx = self.create_metadata() - inner = dict(zip(self.inner_names(), self.create_inner_tensors())) - storage_cls = _STORAGE_REGISTRY[ctx["cls"]] - return storage_cls.__tensor_unflatten__(inner, ctx, shape, _contiguous_stride(shape)) - - -def to_tensor_proto(tensor: Any) -> TensorProto: - """Build a :class:`TensorProto` describing ``tensor``. - - Works for plain ``torch.Tensor`` and for ``QuantizedTensorStorage`` / - ``QuantizedTensor``. A *bare* storage exposes its shape via ``.size()`` and - its (fake) dtype via ``_dtype`` rather than ``.shape`` / ``.dtype``. - """ - from ..quantized_tensor import ( # pylint: disable=import-outside-toplevel - QuantizedTensorStorage, - ) - - requires_grad = bool(getattr(tensor, "requires_grad", False)) - if isinstance(tensor, QuantizedTensorStorage): - shape = getattr(tensor, "shape", None) - if shape is None: - shape = tensor.size() - dtype = getattr(tensor, "dtype", None) - if dtype is None: - dtype = getattr(tensor, "_dtype", None) - return TensorProto( - shape=tuple(shape), - dtype=dtype, - quantizer=getattr(tensor, "_quantizer", None), - requires_grad=requires_grad, - device=tensor.device, - ) - return TensorProto( - shape=tuple(tensor.shape), - dtype=tensor.dtype, - quantizer=None, - requires_grad=requires_grad, - device=tensor.device, - ) diff --git a/transformer_engine/pytorch/dynamo/traceable_utils.py b/transformer_engine/pytorch/dynamo/traceable_utils.py new file mode 100644 index 0000000000..560599518a --- /dev/null +++ b/transformer_engine/pytorch/dynamo/traceable_utils.py @@ -0,0 +1,133 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Pure-Python, torch.compile-traceable quantized tensor allocation and reassembly.""" + +from __future__ import annotations +import copy as _copy +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch + + +def _contiguous_stride(shape: Tuple[int, ...]) -> Tuple[int, ...]: + """Row-major (contiguous) stride for ``shape``.""" + stride: list = [] + acc = 1 + for dim in reversed(shape): + stride.append(acc) + acc *= dim + return tuple(reversed(stride)) + + +def make_empty_traceable( + quantizer, + shape: Tuple[int, ...], + *, + dtype: torch.dtype = torch.float32, + device: Optional[Union[torch.device, str]] = None, + requires_grad: bool = False, +) -> Any: + """Allocate a tensor purely in Python (traceable under torch.compile). + + When ``quantizer`` is not None, produces a quantized tensor via + ``alloc_tensors`` + ``__tensor_unflatten__`` (the compile-friendly + equivalent of ``Quantizer.make_empty``). The quantizer is copied first + so the caller's instance is never mutated. + + When ``quantizer`` is None, falls back to ``torch.empty`` for a plain tensor. + + Stashed metadata (``_te_flat_names``, ``_te_flat_ctx``) + --------------------------------------------------------- + The resulting quantized tensor has these attributes stashed so that + ``forward_fn`` (in custom_op.py) can read them at Dynamo trace time. + + Why: ``forward_fn`` runs inside torch.compile's trace. It needs the flat + buffer names and unflatten context to decode the custom op's flat Tensor[] + return back into a structured QuantizedTensor. Calling + ``__tensor_flatten__()`` for this would cause a graph break (it returns + non-Tensor Python objects -- List[str] and Dict -- that Dynamo cannot + represent as graph nodes). Accessing ``t._quantizer`` and calling + ``_describe_buffers()`` on it also fails: while Dynamo treats the retrieved + quantizer as constant metadata, it wraps it in a generic VariableTracker + that does not support method calls (unlike a closure-captured quantizer + which is recognized as a value-opaque constant). Stashing the buffer names + and context as plain attributes sidesteps both issues -- Dynamo reads them + as constants without needing to call any methods. + + Allocation cost: when ``forward_fn`` calls the fake impl to obtain these + templates, the ``torch.empty`` calls appear as nodes in the initial Dynamo + FX graph. However, because the tensors themselves are never used (only + the stashed metadata is read), AOT autograd's dead-code elimination removes + them before any kernel code is generated. They do not appear in the final + compiled graph. + """ + device = torch.device(device if device is not None else "cuda") + shape = tuple(shape) + if quantizer is None: + return torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad) + from ..quantized_tensor import _STORAGE_REGISTRY # pylint: disable=import-outside-toplevel + + # Copy so the caller's quantizer is not mutated by alloc_tensors internals. + # The caller is expected to have already called set_usage() on the quantizer + # before passing it here -- Dynamo tracks those mutations as explicit setattr + # nodes in the graph, so the copy captures the post-mutation state and the + # stashed _te_flat_names reflects the correct buffer layout. + q = quantizer.copy() if hasattr(quantizer, "copy") else _copy.copy(quantizer) + ctx = q.create_metadata(shape, dtype=dtype, requires_grad=requires_grad) + inner = q.alloc_tensors(shape, device=device) + storage_cls = _STORAGE_REGISTRY[ctx["cls"]] + result = storage_cls.__tensor_unflatten__(inner, ctx, shape, _contiguous_stride(shape)) + if requires_grad and hasattr(result, "requires_grad_"): + result.requires_grad_(True) + # TODO: understand why Dynamo does not recognize the quantizer retrieved via + # t._quantizer as the same value-opaque type it would if captured from a + # closure. If that is fixed upstream, the stashed attributes become + # unnecessary and we could compute slot counts directly from the quantizer. + result._te_flat_names = tuple(inner.keys()) + result._te_flat_ctx = ctx + return result + + +# --------------------------------------------------------------------------- # +# Slot counting and reassembly for the custom-op flat Tensor[] protocol. +# --------------------------------------------------------------------------- # + + +def _slot_count(value: Any) -> int: + """Number of flat tensor slots a value occupies in the op's Tensor[] return. + + Reads ``_te_flat_names`` stashed by :func:`make_empty_traceable`, which is + safe to access at Dynamo trace time (treated as constant metadata on a + traceable wrapper subclass). Plain tensors (no stashed names) occupy 1 slot. + """ + if value is None: + return 1 + names = getattr(value, "_te_flat_names", None) + if names is not None: + return len(names) + return 1 + + +def _maybe_reassemble_tensor_subclass( + template: Any, + chunk: List[Optional[torch.Tensor]], +) -> Optional[Union[torch.Tensor, Any]]: + """Rebuild a value from its flat tensors using ``template`` for geometry. + + ``template`` is a tensor produced by the fake impl via + :func:`make_empty_traceable`. Uses the stashed ``_te_flat_names`` / + ``_te_flat_ctx`` attributes for reassembly (trace-safe). For plain tensors + (no stashed metadata), returns the single chunk element directly. + """ + if template is None: + return None + names = getattr(template, "_te_flat_names", None) + ctx = getattr(template, "_te_flat_ctx", None) + if names is None or ctx is None: + return chunk[0] + inner_dict = dict(zip(names, chunk)) + shape = tuple(template.shape) if hasattr(template, "shape") else tuple(template.size()) + stride = _contiguous_stride(shape) + return type(template).__tensor_unflatten__(inner_dict, ctx, shape, stride) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index d1ea805077..d1dd5f362b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -71,7 +71,8 @@ prepare_for_saving, restore_from_func_ctx, ) -from ..dynamo import TensorProto, register_custom_op, is_value_opaque_quantizer +from ..dynamo import register_custom_op, is_value_opaque_quantizer +from ..dynamo.traceable_utils import make_empty_traceable from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.utils import clear_columnwise_cache, is_custom @@ -698,10 +699,10 @@ def _linear_forward_impl( def _linear_forward_impl_fake( args: LinearFwdArgs, -) -> Tuple[TensorProto, Optional[TensorProto], Optional[Tuple[Any, ...]], None, Optional[Dict]]: +) -> Tuple[Any, Optional[Any], Optional[Tuple[Any, ...]], None, Optional[Dict]]: """Shape/metadata-only twin of :func:`_linear_forward_impl` for torch.compile, - returning ``TensorProto`` descriptors for the outputs and saved tensors instead - of allocating real data.""" + returning traceable tensors for the outputs and saved tensors instead of + allocating real data via C++ kernels.""" if args.fsdp_group is not None and args.is_grad_enabled: raise NotImplementedError( "Compile-time Linear forward does not support manual TE FSDP " @@ -730,7 +731,7 @@ def _linear_forward_impl_fake( inputmat_is_storage = False inputmat_aliases_inp = False if fp8_or_debug: - if inp.is_quantized: + if isinstance(inp, QuantizedTensorStorage): # Primary-quantized input reused as-is. inputmat_is_storage = True inputmat_aliases_inp = True @@ -764,7 +765,7 @@ def _linear_forward_impl_fake( weightmat_is_storage = False weightmat_aliases_weight = False if fp8_or_debug: - if weight_quantizer is not None and (not weight.is_quantized or debug): + if weight_quantizer is not None and (not isinstance(weight, QuantizedTensorStorage) or debug): columnwise_usage = is_grad_enabled and args.input_requires_grad and not args.is_fsdp2 if args.backward_override is not None: columnwise_usage = False @@ -774,34 +775,30 @@ def _linear_forward_impl_fake( and not in_fp8_activation_recompute_phase() ) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - elif weight.is_quantized: - weight_quantizer = weight.quantizer + elif isinstance(weight, QuantizedTensorStorage): + weight_quantizer = weight._quantizer - if weight.is_quantized: + if isinstance(weight, QuantizedTensorStorage): # Primary-quantized weight: the impl reuses it as ``weightmat``. weightmat = weight weightmat_is_storage = True weightmat_aliases_weight = True else: - weightmat = TensorProto( - shape=tuple(weight.shape), - dtype=activation_dtype, - quantizer=weight_quantizer, - device=weight.device, + weightmat = make_empty_traceable( + weight_quantizer, tuple(weight.shape), + dtype=activation_dtype, device=weight.device, ) weightmat_is_storage = True update_ws = args.is_first_microbatch is None or args.is_first_microbatch if args.cache_weight and update_ws and args.weight_workspace is None: - new_weight_workspace = TensorProto( - shape=tuple(weight.shape), - dtype=activation_dtype, - quantizer=weight_quantizer, - device=weight.device, + new_weight_workspace = make_empty_traceable( + weight_quantizer, tuple(weight.shape), + dtype=activation_dtype, device=weight.device, ) else: weightmat_aliases_weight = weight.dtype == activation_dtype - weightmat = TensorProto( - shape=tuple(weight.shape), dtype=activation_dtype, device=weight.device + weightmat = make_empty_traceable( + None, tuple(weight.shape), dtype=activation_dtype, device=weight.device, ) if output_quantizer is not None: @@ -815,11 +812,10 @@ def _linear_forward_impl_fake( out_leading = out_leading * args.tp_size elif args.parallel_mode == "row" and args.sequence_parallel: out_leading = out_leading // args.tp_size - out = TensorProto( - shape=(out_leading, *tuple(inp.shape[1:-1]), out_features), + out = make_empty_traceable( + output_quantizer, + (out_leading, *tuple(inp.shape[1:-1]), out_features), dtype=activation_dtype, - quantizer=output_quantizer, - requires_grad=is_grad_enabled and (args.input_requires_grad or args.weight_requires_grad), device=inp.device, ) @@ -837,29 +833,29 @@ def _linear_forward_impl_fake( if inputmat_aliases_inp: inputmat_alias = "inp" elif inputmat_is_storage: - saved_inputmat = TensorProto( - shape=tuple(inp.shape), - dtype=activation_dtype, - quantizer=input_quantizer, - device=inp.device, - ) # Mirror ``_linear_forward_impl``'s post-quantization # ``inputmat.update_usage(...)`` so the saved input's buffer layout # matches -- driven by the same conditions as the real impl. + # Copy the quantizer so the shared instance on fwd_args is not mutated. + save_q = input_quantizer.copy() if hasattr(input_quantizer, "copy") else input_quantizer if own_quantized_input and not save_original_input: if args.backward_override is not None: - saved_inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + save_q.set_usage(rowwise=True, columnwise=False) elif ( args.backward_input_needs_gather and weight_quantizer is not None and weight_quantizer.supports_only_rowwise_all_gather() ): - saved_inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + save_q.set_usage(rowwise=True, columnwise=False) else: - saved_inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + save_q.set_usage(rowwise=False, columnwise=True) + saved_inputmat = make_empty_traceable( + save_q, tuple(inp.shape), + dtype=activation_dtype, device=inp.device, + ) else: - saved_inputmat = TensorProto( - shape=tuple(inp.shape), dtype=activation_dtype, device=inp.device + saved_inputmat = make_empty_traceable( + None, tuple(inp.shape), dtype=activation_dtype, device=inp.device, ) # Slot 1 -- ``wt_save``. Mirror the real impl's alias dedup: the cached @@ -879,8 +875,8 @@ def _linear_forward_impl_fake( elif weightmat_is_storage: wt_save = weightmat else: - wt_save = TensorProto( - shape=tuple(weight.shape), dtype=activation_dtype, device=weight.device + wt_save = make_empty_traceable( + None, tuple(weight.shape), dtype=activation_dtype, device=weight.device, ) # Slot 2 -- ``saved_weight`` (always aliased to ``weight``). @@ -1612,13 +1608,12 @@ def wgrad_gemm( def _linear_backward_impl_fake( args: LinearBwdArgs, -) -> Tuple[Optional[TensorProto], Optional[TensorProto], Optional[TensorProto]]: - """Allocation-free fake of :func:`_linear_backward` on ``TensorProto``. +) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]: + """Traceable fake of :func:`_linear_backward`. - The saved-tensor fields of ``args`` carry - :class:`~transformer_engine.pytorch.dynamo.TensorProto` instances. Returns - ``(wgrad, dgrad, grad_bias)`` protos describing the nature of the gradients, - mirroring the real backward's return contract without allocating storage. + The saved-tensor fields of ``args`` carry real tensors (fake-backed under + tracing). Returns ``(wgrad, dgrad, grad_bias)`` as traceable tensors, + mirroring the real backward's return contract without running C++ kernels. Tensor-/sequence-parallel gather/scatter happens inside the eager backward custom op and is opaque to ``torch.compile``: ``dgrad`` always carries the @@ -1642,7 +1637,6 @@ def _linear_backward_impl_fake( dgrad = None if args.requires_dgrad: - # dgrad has the logical input shape and may be quantized for the next op. # Derive shape from grad_output + weight + SP config instead of args.inp_shape: # inp_shape is not stored in the value bundle under dynamic shapes (SymInt is # not hashable in OpaqueValueBundle), so we reconstruct it here. @@ -1654,10 +1648,10 @@ def _linear_backward_impl_fake( _dgrad_leading = _go_leading * args.tp_size else: _dgrad_leading = _go_leading - dgrad = TensorProto( - shape=(_dgrad_leading, *args.grad_output.shape[1:-1], _in_features), + dgrad = make_empty_traceable( + args.grad_input_quantizer, + (_dgrad_leading, *args.grad_output.shape[1:-1], _in_features), dtype=out_dtype, - quantizer=args.grad_input_quantizer, device=args.grad_output.device, ) @@ -1667,17 +1661,17 @@ def _linear_backward_impl_fake( # requested (mirrors ``quantization_params=grad_weight_quantizer``), # otherwise high precision. Under fuse_wgrad_accumulation the grad is # written into ``main_grad`` in place and no wgrad tensor is returned. - wgrad = TensorProto( - shape=(out_features, in_features), + wgrad = make_empty_traceable( + args.grad_weight_quantizer, + (out_features, in_features), dtype=out_dtype, - quantizer=args.grad_weight_quantizer, device=weight.device, ) grad_bias = None if args.use_bias and args.requires_wgrad: - grad_bias = TensorProto( - shape=(out_features,), dtype=out_dtype, device=args.grad_output.device + grad_bias = make_empty_traceable( + None, (out_features,), dtype=out_dtype, device=args.grad_output.device, ) return wgrad, dgrad, grad_bias diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 86c4c94f7c..2af15ca7b0 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -150,7 +150,7 @@ def copy_from_storage(self, src: QuantizedTensorStorage) -> None: f"{self.__class__.__name__} class does not implement copy_from_storage function" ) - # ----- PyTorch subclass flatten protocol (torch.compile / TensorProto) ----- + # ----- PyTorch subclass flatten protocol (torch.compile / traceable allocation) ----- # Subclasses declare their tensor buffers once, as ``(attribute_name, # constructor_kwarg)`` pairs in flatten order; everything else returned by @@ -427,7 +427,7 @@ def make_empty( result.requires_grad_(True) return result - # ----- Data-free buffer/metadata primitives backing TensorProto ----- + # ----- Data-free buffer/metadata primitives backing make_empty_traceable ----- def _describe_buffers( self, shape: Tuple[int, ...] @@ -440,7 +440,7 @@ def _describe_buffers( """ raise NotImplementedError( f"{self.__class__.__name__} does not implement _describe_buffers; " - "it cannot be used with TensorProto / pure-Python allocation" + "it cannot be used with traceable allocation" ) def _storage_metadata(self, fake_dtype: torch.dtype) -> Dict[str, Any]: @@ -454,7 +454,7 @@ def _storage_metadata(self, fake_dtype: torch.dtype) -> Dict[str, Any]: """ raise NotImplementedError( f"{self.__class__.__name__} does not implement _storage_metadata; " - "it cannot be used with TensorProto / pure-Python allocation" + "it cannot be used with traceable allocation" ) def alloc_tensors( diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index c816b4fb04..04e5281f65 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -73,7 +73,7 @@ def copy(self) -> Float8BlockQuantizer: def _value_fields(self) -> Tuple[str, ...]: return ("dtype", "block_len", "amax_epsilon", "force_pow_2_scales", "block_scaling_dim") - # ----- TensorProto / pure-Python allocation ----- + # ----- traceable allocation ----- def _storage_metadata(self, fake_dtype: torch.dtype) -> Dict[str, Any]: return { diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 6d3a53b3d7..423b055f15 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -393,7 +393,7 @@ def _value_fields(self) -> Tuple[str, ...]: # raises so it can never be baked into a torch.compile graph. return ("dtype", "force_pow_2_scales", "amax_epsilon", "with_amax_reduction") - # ----- TensorProto / pure-Python allocation ----- + # ----- traceable allocation ----- def _storage_metadata(self, fake_dtype: torch.dtype) -> Dict[str, Any]: return { diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index f804a96f24..bd51a88aaf 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -61,7 +61,7 @@ def copy(self) -> MXFP8Quantizer: def _value_fields(self) -> Tuple[str, ...]: return ("dtype",) - # ----- TensorProto / pure-Python allocation ----- + # ----- traceable allocation ----- def _storage_metadata(self, fake_dtype: torch.dtype) -> Dict[str, Any]: return { diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 8d51480f51..d22423353c 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -245,8 +245,14 @@ def copy(self) -> NVFP4Quantizer: ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm - quantizer.rht_matrix = self.rht_matrix quantizer.rht_matrix_random_sign_mask_t = self.rht_matrix_random_sign_mask_t + if not torch.compiler.is_compiling(): + # Under Dynamo tracing rht_matrix is a FakeTensor on an opaque script + # object; accessing it triggers SourcelessBuilder which cannot wrap + # FakeTensor. The fake impl never runs real quantization so the matrix + # is unnecessary -- it will be rebuilt lazily via _rebuild_derived_state + # if the quantizer is later used outside tracing. + quantizer.rht_matrix = self.rht_matrix return quantizer @@ -366,7 +372,7 @@ def _value_fields(self) -> Tuple[str, ...]: "with_amax_reduction", ) - # ----- TensorProto / pure-Python allocation ----- + # ----- traceable allocation ----- def _storage_metadata(self, fake_dtype: torch.dtype) -> Dict[str, Any]: return {