Skip to content

[PyTorch][torch.compile] Add TensorProto mechanism #8

Open
pggPL wants to merge 18 commits into
make_qunatizers_opaquefrom
tensor_proto_mechanism
Open

[PyTorch][torch.compile] Add TensorProto mechanism #8
pggPL wants to merge 18 commits into
make_qunatizers_opaquefrom
tensor_proto_mechanism

Conversation

@pggPL

@pggPL pggPL commented Jun 6, 2026

Copy link
Copy Markdown
Owner

Description

This PR introduces TensorProto — a data-free prototype of a tensor (or quantized tensor) that captures everything needed to reason about and rebuild a tensor without holding any storage: its logical shape/dtype and, for quantized tensors, the value-opaque quantizer defining the layout.

The key property is that TensorProto.create_tensor() materializes a quantized tensor purely in Python (via Quantizer.alloc_tensors + the storage's __tensor_unflatten__), so it traces under torch.compile(fullgraph=True) with no graph break — unlike make_empty, which goes through the opaque C++ tex.create_empty_quantized_tensor. This is the foundation for writing torch.library custom-op fake implementations of quantized ops.

This builds on the value-opaque quantizer work (so a TensorProto is itself safe to treat as a compile-time constant).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • dynamo.py: Add TensorProto dataclass (shape, dtype, quantizer, requires_grad, device) with is_quantized, inner_names(), create_metadata() and create_tensor(), plus a to_tensor_proto() helper that builds a proto from a plain torch.Tensor or a QuantizedTensorStorage/QuantizedTensor.
  • quantized_tensor.py:
    • Add the PyTorch wrapper-subclass flatten protocol (__tensor_flatten__ / __tensor_unflatten__) to QuantizedTensorStorage, driven by a per-class _FLATTEN_TENSOR_BUFFERS declaration of (attribute_name, constructor_kwarg) pairs.
    • Add a _STORAGE_REGISTRY (populated via __init_subclass__) so __tensor_unflatten__ can resolve a concrete storage/wrapper class from its qualname inside an FX graph.
    • Add pure-Python, traceable allocation hooks to Quantizer: alloc_tensors, create_metadata, and the opt-in overrides _describe_buffers, _storage_scalars, _resolve_storage_cls.
  • Quantizers: Implement the allocation hooks for Float8CurrentScalingQuantizer, MXFP8Quantizer and Float8BlockQuantizer.
  • Storage classes: Declare _FLATTEN_TENSOR_BUFFERS for Float8TensorStorage, MXFP8TensorStorage and Float8BlockwiseQTensorStorage.
  • ops/basic/basic_linear.py: Add allocation-free _functional_forward_fake / _functional_backward_fake that operate on TensorProto and return output/gradient protos, as a basis for custom-op fake impls (single-device only; TP/SP shape effects not yet modeled).
  • Tests: Add tests/pytorch/test_tensor_proto.py (CPU smoke tests for _describe_buffers/alloc_tensors/create_metadata, flatten round-trip, and to_tensor_proto) and torch.compile fullgraph tests in test_torch_compile.py.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL added 18 commits June 6, 2026 16:50
…location

Introduce TensorProto: a data-free prototype of a (possibly quantized)
tensor that captures shape/dtype and a value-opaque quantizer, and can
materialize the tensor purely in Python (Quantizer.alloc_tensors +
storage __tensor_unflatten__), so it traces under torch.compile with no
graph break.

- Add the PyTorch wrapper-subclass flatten protocol
  (__tensor_flatten__ / __tensor_unflatten__) to QuantizedTensorStorage,
  driven by a per-class _FLATTEN_TENSOR_BUFFERS declaration, plus a
  storage-class registry for qualname-based reconstruction in FX graphs.
- Add pure-Python, traceable allocation to Quantizer (alloc_tensors,
  create_metadata, _describe_buffers, _storage_scalars,
  _resolve_storage_cls) implemented for FP8 current-scaling, MXFP8 and
  FP8 blockwise quantizers.
- Add allocation-free fake forward/backward for BasicLinear operating on
  TensorProto, as a basis for torch.library custom-op fake impls.
- Add CPU smoke tests and torch.compile fullgraph tests.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The allocation-free TensorProto fakes belong next to the real
_linear_forward_impl / _linear_backward in module/linear.py rather than
in the ops BasicLinear. Add _linear_forward_impl_fake and
_linear_backward_impl_fake that take the LinearFwdArgs/LinearBwdArgs bags
(with TensorProto in the tensor fields) and return protos mirroring the
real impls' output/saved-tensor and (wgrad, dgrad, grad_bias) contracts.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The forward fake now follows _linear_forward_impl's control flow and
returns the full (out, new_weight_workspace, tensors_to_save_from_forward,
None, ctx_attrs) tuple: it models the quantize_weight workspace (returned
only when a fresh FP8 workspace is created and cached), the
(saved_inputmat, weightmat, saved_weight, bias) save layout with alias
dedup, and the save_original_input / FSDP2 branches. The backward fake
quantizes wgrad with grad_weight_quantizer to match the real wgrad GEMM.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
NVFP4Quantizer is a value-opaque (tensorless) quantizer but was missing
the pure-Python allocation hooks. Add _describe_buffers / _storage_scalars
/ _resolve_storage_cls for it (FP4 data packed 2-per-byte as uint8, E4M3
scales as uint8, FP32 amax -- per-row when row_scaled_nvfp4, else scalar)
and declare NVFP4TensorStorage._FLATTEN_TENSOR_BUFFERS (rowwise/columnwise
data + scale_inv + amax).

Also rework test_tensor_proto_flatten_roundtrip_compiles: instead of
calling __tensor_flatten__ / __tensor_unflatten__ directly inside a
compiled fn (unsupported by Dynamo, so the test was red for every
quantizer), pass the quantized tensor across the compile boundary so
Dynamo/AOTAutograd runs the protocol itself. Cover NVFP4 in the smoke and
compile tests, guarded by CUDA since NVFP4Quantizer builds its RHT matrix
on the current device.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Mirror _linear_forward_impl/_linear_backward more closely in the fake impls:
set_usage on input/weight/output/grad-input quantizers, input/weight pipeline
classification with name-based saved-tensor aliasing, honor
fuse_wgrad_accumulation and use_bias+requires_wgrad in the backward fake, and
reject manual TE FSDP (fsdp_group) with a clear NotImplementedError.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Merge the dynamo/ package split and carve the TensorProto layer into its own
module. This branch owns the data-free tensor description used by the
custom-op fake impls.

  * dynamo/tensor_proto.py -- TensorProto, to_tensor_proto, _contiguous_stride
  * dynamo/quantizer_opaque.py -- inherited unchanged from make_qunatizers_opaque
  * dynamo/__init__.py -- re-exports TensorProto / to_tensor_proto too

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The base Quantizer.create_metadata keeps shape in its signature for API
symmetry with callers, but the default implementation does not use it.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The forward/backward fake implementations rejected tensor-parallel
configs outright. TP/SP communication is opaque to torch.compile (it
runs inside the eager custom op), so the only shape effect to model is
on the output's leading dim: column+sequence-parallel all-gathers the
input (leading *= tp_size) and row+sequence-parallel reduce-scatters the
output (leading //= tp_size). The saved input and gradients stay
rank-local, so the backward fake needs no extra modeling. Also widen
LinearFwdArgs.inp to TensorOrQuantized so quantized inputs flatten
correctly.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The fake forward (`_linear_forward_impl_fake`) decides how many FP8 inner
buffers it emits from `args.input_requires_grad` / `args.weight_requires_grad`
(captured at op-call time). The real `_linear_forward_impl` was instead reading
the live `inp.requires_grad` / `weight.requires_grad`, so the two could disagree
on the custom-op output arity.

Under `torch.compile` with CUDA-graph trees (`mode="reduce-overhead"`) the static
graph inputs are detached during capture, flipping the live flags to False
mid-capture. The weight then drops its column-wise (`_transpose`) buffer, the op
returns fewer tensors than the fake declared, every following slot shifts, and
`assert_size_stride` fails. Reading the captured `args.*` flags keeps real and
fake in lockstep; eager behavior is unchanged (same value captured).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
TensorProto now carries a rowwise/columnwise usage snapshot (taken from the
quantizer at construction) and an update_usage() that mirrors
QuantizedTensor.update_usage. Buffer description/allocation apply the usage to a
quantizer copy, so the shared (value-opaque) quantizer is never mutated and the
proto stays self-consistent. The Linear fake forward now calls
saved_inputmat.update_usage(...) instead of pre-mirroring set_usage on the
quantizer.

Also move the eager (CPU) TensorProto smoke tests from test_tensor_proto.py into
test_torch_compile.py (already in the L0 QA suite) and add an update_usage test
that checks buffer re-description plus quantizer isolation.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…a; rework proto tests

Consolidate the two opt-in hooks _resolve_storage_cls + _storage_scalars
into a single _storage_metadata returning {cls, scalars}; create_metadata
reads from it. Restructure the TensorProto tests around the quantizer
primitives (create_metadata / alloc_tensors / flatten-unflatten) and
TensorProto.create_tensor (eager, fake and compiled).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…sor_proto_mechanism

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

# Conflicts:
#	tests/pytorch/test_torch_compile.py
#	transformer_engine/pytorch/dynamo/__init__.py
Check the storage flatten/unflatten round-trip by value via dequantize()
instead of only shape/dtype, and gate it per quantizer format using the
is_*_available() flags (mirrors test_numerics), since dequantize runs the
real CUDA kernel and each format has its own hardware requirement.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…ge snapshot

Make TensorProto a plain (non-frozen) dataclass that copies the quantizer in
__post_init__, so update_usage mutates the private copy directly. Drops the
rowwise/columnwise usage fields and the lazy _usage_adjusted_quantizer; the
shared value-opaque quantizer is still never mutated.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Restore coverage for to_tensor_proto: a plain tensor maps to a non-quantized
proto, and a quantized tensor round-trips back into a proto with matching
shape/dtype/buffer layout.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Invert the `if not backward_needs_input: pass / elif ...` antipattern into a
plain `if backward_needs_input:` guard. Behavior is unchanged.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The bucket holds all non-tensor constructor kwargs (scalars, dtype, quantizer),
so "scalars" was misleading. Rename the key and the _flatten_scalar_ctx helper
across quantized_tensor and all storage _storage_metadata implementations, and
clarify the TensorProto primitives section header.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant