Skip to content

[PyTorch] Add torch.compile custom-op path for Linear#9

Open
pggPL wants to merge 25 commits into
tensor_proto_mechanismfrom
linear_compile
Open

[PyTorch] Add torch.compile custom-op path for Linear#9
pggPL wants to merge 25 commits into
tensor_proto_mechanismfrom
linear_compile

Conversation

@pggPL

@pggPL pggPL commented Jun 7, 2026

Copy link
Copy Markdown
Owner

Build a torch.compile custom-op framework in dynamo.py that traces Linear forward+backward as single graph nodes (no graph break into the eager autograd.Function):

  • OpaqueSimpleMetadata bundle (with nested-dict support) and per-field buckets mapping the LinearFwd/BwdArgs dataclasses to op schema slots; quantizers ride through as value-opaque objects (own slot each).
  • Fakes stay TensorProto -> TensorProto; the framework converts tensor inputs to TensorProto at the boundary (via dataclasses.replace, which Dynamo can trace) and materializes output protos into fake tensors in register_fake. Saved-tensor aliases resolved by name from ctx_attrs.
  • Two-tier op + register_torch_dispatch flattens Float8Tensor weight inputs; quantized outputs are rebuilt via _ToSubclassFn (autograd-aware).
  • Mirror _linear_forward_impl set_usage / input+weight pipeline in the forward fake so register_fake layout matches eager.
  • Replace LinearBwdArgs.fp8_recipe with precomputed split-accumulator bools (recipe object is not compile-safe across the op boundary).
  • Dispatch through the op under torch.compiler.is_compiling(); drop @no_torch_dynamo from Linear.forward.

Tests: test_te_linear_compiles (bf16 + every recipe), quantized FP8 weight input. Backward through a Float8Tensor output is a strict xfail (AOTAutograd demands a subclass cotangent and the linear backward has no FP8-cotangent path).

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 25 commits June 7, 2026 14:42
Build a torch.compile custom-op framework in dynamo.py that traces Linear
forward+backward as single graph nodes (no graph break into the eager
autograd.Function):

- OpaqueSimpleMetadata bundle (with nested-dict support) and per-field
  buckets mapping the LinearFwd/BwdArgs dataclasses to op schema slots;
  quantizers ride through as value-opaque objects (own slot each).
- Fakes stay TensorProto -> TensorProto; the framework converts tensor
  inputs to TensorProto at the boundary (via dataclasses.replace, which
  Dynamo can trace) and materializes output protos into fake tensors in
  register_fake. Saved-tensor aliases resolved by name from ctx_attrs.
- Two-tier op + register_torch_dispatch flattens Float8Tensor weight
  inputs; quantized outputs are rebuilt via _ToSubclassFn (autograd-aware).
- Mirror _linear_forward_impl set_usage / input+weight pipeline in the
  forward fake so register_fake layout matches eager.
- Replace LinearBwdArgs.fp8_recipe with precomputed split-accumulator
  bools (recipe object is not compile-safe across the op boundary).
- Dispatch through the op under torch.compiler.is_compiling(); drop
  @no_torch_dynamo from Linear.forward.

Tests: test_te_linear_compiles (bf16 + every recipe), quantized FP8 weight
input. Backward through a Float8Tensor output is a strict xfail (AOTAutograd
demands a subclass cotangent and the linear backward has no FP8-cotangent
path).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Remove dead/duplicated code (unused OpaqueSimpleMetadata helpers, the dead
`aliased` slot path, duplicated create_tensor allocation, duplicated outer
kernel/fake closures), always use the two-tier op (the single-tier branch was
unreachable since TE always has quantized subclasses), and tighten the
docstrings. No behavior change; compile tests stay green.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Merge the dynamo/ package split and move the torch.compile custom-op
framework into its own module. This branch owns only the compile wiring;
the quantizer-opaque and TensorProto layers are inherited unchanged from the
lower branches.

  * dynamo/custom_op.py -- _te_register_custom_op and the framework
  * dynamo/{quantizer_opaque,tensor_proto}.py -- inherited from tensor_proto_mechanism
  * dynamo/__init__.py -- re-exports _te_register_custom_op too

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Clean up lint findings introduced by the dynamo/custom_op split and the
Linear torch.compile dispatch: remove unused imports, add missing docstrings,
use targeted pylint disables where intentional, and scope eager autograd
dispatch variables to the non-compile branch.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Covers linear1(fp8_output=True) feeding a Float8Tensor straight into
linear2 under a single fullgraph torch.compile, where the quantized
subclass lives only inside the graph (plain-tensor graph output).
Verifies forward and backward grads.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Differentiating w.r.t. a quantized (FP8) output is unsupported under
torch.compile: AOTAutograd demands a Float8Tensor cotangent that eager
.backward() never supplies, and the linear backward cannot consume an
FP8 grad_output (amax over raw FP8 bytes). Raise a clear
NotImplementedError up front in Linear.forward when fp8_output requires
grad, instead of failing deep in backward.

FP8 output without grad (inference / torch.no_grad) stays supported; fix
the compiled forward to only route subclass outputs through the
autograd-aware rewrap when grad is enabled (the wrap was untraceable
under no_grad). Replace the xfail backward test with a no-grad FP8-output
test and a test asserting the new guard fires; drop the FP8-intermediate
two-linear test (also a differentiable FP8 output, now rejected).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Remove test_te_linear_compile_fp8_output_requires_grad_raises per
review; the guard in Linear.forward stays. Keep the no-grad FP8-output
coverage and drop the stale cross-reference to the removed test.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…ees)

Parametrize the te.Linear torch.compile tests over the default backend and
`mode="reduce-overhead"`, iterating a few steps so the captured CUDA graph is
actually replayed. Drop the standalone reduce-overhead test in favor of this
shared coverage.

`reduce-overhead` runs its warmup with the cudagraph memory pool active, so TE's
lazily-allocated global scratch (e.g. the lru_cache'd cuBLAS workspace) would
otherwise land in that pool and trip the "live allocation not tracked as output"
check. Add a `_cudagraph_warmup` helper that runs the layer once eagerly before
compiling, forcing those globals into the default allocator.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…h_tensor_proto_mechanism/TE into linear_compile
…ear_compile

# Conflicts:
#	tests/pytorch/test_torch_compile.py
#	transformer_engine/pytorch/dynamo/__init__.py
Rename register_custom_op / _linear_op for clarity and shorten the FP8-output
error path. Make register_custom_op return None instead of raising when the
torch.compile APIs are unavailable, and route te.Linear through a
no_torch_dynamo-wrapped eager path (_linear_eager) when the compiled op is
unavailable or cannot serve a case (differentiable FP8 output), instead of
breaking import or raising.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Replace the lazy `_quantized_tensor_storage_cls` / `_quantizer_cls`
resolvers (and their dead `is None` guards) with a direct top-level
import of `QuantizedTensorStorage` and `Quantizer`; no import cycle
exists since `quantized_tensor` never imports `dynamo`. Also add an
eager-vs-compiled numerical check helper in the torch.compile tests.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Remove dead/redundant machinery and tighten the custom-op glue:

- Drop _ToSubclassFn and the with_autograd reassembly path; differentiable
  subclass outputs already fall back to eager, so plain __tensor_unflatten__
  is sufficient.
- Drop the single-use _prepare_for_saving lazy wrapper.
- Drop the always-False as_list grad-target plumbing.
- Move per-bucket grad-slot knowledge into a polymorphic _Bucket.grad_slot(),
  removing the isinstance switch in _resolve_grad_targets.
- Derive quantized wrapper subclasses by filtering _STORAGE_REGISTRY instead
  of walking __subclasses__ recursively.
- Consolidate redundant deferred ..quantized_tensor imports to top level.
- Drop the needless `if subclass_list:` guard (loops already no-op on empty).
- Rename the real kernel _eager -> _impl for clarity vs the fake kernel.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
….compile configs

Route te.Linear to eager execution under torch.compile for configurations the
compiled custom op / its fake impl does not model, instead of letting them hit
an opaque error deeper in tracing. LinearFwdArgs.compile_unsupported_reason()
centralizes the checks (debug instrumentation, manual TE FSDP, differentiable
fp8_output, CPU offloading, delayed wgrad compute, fuse_wgrad_accumulation, and
quantizers not registered as value-opaque types such as delayed scaling), and a
shared utils.warn_compile_eager_fallback() surfaces the reason.

Value-opaque quantizer classes are stamped with a class attribute at
registration so the check reads it off the instance (dynamo-traceable), avoiding
the graph break that type()-based introspection of opaque objects causes.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…ix FP8 weight caching

- run_numerics: add torch.compile coverage (default + reduce-overhead) for
  tensor/sequence-parallel te.Linear.
- comm_gemm_overlap: add Userbuffers te.Linear torch.compile tests (--compile,
  --compile-mode; default + reduce-overhead).
- test_torch_compile: add is_first_microbatch FP8 weight-caching test and a
  reduce-overhead guard asserting no cudagraph skips (eager fallback).
- linear.py: type LinearFwdArgs.weight_workspace as the tensor/quantized union
  so a cached FP8 weight is flattened across the custom-op boundary, fixing
  is_first_microbatch cache reuse under torch.compile.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…ake buffer layout

Add a reference-opaque custom-op bucket so a torch.distributed ProcessGroup
(tp_group) can cross the torch.compile boundary. Unlike value-opaque quantizers
(baked into the graph as constants), a process group is live state and is carried
through its own nullable schema slot as a graph input. tp_group is re-annotated
Optional[dist_group_type] so the bucket can recognize it, and ProcessGroup is
registered as a reference opaque type on import (PyTorch only auto-registers it
when DTensor is imported).

Also align the torch.compile fake/proto buffer layout with the real op:
emit flat buffers in canonical __tensor_flatten__ order, and mirror the C++
non-TN FP8 GEMM allocation (single _data buffer on Blackwell+, no separate
transpose) so the fake layout matches the real kernel slot-for-slot.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
… shapes in distributed compile

TensorProto.inner_names now follows the storage's __tensor_flatten__ order
(the order the real op flattens outputs to), so the proto unit tests assert
against that canonical layout rather than _describe_buffers insertion order
(which interleaves NVFP4 amax buffers differently).

Force dynamic=False for torch.compile in the distributed numerics/overlap
tests: a symbolic shape would land in a value-opaque OpaqueValueBundle op arg
whose hash chokes on non-nested SymInt. Use static shapes (recompile per
shape) until the bundle handles symbolic shapes.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Each parametrized _test_linear case compiles the same module.forward code
object with a different shape/recipe. With dynamic=False the guards accumulate
across cases and eventually trip Dynamo's recompile_limit (hard failure under
fullgraph=True, surfacing first in reduce-overhead mode). Reset the compile
cache before each case so it starts clean, mirroring the single-GPU tests.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…stants; dynamic shapes support

Wrap CommOverlapCore pybind11 methods that return compile-time constants so
torch.compile(fullgraph=True) can trace through them without graph breaks:
- `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py; `_ub_is_fp8()` in gemm.py
- `with_cublasmp()` → `ub_is_cublasmp()` in base.py

All callers in linear.py, layernorm_linear.py, layernorm_mlp.py, base.py, gemm.py
and userbuffers_backward_linear.py updated.

Dynamic-shapes support for te.Linear under torch.compile(dynamic=True):
- _linear_setup_ctx: don't store inp_shape in OpaqueValueBundle (SymInt is not
  hashable); backward reconstructs it from grad_output + weight + SP config.
- _linear_backward_impl_fake: derive dgrad shape the same way.
- Linear.forward: replace bare asserts with torch._check() so shape constraints
  are visible to the compiler as guards rather than hard errors.
- Add test_te_linear_dynamic_shapes covering varying batch sizes with
  recompilation detection.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
In column-parallel + sequence-parallel backward, grad_output is quantized
to Float8TensorStorage (_data, _scale_inv, _transpose) but never explicitly
freed before the backward function returns. Under torch.compile reduce-overhead,
these 3 live pool tensors at recording end trigger
"Detected 3 tensor(s) in the cudagraph pool not tracked as outputs".

Row-parallel SP already calls clear_tensor_data(grad_output) to free the
gathered tensor early. Extend it to cover column-SP where grad_output is
the quantized (non-gathered) internal tensor.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant