Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
f3401df
[PyTorch] Make tensorless quantizers opaque value objects for torch.c…
pggPL Jun 6, 2026
c4ad54c
[PyTorch] Drop quantizer value registry; reconstruct via __fx_repr__ …
pggPL Jun 6, 2026
a06324b
[PyTorch] Split dynamo.py into a dynamo/ package
pggPL Jun 7, 2026
ea5b396
[PyTorch] Raise in quantizer __fx_repr__ when a process group is stored
pggPL Jun 8, 2026
aa65e34
[PyTorch] Cover NVFP4 in quantizer value-object test
pggPL Jun 8, 2026
e1b1db6
Reject a value quantizer that carries an amax reduction group in __eq…
pggPL Jun 16, 2026
8c33d0e
Recognize value-opaque quantizers via a class flag
pggPL Jun 16, 2026
945f62d
Address review: narrow opaque-type except, add fullgraph test, fix nv…
pggPL Jun 29, 2026
e3c8f43
Restore NVFP4 rht_matrix on value-key rebuild; assert quantize round-…
pggPL Jun 29, 2026
3f68621
Enforce process-group rejection in _value_key, not __fx_repr__; add test
pggPL Jun 29, 2026
32d1768
Strengthen fullgraph test: quantize/dequantize via a custom op, not p…
pggPL Jun 29, 2026
28bde9e
Clarify comments: rht_matrix_random_sign_mask_t derivation; why the o…
pggPL Jun 29, 2026
2c3c5df
Reword opaque-flag comment: self-contained, no Linear reference
pggPL Jun 29, 2026
826f271
Cover is_opaque_value_type with the import-safety guard too
pggPL Jun 29, 2026
613c545
Stamp value-opaque flag only after successful registration
pggPL Jun 30, 2026
9db604f
Drop verbose comments around value-opaque flag stamping
pggPL Jun 30, 2026
3011dfd
Narrow value process-group check to amax_reduction_group
pggPL Jun 30, 2026
fe5e5db
Shorten amax_reduction_group check comment
pggPL Jun 30, 2026
f6b6d78
Merge remote-tracking branch 'upstream/main' into make_qunatizers_opaque
pggPL Jun 30, 2026
6f66c3e
Drop trivial value-equality boilerplate from quantizer test
pggPL Jun 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 170 additions & 1 deletion tests/pytorch/test_torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@
from transformer_engine.common import recipe
from transformer_engine.pytorch.constants import FP8FwdTensorIdx, FP8BwdTensorIdx
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.quantization import QuantizerRole
from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.quantization import QuantizerRole
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch import (
is_fp8_available,
is_mxfp8_available,
is_fp8_block_scaling_available,
is_nvfp4_available,
Float8Quantizer,
Float8BlockQuantizer,
MXFP8Quantizer,
)
from utils import recipe_id

Expand Down Expand Up @@ -384,3 +388,168 @@ def fn(inp):

out = compiled(inp)
out.sum().backward()


# ---------------------------------------------------------------------------
# Value-opaque quantizers
# ---------------------------------------------------------------------------


def _mxfp8(dtype=tex.DType.kFloat8E4M3):
return MXFP8Quantizer(fp8_dtype=dtype)


def _blockwise(force_pow_2_scales=True):
return Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=True,
force_pow_2_scales=force_pow_2_scales,
)


def _current_scaling(amax_epsilon=0.0):
return Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=torch.device("cpu"),
amax_epsilon=amax_epsilon,
)


def _nvfp4(with_rht=True):
# Default with_rht=True so the quantize round-trip below exercises the
# derived ``rht_matrix`` tensor (the field most likely to be dropped on
# value-key reconstruction).
return NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_rht=with_rht,
)


def _hw_available(quantizer):
"""Whether this HW can actually run the quantize kernel for *quantizer*."""
if isinstance(quantizer, MXFP8Quantizer):
return mxfp8_available
if isinstance(quantizer, NVFP4Quantizer):
return nvfp4_available
if isinstance(quantizer, Float8BlockQuantizer):
return fp8_block_scaling_available
return fp8_available # Float8CurrentScalingQuantizer


# (factory, kwargs producing a different-but-valid config)
_VALUE_QUANTIZERS = [
pytest.param(_mxfp8, id="mxfp8"),
pytest.param(_blockwise, id="float8_blockwise"),
pytest.param(_current_scaling, id="float8_current_scaling"),
pytest.param(
_nvfp4,
id="nvfp4",
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason="NVFP4Quantizer requires CUDA to construct",
),
),
]


@pytest.mark.parametrize("factory", _VALUE_QUANTIZERS)
def test_quantizer_value_object(factory):
"""Value semantics + ``__fx_repr__`` round-trip via the production FX path."""
a = factory()

# ``__fx_repr__`` (used by torch.compile codegen) rebuilds an equal object.
repr_str, globals_ = a.__fx_repr__()
rebuilt = eval(repr_str, dict(globals_)) # pylint: disable=eval-used
assert rebuilt == a and rebuilt is not a
assert hash(rebuilt) == hash(a)

# The rebuilt quantizer must also *behave* identically, not just compare
# equal: equality only looks at the value key, so a field the kernel needs
# but that is absent from the key (e.g. NVFP4's derived ``rht_matrix``) would
# slip through the checks above and only blow up at quantize time. Run the
# real quantize kernel on both and require bit-exact results.
if torch.cuda.is_available() and _hw_available(a):
x = torch.randn(128, 256, dtype=torch.bfloat16, device="cuda")
torch.testing.assert_close(rebuilt(x).dequantize(), a(x).dequantize(), rtol=0.0, atol=0.0)


def test_value_quantizer_rejects_process_group():
"""A value quantizer holding a live ProcessGroup must refuse to be turned
into a value key / FX constant (raise), not silently drop the group."""
import torch.distributed as dist # pylint: disable=import-outside-toplevel

created = not dist.is_initialized()
if created:
dist.init_process_group(backend="gloo", store=dist.HashStore(), rank=0, world_size=1)
try:
q = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
q.amax_reduction_group = dist.group.WORLD
# Every value-materialization path must reject it (hash, eq, __fx_repr__).
with pytest.raises(TypeError):
hash(q)
with pytest.raises(TypeError):
q.__fx_repr__()
finally:
if created:
dist.destroy_process_group()
Comment on lines +479 to +497

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand the purpose of this test (and the logic needed to make it work). Is it here to make sure that if somebody accidentally adds the group to the quantizer then it would loudly fail e.g. in the CI?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was suggestion of @kshitij12345

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context around the test - There was a possible code-path previously (now fixed) where having a ProcessGroup on the quantizer wouldn't trigger the unsupported TypeError (and fail later). This test verifies that the error would be raised with a clear message.



if _opaque_available:
# A minimal custom op taking a tensor and a value-opaque quantizer that
# quantizes + dequantizes inside it, one per production quantizer class.
# ``test_quantizer_value_object_fullgraph`` drives this under
# ``torch.compile(fullgraph=True)`` so the quantizer is used *inside* the
# graph -- proving the opaque-type registration took effect (a graph break
# would make ``fullgraph=True`` raise).
_qdq_lib = torch.library.Library("test_te_qdq", "DEF")
_QDQ_OPS = {}
for _qcls in (
MXFP8Quantizer,
Float8BlockQuantizer,
Float8CurrentScalingQuantizer,
NVFP4Quantizer,
):
_op = f"qdq_{_qcls.__name__}"
_qdq_lib.define(f"{_op}(Tensor x, {get_opaque_type_name(_qcls)} q) -> Tensor")

@torch.library.impl(f"test_te_qdq::{_op}", "CompositeExplicitAutograd", lib=_qdq_lib)
def _qdq_impl(x, q):
return q(x).dequantize()

@torch.library.register_fake(f"test_te_qdq::{_op}", lib=_qdq_lib)
def _qdq_fake(x, q):
return torch.empty_like(x)

_QDQ_OPS[_qcls] = getattr(torch.ops.test_te_qdq, _op)


@pytest.mark.skipif(
not _opaque_available,
reason="torch.compile opaque-object support requires PyTorch >= 2.11",
)
@pytest.mark.parametrize("factory", _VALUE_QUANTIZERS)
def test_quantizer_value_object_fullgraph(factory):
"""Quantizer is usable *inside* a torch.compile(fullgraph=True) graph.

A custom op quantizes+dequantizes with the (opaque value) quantizer; the
compiled result must match eager. ``fullgraph=True`` raises on any graph
break, so this proves the opaque-type registration actually took effect --
unlike merely passing the quantizer through.
"""
q = factory()
if not (torch.cuda.is_available() and _hw_available(q)):
pytest.skip("format not supported on this HW")

op = _QDQ_OPS[type(q)]
x = torch.randn(128, 256, dtype=torch.bfloat16, device="cuda")

def fn(inp):
return op(inp, q)

ref = fn(x)
torch._dynamo.reset()
out = torch.compile(fn, fullgraph=True)(x)
torch.testing.assert_close(out, ref, rtol=0.0, atol=0.0)
12 changes: 12 additions & 0 deletions transformer_engine/pytorch/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""torch.compile glue for Transformer Engine."""

from .quantizer_opaque import register_value_opaque_quantizer, is_value_opaque_quantizer

__all__ = [
"register_value_opaque_quantizer",
"is_value_opaque_quantizer",
]
114 changes: 114 additions & 0 deletions transformer_engine/pytorch/dynamo/quantizer_opaque.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Value-opaque quantizers for torch.compile."""

from __future__ import annotations
from typing import Any, Dict, Tuple

from ..constants import DType


# Registration marks the class with this attribute rather than recording it in a
# module-level set. It looks odd but is a deliberate workaround: the check must
# stay traceable when it runs inside a torch.compile graph -- Dynamo can bake a
# ``getattr`` on the opaque quantizer into a constant, but cannot evaluate
# ``type(q) in some_set`` (no equality/hash rules for the opaque class object),
# which would graph-break under ``fullgraph=True``.
_VALUE_OPAQUE_FLAG = "_te_compile_value_opaque"


def is_value_opaque_quantizer(quantizer: Any) -> bool:
"""Whether *quantizer*'s class is registered as a torch.compile value-opaque
type."""
return getattr(quantizer, _VALUE_OPAQUE_FLAG, False)


def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any:
"""Rebuild a tensorless quantizer of type *cls* from its value items.

Referenced by the ``__fx_repr__`` emitted for value-opaque quantizers; the
generated FX code calls this to materialize the quantizer constant.
"""
# Bypass ``__init__`` and restore the value attributes directly: the value
# items already capture every value-defining field (including derived ones),
# and the constructors have heterogeneous signatures / side effects.
Comment on lines +34 to +36

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh? What are the issues with just calling the constructor?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many things have different names in constructor than as parameters or are not present as args (like ".internal" or ".optimize_for_gemm")

obj = cls.__new__(cls)
field_names = set()
for name, value in items:
if name == "dtype":
value = DType.cast(value)
object.__setattr__(obj, name, value)
field_names.add(name)
# The deprecated amax-reduction group is not a value field; initialize it to
# None so attribute access keeps working on the rebuilt quantizer.
if "with_amax_reduction" in field_names and not hasattr(obj, "amax_reduction_group"):
object.__setattr__(obj, "amax_reduction_group", None)
# Restore non-value derived state that ``__init__`` would normally build but
# that cannot live in the value key (e.g. NVFP4's ``rht_matrix`` tensor).
finalize = getattr(obj, "_rebuild_derived_state", None)
if finalize is not None:
finalize()
return obj


def _quantizer_fx_repr(self: Any) -> Tuple[str, Dict[str, Any]]:
"""``__fx_repr__`` for value-opaque quantizers (attached at registration).

Returns an evaluable expression that rebuilds the quantizer via
:func:`_rebuild_quantizer`, capturing both the helper and the quantizer
class itself in the FX globals so codegen can resolve them with no global
registry and no qualname collisions.

Raises ``TypeError`` (via :meth:`Quantizer._value_key`) if the quantizer
stores a process group (e.g. a non-``None`` deprecated
``amax_reduction_group``): live distributed state must never be baked into
the graph as a constant. Pass the reduction group per quantize call instead
of storing it on the quantizer.
"""
cls = type(self)
items = self._value_key()[1]
return (
f"_rebuild_quantizer({cls.__name__}, {items!r})",
{"_rebuild_quantizer": _rebuild_quantizer, cls.__name__: cls},
)


def register_value_opaque_quantizer(cls: type) -> None:
"""Register a tensorless quantizer class as a torch.compile value opaque type.

Attaches ``__fx_repr__`` and registers the class with
``torch._library.opaque_object``. Safe to call on any PyTorch build: on
versions without the opaque-object API it only attaches ``__fx_repr__``
(harmless), so Transformer Engine keeps importing and running in eager mode.

The quantizer class must already provide value ``__eq__`` / ``__hash__`` and
a non-``None`` ``_value_fields`` (see
:class:`transformer_engine.pytorch.quantized_tensor.Quantizer`).
"""
# ``register_opaque_type`` requires ``__fx_repr__`` to already exist on the
# class, so attach it before registering.
if "__fx_repr__" not in cls.__dict__:
cls.__fx_repr__ = _quantizer_fx_repr
Comment thread
pggPL marked this conversation as resolved.
Comment on lines +90 to +93

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then why aren'twe just adding that to the Quantizer class? I feel that I miss some context here as a lot of this file feels like a workaround to some issues that are not described.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do this, I wanted to have as much of torch.compile specific logic here


try:
from torch._library.opaque_object import ( # pylint: disable=import-outside-toplevel
register_opaque_type,
is_opaque_value_type,
)
except (ImportError, AttributeError):
# Older PyTorch without the opaque-object API: eager value semantics
# still work; torch.compile specialization on the quantizer does not.
return

try:
if not is_opaque_value_type(cls):
register_opaque_type(cls, typ="value")
except (RuntimeError, TypeError):
# Keep TE importable: neither the opaque-type query nor the registration
# must crash the import, e.g. on PyTorch versions with only partial /
# experimental opaque-object support.
return

setattr(cls, _VALUE_OPAQUE_FLAG, True)
68 changes: 68 additions & 0 deletions transformer_engine/pytorch/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import transformer_engine_torch as tex

from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.constants import dist_group_type
from transformer_engine.pytorch.tensor._quantization_helpers import (
_QuantizeFunc,
_IdentityFunc,
Expand Down Expand Up @@ -408,6 +409,73 @@ def get_usages(self) -> Dict[str, bool]:
"columnwise": self.columnwise_usage,
}

#: Attributes shared by every quantizer that take part in value identity.
_BASE_VALUE_FIELDS: Tuple[str, ...] = (
"rowwise_usage",
"columnwise_usage",
"internal",
"optimize_for_gemm",
)

def _value_fields(self) -> Optional[Tuple[str, ...]]:
"""Subclass-specific value-defining attribute names, or ``None``.

Returning ``None`` (the default) means the quantizer cannot be represented as
a value opaque object and keeps identity-based equality/hashing.
This also means that passing such a quantizer as an argument to a custom op
causes a graph break under torch.compile, since it cannot be baked into the
FX graph as a constant.
"""
return None

def _check_value_has_no_process_group(self) -> None:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How often is this method called? Only once during torch.compile or does it need to be called every time during the graph launch to check the guards?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

every time - I make it faster by just checking amax_reduction_group

# A value quantizer cannot carry live distributed state into the FX
# graph; reject a stored ``amax_reduction_group`` and pass it per
# quantize call instead.
if isinstance(getattr(self, "amax_reduction_group", None), dist_group_type):
raise TypeError(
f"{type(self).__name__} cannot be used as a torch.compile value "
"object: 'amax_reduction_group' holds a torch.distributed.ProcessGroup, "
"which is live distributed state and must not be baked into an FX "
"graph. Pass the amax reduction group per quantize call instead of "
"storing it on the quantizer."
)

def _value_key(self) -> Tuple[Any, ...]:
"""Hashable, reproducible key identifying this quantizer's value.

Only valid for value quantizers (``_value_fields()`` is not ``None``).
"""
fields = self._value_fields() # pylint: disable=assignment-from-none
assert fields is not None, f"{type(self).__name__} is not a value quantizer"
self._check_value_has_no_process_group()
items = []
for name in self._BASE_VALUE_FIELDS + tuple(fields):
value = getattr(self, name)
if name == "dtype":
# ``DType`` is an ``IntEnum``; store the int so the key stays
# plain: hashable and ``repr``-reproducible for FX codegen.
value = int(value)
items.append((name, value))
return (type(self).__qualname__, tuple(items))

def __eq__(self, other: object) -> Any:
# Value quantizers compare by configuration; everything else keeps the
# default identity semantics (returning ``NotImplemented`` makes Python
# fall back to identity). ``_value_key`` rejects a stored ProcessGroup.
if self is other:
return True
if self._value_fields() is None or type(self) is not type(other):
return NotImplemented
if other._value_fields() is None:
return NotImplemented
return self._value_key() == other._value_key()
Comment thread
greptile-apps[bot] marked this conversation as resolved.

def __hash__(self) -> int:
if self._value_fields() is None:
return object.__hash__(self)
return hash(self._value_key())


class QuantizedTensor(torch.Tensor):
"""Abstract base class for tensor with quantized data
Expand Down
Loading
Loading