diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 44773100995e..e10f101ad63e 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -850,12 +850,12 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`") if self.foreach: torch._foreach_copy_( - [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params] + [param.data for param in parameters], + [c_param.to(param.device).data for c_param, param in zip(self.temp_stored_params, parameters)], ) else: for c_param, param in zip(self.temp_stored_params, parameters): param.data.copy_(c_param.data) - # Better memory-wise. self.temp_stored_params = None diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py index 436bbe1d53ff..31c95d1d9bb9 100644 --- a/tests/others/test_ema.py +++ b/tests/others/test_ema.py @@ -179,6 +179,29 @@ def test_serialization(self): assert torch.allclose(output, output_loaded, atol=1e-4) + def test_store_restore(self): + # store() saves params to CPU; restore() must move them back to the model's + # device before copying. The non-foreach path uses copy_() which is + # cross-device safe; this test guards that restore() actually round-trips + # the original weights regardless of device. + unet, ema_unet = self.get_models() + original_params = [p.data.clone() for p in unet.parameters()] + + # Simulate one EMA step so shadow params differ from model params. + unet = self.simulate_backprop(unet) + ema_unet.step(unet.parameters()) + + # Standard EMA validation pattern: store → copy_to → restore. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + ema_unet.restore(unet.parameters()) + + # After restore(), model weights must equal the pre-copy_to values. + for restored, original in zip(unet.parameters(), original_params): + assert torch.allclose(restored.data, original.to(restored.device), atol=1e-6), ( + "restore() did not correctly recover the stored parameters" + ) + class EMAModelTestsForeach(unittest.TestCase): model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" @@ -333,3 +356,20 @@ def test_serialization(self): output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample assert torch.allclose(output, output_loaded, atol=1e-4) + + def test_store_restore(self): + # The foreach restore() path passed raw CPU tensors from store() to + # torch._foreach_copy_(), which requires same-device tensors and therefore + # crashes on GPU. Fix: mirror copy_to()'s pattern of moving each stored + # tensor to param.device before the foreach copy. + unet, ema_unet = self.get_models() + original_params = [p.data.clone() for p in unet.parameters()] + unet = self.simulate_backprop(unet) + ema_unet.step(unet.parameters()) + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + ema_unet.restore(unet.parameters()) + for restored, original in zip(unet.parameters(), original_params): + assert torch.allclose(restored.data, original.to(restored.device), atol=1e-6), ( + "restore() foreach path did not correctly recover the stored parameters" + )