[PyTorch][torch.compile] Make quantizers opaque value objects#7
Open
pggPL wants to merge 5 commits into
Open
Conversation
…ompile Give tensorless quantizers (MXFP8, FP8 blockwise, FP8 current-scaling, NVFP4) value-object semantics so torch.compile can treat them as baked-in constants: - Add opt-in value identity to the base Quantizer (_value_fields / _value_key / __eq__ / __hash__). Quantizers holding live tensors (delayed-scaling Float8Quantizer) and custom quantizers keep identity semantics. - New transformer_engine/pytorch/dynamo.py houses the torch.compile glue: __fx_repr__, value-key reconstruction and register_value_opaque_quantizer (gracefully a no-op on PyTorch builds without the opaque-object API). - Register the four tensorless quantizers as value opaque types. Also fix CustomRecipe state caching in TransformerEngineBaseModule: set_meta_tensor now rebuilds quantizers when the CustomRecipe instance changes (e.g. nested te.autocast regions) instead of reusing the first recipe's state, since every CustomRecipe shares the CustomRecipeState type but carries its own qfactory. Move the quantizer value-object tests into tests/pytorch/test_torch_compile.py and add that file to the L0 pytorch unittest QA suite. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…globals Follow-up to the value-opaque quantizer support: - Remove the module-level _QUANTIZER_VALUE_REGISTRY (qualname -> class) and _quantizer_from_value_key. __fx_repr__ now captures the quantizer class directly in the FX globals and reconstructs via _rebuild_quantizer(cls, items), matching how PyTorch's own value opaque types (e.g. DTensor placements) reconstruct themselves. This removes global mutable state and the qualname collision risk. - Consolidate the quantizer value-object tests in test_torch_compile.py down to two functions and exercise reconstruction through the public __fx_repr__ path instead of internal helpers. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Replace the single dynamo.py module with a dynamo/ package so the
torch.compile glue can grow with a clear responsibility split across the
stacked branches. This branch owns the value-opaque quantizer layer.
* dynamo/quantizer_opaque.py -- register_value_opaque_quantizer and helpers
* dynamo/__init__.py -- re-exports the public API so callers keep importing
from transformer_engine.pytorch.dynamo unchanged
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
A value-opaque quantizer must not carry live distributed state. Scan the quantizer attributes in __fx_repr__ and raise TypeError if any holds a torch.distributed.ProcessGroup (e.g. a non-None deprecated amax_reduction_group), so it cannot be silently baked into a torch.compile FX graph. Clarify the related comments accordingly. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
NVFP4Quantizer is registered as a value-opaque quantizer but was missing from the value-semantics / __fx_repr__ round-trip test. Add it to _VALUE_QUANTIZERS (skipped without CUDA, which it needs to construct). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
kshitij12345
approved these changes
Jun 9, 2026
| repr_str, globals_ = a.__fx_repr__() | ||
| rebuilt = eval(repr_str, dict(globals_)) # pylint: disable=eval-used | ||
| assert rebuilt == a and rebuilt is not a | ||
| assert hash(rebuilt) == hash(a) |
There was a problem hiding this comment.
It would be good to also test that torch.compile(fullgraph=True) + quantizer to verify that registration actually worked and won't be broken.
def fn(quantizer):
return quantizer
torch.compile(fn, fullgraph=True)(some_quantizer)|
|
||
| try: | ||
| register_opaque_type(cls, typ="value") | ||
| except (ImportError, AttributeError, RuntimeError, TypeError): |
There was a problem hiding this comment.
I don't think we should catch ImportError and AttributeError here.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Tensorless quantizers in TE (MXFP8, FP8 blockwise, FP8 current-scaling, NVFP4)
are fully described by a handful of plain, reproducible scalars — they hold no
live tensors and no process groups. This PR turns them into opaque value
objects so
torch.compilecan treat them as baked-in constants: twoquantizers with the same configuration become interchangeable, hashable, and
reconstructible inside an FX graph.
Quantizers that hold live state (delayed-scaling
Float8Quantizer, which keepsscale/amaxtensors) and any user-defined quantizer keep the defaultidentity semantics, so the change is opt-in and backward compatible. On older
PyTorch builds without the opaque-object API the registration is a graceful
no-op.
Along the way this also un-breaks the existing
test_torch_compile.pysuite:that file lived on
mainbut was never wired into CI, and itstest_autocast_nested_customcase (nestedte.autocastwith multipleCustomRecipeinstances) was failing because of theCustomRecipestate-cachingbug fixed here. The file is now run in CI and passes.
Type of change
Changes
Quantizer(
_value_fields/_value_key/__eq__/__hash__). ReturningNonefrom
_value_fields()(the default) keeps identity semantics.transformer_engine/pytorch/dynamo.pyholding thetorch.compileglue:__fx_repr__, value-key reconstruction andregister_value_opaque_quantizer(gracefully no-op without PyTorch'sopaque-object API).
MXFP8Quantizer,Float8BlockQuantizer,Float8CurrentScalingQuantizerandNVFP4Quantizeras value opaque types(the deprecated
amax_reduction_groupis never part of the value).CustomRecipestate caching inTransformerEngineBaseModule.set_meta_tensor:rebuild quantizers when the
CustomRecipeinstance changes (e.g. nestedte.autocastregions) instead of reusing the first recipe's state, sinceevery
CustomRecipeshares theCustomRecipeStatetype but carries its ownqfactory. This fixes the previously-failingtest_autocast_nested_custom.tests/pytorch/test_torch_compile.pyin theL0_pytorch_unittestQAsuite (it existed on
mainbut was never run in CI), and add the quantizervalue-object tests to it. Bringing it into CI required fixing the existing
CustomRecipetorch.compile path: theqfactorynow dispatches onQuantizerRole.tensor_typesupplied byToyLinear.get_quantizer_roles.Checklist: