[PyTorch][torch.compile] Add TensorProto mechanism #8
Open
pggPL wants to merge 18 commits into
Open
Conversation
…location Introduce TensorProto: a data-free prototype of a (possibly quantized) tensor that captures shape/dtype and a value-opaque quantizer, and can materialize the tensor purely in Python (Quantizer.alloc_tensors + storage __tensor_unflatten__), so it traces under torch.compile with no graph break. - Add the PyTorch wrapper-subclass flatten protocol (__tensor_flatten__ / __tensor_unflatten__) to QuantizedTensorStorage, driven by a per-class _FLATTEN_TENSOR_BUFFERS declaration, plus a storage-class registry for qualname-based reconstruction in FX graphs. - Add pure-Python, traceable allocation to Quantizer (alloc_tensors, create_metadata, _describe_buffers, _storage_scalars, _resolve_storage_cls) implemented for FP8 current-scaling, MXFP8 and FP8 blockwise quantizers. - Add allocation-free fake forward/backward for BasicLinear operating on TensorProto, as a basis for torch.library custom-op fake impls. - Add CPU smoke tests and torch.compile fullgraph tests. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The allocation-free TensorProto fakes belong next to the real _linear_forward_impl / _linear_backward in module/linear.py rather than in the ops BasicLinear. Add _linear_forward_impl_fake and _linear_backward_impl_fake that take the LinearFwdArgs/LinearBwdArgs bags (with TensorProto in the tensor fields) and return protos mirroring the real impls' output/saved-tensor and (wgrad, dgrad, grad_bias) contracts. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The forward fake now follows _linear_forward_impl's control flow and returns the full (out, new_weight_workspace, tensors_to_save_from_forward, None, ctx_attrs) tuple: it models the quantize_weight workspace (returned only when a fresh FP8 workspace is created and cached), the (saved_inputmat, weightmat, saved_weight, bias) save layout with alias dedup, and the save_original_input / FSDP2 branches. The backward fake quantizes wgrad with grad_weight_quantizer to match the real wgrad GEMM. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
NVFP4Quantizer is a value-opaque (tensorless) quantizer but was missing the pure-Python allocation hooks. Add _describe_buffers / _storage_scalars / _resolve_storage_cls for it (FP4 data packed 2-per-byte as uint8, E4M3 scales as uint8, FP32 amax -- per-row when row_scaled_nvfp4, else scalar) and declare NVFP4TensorStorage._FLATTEN_TENSOR_BUFFERS (rowwise/columnwise data + scale_inv + amax). Also rework test_tensor_proto_flatten_roundtrip_compiles: instead of calling __tensor_flatten__ / __tensor_unflatten__ directly inside a compiled fn (unsupported by Dynamo, so the test was red for every quantizer), pass the quantized tensor across the compile boundary so Dynamo/AOTAutograd runs the protocol itself. Cover NVFP4 in the smoke and compile tests, guarded by CUDA since NVFP4Quantizer builds its RHT matrix on the current device. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Mirror _linear_forward_impl/_linear_backward more closely in the fake impls: set_usage on input/weight/output/grad-input quantizers, input/weight pipeline classification with name-based saved-tensor aliasing, honor fuse_wgrad_accumulation and use_bias+requires_wgrad in the backward fake, and reject manual TE FSDP (fsdp_group) with a clear NotImplementedError. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Merge the dynamo/ package split and carve the TensorProto layer into its own module. This branch owns the data-free tensor description used by the custom-op fake impls. * dynamo/tensor_proto.py -- TensorProto, to_tensor_proto, _contiguous_stride * dynamo/quantizer_opaque.py -- inherited unchanged from make_qunatizers_opaque * dynamo/__init__.py -- re-exports TensorProto / to_tensor_proto too Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The base Quantizer.create_metadata keeps shape in its signature for API symmetry with callers, but the default implementation does not use it. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The forward/backward fake implementations rejected tensor-parallel configs outright. TP/SP communication is opaque to torch.compile (it runs inside the eager custom op), so the only shape effect to model is on the output's leading dim: column+sequence-parallel all-gathers the input (leading *= tp_size) and row+sequence-parallel reduce-scatters the output (leading //= tp_size). The saved input and gradients stay rank-local, so the backward fake needs no extra modeling. Also widen LinearFwdArgs.inp to TensorOrQuantized so quantized inputs flatten correctly. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The fake forward (`_linear_forward_impl_fake`) decides how many FP8 inner buffers it emits from `args.input_requires_grad` / `args.weight_requires_grad` (captured at op-call time). The real `_linear_forward_impl` was instead reading the live `inp.requires_grad` / `weight.requires_grad`, so the two could disagree on the custom-op output arity. Under `torch.compile` with CUDA-graph trees (`mode="reduce-overhead"`) the static graph inputs are detached during capture, flipping the live flags to False mid-capture. The weight then drops its column-wise (`_transpose`) buffer, the op returns fewer tensors than the fake declared, every following slot shifts, and `assert_size_stride` fails. Reading the captured `args.*` flags keeps real and fake in lockstep; eager behavior is unchanged (same value captured). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
TensorProto now carries a rowwise/columnwise usage snapshot (taken from the quantizer at construction) and an update_usage() that mirrors QuantizedTensor.update_usage. Buffer description/allocation apply the usage to a quantizer copy, so the shared (value-opaque) quantizer is never mutated and the proto stays self-consistent. The Linear fake forward now calls saved_inputmat.update_usage(...) instead of pre-mirroring set_usage on the quantizer. Also move the eager (CPU) TensorProto smoke tests from test_tensor_proto.py into test_torch_compile.py (already in the L0 QA suite) and add an update_usage test that checks buffer re-description plus quantizer isolation. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…a; rework proto tests
Consolidate the two opt-in hooks _resolve_storage_cls + _storage_scalars
into a single _storage_metadata returning {cls, scalars}; create_metadata
reads from it. Restructure the TensorProto tests around the quantizer
primitives (create_metadata / alloc_tensors / flatten-unflatten) and
TensorProto.create_tensor (eager, fake and compiled).
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…sor_proto_mechanism Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> # Conflicts: # tests/pytorch/test_torch_compile.py # transformer_engine/pytorch/dynamo/__init__.py
Check the storage flatten/unflatten round-trip by value via dequantize() instead of only shape/dtype, and gate it per quantizer format using the is_*_available() flags (mirrors test_numerics), since dequantize runs the real CUDA kernel and each format has its own hardware requirement. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…ge snapshot Make TensorProto a plain (non-frozen) dataclass that copies the quantizer in __post_init__, so update_usage mutates the private copy directly. Drops the rowwise/columnwise usage fields and the lazy _usage_adjusted_quantizer; the shared value-opaque quantizer is still never mutated. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Restore coverage for to_tensor_proto: a plain tensor maps to a non-quantized proto, and a quantized tensor round-trips back into a proto with matching shape/dtype/buffer layout. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Invert the `if not backward_needs_input: pass / elif ...` antipattern into a plain `if backward_needs_input:` guard. Behavior is unchanged. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The bucket holds all non-tensor constructor kwargs (scalars, dtype, quantizer), so "scalars" was misleading. Rename the key and the _flatten_scalar_ctx helper across quantized_tensor and all storage _storage_metadata implementations, and clarify the TensorProto primitives section header. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR introduces
TensorProto— a data-free prototype of a tensor (or quantized tensor) that captures everything needed to reason about and rebuild a tensor without holding any storage: its logicalshape/dtypeand, for quantized tensors, the value-opaquequantizerdefining the layout.The key property is that
TensorProto.create_tensor()materializes a quantized tensor purely in Python (viaQuantizer.alloc_tensors+ the storage's__tensor_unflatten__), so it traces undertorch.compile(fullgraph=True)with no graph break — unlikemake_empty, which goes through the opaque C++tex.create_empty_quantized_tensor. This is the foundation for writingtorch.librarycustom-op fake implementations of quantized ops.This builds on the value-opaque quantizer work (so a
TensorProtois itself safe to treat as a compile-time constant).Type of change
Changes
dynamo.py: AddTensorProtodataclass (shape,dtype,quantizer,requires_grad,device) withis_quantized,inner_names(),create_metadata()andcreate_tensor(), plus ato_tensor_proto()helper that builds a proto from a plaintorch.Tensoror aQuantizedTensorStorage/QuantizedTensor.quantized_tensor.py:__tensor_flatten__/__tensor_unflatten__) toQuantizedTensorStorage, driven by a per-class_FLATTEN_TENSOR_BUFFERSdeclaration of(attribute_name, constructor_kwarg)pairs._STORAGE_REGISTRY(populated via__init_subclass__) so__tensor_unflatten__can resolve a concrete storage/wrapper class from its qualname inside an FX graph.Quantizer:alloc_tensors,create_metadata, and the opt-in overrides_describe_buffers,_storage_scalars,_resolve_storage_cls.Float8CurrentScalingQuantizer,MXFP8QuantizerandFloat8BlockQuantizer._FLATTEN_TENSOR_BUFFERSforFloat8TensorStorage,MXFP8TensorStorageandFloat8BlockwiseQTensorStorage.ops/basic/basic_linear.py: Add allocation-free_functional_forward_fake/_functional_backward_fakethat operate onTensorProtoand return output/gradient protos, as a basis for custom-op fake impls (single-device only; TP/SP shape effects not yet modeled).tests/pytorch/test_tensor_proto.py(CPU smoke tests for_describe_buffers/alloc_tensors/create_metadata, flatten round-trip, andto_tensor_proto) andtorch.compilefullgraph tests intest_torch_compile.py.Checklist: