[PyTorch][torch.compile] Make quantizers opaque value objects#3152
[PyTorch][torch.compile] Make quantizers opaque value objects#3152pggPL wants to merge 20 commits into
Conversation
…ompile Give tensorless quantizers (MXFP8, FP8 blockwise, FP8 current-scaling, NVFP4) value-object semantics so torch.compile can treat them as baked-in constants: - Add opt-in value identity to the base Quantizer (_value_fields / _value_key / __eq__ / __hash__). Quantizers holding live tensors (delayed-scaling Float8Quantizer) and custom quantizers keep identity semantics. - New transformer_engine/pytorch/dynamo.py houses the torch.compile glue: __fx_repr__, value-key reconstruction and register_value_opaque_quantizer (gracefully a no-op on PyTorch builds without the opaque-object API). - Register the four tensorless quantizers as value opaque types. Also fix CustomRecipe state caching in TransformerEngineBaseModule: set_meta_tensor now rebuilds quantizers when the CustomRecipe instance changes (e.g. nested te.autocast regions) instead of reusing the first recipe's state, since every CustomRecipe shares the CustomRecipeState type but carries its own qfactory. Move the quantizer value-object tests into tests/pytorch/test_torch_compile.py and add that file to the L0 pytorch unittest QA suite. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…globals Follow-up to the value-opaque quantizer support: - Remove the module-level _QUANTIZER_VALUE_REGISTRY (qualname -> class) and _quantizer_from_value_key. __fx_repr__ now captures the quantizer class directly in the FX globals and reconstructs via _rebuild_quantizer(cls, items), matching how PyTorch's own value opaque types (e.g. DTensor placements) reconstruct themselves. This removes global mutable state and the qualname collision risk. - Consolidate the quantizer value-object tests in test_torch_compile.py down to two functions and exercise reconstruction through the public __fx_repr__ path instead of internal helpers. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Replace the single dynamo.py module with a dynamo/ package so the
torch.compile glue can grow with a clear responsibility split across the
stacked branches. This branch owns the value-opaque quantizer layer.
* dynamo/quantizer_opaque.py -- register_value_opaque_quantizer and helpers
* dynamo/__init__.py -- re-exports the public API so callers keep importing
from transformer_engine.pytorch.dynamo unchanged
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
A value-opaque quantizer must not carry live distributed state. Scan the quantizer attributes in __fx_repr__ and raise TypeError if any holds a torch.distributed.ProcessGroup (e.g. a non-None deprecated amax_reduction_group), so it cannot be silently baked into a torch.compile FX graph. Clarify the related comments accordingly. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
NVFP4Quantizer is registered as a value-opaque quantizer but was missing from the value-semantics / __fx_repr__ round-trip test. Add it to _VALUE_QUANTIZERS (skipped without CUDA, which it needs to construct). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…__/__hash__ The amax reduction group is excluded from the value key, so a value quantizer that stored one would compare/hash equal to a groupless one and let torch.compile reuse a graph that skips the reduction. __eq__/__hash__ now raise (mirroring __fx_repr__, which already rejects any process-group-bearing quantizer). The group should be passed per quantize call, not stored on the quantizer. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Add is_value_opaque_quantizer() + the _te_compile_value_opaque flag stamped at registration, so dynamo-traced code can detect registered quantizers (and fall back to eager for unregistered ones). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…fp4 value key - Narrow register_opaque_type except to (RuntimeError, TypeError): the API is already imported above, so ImportError/AttributeError there only mask real errors. - Add test_quantizer_value_object_fullgraph exercising torch.compile(fullgraph=True) end-to-end to verify opaque-type registration took effect. - Restore missing NVFP4Quantizer._with_random_sign_mask assignment required by _value_fields()/_value_key(). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
|
/te-ci pytorch |
Greptile SummaryThis PR turns stateless TE quantizers (
Confidence Score: 4/5Safe to merge after fixing the The transformer_engine/pytorch/tensor/nvfp4_tensor.py — the Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant User
participant Compile as torch.compile
participant Q as Quantizer
participant RQ as _rebuild_quantizer
participant NQ as NVFP4Quantizer
User->>Compile: "compile(fn, fullgraph=True)"
Compile->>Q: __fx_repr__()
Q->>Q: _value_key() checks no process group
Q-->>Compile: repr string + globals dict
Note over Compile: bake repr into FX graph constant
Compile->>RQ: cls.__new__(cls) then set fields from items
RQ->>NQ: _rebuild_derived_state()
NQ->>NQ: get_rht_matrix(_with_random_sign_mask, device)
RQ-->>Compile: reconstructed quantizer
Compile-->>User: compiled fn with quantizer baked in
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant User
participant Compile as torch.compile
participant Q as Quantizer
participant RQ as _rebuild_quantizer
participant NQ as NVFP4Quantizer
User->>Compile: "compile(fn, fullgraph=True)"
Compile->>Q: __fx_repr__()
Q->>Q: _value_key() checks no process group
Q-->>Compile: repr string + globals dict
Note over Compile: bake repr into FX graph constant
Compile->>RQ: cls.__new__(cls) then set fields from items
RQ->>NQ: _rebuild_derived_state()
NQ->>NQ: get_rht_matrix(_with_random_sign_mask, device)
RQ-->>Compile: reconstructed quantizer
Compile-->>User: compiled fn with quantizer baked in
|
…trip _rebuild_quantizer only restores value-key fields, so a reconstructed NVFP4Quantizer was missing the derived rht_matrix tensor (not hashable, so not in the value key) and failed at copy()/quantize time. Add a _rebuild_derived_state hook (called by _rebuild_quantizer) that NVFP4Quantizer uses to rebuild rht_matrix from _with_random_sign_mask (lru_cache -> cheap). Extend test_quantizer_value_object to also quantize with the original and the rebuilt quantizer and require bit-exact results (gated on HW support), so a field the kernel needs but the value key omits can no longer slip through. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
kshitij12345
left a comment
There was a problem hiding this comment.
Overall LGTM, would be good to resolve the inline comments before merging.
Move the ProcessGroup guard out of the (overridable) __fx_repr__ into Quantizer._value_key -- the single point every value-materialization path (__eq__/__hash__/__fx_repr__) goes through -- so a custom __fx_repr__ can no longer bypass it. Generalizes the old amax-only check to any field holding a ProcessGroup. Add a test that a value quantizer carrying a live group raises. Addresses review on NVIDIA#3152. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…assthrough Replace the trivial pass-through fullgraph test with one that drives each production quantizer through a minimal custom op (quantize + dequantize) under torch.compile(fullgraph=True) and compares to eager -- so the opaque-type registration is actually exercised inside the graph (a graph break would make fullgraph=True raise). Op registration sits right before the test. Also drop stale comments referencing the old __fx_repr__-side process-group guard. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…paque flag - rht_matrix_random_sign_mask_t is a device-independent int derived from _with_random_sign_mask (the device only places a throwaway tensor); fix the misleading comment. - Explain why registration uses a class attribute, not a registry set: is_value_opaque_quantizer is traced inside the compile graph and dynamo can bake a getattr constant but cannot do 'type(q) in set' on the opaque class. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
is_opaque_value_type(cls) sat between the import guard and the register_opaque_type guard, so on a partial/experimental opaque-object build it could raise RuntimeError/TypeError and crash TE import. Move it inside the same except so the 'registration never crashes import' promise holds for both calls. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
|
/te-ci pytorch |
| a, b = factory(), factory() | ||
| # Same config -> equal, same hash, interchangeable as a dict/set key. | ||
| assert a is not b | ||
| assert a == b | ||
| assert hash(a) == hash(b) | ||
| assert {a: "x"}[b] == "x" | ||
| # Different config -> not equal. | ||
| assert a != factory(**other_kwargs) |
There was a problem hiding this comment.
What exactly are we testing in this part?
There was a problem hiding this comment.
nothing, I can remove it it seems to be claude boilerplate
| def test_value_quantizer_rejects_process_group(): | ||
| """A value quantizer holding a live ProcessGroup must refuse to be turned | ||
| into a value key / FX constant (raise), not silently drop the group.""" | ||
| import torch.distributed as dist # pylint: disable=import-outside-toplevel | ||
|
|
||
| created = not dist.is_initialized() | ||
| if created: | ||
| dist.init_process_group(backend="gloo", store=dist.HashStore(), rank=0, world_size=1) | ||
| try: | ||
| q = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) | ||
| q.amax_reduction_group = dist.group.WORLD | ||
| # Every value-materialization path must reject it (hash, eq, __fx_repr__). | ||
| with pytest.raises(TypeError): | ||
| hash(q) | ||
| with pytest.raises(TypeError): | ||
| q.__fx_repr__() | ||
| finally: | ||
| if created: | ||
| dist.destroy_process_group() |
There was a problem hiding this comment.
I'm not sure I understand the purpose of this test (and the logic needed to make it work). Is it here to make sure that if somebody accidentally adds the group to the quantizer then it would loudly fail e.g. in the CI?
There was a problem hiding this comment.
Context around the test - There was a possible code-path previously (now fixed) where having a ProcessGroup on the quantizer wouldn't trigger the unsupported TypeError (and fail later). This test verifies that the error would be raised with a clear message.
| # Bypass ``__init__`` and restore the value attributes directly: the value | ||
| # items already capture every value-defining field (including derived ones), | ||
| # and the constructors have heterogeneous signatures / side effects. |
There was a problem hiding this comment.
Huh? What are the issues with just calling the constructor?
There was a problem hiding this comment.
Many things have different names in constructor than as parameters or are not present as args (like ".internal" or ".optimize_for_gemm")
| # ``register_opaque_type`` requires ``__fx_repr__`` to already exist on the | ||
| # class, so attach it before registering. | ||
| if "__fx_repr__" not in cls.__dict__: | ||
| cls.__fx_repr__ = _quantizer_fx_repr |
There was a problem hiding this comment.
Then why aren'twe just adding that to the Quantizer class? I feel that I miss some context here as a lot of this file feels like a workaround to some issues that are not described.
There was a problem hiding this comment.
We can do this, I wanted to have as much of torch.compile specific logic here
| """ | ||
| return None | ||
|
|
||
| def _check_value_has_no_process_group(self) -> None: |
There was a problem hiding this comment.
How often is this method called? Only once during torch.compile or does it need to be called every time during the graph launch to check the guards?
There was a problem hiding this comment.
every time - I make it faster by just checking amax_reduction_group
| return quantizer | ||
|
|
||
| def _value_fields(self) -> Tuple[str, ...]: | ||
| return ("dtype", "block_len", "amax_epsilon", "force_pow_2_scales", "block_scaling_dim") |
There was a problem hiding this comment.
There are some ways of getting that list without hardcoding it again, e.g. using inspect.get_annotations in the cases like we have where we provide all of the data members up front or via vars.
There was a problem hiding this comment.
I see It's quite messy, but it is result of mess in quantizers. There are some fields called in the constructor (and arguments names sometimes does not match the field name). Some fields are set like .internal or .block_len are not arguments to constructor. And some of them are excluded on purpose like amax_reduction_group or rht_matrix.
| ``_rebuild_quantizer`` calls this hook to rebuild it; the ``lru_cache`` on | ||
| :func:`get_rht_matrix` makes an already-seen (flag, device) a cheap hit. | ||
| """ | ||
| self.rht_matrix = get_rht_matrix(self._with_random_sign_mask, torch.cuda.current_device()) |
There was a problem hiding this comment.
Using current device here is a little annoying since in principle it does not have to be the right device (which instead should be the input device).
There was a problem hiding this comment.
It's not problem of this PR, we use current device in constructor.
Move the _VALUE_OPAQUE_FLAG setattr to the end of register_value_opaque_quantizer, after register_opaque_type succeeds (or the type is already opaque). Previously the flag was set up front, so is_value_opaque_quantizer reported True even when the opaque-object API was missing or registration raised, since both paths are swallowed. Eager value semantics (__eq__/__hash__/__fx_repr__) are independent of the flag, so this only tightens the predicate to mean torch actually knows the type as opaque. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
_check_value_has_no_process_group ran on every guard eval (via __eq__/__hash__) and scanned all of vars(self) recursively. The only attribute that can hold a ProcessGroup is the deprecated amax_reduction_group, so check it directly (O(1)) and drop the _contains_process_group helper. Same guarantee, off the hot path. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
# Conflicts: # transformer_engine/pytorch/tensor/float8_blockwise_tensor.py # transformer_engine/pytorch/tensor/float8_tensor.py # transformer_engine/pytorch/tensor/mxfp8_tensor.py # transformer_engine/pytorch/tensor/nvfp4_tensor.py
Remove the a==b / hash / dict-key block that just exercised Python's own dict semantics; equality and hashing are still covered by the __fx_repr__ round-trip (rebuilt == a, hash match) and the bit-exact kernel check. other_kwargs is now unused, so drop it from the parametrization and both test signatures. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Description
Tensorless quantizers in TE (MXFP8, FP8 blockwise, FP8 current-scaling, NVFP4)
are fully described by a handful of plain, reproducible scalars — they hold no
live tensors and no process groups. This PR turns them into opaque value
objects so
torch.compilecan treat them as baked-in constants: twoquantizers with the same configuration become interchangeable, hashable, and
reconstructible inside an FX graph.
Quantizers that hold live state (delayed-scaling
Float8Quantizer, which keepsscale/amaxtensors) and any user-defined quantizer keep the defaultidentity semantics, so the change is opt-in and backward compatible. On older
PyTorch builds without the opaque-object API the registration is a graceful
no-op.
Along the way this also un-breaks the existing
test_torch_compile.pysuite:that file lived on
mainbut was never wired into CI, and itstest_autocast_nested_customcase (nestedte.autocastwith multipleCustomRecipeinstances) was failing because of theCustomRecipestate-cachingbug fixed here. The file is now run in CI and passes.
Type of change
Changes
Quantizer(
_value_fields/_value_key/__eq__/__hash__). ReturningNonefrom
_value_fields()(the default) keeps identity semantics.transformer_engine/pytorch/dynamo.pyholding thetorch.compileglue:__fx_repr__, value-key reconstruction andregister_value_opaque_quantizer(gracefully no-op without PyTorch'sopaque-object API).
MXFP8Quantizer,Float8BlockQuantizer,Float8CurrentScalingQuantizerandNVFP4Quantizeras value opaque types(the deprecated
amax_reduction_groupis never part of the value).CustomRecipestate caching inTransformerEngineBaseModule.set_meta_tensor:rebuild quantizers when the
CustomRecipeinstance changes (e.g. nestedte.autocastregions) instead of reusing the first recipe's state, sinceevery
CustomRecipeshares theCustomRecipeStatetype but carries its ownqfactory. This fixes the previously-failingtest_autocast_nested_custom.tests/pytorch/test_torch_compile.pyin theL0_pytorch_unittestQAsuite (it existed on
mainbut was never run in CI), and add the quantizervalue-object tests to it. Bringing it into CI required fixing the existing
CustomRecipetorch.compile path: theqfactorynow dispatches onQuantizerRole.tensor_typesupplied byToyLinear.get_quantizer_roles.__fx_repr__already rejects any quantizer holding a process group, and
__eq__/__hash__now raise too. The group is excluded from the value key, so a stored group would
otherwise compare/hash equal to a groupless quantizer and let
torch.compilereuse a graph that skips the reduction. Pass the group per quantize call instead.
Checklist: