[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974
Conversation
Greptile SummaryThis PR fixes several DCP checkpoint loading issues for FSDP2 with quantized model initialization, covering MXFP8, NVFP4, and Float8 tensor types. It also corrects NVFP4 allgather numerical errors under FSDP2 by properly setting the
Confidence Score: 4/5The changes are targeted and well-reasoned; the DCP async/sync checkpoint flows for all quantization recipes should now work correctly with weights_only=True. The backward-compat classmethods (_make_in_reduce_ex) are documented as supporting re-loading of old pickle streams, but they silently break under weights_only=True because getattr is not in add_safe_globals. Any operator who saved checkpoints with a previous build and tries to load them through DCP will hit an opaque WeightsUnpickler error. This is a narrow edge case for pre-existing checkpoints, not the new format, but the misleading comment could lead to wasted debugging time. transformer_engine/pytorch/tensor/float8_tensor.py, mxfp8_tensor.py, nvfp4_tensor.py, float8_blockwise_tensor.py — the backward-compat _make_in_reduce_ex docstrings should be updated to reflect the weights_only=True limitation. Important Files Changed
Sequence DiagramsequenceDiagram
participant FSDP2
participant QT as QuantizedTensor
participant F8 as Float8Tensor
participant DCP
Note over FSDP2,DCP: DCP Async Save Flow
FSDP2->>QT: "aten._to_copy(device=cpu)"
QT->>QT: __torch_dispatch__(_to_copy)
Note right of QT: dtype unchanged, inner branch taken
QT->>QT: get_metadata() move tensors to CPU
QT-->>FSDP2: CPU QuantizedTensor subclass preserved
FSDP2->>DCP: stage CPU tensor to disk
DCP->>QT: __reduce_ex__(protocol)
QT-->>DCP: _make_in_reduce_ex with inner_buffers
Note right of DCP: quantizer.__getstate__ strips amax_reduction_group
DCP->>DCP: "torch.save weights_only=True compatible"
Note over FSDP2,DCP: DCP Sync Load Flow
DCP->>DCP: "torch.load weights_only=True"
DCP->>QT: _make_star_in_reduce_ex buffers
Note right of DCP: module-level fn single GLOBAL opcode
DCP-->>FSDP2: reconstructed CPU QuantizedTensor
FSDP2->>F8: copy_ loaded_cpu_tensor
F8->>F8: direct FP8 buffer copy dtype match
F8-->>FSDP2: model param updated
Note over FSDP2: NVFP4 allgather fix
FSDP2->>QT: fsdp_pre_all_gather
QT->>QT: "NVFP4Quantizer.amax_reduction_group = shard_group"
Note right of QT: PR fix was previously missing for NVFP4
Reviews (11): Last reviewed commit: "Merge branch 'main' into fsdp2_dcp_laod_..." | Re-trigger Greptile |
| def untyped_storage(self) -> torch.UntypedStorage: | ||
| """Return an empty UntypedStorage on the tensor's device. | ||
|
|
||
| ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real | ||
| backing storage of its own; the actual bytes live in the inner | ||
| buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are | ||
| an implementation detail of the quantization scheme. Need to define | ||
| this method to avoid DCP staging errors with FSDP2. | ||
| """ | ||
| return torch.UntypedStorage(0, device=self.device) |
There was a problem hiding this comment.
Empty storage breaks shared-storage detection in existing callers
QuantizedTensor.untyped_storage() now returns a freshly allocated zero-byte storage every call. Code in module/_common.py:128 compares tensors[0].untyped_storage().nbytes() against expected size to decide between a no-op view and an out-of-place torch.cat. With 0 bytes returned, that condition is always true, silently disabling the in-place fast path for any QuantizedTensor through ConcatMerge.forward. More critically, utils.py:403-412 in SplitAlongDim.backward uses data_ptr() for noop detection — if all zero-size CUDA allocations return data_ptr() == 0, every QuantizedTensor pair incorrectly appears co-located, setting noop_ok = True and crashing on ret.set_() against a 0-byte storage.
There was a problem hiding this comment.
The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.
There was a problem hiding this comment.
Yeah, while I don't think we use QuantizedTensors in the SplitAlongDim ever, the concat sounds plausible to be hit.
There was a problem hiding this comment.
Need to resolve this comment after going thoroughly over noop_cat consequences on Quantizedtensors
There was a problem hiding this comment.
The behavior is unchanged with the change. And I would argue the implementation now is more correct with the change. untyped_storage() default implementation from QuantizedTensor(torch.Tensor) before this change, gives a storage with two properties.
-
storage.nbytes() returns bytes based on the fake_dtype that we use to register our QuantizedTensor as a torchTensor using make_wrapper_subclass method of torch.
-
storage.data_ptr() gives an error saying it is an invalid storage and there is no data_ptr()
Both of them is not ideal.
The first one is grossly incrorrect due to two reasons. First we manage the backing storage for the inner tensors of QuantizedTensor and torch has no idea about it. Second nbytes based on fake_dtype is misleading since that might not actually be the number of bytes we actually allocate.
Second one is causing problems with FSDP2 now since it expects some storage for identity check.
For QuantizedTensor, noop_cat today always returns an actual torch.cat which goes through a dequantization luckily due to this condition being true. This condition is going to be true now with the change as well since nbytes() would return 0.
If we do QuantizedTensor.data_ptr() today it gives you 0. QuantizedTensor.untyped_storage().data_ptr() will give invalid storage error which is inconsistent. And giving empty storage as empty storage will fix this inconsitency.
As far as idenity checking goes, FSDP2 does all the comparisong logic only if data_ptr() is not 0. And it also doesnt really make sense to compare two empty storages.
|
/te-ci L1 pytorch |
| msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}", | ||
| ) | ||
| elif recipe_name == "NVFP4BlockScaling": | ||
| # NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances |
There was a problem hiding this comment.
Why do we need dequant + quant here?
There was a problem hiding this comment.
We are doing it anymore
| # torch DCP staging via ``x.new_empty(..., device="cpu")``), we | ||
| # save the high-precision values in a plain CPU dense tensor. | ||
| # For the DCP load path, we will re-quantize the high-precision values. | ||
| target_size = torch.Size(size) if len(size) > 0 else tensor.size() |
There was a problem hiding this comment.
An empty size is valid and it corresponds to a tensor with 1 entry (for the same reason 2^0=1).
>>> import torch
>>> x = torch.ones(123).new_empty([])
>>> print(x.numel())
1
| target_size = torch.Size(size) if len(size) > 0 else tensor.size() | |
| target_size = size |
There was a problem hiding this comment.
Changed the torch dispatch function now. So we dont have size here
| def untyped_storage(self) -> torch.UntypedStorage: | ||
| """Return an empty UntypedStorage on the tensor's device. | ||
|
|
||
| ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real | ||
| backing storage of its own; the actual bytes live in the inner | ||
| buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are | ||
| an implementation detail of the quantization scheme. Need to define | ||
| this method to avoid DCP staging errors with FSDP2. | ||
| """ | ||
| return torch.UntypedStorage(0, device=self.device) |
There was a problem hiding this comment.
The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
3589ffa to
4197bee
Compare
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
Description
Fixes DCP Sync checkpoint loading for MXFP8/NVFP4.
Fixes DCP Async checkpoint loading for all Quantization recipes
Fixes NVFP4 allgather + dequant numerical errors for fsdp2. Turns out this was due to us not setting the fsdp group as the amax reduction group in the quantizer
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Untyped_storage implementation needed for FSDP2 + DCP
DCP Aync/Sync Checkpoint loading
NVFP4 Allgather Correctness issues
TE_DType Serialization issues with DCP Checkpointing
Checklist: