Skip to content

Fix EMAModel.restore() foreach path crashing with device mismatch when model is on GPU#13782

Closed
Dev-X25874 wants to merge 6 commits into
huggingface:mainfrom
Dev-X25874:fix/ema-restore-foreach-device-mismatch
Closed

Fix EMAModel.restore() foreach path crashing with device mismatch when model is on GPU#13782
Dev-X25874 wants to merge 6 commits into
huggingface:mainfrom
Dev-X25874:fix/ema-restore-foreach-device-mismatch

Conversation

@Dev-X25874
Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes a runtime crash in EMAModel.restore() when foreach=True and the model lives on a non-CPU device (e.g. CUDA).

store() always saves parameters to CPU (param.detach().cpu().clone()). The foreach path in restore() then passed those raw CPU tensors directly to torch._foreach_copy_(), which requires all tensors to be on the same device:

# before (broken on GPU)
torch._foreach_copy_(
    [param.data for param in parameters],
    [c_param.data for c_param in self.temp_stored_params],  # always CPU
)

This raises RuntimeError: Expected all tensors to be on same device for any user who calls the standard EMA validation pattern (store → copy_to → restore) with foreach=True on a GPU machine.

The fix mirrors the pattern already used correctly in copy_to()'s foreach path (line 780), which moves each shadow param to the target device before the copy:

# after (matches copy_to() pattern)
torch._foreach_copy_(
    [param.data for param in parameters],
    [c_param.to(param.device).data for c_param, param in zip(self.temp_stored_params, parameters)],
)

Also adds test_store_restore to both EMAModelTests and EMAModelTestsForeach — the store/restore round-trip was completely untested prior to this PR.

Before submitting

Who can review?

@sayakpaul

@github-actions github-actions Bot added tests size/S PR with diff < 50 LOC labels May 21, 2026
@sayakpaul
Copy link
Copy Markdown
Member

How can I minimally reproduce the bug?

@Dev-X25874
Copy link
Copy Markdown
Contributor Author

Hi @sayakpaul, here's a minimal repro (requires a CUDA GPU):

import torch
import torch.nn as nn
from diffusers.training_utils import EMAModel

model = nn.Linear(4, 4).cuda()
ema = EMAModel(model.parameters(), foreach=True)

# Simulate a training step so shadow params differ from model params
with torch.no_grad():
    for p in model.parameters():
        p.add_(torch.randn_like(p))
ema.step(model.parameters())

# Standard EMA validation pattern
ema.store(model.parameters())     # saves to CPU
ema.copy_to(model.parameters())  # works fine
ema.restore(model.parameters())  # RuntimeError: Expected all tensors to be on same device

The crash happens because store() always clones params to CPU, but the foreach path in restore() feeds those raw CPU tensors into torch._foreach_copy_() alongside the GPU model params. The non-foreach path is unaffected since copy_() handles cross-device copies natively.

Comment thread tests/others/test_ema.py
assert torch.allclose(restored.data, original.to(restored.device), atol=1e-6), (
"restore() foreach path did not correctly recover the stored parameters"
)
def test_store_restore(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems duplicated.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, the foreach test_store_restore was mistakenly placed inside EMAModelTests instead of EMAModelTestsForeach. Fixed now.

@github-actions github-actions Bot added size/M PR with diff < 200 LOC and removed size/S PR with diff < 50 LOC labels May 22, 2026
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/M PR with diff < 200 LOC labels May 22, 2026
@sayakpaul
Copy link
Copy Markdown
Member

sayakpaul commented May 22, 2026

How come this works?

import torch

gpu = [torch.zeros(3, device="cuda")]
cpu = [torch.arange(3, dtype=torch.float32)]
torch._foreach_copy_(gpu, cpu)
print(gpu[0])

Prints:

tensor([0., 1., 2.], device='cuda:0')

@Dev-X25874
Copy link
Copy Markdown
Contributor Author

Sorry for the noise, converting this to draft while I investigate further.

@Dev-X25874 Dev-X25874 marked this pull request as draft May 22, 2026 09:58
@Dev-X25874
Copy link
Copy Markdown
Contributor Author

Verified on PyTorch 2.10.0+cu128 — no crash. You're right, torch._foreach_copy_() handles cross-device fine on modern PyTorch. Closing this PR. Sorry for the noise.

@Dev-X25874 Dev-X25874 closed this May 22, 2026
@Dev-X25874 Dev-X25874 deleted the fix/ema-restore-foreach-device-mismatch branch May 23, 2026 09:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

size/S PR with diff < 50 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants