diff --git a/REFACTOR_PLAN_condition_embedders.md b/REFACTOR_PLAN_condition_embedders.md new file mode 100644 index 000000000000..f78d654d6a03 --- /dev/null +++ b/REFACTOR_PLAN_condition_embedders.md @@ -0,0 +1,124 @@ +# Refactor: Move pipeline-local model classes into `src/diffusers/models/` + +## Motivation + +Several `ModelMixin` / `ConfigMixin` subclasses currently live under `src/diffusers/pipelines//` (e.g. `modeling_audioldm2.py`, `connectors.py`, `vocoder.py`). They are model components, not pipeline glue, and the modular work in PR #13732 forced the question of where to put a new one. Going forward, model classes should live under `src/diffusers/models/`. This refactor introduces a new `condition_embedders/` submodule for projection/conditioning encoders, an `others/` submodule for the long-tail pipeline-local oddities, and migrates the remaining classes into existing submodules (`unets/`, `autoencoders/`, `transformers/`). + +## Scope + +The classes below were located via codebase inventory. Each row gives current source → new home. **Every class** in this list gets a deprecation shim at its old import path — even classes that aren't re-exported from a pipeline `__init__.py`, because users may still do `from diffusers.pipelines.. import ClassName` directly. + +### → `models/unets/` + +| Class | From | New file | +|---|---|---| +| `AudioLDM2UNet2DConditionModel` | `pipelines/audioldm2/modeling_audioldm2.py:163` | `models/unets/unet_2d_condition_audioldm2.py` | + +### → `models/autoencoders/` (mild scope drift to "latent / audio codecs") + +| Class | From | New file | +|---|---|---| +| `LTXLatentUpsamplerModel` | `pipelines/ltx/modeling_latent_upsampler.py:76` | `models/autoencoders/latent_upsampler_ltx.py` | +| `LTX2LatentUpsamplerModel` | `pipelines/ltx2/latent_upsampler.py:170` | `models/autoencoders/latent_upsampler_ltx2.py` | +| `LTX2Vocoder` | `pipelines/ltx2/vocoder.py:279` | `models/autoencoders/vocoder_ltx2.py` | +| `LTX2VocoderWithBWE` | `pipelines/ltx2/vocoder.py:479` | (same file as above) | +| `AceStepAudioTokenizer` | `pipelines/ace_step/modeling_ace_step.py:665` | `models/autoencoders/audio_tokenizer_ace_step.py` | +| `AceStepAudioTokenDetokenizer` | `pipelines/ace_step/modeling_ace_step.py:565` | (same file as above) | + +### → `models/condition_embedders/` (new) + +| Class | From | New file | +|---|---|---| +| `AudioLDM2ProjectionModel` | `pipelines/audioldm2/modeling_audioldm2.py:78` | `models/condition_embedders/projection_audioldm2.py` | +| `StableAudioProjectionModel` | `pipelines/stable_audio/modeling_stable_audio.py:114` | `models/condition_embedders/projection_stable_audio.py` | +| `LTX2TextConnectors` | `pipelines/ltx2/connectors.py:331` | `models/condition_embedders/text_connector_ltx2.py` | +| `ReduxImageEncoder` | `pipelines/flux/modeling_flux.py:31` | `models/condition_embedders/image_encoder_redux.py` | +| `CLIPImageProjection` | `pipelines/stable_diffusion/clip_image_project_model.py:21` | `models/condition_embedders/projection_clip_image.py` | +| `AceStepConditionEncoder` | `pipelines/ace_step/modeling_ace_step.py:752` | `models/condition_embedders/condition_encoder_ace_step.py` | +| `AceStepLyricEncoder` | `pipelines/ace_step/modeling_ace_step.py:127` | (same file as above) | +| `AceStepTimbreEncoder` | `pipelines/ace_step/modeling_ace_step.py:233` | (same file as above) | + +`AceStepLyricEncoder` and `AceStepTimbreEncoder` are not re-exported from any `__init__.py`, but the shim still goes in `modeling_ace_step.py` because the deep-import path `from diffusers.pipelines.ace_step.modeling_ace_step import AceStepLyricEncoder` is part of the implicit public surface. + +### → `models/others/` (new — pipeline-local oddities with no obvious peers) + +| Class | From | New file | +|---|---|---| +| `ShapERenderer` (+ `MLPNeRSTFModel`, `ShapEParamsProjModel`, `MLPNeRFModelOutput`, plus the `BoundingBoxVolume` / `StratifiedRaySampler` / `ImportanceRaySampler` / `VoidNeRFModel` helpers) | `pipelines/shap_e/renderer.py` | `models/others/renderer_shap_e.py` | +| `IFWatermarker` | `pipelines/deepfloyd_if/watermark.py:10` | `models/others/watermark_if.py` | +| `StableUnCLIPImageNormalizer` | `pipelines/stable_diffusion/stable_unclip_image_normalizer.py:22` | `models/others/image_normalizer_stable_unclip.py` | + +`ChatGLMModel` (`pipelines/kolors/text_encoder.py`) is intentionally **excluded** — it is a HuggingFace `transformers` PreTrainedModel re-implementation, not a `ModelMixin`, and the right long-term home is upstream `transformers`. Leaving it in `pipelines/kolors/` keeps that boundary visible. (Open for discussion if reviewers disagree.) + +## Implementation order + +This ships as a **single PR**, but built up via a sequence of self-contained commits so that reviewers can step through the pattern once and then skim the rest. The first commit lands the convention (new submodule scaffold + one full end-to-end migration including the deprecation shim, the `__init__.py` wiring, and the first-party import flip); every subsequent commit mirrors that pattern for the next class or group, so individual commits can be bisected if something breaks. + +Suggested commit sequence (push after each so reviewers see the progression): + +1. **Scaffold + first migration (sets the pattern).** Create `models/condition_embedders/__init__.py` and `models/others/__init__.py`. Migrate `CLIPImageProjection` (small, single-file, public-API export, only one downstream caller) end-to-end: new file, shim with `deprecate()`, top-level `__init__.py` re-export, `make fix-copies`, first-party caller update. This commit's diff is the template for the rest. +2. **AudioLDM2 split.** `modeling_audioldm2.py` hosts both a UNet and a projection model; the file gets split into `models/unets/unet_2d_condition_audioldm2.py` and `models/condition_embedders/projection_audioldm2.py`. The old file becomes a shim that re-exports both classes (plus any internal helpers it still owns that the pipeline imports). Flip `pipeline_audioldm2.py` to import from the new paths. +3. **AceStep split.** `modeling_ace_step.py` hosts five classes across two destinations (autoencoders + condition_embedders). Split into per-destination files; old file becomes a shim for **all five** classes (`AceStepAudioTokenizer`, `AceStepAudioTokenDetokenizer`, `AceStepConditionEncoder`, `AceStepLyricEncoder`, `AceStepTimbreEncoder`). Flip `pipeline_ace_step.py` line 29 from `from .modeling_ace_step import AceStepAudioTokenDetokenizer, AceStepAudioTokenizer, AceStepConditionEncoder` to the new `from ...models.autoencoders import ...` / `from ...models.condition_embedders import ...` imports. +4. **Stable Audio + Flux Redux.** `StableAudioProjectionModel` and `ReduxImageEncoder`. `ReduxImageEncoder` is also re-exported from `qwenimage/__init__.py` — update that re-export to point at the new location too. +5. **LTX / LTX2 family.** All five LTX classes (`LTXLatentUpsamplerModel`, `LTX2LatentUpsamplerModel`, `LTX2Vocoder`, `LTX2VocoderWithBWE`, `LTX2TextConnectors`). These are not in the top-level `diffusers.__init__.py` — only in the pipeline `__init__.py` — so the top-level re-export is skipped, but the pipeline-level `__init__.py` should still re-export from the new path. Flip all `ltx*/pipeline_*.py` and `ltx*/__init__.py` relative imports. +6. **UNet AudioLDM2** can be folded into commit 2 (same source file). Listed here as a reminder that the `models/unets/` registration also needs to happen. +7. **`others/` migrations.** `ShapERenderer` trio, `IFWatermarker`, `StableUnCLIPImageNormalizer`. Same recipe. + +A single PR is the right granularity because the changes are mostly mechanical and reviewers benefit from seeing the full move in one place. The commit boundaries are for navigability inside that PR. + +## Per-class change recipe + +For each class being moved, the change is: + +1. **Create the new file** at the target path. Move the class definition verbatim, plus any private helpers it owns (module-level constants, helper functions, internal `nn.Module` subclasses). Adjust relative imports for the new depth — `from ..pipeline_utils import X` (pipelines, depth-2) becomes `from ..modeling_utils import X` / `from ...utils import ...` (models, depth-3 to utils), etc. +2. **Register the new public name.** Add the class to the appropriate `models//__init__.py`, then re-export from `models/__init__.py`, then ensure `src/diffusers/__init__.py` exports it from the new path (if it was previously in the top-level `__init__`). +3. **Turn the old file into a shim** using `diffusers.utils.deprecate`. Keep the file present — do not delete. Replace its body with: + ```python + from ...models.condition_embedders.projection_audioldm2 import AudioLDM2ProjectionModel as _AudioLDM2ProjectionModel + from ...utils import deprecate + + + class AudioLDM2ProjectionModel(_AudioLDM2ProjectionModel): + def __init__(self, *args, **kwargs): + deprecate( + "AudioLDM2ProjectionModel", + "1.0.0", + "Importing `AudioLDM2ProjectionModel` from `diffusers.pipelines.audioldm2.modeling_audioldm2` is " + "deprecated. Import it from `diffusers.models.condition_embedders` instead " + "(or `from diffusers import AudioLDM2ProjectionModel`).", + ) + super().__init__(*args, **kwargs) + ``` + Subclassing rather than re-assigning preserves `isinstance` checks and gives a clean place to fire the warning on instantiation (not on import). For files that hosted multiple classes (e.g. `modeling_audioldm2.py`, `modeling_ace_step.py`), repeat the shim block for each moved class in the same file. The version slot (`"1.0.0"` above) is the pinned removal target — confirm against current `diffusers.__version__` (today: `0.39.0.dev0`) and pick a version that gives at least one full release cycle of warning. +4. **Update first-party imports** to point at the new location: + - **Pipeline files in the same folder.** Concrete example: `src/diffusers/pipelines/ace_step/pipeline_ace_step.py:29` is currently `from .modeling_ace_step import AceStepAudioTokenDetokenizer, AceStepAudioTokenizer, AceStepConditionEncoder` — flip to `from ...models.autoencoders import AceStepAudioTokenizer, AceStepAudioTokenDetokenizer` and `from ...models.condition_embedders import AceStepConditionEncoder`. Apply the same flip to every pipeline file in the inventory (audioldm2, stable_audio, ltx, ltx2, flux, qwenimage, shap_e, deepfloyd_if, stable_diffusion). + - Conversion scripts under `scripts/`. + - Tests under `tests/`. + - Cross-pipeline re-exports: `qwenimage/__init__.py` currently re-exports `ReduxImageEncoder` from the flux modeling file — point it at the new path. + + Do not add deprecation imports to first-party code — fix the import sites directly so we are not warning ourselves. +5. **Dummy objects.** Re-run `utils/check_dummies.py` (or `make fix-copies`) so the auto-generated `utils/dummy_pt_objects.py` reflects the new export paths. + +## Deprecation warning policy + +- Use `diffusers.utils.deprecate(class_name, version, message)` (matches the rest of the library). The version slot is meaningful: once `diffusers.__version__ >= version`, `deprecate()` raises a `ValueError` telling whoever sees it that the shim needs to be deleted. That gives us an automatic, in-CI nudge to clean up rather than letting shims rot forever. +- One warning per class, fired in `__init__` of the shim subclass, not at module import time. Importing a module shouldn't spam — only constructing a deprecated class should warn. (Users may legitimately have `from ... import X` in a file they never instantiate; we shouldn't punish them.) +- Warning message format (consistent across the refactor): + > `Importing \`{ClassName}\` from \`{old.dotted.path}\` is deprecated. Import it from \`{new.dotted.path}\` instead (or \`from diffusers import {ClassName}\` when re-exported at the top level).` +- Pin every shim to the **same** removal version so the whole batch can be deleted in one cleanup PR. + +## Things I checked and decided against + +- **Moving `ChatGLMModel`** — see above; it's a `transformers` reimpl, not a `ModelMixin`. +- **Deleting the old files immediately** — would break `from diffusers.pipelines.audioldm2.modeling_audioldm2 import AudioLDM2ProjectionModel` and similar deep imports we can't see in third-party code. The shim is cheap and reversible. +- **Re-exporting via `__getattr__` on the old module** — works, but harder to attach a per-class warning to and confuses static analyzers / IDEs. Subclass + `__init__` warning is clearer. +- **Skipping shims for "internal-only" classes (`AceStepLyricEncoder`, `AceStepTimbreEncoder`)** — rejected. Even without an `__init__.py` re-export, third-party code may import them directly from the modeling module. The shim cost is one extra subclass; the breakage risk is real. Shim them. + +## Validation checklist (run once before pushing the PR; spot-check after each commit) + +- [ ] `make style && make quality` clean. +- [ ] `make fix-copies` regenerates dummies with no leftover diff. +- [ ] For every moved class: `python -c "from import ; ()"` emits the deprecation warning but does not raise (the `__init__` may require args — adapt to a real construction or use `inspect.signature` to assert importability). +- [ ] For every moved class: `python -c "from import "` succeeds with no warning. +- [ ] For every class previously in the top-level `__init__`: `python -c "from diffusers import "` still works. +- [ ] Add a dedicated **loading-only** test file (e.g. `tests/models/test_relocated_class_loading.py`) that does a tiny `from_pretrained` call against a published checkpoint for each moved class — small models, no inference, no slow-marker — purely to confirm config resolution still works after the move. `_class_name` in saved configs resolves to the class name (not the import path), so loading should work transparently as long as the new class is reachable from `diffusers.`; the test exists to catch the case where it isn't. Keeping these in their own file keeps the per-pipeline test suites untouched and the loading check easy to delete once the shims are removed. diff --git a/scripts/smoke_test_relocated_pretrained.py b/scripts/smoke_test_relocated_pretrained.py new file mode 100644 index 000000000000..a61c76fb8a54 --- /dev/null +++ b/scripts/smoke_test_relocated_pretrained.py @@ -0,0 +1,93 @@ +import sys +import traceback + +import torch + +import diffusers + + +# (class_name, repo, subfolder, kwargs) — subfolder=None loads from the repo root. +MODEL_CHECKPOINTS = [ + ("CLIPImageProjection", "anhnct/Gligen_Text_Image", "image_project", {}), + ("AudioLDM2ProjectionModel", "cvssp/audioldm2", "projection_model", {}), + ("AudioLDM2UNet2DConditionModel", "cvssp/audioldm2", "unet", {}), + ("StableAudioProjectionModel", "stabilityai/stable-audio-open-1.0", "projection_model", {}), + ("ReduxImageEncoder", "black-forest-labs/FLUX.1-Redux-dev", "image_embedder", {"torch_dtype": torch.bfloat16}), + ("LTXLatentUpsamplerModel", "Lightricks/ltxv-spatial-upscaler-0.9.7", None, {"torch_dtype": torch.bfloat16}), + ("LTX2LatentUpsamplerModel", "Lightricks/LTX-2", "latent_upsampler", {"torch_dtype": torch.bfloat16}), + # LTX2Vocoder vs LTX2VocoderWithBWE: only one will load per repo depending on the class + # the `vocoder/` config records — expect one of the two rows to fail. + ("LTX2Vocoder", "Lightricks/LTX-2", "vocoder", {"torch_dtype": torch.bfloat16}), + ("LTX2VocoderWithBWE", "diffusers/LTX-2.3-Diffusers", "vocoder", {"torch_dtype": torch.bfloat16}), + ("LTX2TextConnectors", "Lightricks/LTX-2", "connectors", {"torch_dtype": torch.bfloat16}), + ("AceStepConditionEncoder", "ACE-Step/acestep-v15-xl-turbo-diffusers", "condition_encoder", {}), + ("ShapERenderer", "openai/shap-e", "shap_e_renderer", {}), + # DeepFloyd/IF-I-XL-v1.0 is gated: requires accepting the license + `huggingface-cli login`. + ("IFWatermarker", "DeepFloyd/IF-I-XL-v1.0", "watermarker", {}), +] + + +# (pipeline_class_name, repo, kwargs) +PIPELINE_CHECKPOINTS = [ + ("StableDiffusionGLIGENTextImagePipeline", "anhnct/Gligen_Text_Image", {"torch_dtype": torch.float16}), + ("AudioLDM2Pipeline", "cvssp/audioldm2", {"torch_dtype": torch.float16}), + ("StableAudioPipeline", "stabilityai/stable-audio-open-1.0", {"torch_dtype": torch.float16}), + ("FluxPriorReduxPipeline", "black-forest-labs/FLUX.1-Redux-dev", {"torch_dtype": torch.bfloat16}), + ("LTXLatentUpsamplePipeline", "Lightricks/ltxv-spatial-upscaler-0.9.7", {"torch_dtype": torch.bfloat16}), + ("LTX2Pipeline", "Lightricks/LTX-2", {"torch_dtype": torch.bfloat16}), + ("AceStepPipeline", "ACE-Step/acestep-v15-xl-turbo-diffusers", {"torch_dtype": torch.bfloat16}), + ("ShapEPipeline", "openai/shap-e", {"torch_dtype": torch.float16}), + ("IFPipeline", "DeepFloyd/IF-I-XL-v1.0", {"torch_dtype": torch.float16, "variant": "fp16"}), +] + + +def _try_load(class_name: str, repo: str, kwargs: dict, subfolder: str | None = None) -> str | None: + """Load the checkpoint; return None on success, or the full traceback on failure.""" + try: + cls = getattr(diffusers, class_name) + load_kwargs = dict(kwargs) + if subfolder is not None: + load_kwargs["subfolder"] = subfolder + cls.from_pretrained(repo, **load_kwargs) + except Exception: + return traceback.format_exc() + return None + + +def main() -> int: + failures: list[tuple[str, str, str]] = [] # (class_name, target, traceback) + + print("=== Model components ===") + for class_name, repo, subfolder, kwargs in MODEL_CHECKPOINTS: + target = repo + (f":{subfolder}" if subfolder else "") + print(f"\n[{class_name}] {target}") + err = _try_load(class_name, repo, kwargs, subfolder) + if err is None: + print(" PASS") + else: + print(" FAIL") + failures.append((class_name, target, err)) + + print("\n=== Pipelines ===") + for class_name, repo, kwargs in PIPELINE_CHECKPOINTS: + print(f"\n[{class_name}] {repo}") + err = _try_load(class_name, repo, kwargs) + if err is None: + print(" PASS") + else: + print(" FAIL") + failures.append((class_name, repo, err)) + + print() + if failures: + print(f"FAILED: {len(failures)} case(s).\n") + for class_name, target, err in failures: + print(f"--- {class_name} ({target}) ---") + print(err) + return 1 + print("OK: all cases passed.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d120d0a22818..8b1431d15352 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -189,10 +189,15 @@ ] _import_structure["models"].extend( [ + "AceStepAudioTokenDetokenizer", + "AceStepAudioTokenizer", + "AceStepConditionEncoder", "AceStepTransformer1DModel", "AllegroTransformer3DModel", "AsymmetricAutoencoderKL", "AttentionBackendName", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", "AuraFlowTransformer2DModel", "AutoencoderDC", "AutoencoderKL", @@ -224,6 +229,7 @@ "CacheMixin", "ChromaTransformer2DModel", "ChronoEditTransformer3DModel", + "CLIPImageProjection", "CogVideoXTransformer3DModel", "CogView3PlusTransformer2DModel", "CogView4Transformer2DModel", @@ -253,6 +259,7 @@ "HunyuanVideoFramepackTransformer3DModel", "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", + "IFWatermarker", "JoyImageEditTransformer3DModel", "Kandinsky3UNet", "Kandinsky5Transformer3DModel", @@ -260,7 +267,12 @@ "LongCatAudioDiTTransformer", "LongCatAudioDiTVae", "LongCatImageTransformer2DModel", + "LTX2LatentUpsamplerModel", + "LTX2TextConnectors", "LTX2VideoTransformer3DModel", + "LTX2Vocoder", + "LTX2VocoderWithBWE", + "LTXLatentUpsamplerModel", "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", @@ -280,16 +292,20 @@ "QwenImageControlNetModel", "QwenImageMultiControlNetModel", "QwenImageTransformer2DModel", + "ReduxImageEncoder", "SanaControlNetModel", "SanaTransformer2DModel", "SanaVideoTransformer3DModel", "SD3ControlNetModel", "SD3MultiControlNetModel", "SD3Transformer2DModel", + "ShapERenderer", "SkyReelsV2Transformer3DModel", "SparseControlNetModel", "StableAudioDiTModel", + "StableAudioProjectionModel", "StableCascadeUNet", + "StableUnCLIPImageNormalizer", "T2IAdapter", "T5FilmDecoder", "Transformer2DModel", @@ -494,9 +510,6 @@ ) _import_structure["pipelines"].extend( [ - "AceStepAudioTokenDetokenizer", - "AceStepAudioTokenizer", - "AceStepConditionEncoder", "AceStepPipeline", "AllegroPipeline", "AltDiffusionImg2ImgPipeline", @@ -512,8 +525,6 @@ "AnimateDiffVideoToVideoControlNetPipeline", "AnimateDiffVideoToVideoPipeline", "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", "AudioLDMPipeline", "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", @@ -525,7 +536,6 @@ "ChromaInpaintPipeline", "ChromaPipeline", "ChronoEditPipeline", - "CLIPImageProjection", "CogVideoXFunControlPipeline", "CogVideoXImageToVideoPipeline", "CogVideoXPipeline", @@ -663,7 +673,6 @@ "QwenImageInpaintPipeline", "QwenImageLayeredPipeline", "QwenImagePipeline", - "ReduxImageEncoder", "SanaControlNetPipeline", "SanaImageToVideoPipeline", "SanaPAGPipeline", @@ -681,7 +690,6 @@ "SkyReelsV2ImageToVideoPipeline", "SkyReelsV2Pipeline", "StableAudioPipeline", - "StableAudioProjectionModel", "StableCascadeCombinedPipeline", "StableCascadeDecoderPipeline", "StableCascadePriorPipeline", @@ -1017,10 +1025,15 @@ VaeImageProcessorLDM3D, ) from .models import ( + AceStepAudioTokenDetokenizer, + AceStepAudioTokenizer, + AceStepConditionEncoder, AceStepTransformer1DModel, AllegroTransformer3DModel, AsymmetricAutoencoderKL, AttentionBackendName, + AudioLDM2ProjectionModel, + AudioLDM2UNet2DConditionModel, AuraFlowTransformer2DModel, AutoencoderDC, AutoencoderKL, @@ -1052,6 +1065,7 @@ CacheMixin, ChromaTransformer2DModel, ChronoEditTransformer3DModel, + CLIPImageProjection, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, CogView4Transformer2DModel, @@ -1081,6 +1095,7 @@ HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, I2VGenXLUNet, + IFWatermarker, JoyImageEditTransformer3DModel, Kandinsky3UNet, Kandinsky5Transformer3DModel, @@ -1088,7 +1103,12 @@ LongCatAudioDiTTransformer, LongCatAudioDiTVae, LongCatImageTransformer2DModel, + LTX2LatentUpsamplerModel, + LTX2TextConnectors, LTX2VideoTransformer3DModel, + LTX2Vocoder, + LTX2VocoderWithBWE, + LTXLatentUpsamplerModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, @@ -1108,15 +1128,19 @@ QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageTransformer2DModel, + ReduxImageEncoder, SanaControlNetModel, SanaTransformer2DModel, SanaVideoTransformer3DModel, SD3ControlNetModel, SD3MultiControlNetModel, SD3Transformer2DModel, + ShapERenderer, SkyReelsV2Transformer3DModel, SparseControlNetModel, StableAudioDiTModel, + StableAudioProjectionModel, + StableUnCLIPImageNormalizer, T2IAdapter, T5FilmDecoder, Transformer2DModel, @@ -1166,7 +1190,6 @@ AutoPipelineForText2Image, BlipDiffusionControlNetPipeline, BlipDiffusionPipeline, - CLIPImageProjection, ConsistencyModelPipeline, DanceDiffusionPipeline, DDIMPipeline, @@ -1299,9 +1322,6 @@ ZImageModularPipeline, ) from .pipelines import ( - AceStepAudioTokenDetokenizer, - AceStepAudioTokenizer, - AceStepConditionEncoder, AceStepPipeline, AllegroPipeline, AltDiffusionImg2ImgPipeline, @@ -1317,8 +1337,6 @@ AnimateDiffVideoToVideoControlNetPipeline, AnimateDiffVideoToVideoPipeline, AudioLDM2Pipeline, - AudioLDM2ProjectionModel, - AudioLDM2UNet2DConditionModel, AudioLDMPipeline, AuraFlowPipeline, BriaFiboEditPipeline, @@ -1328,7 +1346,6 @@ ChromaInpaintPipeline, ChromaPipeline, ChronoEditPipeline, - CLIPImageProjection, CogVideoXFunControlPipeline, CogVideoXImageToVideoPipeline, CogVideoXPipeline, @@ -1466,7 +1483,6 @@ QwenImageInpaintPipeline, QwenImageLayeredPipeline, QwenImagePipeline, - ReduxImageEncoder, SanaControlNetPipeline, SanaImageToVideoPipeline, SanaPAGPipeline, @@ -1483,7 +1499,6 @@ SkyReelsV2ImageToVideoPipeline, SkyReelsV2Pipeline, StableAudioPipeline, - StableAudioProjectionModel, StableCascadeCombinedPipeline, StableCascadeDecoderPipeline, StableCascadePriorPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index ff8e16aad447..fe11dd1c9dd3 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -29,6 +29,10 @@ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"] _import_structure["auto_model"] = ["AutoModel"] + _import_structure["autoencoders.audio_tokenizer_ace_step"] = [ + "AceStepAudioTokenDetokenizer", + "AceStepAudioTokenizer", + ] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] @@ -56,8 +60,17 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.autoencoder_vidtok"] = ["AutoencoderVidTok"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] + _import_structure["autoencoders.latent_upsampler_ltx"] = ["LTXLatentUpsamplerModel"] + _import_structure["autoencoders.latent_upsampler_ltx2"] = ["LTX2LatentUpsamplerModel"] + _import_structure["autoencoders.vocoder_ltx2"] = ["LTX2Vocoder", "LTX2VocoderWithBWE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["cache_utils"] = ["CacheMixin"] + _import_structure["condition_embedders.condition_encoder_ace_step"] = ["AceStepConditionEncoder"] + _import_structure["condition_embedders.image_encoder_redux"] = ["ReduxImageEncoder"] + _import_structure["condition_embedders.projection_audioldm2"] = ["AudioLDM2ProjectionModel"] + _import_structure["condition_embedders.projection_clip_image"] = ["CLIPImageProjection"] + _import_structure["condition_embedders.projection_stable_audio"] = ["StableAudioProjectionModel"] + _import_structure["condition_embedders.text_connector_ltx2"] = ["LTX2TextConnectors"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"] _import_structure["controlnets.controlnet_cosmos"] = ["CosmosControlNetModel"] _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] @@ -79,6 +92,18 @@ _import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] + _import_structure["others.image_normalizer_stable_unclip"] = ["StableUnCLIPImageNormalizer"] + _import_structure["others.renderer_shap_e"] = [ + "BoundingBoxVolume", + "ImportanceRaySampler", + "MLPNeRFModelOutput", + "MLPNeRSTFModel", + "ShapEParamsProjModel", + "ShapERenderer", + "StratifiedRaySampler", + "VoidNeRFModel", + ] + _import_structure["others.watermark_if"] = ["IFWatermarker"] _import_structure["transformers.ace_step_transformer"] = ["AceStepTransformer1DModel"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] @@ -138,6 +163,7 @@ _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] + _import_structure["unets.unet_2d_condition_audioldm2"] = ["AudioLDM2UNet2DConditionModel"] _import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"] _import_structure["unets.unet_i2vgen_xl"] = ["I2VGenXLUNet"] _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"] @@ -159,6 +185,8 @@ from .attention_dispatch import AttentionBackendName, attention_backend from .auto_model import AutoModel from .autoencoders import ( + AceStepAudioTokenDetokenizer, + AceStepAudioTokenizer, AsymmetricAutoencoderKL, AutoencoderDC, AutoencoderKL, @@ -186,9 +214,21 @@ AutoencoderVidTok, ConsistencyDecoderVAE, LongCatAudioDiTVae, + LTX2LatentUpsamplerModel, + LTX2Vocoder, + LTX2VocoderWithBWE, + LTXLatentUpsamplerModel, VQModel, ) from .cache_utils import CacheMixin + from .condition_embedders import ( + AceStepConditionEncoder, + AudioLDM2ProjectionModel, + CLIPImageProjection, + LTX2TextConnectors, + ReduxImageEncoder, + StableAudioProjectionModel, + ) from .controlnets import ( ControlNetModel, ControlNetUnionModel, @@ -211,6 +251,18 @@ ) from .embeddings import ImageProjection from .modeling_utils import ModelMixin + from .others import ( + BoundingBoxVolume, + IFWatermarker, + ImportanceRaySampler, + MLPNeRFModelOutput, + MLPNeRSTFModel, + ShapEParamsProjModel, + ShapERenderer, + StableUnCLIPImageNormalizer, + StratifiedRaySampler, + VoidNeRFModel, + ) from .transformers import ( AceStepTransformer1DModel, AllegroTransformer3DModel, @@ -270,6 +322,7 @@ ZImageTransformer2DModel, ) from .unets import ( + AudioLDM2UNet2DConditionModel, I2VGenXLUNet, Kandinsky3UNet, MotionAdapter, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 90dfa31fab6f..128f03fc60c5 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -1,3 +1,4 @@ +from .audio_tokenizer_ace_step import AceStepAudioTokenDetokenizer, AceStepAudioTokenizer from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_dc import AutoencoderDC from .autoencoder_kl import AutoencoderKL @@ -25,4 +26,7 @@ from .autoencoder_tiny import AutoencoderTiny from .autoencoder_vidtok import AutoencoderVidTok from .consistency_decoder_vae import ConsistencyDecoderVAE +from .latent_upsampler_ltx import LTXLatentUpsamplerModel +from .latent_upsampler_ltx2 import LTX2LatentUpsamplerModel +from .vocoder_ltx2 import LTX2Vocoder, LTX2VocoderWithBWE from .vq_model import VQModel diff --git a/src/diffusers/models/autoencoders/audio_tokenizer_ace_step.py b/src/diffusers/models/autoencoders/audio_tokenizer_ace_step.py new file mode 100644 index 000000000000..a8501f405734 --- /dev/null +++ b/src/diffusers/models/autoencoders/audio_tokenizer_ace_step.py @@ -0,0 +1,476 @@ +# Copyright 2025 The ACE-Step Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""ACE-Step audio tokenizer / detokenizer. + +Converts between the VAE's 25 Hz acoustic latents and the 5 Hz semantic audio tokens used by the cover-conditioning +path. The internal ``AceStepEncoderLayer`` block is duplicated here (and in +:mod:`diffusers.models.condition_embedders.condition_encoder_ace_step`) to keep the two model files independent. +""" + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm +from ..transformers.ace_step_transformer import ( + AceStepAttention, + AceStepMLP, + _ace_step_rotary_freqs, + _create_4d_mask, + _is_flash_attention_backend, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class AceStepEncoderLayer(nn.Module): + """Pre-LN transformer block used by the lyric and timbre encoders.""" + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + intermediate_size: int, + attention_bias: bool = False, + attention_dropout: float = 0.0, + rms_norm_eps: float = 1e-6, + sliding_window: Optional[int] = None, + ): + super().__init__() + self.self_attn = AceStepAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + bias=attention_bias, + dropout=attention_dropout, + eps=rms_norm_eps, + sliding_window=sliding_window, + is_cross_attention=False, + ) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.mlp = AceStepMLP(hidden_size, intermediate_size) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + image_rotary_emb=position_embeddings, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class _AceStepResidualFSQ(nn.Module): + """Minimal ResidualFSQ compatible with ACE-Step's saved tokenizer weights.""" + + def __init__( + self, + dim: int = 2048, + levels: Optional[list] = None, + num_quantizers: int = 1, + ): + super().__init__() + + if levels is None: + levels = [8, 8, 8, 5, 5, 5] + + self.levels = levels + self.num_quantizers = num_quantizers + self.codebook_dim = len(levels) + + self.project_in = nn.Linear(dim, self.codebook_dim) + self.project_out = nn.Linear(self.codebook_dim, dim) + + levels_tensor = torch.tensor(levels, dtype=torch.long) + basis = torch.cumprod(torch.tensor([1] + levels[:-1], dtype=torch.long), dim=0) + scales = torch.stack([levels_tensor.float() ** -i for i in range(num_quantizers)]) + self.register_buffer("_levels", levels_tensor, persistent=False) + self.register_buffer("_basis", basis, persistent=False) + self.register_buffer("scales", scales, persistent=False) + + @property + def codebook_size(self) -> int: + return int(torch.prod(self._levels).item()) + + def _indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor: + levels = self._levels.to(device=indices.device) + basis = self._basis.to(device=indices.device) + level_indices = (indices.long().unsqueeze(-1) // basis) % levels + scale = 2.0 / (levels.to(dtype=torch.float32) - 1.0) + return level_indices.to(dtype=torch.float32) * scale - 1.0 + + def _codes_to_indices(self, codes: torch.Tensor) -> torch.Tensor: + levels = self._levels.to(device=codes.device, dtype=codes.dtype) + basis = self._basis.to(device=codes.device, dtype=codes.dtype) + level_indices = (codes + 1.0) / (2.0 / (levels - 1.0)) + return (level_indices * basis).sum(dim=-1).round().to(torch.long) + + def _quantize(self, x: torch.Tensor) -> torch.Tensor: + levels = self._levels.to(device=x.device, dtype=x.dtype) + levels_minus_one = levels - 1.0 + step = 2.0 / levels_minus_one + bracket = levels_minus_one * (x.clamp(-1.0, 1.0) + 1.0) / 2.0 + 0.5 + return step * torch.floor(bracket) - 1.0 + + def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor: + if indices.ndim == 2: + indices = indices.unsqueeze(-1) + if indices.shape[-1] != self.num_quantizers: + raise ValueError( + f"Expected audio code indices with last dimension {self.num_quantizers}, got {indices.shape[-1]}." + ) + + codes = [] + for quantizer_idx in range(self.num_quantizers): + code = self._indices_to_codes(indices[..., quantizer_idx]) + scale = self.scales[quantizer_idx].to(device=code.device, dtype=code.dtype) + codes.append(code * scale) + return torch.stack(codes, dim=0) + + def get_output_from_indices(self, indices: torch.Tensor) -> torch.Tensor: + codes = self.get_codes_from_indices(indices).sum(dim=0) + weight = self.project_out.weight.float() + bias = self.project_out.bias.float() if self.project_out.bias is not None else None + output = F.linear(codes.float(), weight, bias) + return output.to(dtype=self.project_out.weight.dtype) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + input_dtype = hidden_states.dtype + weight = self.project_in.weight.float() + bias = self.project_in.bias.float() if self.project_in.bias is not None else None + hidden_states = F.linear(hidden_states.float(), weight, bias) + + levels = self._levels.to(device=hidden_states.device, dtype=hidden_states.dtype) + soft_clamp = 1.0 + (1.0 / (levels - 1.0)) + hidden_states = (hidden_states / soft_clamp).tanh() * soft_clamp + + quantized_out = torch.zeros_like(hidden_states) + residual = hidden_states + all_indices = [] + for scale in self.scales.to(device=hidden_states.device, dtype=hidden_states.dtype): + quantized = self._quantize(residual / scale) * scale + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + all_indices.append(self._codes_to_indices(quantized / scale)) + + weight = self.project_out.weight.float() + bias = self.project_out.bias.float() if self.project_out.bias is not None else None + quantized_out = F.linear(quantized_out.float(), weight, bias).to(dtype=input_dtype) + all_indices = torch.stack(all_indices, dim=-1) + return quantized_out, all_indices + + +class AceStepAttentionPooler(nn.Module): + """Attention pooler used by the ACE-Step audio tokenizer.""" + + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + num_attention_pooler_hidden_layers: int = 2, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + head_dim: int = 128, + rope_theta: float = 1000000.0, + attention_bias: bool = False, + attention_dropout: float = 0.0, + rms_norm_eps: float = 1e-6, + sliding_window: int = 128, + layer_types: list = None, + ): + super().__init__() + + if layer_types is None: + layer_types = [ + "sliding_attention" if bool((i + 1) % 2) else "full_attention" + for i in range(num_attention_pooler_hidden_layers) + ] + + self.embed_tokens = nn.Linear(hidden_size, hidden_size) + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) + self.head_dim = head_dim + self.rope_theta = rope_theta + self.sliding_window = sliding_window + self.layers = nn.ModuleList( + [ + AceStepEncoderLayer( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + rms_norm_eps=rms_norm_eps, + sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None, + ) + for i in range(num_attention_pooler_hidden_layers) + ] + ) + self._layer_types = layer_types + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_patches, patch_size, _ = hidden_states.shape + hidden_states = self.embed_tokens(hidden_states) + special_token = self.special_token.to(device=hidden_states.device, dtype=hidden_states.dtype) + special_token = special_token.expand(batch_size, num_patches, -1, -1) + hidden_states = torch.cat([special_token, hidden_states], dim=2) + hidden_states = hidden_states.reshape(batch_size * num_patches, patch_size + 1, -1) + + seq_len = hidden_states.shape[1] + dtype = hidden_states.dtype + device = hidden_states.device + position_embeddings = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype) + sliding_attn_mask = None + if not _is_flash_attention_backend(self.layers[0].self_attn.processor): + sliding_attn_mask = _create_4d_mask( + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=None, + sliding_window=self.sliding_window, + is_sliding_window=True, + is_causal=False, + ) + + for i, layer_module in enumerate(self.layers): + mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else None + hidden_states = layer_module( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=mask, + ) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states[:, 0, :] + return hidden_states.reshape(batch_size, num_patches, -1) + + +class AceStepAudioTokenDetokenizer(ModelMixin, ConfigMixin): + """Expands ACE-Step 5 Hz audio tokens back to 25 Hz acoustic conditioning.""" + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + audio_acoustic_hidden_dim: int = 64, + pool_window_size: int = 5, + num_attention_pooler_hidden_layers: int = 2, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + head_dim: int = 128, + rope_theta: float = 1000000.0, + attention_bias: bool = False, + attention_dropout: float = 0.0, + rms_norm_eps: float = 1e-6, + sliding_window: int = 128, + layer_types: list = None, + ): + super().__init__() + + if layer_types is None: + layer_types = [ + "sliding_attention" if bool((i + 1) % 2) else "full_attention" + for i in range(num_attention_pooler_hidden_layers) + ] + + self.embed_tokens = nn.Linear(hidden_size, hidden_size) + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.special_tokens = nn.Parameter(torch.randn(1, pool_window_size, hidden_size) * 0.02) + self.proj_out = nn.Linear(hidden_size, audio_acoustic_hidden_dim) + self.head_dim = head_dim + self.rope_theta = rope_theta + self.sliding_window = sliding_window + self.pool_window_size = pool_window_size + self.layers = nn.ModuleList( + [ + AceStepEncoderLayer( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + rms_norm_eps=rms_norm_eps, + sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None, + ) + for i in range(num_attention_pooler_hidden_layers) + ] + ) + self._layer_types = layer_types + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor`): + Input audio tokens of shape `(batch_size, num_tokens, hidden_size)` to be unpooled back to the 25 Hz + acoustic-latent rate. + """ + batch_size, num_tokens, _ = hidden_states.shape + hidden_states = self.embed_tokens(hidden_states) + hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.pool_window_size, -1) + special_tokens = self.special_tokens.to(device=hidden_states.device, dtype=hidden_states.dtype) + hidden_states = hidden_states + special_tokens.unsqueeze(0) + hidden_states = hidden_states.reshape(batch_size * num_tokens, self.pool_window_size, -1) + + seq_len = hidden_states.shape[1] + dtype = hidden_states.dtype + device = hidden_states.device + position_embeddings = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype) + sliding_attn_mask = None + if not _is_flash_attention_backend(self.layers[0].self_attn.processor): + sliding_attn_mask = _create_4d_mask( + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=None, + sliding_window=self.sliding_window, + is_sliding_window=True, + is_causal=False, + ) + + for i, layer_module in enumerate(self.layers): + mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else None + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + layer_module, hidden_states, position_embeddings, mask + ) + else: + hidden_states = layer_module( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=mask, + ) + + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_out(hidden_states) + return hidden_states.reshape(batch_size, num_tokens * self.pool_window_size, -1) + + +class AceStepAudioTokenizer(ModelMixin, ConfigMixin): + """Converts 25 Hz acoustic latents to ACE-Step 5 Hz audio tokens.""" + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + audio_acoustic_hidden_dim: int = 64, + pool_window_size: int = 5, + fsq_dim: int = 2048, + fsq_input_levels: list = None, + fsq_input_num_quantizers: int = 1, + num_attention_pooler_hidden_layers: int = 2, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + head_dim: int = 128, + rope_theta: float = 1000000.0, + attention_bias: bool = False, + attention_dropout: float = 0.0, + rms_norm_eps: float = 1e-6, + sliding_window: int = 128, + layer_types: list = None, + ): + super().__init__() + + if fsq_input_levels is None: + fsq_input_levels = [8, 8, 8, 5, 5, 5] + + self.audio_acoustic_proj = nn.Linear(audio_acoustic_hidden_dim, hidden_size) + self.attention_pooler = AceStepAttentionPooler( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + rope_theta=rope_theta, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + rms_norm_eps=rms_norm_eps, + sliding_window=sliding_window, + layer_types=layer_types, + ) + self.quantizer = _AceStepResidualFSQ( + dim=fsq_dim, + levels=fsq_input_levels, + num_quantizers=fsq_input_num_quantizers, + ) + self.pool_window_size = pool_window_size + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + hidden_states (`torch.Tensor`): + Input acoustic latents of shape `(batch_size, latent_length, audio_acoustic_hidden_dim)` to be + quantized into ACE-Step 5 Hz audio tokens. + """ + input_dtype = hidden_states.dtype + hidden_states = self.audio_acoustic_proj(hidden_states) + hidden_states = self.attention_pooler(hidden_states) + quantized, indices = self.quantizer(hidden_states) + return quantized.to(dtype=input_dtype), indices + + def tokenize( + self, + hidden_states: torch.Tensor, + silence_latent: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, latent_length, acoustic_dim = hidden_states.shape + pad_len = (-latent_length) % self.pool_window_size + if pad_len: + if silence_latent is not None and silence_latent.shape[-1] == acoustic_dim: + pad = silence_latent[:, :pad_len, :].to(device=hidden_states.device, dtype=hidden_states.dtype) + pad = pad.expand(batch_size, -1, -1) + else: + pad = torch.zeros( + batch_size, pad_len, acoustic_dim, device=hidden_states.device, dtype=hidden_states.dtype + ) + hidden_states = torch.cat([hidden_states, pad], dim=1) + + num_patches = hidden_states.shape[1] // self.pool_window_size + hidden_states = hidden_states.reshape(batch_size, num_patches, self.pool_window_size, acoustic_dim) + return self(hidden_states) diff --git a/src/diffusers/models/autoencoders/latent_upsampler_ltx.py b/src/diffusers/models/autoencoders/latent_upsampler_ltx.py new file mode 100644 index 000000000000..56e2e7fba284 --- /dev/null +++ b/src/diffusers/models/autoencoders/latent_upsampler_ltx.py @@ -0,0 +1,192 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin + + +class ResBlock(torch.nn.Module): + def __init__(self, channels: int, mid_channels: int | None = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.norm2(hidden_states) + hidden_states = self.activation(hidden_states + residual) + return hidden_states + + +class PixelShuffleND(torch.nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + + self.dims = dims + self.upscale_factors = upscale_factors + + if dims not in [1, 2, 3]: + raise ValueError("dims must be 1, 2, or 3") + + def forward(self, x): + if self.dims == 3: + # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:3])) + .permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + elif self.dims == 2: + # spatial: b (c p1 p2) h w -> b c (h p1) (w p2) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3) + ) + elif self.dims == 1: + # temporal: b (c p1) f h w -> b c (f p1) h w + return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3) + + +class LTXLatentUpsamplerModel(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`, defaults to `128`): + Number of channels in the input latent + mid_channels (`int`, defaults to `512`): + Number of channels in the middle layers + num_blocks_per_stage (`int`, defaults to `4`): + Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`, defaults to `3`): + Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`, defaults to `True`): + Whether to spatially upsample the latent + temporal_upsample (`bool`, defaults to `False`): + Whether to temporally upsample the latent + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor`): + Input latents of shape `(batch_size, num_channels, num_frames, height, width)` to spatially or + temporally upsample. + """ + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.dims == 2: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.upsampler(hidden_states) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + else: + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + if self.temporal_upsample: + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states[:, :, 1:, :, :] + else: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + + return hidden_states diff --git a/src/diffusers/models/autoencoders/latent_upsampler_ltx2.py b/src/diffusers/models/autoencoders/latent_upsampler_ltx2.py new file mode 100644 index 000000000000..fb0a42291fd3 --- /dev/null +++ b/src/diffusers/models/autoencoders/latent_upsampler_ltx2.py @@ -0,0 +1,291 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin + + +RATIONAL_RESAMPLER_SCALE_MAPPING = { + 0.75: (3, 4), + 1.5: (3, 2), + 2.0: (2, 1), + 4.0: (4, 1), +} + + +# Copied from diffusers.models.autoencoders.latent_upsampler_ltx.ResBlock +class ResBlock(torch.nn.Module): + def __init__(self, channels: int, mid_channels: int | None = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.norm2(hidden_states) + hidden_states = self.activation(hidden_states + residual) + return hidden_states + + +# Copied from diffusers.models.autoencoders.latent_upsampler_ltx.PixelShuffleND +class PixelShuffleND(torch.nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + + self.dims = dims + self.upscale_factors = upscale_factors + + if dims not in [1, 2, 3]: + raise ValueError("dims must be 1, 2, or 3") + + def forward(self, x): + if self.dims == 3: + # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:3])) + .permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + elif self.dims == 2: + # spatial: b (c p1 p2) h w -> b c (h p1) (w p2) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3) + ) + elif self.dims == 1: + # temporal: b (c p1) f h w -> b c (f p1) h w + return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3) + + +class BlurDownsample(torch.nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W. + Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None: + super().__init__() + + if dims not in (2, 3): + raise ValueError(f"`dims` must be either 2 or 3 but is {dims}") + if kernel_size < 3 or kernel_size % 2 != 1: + raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}") + + self.dims = dims + self.stride = stride + self.kernel_size = kernel_size + + # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from + # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and + # provides a smooth approximation of a Gaussian filter (often called a "binomial filter"). + # The 2D kernel is constructed as the outer product and normalized. + k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + if self.dims == 2: + c = x.shape[1] + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + else: + # dims == 3: apply per-frame on H,W + b, c, f, _, _ = x.shape + x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] + + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + + h2, w2 = x.shape[-2:] + x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W] + return x + + +class SpatialRationalResampler(torch.nn.Module): + """ + Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample + by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the + input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the + (integer) denominator. + """ + + def __init__(self, mid_channels: int = 1024, scale: float = 2.0): + super().__init__() + self.scale = float(scale) + num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None) + if num_denom is None: + raise ValueError( + f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}" + ) + self.num, self.den = num_denom + + self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Expected x shape: [B * F, C, H, W] + # b, _, f, h, w = x.shape + # x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + # x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W] + return x + + +class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`, defaults to `128`): + Number of channels in the input latent + mid_channels (`int`, defaults to `512`): + Number of channels in the middle layers + num_blocks_per_stage (`int`, defaults to `4`): + Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`, defaults to `3`): + Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`, defaults to `True`): + Whether to spatially upsample the latent + temporal_upsample (`bool`, defaults to `False`): + Whether to temporally upsample the latent + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 1024, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + rational_spatial_scale: float = 2.0, + use_rational_resampler: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if use_rational_resampler: + self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale) + else: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor`): + Input latents of shape `(batch_size, num_channels, num_frames, height, width)` to spatially or + temporally upsample. + """ + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.dims == 2: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.upsampler(hidden_states) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + else: + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + if self.temporal_upsample: + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states[:, :, 1:, :, :] + else: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + + return hidden_states diff --git a/src/diffusers/models/autoencoders/vocoder_ltx2.py b/src/diffusers/models/autoencoders/vocoder_ltx2.py new file mode 100644 index 000000000000..8ede102d46a0 --- /dev/null +++ b/src/diffusers/models/autoencoders/vocoder_ltx2.py @@ -0,0 +1,602 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin + + +def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor: + """ + Creates a Kaiser sinc kernel for low-pass filtering. + + Args: + cutoff (`float`): + Normalized frequency cutoff (relative to the sampling rate). Must be between 0 and 0.5 (the Nyquist + frequency). + half_width (`float`): + Used to determine the Kaiser window's beta parameter. + kernel_size: + Size of the Kaiser window (and ultimately the Kaiser sinc kernel). + + Returns: + `torch.Tensor` of shape `(kernel_size,)`: + The Kaiser sinc kernel. + """ + delta_f = 4 * half_width + half_size = kernel_size // 2 + amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if amplitude > 50.0: + beta = 0.1102 * (amplitude - 8.7) + elif amplitude >= 21.0: + beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0) + else: + beta = 0.0 + + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + even = kernel_size % 2 == 0 + time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size + + if cutoff == 0.0: + filter = torch.zeros_like(time) + else: + time = 2 * cutoff * time + sinc = torch.where( + time == 0, + torch.ones_like(time), + torch.sin(math.pi * time) / math.pi / time, + ) + filter = 2 * cutoff * window * sinc + filter = filter / filter.sum() + return filter + + +class DownSample1d(nn.Module): + """1D low-pass filter for antialias downsampling.""" + + def __init__( + self, + ratio: int = 2, + kernel_size: int | None = None, + use_padding: bool = True, + padding_mode: str = "replicate", + persistent: bool = True, + ): + super().__init__() + self.ratio = ratio + self.kernel_size = kernel_size or int(6 * ratio // 2) * 2 + self.pad_left = self.kernel_size // 2 + (self.kernel_size % 2) - 1 + self.pad_right = self.kernel_size // 2 + self.use_padding = use_padding + self.padding_mode = padding_mode + + cutoff = 0.5 / ratio + half_width = 0.6 / ratio + low_pass_filter = kaiser_sinc_filter1d(cutoff, half_width, self.kernel_size) + self.register_buffer("filter", low_pass_filter.view(1, 1, self.kernel_size), persistent=persistent) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x expected shape: [batch_size, num_channels, hidden_dim] + num_channels = x.shape[1] + if self.use_padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + x_filtered = F.conv1d(x, self.filter.expand(num_channels, -1, -1), stride=self.ratio, groups=num_channels) + return x_filtered + + +class UpSample1d(nn.Module): + def __init__( + self, + ratio: int = 2, + kernel_size: int | None = None, + window_type: str = "kaiser", + padding_mode: str = "replicate", + persistent: bool = True, + ): + super().__init__() + self.ratio = ratio + self.padding_mode = padding_mode + + if window_type == "hann": + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + self.kernel_size = 2 * width * ratio + 1 + self.pad = width + self.pad_left = 2 * width * ratio + self.pad_right = self.kernel_size - ratio + + time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff + time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width) + window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2 + sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1) + else: + # Kaiser sinc filter is BigVGAN default + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.ratio + (self.kernel_size - self.ratio) // 2 + self.pad_right = self.pad * self.ratio + (self.kernel_size - self.ratio + 1) // 2 + + sinc_filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size, + ) + + self.register_buffer("filter", sinc_filter.view(1, 1, self.kernel_size), persistent=persistent) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x expected shape: [batch_size, num_channels, hidden_dim] + num_channels = x.shape[1] + x = F.pad(x, (self.pad, self.pad), mode=self.padding_mode) + low_pass_filter = self.filter.to(dtype=x.dtype, device=x.device).expand(num_channels, -1, -1) + x = self.ratio * F.conv_transpose1d(x, low_pass_filter, stride=self.ratio, groups=num_channels) + return x[..., self.pad_left : -self.pad_right] + + +class AntiAliasAct1d(nn.Module): + """ + Antialiasing activation for a 1D signal: upsamples, applies an activation (usually snakebeta), and then downsamples + to avoid aliasing. + """ + + def __init__( + self, + act_fn: str | nn.Module, + ratio: int = 2, + kernel_size: int = 12, + **kwargs, + ): + super().__init__() + self.upsample = UpSample1d(ratio=ratio, kernel_size=kernel_size) + if isinstance(act_fn, str): + if act_fn == "snakebeta": + act_fn = SnakeBeta(**kwargs) + elif act_fn == "snake": + act_fn = SnakeBeta(**kwargs) + else: + act_fn = nn.LeakyReLU(**kwargs) + self.act = act_fn + self.downsample = DownSample1d(ratio=ratio, kernel_size=kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + + +class SnakeBeta(nn.Module): + """ + Implements the Snake and SnakeBeta activations, which help with learning periodic patterns. + """ + + def __init__( + self, + channels: int, + alpha: float = 1.0, + eps: float = 1e-9, + trainable_params: bool = True, + logscale: bool = True, + use_beta: bool = True, + ): + super().__init__() + self.eps = eps + self.logscale = logscale + self.use_beta = use_beta + + self.alpha = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha) + self.alpha.requires_grad = trainable_params + if use_beta: + self.beta = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha) + self.beta.requires_grad = trainable_params + + def forward(self, hidden_states: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + broadcast_shape = [1] * hidden_states.ndim + broadcast_shape[channel_dim] = -1 + alpha = self.alpha.view(broadcast_shape) + if self.use_beta: + beta = self.beta.view(broadcast_shape) + + if self.logscale: + alpha = torch.exp(alpha) + if self.use_beta: + beta = torch.exp(beta) + + amplitude = beta if self.use_beta else alpha + hidden_states = hidden_states + (1.0 / (amplitude + self.eps)) * torch.sin(hidden_states * alpha).pow(2) + return hidden_states + + +class ResBlock(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int = 3, + stride: int = 1, + dilations: tuple[int, ...] = (1, 3, 5), + act_fn: str = "leaky_relu", + leaky_relu_negative_slope: float = 0.1, + antialias: bool = False, + antialias_ratio: int = 2, + antialias_kernel_size: int = 12, + padding_mode: str = "same", + ): + super().__init__() + self.dilations = dilations + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode) + for dilation in dilations + ] + ) + self.acts1 = nn.ModuleList() + for _ in range(len(self.convs1)): + if act_fn == "snakebeta": + act = SnakeBeta(channels, use_beta=True) + elif act_fn == "snake": + act = SnakeBeta(channels, use_beta=False) + else: + act = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) + + if antialias: + act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size) + self.acts1.append(act) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode) + for _ in range(len(dilations)) + ] + ) + self.acts2 = nn.ModuleList() + for _ in range(len(self.convs2)): + if act_fn == "snakebeta": + act = SnakeBeta(channels, use_beta=True) + elif act_fn == "snake": + act = SnakeBeta(channels, use_beta=False) + else: + act_fn = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) + + if antialias: + act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size) + self.acts2.append(act) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for act1, conv1, act2, conv2 in zip(self.acts1, self.convs1, self.acts2, self.convs2): + xt = act1(x) + xt = conv1(xt) + xt = act2(xt) + xt = conv2(xt) + x = x + xt + return x + + +class LTX2Vocoder(ModelMixin, ConfigMixin): + r""" + LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + hidden_channels: int = 1024, + out_channels: int = 2, + upsample_kernel_sizes: list[int] = [16, 15, 8, 4, 4], + upsample_factors: list[int] = [6, 5, 2, 2, 2], + resnet_kernel_sizes: list[int] = [3, 7, 11], + resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + act_fn: str = "leaky_relu", + leaky_relu_negative_slope: float = 0.1, + antialias: bool = False, + antialias_ratio: int = 2, + antialias_kernel_size: int = 12, + final_act_fn: str | None = "tanh", # tanh, clamp, None + final_bias: bool = True, + output_sampling_rate: int = 24000, + ): + super().__init__() + self.num_upsample_layers = len(upsample_kernel_sizes) + self.resnets_per_upsample = len(resnet_kernel_sizes) + self.out_channels = out_channels + self.total_upsample_factor = math.prod(upsample_factors) + self.act_fn = act_fn + self.negative_slope = leaky_relu_negative_slope + self.final_act_fn = final_act_fn + + if self.num_upsample_layers != len(upsample_factors): + raise ValueError( + f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length" + f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively." + ) + + if self.resnets_per_upsample != len(resnet_dilations): + raise ValueError( + f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length" + f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." + ) + + supported_act_fns = ["snakebeta", "snake", "leaky_relu"] + if self.act_fn not in supported_act_fns: + raise ValueError( + f"Unsupported activation function: {self.act_fn}. Currently supported values of `act_fn` are " + f"{supported_act_fns}." + ) + + self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3) + + self.upsamplers = nn.ModuleList() + self.resnets = nn.ModuleList() + input_channels = hidden_channels + for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + output_channels = input_channels // 2 + self.upsamplers.append( + nn.ConvTranspose1d( + input_channels, # hidden_channels // (2 ** i) + output_channels, # hidden_channels // (2 ** (i + 1)) + kernel_size, + stride=stride, + padding=(kernel_size - stride) // 2, + ) + ) + + for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): + self.resnets.append( + ResBlock( + channels=output_channels, + kernel_size=kernel_size, + dilations=dilations, + act_fn=act_fn, + leaky_relu_negative_slope=leaky_relu_negative_slope, + antialias=antialias, + antialias_ratio=antialias_ratio, + antialias_kernel_size=antialias_kernel_size, + ) + ) + input_channels = output_channels + + if act_fn == "snakebeta" or act_fn == "snake": + # Always use antialiasing + act_out = SnakeBeta(channels=output_channels, use_beta=True) + self.act_out = AntiAliasAct1d(act_out, ratio=antialias_ratio, kernel_size=antialias_kernel_size) + elif act_fn == "leaky_relu": + # NOTE: does NOT use self.negative_slope, following the original code + self.act_out = nn.LeakyReLU() + + self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3, bias=final_bias) + + def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: + r""" + Forward pass of the vocoder. + + Args: + hidden_states (`torch.Tensor`): + Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last` + is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is + `True`. + time_last (`bool`, *optional*, defaults to `False`): + Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension. + + Returns: + `torch.Tensor`: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + + # Ensure that the time/frame dimension is last + if not time_last: + hidden_states = hidden_states.transpose(2, 3) + # Combine channels and frequency (mel bins) dimensions + hidden_states = hidden_states.flatten(1, 2) + + hidden_states = self.conv_in(hidden_states) + + for i in range(self.num_upsample_layers): + if self.act_fn == "leaky_relu": + # Other activations are inside each upsampling block + hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) + hidden_states = self.upsamplers[i](hidden_states) + + # Run all resnets in parallel on hidden_states + start = i * self.resnets_per_upsample + end = (i + 1) * self.resnets_per_upsample + resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0) + + hidden_states = torch.mean(resnet_outputs, dim=0) + + hidden_states = self.act_out(hidden_states) + hidden_states = self.conv_out(hidden_states) + if self.final_act_fn == "tanh": + hidden_states = torch.tanh(hidden_states) + elif self.final_act_fn == "clamp": + hidden_states = torch.clamp(hidden_states, -1, 1) + + return hidden_states + + +class CausalSTFT(nn.Module): + """ + Performs a causal short-time Fourier transform (STFT) using causal Hann windows on a waveform. The DFT bases + multiplied by the Hann windows are pre-calculated and stored as buffers. For exact parity with training, the exact + buffers should be loaded from the checkpoint in bfloat16. + """ + + def __init__(self, filter_length: int = 512, hop_length: int = 80, window_length: int = 512): + super().__init__() + self.hop_length = hop_length + self.window_length = window_length + n_freqs = filter_length // 2 + 1 + + self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True) + self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True) + + def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if waveform.ndim == 2: + waveform = waveform.unsqueeze(1) # [B, num_channels, num_samples] + + left_pad = max(0, self.window_length - self.hop_length) # causal: left-only + waveform = F.pad(waveform, (left_pad, 0)) + + spec = F.conv1d(waveform, self.forward_basis, stride=self.hop_length, padding=0) + n_freqs = spec.shape[1] // 2 + real, imag = spec[:, :n_freqs], spec[:, n_freqs:] + magnitude = torch.sqrt(real**2 + imag**2) + phase = torch.atan2(imag.float(), real.float()).to(dtype=real.dtype) + return magnitude, phase + + +class MelSTFT(nn.Module): + """ + Calculates a causal log-mel spectrogram from a waveform. Uses a pre-calculated mel filterbank, which should be + loaded from the checkpoint in bfloat16. + """ + + def __init__( + self, + filter_length: int = 512, + hop_length: int = 80, + window_length: int = 512, + num_mel_channels: int = 64, + ): + super().__init__() + self.stft_fn = CausalSTFT(filter_length, hop_length, window_length) + + num_freqs = filter_length // 2 + 1 + self.register_buffer("mel_basis", torch.zeros(num_mel_channels, num_freqs), persistent=True) + + def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + magnitude, phase = self.stft_fn(waveform) + energy = torch.norm(magnitude, dim=1) + mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + log_mel = torch.log(torch.clamp(mel, min=1e-5)) + return log_mel, magnitude, phase, energy + + +class LTX2VocoderWithBWE(ModelMixin, ConfigMixin): + """ + LTX-2.X vocoder with bandwidth extension (BWE) upsampling. The vocoder and the BWE module run in sequence, with the + BWE module upsampling the vocoder output waveform to a higher sampling rate. The BWE module itself has the same + architecture as the original vocoder. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + hidden_channels: int = 1536, + out_channels: int = 2, + upsample_kernel_sizes: list[int] = [11, 4, 4, 4, 4, 4], + upsample_factors: list[int] = [5, 2, 2, 2, 2, 2], + resnet_kernel_sizes: list[int] = [3, 7, 11], + resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + act_fn: str = "snakebeta", + leaky_relu_negative_slope: float = 0.1, + antialias: bool = True, + antialias_ratio: int = 2, + antialias_kernel_size: int = 12, + final_act_fn: str | None = None, + final_bias: bool = False, + bwe_in_channels: int = 128, + bwe_hidden_channels: int = 512, + bwe_out_channels: int = 2, + bwe_upsample_kernel_sizes: list[int] = [12, 11, 4, 4, 4], + bwe_upsample_factors: list[int] = [6, 5, 2, 2, 2], + bwe_resnet_kernel_sizes: list[int] = [3, 7, 11], + bwe_resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + bwe_act_fn: str = "snakebeta", + bwe_leaky_relu_negative_slope: float = 0.1, + bwe_antialias: bool = True, + bwe_antialias_ratio: int = 2, + bwe_antialias_kernel_size: int = 12, + bwe_final_act_fn: str | None = None, + bwe_final_bias: bool = False, + filter_length: int = 512, + hop_length: int = 80, + window_length: int = 512, + num_mel_channels: int = 64, + input_sampling_rate: int = 16000, + output_sampling_rate: int = 48000, + ): + super().__init__() + + self.vocoder = LTX2Vocoder( + in_channels=in_channels, + hidden_channels=hidden_channels, + out_channels=out_channels, + upsample_kernel_sizes=upsample_kernel_sizes, + upsample_factors=upsample_factors, + resnet_kernel_sizes=resnet_kernel_sizes, + resnet_dilations=resnet_dilations, + act_fn=act_fn, + leaky_relu_negative_slope=leaky_relu_negative_slope, + antialias=antialias, + antialias_ratio=antialias_ratio, + antialias_kernel_size=antialias_kernel_size, + final_act_fn=final_act_fn, + final_bias=final_bias, + output_sampling_rate=input_sampling_rate, + ) + self.bwe_generator = LTX2Vocoder( + in_channels=bwe_in_channels, + hidden_channels=bwe_hidden_channels, + out_channels=bwe_out_channels, + upsample_kernel_sizes=bwe_upsample_kernel_sizes, + upsample_factors=bwe_upsample_factors, + resnet_kernel_sizes=bwe_resnet_kernel_sizes, + resnet_dilations=bwe_resnet_dilations, + act_fn=bwe_act_fn, + leaky_relu_negative_slope=bwe_leaky_relu_negative_slope, + antialias=bwe_antialias, + antialias_ratio=bwe_antialias_ratio, + antialias_kernel_size=bwe_antialias_kernel_size, + final_act_fn=bwe_final_act_fn, + final_bias=bwe_final_bias, + output_sampling_rate=output_sampling_rate, + ) + + self.mel_stft = MelSTFT( + filter_length=filter_length, + hop_length=hop_length, + window_length=window_length, + num_mel_channels=num_mel_channels, + ) + + self.resampler = UpSample1d( + ratio=output_sampling_rate // input_sampling_rate, + window_type="hann", + persistent=False, + ) + + def forward(self, mel_spec: torch.Tensor) -> torch.Tensor: + """ + Args: + mel_spec (`torch.Tensor`): + Input mel spectrogram of shape `(batch_size, num_channels, num_frames, num_mel_bins)`. + """ + # 1. Run stage 1 vocoder to get low sampling rate waveform + x = self.vocoder(mel_spec) + batch_size, num_channels, num_samples = x.shape + + # Pad to exact multiple of hop_length for exact mel frame count + remainder = num_samples % self.config.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) + + # 2. Compute mel spectrogram on vocoder output + mel, _, _, _ = self.mel_stft(x.flatten(0, 1)) + mel = mel.unflatten(0, (-1, num_channels)) + + # 3. Run bandwidth extender (BWE) on new mel spectrogram + mel_for_bwe = mel.transpose(2, 3) # [B, C, num_mel_bins, num_frames] --> [B, C, num_frames, num_mel_bins] + residual = self.bwe_generator(mel_for_bwe) + + # 4. Residual connection with resampler + skip = self.resampler(x) + waveform = torch.clamp(residual + skip, -1, 1) + output_samples = num_samples * self.config.output_sampling_rate // self.config.input_sampling_rate + waveform = waveform[..., :output_samples] + return waveform diff --git a/src/diffusers/models/condition_embedders/__init__.py b/src/diffusers/models/condition_embedders/__init__.py new file mode 100644 index 000000000000..6a29a02fce47 --- /dev/null +++ b/src/diffusers/models/condition_embedders/__init__.py @@ -0,0 +1,6 @@ +from .condition_encoder_ace_step import AceStepConditionEncoder +from .image_encoder_redux import ReduxImageEncoder +from .projection_audioldm2 import AudioLDM2ProjectionModel +from .projection_clip_image import CLIPImageProjection +from .projection_stable_audio import StableAudioProjectionModel +from .text_connector_ltx2 import LTX2TextConnectors diff --git a/src/diffusers/models/condition_embedders/condition_encoder_ace_step.py b/src/diffusers/models/condition_embedders/condition_encoder_ace_step.py new file mode 100644 index 000000000000..c0ab8c8db58e --- /dev/null +++ b/src/diffusers/models/condition_embedders/condition_encoder_ace_step.py @@ -0,0 +1,499 @@ +# Copyright 2025 The ACE-Step Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""ACE-Step condition encoder. + +Fuses text, lyric, and timbre conditioning into the packed sequence used by the DiT's cross-attention. The internal +``AceStepEncoderLayer`` block is duplicated here (and in :mod:`diffusers.models.autoencoders.audio_tokenizer_ace_step`) +to keep the two model files independent. +""" + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm +from ..transformers.ace_step_transformer import ( + AceStepAttention, + AceStepMLP, + _ace_step_rotary_freqs, + _create_4d_mask, + _is_flash_attention_backend, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _pack_sequences( + hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Pack two masked sequences into one with all valid tokens first. + + Concatenates ``hidden1`` + ``hidden2`` along the sequence dim, then stably sorts each batch so mask=1 tokens come + before mask=0 tokens. Returns the packed hidden states plus a fresh contiguous mask. + """ + hidden_cat = torch.cat([hidden1, hidden2], dim=1) + mask_cat = torch.cat([mask1, mask2], dim=1) + + B, L, D = hidden_cat.shape + sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) + hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D)) + lengths = mask_cat.sum(dim=1) + new_mask = torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1) + return hidden_left, new_mask + + +class AceStepEncoderLayer(nn.Module): + """Pre-LN transformer block used by the lyric and timbre encoders.""" + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + intermediate_size: int, + attention_bias: bool = False, + attention_dropout: float = 0.0, + rms_norm_eps: float = 1e-6, + sliding_window: Optional[int] = None, + ): + super().__init__() + self.self_attn = AceStepAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + bias=attention_bias, + dropout=attention_dropout, + eps=rms_norm_eps, + sliding_window=sliding_window, + is_cross_attention=False, + ) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.mlp = AceStepMLP(hidden_size, intermediate_size) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + image_rotary_emb=position_embeddings, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class AceStepLyricEncoder(ModelMixin, ConfigMixin): + """Lyric encoder: projects Qwen3 lyric embeddings and runs a small transformer. + + Output feeds the DiT cross-attention (after packing with text + timbre). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + text_hidden_dim: int = 1024, + num_lyric_encoder_hidden_layers: int = 8, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + head_dim: int = 128, + rope_theta: float = 1000000.0, + attention_bias: bool = False, + attention_dropout: float = 0.0, + rms_norm_eps: float = 1e-6, + sliding_window: int = 128, + layer_types: list = None, + ): + super().__init__() + + if layer_types is None: + layer_types = [ + "sliding_attention" if bool((i + 1) % 2) else "full_attention" + for i in range(num_lyric_encoder_hidden_layers) + ] + + self.embed_tokens = nn.Linear(text_hidden_dim, hidden_size) + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.head_dim = head_dim + self.rope_theta = rope_theta + self.sliding_window = sliding_window + + self.layers = nn.ModuleList( + [ + AceStepEncoderLayer( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + rms_norm_eps=rms_norm_eps, + sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None, + ) + for i in range(num_lyric_encoder_hidden_layers) + ] + ) + + self._layer_types = layer_types + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds: torch.FloatTensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + inputs_embeds (`torch.FloatTensor`): + Lyric token ids of shape `(batch_size, sequence_length)` to embed and encode. + attention_mask (`torch.Tensor`): + Attention mask of shape `(batch_size, sequence_length)` indicating which tokens are valid. + """ + inputs_embeds = self.embed_tokens(inputs_embeds) + + seq_len = inputs_embeds.shape[1] + dtype = inputs_embeds.dtype + device = inputs_embeds.device + + cos, sin = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype) + position_embeddings = (cos, sin) + + if _is_flash_attention_backend(self.layers[0].self_attn.processor): + full_attn_mask = attention_mask + sliding_attn_mask = attention_mask + else: + full_attn_mask = _create_4d_mask( + seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, is_causal=False + ) + sliding_attn_mask = _create_4d_mask( + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=attention_mask, + sliding_window=self.sliding_window, + is_sliding_window=True, + is_causal=False, + ) + + hidden_states = inputs_embeds + for i, layer_module in enumerate(self.layers): + mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else full_attn_mask + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + layer_module, hidden_states, position_embeddings, mask + ) + else: + hidden_states = layer_module( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=mask, + ) + return self.norm(hidden_states) + + +class AceStepTimbreEncoder(ModelMixin, ConfigMixin): + """Timbre encoder: consumes VAE-encoded reference-audio latents and returns a + pooled per-batch timbre embedding (plus a presence mask). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + timbre_hidden_dim: int = 64, + num_timbre_encoder_hidden_layers: int = 4, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + head_dim: int = 128, + rope_theta: float = 1000000.0, + attention_bias: bool = False, + attention_dropout: float = 0.0, + rms_norm_eps: float = 1e-6, + sliding_window: int = 128, + layer_types: list = None, + ): + super().__init__() + + if layer_types is None: + layer_types = [ + "sliding_attention" if bool((i + 1) % 2) else "full_attention" + for i in range(num_timbre_encoder_hidden_layers) + ] + + self.embed_tokens = nn.Linear(timbre_hidden_dim, hidden_size) + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size)) + self.head_dim = head_dim + self.rope_theta = rope_theta + self.sliding_window = sliding_window + + self.layers = nn.ModuleList( + [ + AceStepEncoderLayer( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + rms_norm_eps=rms_norm_eps, + sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None, + ) + for i in range(num_timbre_encoder_hidden_layers) + ] + ) + + self._layer_types = layer_types + self.gradient_checkpointing = False + + @staticmethod + def unpack_timbre_embeddings( + timbre_embs_packed: torch.Tensor, refer_audio_order_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + N, d = timbre_embs_packed.shape + device = timbre_embs_packed.device + dtype = timbre_embs_packed.dtype + + B = int(refer_audio_order_mask.max().item() + 1) + counts = torch.bincount(refer_audio_order_mask, minlength=B) + max_count = counts.max().item() + + sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True) + sorted_batch_ids = refer_audio_order_mask[sorted_indices] + + positions = torch.arange(N, device=device) + batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]]) + positions_in_sorted = positions - batch_starts[sorted_batch_ids] + + inverse_indices = torch.empty_like(sorted_indices) + inverse_indices[sorted_indices] = torch.arange(N, device=device) + positions_in_batch = positions_in_sorted[inverse_indices] + + indices_2d = refer_audio_order_mask * max_count + positions_in_batch + one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype) + + timbre_embs_flat = one_hot.t() @ timbre_embs_packed + timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d) + + mask_flat = (one_hot.sum(dim=0) > 0).long() + new_mask = mask_flat.reshape(B, max_count) + return timbre_embs_unpack, new_mask + + def forward( + self, + refer_audio_acoustic_hidden_states_packed: torch.FloatTensor, + refer_audio_order_mask: torch.LongTensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + refer_audio_acoustic_hidden_states_packed (`torch.FloatTensor`): + Packed reference-audio acoustic hidden states of shape `(total_tokens, hidden_size)` across all + reference samples in the batch. + refer_audio_order_mask (`torch.LongTensor`): + Batch-index assignment of shape `(total_tokens,)` indicating which reference sample each packed token + belongs to. + """ + inputs_embeds = self.embed_tokens(refer_audio_acoustic_hidden_states_packed) + + seq_len = inputs_embeds.shape[1] + dtype = inputs_embeds.dtype + device = inputs_embeds.device + + cos, sin = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype) + position_embeddings = (cos, sin) + + sliding_attn_mask = None + if not _is_flash_attention_backend(self.layers[0].self_attn.processor): + sliding_attn_mask = _create_4d_mask( + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=None, + sliding_window=self.sliding_window, + is_sliding_window=True, + is_causal=False, + ) + + hidden_states = inputs_embeds + for i, layer_module in enumerate(self.layers): + # No padding mask on timbre input (pre-packed), so full-attention layers see None. + mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else None + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + layer_module, hidden_states, position_embeddings, mask + ) + else: + hidden_states = layer_module( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=mask, + ) + + hidden_states = self.norm(hidden_states) + # CLS-like pooling: first-token embedding per packed sequence. + hidden_states = hidden_states[:, 0, :] + timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask) + return timbre_embs_unpack, timbre_embs_mask + + +class AceStepConditionEncoder(ModelMixin, ConfigMixin): + """Fuses text + lyric + timbre conditioning into the packed sequence used by + the DiT's cross-attention. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + hidden_size: int = 2048, + intermediate_size: int = 6144, + text_hidden_dim: int = 1024, + timbre_hidden_dim: int = 64, + num_lyric_encoder_hidden_layers: int = 8, + num_timbre_encoder_hidden_layers: int = 4, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + head_dim: int = 128, + rope_theta: float = 1000000.0, + attention_bias: bool = False, + attention_dropout: float = 0.0, + rms_norm_eps: float = 1e-6, + sliding_window: int = 128, + layer_types: list = None, + ): + super().__init__() + + self.text_projector = nn.Linear(text_hidden_dim, hidden_size, bias=False) + + self.lyric_encoder = AceStepLyricEncoder( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + text_hidden_dim=text_hidden_dim, + num_lyric_encoder_hidden_layers=num_lyric_encoder_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + rope_theta=rope_theta, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + rms_norm_eps=rms_norm_eps, + sliding_window=sliding_window, + layer_types=layer_types, + ) + + self.timbre_encoder = AceStepTimbreEncoder( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + timbre_hidden_dim=timbre_hidden_dim, + num_timbre_encoder_hidden_layers=num_timbre_encoder_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + rope_theta=rope_theta, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + rms_norm_eps=rms_norm_eps, + sliding_window=sliding_window, + ) + + # Learned null-condition embedding for classifier-free guidance, trained with + # `cfg_ratio=0.15` in the original model. Broadcast along the sequence dim when used. + self.null_condition_emb = nn.Parameter(torch.randn(1, 1, hidden_size)) + + # Silence latent — VAE-encoded audio-silence, stored as (1, T_long, timbre_hidden_dim). + # When no reference audio is provided, the pipeline slices `silence_latent[:, :timbre_fix_frame, :]` + # and feeds that to the timbre encoder. Passing literal zeros puts the timbre encoder + # OOD and produces drone-like audio (observed on all text2music outputs before this fix). + # The placeholder here is overwritten by the converter with the real encoded silence, + # so its shape just needs to match the timbre-encoder input: last dim is + # `timbre_hidden_dim` (so smaller test configs with `timbre_hidden_dim != 64` also load). + self.register_buffer( + "silence_latent", + torch.zeros(1, 15000, timbre_hidden_dim), + persistent=True, + ) + + def forward( + self, + text_hidden_states: torch.FloatTensor, + text_attention_mask: torch.Tensor, + lyric_hidden_states: torch.FloatTensor, + lyric_attention_mask: torch.Tensor, + refer_audio_acoustic_hidden_states_packed: torch.FloatTensor, + refer_audio_order_mask: torch.LongTensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + text_hidden_states (`torch.FloatTensor`): + Text encoder hidden states of shape `(batch_size, text_sequence_length, text_hidden_dim)`. + text_attention_mask (`torch.Tensor`): + Attention mask of shape `(batch_size, text_sequence_length)` for the text hidden states. + lyric_hidden_states (`torch.FloatTensor`): + Lyric token ids of shape `(batch_size, lyric_sequence_length)` to be encoded by the lyric encoder. + lyric_attention_mask (`torch.Tensor`): + Attention mask of shape `(batch_size, lyric_sequence_length)` for the lyric tokens. + refer_audio_acoustic_hidden_states_packed (`torch.FloatTensor`): + Packed reference-audio acoustic hidden states of shape `(total_tokens, hidden_size)`. + refer_audio_order_mask (`torch.LongTensor`): + Batch-index assignment of shape `(total_tokens,)` indicating which reference sample each packed token + belongs to. + """ + text_hidden_states = self.text_projector(text_hidden_states) + + lyric_hidden_states = self.lyric_encoder( + inputs_embeds=lyric_hidden_states, attention_mask=lyric_attention_mask + ) + + timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder( + refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask + ) + + encoder_hidden_states, encoder_attention_mask = _pack_sequences( + lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask + ) + encoder_hidden_states, encoder_attention_mask = _pack_sequences( + encoder_hidden_states, text_hidden_states, encoder_attention_mask, text_attention_mask + ) + + return encoder_hidden_states, encoder_attention_mask diff --git a/src/diffusers/models/condition_embedders/image_encoder_redux.py b/src/diffusers/models/condition_embedders/image_encoder_redux.py new file mode 100644 index 000000000000..032fe7925f58 --- /dev/null +++ b/src/diffusers/models/condition_embedders/image_encoder_redux.py @@ -0,0 +1,52 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..modeling_utils import ModelMixin + + +@dataclass +class ReduxImageEncoderOutput(BaseOutput): + image_embeds: torch.Tensor | None = None + + +class ReduxImageEncoder(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + redux_dim: int = 1152, + txt_in_features: int = 4096, + ) -> None: + super().__init__() + + self.redux_up = nn.Linear(redux_dim, txt_in_features * 3) + self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features) + + def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput: + """ + Args: + x (`torch.Tensor`): + Image embeddings of shape `(batch_size, sequence_length, redux_dim)` produced by the SigLIP image + encoder. + """ + projected_x = self.redux_down(nn.functional.silu(self.redux_up(x))) + + return ReduxImageEncoderOutput(image_embeds=projected_x) diff --git a/src/diffusers/models/condition_embedders/projection_audioldm2.py b/src/diffusers/models/condition_embedders/projection_audioldm2.py new file mode 100644 index 000000000000..ff4802324ac9 --- /dev/null +++ b/src/diffusers/models/condition_embedders/projection_audioldm2.py @@ -0,0 +1,155 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..modeling_utils import ModelMixin + + +def add_special_tokens(hidden_states, attention_mask, sos_token, eos_token): + batch_size = hidden_states.shape[0] + + if attention_mask is not None: + # Add two more steps to attn mask + new_attn_mask_step = attention_mask.new_ones((batch_size, 1)) + attention_mask = torch.concat([new_attn_mask_step, attention_mask, new_attn_mask_step], dim=-1) + + # Add the SOS / EOS tokens at the start / end of the sequence respectively + sos_token = sos_token.expand(batch_size, 1, -1) + eos_token = eos_token.expand(batch_size, 1, -1) + hidden_states = torch.concat([sos_token, hidden_states, eos_token], dim=1) + return hidden_states, attention_mask + + +@dataclass +class AudioLDM2ProjectionModelOutput(BaseOutput): + """ + Args: + Class for AudioLDM2 projection layer's outputs. + hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states obtained by linearly projecting the hidden-states for each of the text + encoders and subsequently concatenating them together. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices, formed by concatenating the attention masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + """ + + hidden_states: torch.Tensor + attention_mask: torch.LongTensor | None = None + + +class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin): + """ + A simple linear projection model to map two text embeddings to a shared latent space. It also inserts learned + embedding vectors at the start and end of each text embedding sequence respectively. Each variable appended with + `_1` refers to that corresponding to the second text encoder. Otherwise, it is from the first. + + Args: + text_encoder_dim (`int`): + Dimensionality of the text embeddings from the first text encoder (CLAP). + text_encoder_1_dim (`int`): + Dimensionality of the text embeddings from the second text encoder (T5 or VITS). + langauge_model_dim (`int`): + Dimensionality of the text embeddings from the language model (GPT2). + """ + + @register_to_config + def __init__( + self, + text_encoder_dim, + text_encoder_1_dim, + langauge_model_dim, + use_learned_position_embedding=None, + max_seq_length=None, + ): + super().__init__() + # additional projection layers for each text encoder + self.projection = nn.Linear(text_encoder_dim, langauge_model_dim) + self.projection_1 = nn.Linear(text_encoder_1_dim, langauge_model_dim) + + # learnable SOS / EOS token embeddings for each text encoder + self.sos_embed = nn.Parameter(torch.ones(langauge_model_dim)) + self.eos_embed = nn.Parameter(torch.ones(langauge_model_dim)) + + self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim)) + self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim)) + + self.use_learned_position_embedding = use_learned_position_embedding + + # learable positional embedding for vits encoder + if self.use_learned_position_embedding is not None: + self.learnable_positional_embedding = torch.nn.Parameter( + torch.zeros((1, text_encoder_1_dim, max_seq_length)) + ) + + def forward( + self, + hidden_states: torch.Tensor | None = None, + hidden_states_1: torch.Tensor | None = None, + attention_mask: torch.LongTensor | None = None, + attention_mask_1: torch.LongTensor | None = None, + ): + """ + Args: + hidden_states (`torch.Tensor`, *optional*): + Hidden states from the first text encoder of shape `(batch_size, sequence_length, text_encoder_dim)`. + hidden_states_1 (`torch.Tensor`, *optional*): + Hidden states from the second text encoder of shape `(batch_size, sequence_length_1, + text_encoder_1_dim)`. + attention_mask (`torch.LongTensor`, *optional*): + Attention mask of shape `(batch_size, sequence_length)` for `hidden_states`. + attention_mask_1 (`torch.LongTensor`, *optional*): + Attention mask of shape `(batch_size, sequence_length_1)` for `hidden_states_1`. + """ + hidden_states = self.projection(hidden_states) + hidden_states, attention_mask = add_special_tokens( + hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed + ) + + # Add positional embedding for Vits hidden state + if self.use_learned_position_embedding is not None: + hidden_states_1 = (hidden_states_1.permute(0, 2, 1) + self.learnable_positional_embedding).permute(0, 2, 1) + + hidden_states_1 = self.projection_1(hidden_states_1) + hidden_states_1, attention_mask_1 = add_special_tokens( + hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1 + ) + + # concatenate clap and t5 text encoding + hidden_states = torch.cat([hidden_states, hidden_states_1], dim=1) + + # concatenate attention masks + if attention_mask is None and attention_mask_1 is not None: + attention_mask = attention_mask_1.new_ones((hidden_states[:2])) + elif attention_mask is not None and attention_mask_1 is None: + attention_mask_1 = attention_mask.new_ones((hidden_states_1[:2])) + + if attention_mask is not None and attention_mask_1 is not None: + attention_mask = torch.cat([attention_mask, attention_mask_1], dim=-1) + else: + attention_mask = None + + return AudioLDM2ProjectionModelOutput( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) diff --git a/src/diffusers/models/condition_embedders/projection_clip_image.py b/src/diffusers/models/condition_embedders/projection_clip_image.py new file mode 100644 index 000000000000..b2a99120fccb --- /dev/null +++ b/src/diffusers/models/condition_embedders/projection_clip_image.py @@ -0,0 +1,34 @@ +# Copyright 2025 The GLIGEN Authors and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin + + +class CLIPImageProjection(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, hidden_size: int = 768): + super().__init__() + self.hidden_size = hidden_size + self.project = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + def forward(self, x): + """ + Args: + x (`torch.Tensor`): + Input CLIP image embeddings of shape `(batch_size, hidden_size)`. + """ + return self.project(x) diff --git a/src/diffusers/models/condition_embedders/projection_stable_audio.py b/src/diffusers/models/condition_embedders/projection_stable_audio.py new file mode 100644 index 000000000000..42f36bbcb09d --- /dev/null +++ b/src/diffusers/models/condition_embedders/projection_stable_audio.py @@ -0,0 +1,165 @@ +# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from math import pi + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput, logging +from ..modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableAudioPositionalEmbedding(nn.Module): + """Used for continuous time""" + + def __init__(self, dim: int): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, times: torch.Tensor) -> torch.Tensor: + times = times[..., None] + freqs = times * self.weights[None] * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((times, fouriered), dim=-1) + return fouriered + + +@dataclass +class StableAudioProjectionModelOutput(BaseOutput): + """ + Args: + Class for StableAudio projection layer's outputs. + text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder. + seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the audio start hidden states. + seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the audio end hidden states. + """ + + text_hidden_states: torch.Tensor | None = None + seconds_start_hidden_states: torch.Tensor | None = None + seconds_end_hidden_states: torch.Tensor | None = None + + +class StableAudioNumberConditioner(nn.Module): + """ + A simple linear projection model to map numbers to a latent space. + + Args: + number_embedding_dim (`int`): + Dimensionality of the number embeddings. + min_value (`int`): + The minimum value of the seconds number conditioning modules. + max_value (`int`): + The maximum value of the seconds number conditioning modules + internal_dim (`int`): + Dimensionality of the intermediate number hidden states. + """ + + def __init__( + self, + number_embedding_dim, + min_value, + max_value, + internal_dim: int | None = 256, + ): + super().__init__() + self.time_positional_embedding = nn.Sequential( + StableAudioPositionalEmbedding(internal_dim), + nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim), + ) + + self.number_embedding_dim = number_embedding_dim + self.min_value = min_value + self.max_value = max_value + + def forward( + self, + floats: torch.Tensor, + ): + floats = floats.clamp(self.min_value, self.max_value) + + normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) + + # Cast floats to same type as embedder + embedder_dtype = next(self.time_positional_embedding.parameters()).dtype + normalized_floats = normalized_floats.to(embedder_dtype) + + embedding = self.time_positional_embedding(normalized_floats) + float_embeds = embedding.view(-1, 1, self.number_embedding_dim) + + return float_embeds + + +class StableAudioProjectionModel(ModelMixin, ConfigMixin): + """ + A simple linear projection model to map the conditioning values to a shared latent space. + + Args: + text_encoder_dim (`int`): + Dimensionality of the text embeddings from the text encoder (T5). + conditioning_dim (`int`): + Dimensionality of the output conditioning tensors. + min_value (`int`): + The minimum value of the seconds number conditioning modules. + max_value (`int`): + The maximum value of the seconds number conditioning modules + """ + + @register_to_config + def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value): + super().__init__() + self.text_projection = ( + nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim) + ) + self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) + self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) + + def forward( + self, + text_hidden_states: torch.Tensor | None = None, + start_seconds: torch.Tensor | None = None, + end_seconds: torch.Tensor | None = None, + ): + """ + Args: + text_hidden_states (`torch.Tensor`, *optional*): + Hidden states from the text encoder of shape `(batch_size, sequence_length, text_encoder_dim)`. + start_seconds (`torch.Tensor`, *optional*): + Start-time-in-seconds conditioning values of shape `(batch_size,)`. + end_seconds (`torch.Tensor`, *optional*): + End-time-in-seconds conditioning values of shape `(batch_size,)`. + """ + text_hidden_states = ( + text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states) + ) + seconds_start_hidden_states = ( + start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds) + ) + seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds) + + return StableAudioProjectionModelOutput( + text_hidden_states=text_hidden_states, + seconds_start_hidden_states=seconds_start_hidden_states, + seconds_end_hidden_states=seconds_end_hidden_states, + ) diff --git a/src/diffusers/models/condition_embedders/text_connector_ltx2.py b/src/diffusers/models/condition_embedders/text_connector_ltx2.py new file mode 100644 index 000000000000..44f2f89d61d1 --- /dev/null +++ b/src/diffusers/models/condition_embedders/text_connector_ltx2.py @@ -0,0 +1,474 @@ +import math + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ..attention import FeedForward +from ..modeling_utils import ModelMixin +from ..transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor + + +def per_layer_masked_mean_norm( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: str | torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, +): + """ + Performs per-batch per-layer normalization using a masked mean and range on per-layer text encoder hidden_states. + Respects the padding of the hidden states. + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + +def per_token_rms_norm(text_encoder_hidden_states: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + variance = torch.mean(text_encoder_hidden_states**2, dim=2, keepdim=True) + norm_text_encoder_hidden_states = text_encoder_hidden_states * torch.rsqrt(variance + eps) + return norm_text_encoder_hidden_states + + +class LTX2RotaryPosEmbed1d(nn.Module): + """ + 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. + """ + + def __init__( + self, + dim: int, + base_seq_len: int = 4096, + theta: float = 10000.0, + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ): + super().__init__() + if rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + + self.dim = dim + self.base_seq_len = base_seq_len + self.theta = theta + self.double_precision = double_precision + self.rope_type = rope_type + self.num_attention_heads = num_attention_heads + + def forward( + self, + batch_size: int, + pos: int, + device: str | torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + # 1. Get 1D position ids + grid_1d = torch.arange(pos, dtype=torch.float32, device=device) + # Get fractional indices relative to self.base_seq_len + grid_1d = grid_1d / self.base_seq_len + grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] + + # 2. Calculate 1D RoPE frequencies + num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2 + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape + # (self.dim // 2,). + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2] + + # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + + cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + + return cos_freqs, sin_freqs + + +class LTX2TransformerBlock1d(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + activation_fn: str = "gelu-approximate", + eps: float = 1e-6, + rope_type: str = "interleaved", + apply_gated_attention: bool = False, + ): + super().__init__() + + self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + rope_type=rope_type, + apply_gated_attention=apply_gated_attention, + processor=LTX2AudioVideoAttnProcessor(), + ) + + self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + rotary_emb: torch.Tensor | None = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb) + hidden_states = hidden_states + attn_hidden_states + + norm_hidden_states = self.norm2(hidden_states) + ff_hidden_states = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_hidden_states + + return hidden_states + + +class LTX2ConnectorTransformer1d(nn.Module): + """ + A 1D sequence transformer for modalities such as text. + + In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 128, + num_layers: int = 2, + num_learnable_registers: int | None = 128, + rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + eps: float = 1e-6, + causal_temporal_positioning: bool = False, + rope_type: str = "interleaved", + gated_attention: bool = False, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + + self.num_learnable_registers = num_learnable_registers + self.learnable_registers = None + if num_learnable_registers is not None: + init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0 + self.learnable_registers = torch.nn.Parameter(init_registers) + + self.rope = LTX2RotaryPosEmbed1d( + self.inner_dim, + base_seq_len=rope_base_seq_len, + theta=rope_theta, + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + + self.transformer_blocks = torch.nn.ModuleList( + [ + LTX2TransformerBlock1d( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + rope_type=rope_type, + apply_gated_attention=gated_attention, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + attn_mask_binarize_threshold: float = -9000.0, + ) -> tuple[torch.Tensor, torch.Tensor]: + # hidden_states shape: [batch_size, seq_len, hidden_dim] + # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len] + batch_size, seq_len, _ = hidden_states.shape + + # 1. Replace padding with learned registers, if using + if self.learnable_registers is not None: + if seq_len % self.num_learnable_registers != 0: + raise ValueError( + f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number" + f" of learnable registers {self.num_learnable_registers}" + ) + + num_register_repeats = seq_len // self.num_learnable_registers + registers = ( + self.learnable_registers.unsqueeze(0).expand(num_register_repeats, -1, -1).reshape(seq_len, -1) + ) # [seq_len, inner_dim] + + binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int() + if binary_attn_mask.ndim == 4: + binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] + + # Replace padding positions with learned registers using vectorized masking + mask = binary_attn_mask.unsqueeze(-1) # [B, L, 1] + registers_expanded = registers.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, D] + hidden_states = mask * hidden_states + (1 - mask) * registers_expanded + + # Flip sequence: embeddings move to front, registers to back (from left padding layout) + hidden_states = torch.flip(hidden_states, dims=[1]) + + # Overwrite attention_mask with an all-zeros mask if using registers. + attention_mask = torch.zeros_like(attention_mask) + + # 2. Calculate 1D RoPE positional embeddings + rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device) + + # 3. Run 1D transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb) + else: + hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb) + + hidden_states = self.norm_out(hidden_states) + + return hidden_states, attention_mask + + +class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin): + """ + Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio + streams. + """ + + @register_to_config + def __init__( + self, + caption_channels: int = 3840, # default Gemma-3-12B text encoder hidden_size + text_proj_in_factor: int = 49, # num_layers + 1 for embedding layer = 48 + 1 for Gemma-3-12B + video_connector_num_attention_heads: int = 30, + video_connector_attention_head_dim: int = 128, + video_connector_num_layers: int = 2, + video_connector_num_learnable_registers: int | None = 128, + video_gated_attn: bool = False, + audio_connector_num_attention_heads: int = 30, + audio_connector_attention_head_dim: int = 128, + audio_connector_num_layers: int = 2, + audio_connector_num_learnable_registers: int | None = 128, + audio_gated_attn: bool = False, + connector_rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + causal_temporal_positioning: bool = False, + rope_type: str = "interleaved", + per_modality_projections: bool = False, + video_hidden_dim: int = 4096, + audio_hidden_dim: int = 2048, + proj_bias: bool = False, + ): + super().__init__() + text_encoder_dim = caption_channels * text_proj_in_factor + if per_modality_projections: + self.video_text_proj_in = nn.Linear(text_encoder_dim, video_hidden_dim, bias=proj_bias) + self.audio_text_proj_in = nn.Linear(text_encoder_dim, audio_hidden_dim, bias=proj_bias) + else: + self.text_proj_in = nn.Linear(text_encoder_dim, caption_channels, bias=proj_bias) + + self.video_connector = LTX2ConnectorTransformer1d( + num_attention_heads=video_connector_num_attention_heads, + attention_head_dim=video_connector_attention_head_dim, + num_layers=video_connector_num_layers, + num_learnable_registers=video_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type, + gated_attention=video_gated_attn, + ) + self.audio_connector = LTX2ConnectorTransformer1d( + num_attention_heads=audio_connector_num_attention_heads, + attention_head_dim=audio_connector_attention_head_dim, + num_layers=audio_connector_num_layers, + num_learnable_registers=audio_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type, + gated_attention=audio_gated_attn, + ) + + def forward( + self, + text_encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + padding_side: str = "left", + scale_factor: int = 8, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given per-layer text encoder hidden_states, extracts features and runs per-modality connectors to get text + embeddings for the LTX-2.X DiT models. + + Args: + text_encoder_hidden_states (`torch.Tensor`)): + Per-layer text encoder hidden_states. Can either be 4D with shape `(batch_size, seq_len, + caption_channels, text_proj_in_factor) or 3D with the last two dimensions flattened. + attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): + Multiplicative binary attention mask where 1s indicate unmasked positions and 0s indicate masked + positions. + padding_side (`str`, *optional*, defaults to `"left"`): + The padding side used by the text encoder's text encoder (either `"left"` or `"right"`). Defaults to + `"left"` as this is what the default Gemma3-12B text encoder uses. Only used if + `per_modality_projections` is `False` (LTX-2.0 models). + scale_factor (`int`, *optional*, defaults to `8`): + Scale factor for masked mean/range normalization. Only used if `per_modality_projections` is `False` + (LTX-2.0 models). + """ + if text_encoder_hidden_states.ndim == 3: + # Ensure shape is [batch_size, seq_len, caption_channels, text_proj_in_factor] + text_encoder_hidden_states = text_encoder_hidden_states.unflatten(2, (self.config.caption_channels, -1)) + + if self.config.per_modality_projections: + # LTX-2.3 + norm_text_encoder_hidden_states = per_token_rms_norm(text_encoder_hidden_states) + + norm_text_encoder_hidden_states = norm_text_encoder_hidden_states.flatten(2, 3) + bool_mask = attention_mask.bool().unsqueeze(-1) + norm_text_encoder_hidden_states = torch.where( + bool_mask, norm_text_encoder_hidden_states, torch.zeros_like(norm_text_encoder_hidden_states) + ) + + # Rescale norms with respect to video and audio dims for feature extractors + video_scale_factor = math.sqrt(self.config.video_hidden_dim / self.config.caption_channels) + video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor + audio_scale_factor = math.sqrt(self.config.audio_hidden_dim / self.config.caption_channels) + audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor + + # Per-Modality Feature extractors + video_text_emb_proj = self.video_text_proj_in(video_norm_text_emb) + audio_text_emb_proj = self.audio_text_proj_in(audio_norm_text_emb) + else: + # LTX-2.0 + sequence_lengths = attention_mask.sum(dim=-1) + norm_text_encoder_hidden_states = per_layer_masked_mean_norm( + text_hidden_states=text_encoder_hidden_states, + sequence_lengths=sequence_lengths, + device=text_encoder_hidden_states.device, + padding_side=padding_side, + scale_factor=scale_factor, + ) + + text_emb_proj = self.text_proj_in(norm_text_encoder_hidden_states) + video_text_emb_proj = text_emb_proj + audio_text_emb_proj = text_emb_proj + + # Convert to additive attention mask for connectors + text_dtype = video_text_emb_proj.dtype + attention_mask = (attention_mask.to(torch.int64) - 1).to(text_dtype) + attention_mask = attention_mask.reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + add_attn_mask = attention_mask * torch.finfo(text_dtype).max + + video_text_embedding, video_attn_mask = self.video_connector(video_text_emb_proj, add_attn_mask) + + # Convert video attn mask to binary (multiplicative) mask and mask video text embedding + binary_attn_mask = (video_attn_mask < 1e-6).to(torch.int64) + binary_attn_mask = binary_attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) + video_text_embedding = video_text_embedding * binary_attn_mask + + audio_text_embedding, _ = self.audio_connector(audio_text_emb_proj, add_attn_mask) + + return video_text_embedding, audio_text_embedding, binary_attn_mask.squeeze(-1) diff --git a/src/diffusers/models/others/__init__.py b/src/diffusers/models/others/__init__.py new file mode 100644 index 000000000000..cf0c05998d83 --- /dev/null +++ b/src/diffusers/models/others/__init__.py @@ -0,0 +1,12 @@ +from .image_normalizer_stable_unclip import StableUnCLIPImageNormalizer +from .renderer_shap_e import ( + BoundingBoxVolume, + ImportanceRaySampler, + MLPNeRFModelOutput, + MLPNeRSTFModel, + ShapEParamsProjModel, + ShapERenderer, + StratifiedRaySampler, + VoidNeRFModel, +) +from .watermark_if import IFWatermarker diff --git a/src/diffusers/models/others/image_normalizer_stable_unclip.py b/src/diffusers/models/others/image_normalizer_stable_unclip.py new file mode 100644 index 000000000000..203efc42cc27 --- /dev/null +++ b/src/diffusers/models/others/image_normalizer_stable_unclip.py @@ -0,0 +1,55 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin + + +class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): + """ + This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. + + It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image + embeddings. + """ + + @register_to_config + def __init__( + self, + embedding_dim: int = 768, + ): + super().__init__() + + self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) + self.std = nn.Parameter(torch.ones(1, embedding_dim)) + + def to( + self, + torch_device: str | torch.device | None = None, + torch_dtype: torch.dtype | None = None, + ): + self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype)) + self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype)) + return self + + def scale(self, embeds): + embeds = (embeds - self.mean) * 1.0 / self.std + return embeds + + def unscale(self, embeds): + embeds = (embeds * self.std) + self.mean + return embeds diff --git a/src/diffusers/models/others/renderer_shap_e.py b/src/diffusers/models/others/renderer_shap_e.py new file mode 100644 index 000000000000..8c5ad12e70ac --- /dev/null +++ b/src/diffusers/models/others/renderer_shap_e.py @@ -0,0 +1,1074 @@ +# Copyright 2025 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from .. import ModelMixin + + +def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor: + r""" + Sample from the given discrete probability distribution with replacement. + + The i-th bin is assumed to have mass pmf[i]. + + Args: + pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all() + n_samples: number of samples + + Return: + indices sampled with replacement + """ + + *shape, support_size, last_dim = pmf.shape + assert last_dim == 1 + + cdf = torch.cumsum(pmf.view(-1, support_size), dim=1) + inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device)) + + return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1) + + +def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor: + """ + Concatenate x and its positional encodings, following NeRF. + + Reference: https://huggingface.co/papers/2210.04628 + """ + if min_deg == max_deg: + return x + + scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device) + *shape, dim = x.shape + xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1) + assert xb.shape[-1] == dim * (max_deg - min_deg) + emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin() + return torch.cat([x, emb], dim=-1) + + +def encode_position(position): + return posenc_nerf(position, min_deg=0, max_deg=15) + + +def encode_direction(position, direction=None): + if direction is None: + return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8)) + else: + return posenc_nerf(direction, min_deg=0, max_deg=8) + + +def _sanitize_name(x: str) -> str: + return x.replace(".", "__") + + +def integrate_samples(volume_range, ts, density, channels): + r""" + Function integrating the model output. + + Args: + volume_range: Specifies the integral range [t0, t1] + ts: timesteps + density: torch.Tensor [batch_size, *shape, n_samples, 1] + channels: torch.Tensor [batch_size, *shape, n_samples, n_channels] + returns: + channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density + *transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume + ) + """ + + # 1. Calculate the weights + _, _, dt = volume_range.partition(ts) + ddensity = density * dt + + mass = torch.cumsum(ddensity, dim=-2) + transmittance = torch.exp(-mass[..., -1, :]) + + alphas = 1.0 - torch.exp(-ddensity) + Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2)) + # This is the probability of light hitting and reflecting off of + # something at depth [..., i, :]. + weights = alphas * Ts + + # 2. Integrate channels + channels = torch.sum(channels * weights, dim=-2) + + return channels, weights, transmittance + + +def volume_query_points(volume, grid_size): + indices = torch.arange(grid_size**3, device=volume.bbox_min.device) + zs = indices % grid_size + ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size + xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size + combined = torch.stack([xs, ys, zs], dim=1) + return (combined.float() / (grid_size - 1)) * (volume.bbox_max - volume.bbox_min) + volume.bbox_min + + +def _convert_srgb_to_linear(u: torch.Tensor): + return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4) + + +def _create_flat_edge_indices( + flat_cube_indices: torch.Tensor, + grid_size: tuple[int, int, int], +): + num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2] + y_offset = num_xs + num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2] + z_offset = num_xs + num_ys + return torch.stack( + [ + # Edges spanning x-axis. + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2], + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + (flat_cube_indices[:, 1] + 1) * grid_size[2] + + flat_cube_indices[:, 2], + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1, + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + (flat_cube_indices[:, 1] + 1) * grid_size[2] + + flat_cube_indices[:, 2] + + 1, + # Edges spanning y-axis. + ( + y_offset + + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + ), + ( + y_offset + + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + ), + ( + y_offset + + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1 + ), + ( + y_offset + + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1 + ), + # Edges spanning z-axis. + ( + z_offset + + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + + flat_cube_indices[:, 1] * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + + flat_cube_indices[:, 1] * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ], + dim=-1, + ) + + +class VoidNeRFModel(nn.Module): + """ + Implements the default empty space model where all queries are rendered as background. + """ + + def __init__(self, background, channel_scale=255.0): + super().__init__() + background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale) + + self.register_buffer("background", background) + + def forward(self, position): + background = self.background[None].to(position.device) + + shape = position.shape[:-1] + ones = [1] * (len(shape) - 1) + n_channels = background.shape[-1] + background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]) + + return background + + +@dataclass +class VolumeRange: + t0: torch.Tensor + t1: torch.Tensor + intersected: torch.Tensor + + def __post_init__(self): + assert self.t0.shape == self.t1.shape == self.intersected.shape + + def partition(self, ts): + """ + Partitions t0 and t1 into n_samples intervals. + + Args: + ts: [batch_size, *shape, n_samples, 1] + + Return: + + lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size, + *shape, n_samples, 1] + + where + ts \\in [lower, upper] deltas = upper - lower + """ + + mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5 + lower = torch.cat([self.t0[..., None, :], mids], dim=-2) + upper = torch.cat([mids, self.t1[..., None, :]], dim=-2) + delta = upper - lower + assert lower.shape == upper.shape == delta.shape == ts.shape + return lower, upper, delta + + +class BoundingBoxVolume(nn.Module): + """ + Axis-aligned bounding box defined by the two opposite corners. + """ + + def __init__( + self, + *, + bbox_min, + bbox_max, + min_dist: float = 0.0, + min_t_range: float = 1e-3, + ): + """ + Args: + bbox_min: the left/bottommost corner of the bounding box + bbox_max: the other corner of the bounding box + min_dist: all rays should start at least this distance away from the origin. + """ + super().__init__() + + self.min_dist = min_dist + self.min_t_range = min_t_range + + self.bbox_min = torch.tensor(bbox_min) + self.bbox_max = torch.tensor(bbox_max) + self.bbox = torch.stack([self.bbox_min, self.bbox_max]) + assert self.bbox.shape == (2, 3) + assert min_dist >= 0.0 + assert min_t_range > 0.0 + + def intersect( + self, + origin: torch.Tensor, + direction: torch.Tensor, + t0_lower: torch.Tensor | None = None, + epsilon=1e-6, + ): + """ + Args: + origin: [batch_size, *shape, 3] + direction: [batch_size, *shape, 3] + t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. + params: Optional meta parameters in case Volume is parametric + epsilon: to stabilize calculations + + Return: + A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with + the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to + be on the boundary of the volume. + """ + + batch_size, *shape, _ = origin.shape + ones = [1] * len(shape) + bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device) + + def _safe_divide(a, b, epsilon=1e-6): + return a / torch.where(b < 0, b - epsilon, b + epsilon) + + ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon) + + # Cases to think about: + # + # 1. t1 <= t0: the ray does not pass through the AABB. + # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin. + # 3. t0 <= 0 <= t1: the ray starts from inside the BB + # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice. + # + # 1 and 4 are clearly handled from t0 < t1 below. + # Making t0 at least min_dist (>= 0) takes care of 2 and 3. + t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist) + t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values + assert t0.shape == t1.shape == (batch_size, *shape, 1) + if t0_lower is not None: + assert t0.shape == t0_lower.shape + t0 = torch.maximum(t0, t0_lower) + + intersected = t0 + self.min_t_range < t1 + t0 = torch.where(intersected, t0, torch.zeros_like(t0)) + t1 = torch.where(intersected, t1, torch.ones_like(t1)) + + return VolumeRange(t0=t0, t1=t1, intersected=intersected) + + +class StratifiedRaySampler(nn.Module): + """ + Instead of fixed intervals, a sample is drawn uniformly at random from each interval. + """ + + def __init__(self, depth_mode: str = "linear"): + """ + :param depth_mode: linear samples ts linearly in depth. harmonic ensures + closer points are sampled more densely. + """ + self.depth_mode = depth_mode + assert self.depth_mode in ("linear", "geometric", "harmonic") + + def sample( + self, + t0: torch.Tensor, + t1: torch.Tensor, + n_samples: int, + epsilon: float = 1e-3, + ) -> torch.Tensor: + """ + Args: + t0: start time has shape [batch_size, *shape, 1] + t1: finish time has shape [batch_size, *shape, 1] + n_samples: number of ts to sample + Return: + sampled ts of shape [batch_size, *shape, n_samples, 1] + """ + ones = [1] * (len(t0.shape) - 1) + ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device) + + if self.depth_mode == "linear": + ts = t0 * (1.0 - ts) + t1 * ts + elif self.depth_mode == "geometric": + ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp() + elif self.depth_mode == "harmonic": + # The original NeRF recommends this interpolation scheme for + # spherical scenes, but there could be some weird edge cases when + # the observer crosses from the inner to outer volume. + ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts) + + mids = 0.5 * (ts[..., 1:] + ts[..., :-1]) + upper = torch.cat([mids, t1], dim=-1) + lower = torch.cat([t0, mids], dim=-1) + # yiyi notes: add a random seed here for testing, don't forget to remove + torch.manual_seed(0) + t_rand = torch.rand_like(ts) + + ts = lower + (upper - lower) * t_rand + return ts.unsqueeze(-1) + + +class ImportanceRaySampler(nn.Module): + """ + Given the initial estimate of densities, this samples more from regions/bins expected to have objects. + """ + + def __init__( + self, + volume_range: VolumeRange, + ts: torch.Tensor, + weights: torch.Tensor, + blur_pool: bool = False, + alpha: float = 1e-5, + ): + """ + Args: + volume_range: the range in which a ray intersects the given volume. + ts: earlier samples from the coarse rendering step + weights: discretized version of density * transmittance + blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF. + alpha: small value to add to weights. + """ + self.volume_range = volume_range + self.ts = ts.clone().detach() + self.weights = weights.clone().detach() + self.blur_pool = blur_pool + self.alpha = alpha + + @torch.no_grad() + def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: + """ + Args: + t0: start time has shape [batch_size, *shape, 1] + t1: finish time has shape [batch_size, *shape, 1] + n_samples: number of ts to sample + Return: + sampled ts of shape [batch_size, *shape, n_samples, 1] + """ + lower, upper, _ = self.volume_range.partition(self.ts) + + batch_size, *shape, n_coarse_samples, _ = self.ts.shape + + weights = self.weights + if self.blur_pool: + padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2) + maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :]) + weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :]) + weights = weights + self.alpha + pmf = weights / weights.sum(dim=-2, keepdim=True) + inds = sample_pmf(pmf, n_samples) + assert inds.shape == (batch_size, *shape, n_samples, 1) + assert (inds >= 0).all() and (inds < n_coarse_samples).all() + + t_rand = torch.rand(inds.shape, device=inds.device) + lower_ = torch.gather(lower, -2, inds) + upper_ = torch.gather(upper, -2, inds) + + ts = lower_ + (upper_ - lower_) * t_rand + ts = torch.sort(ts, dim=-2).values + return ts + + +@dataclass +class MeshDecoderOutput(BaseOutput): + """ + A 3D triangle mesh with optional data at the vertices and faces. + + Args: + verts (`torch.Tensor` of shape `(N, 3)`): + array of vertext coordinates + faces (`torch.Tensor` of shape `(N, 3)`): + array of triangles, pointing to indices in verts. + vertext_channels (Dict): + vertext coordinates for each color channel + """ + + verts: torch.Tensor + faces: torch.Tensor + vertex_channels: dict[str, torch.Tensor] + + +class MeshDecoder(nn.Module): + """ + Construct meshes from Signed distance functions (SDFs) using marching cubes method + """ + + def __init__(self): + super().__init__() + cases = torch.zeros(256, 5, 3, dtype=torch.long) + masks = torch.zeros(256, 5, dtype=torch.bool) + + self.register_buffer("cases", cases) + self.register_buffer("masks", masks) + + def forward(self, field: torch.Tensor, min_point: torch.Tensor, size: torch.Tensor): + """ + For a signed distance field, produce a mesh using marching cubes. + + :param field: a 3D tensor of field values, where negative values correspond + to the outside of the shape. The dimensions correspond to the x, y, and z directions, respectively. + :param min_point: a tensor of shape [3] containing the point corresponding + to (0, 0, 0) in the field. + :param size: a tensor of shape [3] containing the per-axis distance from the + (0, 0, 0) field corner and the (-1, -1, -1) field corner. + """ + assert len(field.shape) == 3, "input must be a 3D scalar field" + dev = field.device + + cases = self.cases.to(dev) + masks = self.masks.to(dev) + + min_point = min_point.to(dev) + size = size.to(dev) + + grid_size = field.shape + grid_size_tensor = torch.tensor(grid_size).to(size) + + # Create bitmasks between 0 and 255 (inclusive) indicating the state + # of the eight corners of each cube. + bitmasks = (field > 0).to(torch.uint8) + bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1) + bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2) + bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4) + + # Compute corner coordinates across the entire grid. + corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype) + corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(grid_size[0], device=dev, dtype=field.dtype)[ + :, None, None + ] + corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(grid_size[1], device=dev, dtype=field.dtype)[ + :, None + ] + corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(grid_size[2], device=dev, dtype=field.dtype) + + # Compute all vertices across all edges in the grid, even though we will + # throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices. + # These are all midpoints, and don't account for interpolation (which is + # done later based on the used edge midpoints). + edge_midpoints = torch.cat( + [ + ((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3), + ((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3), + ((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3), + ], + dim=0, + ) + + # Create a flat array of [X, Y, Z] indices for each cube. + cube_indices = torch.zeros( + grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long + ) + cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[:, None, None] + cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[:, None] + cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev) + flat_cube_indices = cube_indices.reshape(-1, 3) + + # Create a flat array mapping each cube to 12 global edge indices. + edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size) + + # Apply the LUT to figure out the triangles. + flat_bitmasks = bitmasks.reshape(-1).long() # must cast to long for indexing to believe this not a mask + local_tris = cases[flat_bitmasks] + local_masks = masks[flat_bitmasks] + # Compute the global edge indices for the triangles. + global_tris = torch.gather(edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)).reshape( + local_tris.shape + ) + # Select the used triangles for each cube. + selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)] + + # Now we have a bunch of indices into the full list of possible vertices, + # but we want to reduce this list to only the used vertices. + used_vertex_indices = torch.unique(selected_tris.view(-1)) + used_edge_midpoints = edge_midpoints[used_vertex_indices] + old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long) + old_index_to_new_index[used_vertex_indices] = torch.arange( + len(used_vertex_indices), device=dev, dtype=torch.long + ) + + # Rewrite the triangles to use the new indices + faces = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(selected_tris.shape) + + # Compute the actual interpolated coordinates corresponding to edge midpoints. + v1 = torch.floor(used_edge_midpoints).to(torch.long) + v2 = torch.ceil(used_edge_midpoints).to(torch.long) + s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]] + s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]] + p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point + p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point + # The signs of s1 and s2 should be different. We want to find + # t such that t*s2 + (1-t)*s1 = 0. + t = (s1 / (s1 - s2))[:, None] + verts = t * p2 + (1 - t) * p1 + + return MeshDecoderOutput(verts=verts, faces=faces, vertex_channels=None) + + +@dataclass +class MLPNeRFModelOutput(BaseOutput): + density: torch.Tensor + signed_distance: torch.Tensor + channels: torch.Tensor + ts: torch.Tensor + + +class MLPNeRSTFModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + d_hidden: int = 256, + n_output: int = 12, + n_hidden_layers: int = 6, + act_fn: str = "swish", + insert_direction_at: int = 4, + ): + super().__init__() + + # Instantiate the MLP + + # Find out the dimension of encoded position and direction + dummy = torch.eye(1, 3) + d_posenc_pos = encode_position(position=dummy).shape[-1] + d_posenc_dir = encode_direction(position=dummy).shape[-1] + + mlp_widths = [d_hidden] * n_hidden_layers + input_widths = [d_posenc_pos] + mlp_widths + output_widths = mlp_widths + [n_output] + + if insert_direction_at is not None: + input_widths[insert_direction_at] += d_posenc_dir + + self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)]) + + if act_fn == "swish": + # self.activation = swish + # yiyi testing: + self.activation = lambda x: F.silu(x) + else: + raise ValueError(f"Unsupported activation function {act_fn}") + + self.sdf_activation = torch.tanh + self.density_activation = torch.nn.functional.relu + self.channel_activation = torch.sigmoid + + def map_indices_to_keys(self, output): + h_map = { + "sdf": (0, 1), + "density_coarse": (1, 2), + "density_fine": (2, 3), + "stf": (3, 6), + "nerf_coarse": (6, 9), + "nerf_fine": (9, 12), + } + + mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()} + + return mapped_output + + def forward(self, *, position, direction, ts, nerf_level="coarse", rendering_mode="nerf"): + """ + Args: + position (`torch.Tensor`): + 3D query positions of shape `(batch_size, ..., 3)` to evaluate the NeRSTF MLP at. + direction (`torch.Tensor`): + Viewing directions of shape `(batch_size, ..., 3)` used for view-dependent color prediction. + ts (`torch.Tensor`): + Per-ray sample distances of shape `(batch_size, ..., 1)` passed through to the output for downstream + integration. + nerf_level (`str`, *optional*, defaults to `"coarse"`): + Which density/color head to read from — `"coarse"` or `"fine"`. + rendering_mode (`str`, *optional*, defaults to `"nerf"`): + Output head to use: `"nerf"` for radiance-field colors or `"stf"` for the signed-distance/texture + field. + """ + h = encode_position(position) + + h_preact = h + h_directionless = None + for i, layer in enumerate(self.mlp): + if i == self.config.insert_direction_at: # 4 in the config + h_directionless = h_preact + h_direction = encode_direction(position, direction=direction) + h = torch.cat([h, h_direction], dim=-1) + + h = layer(h) + + h_preact = h + + if i < len(self.mlp) - 1: + h = self.activation(h) + + h_final = h + if h_directionless is None: + h_directionless = h_preact + + activation = self.map_indices_to_keys(h_final) + + if nerf_level == "coarse": + h_density = activation["density_coarse"] + else: + h_density = activation["density_fine"] + + if rendering_mode == "nerf": + if nerf_level == "coarse": + h_channels = activation["nerf_coarse"] + else: + h_channels = activation["nerf_fine"] + + elif rendering_mode == "stf": + h_channels = activation["stf"] + + density = self.density_activation(h_density) + signed_distance = self.sdf_activation(activation["sdf"]) + channels = self.channel_activation(h_channels) + + # yiyi notes: I think signed_distance is not used + return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts) + + +class ChannelsProj(nn.Module): + def __init__( + self, + *, + vectors: int, + channels: int, + d_latent: int, + ): + super().__init__() + self.proj = nn.Linear(d_latent, vectors * channels) + self.norm = nn.LayerNorm(channels) + self.d_latent = d_latent + self.vectors = vectors + self.channels = channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_bvd = x + w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) + b_vc = self.proj.bias.view(1, self.vectors, self.channels) + h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) + h = self.norm(h) + + h = h + b_vc + return h + + +class ShapEParamsProjModel(ModelMixin, ConfigMixin): + """ + project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP). + + For more details, see the original paper: + """ + + @register_to_config + def __init__( + self, + *, + param_names: tuple[str] = ( + "nerstf.mlp.0.weight", + "nerstf.mlp.1.weight", + "nerstf.mlp.2.weight", + "nerstf.mlp.3.weight", + ), + param_shapes: tuple[tuple[int]] = ( + (256, 93), + (256, 256), + (256, 256), + (256, 256), + ), + d_latent: int = 1024, + ): + super().__init__() + + # check inputs + if len(param_names) != len(param_shapes): + raise ValueError("Must provide same number of `param_names` as `param_shapes`") + self.projections = nn.ModuleDict({}) + for k, (vectors, channels) in zip(param_names, param_shapes): + self.projections[_sanitize_name(k)] = ChannelsProj( + vectors=vectors, + channels=channels, + d_latent=d_latent, + ) + + def forward(self, x: torch.Tensor): + """ + Args: + x (`torch.Tensor`): + Latent representation of a 3D asset of shape `(batch_size, total_vectors, d_latent)`, sliced per + `param_name` and projected to each MLP weight tensor. + """ + out = {} + start = 0 + for k, shape in zip(self.config.param_names, self.config.param_shapes): + vectors, _ = shape + end = start + vectors + x_bvd = x[:, start:end] + out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) + start = end + return out + + +class ShapERenderer(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + *, + param_names: tuple[str] = ( + "nerstf.mlp.0.weight", + "nerstf.mlp.1.weight", + "nerstf.mlp.2.weight", + "nerstf.mlp.3.weight", + ), + param_shapes: tuple[tuple[int]] = ( + (256, 93), + (256, 256), + (256, 256), + (256, 256), + ), + d_latent: int = 1024, + d_hidden: int = 256, + n_output: int = 12, + n_hidden_layers: int = 6, + act_fn: str = "swish", + insert_direction_at: int = 4, + background: tuple[float] = ( + 255.0, + 255.0, + 255.0, + ), + ): + super().__init__() + + self.params_proj = ShapEParamsProjModel( + param_names=param_names, + param_shapes=param_shapes, + d_latent=d_latent, + ) + self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at) + self.void = VoidNeRFModel(background=background, channel_scale=255.0) + self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0]) + self.mesh_decoder = MeshDecoder() + + @torch.no_grad() + def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False): + """ + Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below + with some abuse of notations) + + C(r) := sum( + transmittance(t[i]) * integrate( + lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]], + ) for i in range(len(parts)) + ) + transmittance(t[-1]) * void_model(t[-1]).channels + + where + + 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through + the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are + obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t + where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the + shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and + transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1], + math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). + + Args: + rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples: + number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including + + :return: A tuple of + - `channels` + - A importance samplers for additional fine-grained rendering + - raw model output + """ + origin, direction = rays[..., 0, :], rays[..., 1, :] + + # Integrate over [t[i], t[i + 1]] + + # 1 Intersect the rays with the current volume and sample ts to integrate along. + vrange = self.volume.intersect(origin, direction, t0_lower=None) + ts = sampler.sample(vrange.t0, vrange.t1, n_samples) + ts = ts.to(rays.dtype) + + if prev_model_out is not None: + # Append the previous ts now before fprop because previous + # rendering used a different model and we can't reuse the output. + ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values + + batch_size, *_shape, _t0_dim = vrange.t0.shape + _, *ts_shape, _ts_dim = ts.shape + + # 2. Get the points along the ray and query the model + directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3]) + positions = origin.unsqueeze(-2) + ts * directions + + directions = directions.to(self.mlp.dtype) + positions = positions.to(self.mlp.dtype) + + optional_directions = directions if render_with_direction else None + + model_out = self.mlp( + position=positions, + direction=optional_directions, + ts=ts, + nerf_level="coarse" if prev_model_out is None else "fine", + ) + + # 3. Integrate the model results + channels, weights, transmittance = integrate_samples( + vrange, model_out.ts, model_out.density, model_out.channels + ) + + # 4. Clean up results that do not intersect with the volume. + transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance)) + channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels)) + # 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). + channels = channels + transmittance * self.void(origin) + + weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights) + + return channels, weighted_sampler, model_out + + @torch.no_grad() + def decode_to_image( + self, + latents, + device, + size: int = 64, + ray_batch_size: int = 4096, + n_coarse_samples=64, + n_fine_samples=128, + ): + # project the parameters from the generated latents + projected_params = self.params_proj(latents) + + # update the mlp layers of the renderer + for name, param in self.mlp.state_dict().items(): + if f"nerstf.{name}" in projected_params.keys(): + param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) + + # create cameras object + # Deferred import: `camera.py` lives under `pipelines/shap_e/`, and a top-level import here + # would trigger a circular import when `diffusers.models` is loaded before `diffusers.pipelines` + # (e.g. under DIFFUSERS_SLOW_IMPORT used by the docs build). + from ...pipelines.shap_e.camera import create_pan_cameras + + camera = create_pan_cameras(size) + rays = camera.camera_rays + rays = rays.to(device) + n_batches = rays.shape[1] // ray_batch_size + + coarse_sampler = StratifiedRaySampler() + + images = [] + + for idx in range(n_batches): + rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size] + + # render rays with coarse, stratified samples. + _, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples) + # Then, render with additional importance-weighted ray samples. + channels, _, _ = self.render_rays( + rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out + ) + + images.append(channels) + + images = torch.cat(images, dim=1) + images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0) + + return images + + @torch.no_grad() + def decode_to_mesh( + self, + latents, + device, + grid_size: int = 128, + query_batch_size: int = 4096, + texture_channels: tuple = ("R", "G", "B"), + ): + # 1. project the parameters from the generated latents + projected_params = self.params_proj(latents) + + # 2. update the mlp layers of the renderer + for name, param in self.mlp.state_dict().items(): + if f"nerstf.{name}" in projected_params.keys(): + param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) + + # 3. decoding with STF rendering + # 3.1 query the SDF values at vertices along a regular 128**3 grid + + query_points = volume_query_points(self.volume, grid_size) + query_positions = query_points[None].repeat(1, 1, 1).to(device=device, dtype=self.mlp.dtype) + + fields = [] + + for idx in range(0, query_positions.shape[1], query_batch_size): + query_batch = query_positions[:, idx : idx + query_batch_size] + + model_out = self.mlp( + position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" + ) + fields.append(model_out.signed_distance) + + # predicted SDF values + fields = torch.cat(fields, dim=1) + fields = fields.float() + + assert len(fields.shape) == 3 and fields.shape[-1] == 1, ( + f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" + ) + + fields = fields.reshape(1, *([grid_size] * 3)) + + # create grid 128 x 128 x 128 + # - force a negative border around the SDFs to close off all the models. + full_grid = torch.zeros( + 1, + grid_size + 2, + grid_size + 2, + grid_size + 2, + device=fields.device, + dtype=fields.dtype, + ) + full_grid.fill_(-1.0) + full_grid[:, 1:-1, 1:-1, 1:-1] = fields + fields = full_grid + + # apply a differentiable implementation of Marching Cubes to construct meshs + raw_meshes = [] + mesh_mask = [] + + for field in fields: + raw_mesh = self.mesh_decoder(field, self.volume.bbox_min, self.volume.bbox_max - self.volume.bbox_min) + mesh_mask.append(True) + raw_meshes.append(raw_mesh) + + mesh_mask = torch.tensor(mesh_mask, device=fields.device) + max_vertices = max(len(m.verts) for m in raw_meshes) + + # 3.2. query the texture color head at each vertex of the resulting mesh. + texture_query_positions = torch.stack( + [m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes], + dim=0, + ) + texture_query_positions = texture_query_positions.to(device=device, dtype=self.mlp.dtype) + + textures = [] + + for idx in range(0, texture_query_positions.shape[1], query_batch_size): + query_batch = texture_query_positions[:, idx : idx + query_batch_size] + + texture_model_out = self.mlp( + position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" + ) + textures.append(texture_model_out.channels) + + # predict texture color + textures = torch.cat(textures, dim=1) + + textures = _convert_srgb_to_linear(textures) + textures = textures.float() + + # 3.3 augment the mesh with texture data + assert len(textures.shape) == 3 and textures.shape[-1] == len(texture_channels), ( + f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" + ) + + for m, texture in zip(raw_meshes, textures): + texture = texture[: len(m.verts)] + m.vertex_channels = dict(zip(texture_channels, texture.unbind(-1))) + + return raw_meshes[0] diff --git a/src/diffusers/models/others/watermark_if.py b/src/diffusers/models/others/watermark_if.py new file mode 100644 index 000000000000..bc88e8ca018f --- /dev/null +++ b/src/diffusers/models/others/watermark_if.py @@ -0,0 +1,58 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import PIL.Image +import torch +from PIL import Image + +from ...configuration_utils import ConfigMixin +from ...utils import PIL_INTERPOLATION +from ..modeling_utils import ModelMixin + + +class IFWatermarker(ModelMixin, ConfigMixin): + def __init__(self): + super().__init__() + + self.register_buffer("watermark_image", torch.zeros((62, 62, 4))) + self.watermark_image_as_pil = None + + def apply_watermark(self, images: list[PIL.Image.Image], sample_size=None): + # Copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287 + + h = images[0].height + w = images[0].width + + sample_size = sample_size or h + + coef = min(h / sample_size, w / sample_size) + img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w) + + S1, S2 = 1024**2, img_w * img_h + K = (S2 / S1) ** 0.5 + wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K) + + if self.watermark_image_as_pil is None: + watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy() + watermark_image = Image.fromarray(watermark_image, mode="RGBA") + self.watermark_image_as_pil = watermark_image + + wm_img = self.watermark_image_as_pil.resize( + (wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None + ) + + for pil_img in images: + pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1]) + + return images diff --git a/src/diffusers/models/unets/__init__.py b/src/diffusers/models/unets/__init__.py index 9ef04fb62606..30a49c638991 100644 --- a/src/diffusers/models/unets/__init__.py +++ b/src/diffusers/models/unets/__init__.py @@ -5,6 +5,7 @@ from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel + from .unet_2d_condition_audioldm2 import AudioLDM2UNet2DConditionModel from .unet_3d_condition import UNet3DConditionModel from .unet_i2vgen_xl import I2VGenXLUNet from .unet_kandinsky3 import Kandinsky3UNet diff --git a/src/diffusers/models/unets/unet_2d_condition_audioldm2.py b/src/diffusers/models/unets/unet_2d_condition_audioldm2.py new file mode 100644 index 000000000000..e56d5d0c5a64 --- /dev/null +++ b/src/diffusers/models/unets/unet_2d_condition_audioldm2.py @@ -0,0 +1,1301 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import logging +from ..activations import get_activation +from ..attention import AttentionMixin +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttnAddedKVProcessor, + AttnProcessor, +) +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..transformers.transformer_2d import Transformer2DModel +from .unet_2d_blocks import DownBlock2D, UpBlock2D +from .unet_2d_condition import UNet2DConditionOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class AudioLDM2UNet2DConditionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional + self-attention layer in each Transformer block, as well as multiple cross-attention layers. It also allows for up + to two cross-attention embeddings, `encoder_hidden_states` and `encoder_hidden_states_1`. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can only be `UNetMidBlock2DCrossAttn` for AudioLDM2. + up_block_types (`tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention (`bool` or `tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int | None = None, + in_channels: int = 4, + out_channels: int = 4, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: str = "UNetMidBlock2DCrossAttn", + up_block_types: tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: bool | tuple[bool] = False, + block_out_channels: tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int | tuple[int] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int | None = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int | tuple[int] = 1280, + transformer_layers_per_block: int | tuple[int] = 1, + attention_head_dim: int | tuple[int] = 8, + num_attention_heads: int | tuple[int] | None = None, + use_linear_projection: bool = False, + class_embed_type: str | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + time_embedding_type: str = "positional", + time_embedding_dim: int | None = None, + time_embedding_act_fn: str | None = None, + timestep_post_act: str | None = None, + time_cond_proj_dim: int | None = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: int | None = None, + class_embeddings_concat: bool = False, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError(f"{time_embedding_type} does not exist. Please make sure to use `positional`.") + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + else: + raise ValueError( + f"unknown mid_block_type : {mid_block_type}. Should be `UNetMidBlock2DCrossAttn` for AudioLDM2." + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | float | int, + encoder_hidden_states: torch.Tensor, + class_labels: torch.Tensor | None = None, + timestep_cond: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + return_dict: bool = True, + encoder_hidden_states_1: torch.Tensor | None = None, + encoder_attention_mask_1: torch.Tensor | None = None, + ) -> UNet2DConditionOutput | tuple: + r""" + The [`AudioLDM2UNet2DConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*): + Conditional class labels of shape `(batch,)`. Only used when the model is configured with a + `class_embed_type`. + timestep_cond (`torch.Tensor`, *optional*): + Additional timestep conditioning of shape `(batch, time_cond_proj_dim)`, applied after the timestep + embedding. + attention_mask (`torch.Tensor`, *optional*): + A self-attention mask of shape `(batch, sequence_length)`. If `True` the mask is kept, otherwise if + `False` it is discarded. The mask is converted to a bias added to the attention scores for "discard" + tokens. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + encoder_hidden_states_1 (`torch.Tensor`, *optional*): + A second set of encoder hidden states with shape `(batch, sequence_length_2, feature_dim_2)`. Can be + used to condition the model on a different set of embeddings to `encoder_hidden_states`. + encoder_attention_mask_1 (`torch.Tensor`, *optional*): + A cross-attention mask of shape `(batch, sequence_length_2)` is applied to `encoder_hidden_states_1`. + If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if encoder_attention_mask_1 is not None: + encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0 + encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states_1=encoder_hidden_states_1, + encoder_attention_mask_1=encoder_attention_mask_1, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states_1=encoder_hidden_states_1, + encoder_attention_mask_1=encoder_attention_mask_1, + ) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states_1=encoder_hidden_states_1, + encoder_attention_mask_1=encoder_attention_mask_1, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) + if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: + raise ValueError( + "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " + f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" + ) + self.cross_attention_dim = cross_attention_dim + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + for j in range(len(cross_attention_dim)): + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim[j], + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + double_self_attention=True if cross_attention_dim[j] is None else False, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + encoder_hidden_states_1: torch.Tensor | None = None, + encoder_attention_mask_1: torch.Tensor | None = None, + ): + output_states = () + num_layers = len(self.resnets) + num_attention_per_layer = len(self.attentions) // num_layers + + encoder_hidden_states_1 = ( + encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states + ) + encoder_attention_mask_1 = ( + encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask + ) + + for i in range(num_layers): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb) + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self._gradient_checkpointing_func( + self.attentions[i * num_attention_per_layer + idx], + hidden_states, + forward_encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + forward_encoder_attention_mask, + )[0] + else: + hidden_states = self.resnets[i](hidden_states, temb) + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self.attentions[i * num_attention_per_layer + idx]( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=forward_encoder_hidden_states, + encoder_attention_mask=forward_encoder_attention_mask, + return_dict=False, + )[0] + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) + if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: + raise ValueError( + "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " + f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" + ) + self.cross_attention_dim = cross_attention_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + for j in range(len(cross_attention_dim)): + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim[j], + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + double_self_attention=True if cross_attention_dim[j] is None else False, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + encoder_hidden_states_1: torch.Tensor | None = None, + encoder_attention_mask_1: torch.Tensor | None = None, + ) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + num_attention_per_layer = len(self.attentions) // (len(self.resnets) - 1) + + encoder_hidden_states_1 = ( + encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states + ) + encoder_attention_mask_1 = ( + encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask + ) + + for i in range(len(self.resnets[1:])): + if torch.is_grad_enabled() and self.gradient_checkpointing: + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self._gradient_checkpointing_func( + self.attentions[i * num_attention_per_layer + idx], + hidden_states, + forward_encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + forward_encoder_attention_mask, + )[0] + hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb) + else: + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self.attentions[i * num_attention_per_layer + idx]( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=forward_encoder_hidden_states, + encoder_attention_mask=forward_encoder_attention_mask, + return_dict=False, + )[0] + + hidden_states = self.resnets[i + 1](hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) + if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: + raise ValueError( + "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " + f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" + ) + self.cross_attention_dim = cross_attention_dim + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + for j in range(len(cross_attention_dim)): + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim[j], + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + double_self_attention=True if cross_attention_dim[j] is None else False, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + upsample_size: int | None = None, + attention_mask: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + encoder_hidden_states_1: torch.Tensor | None = None, + encoder_attention_mask_1: torch.Tensor | None = None, + ): + num_layers = len(self.resnets) + num_attention_per_layer = len(self.attentions) // num_layers + + encoder_hidden_states_1 = ( + encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states + ) + encoder_attention_mask_1 = ( + encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask + ) + + for i in range(num_layers): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb) + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self._gradient_checkpointing_func( + self.attentions[i * num_attention_per_layer + idx], + hidden_states, + forward_encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + forward_encoder_attention_mask, + )[0] + else: + hidden_states = self.resnets[i](hidden_states, temb) + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self.attentions[i * num_attention_per_layer + idx]( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=forward_encoder_hidden_states, + encoder_attention_mask=forward_encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/src/diffusers/pipelines/ace_step/modeling_ace_step.py b/src/diffusers/pipelines/ace_step/modeling_ace_step.py index 769b07044420..06242de6d86a 100644 --- a/src/diffusers/pipelines/ace_step/modeling_ace_step.py +++ b/src/diffusers/pipelines/ace_step/modeling_ace_step.py @@ -12,845 +12,89 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Pipeline-specific models for ACE-Step 1.5. - -Holds the condition encoder (lyric + timbre + text packing), the encoder layer (``AceStepEncoderLayer`` — not used by -the DiT itself, hence kept here), the audio tokenizer / detokenizer used by cover conditioning, and the -``_pack_sequences`` helper. The DiT uses the RoPE helper, ``AceStepAttention``, and ``_create_4d_mask`` from -``diffusers/models/transformers/ace_step_transformer.py``. -""" - -from typing import Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.modeling_utils import ModelMixin -from ...models.normalization import RMSNorm -from ...models.transformers.ace_step_transformer import ( - AceStepAttention, - AceStepMLP, - _ace_step_rotary_freqs, - _create_4d_mask, - _is_flash_attention_backend, +from ...models.autoencoders.audio_tokenizer_ace_step import ( + AceStepAttentionPooler, # noqa: F401 re-exported for back-compat + _AceStepResidualFSQ, # noqa: F401 re-exported for back-compat ) -from ...utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# --------------------------------------------------------------------------- # -# helpers used only by condition encoder # -# --------------------------------------------------------------------------- # - - -def _pack_sequences( - hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - """Pack two masked sequences into one with all valid tokens first. - - Concatenates ``hidden1`` + ``hidden2`` along the sequence dim, then stably sorts each batch so mask=1 tokens come - before mask=0 tokens. Returns the packed hidden states plus a fresh contiguous mask. - """ - hidden_cat = torch.cat([hidden1, hidden2], dim=1) - mask_cat = torch.cat([mask1, mask2], dim=1) - - B, L, D = hidden_cat.shape - sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) - hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D)) - lengths = mask_cat.sum(dim=1) - new_mask = torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1) - return hidden_left, new_mask - - -class AceStepEncoderLayer(nn.Module): - """Pre-LN transformer block used by the lyric and timbre encoders.""" - - def __init__( - self, - hidden_size: int, - num_attention_heads: int, - num_key_value_heads: int, - head_dim: int, - intermediate_size: int, - attention_bias: bool = False, - attention_dropout: float = 0.0, - rms_norm_eps: float = 1e-6, - sliding_window: Optional[int] = None, - ): - super().__init__() - self.self_attn = AceStepAttention( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - head_dim=head_dim, - bias=attention_bias, - dropout=attention_dropout, - eps=rms_norm_eps, - sliding_window=sliding_window, - is_cross_attention=False, - ) - self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) - self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) - self.mlp = AceStepMLP(hidden_size, intermediate_size) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn( - hidden_states=hidden_states, - image_rotary_emb=position_embeddings, - attention_mask=attention_mask, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - -# --------------------------------------------------------------------------- # -# encoders # -# --------------------------------------------------------------------------- # - - -class AceStepLyricEncoder(ModelMixin, ConfigMixin): - """Lyric encoder: projects Qwen3 lyric embeddings and runs a small transformer. - - Output feeds the DiT cross-attention (after packing with text + timbre). - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - hidden_size: int = 2048, - intermediate_size: int = 6144, - text_hidden_dim: int = 1024, - num_lyric_encoder_hidden_layers: int = 8, - num_attention_heads: int = 16, - num_key_value_heads: int = 8, - head_dim: int = 128, - rope_theta: float = 1000000.0, - attention_bias: bool = False, - attention_dropout: float = 0.0, - rms_norm_eps: float = 1e-6, - sliding_window: int = 128, - layer_types: list = None, - ): - super().__init__() - - if layer_types is None: - layer_types = [ - "sliding_attention" if bool((i + 1) % 2) else "full_attention" - for i in range(num_lyric_encoder_hidden_layers) - ] - - self.embed_tokens = nn.Linear(text_hidden_dim, hidden_size) - self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) - self.head_dim = head_dim - self.rope_theta = rope_theta - self.sliding_window = sliding_window - - self.layers = nn.ModuleList( - [ - AceStepEncoderLayer( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - head_dim=head_dim, - intermediate_size=intermediate_size, - attention_bias=attention_bias, - attention_dropout=attention_dropout, - rms_norm_eps=rms_norm_eps, - sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None, - ) - for i in range(num_lyric_encoder_hidden_layers) - ] - ) - - self._layer_types = layer_types - self.gradient_checkpointing = False - - def forward( - self, - inputs_embeds: torch.FloatTensor, - attention_mask: torch.Tensor, - ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(inputs_embeds) - - seq_len = inputs_embeds.shape[1] - dtype = inputs_embeds.dtype - device = inputs_embeds.device - - cos, sin = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype) - position_embeddings = (cos, sin) - - if _is_flash_attention_backend(self.layers[0].self_attn.processor): - full_attn_mask = attention_mask - sliding_attn_mask = attention_mask - else: - full_attn_mask = _create_4d_mask( - seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, is_causal=False - ) - sliding_attn_mask = _create_4d_mask( - seq_len=seq_len, - dtype=dtype, - device=device, - attention_mask=attention_mask, - sliding_window=self.sliding_window, - is_sliding_window=True, - is_causal=False, - ) - - hidden_states = inputs_embeds - for i, layer_module in enumerate(self.layers): - mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else full_attn_mask - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( - layer_module, hidden_states, position_embeddings, mask - ) - else: - hidden_states = layer_module( - hidden_states=hidden_states, - position_embeddings=position_embeddings, - attention_mask=mask, - ) - return self.norm(hidden_states) - - -class AceStepTimbreEncoder(ModelMixin, ConfigMixin): - """Timbre encoder: consumes VAE-encoded reference-audio latents and returns a - pooled per-batch timbre embedding (plus a presence mask). - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - hidden_size: int = 2048, - intermediate_size: int = 6144, - timbre_hidden_dim: int = 64, - num_timbre_encoder_hidden_layers: int = 4, - num_attention_heads: int = 16, - num_key_value_heads: int = 8, - head_dim: int = 128, - rope_theta: float = 1000000.0, - attention_bias: bool = False, - attention_dropout: float = 0.0, - rms_norm_eps: float = 1e-6, - sliding_window: int = 128, - layer_types: list = None, - ): - super().__init__() - - if layer_types is None: - layer_types = [ - "sliding_attention" if bool((i + 1) % 2) else "full_attention" - for i in range(num_timbre_encoder_hidden_layers) - ] - - self.embed_tokens = nn.Linear(timbre_hidden_dim, hidden_size) - self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) - self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size)) - self.head_dim = head_dim - self.rope_theta = rope_theta - self.sliding_window = sliding_window - - self.layers = nn.ModuleList( - [ - AceStepEncoderLayer( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - head_dim=head_dim, - intermediate_size=intermediate_size, - attention_bias=attention_bias, - attention_dropout=attention_dropout, - rms_norm_eps=rms_norm_eps, - sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None, - ) - for i in range(num_timbre_encoder_hidden_layers) - ] - ) - - self._layer_types = layer_types - self.gradient_checkpointing = False - - @staticmethod - def unpack_timbre_embeddings( - timbre_embs_packed: torch.Tensor, refer_audio_order_mask: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - N, d = timbre_embs_packed.shape - device = timbre_embs_packed.device - dtype = timbre_embs_packed.dtype - - B = int(refer_audio_order_mask.max().item() + 1) - counts = torch.bincount(refer_audio_order_mask, minlength=B) - max_count = counts.max().item() - - sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True) - sorted_batch_ids = refer_audio_order_mask[sorted_indices] - - positions = torch.arange(N, device=device) - batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]]) - positions_in_sorted = positions - batch_starts[sorted_batch_ids] - - inverse_indices = torch.empty_like(sorted_indices) - inverse_indices[sorted_indices] = torch.arange(N, device=device) - positions_in_batch = positions_in_sorted[inverse_indices] - - indices_2d = refer_audio_order_mask * max_count + positions_in_batch - one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype) - - timbre_embs_flat = one_hot.t() @ timbre_embs_packed - timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d) - - mask_flat = (one_hot.sum(dim=0) > 0).long() - new_mask = mask_flat.reshape(B, max_count) - return timbre_embs_unpack, new_mask - - def forward( - self, - refer_audio_acoustic_hidden_states_packed: torch.FloatTensor, - refer_audio_order_mask: torch.LongTensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - inputs_embeds = self.embed_tokens(refer_audio_acoustic_hidden_states_packed) - - seq_len = inputs_embeds.shape[1] - dtype = inputs_embeds.dtype - device = inputs_embeds.device - - cos, sin = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype) - position_embeddings = (cos, sin) - - sliding_attn_mask = None - if not _is_flash_attention_backend(self.layers[0].self_attn.processor): - sliding_attn_mask = _create_4d_mask( - seq_len=seq_len, - dtype=dtype, - device=device, - attention_mask=None, - sliding_window=self.sliding_window, - is_sliding_window=True, - is_causal=False, - ) - - hidden_states = inputs_embeds - for i, layer_module in enumerate(self.layers): - # No padding mask on timbre input (pre-packed), so full-attention layers see None. - mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else None - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( - layer_module, hidden_states, position_embeddings, mask - ) - else: - hidden_states = layer_module( - hidden_states=hidden_states, - position_embeddings=position_embeddings, - attention_mask=mask, - ) - - hidden_states = self.norm(hidden_states) - # CLS-like pooling: first-token embedding per packed sequence. - hidden_states = hidden_states[:, 0, :] - timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask) - return timbre_embs_unpack, timbre_embs_mask - - -# --------------------------------------------------------------------------- # -# audio tokenizer / detokenizer # -# --------------------------------------------------------------------------- # - - -class _AceStepResidualFSQ(nn.Module): - """Minimal ResidualFSQ compatible with ACE-Step's saved tokenizer weights.""" - - def __init__( - self, - dim: int = 2048, - levels: Optional[list] = None, - num_quantizers: int = 1, - ): - super().__init__() - - if levels is None: - levels = [8, 8, 8, 5, 5, 5] - - self.levels = levels - self.num_quantizers = num_quantizers - self.codebook_dim = len(levels) - - self.project_in = nn.Linear(dim, self.codebook_dim) - self.project_out = nn.Linear(self.codebook_dim, dim) - - levels_tensor = torch.tensor(levels, dtype=torch.long) - basis = torch.cumprod(torch.tensor([1] + levels[:-1], dtype=torch.long), dim=0) - scales = torch.stack([levels_tensor.float() ** -i for i in range(num_quantizers)]) - self.register_buffer("_levels", levels_tensor, persistent=False) - self.register_buffer("_basis", basis, persistent=False) - self.register_buffer("scales", scales, persistent=False) - - @property - def codebook_size(self) -> int: - return int(torch.prod(self._levels).item()) - - def _indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor: - levels = self._levels.to(device=indices.device) - basis = self._basis.to(device=indices.device) - level_indices = (indices.long().unsqueeze(-1) // basis) % levels - scale = 2.0 / (levels.to(dtype=torch.float32) - 1.0) - return level_indices.to(dtype=torch.float32) * scale - 1.0 - - def _codes_to_indices(self, codes: torch.Tensor) -> torch.Tensor: - levels = self._levels.to(device=codes.device, dtype=codes.dtype) - basis = self._basis.to(device=codes.device, dtype=codes.dtype) - level_indices = (codes + 1.0) / (2.0 / (levels - 1.0)) - return (level_indices * basis).sum(dim=-1).round().to(torch.long) - - def _quantize(self, x: torch.Tensor) -> torch.Tensor: - levels = self._levels.to(device=x.device, dtype=x.dtype) - levels_minus_one = levels - 1.0 - step = 2.0 / levels_minus_one - bracket = levels_minus_one * (x.clamp(-1.0, 1.0) + 1.0) / 2.0 + 0.5 - return step * torch.floor(bracket) - 1.0 - - def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor: - if indices.ndim == 2: - indices = indices.unsqueeze(-1) - if indices.shape[-1] != self.num_quantizers: - raise ValueError( - f"Expected audio code indices with last dimension {self.num_quantizers}, got {indices.shape[-1]}." - ) - - codes = [] - for quantizer_idx in range(self.num_quantizers): - code = self._indices_to_codes(indices[..., quantizer_idx]) - scale = self.scales[quantizer_idx].to(device=code.device, dtype=code.dtype) - codes.append(code * scale) - return torch.stack(codes, dim=0) - - def get_output_from_indices(self, indices: torch.Tensor) -> torch.Tensor: - codes = self.get_codes_from_indices(indices).sum(dim=0) - weight = self.project_out.weight.float() - bias = self.project_out.bias.float() if self.project_out.bias is not None else None - output = F.linear(codes.float(), weight, bias) - return output.to(dtype=self.project_out.weight.dtype) - - def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - input_dtype = hidden_states.dtype - weight = self.project_in.weight.float() - bias = self.project_in.bias.float() if self.project_in.bias is not None else None - hidden_states = F.linear(hidden_states.float(), weight, bias) - - levels = self._levels.to(device=hidden_states.device, dtype=hidden_states.dtype) - soft_clamp = 1.0 + (1.0 / (levels - 1.0)) - hidden_states = (hidden_states / soft_clamp).tanh() * soft_clamp - - quantized_out = torch.zeros_like(hidden_states) - residual = hidden_states - all_indices = [] - for scale in self.scales.to(device=hidden_states.device, dtype=hidden_states.dtype): - quantized = self._quantize(residual / scale) * scale - residual = residual - quantized.detach() - quantized_out = quantized_out + quantized - all_indices.append(self._codes_to_indices(quantized / scale)) - - weight = self.project_out.weight.float() - bias = self.project_out.bias.float() if self.project_out.bias is not None else None - quantized_out = F.linear(quantized_out.float(), weight, bias).to(dtype=input_dtype) - all_indices = torch.stack(all_indices, dim=-1) - return quantized_out, all_indices - - -class AceStepAttentionPooler(nn.Module): - """Attention pooler used by the ACE-Step audio tokenizer.""" - - def __init__( - self, - hidden_size: int = 2048, - intermediate_size: int = 6144, - num_attention_pooler_hidden_layers: int = 2, - num_attention_heads: int = 16, - num_key_value_heads: int = 8, - head_dim: int = 128, - rope_theta: float = 1000000.0, - attention_bias: bool = False, - attention_dropout: float = 0.0, - rms_norm_eps: float = 1e-6, - sliding_window: int = 128, - layer_types: list = None, - ): - super().__init__() - - if layer_types is None: - layer_types = [ - "sliding_attention" if bool((i + 1) % 2) else "full_attention" - for i in range(num_attention_pooler_hidden_layers) - ] - - self.embed_tokens = nn.Linear(hidden_size, hidden_size) - self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) - self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) - self.head_dim = head_dim - self.rope_theta = rope_theta - self.sliding_window = sliding_window - self.layers = nn.ModuleList( - [ - AceStepEncoderLayer( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - head_dim=head_dim, - intermediate_size=intermediate_size, - attention_bias=attention_bias, - attention_dropout=attention_dropout, - rms_norm_eps=rms_norm_eps, - sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None, - ) - for i in range(num_attention_pooler_hidden_layers) - ] - ) - self._layer_types = layer_types - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, num_patches, patch_size, _ = hidden_states.shape - hidden_states = self.embed_tokens(hidden_states) - special_token = self.special_token.to(device=hidden_states.device, dtype=hidden_states.dtype) - special_token = special_token.expand(batch_size, num_patches, -1, -1) - hidden_states = torch.cat([special_token, hidden_states], dim=2) - hidden_states = hidden_states.reshape(batch_size * num_patches, patch_size + 1, -1) - - seq_len = hidden_states.shape[1] - dtype = hidden_states.dtype - device = hidden_states.device - position_embeddings = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype) - sliding_attn_mask = None - if not _is_flash_attention_backend(self.layers[0].self_attn.processor): - sliding_attn_mask = _create_4d_mask( - seq_len=seq_len, - dtype=dtype, - device=device, - attention_mask=None, - sliding_window=self.sliding_window, - is_sliding_window=True, - is_causal=False, - ) - - for i, layer_module in enumerate(self.layers): - mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else None - hidden_states = layer_module( - hidden_states=hidden_states, - position_embeddings=position_embeddings, - attention_mask=mask, - ) - - hidden_states = self.norm(hidden_states) - hidden_states = hidden_states[:, 0, :] - return hidden_states.reshape(batch_size, num_patches, -1) - - -class AceStepAudioTokenDetokenizer(ModelMixin, ConfigMixin): - """Expands ACE-Step 5 Hz audio tokens back to 25 Hz acoustic conditioning.""" - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - hidden_size: int = 2048, - intermediate_size: int = 6144, - audio_acoustic_hidden_dim: int = 64, - pool_window_size: int = 5, - num_attention_pooler_hidden_layers: int = 2, - num_attention_heads: int = 16, - num_key_value_heads: int = 8, - head_dim: int = 128, - rope_theta: float = 1000000.0, - attention_bias: bool = False, - attention_dropout: float = 0.0, - rms_norm_eps: float = 1e-6, - sliding_window: int = 128, - layer_types: list = None, - ): - super().__init__() - - if layer_types is None: - layer_types = [ - "sliding_attention" if bool((i + 1) % 2) else "full_attention" - for i in range(num_attention_pooler_hidden_layers) - ] - - self.embed_tokens = nn.Linear(hidden_size, hidden_size) - self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) - self.special_tokens = nn.Parameter(torch.randn(1, pool_window_size, hidden_size) * 0.02) - self.proj_out = nn.Linear(hidden_size, audio_acoustic_hidden_dim) - self.head_dim = head_dim - self.rope_theta = rope_theta - self.sliding_window = sliding_window - self.pool_window_size = pool_window_size - self.layers = nn.ModuleList( - [ - AceStepEncoderLayer( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - head_dim=head_dim, - intermediate_size=intermediate_size, - attention_bias=attention_bias, - attention_dropout=attention_dropout, - rms_norm_eps=rms_norm_eps, - sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None, - ) - for i in range(num_attention_pooler_hidden_layers) - ] - ) - self._layer_types = layer_types - self.gradient_checkpointing = False - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, num_tokens, _ = hidden_states.shape - hidden_states = self.embed_tokens(hidden_states) - hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.pool_window_size, -1) - special_tokens = self.special_tokens.to(device=hidden_states.device, dtype=hidden_states.dtype) - hidden_states = hidden_states + special_tokens.unsqueeze(0) - hidden_states = hidden_states.reshape(batch_size * num_tokens, self.pool_window_size, -1) - - seq_len = hidden_states.shape[1] - dtype = hidden_states.dtype - device = hidden_states.device - position_embeddings = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype) - sliding_attn_mask = None - if not _is_flash_attention_backend(self.layers[0].self_attn.processor): - sliding_attn_mask = _create_4d_mask( - seq_len=seq_len, - dtype=dtype, - device=device, - attention_mask=None, - sliding_window=self.sliding_window, - is_sliding_window=True, - is_causal=False, - ) - - for i, layer_module in enumerate(self.layers): - mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else None - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( - layer_module, hidden_states, position_embeddings, mask - ) - else: - hidden_states = layer_module( - hidden_states=hidden_states, - position_embeddings=position_embeddings, - attention_mask=mask, - ) - - hidden_states = self.norm(hidden_states) - hidden_states = self.proj_out(hidden_states) - return hidden_states.reshape(batch_size, num_tokens * self.pool_window_size, -1) - - -class AceStepAudioTokenizer(ModelMixin, ConfigMixin): - """Converts 25 Hz acoustic latents to ACE-Step 5 Hz audio tokens.""" - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - hidden_size: int = 2048, - intermediate_size: int = 6144, - audio_acoustic_hidden_dim: int = 64, - pool_window_size: int = 5, - fsq_dim: int = 2048, - fsq_input_levels: list = None, - fsq_input_num_quantizers: int = 1, - num_attention_pooler_hidden_layers: int = 2, - num_attention_heads: int = 16, - num_key_value_heads: int = 8, - head_dim: int = 128, - rope_theta: float = 1000000.0, - attention_bias: bool = False, - attention_dropout: float = 0.0, - rms_norm_eps: float = 1e-6, - sliding_window: int = 128, - layer_types: list = None, - ): - super().__init__() - - if fsq_input_levels is None: - fsq_input_levels = [8, 8, 8, 5, 5, 5] - - self.audio_acoustic_proj = nn.Linear(audio_acoustic_hidden_dim, hidden_size) - self.attention_pooler = AceStepAttentionPooler( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - head_dim=head_dim, - rope_theta=rope_theta, - attention_bias=attention_bias, - attention_dropout=attention_dropout, - rms_norm_eps=rms_norm_eps, - sliding_window=sliding_window, - layer_types=layer_types, - ) - self.quantizer = _AceStepResidualFSQ( - dim=fsq_dim, - levels=fsq_input_levels, - num_quantizers=fsq_input_num_quantizers, +from ...models.autoencoders.audio_tokenizer_ace_step import ( + AceStepAudioTokenDetokenizer as _AceStepAudioTokenDetokenizer, +) +from ...models.autoencoders.audio_tokenizer_ace_step import ( + AceStepAudioTokenizer as _AceStepAudioTokenizer, +) +from ...models.condition_embedders.condition_encoder_ace_step import ( + AceStepConditionEncoder as _AceStepConditionEncoder, +) +from ...models.condition_embedders.condition_encoder_ace_step import ( + AceStepEncoderLayer, # noqa: F401 re-exported for back-compat + _pack_sequences, # noqa: F401 re-exported for back-compat +) +from ...models.condition_embedders.condition_encoder_ace_step import ( + AceStepLyricEncoder as _AceStepLyricEncoder, +) +from ...models.condition_embedders.condition_encoder_ace_step import ( + AceStepTimbreEncoder as _AceStepTimbreEncoder, +) +from ...utils import deprecate + + +# The deprecation warning is emitted from ``__new__`` rather than ``__init__`` so the shim does not +# override the parent's ``__init__`` signature — ``ConfigMixin.extract_init_dict`` reflects on +# ``inspect.signature(cls.__init__)`` to decide which saved config keys to forward at +# ``from_pretrained`` time, and an ``__init__(self, *args, **kwargs)`` override would erase them all. +class AceStepAudioTokenizer(_AceStepAudioTokenizer): + def __new__(cls, *args, **kwargs): + deprecate( + "AceStepAudioTokenizer", + "1.0.0", + "Importing `AceStepAudioTokenizer` from `diffusers.pipelines.ace_step.modeling_ace_step` is deprecated. " + "Import it from `diffusers.models.autoencoders` instead " + "(or `from diffusers import AceStepAudioTokenizer`).", ) - self.pool_window_size = pool_window_size - - def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - input_dtype = hidden_states.dtype - hidden_states = self.audio_acoustic_proj(hidden_states) - hidden_states = self.attention_pooler(hidden_states) - quantized, indices = self.quantizer(hidden_states) - return quantized.to(dtype=input_dtype), indices - - def tokenize( - self, - hidden_states: torch.Tensor, - silence_latent: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size, latent_length, acoustic_dim = hidden_states.shape - pad_len = (-latent_length) % self.pool_window_size - if pad_len: - if silence_latent is not None and silence_latent.shape[-1] == acoustic_dim: - pad = silence_latent[:, :pad_len, :].to(device=hidden_states.device, dtype=hidden_states.dtype) - pad = pad.expand(batch_size, -1, -1) - else: - pad = torch.zeros( - batch_size, pad_len, acoustic_dim, device=hidden_states.device, dtype=hidden_states.dtype - ) - hidden_states = torch.cat([hidden_states, pad], dim=1) - - num_patches = hidden_states.shape[1] // self.pool_window_size - hidden_states = hidden_states.reshape(batch_size, num_patches, self.pool_window_size, acoustic_dim) - return self(hidden_states) - - -# --------------------------------------------------------------------------- # -# condition encoder # -# --------------------------------------------------------------------------- # - - -class AceStepConditionEncoder(ModelMixin, ConfigMixin): - """Fuses text + lyric + timbre conditioning into the packed sequence used by - the DiT's cross-attention. - """ + return super().__new__(cls) - _supports_gradient_checkpointing = True - @register_to_config - def __init__( - self, - hidden_size: int = 2048, - intermediate_size: int = 6144, - text_hidden_dim: int = 1024, - timbre_hidden_dim: int = 64, - num_lyric_encoder_hidden_layers: int = 8, - num_timbre_encoder_hidden_layers: int = 4, - num_attention_heads: int = 16, - num_key_value_heads: int = 8, - head_dim: int = 128, - rope_theta: float = 1000000.0, - attention_bias: bool = False, - attention_dropout: float = 0.0, - rms_norm_eps: float = 1e-6, - sliding_window: int = 128, - layer_types: list = None, - ): - super().__init__() - - self.text_projector = nn.Linear(text_hidden_dim, hidden_size, bias=False) - - self.lyric_encoder = AceStepLyricEncoder( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - text_hidden_dim=text_hidden_dim, - num_lyric_encoder_hidden_layers=num_lyric_encoder_hidden_layers, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - head_dim=head_dim, - rope_theta=rope_theta, - attention_bias=attention_bias, - attention_dropout=attention_dropout, - rms_norm_eps=rms_norm_eps, - sliding_window=sliding_window, - layer_types=layer_types, - ) - - self.timbre_encoder = AceStepTimbreEncoder( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - timbre_hidden_dim=timbre_hidden_dim, - num_timbre_encoder_hidden_layers=num_timbre_encoder_hidden_layers, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - head_dim=head_dim, - rope_theta=rope_theta, - attention_bias=attention_bias, - attention_dropout=attention_dropout, - rms_norm_eps=rms_norm_eps, - sliding_window=sliding_window, +class AceStepAudioTokenDetokenizer(_AceStepAudioTokenDetokenizer): + def __new__(cls, *args, **kwargs): + deprecate( + "AceStepAudioTokenDetokenizer", + "1.0.0", + "Importing `AceStepAudioTokenDetokenizer` from `diffusers.pipelines.ace_step.modeling_ace_step` is deprecated. " + "Import it from `diffusers.models.autoencoders` instead " + "(or `from diffusers import AceStepAudioTokenDetokenizer`).", ) + return super().__new__(cls) - # Learned null-condition embedding for classifier-free guidance, trained with - # `cfg_ratio=0.15` in the original model. Broadcast along the sequence dim when used. - self.null_condition_emb = nn.Parameter(torch.randn(1, 1, hidden_size)) - # Silence latent — VAE-encoded audio-silence, stored as (1, T_long, timbre_hidden_dim). - # When no reference audio is provided, the pipeline slices `silence_latent[:, :timbre_fix_frame, :]` - # and feeds that to the timbre encoder. Passing literal zeros puts the timbre encoder - # OOD and produces drone-like audio (observed on all text2music outputs before this fix). - # The placeholder here is overwritten by the converter with the real encoded silence, - # so its shape just needs to match the timbre-encoder input: last dim is - # `timbre_hidden_dim` (so smaller test configs with `timbre_hidden_dim != 64` also load). - self.register_buffer( - "silence_latent", - torch.zeros(1, 15000, timbre_hidden_dim), - persistent=True, +class AceStepConditionEncoder(_AceStepConditionEncoder): + def __new__(cls, *args, **kwargs): + deprecate( + "AceStepConditionEncoder", + "1.0.0", + "Importing `AceStepConditionEncoder` from `diffusers.pipelines.ace_step.modeling_ace_step` is deprecated. " + "Import it from `diffusers.models.condition_embedders` instead " + "(or `from diffusers import AceStepConditionEncoder`).", ) + return super().__new__(cls) - def forward( - self, - text_hidden_states: torch.FloatTensor, - text_attention_mask: torch.Tensor, - lyric_hidden_states: torch.FloatTensor, - lyric_attention_mask: torch.Tensor, - refer_audio_acoustic_hidden_states_packed: torch.FloatTensor, - refer_audio_order_mask: torch.LongTensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - text_hidden_states = self.text_projector(text_hidden_states) - lyric_hidden_states = self.lyric_encoder( - inputs_embeds=lyric_hidden_states, attention_mask=lyric_attention_mask +class AceStepLyricEncoder(_AceStepLyricEncoder): + def __new__(cls, *args, **kwargs): + deprecate( + "AceStepLyricEncoder", + "1.0.0", + "Importing `AceStepLyricEncoder` from `diffusers.pipelines.ace_step.modeling_ace_step` is deprecated. " + "Import it from `diffusers.models.condition_embedders.condition_encoder_ace_step` instead.", ) + return super().__new__(cls) - timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder( - refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask - ) - encoder_hidden_states, encoder_attention_mask = _pack_sequences( - lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask - ) - encoder_hidden_states, encoder_attention_mask = _pack_sequences( - encoder_hidden_states, text_hidden_states, encoder_attention_mask, text_attention_mask +class AceStepTimbreEncoder(_AceStepTimbreEncoder): + def __new__(cls, *args, **kwargs): + deprecate( + "AceStepTimbreEncoder", + "1.0.0", + "Importing `AceStepTimbreEncoder` from `diffusers.pipelines.ace_step.modeling_ace_step` is deprecated. " + "Import it from `diffusers.models.condition_embedders.condition_encoder_ace_step` instead.", ) - - return encoder_hidden_states, encoder_attention_mask + return super().__new__(cls) diff --git a/src/diffusers/pipelines/ace_step/pipeline_ace_step.py b/src/diffusers/pipelines/ace_step/pipeline_ace_step.py index 1946f148f390..2d3c33ed1a31 100644 --- a/src/diffusers/pipelines/ace_step/pipeline_ace_step.py +++ b/src/diffusers/pipelines/ace_step/pipeline_ace_step.py @@ -21,12 +21,13 @@ from ...guiders.adaptive_projected_guidance import MomentumBuffer, normalized_guidance from ...models import AutoencoderOobleck +from ...models.autoencoders.audio_tokenizer_ace_step import AceStepAudioTokenDetokenizer, AceStepAudioTokenizer +from ...models.condition_embedders.condition_encoder_ace_step import AceStepConditionEncoder from ...models.transformers.ace_step_transformer import AceStepTransformer1DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline -from .modeling_ace_step import AceStepAudioTokenDetokenizer, AceStepAudioTokenizer, AceStepConditionEncoder logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index 09aa0ad17003..6e7e676ce2ad 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -12,1400 +12,49 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Any - -import torch -import torch.nn as nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import UNet2DConditionLoadersMixin -from ...models.activations import get_activation -from ...models.attention import AttentionMixin -from ...models.attention_processor import ( - ADDED_KV_ATTENTION_PROCESSORS, - CROSS_ATTENTION_PROCESSORS, - AttnAddedKVProcessor, - AttnProcessor, +from ...models.condition_embedders.projection_audioldm2 import ( + AudioLDM2ProjectionModel as _AudioLDM2ProjectionModel, ) -from ...models.embeddings import TimestepEmbedding, Timesteps -from ...models.modeling_utils import ModelMixin -from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D -from ...models.transformers.transformer_2d import Transformer2DModel -from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D -from ...models.unets.unet_2d_condition import UNet2DConditionOutput -from ...utils import BaseOutput, logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def add_special_tokens(hidden_states, attention_mask, sos_token, eos_token): - batch_size = hidden_states.shape[0] - - if attention_mask is not None: - # Add two more steps to attn mask - new_attn_mask_step = attention_mask.new_ones((batch_size, 1)) - attention_mask = torch.concat([new_attn_mask_step, attention_mask, new_attn_mask_step], dim=-1) - - # Add the SOS / EOS tokens at the start / end of the sequence respectively - sos_token = sos_token.expand(batch_size, 1, -1) - eos_token = eos_token.expand(batch_size, 1, -1) - hidden_states = torch.concat([sos_token, hidden_states, eos_token], dim=1) - return hidden_states, attention_mask - - -@dataclass -class AudioLDM2ProjectionModelOutput(BaseOutput): - """ - Args: - Class for AudioLDM2 projection layer's outputs. - hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states obtained by linearly projecting the hidden-states for each of the text - encoders and subsequently concatenating them together. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices, formed by concatenating the attention masks - for the two text encoders together. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - """ - - hidden_states: torch.Tensor - attention_mask: torch.LongTensor | None = None - - -class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin): - """ - A simple linear projection model to map two text embeddings to a shared latent space. It also inserts learned - embedding vectors at the start and end of each text embedding sequence respectively. Each variable appended with - `_1` refers to that corresponding to the second text encoder. Otherwise, it is from the first. - - Args: - text_encoder_dim (`int`): - Dimensionality of the text embeddings from the first text encoder (CLAP). - text_encoder_1_dim (`int`): - Dimensionality of the text embeddings from the second text encoder (T5 or VITS). - langauge_model_dim (`int`): - Dimensionality of the text embeddings from the language model (GPT2). - """ - - @register_to_config - def __init__( - self, - text_encoder_dim, - text_encoder_1_dim, - langauge_model_dim, - use_learned_position_embedding=None, - max_seq_length=None, - ): - super().__init__() - # additional projection layers for each text encoder - self.projection = nn.Linear(text_encoder_dim, langauge_model_dim) - self.projection_1 = nn.Linear(text_encoder_1_dim, langauge_model_dim) - - # learnable SOS / EOS token embeddings for each text encoder - self.sos_embed = nn.Parameter(torch.ones(langauge_model_dim)) - self.eos_embed = nn.Parameter(torch.ones(langauge_model_dim)) - - self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim)) - self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim)) - - self.use_learned_position_embedding = use_learned_position_embedding - - # learable positional embedding for vits encoder - if self.use_learned_position_embedding is not None: - self.learnable_positional_embedding = torch.nn.Parameter( - torch.zeros((1, text_encoder_1_dim, max_seq_length)) - ) - - def forward( - self, - hidden_states: torch.Tensor | None = None, - hidden_states_1: torch.Tensor | None = None, - attention_mask: torch.LongTensor | None = None, - attention_mask_1: torch.LongTensor | None = None, - ): - hidden_states = self.projection(hidden_states) - hidden_states, attention_mask = add_special_tokens( - hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed - ) - - # Add positional embedding for Vits hidden state - if self.use_learned_position_embedding is not None: - hidden_states_1 = (hidden_states_1.permute(0, 2, 1) + self.learnable_positional_embedding).permute(0, 2, 1) - - hidden_states_1 = self.projection_1(hidden_states_1) - hidden_states_1, attention_mask_1 = add_special_tokens( - hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1 - ) - - # concatenate clap and t5 text encoding - hidden_states = torch.cat([hidden_states, hidden_states_1], dim=1) - - # concatenate attention masks - if attention_mask is None and attention_mask_1 is not None: - attention_mask = attention_mask_1.new_ones((hidden_states[:2])) - elif attention_mask is not None and attention_mask_1 is None: - attention_mask_1 = attention_mask.new_ones((hidden_states_1[:2])) - - if attention_mask is not None and attention_mask_1 is not None: - attention_mask = torch.cat([attention_mask, attention_mask_1], dim=-1) - else: - attention_mask = None - - return AudioLDM2ProjectionModelOutput( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - - -class AudioLDM2UNet2DConditionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin): - r""" - A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample - shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional - self-attention layer in each Transformer block, as well as multiple cross-attention layers. It also allows for up - to two cross-attention embeddings, `encoder_hidden_states` and `encoder_hidden_states_1`. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - sample_size (`int` or `tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - Block type for middle of UNet, it can only be `UNetMidBlock2DCrossAttn` for AudioLDM2. - up_block_types (`tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): - The tuple of upsample blocks to use. - only_cross_attention (`bool` or `tuple[bool]`, *optional*, default to `False`): - Whether to include self-attention in the basic transformer blocks, see - [`~models.attention.BasicTransformerBlock`]. - block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - If `None`, normalization and activation layers is skipped in post-processing. - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int` or `tuple[int]`, *optional*, defaults to 1280): - The dimension of the cross attention features. - transformer_layers_per_block (`int` or `tuple[int]`, *optional*, defaults to 1): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - num_attention_heads (`int`, *optional*): - The number of attention heads. If not defined, defaults to `attention_head_dim` - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to `None`): - The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, - `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. - num_class_embeds (`int`, *optional*, defaults to `None`): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - time_embedding_type (`str`, *optional*, defaults to `positional`): - The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. - time_embedding_dim (`int`, *optional*, defaults to `None`): - An optional override for the dimension of the projected time embedding. - time_embedding_act_fn (`str`, *optional*, defaults to `None`): - Optional activation function to use only once on the time embeddings before they are passed to the rest of - the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. - timestep_post_act (`str`, *optional*, defaults to `None`): - The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. - time_cond_proj_dim (`int`, *optional*, defaults to `None`): - The dimension of `cond_proj` layer in the timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. - projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when - `class_embed_type="projection"`. Required when `class_embed_type="projection"`. - class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time - embeddings with the class embeddings. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: int | None = None, - in_channels: int = 4, - out_channels: int = 4, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: str = "UNetMidBlock2DCrossAttn", - up_block_types: tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - only_cross_attention: bool | tuple[bool] = False, - block_out_channels: tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int | tuple[int] = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: int | None = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int | tuple[int] = 1280, - transformer_layers_per_block: int | tuple[int] = 1, - attention_head_dim: int | tuple[int] = 8, - num_attention_heads: int | tuple[int] | None = None, - use_linear_projection: bool = False, - class_embed_type: str | None = None, - num_class_embeds: int | None = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - time_embedding_type: str = "positional", - time_embedding_dim: int | None = None, - time_embedding_act_fn: str | None = None, - timestep_post_act: str | None = None, - time_cond_proj_dim: int | None = None, - conv_in_kernel: int = 3, - conv_out_kernel: int = 3, - projection_class_embeddings_input_dim: int | None = None, - class_embeddings_concat: bool = False, - ): - super().__init__() - - self.sample_size = sample_size - - if num_attention_heads is not None: - raise ValueError( - "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." - ) - - # If `num_attention_heads` is not defined (which is the case for most models) - # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. - # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 - # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking - # which is why we correct for the naming here. - num_attention_heads = num_attention_heads or attention_head_dim - - # Check inputs - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) - - if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." - ) - - # input - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) - - # time - if time_embedding_type == "positional": - time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - else: - raise ValueError(f"{time_embedding_type} does not exist. Please make sure to use `positional`.") - - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - post_act_fn=timestep_post_act, - cond_proj_dim=time_cond_proj_dim, - ) - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - elif class_embed_type == "simple_projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" - ) - self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None - - if time_embedding_act_fn is None: - self.time_embed_act = None - else: - self.time_embed_act = get_activation(time_embedding_act_fn) - - self.down_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if isinstance(num_attention_heads, int): - num_attention_heads = (num_attention_heads,) * len(down_block_types) - - if isinstance(cross_attention_dim, int): - cross_attention_dim = (cross_attention_dim,) * len(down_block_types) - - if isinstance(layers_per_block, int): - layers_per_block = [layers_per_block] * len(down_block_types) - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - - if class_embeddings_concat: - # The time embeddings are concatenated with the class embeddings. The dimension of the - # time embeddings passed to the down, middle, and up blocks is twice the dimension of the - # regular time embeddings - blocks_time_embed_dim = time_embed_dim * 2 - else: - blocks_time_embed_dim = time_embed_dim - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block[i], - transformer_layers_per_block=transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - temb_channels=blocks_time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim[i], - num_attention_heads=num_attention_heads[i], - downsample_padding=downsample_padding, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.down_blocks.append(down_block) - - # mid - if mid_block_type == "UNetMidBlock2DCrossAttn": - self.mid_block = UNetMidBlock2DCrossAttn( - transformer_layers_per_block=transformer_layers_per_block[-1], - in_channels=block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim[-1], - num_attention_heads=num_attention_heads[-1], - resnet_groups=norm_num_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - else: - raise ValueError( - f"unknown mid_block_type : {mid_block_type}. Should be `UNetMidBlock2DCrossAttn` for AudioLDM2." - ) - - # count how many layers upsample the images - self.num_upsamplers = 0 - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_num_attention_heads = list(reversed(num_attention_heads)) - reversed_layers_per_block = list(reversed(layers_per_block)) - reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) - only_cross_attention = list(reversed(only_cross_attention)) - - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - - up_block = get_up_block( - up_block_type, - num_layers=reversed_layers_per_block[i] + 1, - transformer_layers_per_block=reversed_transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=blocks_time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=reversed_cross_attention_dim[i], - num_attention_heads=reversed_num_attention_heads[i], - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_num_groups is not None: - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps - ) - - self.conv_act = get_activation(act_fn) - - else: - self.conv_norm_out = None - self.conv_act = None - - conv_out_padding = (conv_out_kernel - 1) // 2 - self.conv_out = nn.Conv2d( - block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding - ) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnAddedKVProcessor() - elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() - else: - raise ValueError( - f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" - ) - - self.set_attn_processor(processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_sliceable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_sliceable_dims(module) - - num_sliceable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_sliceable_layers * [1] - - slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def forward( - self, - sample: torch.Tensor, - timestep: torch.Tensor | float | int, - encoder_hidden_states: torch.Tensor, - class_labels: torch.Tensor | None = None, - timestep_cond: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - cross_attention_kwargs: dict[str, Any] | None = None, - encoder_attention_mask: torch.Tensor | None = None, - return_dict: bool = True, - encoder_hidden_states_1: torch.Tensor | None = None, - encoder_attention_mask_1: torch.Tensor | None = None, - ) -> UNet2DConditionOutput | tuple: - r""" - The [`AudioLDM2UNet2DConditionModel`] forward method. - - Args: - sample (`torch.Tensor`): - The noisy input tensor with the following shape `(batch, channel, height, width)`. - timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. - encoder_hidden_states (`torch.Tensor`): - The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. - encoder_attention_mask (`torch.Tensor`): - A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If - `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, - which adds large negative values to the attention scores corresponding to "discard" tokens. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. - encoder_hidden_states_1 (`torch.Tensor`, *optional*): - A second set of encoder hidden states with shape `(batch, sequence_length_2, feature_dim_2)`. Can be - used to condition the model on a different set of embeddings to `encoder_hidden_states`. - encoder_attention_mask_1 (`torch.Tensor`, *optional*): - A cross-attention mask of shape `(batch, sequence_length_2)` is applied to `encoder_hidden_states_1`. - If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, - which adds large negative values to the attention scores corresponding to "discard" tokens. - - Returns: - [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, - otherwise a `tuple` is returned where the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None: - encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - if encoder_attention_mask_1 is not None: - encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0 - encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1) - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - is_npu = sample.device.type == "npu" - if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - emb = self.time_embedding(t_emb, timestep_cond) - aug_emb = None - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # there might be better ways to encapsulate this. - class_labels = class_labels.to(dtype=sample.dtype) - - class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) - - if self.config.class_embeddings_concat: - emb = torch.cat([emb, class_emb], dim=-1) - else: - emb = emb + class_emb - - emb = emb + aug_emb if aug_emb is not None else emb - - if self.time_embed_act is not None: - emb = self.time_embed_act(emb) - - # 2. pre-process - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - encoder_hidden_states_1=encoder_hidden_states_1, - encoder_attention_mask_1=encoder_attention_mask_1, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - encoder_hidden_states_1=encoder_hidden_states_1, - encoder_attention_mask_1=encoder_attention_mask_1, - ) - - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - upsample_size=upsample_size, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - encoder_hidden_states_1=encoder_hidden_states_1, - encoder_attention_mask_1=encoder_attention_mask_1, - ) - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - - # 6. post-process - if self.conv_norm_out: - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) - - -def get_down_block( - down_block_type, - num_layers, - in_channels, - out_channels, - temb_channels, - add_downsample, - resnet_eps, - resnet_act_fn, - transformer_layers_per_block=1, - num_attention_heads=None, - resnet_groups=None, - cross_attention_dim=None, - downsample_padding=None, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", -): - down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type - if down_block_type == "DownBlock2D": - return DownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "CrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") - return CrossAttnDownBlock2D( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - raise ValueError(f"{down_block_type} does not exist.") - - -def get_up_block( - up_block_type, - num_layers, - in_channels, - out_channels, - prev_output_channel, - temb_channels, - add_upsample, - resnet_eps, - resnet_act_fn, - transformer_layers_per_block=1, - num_attention_heads=None, - resnet_groups=None, - cross_attention_dim=None, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", -): - up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - if up_block_type == "UpBlock2D": - return UpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "CrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") - return CrossAttnUpBlock2D( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - raise ValueError(f"{up_block_type} does not exist.") - - -class CrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - - if isinstance(cross_attention_dim, int): - cross_attention_dim = (cross_attention_dim,) - if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: - raise ValueError( - "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " - f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" - ) - self.cross_attention_dim = cross_attention_dim - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - for j in range(len(cross_attention_dim)): - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block, - cross_attention_dim=cross_attention_dim[j], - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - double_self_attention=True if cross_attention_dim[j] is None else False, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: torch.Tensor | None = None, - encoder_hidden_states: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - cross_attention_kwargs: dict[str, Any] | None = None, - encoder_attention_mask: torch.Tensor | None = None, - encoder_hidden_states_1: torch.Tensor | None = None, - encoder_attention_mask_1: torch.Tensor | None = None, - ): - output_states = () - num_layers = len(self.resnets) - num_attention_per_layer = len(self.attentions) // num_layers - - encoder_hidden_states_1 = ( - encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states - ) - encoder_attention_mask_1 = ( - encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask - ) - - for i in range(num_layers): - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb) - for idx, cross_attention_dim in enumerate(self.cross_attention_dim): - if cross_attention_dim is not None and idx <= 1: - forward_encoder_hidden_states = encoder_hidden_states - forward_encoder_attention_mask = encoder_attention_mask - elif cross_attention_dim is not None and idx > 1: - forward_encoder_hidden_states = encoder_hidden_states_1 - forward_encoder_attention_mask = encoder_attention_mask_1 - else: - forward_encoder_hidden_states = None - forward_encoder_attention_mask = None - hidden_states = self._gradient_checkpointing_func( - self.attentions[i * num_attention_per_layer + idx], - hidden_states, - forward_encoder_hidden_states, - None, # timestep - None, # class_labels - cross_attention_kwargs, - attention_mask, - forward_encoder_attention_mask, - )[0] - else: - hidden_states = self.resnets[i](hidden_states, temb) - for idx, cross_attention_dim in enumerate(self.cross_attention_dim): - if cross_attention_dim is not None and idx <= 1: - forward_encoder_hidden_states = encoder_hidden_states - forward_encoder_attention_mask = encoder_attention_mask - elif cross_attention_dim is not None and idx > 1: - forward_encoder_hidden_states = encoder_hidden_states_1 - forward_encoder_attention_mask = encoder_attention_mask_1 - else: - forward_encoder_hidden_states = None - forward_encoder_attention_mask = None - hidden_states = self.attentions[i * num_attention_per_layer + idx]( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=forward_encoder_hidden_states, - encoder_attention_mask=forward_encoder_attention_mask, - return_dict=False, - )[0] - - output_states = output_states + (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states = output_states + (hidden_states,) - - return hidden_states, output_states - - -class UNetMidBlock2DCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - use_linear_projection=False, - upcast_attention=False, - ): - super().__init__() - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - if isinstance(cross_attention_dim, int): - cross_attention_dim = (cross_attention_dim,) - if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: - raise ValueError( - "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " - f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" - ) - self.cross_attention_dim = cross_attention_dim - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for i in range(num_layers): - for j in range(len(cross_attention_dim)): - attentions.append( - Transformer2DModel( - num_attention_heads, - in_channels // num_attention_heads, - in_channels=in_channels, - num_layers=transformer_layers_per_block, - cross_attention_dim=cross_attention_dim[j], - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - double_self_attention=True if cross_attention_dim[j] is None else False, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: torch.Tensor | None = None, - encoder_hidden_states: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - cross_attention_kwargs: dict[str, Any] | None = None, - encoder_attention_mask: torch.Tensor | None = None, - encoder_hidden_states_1: torch.Tensor | None = None, - encoder_attention_mask_1: torch.Tensor | None = None, - ) -> torch.Tensor: - hidden_states = self.resnets[0](hidden_states, temb) - num_attention_per_layer = len(self.attentions) // (len(self.resnets) - 1) - - encoder_hidden_states_1 = ( - encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states - ) - encoder_attention_mask_1 = ( - encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask - ) - - for i in range(len(self.resnets[1:])): - if torch.is_grad_enabled() and self.gradient_checkpointing: - for idx, cross_attention_dim in enumerate(self.cross_attention_dim): - if cross_attention_dim is not None and idx <= 1: - forward_encoder_hidden_states = encoder_hidden_states - forward_encoder_attention_mask = encoder_attention_mask - elif cross_attention_dim is not None and idx > 1: - forward_encoder_hidden_states = encoder_hidden_states_1 - forward_encoder_attention_mask = encoder_attention_mask_1 - else: - forward_encoder_hidden_states = None - forward_encoder_attention_mask = None - hidden_states = self._gradient_checkpointing_func( - self.attentions[i * num_attention_per_layer + idx], - hidden_states, - forward_encoder_hidden_states, - None, # timestep - None, # class_labels - cross_attention_kwargs, - attention_mask, - forward_encoder_attention_mask, - )[0] - hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb) - else: - for idx, cross_attention_dim in enumerate(self.cross_attention_dim): - if cross_attention_dim is not None and idx <= 1: - forward_encoder_hidden_states = encoder_hidden_states - forward_encoder_attention_mask = encoder_attention_mask - elif cross_attention_dim is not None and idx > 1: - forward_encoder_hidden_states = encoder_hidden_states_1 - forward_encoder_attention_mask = encoder_attention_mask_1 - else: - forward_encoder_hidden_states = None - forward_encoder_attention_mask = None - hidden_states = self.attentions[i * num_attention_per_layer + idx]( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=forward_encoder_hidden_states, - encoder_attention_mask=forward_encoder_attention_mask, - return_dict=False, - )[0] - - hidden_states = self.resnets[i + 1](hidden_states, temb) - - return hidden_states - - -class CrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - num_attention_heads=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - - if isinstance(cross_attention_dim, int): - cross_attention_dim = (cross_attention_dim,) - if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: - raise ValueError( - "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " - f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" - ) - self.cross_attention_dim = cross_attention_dim - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - for j in range(len(cross_attention_dim)): - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block, - cross_attention_dim=cross_attention_dim[j], - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - double_self_attention=True if cross_attention_dim[j] is None else False, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_tuple: tuple[torch.Tensor, ...], - temb: torch.Tensor | None = None, - encoder_hidden_states: torch.Tensor | None = None, - cross_attention_kwargs: dict[str, Any] | None = None, - upsample_size: int | None = None, - attention_mask: torch.Tensor | None = None, - encoder_attention_mask: torch.Tensor | None = None, - encoder_hidden_states_1: torch.Tensor | None = None, - encoder_attention_mask_1: torch.Tensor | None = None, - ): - num_layers = len(self.resnets) - num_attention_per_layer = len(self.attentions) // num_layers - - encoder_hidden_states_1 = ( - encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states - ) - encoder_attention_mask_1 = ( - encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask - ) - - for i in range(num_layers): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb) - for idx, cross_attention_dim in enumerate(self.cross_attention_dim): - if cross_attention_dim is not None and idx <= 1: - forward_encoder_hidden_states = encoder_hidden_states - forward_encoder_attention_mask = encoder_attention_mask - elif cross_attention_dim is not None and idx > 1: - forward_encoder_hidden_states = encoder_hidden_states_1 - forward_encoder_attention_mask = encoder_attention_mask_1 - else: - forward_encoder_hidden_states = None - forward_encoder_attention_mask = None - hidden_states = self._gradient_checkpointing_func( - self.attentions[i * num_attention_per_layer + idx], - hidden_states, - forward_encoder_hidden_states, - None, # timestep - None, # class_labels - cross_attention_kwargs, - attention_mask, - forward_encoder_attention_mask, - )[0] - else: - hidden_states = self.resnets[i](hidden_states, temb) - for idx, cross_attention_dim in enumerate(self.cross_attention_dim): - if cross_attention_dim is not None and idx <= 1: - forward_encoder_hidden_states = encoder_hidden_states - forward_encoder_attention_mask = encoder_attention_mask - elif cross_attention_dim is not None and idx > 1: - forward_encoder_hidden_states = encoder_hidden_states_1 - forward_encoder_attention_mask = encoder_attention_mask_1 - else: - forward_encoder_hidden_states = None - forward_encoder_attention_mask = None - hidden_states = self.attentions[i * num_attention_per_layer + idx]( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=forward_encoder_hidden_states, - encoder_attention_mask=forward_encoder_attention_mask, - return_dict=False, - )[0] - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states +from ...models.condition_embedders.projection_audioldm2 import ( + AudioLDM2ProjectionModelOutput, # noqa: F401 re-exported for back-compat + add_special_tokens, # noqa: F401 re-exported for back-compat +) +from ...models.unets.unet_2d_condition_audioldm2 import ( + AudioLDM2UNet2DConditionModel as _AudioLDM2UNet2DConditionModel, +) +from ...models.unets.unet_2d_condition_audioldm2 import ( + CrossAttnDownBlock2D, # noqa: F401 re-exported for back-compat + CrossAttnUpBlock2D, # noqa: F401 + UNetMidBlock2DCrossAttn, # noqa: F401 + get_down_block, # noqa: F401 + get_up_block, # noqa: F401 +) +from ...utils import deprecate + + +# The deprecation warning is emitted from ``__new__`` rather than ``__init__`` so the shim does not +# override the parent's ``__init__`` signature — ``ConfigMixin.extract_init_dict`` reflects on +# ``inspect.signature(cls.__init__)`` to decide which saved config keys to forward at +# ``from_pretrained`` time, and an ``__init__(self, *args, **kwargs)`` override would erase them all. +class AudioLDM2ProjectionModel(_AudioLDM2ProjectionModel): + def __new__(cls, *args, **kwargs): + deprecate( + "AudioLDM2ProjectionModel", + "1.0.0", + "Importing `AudioLDM2ProjectionModel` from `diffusers.pipelines.audioldm2.modeling_audioldm2` is " + "deprecated. Import it from `diffusers.models.condition_embedders` instead " + "(or `from diffusers import AudioLDM2ProjectionModel`).", + ) + return super().__new__(cls) + + +class AudioLDM2UNet2DConditionModel(_AudioLDM2UNet2DConditionModel): + def __new__(cls, *args, **kwargs): + deprecate( + "AudioLDM2UNet2DConditionModel", + "1.0.0", + "Importing `AudioLDM2UNet2DConditionModel` from `diffusers.pipelines.audioldm2.modeling_audioldm2` is " + "deprecated. Import it from `diffusers.models.unets` instead " + "(or `from diffusers import AudioLDM2UNet2DConditionModel`).", + ) + return super().__new__(cls) diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py index 6fb02433dace..acb3302646a6 100644 --- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py @@ -32,6 +32,8 @@ ) from ...models import AutoencoderKL +from ...models.condition_embedders import AudioLDM2ProjectionModel +from ...models.unets import AudioLDM2UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -44,7 +46,6 @@ from ...utils.import_utils import is_transformers_version from ...utils.torch_utils import empty_device_cache, randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline -from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel if is_librosa_available(): diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py index b8c70fc6528c..8ec1521bff7b 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py @@ -9,6 +9,7 @@ from ...loaders import StableDiffusionLoraLoaderMixin from ...models import UNet2DConditionModel +from ...models.others import IFWatermarker from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, @@ -22,7 +23,6 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import IFPipelineOutput from .safety_checker import IFSafetyChecker -from .watermark import IFWatermarker if is_torch_xla_available(): diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py index 3dadc63f4952..a36c2f6211f7 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py @@ -11,6 +11,7 @@ from ...loaders import StableDiffusionLoraLoaderMixin from ...models import UNet2DConditionModel +from ...models.others import IFWatermarker from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, @@ -25,7 +26,6 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import IFPipelineOutput from .safety_checker import IFSafetyChecker -from .watermark import IFWatermarker if is_torch_xla_available(): diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py index 4839a0860462..d96a85947951 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py @@ -12,6 +12,7 @@ from ...loaders import StableDiffusionLoraLoaderMixin from ...models import UNet2DConditionModel +from ...models.others import IFWatermarker from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, @@ -25,7 +26,6 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import IFPipelineOutput from .safety_checker import IFSafetyChecker -from .watermark import IFWatermarker if is_bs4_available(): diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py index 03a9d6f7c5e8..b1522d628736 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py @@ -11,6 +11,7 @@ from ...loaders import StableDiffusionLoraLoaderMixin from ...models import UNet2DConditionModel +from ...models.others import IFWatermarker from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, @@ -25,7 +26,6 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import IFPipelineOutput from .safety_checker import IFSafetyChecker -from .watermark import IFWatermarker if is_torch_xla_available(): diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py index 841382ad9c63..cfa522ccb1ca 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py @@ -12,6 +12,7 @@ from ...loaders import StableDiffusionLoraLoaderMixin from ...models import UNet2DConditionModel +from ...models.others import IFWatermarker from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, @@ -25,7 +26,6 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import IFPipelineOutput from .safety_checker import IFSafetyChecker -from .watermark import IFWatermarker if is_bs4_available(): diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py index 52ebebb6f9b4..02088aaa8524 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py @@ -12,6 +12,7 @@ from ...loaders import StableDiffusionLoraLoaderMixin from ...models import UNet2DConditionModel +from ...models.others import IFWatermarker from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, @@ -24,7 +25,6 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import IFPipelineOutput from .safety_checker import IFSafetyChecker -from .watermark import IFWatermarker if is_bs4_available(): diff --git a/src/diffusers/pipelines/deepfloyd_if/safety_checker.py b/src/diffusers/pipelines/deepfloyd_if/safety_checker.py index 8ffeed580bbe..bdb21fa1e2e4 100644 --- a/src/diffusers/pipelines/deepfloyd_if/safety_checker.py +++ b/src/diffusers/pipelines/deepfloyd_if/safety_checker.py @@ -22,6 +22,8 @@ def __init__(self, config: CLIPConfig): self.p_head = nn.Linear(config.vision_config.projection_dim, 1) self.w_head = nn.Linear(config.vision_config.projection_dim, 1) + self.post_init() + @torch.no_grad() def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5): image_embeds = self.vision_model(clip_input)[0] diff --git a/src/diffusers/pipelines/deepfloyd_if/watermark.py b/src/diffusers/pipelines/deepfloyd_if/watermark.py index d5fe99f681f7..ff51fe4ce211 100644 --- a/src/diffusers/pipelines/deepfloyd_if/watermark.py +++ b/src/diffusers/pipelines/deepfloyd_if/watermark.py @@ -1,44 +1,31 @@ -import PIL.Image -import torch -from PIL import Image - -from ...configuration_utils import ConfigMixin -from ...models.modeling_utils import ModelMixin -from ...utils import PIL_INTERPOLATION - - -class IFWatermarker(ModelMixin, ConfigMixin): - def __init__(self): - super().__init__() - - self.register_buffer("watermark_image", torch.zeros((62, 62, 4))) - self.watermark_image_as_pil = None - - def apply_watermark(self, images: list[PIL.Image.Image], sample_size=None): - # Copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287 - - h = images[0].height - w = images[0].width - - sample_size = sample_size or h - - coef = min(h / sample_size, w / sample_size) - img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w) - - S1, S2 = 1024**2, img_w * img_h - K = (S2 / S1) ** 0.5 - wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K) - - if self.watermark_image_as_pil is None: - watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy() - watermark_image = Image.fromarray(watermark_image, mode="RGBA") - self.watermark_image_as_pil = watermark_image - - wm_img = self.watermark_image_as_pil.resize( - (wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...models.others.watermark_if import IFWatermarker as _IFWatermarker +from ...utils import deprecate + + +# The deprecation warning is emitted from ``__new__`` rather than ``__init__`` so the shim does not +# override the parent's ``__init__`` signature — ``ConfigMixin.extract_init_dict`` reflects on +# ``inspect.signature(cls.__init__)`` to decide which saved config keys to forward at +# ``from_pretrained`` time, and an ``__init__(self, *args, **kwargs)`` override would erase them all. +class IFWatermarker(_IFWatermarker): + def __new__(cls, *args, **kwargs): + deprecate( + "IFWatermarker", + "1.0.0", + "Importing `IFWatermarker` from `diffusers.pipelines.deepfloyd_if.watermark` is deprecated. " + "Import it from `diffusers.models.others` instead.", ) - - for pil_img in images: - pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1]) - - return images + return super().__new__(cls) diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/deprecated/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py index d72d12a64945..00ebcda4defe 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py @@ -30,6 +30,7 @@ from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ....models import AutoencoderKL, UNet2DConditionModel from ....models.attention import GatedSelfAttentionDense +from ....models.condition_embedders import CLIPImageProjection from ....models.lora import adjust_lora_scale_text_encoder from ....schedulers import KarrasDiffusionSchedulers from ....utils import ( @@ -43,7 +44,6 @@ from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin from ...stable_diffusion import StableDiffusionPipelineOutput -from ...stable_diffusion.clip_image_project_model import CLIPImageProjection from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/flux/modeling_flux.py b/src/diffusers/pipelines/flux/modeling_flux.py index 916e3fbe5953..d78f280ad9bb 100644 --- a/src/diffusers/pipelines/flux/modeling_flux.py +++ b/src/diffusers/pipelines/flux/modeling_flux.py @@ -13,34 +13,26 @@ # limitations under the License. -from dataclasses import dataclass - -import torch -import torch.nn as nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.modeling_utils import ModelMixin -from ...utils import BaseOutput - - -@dataclass -class ReduxImageEncoderOutput(BaseOutput): - image_embeds: torch.Tensor | None = None - - -class ReduxImageEncoder(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - redux_dim: int = 1152, - txt_in_features: int = 4096, - ) -> None: - super().__init__() - - self.redux_up = nn.Linear(redux_dim, txt_in_features * 3) - self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features) - - def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput: - projected_x = self.redux_down(nn.functional.silu(self.redux_up(x))) - - return ReduxImageEncoderOutput(image_embeds=projected_x) +from ...models.condition_embedders.image_encoder_redux import ( + ReduxImageEncoder as _ReduxImageEncoder, +) +from ...models.condition_embedders.image_encoder_redux import ( + ReduxImageEncoderOutput, # noqa: F401 re-exported for back-compat +) +from ...utils import deprecate + + +# The deprecation warning is emitted from ``__new__`` rather than ``__init__`` so the shim does not +# override the parent's ``__init__`` signature — ``ConfigMixin.extract_init_dict`` reflects on +# ``inspect.signature(cls.__init__)`` to decide which saved config keys to forward at +# ``from_pretrained`` time, and an ``__init__(self, *args, **kwargs)`` override would erase them all. +class ReduxImageEncoder(_ReduxImageEncoder): + def __new__(cls, *args, **kwargs): + deprecate( + "ReduxImageEncoder", + "1.0.0", + "Importing `ReduxImageEncoder` from `diffusers.pipelines.flux.modeling_flux` is " + "deprecated. Import it from `diffusers.models.condition_embedders` instead " + "(or `from diffusers import ReduxImageEncoder`).", + ) + return super().__new__(cls) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 94c7bcc80782..11e3e5ca6382 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -26,6 +26,7 @@ from ...image_processor import PipelineImageInput from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin +from ...models.condition_embedders import ReduxImageEncoder from ...utils import ( USE_PEFT_BACKEND, is_torch_xla_available, @@ -35,7 +36,6 @@ unscale_lora_layers, ) from ..pipeline_utils import DiffusionPipeline -from .modeling_flux import ReduxImageEncoder from .pipeline_output import FluxPriorReduxPipelineOutput diff --git a/src/diffusers/pipelines/ltx/modeling_latent_upsampler.py b/src/diffusers/pipelines/ltx/modeling_latent_upsampler.py index f579cf00dbe7..47f4dd8a48b5 100644 --- a/src/diffusers/pipelines/ltx/modeling_latent_upsampler.py +++ b/src/diffusers/pipelines/ltx/modeling_latent_upsampler.py @@ -12,175 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.modeling_utils import ModelMixin - - -class ResBlock(torch.nn.Module): - def __init__(self, channels: int, mid_channels: int | None = None, dims: int = 3): - super().__init__() - if mid_channels is None: - mid_channels = channels - - Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d - - self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) - self.norm1 = torch.nn.GroupNorm(32, mid_channels) - self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) - self.norm2 = torch.nn.GroupNorm(32, channels) - self.activation = torch.nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - residual = hidden_states - hidden_states = self.conv1(hidden_states) - hidden_states = self.norm1(hidden_states) - hidden_states = self.activation(hidden_states) - hidden_states = self.conv2(hidden_states) - hidden_states = self.norm2(hidden_states) - hidden_states = self.activation(hidden_states + residual) - return hidden_states - - -class PixelShuffleND(torch.nn.Module): - def __init__(self, dims, upscale_factors=(2, 2, 2)): - super().__init__() - - self.dims = dims - self.upscale_factors = upscale_factors - - if dims not in [1, 2, 3]: - raise ValueError("dims must be 1, 2, or 3") - - def forward(self, x): - if self.dims == 3: - # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3) - return ( - x.unflatten(1, (-1, *self.upscale_factors[:3])) - .permute(0, 1, 5, 2, 6, 3, 7, 4) - .flatten(6, 7) - .flatten(4, 5) - .flatten(2, 3) - ) - elif self.dims == 2: - # spatial: b (c p1 p2) h w -> b c (h p1) (w p2) - return ( - x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3) - ) - elif self.dims == 1: - # temporal: b (c p1) f h w -> b c (f p1) h w - return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3) - - -class LTXLatentUpsamplerModel(ModelMixin, ConfigMixin): - """ - Model to spatially upsample VAE latents. - - Args: - in_channels (`int`, defaults to `128`): - Number of channels in the input latent - mid_channels (`int`, defaults to `512`): - Number of channels in the middle layers - num_blocks_per_stage (`int`, defaults to `4`): - Number of ResBlocks to use in each stage (pre/post upsampling) - dims (`int`, defaults to `3`): - Number of dimensions for convolutions (2 or 3) - spatial_upsample (`bool`, defaults to `True`): - Whether to spatially upsample the latent - temporal_upsample (`bool`, defaults to `False`): - Whether to temporally upsample the latent - """ - - @register_to_config - def __init__( - self, - in_channels: int = 128, - mid_channels: int = 512, - num_blocks_per_stage: int = 4, - dims: int = 3, - spatial_upsample: bool = True, - temporal_upsample: bool = False, - ): - super().__init__() - - self.in_channels = in_channels - self.mid_channels = mid_channels - self.num_blocks_per_stage = num_blocks_per_stage - self.dims = dims - self.spatial_upsample = spatial_upsample - self.temporal_upsample = temporal_upsample - - ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d - - self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1) - self.initial_norm = torch.nn.GroupNorm(32, mid_channels) - self.initial_activation = torch.nn.SiLU() - - self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) - - if spatial_upsample and temporal_upsample: - self.upsampler = torch.nn.Sequential( - torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(3), - ) - elif spatial_upsample: - self.upsampler = torch.nn.Sequential( - torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(2), - ) - elif temporal_upsample: - self.upsampler = torch.nn.Sequential( - torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(1), - ) - else: - raise ValueError("Either spatial_upsample or temporal_upsample must be True") - - self.post_upsample_res_blocks = torch.nn.ModuleList( - [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] +from ...models.autoencoders.latent_upsampler_ltx import ( + LTXLatentUpsamplerModel as _LTXLatentUpsamplerModel, +) +from ...models.autoencoders.latent_upsampler_ltx import ( + PixelShuffleND, # noqa: F401 re-exported for back-compat + ResBlock, # noqa: F401 re-exported for back-compat +) +from ...utils import deprecate + + +# The deprecation warning is emitted from ``__new__`` rather than ``__init__`` so the shim does not +# override the parent's ``__init__`` signature — ``ConfigMixin.extract_init_dict`` reflects on +# ``inspect.signature(cls.__init__)`` to decide which saved config keys to forward at +# ``from_pretrained`` time, and an ``__init__(self, *args, **kwargs)`` override would erase them all. +class LTXLatentUpsamplerModel(_LTXLatentUpsamplerModel): + def __new__(cls, *args, **kwargs): + deprecate( + "LTXLatentUpsamplerModel", + "1.0.0", + "Importing `LTXLatentUpsamplerModel` from `diffusers.pipelines.ltx.modeling_latent_upsampler` is " + "deprecated. Import it from `diffusers.models.autoencoders` instead " + "(or `from diffusers import LTXLatentUpsamplerModel`).", ) - - self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, num_channels, num_frames, height, width = hidden_states.shape - - if self.dims == 2: - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) - hidden_states = self.initial_conv(hidden_states) - hidden_states = self.initial_norm(hidden_states) - hidden_states = self.initial_activation(hidden_states) - - for block in self.res_blocks: - hidden_states = block(hidden_states) - - hidden_states = self.upsampler(hidden_states) - - for block in self.post_upsample_res_blocks: - hidden_states = block(hidden_states) - - hidden_states = self.final_conv(hidden_states) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - else: - hidden_states = self.initial_conv(hidden_states) - hidden_states = self.initial_norm(hidden_states) - hidden_states = self.initial_activation(hidden_states) - - for block in self.res_blocks: - hidden_states = block(hidden_states) - - if self.temporal_upsample: - hidden_states = self.upsampler(hidden_states) - hidden_states = hidden_states[:, :, 1:, :, :] - else: - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) - hidden_states = self.upsampler(hidden_states) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - - for block in self.post_upsample_res_blocks: - hidden_states = block(hidden_states) - - hidden_states = self.final_conv(hidden_states) - - return hidden_states + return super().__new__(cls) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py index 315dcc04cb30..59a2253e68d9 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py @@ -16,11 +16,11 @@ from ...image_processor import PipelineImageInput from ...models import AutoencoderKLLTXVideo +from ...models.autoencoders import LTXLatentUpsamplerModel from ...utils import deprecate, get_logger from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline -from .modeling_latent_upsampler import LTXLatentUpsamplerModel from .pipeline_output import LTXPipelineOutput diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py index 8a00a0c6b452..88d27a9838f1 100644 --- a/src/diffusers/pipelines/ltx2/connectors.py +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -1,474 +1,41 @@ -import math - -import torch -import torch.nn as nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin -from ...models.attention import FeedForward -from ...models.modeling_utils import ModelMixin -from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor - - -def per_layer_masked_mean_norm( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device: str | torch.device, - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, -): - """ - Performs per-batch per-layer normalization using a masked mean and range on per-layer text encoder hidden_states. - Respects the padding of the hidden states. - - Args: - text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): - Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). - sequence_lengths (`torch.Tensor of shape `(batch_size,)`): - The number of valid (non-padded) tokens for each batch instance. - device: (`str` or `torch.device`, *optional*): - torch device to place the resulting embeddings on - padding_side: (`str`, *optional*, defaults to `"left"`): - Whether the text tokenizer performs padding on the `"left"` or `"right"`. - scale_factor (`int`, *optional*, defaults to `8`): - Scaling factor to multiply the normalized hidden states by. - eps (`float`, *optional*, defaults to `1e-6`): - A small positive value for numerical stability when performing normalization. - - Returns: - `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: - Normed and flattened text encoder hidden states. - """ - batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - # Create padding mask - token_indices = torch.arange(seq_len, device=device).unsqueeze(0) - if padding_side == "right": - # For right padding, valid tokens are from 0 to sequence_length-1 - mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] - elif padding_side == "left": - # For left padding, valid tokens are from (T - sequence_length) to T-1 - start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] - mask = token_indices >= start_indices # [B, T] - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] - - # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) - masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) - num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) - masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) - - # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - # Normalization - normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) - normalized_hidden_states = normalized_hidden_states * scale_factor - - # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.flatten(2) - mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) - normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) - return normalized_hidden_states - - -def per_token_rms_norm(text_encoder_hidden_states: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: - variance = torch.mean(text_encoder_hidden_states**2, dim=2, keepdim=True) - norm_text_encoder_hidden_states = text_encoder_hidden_states * torch.rsqrt(variance + eps) - return norm_text_encoder_hidden_states - - -class LTX2RotaryPosEmbed1d(nn.Module): - """ - 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. - """ - - def __init__( - self, - dim: int, - base_seq_len: int = 4096, - theta: float = 10000.0, - double_precision: bool = True, - rope_type: str = "interleaved", - num_attention_heads: int = 32, - ): - super().__init__() - if rope_type not in ["interleaved", "split"]: - raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") - - self.dim = dim - self.base_seq_len = base_seq_len - self.theta = theta - self.double_precision = double_precision - self.rope_type = rope_type - self.num_attention_heads = num_attention_heads - - def forward( - self, - batch_size: int, - pos: int, - device: str | torch.device, - ) -> tuple[torch.Tensor, torch.Tensor]: - # 1. Get 1D position ids - grid_1d = torch.arange(pos, dtype=torch.float32, device=device) - # Get fractional indices relative to self.base_seq_len - grid_1d = grid_1d / self.base_seq_len - grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] - - # 2. Calculate 1D RoPE frequencies - num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2 - freqs_dtype = torch.float64 if self.double_precision else torch.float32 - pow_indices = torch.pow( - self.theta, - torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), - ) - freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) - - # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape - # (self.dim // 2,). - freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2] - - # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim - if self.rope_type == "interleaved": - cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) - sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) - - if self.dim % num_rope_elems != 0: - cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) - sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) - cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) - sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) - - elif self.rope_type == "split": - expected_freqs = self.dim // 2 - current_freqs = freqs.shape[-1] - pad_size = expected_freqs - current_freqs - cos_freq = freqs.cos() - sin_freq = freqs.sin() - - if pad_size != 0: - cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) - sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) - - cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) - sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) - - # Reshape freqs to be compatible with multi-head attention - b = cos_freq.shape[0] - t = cos_freq.shape[1] - - cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) - sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) - - cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) - sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) - - return cos_freqs, sin_freqs - - -class LTX2TransformerBlock1d(nn.Module): - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - activation_fn: str = "gelu-approximate", - eps: float = 1e-6, - rope_type: str = "interleaved", - apply_gated_attention: bool = False, - ): - super().__init__() - - self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) - self.attn1 = LTX2Attention( - query_dim=dim, - heads=num_attention_heads, - kv_heads=num_attention_heads, - dim_head=attention_head_dim, - rope_type=rope_type, - apply_gated_attention=apply_gated_attention, - processor=LTX2AudioVideoAttnProcessor(), - ) - - self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) - self.ff = FeedForward(dim, activation_fn=activation_fn) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - rotary_emb: torch.Tensor | None = None, - ) -> torch.Tensor: - norm_hidden_states = self.norm1(hidden_states) - attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb) - hidden_states = hidden_states + attn_hidden_states - - norm_hidden_states = self.norm2(hidden_states) - ff_hidden_states = self.ff(norm_hidden_states) - hidden_states = hidden_states + ff_hidden_states - - return hidden_states - - -class LTX2ConnectorTransformer1d(nn.Module): - """ - A 1D sequence transformer for modalities such as text. - - In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams. - """ - - _supports_gradient_checkpointing = True - - def __init__( - self, - num_attention_heads: int = 30, - attention_head_dim: int = 128, - num_layers: int = 2, - num_learnable_registers: int | None = 128, - rope_base_seq_len: int = 4096, - rope_theta: float = 10000.0, - rope_double_precision: bool = True, - eps: float = 1e-6, - causal_temporal_positioning: bool = False, - rope_type: str = "interleaved", - gated_attention: bool = False, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.inner_dim = num_attention_heads * attention_head_dim - self.causal_temporal_positioning = causal_temporal_positioning - - self.num_learnable_registers = num_learnable_registers - self.learnable_registers = None - if num_learnable_registers is not None: - init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0 - self.learnable_registers = torch.nn.Parameter(init_registers) - - self.rope = LTX2RotaryPosEmbed1d( - self.inner_dim, - base_seq_len=rope_base_seq_len, - theta=rope_theta, - double_precision=rope_double_precision, - rope_type=rope_type, - num_attention_heads=num_attention_heads, - ) - - self.transformer_blocks = torch.nn.ModuleList( - [ - LTX2TransformerBlock1d( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - rope_type=rope_type, - apply_gated_attention=gated_attention, - ) - for _ in range(num_layers) - ] +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...models.condition_embedders.text_connector_ltx2 import ( + LTX2ConnectorTransformer1d, # noqa: F401 re-exported for back-compat + LTX2RotaryPosEmbed1d, # noqa: F401 + LTX2TransformerBlock1d, # noqa: F401 + per_layer_masked_mean_norm, # noqa: F401 + per_token_rms_norm, # noqa: F401 +) +from ...models.condition_embedders.text_connector_ltx2 import ( + LTX2TextConnectors as _LTX2TextConnectors, +) +from ...utils import deprecate + + +# The deprecation warning is emitted from ``__new__`` rather than ``__init__`` so the shim does not +# override the parent's ``__init__`` signature — ``ConfigMixin.extract_init_dict`` reflects on +# ``inspect.signature(cls.__init__)`` to decide which saved config keys to forward at +# ``from_pretrained`` time, and an ``__init__(self, *args, **kwargs)`` override would erase them all. +class LTX2TextConnectors(_LTX2TextConnectors): + def __new__(cls, *args, **kwargs): + deprecate( + "LTX2TextConnectors", + "1.0.0", + "Importing `LTX2TextConnectors` from `diffusers.pipelines.ltx2.connectors` is deprecated. " + "Import it from `diffusers.models.condition_embedders` instead " + "(or `from diffusers import LTX2TextConnectors`).", ) - - self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - attn_mask_binarize_threshold: float = -9000.0, - ) -> tuple[torch.Tensor, torch.Tensor]: - # hidden_states shape: [batch_size, seq_len, hidden_dim] - # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len] - batch_size, seq_len, _ = hidden_states.shape - - # 1. Replace padding with learned registers, if using - if self.learnable_registers is not None: - if seq_len % self.num_learnable_registers != 0: - raise ValueError( - f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number" - f" of learnable registers {self.num_learnable_registers}" - ) - - num_register_repeats = seq_len // self.num_learnable_registers - registers = ( - self.learnable_registers.unsqueeze(0).expand(num_register_repeats, -1, -1).reshape(seq_len, -1) - ) # [seq_len, inner_dim] - - binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int() - if binary_attn_mask.ndim == 4: - binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] - - # Replace padding positions with learned registers using vectorized masking - mask = binary_attn_mask.unsqueeze(-1) # [B, L, 1] - registers_expanded = registers.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, D] - hidden_states = mask * hidden_states + (1 - mask) * registers_expanded - - # Flip sequence: embeddings move to front, registers to back (from left padding layout) - hidden_states = torch.flip(hidden_states, dims=[1]) - - # Overwrite attention_mask with an all-zeros mask if using registers. - attention_mask = torch.zeros_like(attention_mask) - - # 2. Calculate 1D RoPE positional embeddings - rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device) - - # 3. Run 1D transformer blocks - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb) - else: - hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb) - - hidden_states = self.norm_out(hidden_states) - - return hidden_states, attention_mask - - -class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin): - """ - Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio - streams. - """ - - @register_to_config - def __init__( - self, - caption_channels: int = 3840, # default Gemma-3-12B text encoder hidden_size - text_proj_in_factor: int = 49, # num_layers + 1 for embedding layer = 48 + 1 for Gemma-3-12B - video_connector_num_attention_heads: int = 30, - video_connector_attention_head_dim: int = 128, - video_connector_num_layers: int = 2, - video_connector_num_learnable_registers: int | None = 128, - video_gated_attn: bool = False, - audio_connector_num_attention_heads: int = 30, - audio_connector_attention_head_dim: int = 128, - audio_connector_num_layers: int = 2, - audio_connector_num_learnable_registers: int | None = 128, - audio_gated_attn: bool = False, - connector_rope_base_seq_len: int = 4096, - rope_theta: float = 10000.0, - rope_double_precision: bool = True, - causal_temporal_positioning: bool = False, - rope_type: str = "interleaved", - per_modality_projections: bool = False, - video_hidden_dim: int = 4096, - audio_hidden_dim: int = 2048, - proj_bias: bool = False, - ): - super().__init__() - text_encoder_dim = caption_channels * text_proj_in_factor - if per_modality_projections: - self.video_text_proj_in = nn.Linear(text_encoder_dim, video_hidden_dim, bias=proj_bias) - self.audio_text_proj_in = nn.Linear(text_encoder_dim, audio_hidden_dim, bias=proj_bias) - else: - self.text_proj_in = nn.Linear(text_encoder_dim, caption_channels, bias=proj_bias) - - self.video_connector = LTX2ConnectorTransformer1d( - num_attention_heads=video_connector_num_attention_heads, - attention_head_dim=video_connector_attention_head_dim, - num_layers=video_connector_num_layers, - num_learnable_registers=video_connector_num_learnable_registers, - rope_base_seq_len=connector_rope_base_seq_len, - rope_theta=rope_theta, - rope_double_precision=rope_double_precision, - causal_temporal_positioning=causal_temporal_positioning, - rope_type=rope_type, - gated_attention=video_gated_attn, - ) - self.audio_connector = LTX2ConnectorTransformer1d( - num_attention_heads=audio_connector_num_attention_heads, - attention_head_dim=audio_connector_attention_head_dim, - num_layers=audio_connector_num_layers, - num_learnable_registers=audio_connector_num_learnable_registers, - rope_base_seq_len=connector_rope_base_seq_len, - rope_theta=rope_theta, - rope_double_precision=rope_double_precision, - causal_temporal_positioning=causal_temporal_positioning, - rope_type=rope_type, - gated_attention=audio_gated_attn, - ) - - def forward( - self, - text_encoder_hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - padding_side: str = "left", - scale_factor: int = 8, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Given per-layer text encoder hidden_states, extracts features and runs per-modality connectors to get text - embeddings for the LTX-2.X DiT models. - - Args: - text_encoder_hidden_states (`torch.Tensor`)): - Per-layer text encoder hidden_states. Can either be 4D with shape `(batch_size, seq_len, - caption_channels, text_proj_in_factor) or 3D with the last two dimensions flattened. - attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): - Multiplicative binary attention mask where 1s indicate unmasked positions and 0s indicate masked - positions. - padding_side (`str`, *optional*, defaults to `"left"`): - The padding side used by the text encoder's text encoder (either `"left"` or `"right"`). Defaults to - `"left"` as this is what the default Gemma3-12B text encoder uses. Only used if - `per_modality_projections` is `False` (LTX-2.0 models). - scale_factor (`int`, *optional*, defaults to `8`): - Scale factor for masked mean/range normalization. Only used if `per_modality_projections` is `False` - (LTX-2.0 models). - """ - if text_encoder_hidden_states.ndim == 3: - # Ensure shape is [batch_size, seq_len, caption_channels, text_proj_in_factor] - text_encoder_hidden_states = text_encoder_hidden_states.unflatten(2, (self.config.caption_channels, -1)) - - if self.config.per_modality_projections: - # LTX-2.3 - norm_text_encoder_hidden_states = per_token_rms_norm(text_encoder_hidden_states) - - norm_text_encoder_hidden_states = norm_text_encoder_hidden_states.flatten(2, 3) - bool_mask = attention_mask.bool().unsqueeze(-1) - norm_text_encoder_hidden_states = torch.where( - bool_mask, norm_text_encoder_hidden_states, torch.zeros_like(norm_text_encoder_hidden_states) - ) - - # Rescale norms with respect to video and audio dims for feature extractors - video_scale_factor = math.sqrt(self.config.video_hidden_dim / self.config.caption_channels) - video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor - audio_scale_factor = math.sqrt(self.config.audio_hidden_dim / self.config.caption_channels) - audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor - - # Per-Modality Feature extractors - video_text_emb_proj = self.video_text_proj_in(video_norm_text_emb) - audio_text_emb_proj = self.audio_text_proj_in(audio_norm_text_emb) - else: - # LTX-2.0 - sequence_lengths = attention_mask.sum(dim=-1) - norm_text_encoder_hidden_states = per_layer_masked_mean_norm( - text_hidden_states=text_encoder_hidden_states, - sequence_lengths=sequence_lengths, - device=text_encoder_hidden_states.device, - padding_side=padding_side, - scale_factor=scale_factor, - ) - - text_emb_proj = self.text_proj_in(norm_text_encoder_hidden_states) - video_text_emb_proj = text_emb_proj - audio_text_emb_proj = text_emb_proj - - # Convert to additive attention mask for connectors - text_dtype = video_text_emb_proj.dtype - attention_mask = (attention_mask.to(torch.int64) - 1).to(text_dtype) - attention_mask = attention_mask.reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) - add_attn_mask = attention_mask * torch.finfo(text_dtype).max - - video_text_embedding, video_attn_mask = self.video_connector(video_text_emb_proj, add_attn_mask) - - # Convert video attn mask to binary (multiplicative) mask and mask video text embedding - binary_attn_mask = (video_attn_mask < 1e-6).to(torch.int64) - binary_attn_mask = binary_attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) - video_text_embedding = video_text_embedding * binary_attn_mask - - audio_text_embedding, _ = self.audio_connector(audio_text_emb_proj, add_attn_mask) - - return video_text_embedding, audio_text_embedding, binary_attn_mask.squeeze(-1) + return super().__new__(cls) diff --git a/src/diffusers/pipelines/ltx2/latent_upsampler.py b/src/diffusers/pipelines/ltx2/latent_upsampler.py index 329ced36d45b..2b0e23d45267 100644 --- a/src/diffusers/pipelines/ltx2/latent_upsampler.py +++ b/src/diffusers/pipelines/ltx2/latent_upsampler.py @@ -12,274 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - -import torch -import torch.nn.functional as F - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.modeling_utils import ModelMixin - - -RATIONAL_RESAMPLER_SCALE_MAPPING = { - 0.75: (3, 4), - 1.5: (3, 2), - 2.0: (2, 1), - 4.0: (4, 1), -} - - -# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.ResBlock -class ResBlock(torch.nn.Module): - def __init__(self, channels: int, mid_channels: int | None = None, dims: int = 3): - super().__init__() - if mid_channels is None: - mid_channels = channels - - Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d - - self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) - self.norm1 = torch.nn.GroupNorm(32, mid_channels) - self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) - self.norm2 = torch.nn.GroupNorm(32, channels) - self.activation = torch.nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - residual = hidden_states - hidden_states = self.conv1(hidden_states) - hidden_states = self.norm1(hidden_states) - hidden_states = self.activation(hidden_states) - hidden_states = self.conv2(hidden_states) - hidden_states = self.norm2(hidden_states) - hidden_states = self.activation(hidden_states + residual) - return hidden_states - - -# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.PixelShuffleND -class PixelShuffleND(torch.nn.Module): - def __init__(self, dims, upscale_factors=(2, 2, 2)): - super().__init__() - - self.dims = dims - self.upscale_factors = upscale_factors - - if dims not in [1, 2, 3]: - raise ValueError("dims must be 1, 2, or 3") - - def forward(self, x): - if self.dims == 3: - # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3) - return ( - x.unflatten(1, (-1, *self.upscale_factors[:3])) - .permute(0, 1, 5, 2, 6, 3, 7, 4) - .flatten(6, 7) - .flatten(4, 5) - .flatten(2, 3) - ) - elif self.dims == 2: - # spatial: b (c p1 p2) h w -> b c (h p1) (w p2) - return ( - x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3) - ) - elif self.dims == 1: - # temporal: b (c p1) f h w -> b c (f p1) h w - return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3) - - -class BlurDownsample(torch.nn.Module): - """ - Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W. - Works for dims=2 or dims=3 (per-frame). - """ - - def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None: - super().__init__() - - if dims not in (2, 3): - raise ValueError(f"`dims` must be either 2 or 3 but is {dims}") - if kernel_size < 3 or kernel_size % 2 != 1: - raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}") - - self.dims = dims - self.stride = stride - self.kernel_size = kernel_size - - # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from - # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and - # provides a smooth approximation of a Gaussian filter (often called a "binomial filter"). - # The 2D kernel is constructed as the outer product and normalized. - k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)]) - k2d = k[:, None] @ k[None, :] - k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size) - self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.stride == 1: - return x - - if self.dims == 2: - c = x.shape[1] - weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise - x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) - else: - # dims == 3: apply per-frame on H,W - b, c, f, _, _ = x.shape - x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] - - weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise - x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) - - h2, w2 = x.shape[-2:] - x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W] - return x - - -class SpatialRationalResampler(torch.nn.Module): - """ - Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample - by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the - input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the - (integer) denominator. - """ - - def __init__(self, mid_channels: int = 1024, scale: float = 2.0): - super().__init__() - self.scale = float(scale) - num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None) - if num_denom is None: - raise ValueError( - f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}" - ) - self.num, self.den = num_denom - - self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1) - self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) - self.blur_down = BlurDownsample(dims=2, stride=self.den) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # Expected x shape: [B * F, C, H, W] - # b, _, f, h, w = x.shape - # x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] - x = self.conv(x) - x = self.pixel_shuffle(x) - x = self.blur_down(x) - # x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W] - return x - - -class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin): - """ - Model to spatially upsample VAE latents. - - Args: - in_channels (`int`, defaults to `128`): - Number of channels in the input latent - mid_channels (`int`, defaults to `512`): - Number of channels in the middle layers - num_blocks_per_stage (`int`, defaults to `4`): - Number of ResBlocks to use in each stage (pre/post upsampling) - dims (`int`, defaults to `3`): - Number of dimensions for convolutions (2 or 3) - spatial_upsample (`bool`, defaults to `True`): - Whether to spatially upsample the latent - temporal_upsample (`bool`, defaults to `False`): - Whether to temporally upsample the latent - """ - - @register_to_config - def __init__( - self, - in_channels: int = 128, - mid_channels: int = 1024, - num_blocks_per_stage: int = 4, - dims: int = 3, - spatial_upsample: bool = True, - temporal_upsample: bool = False, - rational_spatial_scale: float = 2.0, - use_rational_resampler: bool = True, - ): - super().__init__() - - self.in_channels = in_channels - self.mid_channels = mid_channels - self.num_blocks_per_stage = num_blocks_per_stage - self.dims = dims - self.spatial_upsample = spatial_upsample - self.temporal_upsample = temporal_upsample - - ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d - - self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1) - self.initial_norm = torch.nn.GroupNorm(32, mid_channels) - self.initial_activation = torch.nn.SiLU() - - self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) - - if spatial_upsample and temporal_upsample: - self.upsampler = torch.nn.Sequential( - torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(3), - ) - elif spatial_upsample: - if use_rational_resampler: - self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale) - else: - self.upsampler = torch.nn.Sequential( - torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(2), - ) - elif temporal_upsample: - self.upsampler = torch.nn.Sequential( - torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), - PixelShuffleND(1), - ) - else: - raise ValueError("Either spatial_upsample or temporal_upsample must be True") - - self.post_upsample_res_blocks = torch.nn.ModuleList( - [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] +from ...models.autoencoders.latent_upsampler_ltx2 import ( + RATIONAL_RESAMPLER_SCALE_MAPPING, # noqa: F401 re-exported for back-compat + PixelShuffleND, # noqa: F401 re-exported for back-compat + ResBlock, # noqa: F401 re-exported for back-compat +) +from ...models.autoencoders.latent_upsampler_ltx2 import ( + LTX2LatentUpsamplerModel as _LTX2LatentUpsamplerModel, +) +from ...utils import deprecate + + +# The deprecation warning is emitted from ``__new__`` rather than ``__init__`` so the shim does not +# override the parent's ``__init__`` signature — ``ConfigMixin.extract_init_dict`` reflects on +# ``inspect.signature(cls.__init__)`` to decide which saved config keys to forward at +# ``from_pretrained`` time, and an ``__init__(self, *args, **kwargs)`` override would erase them all. +class LTX2LatentUpsamplerModel(_LTX2LatentUpsamplerModel): + def __new__(cls, *args, **kwargs): + deprecate( + "LTX2LatentUpsamplerModel", + "1.0.0", + "Importing `LTX2LatentUpsamplerModel` from `diffusers.pipelines.ltx2.latent_upsampler` is " + "deprecated. Import it from `diffusers.models.autoencoders` instead " + "(or `from diffusers import LTX2LatentUpsamplerModel`).", ) - - self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, num_channels, num_frames, height, width = hidden_states.shape - - if self.dims == 2: - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) - hidden_states = self.initial_conv(hidden_states) - hidden_states = self.initial_norm(hidden_states) - hidden_states = self.initial_activation(hidden_states) - - for block in self.res_blocks: - hidden_states = block(hidden_states) - - hidden_states = self.upsampler(hidden_states) - - for block in self.post_upsample_res_blocks: - hidden_states = block(hidden_states) - - hidden_states = self.final_conv(hidden_states) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - else: - hidden_states = self.initial_conv(hidden_states) - hidden_states = self.initial_norm(hidden_states) - hidden_states = self.initial_activation(hidden_states) - - for block in self.res_blocks: - hidden_states = block(hidden_states) - - if self.temporal_upsample: - hidden_states = self.upsampler(hidden_states) - hidden_states = hidden_states[:, :, 1:, :, :] - else: - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) - hidden_states = self.upsampler(hidden_states) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) - - for block in self.post_upsample_res_blocks: - hidden_states = block(hidden_states) - - hidden_states = self.final_conv(hidden_states) - - return hidden_states + return super().__new__(cls) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index ba32f6ed4c0c..0162c9df4fed 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -22,16 +22,15 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin -from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2Vocoder, LTX2VocoderWithBWE +from ...models.condition_embedders import LTX2TextConnectors from ...models.transformers import LTX2VideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline -from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput -from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE if is_torch_xla_available(): diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py index 600665966f13..eb5dbdf7018e 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -25,16 +25,15 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin -from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2Vocoder, LTX2VocoderWithBWE +from ...models.condition_embedders import LTX2TextConnectors from ...models.transformers import LTX2VideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline -from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput -from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE if is_torch_xla_available(): diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py index 38cd69b66c64..44a54528d697 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_hdr_lora.py @@ -24,16 +24,15 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin -from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2Vocoder, LTX2VocoderWithBWE +from ...models.condition_embedders import LTX2TextConnectors from ...models.transformers import LTX2VideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .connectors import LTX2TextConnectors from .image_processor import LTX2VideoHDRProcessor from .pipeline_output import LTX2PipelineOutput -from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE if is_torch_xla_available(): diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py index 8f2e3504e777..066e756ee291 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_ic_lora.py @@ -25,17 +25,16 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin -from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2Vocoder, LTX2VocoderWithBWE +from ...models.condition_embedders import LTX2TextConnectors from ...models.transformers import LTX2VideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline -from .connectors import LTX2TextConnectors from .pipeline_ltx2_condition import LTX2VideoCondition from .pipeline_output import LTX2PipelineOutput -from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE if is_torch_xla_available(): diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index bf27927ec8cd..caa0cf4573f2 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -23,16 +23,15 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin -from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2Vocoder, LTX2VocoderWithBWE +from ...models.condition_embedders import LTX2TextConnectors from ...models.transformers import LTX2VideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline -from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput -from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE if is_torch_xla_available(): diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index 00d81dfd11c3..a013b7c3aa37 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -17,12 +17,12 @@ from ...image_processor import PipelineImageInput from ...models import AutoencoderKLLTX2Video +from ...models.autoencoders import LTX2LatentUpsamplerModel from ...utils import get_logger, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..ltx.pipeline_output import LTXPipelineOutput from ..pipeline_utils import DiffusionPipeline -from .latent_upsampler import LTX2LatentUpsamplerModel logger = get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/ltx2/vocoder.py b/src/diffusers/pipelines/ltx2/vocoder.py index f0004f2ec02d..9ca795fe1639 100644 --- a/src/diffusers/pipelines/ltx2/vocoder.py +++ b/src/diffusers/pipelines/ltx2/vocoder.py @@ -1,597 +1,59 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.modeling_utils import ModelMixin - - -def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor: - """ - Creates a Kaiser sinc kernel for low-pass filtering. - - Args: - cutoff (`float`): - Normalized frequency cutoff (relative to the sampling rate). Must be between 0 and 0.5 (the Nyquist - frequency). - half_width (`float`): - Used to determine the Kaiser window's beta parameter. - kernel_size: - Size of the Kaiser window (and ultimately the Kaiser sinc kernel). - - Returns: - `torch.Tensor` of shape `(kernel_size,)`: - The Kaiser sinc kernel. - """ - delta_f = 4 * half_width - half_size = kernel_size // 2 - amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 - if amplitude > 50.0: - beta = 0.1102 * (amplitude - 8.7) - elif amplitude >= 21.0: - beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0) - else: - beta = 0.0 - - window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) - - even = kernel_size % 2 == 0 - time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size - - if cutoff == 0.0: - filter = torch.zeros_like(time) - else: - time = 2 * cutoff * time - sinc = torch.where( - time == 0, - torch.ones_like(time), - torch.sin(math.pi * time) / math.pi / time, - ) - filter = 2 * cutoff * window * sinc - filter = filter / filter.sum() - return filter - - -class DownSample1d(nn.Module): - """1D low-pass filter for antialias downsampling.""" - - def __init__( - self, - ratio: int = 2, - kernel_size: int | None = None, - use_padding: bool = True, - padding_mode: str = "replicate", - persistent: bool = True, - ): - super().__init__() - self.ratio = ratio - self.kernel_size = kernel_size or int(6 * ratio // 2) * 2 - self.pad_left = self.kernel_size // 2 + (self.kernel_size % 2) - 1 - self.pad_right = self.kernel_size // 2 - self.use_padding = use_padding - self.padding_mode = padding_mode - - cutoff = 0.5 / ratio - half_width = 0.6 / ratio - low_pass_filter = kaiser_sinc_filter1d(cutoff, half_width, self.kernel_size) - self.register_buffer("filter", low_pass_filter.view(1, 1, self.kernel_size), persistent=persistent) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # x expected shape: [batch_size, num_channels, hidden_dim] - num_channels = x.shape[1] - if self.use_padding: - x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) - x_filtered = F.conv1d(x, self.filter.expand(num_channels, -1, -1), stride=self.ratio, groups=num_channels) - return x_filtered - - -class UpSample1d(nn.Module): - def __init__( - self, - ratio: int = 2, - kernel_size: int | None = None, - window_type: str = "kaiser", - padding_mode: str = "replicate", - persistent: bool = True, - ): - super().__init__() - self.ratio = ratio - self.padding_mode = padding_mode - - if window_type == "hann": - rolloff = 0.99 - lowpass_filter_width = 6 - width = math.ceil(lowpass_filter_width / rolloff) - self.kernel_size = 2 * width * ratio + 1 - self.pad = width - self.pad_left = 2 * width * ratio - self.pad_right = self.kernel_size - ratio - - time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff - time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width) - window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2 - sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1) - else: - # Kaiser sinc filter is BigVGAN default - self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size - self.pad = self.kernel_size // ratio - 1 - self.pad_left = self.pad * self.ratio + (self.kernel_size - self.ratio) // 2 - self.pad_right = self.pad * self.ratio + (self.kernel_size - self.ratio + 1) // 2 - - sinc_filter = kaiser_sinc_filter1d( - cutoff=0.5 / ratio, - half_width=0.6 / ratio, - kernel_size=self.kernel_size, - ) - - self.register_buffer("filter", sinc_filter.view(1, 1, self.kernel_size), persistent=persistent) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # x expected shape: [batch_size, num_channels, hidden_dim] - num_channels = x.shape[1] - x = F.pad(x, (self.pad, self.pad), mode=self.padding_mode) - low_pass_filter = self.filter.to(dtype=x.dtype, device=x.device).expand(num_channels, -1, -1) - x = self.ratio * F.conv_transpose1d(x, low_pass_filter, stride=self.ratio, groups=num_channels) - return x[..., self.pad_left : -self.pad_right] - - -class AntiAliasAct1d(nn.Module): - """ - Antialiasing activation for a 1D signal: upsamples, applies an activation (usually snakebeta), and then downsamples - to avoid aliasing. - """ - - def __init__( - self, - act_fn: str | nn.Module, - ratio: int = 2, - kernel_size: int = 12, - **kwargs, - ): - super().__init__() - self.upsample = UpSample1d(ratio=ratio, kernel_size=kernel_size) - if isinstance(act_fn, str): - if act_fn == "snakebeta": - act_fn = SnakeBeta(**kwargs) - elif act_fn == "snake": - act_fn = SnakeBeta(**kwargs) - else: - act_fn = nn.LeakyReLU(**kwargs) - self.act = act_fn - self.downsample = DownSample1d(ratio=ratio, kernel_size=kernel_size) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.upsample(x) - x = self.act(x) - x = self.downsample(x) - return x - - -class SnakeBeta(nn.Module): - """ - Implements the Snake and SnakeBeta activations, which help with learning periodic patterns. - """ - - def __init__( - self, - channels: int, - alpha: float = 1.0, - eps: float = 1e-9, - trainable_params: bool = True, - logscale: bool = True, - use_beta: bool = True, - ): - super().__init__() - self.eps = eps - self.logscale = logscale - self.use_beta = use_beta - - self.alpha = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha) - self.alpha.requires_grad = trainable_params - if use_beta: - self.beta = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha) - self.beta.requires_grad = trainable_params - - def forward(self, hidden_states: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: - broadcast_shape = [1] * hidden_states.ndim - broadcast_shape[channel_dim] = -1 - alpha = self.alpha.view(broadcast_shape) - if self.use_beta: - beta = self.beta.view(broadcast_shape) - - if self.logscale: - alpha = torch.exp(alpha) - if self.use_beta: - beta = torch.exp(beta) - - amplitude = beta if self.use_beta else alpha - hidden_states = hidden_states + (1.0 / (amplitude + self.eps)) * torch.sin(hidden_states * alpha).pow(2) - return hidden_states - - -class ResBlock(nn.Module): - def __init__( - self, - channels: int, - kernel_size: int = 3, - stride: int = 1, - dilations: tuple[int, ...] = (1, 3, 5), - act_fn: str = "leaky_relu", - leaky_relu_negative_slope: float = 0.1, - antialias: bool = False, - antialias_ratio: int = 2, - antialias_kernel_size: int = 12, - padding_mode: str = "same", - ): - super().__init__() - self.dilations = dilations - - self.convs1 = nn.ModuleList( - [ - nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode) - for dilation in dilations - ] - ) - self.acts1 = nn.ModuleList() - for _ in range(len(self.convs1)): - if act_fn == "snakebeta": - act = SnakeBeta(channels, use_beta=True) - elif act_fn == "snake": - act = SnakeBeta(channels, use_beta=False) - else: - act = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) - - if antialias: - act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size) - self.acts1.append(act) - - self.convs2 = nn.ModuleList( - [ - nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode) - for _ in range(len(dilations)) - ] - ) - self.acts2 = nn.ModuleList() - for _ in range(len(self.convs2)): - if act_fn == "snakebeta": - act = SnakeBeta(channels, use_beta=True) - elif act_fn == "snake": - act = SnakeBeta(channels, use_beta=False) - else: - act_fn = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) - - if antialias: - act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size) - self.acts2.append(act) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - for act1, conv1, act2, conv2 in zip(self.acts1, self.convs1, self.acts2, self.convs2): - xt = act1(x) - xt = conv1(xt) - xt = act2(xt) - xt = conv2(xt) - x = x + xt - return x - - -class LTX2Vocoder(ModelMixin, ConfigMixin): - r""" - LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms. - """ - - @register_to_config - def __init__( - self, - in_channels: int = 128, - hidden_channels: int = 1024, - out_channels: int = 2, - upsample_kernel_sizes: list[int] = [16, 15, 8, 4, 4], - upsample_factors: list[int] = [6, 5, 2, 2, 2], - resnet_kernel_sizes: list[int] = [3, 7, 11], - resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - act_fn: str = "leaky_relu", - leaky_relu_negative_slope: float = 0.1, - antialias: bool = False, - antialias_ratio: int = 2, - antialias_kernel_size: int = 12, - final_act_fn: str | None = "tanh", # tanh, clamp, None - final_bias: bool = True, - output_sampling_rate: int = 24000, - ): - super().__init__() - self.num_upsample_layers = len(upsample_kernel_sizes) - self.resnets_per_upsample = len(resnet_kernel_sizes) - self.out_channels = out_channels - self.total_upsample_factor = math.prod(upsample_factors) - self.act_fn = act_fn - self.negative_slope = leaky_relu_negative_slope - self.final_act_fn = final_act_fn - - if self.num_upsample_layers != len(upsample_factors): - raise ValueError( - f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length" - f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively." - ) - - if self.resnets_per_upsample != len(resnet_dilations): - raise ValueError( - f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length" - f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." - ) - - supported_act_fns = ["snakebeta", "snake", "leaky_relu"] - if self.act_fn not in supported_act_fns: - raise ValueError( - f"Unsupported activation function: {self.act_fn}. Currently supported values of `act_fn` are " - f"{supported_act_fns}." - ) - - self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3) - - self.upsamplers = nn.ModuleList() - self.resnets = nn.ModuleList() - input_channels = hidden_channels - for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): - output_channels = input_channels // 2 - self.upsamplers.append( - nn.ConvTranspose1d( - input_channels, # hidden_channels // (2 ** i) - output_channels, # hidden_channels // (2 ** (i + 1)) - kernel_size, - stride=stride, - padding=(kernel_size - stride) // 2, - ) - ) - - for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): - self.resnets.append( - ResBlock( - channels=output_channels, - kernel_size=kernel_size, - dilations=dilations, - act_fn=act_fn, - leaky_relu_negative_slope=leaky_relu_negative_slope, - antialias=antialias, - antialias_ratio=antialias_ratio, - antialias_kernel_size=antialias_kernel_size, - ) - ) - input_channels = output_channels - - if act_fn == "snakebeta" or act_fn == "snake": - # Always use antialiasing - act_out = SnakeBeta(channels=output_channels, use_beta=True) - self.act_out = AntiAliasAct1d(act_out, ratio=antialias_ratio, kernel_size=antialias_kernel_size) - elif act_fn == "leaky_relu": - # NOTE: does NOT use self.negative_slope, following the original code - self.act_out = nn.LeakyReLU() - - self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3, bias=final_bias) - - def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: - r""" - Forward pass of the vocoder. - - Args: - hidden_states (`torch.Tensor`): - Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last` - is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is - `True`. - time_last (`bool`, *optional*, defaults to `False`): - Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension. - - Returns: - `torch.Tensor`: - Audio waveform tensor of shape (batch_size, out_channels, audio_length) - """ - - # Ensure that the time/frame dimension is last - if not time_last: - hidden_states = hidden_states.transpose(2, 3) - # Combine channels and frequency (mel bins) dimensions - hidden_states = hidden_states.flatten(1, 2) - - hidden_states = self.conv_in(hidden_states) - - for i in range(self.num_upsample_layers): - if self.act_fn == "leaky_relu": - # Other activations are inside each upsampling block - hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) - hidden_states = self.upsamplers[i](hidden_states) - - # Run all resnets in parallel on hidden_states - start = i * self.resnets_per_upsample - end = (i + 1) * self.resnets_per_upsample - resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0) - - hidden_states = torch.mean(resnet_outputs, dim=0) - - hidden_states = self.act_out(hidden_states) - hidden_states = self.conv_out(hidden_states) - if self.final_act_fn == "tanh": - hidden_states = torch.tanh(hidden_states) - elif self.final_act_fn == "clamp": - hidden_states = torch.clamp(hidden_states, -1, 1) - - return hidden_states - - -class CausalSTFT(nn.Module): - """ - Performs a causal short-time Fourier transform (STFT) using causal Hann windows on a waveform. The DFT bases - multiplied by the Hann windows are pre-calculated and stored as buffers. For exact parity with training, the exact - buffers should be loaded from the checkpoint in bfloat16. - """ - - def __init__(self, filter_length: int = 512, hop_length: int = 80, window_length: int = 512): - super().__init__() - self.hop_length = hop_length - self.window_length = window_length - n_freqs = filter_length // 2 + 1 - - self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True) - self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True) - - def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - if waveform.ndim == 2: - waveform = waveform.unsqueeze(1) # [B, num_channels, num_samples] - - left_pad = max(0, self.window_length - self.hop_length) # causal: left-only - waveform = F.pad(waveform, (left_pad, 0)) - - spec = F.conv1d(waveform, self.forward_basis, stride=self.hop_length, padding=0) - n_freqs = spec.shape[1] // 2 - real, imag = spec[:, :n_freqs], spec[:, n_freqs:] - magnitude = torch.sqrt(real**2 + imag**2) - phase = torch.atan2(imag.float(), real.float()).to(dtype=real.dtype) - return magnitude, phase - - -class MelSTFT(nn.Module): - """ - Calculates a causal log-mel spectrogram from a waveform. Uses a pre-calculated mel filterbank, which should be - loaded from the checkpoint in bfloat16. - """ - - def __init__( - self, - filter_length: int = 512, - hop_length: int = 80, - window_length: int = 512, - num_mel_channels: int = 64, - ): - super().__init__() - self.stft_fn = CausalSTFT(filter_length, hop_length, window_length) - - num_freqs = filter_length // 2 + 1 - self.register_buffer("mel_basis", torch.zeros(num_mel_channels, num_freqs), persistent=True) - - def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - magnitude, phase = self.stft_fn(waveform) - energy = torch.norm(magnitude, dim=1) - mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) - log_mel = torch.log(torch.clamp(mel, min=1e-5)) - return log_mel, magnitude, phase, energy - - -class LTX2VocoderWithBWE(ModelMixin, ConfigMixin): - """ - LTX-2.X vocoder with bandwidth extension (BWE) upsampling. The vocoder and the BWE module run in sequence, with the - BWE module upsampling the vocoder output waveform to a higher sampling rate. The BWE module itself has the same - architecture as the original vocoder. - """ - - @register_to_config - def __init__( - self, - in_channels: int = 128, - hidden_channels: int = 1536, - out_channels: int = 2, - upsample_kernel_sizes: list[int] = [11, 4, 4, 4, 4, 4], - upsample_factors: list[int] = [5, 2, 2, 2, 2, 2], - resnet_kernel_sizes: list[int] = [3, 7, 11], - resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - act_fn: str = "snakebeta", - leaky_relu_negative_slope: float = 0.1, - antialias: bool = True, - antialias_ratio: int = 2, - antialias_kernel_size: int = 12, - final_act_fn: str | None = None, - final_bias: bool = False, - bwe_in_channels: int = 128, - bwe_hidden_channels: int = 512, - bwe_out_channels: int = 2, - bwe_upsample_kernel_sizes: list[int] = [12, 11, 4, 4, 4], - bwe_upsample_factors: list[int] = [6, 5, 2, 2, 2], - bwe_resnet_kernel_sizes: list[int] = [3, 7, 11], - bwe_resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - bwe_act_fn: str = "snakebeta", - bwe_leaky_relu_negative_slope: float = 0.1, - bwe_antialias: bool = True, - bwe_antialias_ratio: int = 2, - bwe_antialias_kernel_size: int = 12, - bwe_final_act_fn: str | None = None, - bwe_final_bias: bool = False, - filter_length: int = 512, - hop_length: int = 80, - window_length: int = 512, - num_mel_channels: int = 64, - input_sampling_rate: int = 16000, - output_sampling_rate: int = 48000, - ): - super().__init__() - - self.vocoder = LTX2Vocoder( - in_channels=in_channels, - hidden_channels=hidden_channels, - out_channels=out_channels, - upsample_kernel_sizes=upsample_kernel_sizes, - upsample_factors=upsample_factors, - resnet_kernel_sizes=resnet_kernel_sizes, - resnet_dilations=resnet_dilations, - act_fn=act_fn, - leaky_relu_negative_slope=leaky_relu_negative_slope, - antialias=antialias, - antialias_ratio=antialias_ratio, - antialias_kernel_size=antialias_kernel_size, - final_act_fn=final_act_fn, - final_bias=final_bias, - output_sampling_rate=input_sampling_rate, - ) - self.bwe_generator = LTX2Vocoder( - in_channels=bwe_in_channels, - hidden_channels=bwe_hidden_channels, - out_channels=bwe_out_channels, - upsample_kernel_sizes=bwe_upsample_kernel_sizes, - upsample_factors=bwe_upsample_factors, - resnet_kernel_sizes=bwe_resnet_kernel_sizes, - resnet_dilations=bwe_resnet_dilations, - act_fn=bwe_act_fn, - leaky_relu_negative_slope=bwe_leaky_relu_negative_slope, - antialias=bwe_antialias, - antialias_ratio=bwe_antialias_ratio, - antialias_kernel_size=bwe_antialias_kernel_size, - final_act_fn=bwe_final_act_fn, - final_bias=bwe_final_bias, - output_sampling_rate=output_sampling_rate, +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...models.autoencoders.vocoder_ltx2 import ( + AntiAliasAct1d, # noqa: F401 re-exported for back-compat + CausalSTFT, # noqa: F401 + DownSample1d, # noqa: F401 + MelSTFT, # noqa: F401 + ResBlock, # noqa: F401 + SnakeBeta, # noqa: F401 + UpSample1d, # noqa: F401 + kaiser_sinc_filter1d, # noqa: F401 +) +from ...models.autoencoders.vocoder_ltx2 import ( + LTX2Vocoder as _LTX2Vocoder, +) +from ...models.autoencoders.vocoder_ltx2 import ( + LTX2VocoderWithBWE as _LTX2VocoderWithBWE, +) +from ...utils import deprecate + + +# The deprecation warning is emitted from ``__new__`` rather than ``__init__`` so the shim does not +# override the parent's ``__init__`` signature — ``ConfigMixin.extract_init_dict`` reflects on +# ``inspect.signature(cls.__init__)`` to decide which saved config keys to forward at +# ``from_pretrained`` time, and an ``__init__(self, *args, **kwargs)`` override would erase them all. +class LTX2Vocoder(_LTX2Vocoder): + def __new__(cls, *args, **kwargs): + deprecate( + "LTX2Vocoder", + "1.0.0", + "Importing `LTX2Vocoder` from `diffusers.pipelines.ltx2.vocoder` is deprecated. " + "Import it from `diffusers.models.autoencoders` instead " + "(or `from diffusers import LTX2Vocoder`).", ) + return super().__new__(cls) - self.mel_stft = MelSTFT( - filter_length=filter_length, - hop_length=hop_length, - window_length=window_length, - num_mel_channels=num_mel_channels, - ) - self.resampler = UpSample1d( - ratio=output_sampling_rate // input_sampling_rate, - window_type="hann", - persistent=False, +class LTX2VocoderWithBWE(_LTX2VocoderWithBWE): + def __new__(cls, *args, **kwargs): + deprecate( + "LTX2VocoderWithBWE", + "1.0.0", + "Importing `LTX2VocoderWithBWE` from `diffusers.pipelines.ltx2.vocoder` is deprecated. " + "Import it from `diffusers.models.autoencoders` instead " + "(or `from diffusers import LTX2VocoderWithBWE`).", ) - - def forward(self, mel_spec: torch.Tensor) -> torch.Tensor: - # 1. Run stage 1 vocoder to get low sampling rate waveform - x = self.vocoder(mel_spec) - batch_size, num_channels, num_samples = x.shape - - # Pad to exact multiple of hop_length for exact mel frame count - remainder = num_samples % self.config.hop_length - if remainder != 0: - x = F.pad(x, (0, self.hop_length - remainder)) - - # 2. Compute mel spectrogram on vocoder output - mel, _, _, _ = self.mel_stft(x.flatten(0, 1)) - mel = mel.unflatten(0, (-1, num_channels)) - - # 3. Run bandwidth extender (BWE) on new mel spectrogram - mel_for_bwe = mel.transpose(2, 3) # [B, C, num_mel_bins, num_frames] --> [B, C, num_frames, num_mel_bins] - residual = self.bwe_generator(mel_for_bwe) - - # 4. Residual connection with resampler - skip = self.resampler(x) - waveform = torch.clamp(residual + skip, -1, 1) - output_samples = num_samples * self.config.output_sampling_rate // self.config.input_sampling_rate - waveform = waveform[..., :output_samples] - return waveform + return super().__new__(cls) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index d695f5e7284d..279409059489 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -388,10 +388,26 @@ def maybe_raise_or_warn( ) +# Reroutes pretrained loads past pipeline-local deprecation shims onto the canonical +# top-level diffusers export. +_RELOCATED_PIPELINE_CLASSES: dict[tuple[str, str], tuple[str, str]] = { + ("audioldm2", "AudioLDM2ProjectionModel"): ("diffusers", "AudioLDM2ProjectionModel"), + ("audioldm2", "AudioLDM2UNet2DConditionModel"): ("diffusers", "AudioLDM2UNet2DConditionModel"), + ("stable_audio", "StableAudioProjectionModel"): ("diffusers", "StableAudioProjectionModel"), + ("deepfloyd_if", "IFWatermarker"): ("diffusers", "IFWatermarker"), + ("ltx2", "LTX2TextConnectors"): ("diffusers", "LTX2TextConnectors"), + ("ltx2", "LTX2Vocoder"): ("diffusers", "LTX2Vocoder"), +} + + # a simpler version of get_class_obj_and_candidates, it won't work with custom code def simple_get_class_obj(library_name, class_name): from diffusers import pipelines + remapped = _RELOCATED_PIPELINE_CLASSES.get((library_name, class_name)) + if remapped is not None: + library_name, class_name = remapped + is_pipeline_module = hasattr(pipelines, library_name) if is_pipeline_module: @@ -425,6 +441,11 @@ def get_class_obj_and_candidates( if class_name.startswith("FlashPack"): class_name = class_name.removeprefix("FlashPack") + remapped = _RELOCATED_PIPELINE_CLASSES.get((library_name, class_name)) + if remapped is not None: + library_name, class_name = remapped + is_pipeline_module = False + if is_pipeline_module: pipeline_module = getattr(pipelines, library_name) diff --git a/src/diffusers/pipelines/qwenimage/__init__.py b/src/diffusers/pipelines/qwenimage/__init__.py index 3f43d0ebb0b9..465b6320349c 100644 --- a/src/diffusers/pipelines/qwenimage/__init__.py +++ b/src/diffusers/pipelines/qwenimage/__init__.py @@ -22,7 +22,6 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"] _import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"] _import_structure["pipeline_qwenimage_controlnet"] = ["QwenImageControlNetPipeline"] _import_structure["pipeline_qwenimage_controlnet_inpaint"] = ["QwenImageControlNetInpaintPipeline"] diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py index eea83aff9e10..8b4b879f467b 100644 --- a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py @@ -21,6 +21,7 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer from ...models import PriorTransformer +from ...models.others import ShapERenderer from ...schedulers import HeunDiscreteScheduler from ...utils import ( BaseOutput, @@ -30,7 +31,6 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .renderer import ShapERenderer if is_torch_xla_available(): diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py index f59fd298c684..8a9f7e88a4f2 100644 --- a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py @@ -20,6 +20,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModel from ...models import PriorTransformer +from ...models.others import ShapERenderer from ...schedulers import HeunDiscreteScheduler from ...utils import ( BaseOutput, @@ -29,7 +30,6 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .renderer import ShapERenderer if is_torch_xla_available(): diff --git a/src/diffusers/pipelines/shap_e/renderer.py b/src/diffusers/pipelines/shap_e/renderer.py index 0c2058c887fc..84eda719e4b6 100644 --- a/src/diffusers/pipelines/shap_e/renderer.py +++ b/src/diffusers/pipelines/shap_e/renderer.py @@ -12,1038 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from dataclasses import dataclass - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models import ModelMixin -from ...utils import BaseOutput -from .camera import create_pan_cameras - - -def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor: - r""" - Sample from the given discrete probability distribution with replacement. - - The i-th bin is assumed to have mass pmf[i]. - - Args: - pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all() - n_samples: number of samples - - Return: - indices sampled with replacement - """ - - *shape, support_size, last_dim = pmf.shape - assert last_dim == 1 - - cdf = torch.cumsum(pmf.view(-1, support_size), dim=1) - inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device)) - - return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1) - - -def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor: - """ - Concatenate x and its positional encodings, following NeRF. - - Reference: https://huggingface.co/papers/2210.04628 - """ - if min_deg == max_deg: - return x - - scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device) - *shape, dim = x.shape - xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1) - assert xb.shape[-1] == dim * (max_deg - min_deg) - emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin() - return torch.cat([x, emb], dim=-1) - - -def encode_position(position): - return posenc_nerf(position, min_deg=0, max_deg=15) - - -def encode_direction(position, direction=None): - if direction is None: - return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8)) - else: - return posenc_nerf(direction, min_deg=0, max_deg=8) - - -def _sanitize_name(x: str) -> str: - return x.replace(".", "__") - - -def integrate_samples(volume_range, ts, density, channels): - r""" - Function integrating the model output. - - Args: - volume_range: Specifies the integral range [t0, t1] - ts: timesteps - density: torch.Tensor [batch_size, *shape, n_samples, 1] - channels: torch.Tensor [batch_size, *shape, n_samples, n_channels] - returns: - channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density - *transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume - ) - """ - - # 1. Calculate the weights - _, _, dt = volume_range.partition(ts) - ddensity = density * dt - - mass = torch.cumsum(ddensity, dim=-2) - transmittance = torch.exp(-mass[..., -1, :]) - - alphas = 1.0 - torch.exp(-ddensity) - Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2)) - # This is the probability of light hitting and reflecting off of - # something at depth [..., i, :]. - weights = alphas * Ts - - # 2. Integrate channels - channels = torch.sum(channels * weights, dim=-2) - - return channels, weights, transmittance - - -def volume_query_points(volume, grid_size): - indices = torch.arange(grid_size**3, device=volume.bbox_min.device) - zs = indices % grid_size - ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size - xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size - combined = torch.stack([xs, ys, zs], dim=1) - return (combined.float() / (grid_size - 1)) * (volume.bbox_max - volume.bbox_min) + volume.bbox_min - - -def _convert_srgb_to_linear(u: torch.Tensor): - return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4) - - -def _create_flat_edge_indices( - flat_cube_indices: torch.Tensor, - grid_size: tuple[int, int, int], -): - num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2] - y_offset = num_xs - num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2] - z_offset = num_xs + num_ys - return torch.stack( - [ - # Edges spanning x-axis. - flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] - + flat_cube_indices[:, 1] * grid_size[2] - + flat_cube_indices[:, 2], - flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] - + (flat_cube_indices[:, 1] + 1) * grid_size[2] - + flat_cube_indices[:, 2], - flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] - + flat_cube_indices[:, 1] * grid_size[2] - + flat_cube_indices[:, 2] - + 1, - flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] - + (flat_cube_indices[:, 1] + 1) * grid_size[2] - + flat_cube_indices[:, 2] - + 1, - # Edges spanning y-axis. - ( - y_offset - + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] - + flat_cube_indices[:, 1] * grid_size[2] - + flat_cube_indices[:, 2] - ), - ( - y_offset - + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] - + flat_cube_indices[:, 1] * grid_size[2] - + flat_cube_indices[:, 2] - ), - ( - y_offset - + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] - + flat_cube_indices[:, 1] * grid_size[2] - + flat_cube_indices[:, 2] - + 1 - ), - ( - y_offset - + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] - + flat_cube_indices[:, 1] * grid_size[2] - + flat_cube_indices[:, 2] - + 1 - ), - # Edges spanning z-axis. - ( - z_offset - + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) - + flat_cube_indices[:, 1] * (grid_size[2] - 1) - + flat_cube_indices[:, 2] - ), - ( - z_offset - + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) - + flat_cube_indices[:, 1] * (grid_size[2] - 1) - + flat_cube_indices[:, 2] - ), - ( - z_offset - + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) - + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) - + flat_cube_indices[:, 2] - ), - ( - z_offset - + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) - + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) - + flat_cube_indices[:, 2] - ), - ], - dim=-1, - ) - - -class VoidNeRFModel(nn.Module): - """ - Implements the default empty space model where all queries are rendered as background. - """ - - def __init__(self, background, channel_scale=255.0): - super().__init__() - background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale) - - self.register_buffer("background", background) - - def forward(self, position): - background = self.background[None].to(position.device) - - shape = position.shape[:-1] - ones = [1] * (len(shape) - 1) - n_channels = background.shape[-1] - background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]) - - return background - - -@dataclass -class VolumeRange: - t0: torch.Tensor - t1: torch.Tensor - intersected: torch.Tensor - - def __post_init__(self): - assert self.t0.shape == self.t1.shape == self.intersected.shape - - def partition(self, ts): - """ - Partitions t0 and t1 into n_samples intervals. - - Args: - ts: [batch_size, *shape, n_samples, 1] - - Return: - - lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size, - *shape, n_samples, 1] - - where - ts \\in [lower, upper] deltas = upper - lower - """ - - mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5 - lower = torch.cat([self.t0[..., None, :], mids], dim=-2) - upper = torch.cat([mids, self.t1[..., None, :]], dim=-2) - delta = upper - lower - assert lower.shape == upper.shape == delta.shape == ts.shape - return lower, upper, delta - - -class BoundingBoxVolume(nn.Module): - """ - Axis-aligned bounding box defined by the two opposite corners. - """ - - def __init__( - self, - *, - bbox_min, - bbox_max, - min_dist: float = 0.0, - min_t_range: float = 1e-3, - ): - """ - Args: - bbox_min: the left/bottommost corner of the bounding box - bbox_max: the other corner of the bounding box - min_dist: all rays should start at least this distance away from the origin. - """ - super().__init__() - - self.min_dist = min_dist - self.min_t_range = min_t_range - - self.bbox_min = torch.tensor(bbox_min) - self.bbox_max = torch.tensor(bbox_max) - self.bbox = torch.stack([self.bbox_min, self.bbox_max]) - assert self.bbox.shape == (2, 3) - assert min_dist >= 0.0 - assert min_t_range > 0.0 - - def intersect( - self, - origin: torch.Tensor, - direction: torch.Tensor, - t0_lower: torch.Tensor | None = None, - epsilon=1e-6, - ): - """ - Args: - origin: [batch_size, *shape, 3] - direction: [batch_size, *shape, 3] - t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. - params: Optional meta parameters in case Volume is parametric - epsilon: to stabilize calculations - - Return: - A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with - the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to - be on the boundary of the volume. - """ - - batch_size, *shape, _ = origin.shape - ones = [1] * len(shape) - bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device) - - def _safe_divide(a, b, epsilon=1e-6): - return a / torch.where(b < 0, b - epsilon, b + epsilon) - - ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon) - - # Cases to think about: - # - # 1. t1 <= t0: the ray does not pass through the AABB. - # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin. - # 3. t0 <= 0 <= t1: the ray starts from inside the BB - # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice. - # - # 1 and 4 are clearly handled from t0 < t1 below. - # Making t0 at least min_dist (>= 0) takes care of 2 and 3. - t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist) - t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values - assert t0.shape == t1.shape == (batch_size, *shape, 1) - if t0_lower is not None: - assert t0.shape == t0_lower.shape - t0 = torch.maximum(t0, t0_lower) - - intersected = t0 + self.min_t_range < t1 - t0 = torch.where(intersected, t0, torch.zeros_like(t0)) - t1 = torch.where(intersected, t1, torch.ones_like(t1)) - - return VolumeRange(t0=t0, t1=t1, intersected=intersected) - - -class StratifiedRaySampler(nn.Module): - """ - Instead of fixed intervals, a sample is drawn uniformly at random from each interval. - """ - - def __init__(self, depth_mode: str = "linear"): - """ - :param depth_mode: linear samples ts linearly in depth. harmonic ensures - closer points are sampled more densely. - """ - self.depth_mode = depth_mode - assert self.depth_mode in ("linear", "geometric", "harmonic") - - def sample( - self, - t0: torch.Tensor, - t1: torch.Tensor, - n_samples: int, - epsilon: float = 1e-3, - ) -> torch.Tensor: - """ - Args: - t0: start time has shape [batch_size, *shape, 1] - t1: finish time has shape [batch_size, *shape, 1] - n_samples: number of ts to sample - Return: - sampled ts of shape [batch_size, *shape, n_samples, 1] - """ - ones = [1] * (len(t0.shape) - 1) - ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device) - - if self.depth_mode == "linear": - ts = t0 * (1.0 - ts) + t1 * ts - elif self.depth_mode == "geometric": - ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp() - elif self.depth_mode == "harmonic": - # The original NeRF recommends this interpolation scheme for - # spherical scenes, but there could be some weird edge cases when - # the observer crosses from the inner to outer volume. - ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts) - - mids = 0.5 * (ts[..., 1:] + ts[..., :-1]) - upper = torch.cat([mids, t1], dim=-1) - lower = torch.cat([t0, mids], dim=-1) - # yiyi notes: add a random seed here for testing, don't forget to remove - torch.manual_seed(0) - t_rand = torch.rand_like(ts) - - ts = lower + (upper - lower) * t_rand - return ts.unsqueeze(-1) - - -class ImportanceRaySampler(nn.Module): - """ - Given the initial estimate of densities, this samples more from regions/bins expected to have objects. - """ - - def __init__( - self, - volume_range: VolumeRange, - ts: torch.Tensor, - weights: torch.Tensor, - blur_pool: bool = False, - alpha: float = 1e-5, - ): - """ - Args: - volume_range: the range in which a ray intersects the given volume. - ts: earlier samples from the coarse rendering step - weights: discretized version of density * transmittance - blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF. - alpha: small value to add to weights. - """ - self.volume_range = volume_range - self.ts = ts.clone().detach() - self.weights = weights.clone().detach() - self.blur_pool = blur_pool - self.alpha = alpha - - @torch.no_grad() - def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: - """ - Args: - t0: start time has shape [batch_size, *shape, 1] - t1: finish time has shape [batch_size, *shape, 1] - n_samples: number of ts to sample - Return: - sampled ts of shape [batch_size, *shape, n_samples, 1] - """ - lower, upper, _ = self.volume_range.partition(self.ts) - - batch_size, *shape, n_coarse_samples, _ = self.ts.shape - - weights = self.weights - if self.blur_pool: - padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2) - maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :]) - weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :]) - weights = weights + self.alpha - pmf = weights / weights.sum(dim=-2, keepdim=True) - inds = sample_pmf(pmf, n_samples) - assert inds.shape == (batch_size, *shape, n_samples, 1) - assert (inds >= 0).all() and (inds < n_coarse_samples).all() - - t_rand = torch.rand(inds.shape, device=inds.device) - lower_ = torch.gather(lower, -2, inds) - upper_ = torch.gather(upper, -2, inds) - - ts = lower_ + (upper_ - lower_) * t_rand - ts = torch.sort(ts, dim=-2).values - return ts - - -@dataclass -class MeshDecoderOutput(BaseOutput): - """ - A 3D triangle mesh with optional data at the vertices and faces. - - Args: - verts (`torch.Tensor` of shape `(N, 3)`): - array of vertext coordinates - faces (`torch.Tensor` of shape `(N, 3)`): - array of triangles, pointing to indices in verts. - vertext_channels (Dict): - vertext coordinates for each color channel - """ - - verts: torch.Tensor - faces: torch.Tensor - vertex_channels: dict[str, torch.Tensor] - - -class MeshDecoder(nn.Module): - """ - Construct meshes from Signed distance functions (SDFs) using marching cubes method - """ - - def __init__(self): - super().__init__() - cases = torch.zeros(256, 5, 3, dtype=torch.long) - masks = torch.zeros(256, 5, dtype=torch.bool) - - self.register_buffer("cases", cases) - self.register_buffer("masks", masks) - - def forward(self, field: torch.Tensor, min_point: torch.Tensor, size: torch.Tensor): - """ - For a signed distance field, produce a mesh using marching cubes. - - :param field: a 3D tensor of field values, where negative values correspond - to the outside of the shape. The dimensions correspond to the x, y, and z directions, respectively. - :param min_point: a tensor of shape [3] containing the point corresponding - to (0, 0, 0) in the field. - :param size: a tensor of shape [3] containing the per-axis distance from the - (0, 0, 0) field corner and the (-1, -1, -1) field corner. - """ - assert len(field.shape) == 3, "input must be a 3D scalar field" - dev = field.device - - cases = self.cases.to(dev) - masks = self.masks.to(dev) - - min_point = min_point.to(dev) - size = size.to(dev) - - grid_size = field.shape - grid_size_tensor = torch.tensor(grid_size).to(size) - - # Create bitmasks between 0 and 255 (inclusive) indicating the state - # of the eight corners of each cube. - bitmasks = (field > 0).to(torch.uint8) - bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1) - bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2) - bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4) - - # Compute corner coordinates across the entire grid. - corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype) - corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(grid_size[0], device=dev, dtype=field.dtype)[ - :, None, None - ] - corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(grid_size[1], device=dev, dtype=field.dtype)[ - :, None - ] - corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(grid_size[2], device=dev, dtype=field.dtype) - - # Compute all vertices across all edges in the grid, even though we will - # throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices. - # These are all midpoints, and don't account for interpolation (which is - # done later based on the used edge midpoints). - edge_midpoints = torch.cat( - [ - ((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3), - ((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3), - ((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3), - ], - dim=0, - ) - - # Create a flat array of [X, Y, Z] indices for each cube. - cube_indices = torch.zeros( - grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long - ) - cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[:, None, None] - cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[:, None] - cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev) - flat_cube_indices = cube_indices.reshape(-1, 3) - - # Create a flat array mapping each cube to 12 global edge indices. - edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size) - - # Apply the LUT to figure out the triangles. - flat_bitmasks = bitmasks.reshape(-1).long() # must cast to long for indexing to believe this not a mask - local_tris = cases[flat_bitmasks] - local_masks = masks[flat_bitmasks] - # Compute the global edge indices for the triangles. - global_tris = torch.gather(edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)).reshape( - local_tris.shape - ) - # Select the used triangles for each cube. - selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)] - - # Now we have a bunch of indices into the full list of possible vertices, - # but we want to reduce this list to only the used vertices. - used_vertex_indices = torch.unique(selected_tris.view(-1)) - used_edge_midpoints = edge_midpoints[used_vertex_indices] - old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long) - old_index_to_new_index[used_vertex_indices] = torch.arange( - len(used_vertex_indices), device=dev, dtype=torch.long - ) - - # Rewrite the triangles to use the new indices - faces = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(selected_tris.shape) - - # Compute the actual interpolated coordinates corresponding to edge midpoints. - v1 = torch.floor(used_edge_midpoints).to(torch.long) - v2 = torch.ceil(used_edge_midpoints).to(torch.long) - s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]] - s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]] - p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point - p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point - # The signs of s1 and s2 should be different. We want to find - # t such that t*s2 + (1-t)*s1 = 0. - t = (s1 / (s1 - s2))[:, None] - verts = t * p2 + (1 - t) * p1 - - return MeshDecoderOutput(verts=verts, faces=faces, vertex_channels=None) - - -@dataclass -class MLPNeRFModelOutput(BaseOutput): - density: torch.Tensor - signed_distance: torch.Tensor - channels: torch.Tensor - ts: torch.Tensor - - -class MLPNeRSTFModel(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - d_hidden: int = 256, - n_output: int = 12, - n_hidden_layers: int = 6, - act_fn: str = "swish", - insert_direction_at: int = 4, - ): - super().__init__() - - # Instantiate the MLP - - # Find out the dimension of encoded position and direction - dummy = torch.eye(1, 3) - d_posenc_pos = encode_position(position=dummy).shape[-1] - d_posenc_dir = encode_direction(position=dummy).shape[-1] - - mlp_widths = [d_hidden] * n_hidden_layers - input_widths = [d_posenc_pos] + mlp_widths - output_widths = mlp_widths + [n_output] - - if insert_direction_at is not None: - input_widths[insert_direction_at] += d_posenc_dir - - self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)]) - - if act_fn == "swish": - # self.activation = swish - # yiyi testing: - self.activation = lambda x: F.silu(x) - else: - raise ValueError(f"Unsupported activation function {act_fn}") - - self.sdf_activation = torch.tanh - self.density_activation = torch.nn.functional.relu - self.channel_activation = torch.sigmoid - - def map_indices_to_keys(self, output): - h_map = { - "sdf": (0, 1), - "density_coarse": (1, 2), - "density_fine": (2, 3), - "stf": (3, 6), - "nerf_coarse": (6, 9), - "nerf_fine": (9, 12), - } - - mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()} - - return mapped_output - - def forward(self, *, position, direction, ts, nerf_level="coarse", rendering_mode="nerf"): - h = encode_position(position) - - h_preact = h - h_directionless = None - for i, layer in enumerate(self.mlp): - if i == self.config.insert_direction_at: # 4 in the config - h_directionless = h_preact - h_direction = encode_direction(position, direction=direction) - h = torch.cat([h, h_direction], dim=-1) - - h = layer(h) - - h_preact = h - - if i < len(self.mlp) - 1: - h = self.activation(h) - - h_final = h - if h_directionless is None: - h_directionless = h_preact - - activation = self.map_indices_to_keys(h_final) - - if nerf_level == "coarse": - h_density = activation["density_coarse"] - else: - h_density = activation["density_fine"] - - if rendering_mode == "nerf": - if nerf_level == "coarse": - h_channels = activation["nerf_coarse"] - else: - h_channels = activation["nerf_fine"] - - elif rendering_mode == "stf": - h_channels = activation["stf"] - - density = self.density_activation(h_density) - signed_distance = self.sdf_activation(activation["sdf"]) - channels = self.channel_activation(h_channels) - - # yiyi notes: I think signed_distance is not used - return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts) - - -class ChannelsProj(nn.Module): - def __init__( - self, - *, - vectors: int, - channels: int, - d_latent: int, - ): - super().__init__() - self.proj = nn.Linear(d_latent, vectors * channels) - self.norm = nn.LayerNorm(channels) - self.d_latent = d_latent - self.vectors = vectors - self.channels = channels - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_bvd = x - w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) - b_vc = self.proj.bias.view(1, self.vectors, self.channels) - h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) - h = self.norm(h) - - h = h + b_vc - return h - - -class ShapEParamsProjModel(ModelMixin, ConfigMixin): - """ - project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP). - - For more details, see the original paper: - """ - - @register_to_config - def __init__( - self, - *, - param_names: tuple[str] = ( - "nerstf.mlp.0.weight", - "nerstf.mlp.1.weight", - "nerstf.mlp.2.weight", - "nerstf.mlp.3.weight", - ), - param_shapes: tuple[tuple[int]] = ( - (256, 93), - (256, 256), - (256, 256), - (256, 256), - ), - d_latent: int = 1024, - ): - super().__init__() - - # check inputs - if len(param_names) != len(param_shapes): - raise ValueError("Must provide same number of `param_names` as `param_shapes`") - self.projections = nn.ModuleDict({}) - for k, (vectors, channels) in zip(param_names, param_shapes): - self.projections[_sanitize_name(k)] = ChannelsProj( - vectors=vectors, - channels=channels, - d_latent=d_latent, - ) - - def forward(self, x: torch.Tensor): - out = {} - start = 0 - for k, shape in zip(self.config.param_names, self.config.param_shapes): - vectors, _ = shape - end = start + vectors - x_bvd = x[:, start:end] - out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) - start = end - return out - - -class ShapERenderer(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - *, - param_names: tuple[str] = ( - "nerstf.mlp.0.weight", - "nerstf.mlp.1.weight", - "nerstf.mlp.2.weight", - "nerstf.mlp.3.weight", - ), - param_shapes: tuple[tuple[int]] = ( - (256, 93), - (256, 256), - (256, 256), - (256, 256), - ), - d_latent: int = 1024, - d_hidden: int = 256, - n_output: int = 12, - n_hidden_layers: int = 6, - act_fn: str = "swish", - insert_direction_at: int = 4, - background: tuple[float] = ( - 255.0, - 255.0, - 255.0, - ), - ): - super().__init__() - - self.params_proj = ShapEParamsProjModel( - param_names=param_names, - param_shapes=param_shapes, - d_latent=d_latent, - ) - self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at) - self.void = VoidNeRFModel(background=background, channel_scale=255.0) - self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0]) - self.mesh_decoder = MeshDecoder() - - @torch.no_grad() - def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False): - """ - Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below - with some abuse of notations) - - C(r) := sum( - transmittance(t[i]) * integrate( - lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]], - ) for i in range(len(parts)) - ) + transmittance(t[-1]) * void_model(t[-1]).channels - - where - - 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through - the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are - obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t - where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the - shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and - transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1], - math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). - - Args: - rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples: - number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including - - :return: A tuple of - - `channels` - - A importance samplers for additional fine-grained rendering - - raw model output - """ - origin, direction = rays[..., 0, :], rays[..., 1, :] - - # Integrate over [t[i], t[i + 1]] - - # 1 Intersect the rays with the current volume and sample ts to integrate along. - vrange = self.volume.intersect(origin, direction, t0_lower=None) - ts = sampler.sample(vrange.t0, vrange.t1, n_samples) - ts = ts.to(rays.dtype) - - if prev_model_out is not None: - # Append the previous ts now before fprop because previous - # rendering used a different model and we can't reuse the output. - ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values - - batch_size, *_shape, _t0_dim = vrange.t0.shape - _, *ts_shape, _ts_dim = ts.shape - - # 2. Get the points along the ray and query the model - directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3]) - positions = origin.unsqueeze(-2) + ts * directions - - directions = directions.to(self.mlp.dtype) - positions = positions.to(self.mlp.dtype) - - optional_directions = directions if render_with_direction else None - - model_out = self.mlp( - position=positions, - direction=optional_directions, - ts=ts, - nerf_level="coarse" if prev_model_out is None else "fine", - ) - - # 3. Integrate the model results - channels, weights, transmittance = integrate_samples( - vrange, model_out.ts, model_out.density, model_out.channels - ) - - # 4. Clean up results that do not intersect with the volume. - transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance)) - channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels)) - # 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). - channels = channels + transmittance * self.void(origin) - - weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights) - - return channels, weighted_sampler, model_out - - @torch.no_grad() - def decode_to_image( - self, - latents, - device, - size: int = 64, - ray_batch_size: int = 4096, - n_coarse_samples=64, - n_fine_samples=128, - ): - # project the parameters from the generated latents - projected_params = self.params_proj(latents) - - # update the mlp layers of the renderer - for name, param in self.mlp.state_dict().items(): - if f"nerstf.{name}" in projected_params.keys(): - param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) - - # create cameras object - camera = create_pan_cameras(size) - rays = camera.camera_rays - rays = rays.to(device) - n_batches = rays.shape[1] // ray_batch_size - - coarse_sampler = StratifiedRaySampler() - - images = [] - - for idx in range(n_batches): - rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size] - - # render rays with coarse, stratified samples. - _, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples) - # Then, render with additional importance-weighted ray samples. - channels, _, _ = self.render_rays( - rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out - ) - - images.append(channels) - - images = torch.cat(images, dim=1) - images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0) - - return images - - @torch.no_grad() - def decode_to_mesh( - self, - latents, - device, - grid_size: int = 128, - query_batch_size: int = 4096, - texture_channels: tuple = ("R", "G", "B"), - ): - # 1. project the parameters from the generated latents - projected_params = self.params_proj(latents) - - # 2. update the mlp layers of the renderer - for name, param in self.mlp.state_dict().items(): - if f"nerstf.{name}" in projected_params.keys(): - param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) - - # 3. decoding with STF rendering - # 3.1 query the SDF values at vertices along a regular 128**3 grid - - query_points = volume_query_points(self.volume, grid_size) - query_positions = query_points[None].repeat(1, 1, 1).to(device=device, dtype=self.mlp.dtype) - - fields = [] - - for idx in range(0, query_positions.shape[1], query_batch_size): - query_batch = query_positions[:, idx : idx + query_batch_size] - - model_out = self.mlp( - position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" - ) - fields.append(model_out.signed_distance) - - # predicted SDF values - fields = torch.cat(fields, dim=1) - fields = fields.float() - - assert len(fields.shape) == 3 and fields.shape[-1] == 1, ( - f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" - ) - - fields = fields.reshape(1, *([grid_size] * 3)) - - # create grid 128 x 128 x 128 - # - force a negative border around the SDFs to close off all the models. - full_grid = torch.zeros( - 1, - grid_size + 2, - grid_size + 2, - grid_size + 2, - device=fields.device, - dtype=fields.dtype, - ) - full_grid.fill_(-1.0) - full_grid[:, 1:-1, 1:-1, 1:-1] = fields - fields = full_grid - - # apply a differentiable implementation of Marching Cubes to construct meshs - raw_meshes = [] - mesh_mask = [] - - for field in fields: - raw_mesh = self.mesh_decoder(field, self.volume.bbox_min, self.volume.bbox_max - self.volume.bbox_min) - mesh_mask.append(True) - raw_meshes.append(raw_mesh) - - mesh_mask = torch.tensor(mesh_mask, device=fields.device) - max_vertices = max(len(m.verts) for m in raw_meshes) - - # 3.2. query the texture color head at each vertex of the resulting mesh. - texture_query_positions = torch.stack( - [m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes], - dim=0, - ) - texture_query_positions = texture_query_positions.to(device=device, dtype=self.mlp.dtype) - - textures = [] - - for idx in range(0, texture_query_positions.shape[1], query_batch_size): - query_batch = texture_query_positions[:, idx : idx + query_batch_size] - - texture_model_out = self.mlp( - position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" - ) - textures.append(texture_model_out.channels) - - # predict texture color - textures = torch.cat(textures, dim=1) - - textures = _convert_srgb_to_linear(textures) - textures = textures.float() - - # 3.3 augment the mesh with texture data - assert len(textures.shape) == 3 and textures.shape[-1] == len(texture_channels), ( - f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" - ) - - for m, texture in zip(raw_meshes, textures): - texture = texture[: len(m.verts)] - m.vertex_channels = dict(zip(texture_channels, texture.unbind(-1))) - - return raw_meshes[0] +from ...models.others.renderer_shap_e import ( + BoundingBoxVolume, # noqa: F401 re-exported for back-compat + ImportanceRaySampler, # noqa: F401 + MLPNeRFModelOutput, # noqa: F401 + MLPNeRSTFModel, # noqa: F401 + ShapEParamsProjModel, # noqa: F401 + StratifiedRaySampler, # noqa: F401 + VoidNeRFModel, # noqa: F401 +) +from ...models.others.renderer_shap_e import ( + ShapERenderer as _ShapERenderer, +) +from ...utils import deprecate + + +# The deprecation warning is emitted from ``__new__`` rather than ``__init__`` so the shim does not +# override the parent's ``__init__`` signature — ``ConfigMixin.extract_init_dict`` reflects on +# ``inspect.signature(cls.__init__)`` to decide which saved config keys to forward at +# ``from_pretrained`` time, and an ``__init__(self, *args, **kwargs)`` override would erase them all. +class ShapERenderer(_ShapERenderer): + def __new__(cls, *args, **kwargs): + deprecate( + "ShapERenderer", + "1.0.0", + "Importing `ShapERenderer` from `diffusers.pipelines.shap_e.renderer` is deprecated. " + "Import it from `diffusers.models.others` instead.", + ) + return super().__new__(cls) diff --git a/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py b/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py index d40269411bc0..4cfb301bcbc1 100644 --- a/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py +++ b/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py @@ -12,145 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from math import pi - -import torch -import torch.nn as nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.modeling_utils import ModelMixin -from ...utils import BaseOutput, logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class StableAudioPositionalEmbedding(nn.Module): - """Used for continuous time""" - - def __init__(self, dim: int): - super().__init__() - assert (dim % 2) == 0 - half_dim = dim // 2 - self.weights = nn.Parameter(torch.randn(half_dim)) - - def forward(self, times: torch.Tensor) -> torch.Tensor: - times = times[..., None] - freqs = times * self.weights[None] * 2 * pi - fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) - fouriered = torch.cat((times, fouriered), dim=-1) - return fouriered - - -@dataclass -class StableAudioProjectionModelOutput(BaseOutput): - """ - Args: - Class for StableAudio projection layer's outputs. - text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder. - seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): - Sequence of hidden-states obtained by linearly projecting the audio start hidden states. - seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): - Sequence of hidden-states obtained by linearly projecting the audio end hidden states. - """ - - text_hidden_states: torch.Tensor | None = None - seconds_start_hidden_states: torch.Tensor | None = None - seconds_end_hidden_states: torch.Tensor | None = None - - -class StableAudioNumberConditioner(nn.Module): - """ - A simple linear projection model to map numbers to a latent space. - - Args: - number_embedding_dim (`int`): - Dimensionality of the number embeddings. - min_value (`int`): - The minimum value of the seconds number conditioning modules. - max_value (`int`): - The maximum value of the seconds number conditioning modules - internal_dim (`int`): - Dimensionality of the intermediate number hidden states. - """ - - def __init__( - self, - number_embedding_dim, - min_value, - max_value, - internal_dim: int | None = 256, - ): - super().__init__() - self.time_positional_embedding = nn.Sequential( - StableAudioPositionalEmbedding(internal_dim), - nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim), - ) - - self.number_embedding_dim = number_embedding_dim - self.min_value = min_value - self.max_value = max_value - - def forward( - self, - floats: torch.Tensor, - ): - floats = floats.clamp(self.min_value, self.max_value) - - normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) - - # Cast floats to same type as embedder - embedder_dtype = next(self.time_positional_embedding.parameters()).dtype - normalized_floats = normalized_floats.to(embedder_dtype) - - embedding = self.time_positional_embedding(normalized_floats) - float_embeds = embedding.view(-1, 1, self.number_embedding_dim) - - return float_embeds - - -class StableAudioProjectionModel(ModelMixin, ConfigMixin): - """ - A simple linear projection model to map the conditioning values to a shared latent space. - - Args: - text_encoder_dim (`int`): - Dimensionality of the text embeddings from the text encoder (T5). - conditioning_dim (`int`): - Dimensionality of the output conditioning tensors. - min_value (`int`): - The minimum value of the seconds number conditioning modules. - max_value (`int`): - The maximum value of the seconds number conditioning modules - """ - - @register_to_config - def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value): - super().__init__() - self.text_projection = ( - nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim) - ) - self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) - self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) - - def forward( - self, - text_hidden_states: torch.Tensor | None = None, - start_seconds: torch.Tensor | None = None, - end_seconds: torch.Tensor | None = None, - ): - text_hidden_states = ( - text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states) - ) - seconds_start_hidden_states = ( - start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds) - ) - seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds) - - return StableAudioProjectionModelOutput( - text_hidden_states=text_hidden_states, - seconds_start_hidden_states=seconds_start_hidden_states, - seconds_end_hidden_states=seconds_end_hidden_states, +from ...models.condition_embedders.projection_stable_audio import ( + StableAudioNumberConditioner, # noqa: F401 re-exported for back-compat + StableAudioPositionalEmbedding, # noqa: F401 re-exported for back-compat + StableAudioProjectionModelOutput, # noqa: F401 re-exported for back-compat +) +from ...models.condition_embedders.projection_stable_audio import ( + StableAudioProjectionModel as _StableAudioProjectionModel, +) +from ...utils import deprecate + + +# The deprecation warning is emitted from ``__new__`` rather than ``__init__`` so the shim does not +# override the parent's ``__init__`` signature — ``ConfigMixin.extract_init_dict`` reflects on +# ``inspect.signature(cls.__init__)`` to decide which saved config keys to forward at +# ``from_pretrained`` time, and an ``__init__(self, *args, **kwargs)`` override would erase them all. +class StableAudioProjectionModel(_StableAudioProjectionModel): + def __new__(cls, *args, **kwargs): + deprecate( + "StableAudioProjectionModel", + "1.0.0", + "Importing `StableAudioProjectionModel` from `diffusers.pipelines.stable_audio.modeling_stable_audio` is " + "deprecated. Import it from `diffusers.models.condition_embedders` instead " + "(or `from diffusers import StableAudioProjectionModel`).", ) + return super().__new__(cls) diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py index 351c8b65de0e..4f69d718f52e 100644 --- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py +++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py @@ -23,12 +23,12 @@ ) from ...models import AutoencoderOobleck, StableAudioDiTModel +from ...models.condition_embedders import StableAudioProjectionModel from ...models.embeddings import get_1d_rotary_pos_embed from ...schedulers import EDMDPMSolverMultistepScheduler from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline -from .modeling_stable_audio import StableAudioProjectionModel if is_torch_xla_available(): diff --git a/src/diffusers/pipelines/stable_diffusion/clip_image_project_model.py b/src/diffusers/pipelines/stable_diffusion/clip_image_project_model.py index 30dd90242d07..ae7b5b68269b 100644 --- a/src/diffusers/pipelines/stable_diffusion/clip_image_project_model.py +++ b/src/diffusers/pipelines/stable_diffusion/clip_image_project_model.py @@ -12,18 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from torch import nn +from ...models.condition_embedders.projection_clip_image import CLIPImageProjection as _CLIPImageProjection +from ...utils import deprecate -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.modeling_utils import ModelMixin - -class CLIPImageProjection(ModelMixin, ConfigMixin): - @register_to_config - def __init__(self, hidden_size: int = 768): - super().__init__() - self.hidden_size = hidden_size - self.project = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - - def forward(self, x): - return self.project(x) +# The deprecation warning is emitted from ``__new__`` rather than ``__init__`` so the shim does not +# override the parent's ``__init__`` signature — ``ConfigMixin.extract_init_dict`` reflects on +# ``inspect.signature(cls.__init__)`` to decide which saved config keys to forward at +# ``from_pretrained`` time, and an ``__init__(self, *args, **kwargs)`` override would erase them all. +class CLIPImageProjection(_CLIPImageProjection): + def __new__(cls, *args, **kwargs): + deprecate( + "CLIPImageProjection", + "1.0.0", + "Importing `CLIPImageProjection` from `diffusers.pipelines.stable_diffusion.clip_image_project_model` is " + "deprecated. Import it from `diffusers.models.condition_embedders` instead " + "(or `from diffusers import CLIPImageProjection`).", + ) + return super().__new__(cls) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 0c8fd842fcba..e6059ce40e23 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -39,6 +39,7 @@ PriorTransformer, UNet2DConditionModel, ) +from ...models.others import StableUnCLIPImageNormalizer from ...schedulers import ( DDIMScheduler, DDPMScheduler, @@ -57,7 +58,6 @@ from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from ..pipeline_utils import DiffusionPipeline from .safety_checker import StableDiffusionSafetyChecker -from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer if is_accelerate_available(): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 7015e9727ea5..ce39a540b889 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -24,6 +24,7 @@ from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...models.lora import adjust_lora_scale_text_encoder +from ...models.others import StableUnCLIPImageNormalizer from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, @@ -36,7 +37,6 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin -from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer if is_torch_xla_available(): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index bb96e5db0295..b361701a6ce7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -24,6 +24,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...models.lora import adjust_lora_scale_text_encoder +from ...models.others import StableUnCLIPImageNormalizer from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, @@ -36,7 +37,6 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin -from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer if is_torch_xla_available(): diff --git a/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py b/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py index ba91a0f23923..c317b871021c 100644 --- a/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +++ b/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py @@ -12,44 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -from torch import nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.modeling_utils import ModelMixin - - -class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): - """ - This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. - - It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image - embeddings. - """ - - @register_to_config - def __init__( - self, - embedding_dim: int = 768, - ): - super().__init__() - - self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) - self.std = nn.Parameter(torch.ones(1, embedding_dim)) - - def to( - self, - torch_device: str | torch.device | None = None, - torch_dtype: torch.dtype | None = None, - ): - self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype)) - self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype)) - return self - - def scale(self, embeds): - embeds = (embeds - self.mean) * 1.0 / self.std - return embeds - - def unscale(self, embeds): - embeds = (embeds * self.std) + self.mean - return embeds +from ...models.others.image_normalizer_stable_unclip import ( + StableUnCLIPImageNormalizer as _StableUnCLIPImageNormalizer, +) +from ...utils import deprecate + + +# The deprecation warning is emitted from ``__new__`` rather than ``__init__`` so the shim does not +# override the parent's ``__init__`` signature — ``ConfigMixin.extract_init_dict`` reflects on +# ``inspect.signature(cls.__init__)`` to decide which saved config keys to forward at +# ``from_pretrained`` time, and an ``__init__(self, *args, **kwargs)`` override would erase them all. +class StableUnCLIPImageNormalizer(_StableUnCLIPImageNormalizer): + def __new__(cls, *args, **kwargs): + deprecate( + "StableUnCLIPImageNormalizer", + "1.0.0", + "Importing `StableUnCLIPImageNormalizer` from `diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer` " + "is deprecated. Import it from `diffusers.models.others` instead.", + ) + return super().__new__(cls) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 0ce20a4f7d97..103f0475c331 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -405,6 +405,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AceStepAudioTokenDetokenizer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AceStepAudioTokenizer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AceStepConditionEncoder(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AceStepTransformer1DModel(metaclass=DummyObject): _backends = ["torch"] @@ -465,6 +510,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AudioLDM2ProjectionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AudioLDM2UNet2DConditionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AuraFlowTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] @@ -930,6 +1005,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CLIPImageProjection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CogVideoXTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1365,6 +1455,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class IFWatermarker(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class JoyImageEditTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1470,6 +1575,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LTX2LatentUpsamplerModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class LTX2TextConnectors(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LTX2VideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1485,6 +1620,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LTX2Vocoder(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class LTX2VocoderWithBWE(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class LTXLatentUpsamplerModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LTXVideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1770,6 +1950,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ReduxImageEncoder(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class SanaControlNetModel(metaclass=DummyObject): _backends = ["torch"] @@ -1860,6 +2055,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ShapERenderer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class SkyReelsV2Transformer3DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1905,6 +2115,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class StableAudioProjectionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class StableUnCLIPImageNormalizer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class T2IAdapter(metaclass=DummyObject): _backends = ["torch"] @@ -2462,21 +2702,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class CLIPImageProjection(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class ConsistencyModelPipeline(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 1e9bb67a768a..545bc78045ec 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -662,51 +662,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class AceStepAudioTokenDetokenizer(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class AceStepAudioTokenizer(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class AceStepConditionEncoder(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class AceStepPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -932,36 +887,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class AudioLDM2ProjectionModel(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class AudioLDM2UNet2DConditionModel(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class AudioLDMPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1097,21 +1022,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class CLIPImageProjection(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class CogVideoXFunControlPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -3167,21 +3077,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class ReduxImageEncoder(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class SanaControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -3422,21 +3317,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableAudioProjectionModel(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class StableCascadeCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/test_relocated_class_loading.py b/tests/models/test_relocated_class_loading.py new file mode 100644 index 000000000000..003850bef319 --- /dev/null +++ b/tests/models/test_relocated_class_loading.py @@ -0,0 +1,78 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import warnings + +import pytest +import torch + + +# (deprecated module path, class name, kwargs). Most ctors have full defaults so ``kwargs`` is +# empty; only ``AudioLDM2ProjectionModel`` and ``StableAudioProjectionModel`` have required +# positional args, so we pass the smallest values that satisfy the signature. +DEPRECATED_PATHS = [ + ("diffusers.pipelines.stable_diffusion.clip_image_project_model", "CLIPImageProjection", {}), + ( + "diffusers.pipelines.audioldm2.modeling_audioldm2", + "AudioLDM2ProjectionModel", + {"text_encoder_dim": 8, "text_encoder_1_dim": 8, "langauge_model_dim": 16}, + ), + ("diffusers.pipelines.audioldm2.modeling_audioldm2", "AudioLDM2UNet2DConditionModel", {}), + ( + "diffusers.pipelines.stable_audio.modeling_stable_audio", + "StableAudioProjectionModel", + {"text_encoder_dim": 8, "conditioning_dim": 8, "min_value": 0, "max_value": 10}, + ), + ("diffusers.pipelines.flux.modeling_flux", "ReduxImageEncoder", {}), + ("diffusers.pipelines.ltx.modeling_latent_upsampler", "LTXLatentUpsamplerModel", {}), + ("diffusers.pipelines.ltx2.latent_upsampler", "LTX2LatentUpsamplerModel", {}), + ("diffusers.pipelines.ltx2.vocoder", "LTX2Vocoder", {}), + ("diffusers.pipelines.ltx2.vocoder", "LTX2VocoderWithBWE", {}), + ("diffusers.pipelines.ltx2.connectors", "LTX2TextConnectors", {}), + ("diffusers.pipelines.ace_step.modeling_ace_step", "AceStepAudioTokenizer", {}), + ("diffusers.pipelines.ace_step.modeling_ace_step", "AceStepAudioTokenDetokenizer", {}), + ("diffusers.pipelines.ace_step.modeling_ace_step", "AceStepConditionEncoder", {}), + ("diffusers.pipelines.ace_step.modeling_ace_step", "AceStepLyricEncoder", {}), + ("diffusers.pipelines.ace_step.modeling_ace_step", "AceStepTimbreEncoder", {}), + ("diffusers.pipelines.shap_e.renderer", "ShapERenderer", {}), + ("diffusers.pipelines.deepfloyd_if.watermark", "IFWatermarker", {}), + ("diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer", "StableUnCLIPImageNormalizer", {}), +] + + +@pytest.mark.parametrize( + "module, name, kwargs", + DEPRECATED_PATHS, + ids=[name for _, name, _ in DEPRECATED_PATHS], +) +def test_deprecated_path_warns_on_use(module, name, kwargs): + """Constructing the relocated class through its deprecated pipeline path emits FutureWarning. + + Instantiation runs under ``torch.device("meta")`` so the test stays fast and CPU-only — the + parameters are meta tensors and no real memory is allocated. We only verify the deprecation + signal here; functional behaviour of each class is covered by its own dedicated tests at the + canonical model path. + """ + mod = importlib.import_module(module) + cls = getattr(mod, name) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + with torch.device("meta"): + cls(**kwargs) + + assert any(issubclass(w.category, FutureWarning) and "deprecated" in str(w.message).lower() for w in caught), ( + f"expected a FutureWarning containing 'deprecated' for {module}.{name}, got: {[str(w.message) for w in caught]}" + ) diff --git a/tests/pipelines/ace_step/test_ace_step.py b/tests/pipelines/ace_step/test_ace_step.py index 6be8bfd155f0..8034a4c8efd3 100644 --- a/tests/pipelines/ace_step/test_ace_step.py +++ b/tests/pipelines/ace_step/test_ace_step.py @@ -20,13 +20,14 @@ import torch from transformers import AutoTokenizer, Qwen3Config, Qwen3Model -from diffusers import AutoencoderOobleck, FlowMatchEulerDiscreteScheduler -from diffusers.models.transformers.ace_step_transformer import AceStepTransformer1DModel -from diffusers.pipelines.ace_step import ( +from diffusers import ( AceStepAudioTokenDetokenizer, AceStepAudioTokenizer, AceStepConditionEncoder, AceStepPipeline, + AceStepTransformer1DModel, + AutoencoderOobleck, + FlowMatchEulerDiscreteScheduler, ) from ...testing_utils import enable_full_determinism diff --git a/tests/pipelines/ltx/test_ltx_latent_upsample.py b/tests/pipelines/ltx/test_ltx_latent_upsample.py index 0044a85c644b..874c53d3535d 100644 --- a/tests/pipelines/ltx/test_ltx_latent_upsample.py +++ b/tests/pipelines/ltx/test_ltx_latent_upsample.py @@ -18,7 +18,7 @@ import torch from diffusers import AutoencoderKLLTXVideo, LTXLatentUpsamplePipeline -from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel +from diffusers.models.autoencoders import LTXLatentUpsamplerModel from ...testing_utils import enable_full_determinism from ..test_pipelines_common import PipelineTesterMixin, to_np diff --git a/tests/pipelines/ltx2/test_ltx2.py b/tests/pipelines/ltx2/test_ltx2.py index 0941ae550989..bf39dc3bc8c7 100644 --- a/tests/pipelines/ltx2/test_ltx2.py +++ b/tests/pipelines/ltx2/test_ltx2.py @@ -24,8 +24,8 @@ LTX2Pipeline, LTX2VideoTransformer3DModel, ) -from diffusers.pipelines.ltx2 import LTX2TextConnectors -from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder +from diffusers.models.autoencoders import LTX2Vocoder +from diffusers.models.condition_embedders import LTX2TextConnectors from ...testing_utils import enable_full_determinism from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS diff --git a/tests/pipelines/ltx2/test_ltx2_hdr.py b/tests/pipelines/ltx2/test_ltx2_hdr.py index f92f2535f34e..8157487de03b 100644 --- a/tests/pipelines/ltx2/test_ltx2_hdr.py +++ b/tests/pipelines/ltx2/test_ltx2_hdr.py @@ -27,9 +27,8 @@ LTX2HDRPipeline, LTX2VideoTransformer3DModel, ) +from diffusers.models.autoencoders import LTX2LatentUpsamplerModel, LTX2Vocoder from diffusers.pipelines.ltx2 import LTX2HDRReferenceCondition, LTX2TextConnectors -from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel -from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder from diffusers.utils import logging from ...testing_utils import enable_full_determinism, require_accelerator, torch_device diff --git a/tests/pipelines/ltx2/test_ltx2_image2video.py b/tests/pipelines/ltx2/test_ltx2_image2video.py index a0e4cb803084..3baf8f67e84f 100644 --- a/tests/pipelines/ltx2/test_ltx2_image2video.py +++ b/tests/pipelines/ltx2/test_ltx2_image2video.py @@ -24,9 +24,8 @@ LTX2ImageToVideoPipeline, LTX2VideoTransformer3DModel, ) +from diffusers.models.autoencoders import LTX2LatentUpsamplerModel, LTX2Vocoder from diffusers.pipelines.ltx2 import LTX2LatentUpsamplePipeline, LTX2TextConnectors -from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel -from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder from ...testing_utils import enable_full_determinism from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS diff --git a/tests/pipelines/shap_e/test_shap_e.py b/tests/pipelines/shap_e/test_shap_e.py index 99fd28692981..77ffbcdabb96 100644 --- a/tests/pipelines/shap_e/test_shap_e.py +++ b/tests/pipelines/shap_e/test_shap_e.py @@ -20,7 +20,7 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import HeunDiscreteScheduler, PriorTransformer, ShapEPipeline -from diffusers.pipelines.shap_e import ShapERenderer +from diffusers.models.others import ShapERenderer from ...testing_utils import ( backend_empty_cache, diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py index 8923c2f63cee..4d68fb798b69 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -12,7 +12,7 @@ StableUnCLIPPipeline, UNet2DConditionModel, ) -from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from diffusers.models.others import StableUnCLIPImageNormalizer from ...testing_utils import ( backend_empty_cache, diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index e7a0fbccef67..739d5a3f94f8 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -14,8 +14,8 @@ ) from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableUnCLIPImg2ImgPipeline, UNet2DConditionModel +from diffusers.models.others import StableUnCLIPImageNormalizer from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from diffusers.utils.import_utils import is_xformers_available from ...testing_utils import ( diff --git a/tests/pipelines/test_relocated_class_loading.py b/tests/pipelines/test_relocated_class_loading.py new file mode 100644 index 000000000000..d88103655e2f --- /dev/null +++ b/tests/pipelines/test_relocated_class_loading.py @@ -0,0 +1,56 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import pytest + +import diffusers + + +TINY_PIPELINE_CLASSES = [ + "LTXLatentUpsamplePipeline", + "LTX2Pipeline", + "AudioLDM2Pipeline", + "StableAudioPipeline", + "ShapEPipeline", + "AceStepPipeline", + "IFPipeline", +] + + +@pytest.mark.parametrize("pipeline_class_name", TINY_PIPELINE_CLASSES) +def test_tiny_pipeline_loads_without_relocation_warning(pipeline_class_name): + """ + Loading a pipeline from a pretrained checkpoint must not trigger any relocation-shim deprecation. + """ + pipeline_class = getattr(diffusers, pipeline_class_name) + repo_id = f"hf-internal-testing/tiny-{pipeline_class_name}" + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + pipeline_class.from_pretrained(repo_id) + + relocation_warnings = [ + w + for w in caught + if issubclass(w.category, FutureWarning) + and "Importing " in str(w.message) + and "diffusers.pipelines." in str(w.message) + and "is deprecated" in str(w.message) + ] + assert not relocation_warnings, ( + f"Loading {pipeline_class_name} from {repo_id} triggered relocation shim FutureWarning(s):\n" + + "\n".join(f" - {w.message}" for w in relocation_warnings) + )