Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,12 +850,12 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
if self.foreach:
torch._foreach_copy_(
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
[param.data for param in parameters],
[c_param.to(param.device).data for c_param, param in zip(self.temp_stored_params, parameters)],
)
else:
for c_param, param in zip(self.temp_stored_params, parameters):
param.data.copy_(c_param.data)

# Better memory-wise.
self.temp_stored_params = None

Expand Down
40 changes: 40 additions & 0 deletions tests/others/test_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,29 @@ def test_serialization(self):

assert torch.allclose(output, output_loaded, atol=1e-4)

def test_store_restore(self):
# store() saves params to CPU; restore() must move them back to the model's
# device before copying. The non-foreach path uses copy_() which is
# cross-device safe; this test guards that restore() actually round-trips
# the original weights regardless of device.
unet, ema_unet = self.get_models()
original_params = [p.data.clone() for p in unet.parameters()]

# Simulate one EMA step so shadow params differ from model params.
unet = self.simulate_backprop(unet)
ema_unet.step(unet.parameters())

# Standard EMA validation pattern: store → copy_to → restore.
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
ema_unet.restore(unet.parameters())

# After restore(), model weights must equal the pre-copy_to values.
for restored, original in zip(unet.parameters(), original_params):
assert torch.allclose(restored.data, original.to(restored.device), atol=1e-6), (
"restore() did not correctly recover the stored parameters"
)


class EMAModelTestsForeach(unittest.TestCase):
model_id = "hf-internal-testing/tiny-stable-diffusion-pipe"
Expand Down Expand Up @@ -333,3 +356,20 @@ def test_serialization(self):
output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample

assert torch.allclose(output, output_loaded, atol=1e-4)

def test_store_restore(self):
# The foreach restore() path passed raw CPU tensors from store() to
# torch._foreach_copy_(), which requires same-device tensors and therefore
# crashes on GPU. Fix: mirror copy_to()'s pattern of moving each stored
# tensor to param.device before the foreach copy.
unet, ema_unet = self.get_models()
original_params = [p.data.clone() for p in unet.parameters()]
unet = self.simulate_backprop(unet)
ema_unet.step(unet.parameters())
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
ema_unet.restore(unet.parameters())
for restored, original in zip(unet.parameters(), original_params):
assert torch.allclose(restored.data, original.to(restored.device), atol=1e-6), (
"restore() foreach path did not correctly recover the stored parameters"
)
Loading