diff --git a/tests/models/transformers/test_models_transformer_bria.py b/tests/models/transformers/test_models_transformer_bria.py index 9056590edffe..10ceeae61bad 100644 --- a/tests/models/transformers/test_models_transformer_bria.py +++ b/tests/models/transformers/test_models_transformer_bria.py @@ -13,23 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +from typing import Any import torch from diffusers import BriaTransformer2DModel from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 from diffusers.models.embeddings import ImageProjection +from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin +from ..testing_utils import ( + BaseModelTesterConfig, + IPAdapterTesterMixin, + LoraHotSwappingForModelTesterMixin, + LoraTesterMixin, + ModelTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -def create_bria_ip_adapter_state_dict(model): - # "ip_adapter" (cross-attention weights) +def create_bria_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]: ip_cross_attn_state_dict = {} key_id = 0 @@ -50,11 +57,8 @@ def create_bria_ip_adapter_state_dict(model): f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], } ) - key_id += 1 - # "image_proj" (ImageProjection layer weights) - image_projection = ImageProjection( cross_attention_dim=model.config["joint_attention_dim"], image_embed_dim=model.config["pooled_projection_dim"], @@ -73,53 +77,36 @@ def create_bria_ip_adapter_state_dict(model): ) del sd - ip_state_dict = {} - ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) - return ip_state_dict - + return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict} -class BriaTransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = BriaTransformer2DModel - main_input_name = "hidden_states" - # We override the items here because the transformer under consideration is small. - model_split_percents = [0.8, 0.7, 0.7] - - # Skip setting testing with default: AttnProcessor - uses_custom_attn_processor = True +class BriaTransformerTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - batch_size = 1 - num_latent_channels = 4 - num_image_channels = 3 - height = width = 4 - sequence_length = 48 - embedding_dim = 32 + def model_class(self): + return BriaTransformer2DModel - hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) - text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) - image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) - timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + @property + def main_input_name(self) -> str: + return "hidden_states" - return { - "hidden_states": hidden_states, - "encoder_hidden_states": encoder_hidden_states, - "img_ids": image_ids, - "txt_ids": text_ids, - "timestep": timestep, - } + @property + def model_split_percents(self) -> list: + return [0.8, 0.7, 0.7] @property - def input_shape(self): + def output_shape(self) -> tuple: return (16, 4) @property - def output_shape(self): + def input_shape(self) -> tuple: return (16, 4) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict: + return { "patch_size": 1, "in_channels": 4, "num_layers": 1, @@ -131,11 +118,35 @@ def prepare_init_args_and_inputs_for_common(self): "axes_dims_rope": [0, 4, 4], } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + return { + "hidden_states": randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ), + "img_ids": randn_tensor( + (height * width, num_image_channels), generator=self.generator, device=torch_device + ), + "txt_ids": randn_tensor( + (sequence_length, num_image_channels), generator=self.generator, device=torch_device + ), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } + +class TestBriaTransformer(BriaTransformerTesterConfig, ModelTesterMixin): def test_deprecated_inputs_img_txt_ids_3d(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) model.to(torch_device) model.eval() @@ -143,7 +154,6 @@ def test_deprecated_inputs_img_txt_ids_3d(self): with torch.no_grad(): output_1 = model(**inputs_dict).to_tuple()[0] - # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated) text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0) image_ids_3d = inputs_dict["img_ids"].unsqueeze(0) @@ -156,26 +166,63 @@ def test_deprecated_inputs_img_txt_ids_3d(self): with torch.no_grad(): output_2 = model(**inputs_dict).to_tuple()[0] - self.assertEqual(output_1.shape, output_2.shape) - self.assertTrue( - torch.allclose(output_1, output_2, atol=1e-5), - msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", + assert output_1.shape == output_2.shape + assert torch.allclose(output_1, output_2, atol=1e-5), ( + "output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) " + "are not equal as them as 2d inputs" ) + +class TestBriaTransformerTraining(BriaTransformerTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): expected_set = {"BriaTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) -class BriaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = BriaTransformer2DModel +class TestBriaTransformerIPAdapter(BriaTransformerTesterConfig, IPAdapterTesterMixin): + @property + def ip_adapter_processor_cls(self): + return FluxIPAdapterJointAttnProcessor2_0 + + def modify_inputs_for_ip_adapter(self, model, inputs_dict): + torch.manual_seed(0) + cross_attention_dim = getattr(model.config, "joint_attention_dim", 32) + image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device) + inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}}) + return inputs_dict + + def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]: + return create_bria_ip_adapter_state_dict(model) + - def prepare_init_args_and_inputs_for_common(self): - return BriaTransformerTests().prepare_init_args_and_inputs_for_common() +class TestBriaTransformerLoRA(BriaTransformerTesterConfig, LoraTesterMixin): + pass -class BriaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): - model_class = BriaTransformer2DModel +class TestBriaTransformerLoRAHotSwap(BriaTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 24 + embedding_dim = 32 - def prepare_init_args_and_inputs_for_common(self): - return BriaTransformerTests().prepare_init_args_and_inputs_for_common() + return { + "hidden_states": randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ), + "img_ids": randn_tensor( + (height * width, num_image_channels), generator=self.generator, device=torch_device + ), + "txt_ids": randn_tensor( + (sequence_length, num_image_channels), generator=self.generator, device=torch_device + ), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } diff --git a/tests/models/transformers/test_models_transformer_bria_fibo.py b/tests/models/transformers/test_models_transformer_bria_fibo.py index f859f4608bd5..0b220f1695b6 100644 --- a/tests/models/transformers/test_models_transformer_bria_fibo.py +++ b/tests/models/transformers/test_models_transformer_bria_fibo.py @@ -13,62 +13,49 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import torch from diffusers import BriaFiboTransformer2DModel +from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin +from ..testing_utils import ( + BaseModelTesterConfig, + ModelTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = BriaFiboTransformer2DModel - main_input_name = "hidden_states" - # We override the items here because the transformer under consideration is small. - model_split_percents = [0.8, 0.7, 0.7] - - # Skip setting testing with default: AttnProcessor - uses_custom_attn_processor = True +class BriaFiboTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return BriaFiboTransformer2DModel @property - def dummy_input(self): - batch_size = 1 - num_latent_channels = 48 - num_image_channels = 3 - height = width = 16 - sequence_length = 32 - embedding_dim = 64 + def main_input_name(self) -> str: + return "hidden_states" - hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) - text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) - image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) - timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + @property + def model_split_percents(self) -> list: + return [0.8, 0.7, 0.7] - return { - "hidden_states": hidden_states, - "encoder_hidden_states": encoder_hidden_states, - "img_ids": image_ids, - "txt_ids": text_ids, - "timestep": timestep, - "text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]], - } + @property + def output_shape(self) -> tuple: + return (256, 48) @property - def input_shape(self): + def input_shape(self) -> tuple: return (16, 16) @property - def output_shape(self): - return (256, 48) + def generator(self): + return torch.Generator("cpu").manual_seed(0) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self) -> dict: + return { "patch_size": 1, "in_channels": 48, "num_layers": 1, @@ -81,9 +68,37 @@ def prepare_init_args_and_inputs_for_common(self): "axes_dims_rope": [0, 4, 4], } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: + num_latent_channels = 48 + num_image_channels = 3 + height = width = 16 + sequence_length = 32 + embedding_dim = 64 + + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) + return { + "hidden_states": randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ), + "encoder_hidden_states": encoder_hidden_states, + "img_ids": randn_tensor( + (height * width, num_image_channels), generator=self.generator, device=torch_device + ), + "txt_ids": randn_tensor( + (sequence_length, num_image_channels), generator=self.generator, device=torch_device + ), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + "text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]], + } + + +class TestBriaFiboTransformer(BriaFiboTransformerTesterConfig, ModelTesterMixin): + pass + +class TestBriaFiboTransformerTraining(BriaFiboTransformerTesterConfig, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): expected_set = {"BriaFiboTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set)