Fix EMAModel.restore() foreach path crashing with device mismatch when model is on GPU#13782
Fix EMAModel.restore() foreach path crashing with device mismatch when model is on GPU#13782Dev-X25874 wants to merge 6 commits into
EMAModel.restore() foreach path crashing with device mismatch when model is on GPU#13782Conversation
…ice mismatch on GPU
…n-foreach EMAModel
|
How can I minimally reproduce the bug? |
|
Hi @sayakpaul, here's a minimal repro (requires a CUDA GPU): import torch
import torch.nn as nn
from diffusers.training_utils import EMAModel
model = nn.Linear(4, 4).cuda()
ema = EMAModel(model.parameters(), foreach=True)
# Simulate a training step so shadow params differ from model params
with torch.no_grad():
for p in model.parameters():
p.add_(torch.randn_like(p))
ema.step(model.parameters())
# Standard EMA validation pattern
ema.store(model.parameters()) # saves to CPU
ema.copy_to(model.parameters()) # works fine
ema.restore(model.parameters()) # RuntimeError: Expected all tensors to be on same deviceThe crash happens because |
| assert torch.allclose(restored.data, original.to(restored.device), atol=1e-6), ( | ||
| "restore() foreach path did not correctly recover the stored parameters" | ||
| ) | ||
| def test_store_restore(self): |
There was a problem hiding this comment.
My bad, the foreach test_store_restore was mistakenly placed inside EMAModelTests instead of EMAModelTestsForeach. Fixed now.
|
How come this works? import torch
gpu = [torch.zeros(3, device="cuda")]
cpu = [torch.arange(3, dtype=torch.float32)]
torch._foreach_copy_(gpu, cpu)
print(gpu[0])Prints: tensor([0., 1., 2.], device='cuda:0') |
|
Sorry for the noise, converting this to draft while I investigate further. |
|
Verified on PyTorch 2.10.0+cu128 — no crash. You're right, |
What does this PR do?
Fixes a runtime crash in
EMAModel.restore()whenforeach=Trueand the model lives on a non-CPU device (e.g. CUDA).store()always saves parameters to CPU (param.detach().cpu().clone()). Theforeachpath inrestore()then passed those raw CPU tensors directly totorch._foreach_copy_(), which requires all tensors to be on the same device:This raises
RuntimeError: Expected all tensors to be on same devicefor any user who calls the standard EMA validation pattern (store → copy_to → restore) withforeach=Trueon a GPU machine.The fix mirrors the pattern already used correctly in
copy_to()'s foreach path (line 780), which moves each shadow param to the target device before the copy:Also adds
test_store_restoreto bothEMAModelTestsandEMAModelTestsForeach— the store/restore round-trip was completely untested prior to this PR.Before submitting
Who can review?
@sayakpaul