Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 128 additions & 102 deletions tests/models/transformers/test_models_transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,58 +13,63 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from diffusers import SD3Transformer2DModel
from diffusers.utils.import_utils import is_xformers_available

from ...testing_utils import (
enable_full_determinism,
torch_device,
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
Comment on lines +26 to +27
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we ever had these in the legacy tests? If so, let's not include them.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this points to a gap in the previous testing. Let's keep the coverage since SD3.5 is reasonably popular.

TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin


enable_full_determinism()


class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
main_input_name = "hidden_states"
model_split_percents = [0.8, 0.8, 0.9]
# ======================== SD3 Transformer ========================


class SD3TransformerTesterConfig(BaseModelTesterConfig):
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = width = embedding_dim = 32
pooled_embedding_dim = embedding_dim * 2
sequence_length = 154
def model_class(self):
return SD3Transformer2DModel

@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-sd3-pipe"

hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def pretrained_model_kwargs(self):
return {"subfolder": "transformer"}

return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
}
@property
def main_input_name(self) -> str:
return "hidden_states"

@property
def model_split_percents(self) -> list:
return [0.8, 0.8, 0.9]

@property
def input_shape(self):
def output_shape(self) -> tuple:
return (4, 32, 32)

@property
def output_shape(self):
def input_shape(self) -> tuple:
return (4, 32, 32)

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_init_dict(self) -> dict:
return {
"sample_size": 32,
"patch_size": 1,
"in_channels": 4,
Expand All @@ -79,67 +84,79 @@ def prepare_init_args_and_inputs_for_common(self):
"dual_attention_layers": (),
"qk_norm": None,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
height = width = embedding_dim = 32
pooled_embedding_dim = embedding_dim * 2
sequence_length = 154

return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"pooled_projections": randn_tensor(
(batch_size, pooled_embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}

model.enable_xformers_memory_efficient_attention()

assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
"xformers is not enabled"
)
class TestSD3Transformer(SD3TransformerTesterConfig, ModelTesterMixin):
pass

@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass

class TestSD3TransformerTraining(SD3TransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
main_input_name = "hidden_states"
model_split_percents = [0.8, 0.8, 0.9]
class TestSD3TransformerCompile(SD3TransformerTesterConfig, TorchCompileTesterMixin):
pass


# ======================== SD3.5 Transformer ========================


class SD35TransformerTesterConfig(BaseModelTesterConfig):
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = width = embedding_dim = 32
pooled_embedding_dim = embedding_dim * 2
sequence_length = 154
def model_class(self):
return SD3Transformer2DModel

hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-sd35-pipe"

return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
}
@property
def pretrained_model_kwargs(self):
return {"subfolder": "transformer"}

@property
def main_input_name(self) -> str:
return "hidden_states"

@property
def model_split_percents(self) -> list:
return [0.8, 0.8, 0.9]

@property
def input_shape(self):
def output_shape(self) -> tuple:
return (4, 32, 32)

@property
def output_shape(self):
def input_shape(self) -> tuple:
return (4, 32, 32)

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_init_dict(self) -> dict:
return {
"sample_size": 32,
"patch_size": 1,
"in_channels": 4,
Expand All @@ -154,47 +171,56 @@ def prepare_init_args_and_inputs_for_common(self):
"dual_attention_layers": (0,),
"qk_norm": "rms_norm",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)

model.enable_xformers_memory_efficient_attention()

assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
"xformers is not enabled"
)
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
height = width = embedding_dim = 32
pooled_embedding_dim = embedding_dim * 2
sequence_length = 154

@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"pooled_projections": randn_tensor(
(batch_size, pooled_embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}

def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

class TestSD35Transformer(SD35TransformerTesterConfig, ModelTesterMixin):
def test_skip_layers(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)

# Forward pass without skipping layers
output_full = model(**inputs_dict).sample

# Forward pass with skipping layers 0 (since there's only one layer in this test setup)
inputs_dict_with_skip = inputs_dict.copy()
inputs_dict_with_skip["skip_layers"] = [0]
output_skip = model(**inputs_dict_with_skip).sample

# Check that the outputs are different
self.assertFalse(
torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
)
assert not torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
assert output_full.shape == output_skip.shape, "Outputs should have the same shape"


class TestSD35TransformerTraining(SD35TransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class TestSD35TransformerCompile(SD35TransformerTesterConfig, TorchCompileTesterMixin):
pass


class TestSD35TransformerBitsAndBytes(SD35TransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for SD3.5 Transformer."""


# Check that the outputs have the same shape
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
class TestSD35TransformerTorchAo(SD35TransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for SD3.5 Transformer."""
Loading