Skip to content
123 changes: 61 additions & 62 deletions tests/models/autoencoders/test_models_autoencoder_kl_kvae_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import pytest
import torch

from diffusers import AutoencoderKLKVAEVideo
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLKVAEVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLKVAEVideo
main_input_name = "sample"
base_precision = 1e-2
def _run_nondeterministic(fn):
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
# temporarily relax the requirement for tests that do backward passes.
torch.use_deterministic_algorithms(False)
try:
fn()
finally:
torch.use_deterministic_algorithms(True)

def get_autoencoder_kl_kvae_video_config(self):

class AutoencoderKLKVAEVideoTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKLKVAEVideo

@property
def main_input_name(self) -> str:
return "sample"

@property
def output_shape(self) -> tuple:
return (3, 3, 16, 16)

@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_init_dict(self) -> dict:
return {
"ch": 32,
"ch_mult": (1, 2),
Expand All @@ -41,78 +65,53 @@ def get_autoencoder_kl_kvae_video_config(self):
"temporal_compress_times": 2,
}

@property
def dummy_input(self):
def get_dummy_inputs(self) -> dict:
batch_size = 2
num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2
num_channels = 3
sizes = (16, 16)

video = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)

video = randn_tensor(
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
)
return {"sample": video}

@property
def input_shape(self):
return (3, 3, 16, 16)

@property
def output_shape(self):
return (3, 3, 16, 16)

def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_kvae_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_gradient_checkpointing_is_applied(self):
expected_set = {
"KVAECachedEncoder3D",
"KVAECachedDecoder3D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass

@unittest.skip(
class TestAutoencoderKLKVAEVideo(AutoencoderKLKVAEVideoTesterConfig, ModelTesterMixin):
@pytest.mark.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_model_parallelism(self):
pass
super().test_model_parallelism()

@unittest.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_sharded_checkpoints_device_map(self):
pass

def _run_nondeterministic(self, fn):
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
# temporarily relax the requirement for training tests that do backward passes.
import torch
class TestAutoencoderKLKVAEVideoTraining(AutoencoderKLKVAEVideoTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderKLKVAEVideo."""

torch.use_deterministic_algorithms(False)
try:
fn()
finally:
torch.use_deterministic_algorithms(True)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"KVAECachedEncoder3D", "KVAECachedDecoder3D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_training(self):
self._run_nondeterministic(super().test_training)
_run_nondeterministic(super().test_training)

def test_ema_training(self):
self._run_nondeterministic(super().test_ema_training)
def test_training_with_ema(self):
_run_nondeterministic(super().test_training_with_ema)

@unittest.skip(
@pytest.mark.skip(
"Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict "
"that is mutated during the first forward. On recomputation the cache is already populated, "
"causing a different execution path and numerically different gradients. "
"GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation."
"causing a different execution path and numerically different gradients."
)
def test_effective_gradient_checkpointing(self):
pass
def test_gradient_checkpointing_equivalence(self):
super().test_gradient_checkpointing_equivalence()

def test_layerwise_casting_training(self):
self._run_nondeterministic(super().test_layerwise_casting_training)
_run_nondeterministic(super().test_layerwise_casting_training)


class TestAutoencoderKLKVAEVideoMemory(AutoencoderKLKVAEVideoTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKLKVAEVideo."""


class TestAutoencoderKLKVAEVideoSlicingTiling(AutoencoderKLKVAEVideoTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKLKVAEVideo."""
Loading
Loading