-
Notifications
You must be signed in to change notification settings - Fork 763
[PyTorch][torch.compile] Make quantizers opaque value objects #3152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f3401df
c4ad54c
a06324b
ea5b396
aa65e34
e1b1db6
8c33d0e
945f62d
e3c8f43
3f68621
32d1768
28bde9e
2c3c5df
826f271
613c545
9db604f
3011dfd
fe5e5db
f6b6d78
6f66c3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """torch.compile glue for Transformer Engine.""" | ||
|
|
||
| from .quantizer_opaque import register_value_opaque_quantizer, is_value_opaque_quantizer | ||
|
|
||
| __all__ = [ | ||
| "register_value_opaque_quantizer", | ||
| "is_value_opaque_quantizer", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Value-opaque quantizers for torch.compile.""" | ||
|
|
||
| from __future__ import annotations | ||
| from typing import Any, Dict, Tuple | ||
|
|
||
| from ..constants import DType | ||
|
|
||
|
|
||
| # Registration marks the class with this attribute rather than recording it in a | ||
| # module-level set. It looks odd but is a deliberate workaround: the check must | ||
| # stay traceable when it runs inside a torch.compile graph -- Dynamo can bake a | ||
| # ``getattr`` on the opaque quantizer into a constant, but cannot evaluate | ||
| # ``type(q) in some_set`` (no equality/hash rules for the opaque class object), | ||
| # which would graph-break under ``fullgraph=True``. | ||
| _VALUE_OPAQUE_FLAG = "_te_compile_value_opaque" | ||
|
|
||
|
|
||
| def is_value_opaque_quantizer(quantizer: Any) -> bool: | ||
| """Whether *quantizer*'s class is registered as a torch.compile value-opaque | ||
| type.""" | ||
| return getattr(quantizer, _VALUE_OPAQUE_FLAG, False) | ||
|
|
||
|
|
||
| def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any: | ||
| """Rebuild a tensorless quantizer of type *cls* from its value items. | ||
|
|
||
| Referenced by the ``__fx_repr__`` emitted for value-opaque quantizers; the | ||
| generated FX code calls this to materialize the quantizer constant. | ||
| """ | ||
| # Bypass ``__init__`` and restore the value attributes directly: the value | ||
| # items already capture every value-defining field (including derived ones), | ||
| # and the constructors have heterogeneous signatures / side effects. | ||
|
Comment on lines
+34
to
+36
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Huh? What are the issues with just calling the constructor?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Many things have different names in constructor than as parameters or are not present as args (like ".internal" or ".optimize_for_gemm") |
||
| obj = cls.__new__(cls) | ||
| field_names = set() | ||
| for name, value in items: | ||
| if name == "dtype": | ||
| value = DType.cast(value) | ||
| object.__setattr__(obj, name, value) | ||
| field_names.add(name) | ||
| # The deprecated amax-reduction group is not a value field; initialize it to | ||
| # None so attribute access keeps working on the rebuilt quantizer. | ||
| if "with_amax_reduction" in field_names and not hasattr(obj, "amax_reduction_group"): | ||
| object.__setattr__(obj, "amax_reduction_group", None) | ||
| # Restore non-value derived state that ``__init__`` would normally build but | ||
| # that cannot live in the value key (e.g. NVFP4's ``rht_matrix`` tensor). | ||
| finalize = getattr(obj, "_rebuild_derived_state", None) | ||
| if finalize is not None: | ||
| finalize() | ||
| return obj | ||
|
|
||
|
|
||
| def _quantizer_fx_repr(self: Any) -> Tuple[str, Dict[str, Any]]: | ||
| """``__fx_repr__`` for value-opaque quantizers (attached at registration). | ||
|
|
||
| Returns an evaluable expression that rebuilds the quantizer via | ||
| :func:`_rebuild_quantizer`, capturing both the helper and the quantizer | ||
| class itself in the FX globals so codegen can resolve them with no global | ||
| registry and no qualname collisions. | ||
|
|
||
| Raises ``TypeError`` (via :meth:`Quantizer._value_key`) if the quantizer | ||
| stores a process group (e.g. a non-``None`` deprecated | ||
| ``amax_reduction_group``): live distributed state must never be baked into | ||
| the graph as a constant. Pass the reduction group per quantize call instead | ||
| of storing it on the quantizer. | ||
| """ | ||
| cls = type(self) | ||
| items = self._value_key()[1] | ||
| return ( | ||
| f"_rebuild_quantizer({cls.__name__}, {items!r})", | ||
| {"_rebuild_quantizer": _rebuild_quantizer, cls.__name__: cls}, | ||
| ) | ||
|
|
||
|
|
||
| def register_value_opaque_quantizer(cls: type) -> None: | ||
| """Register a tensorless quantizer class as a torch.compile value opaque type. | ||
|
|
||
| Attaches ``__fx_repr__`` and registers the class with | ||
| ``torch._library.opaque_object``. Safe to call on any PyTorch build: on | ||
| versions without the opaque-object API it only attaches ``__fx_repr__`` | ||
| (harmless), so Transformer Engine keeps importing and running in eager mode. | ||
|
|
||
| The quantizer class must already provide value ``__eq__`` / ``__hash__`` and | ||
| a non-``None`` ``_value_fields`` (see | ||
| :class:`transformer_engine.pytorch.quantized_tensor.Quantizer`). | ||
| """ | ||
| # ``register_opaque_type`` requires ``__fx_repr__`` to already exist on the | ||
| # class, so attach it before registering. | ||
| if "__fx_repr__" not in cls.__dict__: | ||
| cls.__fx_repr__ = _quantizer_fx_repr | ||
|
pggPL marked this conversation as resolved.
Comment on lines
+90
to
+93
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then why aren'twe just adding that to the Quantizer class? I feel that I miss some context here as a lot of this file feels like a workaround to some issues that are not described.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can do this, I wanted to have as much of torch.compile specific logic here |
||
|
|
||
| try: | ||
| from torch._library.opaque_object import ( # pylint: disable=import-outside-toplevel | ||
| register_opaque_type, | ||
| is_opaque_value_type, | ||
| ) | ||
| except (ImportError, AttributeError): | ||
| # Older PyTorch without the opaque-object API: eager value semantics | ||
| # still work; torch.compile specialization on the quantizer does not. | ||
| return | ||
|
|
||
| try: | ||
| if not is_opaque_value_type(cls): | ||
| register_opaque_type(cls, typ="value") | ||
| except (RuntimeError, TypeError): | ||
| # Keep TE importable: neither the opaque-type query nor the registration | ||
| # must crash the import, e.g. on PyTorch versions with only partial / | ||
| # experimental opaque-object support. | ||
| return | ||
|
|
||
| setattr(cls, _VALUE_OPAQUE_FLAG, True) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| import transformer_engine_torch as tex | ||
|
|
||
| from transformer_engine.common.recipe import Recipe | ||
| from transformer_engine.pytorch.constants import dist_group_type | ||
| from transformer_engine.pytorch.tensor._quantization_helpers import ( | ||
| _QuantizeFunc, | ||
| _IdentityFunc, | ||
|
|
@@ -408,6 +409,73 @@ def get_usages(self) -> Dict[str, bool]: | |
| "columnwise": self.columnwise_usage, | ||
| } | ||
|
|
||
| #: Attributes shared by every quantizer that take part in value identity. | ||
| _BASE_VALUE_FIELDS: Tuple[str, ...] = ( | ||
| "rowwise_usage", | ||
| "columnwise_usage", | ||
| "internal", | ||
| "optimize_for_gemm", | ||
| ) | ||
|
|
||
| def _value_fields(self) -> Optional[Tuple[str, ...]]: | ||
| """Subclass-specific value-defining attribute names, or ``None``. | ||
|
|
||
| Returning ``None`` (the default) means the quantizer cannot be represented as | ||
| a value opaque object and keeps identity-based equality/hashing. | ||
| This also means that passing such a quantizer as an argument to a custom op | ||
| causes a graph break under torch.compile, since it cannot be baked into the | ||
| FX graph as a constant. | ||
| """ | ||
| return None | ||
|
|
||
| def _check_value_has_no_process_group(self) -> None: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How often is this method called? Only once during torch.compile or does it need to be called every time during the graph launch to check the guards?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. every time - I make it faster by just checking amax_reduction_group |
||
| # A value quantizer cannot carry live distributed state into the FX | ||
| # graph; reject a stored ``amax_reduction_group`` and pass it per | ||
| # quantize call instead. | ||
| if isinstance(getattr(self, "amax_reduction_group", None), dist_group_type): | ||
| raise TypeError( | ||
| f"{type(self).__name__} cannot be used as a torch.compile value " | ||
| "object: 'amax_reduction_group' holds a torch.distributed.ProcessGroup, " | ||
| "which is live distributed state and must not be baked into an FX " | ||
| "graph. Pass the amax reduction group per quantize call instead of " | ||
| "storing it on the quantizer." | ||
| ) | ||
|
|
||
| def _value_key(self) -> Tuple[Any, ...]: | ||
| """Hashable, reproducible key identifying this quantizer's value. | ||
|
|
||
| Only valid for value quantizers (``_value_fields()`` is not ``None``). | ||
| """ | ||
| fields = self._value_fields() # pylint: disable=assignment-from-none | ||
| assert fields is not None, f"{type(self).__name__} is not a value quantizer" | ||
| self._check_value_has_no_process_group() | ||
| items = [] | ||
| for name in self._BASE_VALUE_FIELDS + tuple(fields): | ||
| value = getattr(self, name) | ||
| if name == "dtype": | ||
| # ``DType`` is an ``IntEnum``; store the int so the key stays | ||
| # plain: hashable and ``repr``-reproducible for FX codegen. | ||
| value = int(value) | ||
| items.append((name, value)) | ||
| return (type(self).__qualname__, tuple(items)) | ||
|
|
||
| def __eq__(self, other: object) -> Any: | ||
| # Value quantizers compare by configuration; everything else keeps the | ||
| # default identity semantics (returning ``NotImplemented`` makes Python | ||
| # fall back to identity). ``_value_key`` rejects a stored ProcessGroup. | ||
| if self is other: | ||
| return True | ||
| if self._value_fields() is None or type(self) is not type(other): | ||
| return NotImplemented | ||
| if other._value_fields() is None: | ||
| return NotImplemented | ||
| return self._value_key() == other._value_key() | ||
|
greptile-apps[bot] marked this conversation as resolved.
|
||
|
|
||
| def __hash__(self) -> int: | ||
| if self._value_fields() is None: | ||
| return object.__hash__(self) | ||
| return hash(self._value_key()) | ||
|
|
||
|
|
||
| class QuantizedTensor(torch.Tensor): | ||
| """Abstract base class for tensor with quantized data | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand the purpose of this test (and the logic needed to make it work). Is it here to make sure that if somebody accidentally adds the group to the quantizer then it would loudly fail e.g. in the CI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was suggestion of @kshitij12345
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Context around the test - There was a possible code-path previously (now fixed) where having a ProcessGroup on the quantizer wouldn't trigger the unsupported TypeError (and fail later). This test verifies that the error would be raised with a clear message.