diff --git a/docs/source/en/api/models/anyflow_far_transformer3d.md b/docs/source/en/api/models/anyflow_far_transformer3d.md index 3a9909b4887a..7f818c44ef20 100644 --- a/docs/source/en/api/models/anyflow_far_transformer3d.md +++ b/docs/source/en/api/models/anyflow_far_transformer3d.md @@ -13,19 +13,22 @@ specific language governing permissions and limitations under the License. # AnyFlowFARTransformer3DModel The causal (FAR) 3D Transformer used by [`AnyFlowFARPipeline`](../pipelines/anyflow#anyflowfarpipeline) — -the FAR variant of [AnyFlow](https://huggingface.co/papers/2605.13724) (Yuchao Gu, Guian Fang et al., NUS -ShowLab × NVIDIA). It extends the v0.35.1 Wan2.1 backbone with three additions: +the FAR variant of [AnyFlow](https://huggingface.co/papers/2605.13724). See the +[`AnyFlowFARPipeline`](../pipelines/anyflow) page for paper, authors, and released checkpoints. It extends +the v0.35.1 Wan2.1 backbone with three additions: -1. **FAR causal block-mask** via `torch.nn.attention.flex_attention`, supporting frame-level autoregressive - generation as introduced in [FAR (Gu et al., 2025)](https://arxiv.org/abs/2503.19325). +1. **FAR causal block-mask** via `torch.nn.attention.flex_attention`, supporting chunk-wise autoregressive + generation as introduced in [FAR](https://huggingface.co/papers/2503.19325). 2. **Compressed-frame patch embedding** (`far_patch_embedding`) for context (already-generated) frames, warm-started from the full-resolution `patch_embedding` at construction time via trilinear interpolation. 3. **Dual-timestep flow-map embedding** (same as [`AnyFlowTransformer3DModel`](anyflow_transformer3d)) — every forward call conditions on both the source timestep ``t`` and the target timestep ``r``. -The chunk schedule (`chunk_partition`) is **not** baked into the model config. It is a per-call argument to -`forward`, so the same checkpoint handles different `num_frames` configurations without retraining. +The default chunk schedule (`chunk_partition`) is stored in the model config; the released NVIDIA AnyFlow-FAR +checkpoints use `[1, 3, 3, 3, 3, 3, 3, 2]` for the canonical 81-frame setting. `forward` accepts a per-call +`chunk_partition` override, so the same checkpoint also handles other `num_frames` configurations without +retraining. ```python from diffusers import AnyFlowFARTransformer3DModel diff --git a/docs/source/en/api/models/anyflow_transformer3d.md b/docs/source/en/api/models/anyflow_transformer3d.md index 95888080c0ce..d37f7fba62fb 100644 --- a/docs/source/en/api/models/anyflow_transformer3d.md +++ b/docs/source/en/api/models/anyflow_transformer3d.md @@ -16,10 +16,11 @@ The bidirectional 3D Transformer used by [`AnyFlowPipeline`](../pipelines/anyflo v0.35.1 Wan2.1 backbone with one structural change: the timestep embedder is replaced by ``AnyFlowDualTimestepTextImageEmbedding``, so every forward call conditions on both the source timestep ``t`` and the target timestep ``r``. This is the embedding required to learn the flow map -:math:`\Phi_{r\leftarrow t}` introduced in -[AnyFlow](https://huggingface.co/papers/2605.13724) (Yuchao Gu, Guian Fang et al., NUS ShowLab × NVIDIA). +$\Phi_{r\leftarrow t}$ introduced in +[AnyFlow](https://huggingface.co/papers/2605.13724). See the [`AnyFlowPipeline`](../pipelines/anyflow) page +for paper, authors, and released checkpoints. -For frame-level autoregressive (FAR causal) generation, use +For chunk-wise autoregressive (FAR causal) generation, use [`AnyFlowFARTransformer3DModel`](anyflow_far_transformer3d) instead. ```python diff --git a/docs/source/en/api/pipelines/anyflow.md b/docs/source/en/api/pipelines/anyflow.md index 9358b8d454fc..9e496a61113f 100644 --- a/docs/source/en/api/pipelines/anyflow.md +++ b/docs/source/en/api/pipelines/anyflow.md @@ -20,68 +20,28 @@ specific language governing permissions and limitations under the License. # AnyFlow -[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/2605.13724) by Yuchao Gu, Guian Fang and collaborators at [NUS ShowLab](https://sites.google.com/view/showlab) in collaboration with NVIDIA. +[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/2605.13724) from NVIDIA, National University of Singapore, and Massachusetts Institute of Technology, by Yuchao Gu, Guian Fang, Yuxin Jiang, Weijia Mao, Song Han, Han Cai, Mike Zheng Shou. + +> **TL;DR:** AnyFlow is the first any-step video diffusion framework built on flow maps, which enables a single model (bidirectional or causal) to adapt to arbitrary inference budgets. *Few-step video generation has been significantly advanced by consistency models. However, their performance often degrades in any-step video diffusion models due to the fixed-point formulation. To address this limitation, we present AnyFlow, the first any-step video diffusion distillation framework built on flow maps. Instead of learning only the mapping z_t → z_0, AnyFlow learns transitions z_t → z_r over arbitrary time intervals, enabling a single model to adapt to different inference budgets. We design an improved forward flow map training recipe that fine-tunes pretrained video diffusion models into flow map models, and introduce Flow Map Backward Simulation to enable on-policy distillation for flow map models. Extensive experiments across both bidirectional and causal architectures, at scales ranging from 1.3B to 14B, on text-to-video and image-to-video tasks demonstrate that AnyFlow outperforms consistency-based baselines while preserving high fidelity and flexible sampling under varying step budgets.* -The original training code is at [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow). The project page is at [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow). +The AnyFlow pipelines were contributed by the AnyFlow Team. The original code is available on [GitHub](https://github.com/NVlabs/AnyFlow), the project page is at [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow), and pretrained models can be found in the [nvidia/anyflow](https://huggingface.co/collections/nvidia/anyflow) collection on Hugging Face. -The following AnyFlow checkpoints are supported: +Available Models: | Checkpoint | Backbone | Description | -|------------|----------|-------------| -| [`nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers) | Wan2.1 1.3B | Bidirectional T2V, lightweight | -| [`nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers) | Wan2.1 14B | Bidirectional T2V, full quality | +|---|---|---| +| [`nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers) | Wan2.1 1.3B | Bidirectional T2V | +| [`nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers) | Wan2.1 14B | Bidirectional T2V | | [`nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers) | FAR + Wan2.1 1.3B | Causal T2V / I2V / V2V | | [`nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers) | FAR + Wan2.1 14B | Causal T2V / I2V / V2V | -All four are grouped under the [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection. - > [!TIP] -> Choose `AnyFlowPipeline` for traditional bidirectional text-to-video generation. Choose `AnyFlowFARPipeline` for streaming I2V, video continuation (V2V), or any setup that benefits from frame-by-frame autoregressive sampling. - -> [!TIP] -> AnyFlow supports any-step sampling: a single distilled checkpoint can be evaluated at 1, 2, 4, 8, 16... NFE without retraining. Quality scales monotonically with steps in our benchmarks. - -### Optimizing Memory and Inference Speed - - - - -```py -import torch -from diffusers import AnyFlowPipeline -from diffusers.hooks import apply_group_offloading - -pipe = AnyFlowPipeline.from_pretrained( - "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 -) -apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level") -pipe.vae.enable_slicing() -pipe.vae.enable_tiling() -``` - - - - -```py -import torch -from diffusers import AnyFlowPipeline - -pipe = AnyFlowPipeline.from_pretrained( - "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 -).to("cuda") -pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") -``` - - - +> `AnyFlowPipeline` is designed for bidirectional diffusion models in text-to-video (T2V) generation. `AnyFlowFARPipeline` is a chunk-wise causal diffusion model that supports text-to-video (T2V) generation, image-to-video (I2V) generation, and video continuation (V2V). ### Generation with AnyFlow (Bidirectional T2V) - - - ```py import torch from diffusers import AnyFlowPipeline @@ -91,14 +51,16 @@ pipe = AnyFlowPipeline.from_pretrained( "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16 ).to("cuda") -prompt = "A red panda eating bamboo in a forest, cinematic lighting" -video = pipe(prompt, num_inference_steps=4, num_frames=33).frames[0] -export_to_video(video, "out.mp4", fps=16) +prompt = ( + "An astronaut runs smoothly and appears almost weightless on the lunar surface, " + "as seen from a low-angle shot that highlights the vast, desolate background of the moon. " + "The moon's craters and rocky terrain are clearly visible, creating a stark contrast against " + "the running astronaut who moves with graceful, fluid motions." +) +video = pipe(prompt, num_inference_steps=4, num_frames=81).frames[0] +export_to_video(video, "anyflow_t2v.mp4", fps=16) ``` - - - ### Generation with AnyFlow (FAR Causal) The causal pipeline selects between T2V / I2V / V2V via the ``video`` (or ``video_latents``) argument: @@ -108,10 +70,10 @@ clip for V2V continuation. If you already have pre-encoded latents in the model ``video_latents=`` to skip VAE encoding. ``video`` and ``video_latents`` are mutually exclusive. > [!IMPORTANT] -> `AnyFlowFARPipeline.default_chunk_partition = [1, 3, 3, 3, 3, 3, 3, 2]` (sum 21) is matched to the -> released checkpoints' canonical 81 raw frames (21 latent frames at the VAE temporal stride of 4). When -> you change `num_frames`, you must also pass a matching `chunk_partition` summing to -> `(num_frames - 1) // 4 + 1`, otherwise the pipeline raises an `AssertionError`. +> The released checkpoints bake `chunk_partition=[1, 3, 3, 3, 3, 3, 3, 2]` (sum 21) into the transformer +> config, matched to the canonical 81 raw frames (21 latent frames at the VAE temporal stride of 4). When +> you change `num_frames`, pass a matching `chunk_partition` summing to `(num_frames - 1) // 4 + 1`, +> otherwise the pipeline raises a `ValueError`. @@ -125,12 +87,12 @@ pipe = AnyFlowFARPipeline.from_pretrained( "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 ).to("cuda") -video = pipe( - prompt="A cat surfing a wave, sunset", - num_inference_steps=4, - num_frames=81, -).frames[0] -export_to_video(video, "out.mp4", fps=16) +prompt = ( + "An astronaut runs smoothly and appears almost weightless on the lunar surface, " + "as seen from a low-angle shot that highlights the vast, desolate background of the moon." +) +video = pipe(prompt, num_inference_steps=4, num_frames=81).frames[0] +export_to_video(video, "anyflow_far_t2v.mp4", fps=16) ``` @@ -146,18 +108,25 @@ pipe = AnyFlowFARPipeline.from_pretrained( "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 ).to("cuda") -# Wrap the conditioning image as a one-frame video tensor: (1, 1, 3, H, W) in [0, 1]. -first_frame = load_image("path/to/first_frame.png").resize((832, 480)) +# Example conditioning image from the AnyFlow repo. +first_frame = load_image( + "https://raw.githubusercontent.com/NVlabs/AnyFlow/main/assets/evaluation/example/images/1.jpg" +).resize((832, 480)) arr = np.asarray(first_frame).astype("float32") / 255.0 # (480, 832, 3) -context_tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).unsqueeze(1).to("cuda") +context_tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).unsqueeze(1).to("cuda") # (1, 1, 3, 480, 832) +prompt = ( + "A towering, battle-scarred humanoid robot, reminiscent of a Transformer with powerful, segmented armor " + "and glowing red optics, walking through the skeletal remains of a city ruin. Twisted metal and shattered " + "concrete crunch under its heavy steps, as the robot scans the desolate, dust-choked skyline under an dark sky." +) video = pipe( - prompt="a cat walks across a sunlit lawn", + prompt=prompt, video=context_tensor, num_inference_steps=4, num_frames=81, ).frames[0] -export_to_video(video, "out.mp4", fps=16) +export_to_video(video, "anyflow_far_i2v.mp4", fps=16) ``` @@ -173,21 +142,26 @@ pipe = AnyFlowFARPipeline.from_pretrained( "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16 ).to("cuda") -# Context clip — 9 raw frames map to 3 latent frames (9 = 4·2 + 1, 3 = 2 + 1). -context_frames = load_video("path/to/context.mp4")[:9] +# Example conditioning clip from the AnyFlow repo — take the first 9 frames (3 latent frames at VAE temporal stride 4). +context_frames = load_video( + "https://raw.githubusercontent.com/NVlabs/AnyFlow/main/assets/evaluation/example/videos/2.mp4" +)[:9] arr = np.stack([np.asarray(f.resize((832, 480))) for f in context_frames]).astype("float32") / 255.0 -# np.stack gives (T, H, W, C) = (9, 480, 832, 3) → permute to (T, C, H, W) then add batch. context_tensor = torch.from_numpy(arr).permute(0, 3, 1, 2).unsqueeze(0).to("cuda") # (1, 9, 3, 480, 832) +prompt = ( + "A focused trail runner's powerful strides through a dense, sun-dappled forest. " + "The camera tracks alongside, highlighting muscular exertion, sweat, and determined facial expression." +) video = pipe( - prompt="continue the story", + prompt=prompt, video=context_tensor, num_inference_steps=4, num_frames=81, # Override chunk_partition so the first chunk covers exactly the 3 latent context frames. chunk_partition=[3, 3, 3, 3, 3, 3, 3], ).frames[0] -export_to_video(video, "out.mp4", fps=16) +export_to_video(video, "anyflow_far_v2v.mp4", fps=16) ``` diff --git a/docs/source/zh/using-diffusers/anyflow.md b/docs/source/zh/using-diffusers/anyflow.md index 575cdb1c1cb8..e9c925a85256 100644 --- a/docs/source/zh/using-diffusers/anyflow.md +++ b/docs/source/zh/using-diffusers/anyflow.md @@ -22,7 +22,7 @@ NFE 增加反而经常掉点。 采样步之间的 re-noising;on-policy 蒸馏阶段额外用 **DMD 反向散度监督** + **Flow-Map backward simulation** (3 段 shortcut)补上 consistency 蒸馏遗留的 exposure-bias 缺口。 -AnyFlow 由 Yuchao Gu、Guian Fang 等人在 [NUS ShowLab](https://sites.google.com/view/showlab) 与 NVIDIA 合作完成。原始训练代码在 [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow),项目主页是 [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow)。4 个发布 checkpoint 归在 [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection 里。 +AnyFlow 由 NVIDIA、新加坡国立大学(NUS)和 MIT 合作完成,作者为 Yuchao Gu、Guian Fang、Yuxin Jiang、Weijia Mao、Song Han、Han Cai、Mike Zheng Shou。原始训练代码在 [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow),项目主页是 [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow),4 个发布 checkpoint 归在 [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection 里。 本文档梳理实战要点:怎么选 pipeline、怎么用 any-step 采样、怎么把 AnyFlow 嵌进 T2V / I2V / V2V 工作流。 @@ -100,7 +100,7 @@ prompt = "森林里一只小熊猫在啃竹子,电影感光照" for nfe in [1, 2, 4, 8, 16, 32]: # 每轮重建 generator —— 这样跨步数对比时唯一变量是 NFE。 generator = torch.Generator("cuda").manual_seed(0) - video = pipe(prompt, num_inference_steps=nfe, num_frames=33, generator=generator).frames[0] + video = pipe(prompt, num_inference_steps=nfe, num_frames=81, generator=generator).frames[0] export_to_video(video, f"out_nfe{nfe}.mp4", fps=16) ``` @@ -125,11 +125,11 @@ Causal pipeline 用同一个蒸馏模型支持三种任务模式,**通过 `vid Context tensor 的帧数必须满足 `T = 4n + 1`,跟 VAE 时间步长对齐。 > [!IMPORTANT] -> FAR pipeline 是分块 (chunk) rollout,`num_frames` 必须配合 chunk 调度。默认 -> `chunk_partition=[1, 3, 3, 3, 3, 3, 3, 2]`(求和 21)对应发布 checkpoint 的标准 `num_frames=81` -> (21 = (81 − 1) // 4 + 1)。改 `num_frames` 时**必须**显式传匹配的 `chunk_partition`,使其求和等于 -> `(num_frames - 1) // 4 + 1`,否则 pipeline 会抛 `AssertionError`。比如 `num_frames=33` 对应 9 个 latent -> 帧,可用 `chunk_partition=[1, 4, 4]`。 +> FAR pipeline 是分块 (chunk) rollout,`num_frames` 必须配合 chunk 调度。发布的 checkpoint 在 +> transformer config 里写入 `chunk_partition=[1, 3, 3, 3, 3, 3, 3, 2]`(求和 21),对应标准 +> `num_frames=81`(21 = (81 − 1) // 4 + 1)。改 `num_frames` 时**必须**显式传匹配的 `chunk_partition`, +> 使其求和等于 `(num_frames - 1) // 4 + 1`,否则 pipeline 会抛 `ValueError`。比如 `num_frames=33` 对应 +> 9 个 latent 帧,可用 `chunk_partition=[1, 4, 4]`。 ```py import numpy as np @@ -183,33 +183,6 @@ export_to_video(video, "v2v.mp4", fps=16) 如果你已经有 VAE 编码过的 latent,可以直接传 `video_latents=` 跳过 `vae_encode` 步骤 (和 `video` 互斥)。 -## 显存与推理速度 - -14B 的 AnyFlow 模型用 group offload + VAE slicing 单卡 40 GB 能跑: - -```py -import torch -from diffusers import AnyFlowPipeline -from diffusers.hooks import apply_group_offloading - -pipe = AnyFlowPipeline.from_pretrained( - "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16 -) -apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level") -pipe.vae.enable_slicing() -pipe.vae.enable_tiling() -``` - -延迟方面,`torch.compile` 对 transformer(最重的模块)效果很好: - -```py -pipe = pipe.to("cuda") -pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") -``` - -编译开销跑几步就摊销掉;配合 AnyFlow 的低 NFE(4-8 步),`torch.compile` 在 14B 上相比 eager -模式有明显加速。 - ## LoRA 微调 两个 pipeline 都复用 [`WanLoraLoaderMixin`](../api/loaders/lora),因此为对应 Wan2.1 backbone 训练的 diff --git a/scripts/convert_anyflow_to_diffusers.py b/scripts/convert_anyflow_to_diffusers.py index 60574ca23a1e..229d286c4701 100644 --- a/scripts/convert_anyflow_to_diffusers.py +++ b/scripts/convert_anyflow_to_diffusers.py @@ -57,13 +57,21 @@ "AnyFlow-FAR-Wan2.1-1.3B-Diffusers": { "base_model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", "transformer_cls": AnyFlowFARTransformer3DModel, - "transformer_kwargs": {"full_chunk_limit": 3, "compressed_patch_size": [1, 4, 4]}, + "transformer_kwargs": { + "full_chunk_limit": 3, + "compressed_patch_size": [1, 4, 4], + "chunk_partition": [1, 3, 3, 3, 3, 3, 3, 2], + }, "pipeline_cls": AnyFlowFARPipeline, }, "AnyFlow-FAR-Wan2.1-14B-Diffusers": { "base_model": "Wan-AI/Wan2.1-T2V-14B-Diffusers", "transformer_cls": AnyFlowFARTransformer3DModel, - "transformer_kwargs": {"full_chunk_limit": 3, "compressed_patch_size": [1, 4, 4]}, + "transformer_kwargs": { + "full_chunk_limit": 3, + "compressed_patch_size": [1, 4, 4], + "chunk_partition": [1, 3, 3, 3, 3, 3, 3, 2], + }, "pipeline_cls": AnyFlowFARPipeline, }, "AnyFlow-Wan2.1-T2V-1.3B-Diffusers": { diff --git a/src/diffusers/models/transformers/transformer_anyflow.py b/src/diffusers/models/transformers/transformer_anyflow.py index 2ac554419e5e..873e0b095b33 100644 --- a/src/diffusers/models/transformers/transformer_anyflow.py +++ b/src/diffusers/models/transformers/transformer_anyflow.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# This file derives from the FAR architecture (Gu et al., 2025, arXiv:2503.19325) and adds the -# AnyFlow dual-timestep flow-map embedding (AnyFlowDualTimestepTextImageEmbedding) introduced by -# Yuchao Gu, Guian Fang et al. (arXiv:2605.13724). The base 3D DiT structure is adapted from the +# This file derives from the FAR architecture (arXiv:2503.19325) and adds the +# AnyFlow dual-timestep flow-map embedding (AnyFlowDualTimestepTextImageEmbedding) introduced in +# AnyFlow (arXiv:2605.13724). The base 3D DiT structure is adapted from the # v0.35.1 Wan2.1 transformer (transformer_wan.py); upstream Wan has since been refactored, so # this file is intentionally self-contained rather than annotated with `# Copied from`. @@ -334,8 +334,11 @@ def __init__( self._freqs_cache: Optional[Tuple[Any, torch.Tensor]] = None def _build_freqs(self, device: torch.device) -> torch.Tensor: + # Skip the cache read/write inside torch.compile: mutating ``self._freqs_cache`` between calls + # becomes a Dynamo guard and forces recompilation on the second invocation. + is_compiling = torch.compiler.is_compiling() cache_key = (device.type, str(device)) - if self._freqs_cache is not None and self._freqs_cache[0] == cache_key: + if not is_compiling and self._freqs_cache is not None and self._freqs_cache[0] == cache_key: return self._freqs_cache[1] is_mps = device.type == "mps" @@ -357,7 +360,8 @@ def _build_freqs(self, device: torch.device) -> torch.Tensor: ) freqs_list.append(f.to(device)) freqs = torch.cat(freqs_list, dim=1) - self._freqs_cache = (cache_key, freqs) + if not is_compiling: + self._freqs_cache = (cache_key, freqs) return freqs def _forward_full_frame(self, num_frames, height, width, device) -> torch.Tensor: @@ -510,10 +514,9 @@ class AnyFlowTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO The architecture is the v0.35.1 Wan2.1 3D DiT backbone with one structural change: the timestep embedder is replaced by ``AnyFlowDualTimestepTextImageEmbedding`` so that every forward call conditions on both the source timestep ``t`` and the target timestep ``r``. This is the embedding required to learn the flow map - :math:`\Phi_{r\leftarrow t}` introduced in [AnyFlow](https://huggingface.co/papers/2605.13724) by Yuchao Gu, Guian - Fang et al. + :math:`\Phi_{r\leftarrow t}` introduced in [AnyFlow](https://huggingface.co/papers/2605.13724). - For frame-level autoregressive (FAR causal) generation, use ``AnyFlowFARTransformer3DModel`` instead; that variant + For chunk-wise autoregressive (FAR causal) generation, use ``AnyFlowFARTransformer3DModel`` instead; that variant adds the FAR causal block-mask and a compressed-frame patch embedding on top of the same backbone. Args: diff --git a/src/diffusers/models/transformers/transformer_anyflow_far.py b/src/diffusers/models/transformers/transformer_anyflow_far.py index a40e2fafcb61..4b418f129636 100644 --- a/src/diffusers/models/transformers/transformer_anyflow_far.py +++ b/src/diffusers/models/transformers/transformer_anyflow_far.py @@ -14,9 +14,9 @@ # # This file is the FAR causal sibling of `transformer_anyflow.py`. Shared submodules are duplicated # via `# Copied from` so `make fix-copies` keeps both files in sync; this keeps each transformer -# variant readable in isolation. The FAR architecture comes from Gu et al., 2025 +# variant readable in isolation. The FAR architecture comes from FAR # (arXiv:2503.19325); the dual-timestep flow-map embedding is AnyFlow's contribution -# (Yuchao Gu, Guian Fang et al., arXiv:2605.13724). +# (arXiv:2605.13724). import math from dataclasses import dataclass @@ -25,7 +25,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.attention.flex_attention import create_block_mask +from torch.nn.attention.flex_attention import BlockMask, create_block_mask from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin @@ -77,7 +77,7 @@ class AnyFlowCausalAttnProcessor: autoregressive read (cache-read step). Requires the ``flex`` attention backend — the ``BlockMask`` produced by - :class:`AnyFlowFARTransformer3DModel._build_causal_mask` is consumed only by the flex backend. A clear + :meth:`AnyFlowFARTransformer3DModel.build_attention_mask` is consumed only by the flex backend. A clear :class:`ValueError` is raised if a non-flex backend is configured via ``_attention_backend``. """ @@ -643,8 +643,11 @@ def __init__( # Copied from diffusers.models.transformers.transformer_anyflow.AnyFlowRotaryPosEmbed._build_freqs def _build_freqs(self, device: torch.device) -> torch.Tensor: + # Skip the cache read/write inside torch.compile: mutating ``self._freqs_cache`` between calls + # becomes a Dynamo guard and forces recompilation on the second invocation. + is_compiling = torch.compiler.is_compiling() cache_key = (device.type, str(device)) - if self._freqs_cache is not None and self._freqs_cache[0] == cache_key: + if not is_compiling and self._freqs_cache is not None and self._freqs_cache[0] == cache_key: return self._freqs_cache[1] is_mps = device.type == "mps" @@ -666,7 +669,8 @@ def _build_freqs(self, device: torch.device) -> torch.Tensor: ) freqs_list.append(f.to(device)) freqs = torch.cat(freqs_list, dim=1) - self._freqs_cache = (cache_key, freqs) + if not is_compiling: + self._freqs_cache = (cache_key, freqs) return freqs def avg_pool_complex(self, freq: torch.Tensor, kernel_size: int, stride: int): @@ -774,14 +778,191 @@ def forward(self, far_cfg, device, clean_hidden_states=None): return {"query": freqs, "key": freqs} +def _build_anyflow_far_causal_block_mask( + chunk_partition: List[int], + height: int, + width: int, + patch_size: Tuple[int, int, int], + compressed_patch_size: Tuple[int, int, int], + full_chunk_limit: int, + *, + mode: str = "train", + has_clean_context: bool = False, + device: Optional[torch.device] = None, +) -> BlockMask: + r"""Build the causal :class:`~torch.nn.attention.flex_attention.BlockMask` for the FAR transformer. + + Provided as a standalone function so callers can construct the mask *outside* the transformer's compiled region, + which is required to wrap the forward in ``torch.compile(fullgraph=True)`` (``flex_attention.create_block_mask`` + itself uses ``_compile=False`` internally and breaks the graph when invoked inside the compiled scope). + + Two modes are exposed, mirroring the FAR forward paths that actually consume a mask. The autoregressive + ``_forward_inference`` path attends through the KV cache and does not use a full BlockMask, so it has no + corresponding mode here. + + Args: + chunk_partition: per-chunk frame counts; must sum to the number of latent frames. + height, width: latent spatial dimensions. + patch_size, compressed_patch_size, full_chunk_limit: must match the transformer config. + mode: ``"train"`` (strict ``>`` comparison against ``full_chunk_limit``, matches + :meth:`AnyFlowFARTransformer3DModel._forward_train`) or ``"cache"`` (``>=`` comparison via the + ``full_chunk_limit - 1`` offset used by :meth:`AnyFlowFARTransformer3DModel._forward_cache`). + has_clean_context: ``True`` when ``clean_hidden_states`` is being threaded through the + transformer (training V2V/I2V). + device: device for the resulting BlockMask. Defaults to CPU. + """ + if mode not in {"train", "cache"}: + raise ValueError(f"Unknown mode {mode!r}; expected 'train' or 'cache'.") + full_token_per_frame = (height // patch_size[1]) * (width // patch_size[2]) + compressed_token_per_frame = (height // compressed_patch_size[1]) * (width // compressed_patch_size[2]) + + # `cache` uses `full_chunk_limit - 1` (an effective `>= full_chunk_limit` comparison); `train` uses a strict `>`. + total_chunks = len(chunk_partition) + threshold = full_chunk_limit - 1 if mode == "cache" else full_chunk_limit + if total_chunks > threshold: + num_full_chunk = threshold + num_compressed_chunk = total_chunks - threshold + else: + num_full_chunk, num_compressed_chunk = total_chunks, 0 + + far_cfg = { + "num_full_chunk": num_full_chunk, + "num_compressed_chunk": num_compressed_chunk, + "num_full_frames": sum(chunk_partition[num_compressed_chunk:]), + "num_compressed_frames": sum(chunk_partition[:num_compressed_chunk]), + "full_token_per_frame": full_token_per_frame, + "compressed_token_per_frame": compressed_token_per_frame, + "chunk_partition": chunk_partition, + } + return _build_far_block_mask_from_far_cfg(far_cfg, has_clean=has_clean_context, device=device) + + +def _build_far_block_mask_from_far_cfg(far_cfg, has_clean, device): + """Internal: build a BlockMask given an already-computed ``far_cfg`` dict. + + Factored out of :class:`AnyFlowFARTransformer3DModel` so it can be shared between + :func:`_build_anyflow_far_causal_block_mask` (the user-facing entry point) and the in-forward fallback path used + when no pre-built ``attention_mask`` is passed. + """ + chunk_partition = far_cfg["chunk_partition"] + + noise_seq_len = clean_seq_len = far_cfg["num_full_frames"] * far_cfg["full_token_per_frame"] + context_seq_len = far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"] + + noise_start = context_seq_len + noise_end = noise_start + noise_seq_len + + clean_start = context_seq_len + noise_seq_len + clean_end = clean_start + clean_seq_len + + if has_clean: + real_seq_len = context_seq_len + noise_seq_len + clean_seq_len + else: + real_seq_len = context_seq_len + noise_seq_len + + padded_seq_len = int(math.ceil(real_seq_len / 128.0) * 128.0) + + context_chunk_partition, noise_chunk_partition = ( + chunk_partition[: far_cfg["num_compressed_chunk"]], + chunk_partition[far_cfg["num_compressed_chunk"] :], + ) + + if len(context_chunk_partition) != 0: + context_frame_idx = torch.cat( + [ + torch.ones(chunk_len * far_cfg["compressed_token_per_frame"], device=device) * chunk_idx + for chunk_idx, chunk_len in enumerate(context_chunk_partition) + ] + ) + else: + context_frame_idx = None + + if has_clean: + noise_frame_idx = clean_frame_idx = torch.cat( + [ + torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device) + * (chunk_idx + len(context_chunk_partition)) + for chunk_idx, chunk_len in enumerate(noise_chunk_partition) + ] + ) + pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device) + + if len(context_chunk_partition) != 0: + frame_idx = torch.cat([context_frame_idx, noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0) + else: + frame_idx = torch.cat([noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0) + + def mask_mod(b, h, q_idx, kv_idx): + # 1) is padding + is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len) + + # 2) chunk causal + base = frame_idx[q_idx] >= frame_idx[kv_idx] + + # 3) interval mask + q_is_noise = (q_idx >= noise_start) & (q_idx < noise_end) + q_is_clean = (q_idx >= clean_start) & (q_idx < clean_end) + + k_is_noise = (kv_idx >= noise_start) & (kv_idx < noise_end) + k_is_clean = (kv_idx >= clean_start) & (kv_idx < clean_end) + + # 4) clean -> noise: disallowed + is_clean_to_noise = q_is_clean & k_is_noise + + # 5) noise -> noise: only same frame + same_frame_idx = frame_idx[q_idx] == frame_idx[kv_idx] + + noise_to_noise = q_is_noise & k_is_noise + noise_to_clean = q_is_noise & k_is_clean + + noise_to_noise_allow = noise_to_noise & same_frame_idx + noise_to_noise_mask = (~noise_to_noise) | noise_to_noise_allow + + noise_to_clean_same = noise_to_clean & same_frame_idx + noise_to_clean_disallow = noise_to_clean_same + + allowed = base & ~is_padding & ~is_clean_to_noise & noise_to_noise_mask & ~noise_to_clean_disallow + return allowed + + else: + noise_frame_idx = torch.cat( + [ + torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device) + * (chunk_idx + len(context_chunk_partition)) + for chunk_idx, chunk_len in enumerate(noise_chunk_partition) + ] + ) + pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device) + + if len(context_chunk_partition) != 0: + frame_idx = torch.cat([context_frame_idx, noise_frame_idx, pad_frame_idx], dim=0) + else: + frame_idx = torch.cat([noise_frame_idx, pad_frame_idx], dim=0) + + def mask_mod(b, h, q_idx, kv_idx): + is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len) + base = frame_idx[q_idx] >= frame_idx[kv_idx] + return base & ~is_padding + + return create_block_mask( + mask_mod, + B=None, + H=None, + Q_LEN=padded_seq_len, + KV_LEN=padded_seq_len, + device=device, + _compile=False, + ) + + class AnyFlowFARTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" - Causal (FAR) 3D Transformer for AnyFlow flow-map sampling with frame-level autoregressive generation. + Causal (FAR) 3D Transformer for AnyFlow flow-map sampling with chunk-wise autoregressive generation. Extends the v0.35.1 Wan2.1 backbone with: - * **FAR causal block-mask** via :func:`torch.nn.attention.flex_attention`, supporting frame-level autoregressive - generation (FAR; [Gu et al., 2025](https://arxiv.org/abs/2503.19325)). + * **FAR causal block-mask** via :func:`torch.nn.attention.flex_attention`, supporting chunk-wise autoregressive + generation ([FAR](https://huggingface.co/papers/2503.19325)). * **Compressed-frame patch embedding** ``far_patch_embedding`` for context (already-generated) frames, initialized from ``patch_embedding`` via trilinear interpolation so a freshly constructed model is already at a reasonable starting point even before LoRA fine-tuning. @@ -826,11 +1007,11 @@ class AnyFlowFARTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr Mixing gate between source-timestep and delta-timestep embeddings. deltatime_type (`str`, defaults to `'r'`): Either ``"r"`` (delta is the target timestep) or ``"t-r"`` (delta is the absolute interval). - - .. note:: - ``chunk_partition`` is **not** a model config field — it is a per-call argument passed to :meth:`forward`. - Different inference setups (varying ``num_frames`` or full-vs-compressed schedules) therefore do not require - separate checkpoints. + chunk_partition (`Tuple[int, ...]`, defaults to `(1, 3, 3, 3, 3, 3, 3, 2)`): + Default per-chunk frame counts used by the pipeline. The released NVIDIA AnyFlow-FAR checkpoints target + ``num_frames=81`` (21 latent frames at VAE temporal stride 4) split as ``1 + 3*6 + 2``. A different + ``num_frames`` requires a matching ``chunk_partition`` override passed to + :meth:`AnyFlowFARPipeline.__call__` (and likewise to :meth:`forward`). """ _supports_gradient_checkpointing = True @@ -859,6 +1040,7 @@ def __init__( rope_max_seq_len: int = 1024, gate_value: float = 0.25, deltatime_type: str = "r", + chunk_partition: Tuple[int, ...] = (1, 3, 3, 3, 3, 3, 3, 2), ) -> None: super().__init__() @@ -923,6 +1105,7 @@ def forward( clean_timestep: Optional[torch.Tensor] = None, kv_cache: Optional[List[Dict[str, torch.Tensor]]] = None, kv_cache_flag: Optional[Dict[str, Any]] = None, + attention_mask: Optional[BlockMask] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Transformer2DModelOutput, AnyFlowFARTransformerOutput, Tuple]: @@ -955,6 +1138,12 @@ def forward( Per-block KV cache for autoregressive inference. `None` selects the training path. kv_cache_flag (`Dict[str, Any]`, *optional*): KV-cache metadata (e.g. ``is_cache_step`` flag and token counts). + attention_mask (`BlockMask`, *optional*): + Pre-built causal mask, typically constructed via :meth:`build_attention_mask`. Consumed by the train + and KV-cache prefill paths; the autoregressive inference path attends through the KV cache and does not + use a full mask. When ``None``, the train / cache paths build the mask internally; that fallback is not + compile-safe (the underlying ``flex_attention.create_block_mask`` breaks the graph under + ``fullgraph=True``), so pass a pre-built mask whenever wrapping ``forward`` in ``torch.compile``. attention_kwargs (`dict`, *optional*): Forwarded to the attention processors. return_dict (`bool`, *optional*, defaults to `True`): @@ -978,12 +1167,14 @@ def forward( return self._forward_cache( clean_hidden_states=clean_hidden_states, clean_timestep=clean_timestep, + attention_mask=attention_mask, **common, ) return self._forward_inference(**common) return self._forward_train( clean_hidden_states=clean_hidden_states, clean_timestep=clean_timestep, + attention_mask=attention_mask, **common, ) @@ -1029,142 +1220,50 @@ def _forward_far_patchify_inference(self, hidden_states): hidden_states = self.patch_embedding(hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2) return hidden_states - def _build_causal_mask(self, far_cfg, clean_hidden_states, device, dtype): - chunk_partition = far_cfg["chunk_partition"] - - noise_seq_len = clean_seq_len = far_cfg["num_full_frames"] * far_cfg["full_token_per_frame"] - context_seq_len = far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"] - - noise_start = context_seq_len - noise_end = noise_start + noise_seq_len - - clean_start = context_seq_len + noise_seq_len - clean_end = clean_start + clean_seq_len - - if clean_hidden_states is not None: - real_seq_len = context_seq_len + noise_seq_len + clean_seq_len - else: - real_seq_len = context_seq_len + noise_seq_len - - padded_seq_len = int(math.ceil(real_seq_len / 128.0) * 128.0) - - if clean_hidden_states is not None: - context_chunk_partition, noise_chunk_partition = ( - chunk_partition[: far_cfg["num_compressed_chunk"]], - chunk_partition[far_cfg["num_compressed_chunk"] :], - ) # noqa: E501 - - if len(context_chunk_partition) != 0: - context_frame_idx = torch.cat( - [ - torch.ones(chunk_len * far_cfg["compressed_token_per_frame"], device=device) * chunk_idx - for chunk_idx, chunk_len in enumerate(context_chunk_partition) - ] - ) # noqa: E501 - else: - context_frame_idx = None - noise_frame_idx = clean_frame_idx = torch.cat( - [ - torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device) - * (chunk_idx + len(context_chunk_partition)) - for chunk_idx, chunk_len in enumerate(noise_chunk_partition) - ] - ) # noqa: E501 - pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device) - - if len(context_chunk_partition) != 0: - frame_idx = torch.cat([context_frame_idx, noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0) - else: - frame_idx = torch.cat([noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0) - - def mask_mod(b, h, q_idx, kv_idx): - # q_idx, kv_idx: LongTensor, range: [0, padded_seq_len) - - # 1) whether is padding - is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len) - - # 2) chunk causal - base = frame_idx[q_idx] >= frame_idx[kv_idx] - - # 3) interval mask - q_is_noise = (q_idx >= noise_start) & (q_idx < noise_end) - q_is_clean = (q_idx >= clean_start) & (q_idx < clean_end) - - k_is_noise = (kv_idx >= noise_start) & (kv_idx < noise_end) - k_is_clean = (kv_idx >= clean_start) & (kv_idx < clean_end) - - # 4) clean -> noise: disallowed - is_clean_to_noise = q_is_clean & k_is_noise - - # 5) noise -> noise: only same frame - same_frame_idx = frame_idx[q_idx] == frame_idx[kv_idx] - - noise_to_noise = q_is_noise & k_is_noise - noise_to_clean = q_is_noise & k_is_clean - - noise_to_noise_allow = noise_to_noise & same_frame_idx - noise_to_noise_mask = (~noise_to_noise) | noise_to_noise_allow - - noise_to_clean_same = noise_to_clean & same_frame_idx - noise_to_clean_disallow = noise_to_clean_same - - # attention mask is chunk casual - allowed = base & ~is_padding & ~is_clean_to_noise & noise_to_noise_mask & ~noise_to_clean_disallow - return allowed - - return create_block_mask( - mask_mod, - B=None, - H=None, - Q_LEN=padded_seq_len, - KV_LEN=padded_seq_len, - device=device, - _compile=False, - ) - else: - context_chunk_partition, noise_chunk_partition = ( - chunk_partition[: far_cfg["num_compressed_chunk"]], - chunk_partition[far_cfg["num_compressed_chunk"] :], - ) # noqa: E501 - - if len(context_chunk_partition) != 0: - context_frame_idx = torch.cat( - [ - torch.ones(chunk_len * far_cfg["compressed_token_per_frame"], device=device) * chunk_idx - for chunk_idx, chunk_len in enumerate(context_chunk_partition) - ] - ) # noqa: E501 - else: - context_frame_idx = None - - noise_frame_idx = torch.cat( - [ - torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device) - * (chunk_idx + len(context_chunk_partition)) - for chunk_idx, chunk_len in enumerate(noise_chunk_partition) - ] - ) # noqa: E501 - pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device) + def build_attention_mask( + self, + *, + chunk_partition: List[int], + height: int, + width: int, + has_clean_context: bool = False, + device: Optional[torch.device] = None, + mode: str = "train", + ) -> BlockMask: + r"""Pre-build the causal :class:`~torch.nn.attention.flex_attention.BlockMask` outside ``forward``. + + Pass the result via :meth:`forward`'s ``attention_mask`` kwarg to make the whole transformer compatible with + ``torch.compile(fullgraph=True)``. Without a pre-built mask, ``forward`` falls back to constructing it + internally — that path uses ``flex_attention.create_block_mask(_compile=False)`` and breaks the compile graph. - if len(context_chunk_partition) != 0: - frame_idx = torch.cat([context_frame_idx, noise_frame_idx, pad_frame_idx], dim=0) - else: - frame_idx = torch.cat([noise_frame_idx, pad_frame_idx], dim=0) - - def mask_mod(b, h, q_idx, kv_idx): - is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len) - base = frame_idx[q_idx] >= frame_idx[kv_idx] - return base & ~is_padding - - return create_block_mask( - mask_mod, - B=None, - H=None, - Q_LEN=padded_seq_len, - KV_LEN=padded_seq_len, - device=device, - _compile=False, - ) + Args: + chunk_partition: per-chunk frame counts (must sum to the number of latent frames). + height, width: latent spatial dimensions. + has_clean_context: ``True`` when ``clean_hidden_states`` will be threaded through :meth:`forward` + (training V2V/I2V); only this presence flag affects the mask layout. + device: device for the resulting :class:`BlockMask`. The mask is not auto-moved by + ``device_map="auto"``; build it on the same device the transformer's inputs will live on. + mode: ``"train"`` (matches :meth:`_forward_train`) or ``"cache"`` (matches :meth:`_forward_cache`). + The autoregressive ``_forward_inference`` path attends through the KV cache and has no mode here. + + Returns: + :class:`~torch.nn.attention.flex_attention.BlockMask`: causal mask spanning the FAR layout, padded to a + multiple of 128 along the sequence dimension (the BlockMask block-size requirement). + + Raises: + ValueError: if ``mode`` is neither ``"train"`` nor ``"cache"``. + """ + return _build_anyflow_far_causal_block_mask( + chunk_partition=chunk_partition, + height=height, + width=width, + patch_size=self.config.patch_size, + compressed_patch_size=self.config.compressed_patch_size, + full_chunk_limit=self.config.full_chunk_limit, + mode=mode, + has_clean_context=has_clean_context, + device=device, + ) def _forward_inference( self, @@ -1291,6 +1390,7 @@ def _forward_cache( r_timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, + attention_mask: Optional[BlockMask] = None, return_dict: bool = True, clean_hidden_states=None, clean_timestep=None, @@ -1337,9 +1437,10 @@ def _forward_cache( far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"] ) - attention_mask = self._build_causal_mask( - far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype - ) + if attention_mask is None: + attention_mask = _build_far_block_mask_from_far_cfg( + far_cfg, has_clean=clean_hidden_states is not None, device=hidden_states.device + ) rotary_emb = self.rope(far_cfg=far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device) hidden_states = self._forward_far_patchify( @@ -1396,6 +1497,7 @@ def _forward_train( r_timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, + attention_mask: Optional[BlockMask] = None, return_dict: bool = True, clean_hidden_states=None, clean_timestep=None, @@ -1436,9 +1538,12 @@ def _forward_train( "chunk_partition": chunk_partition, } - attention_mask = self._build_causal_mask( - far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype - ) + if attention_mask is None: + # Fallback for callers that don't pre-build (e.g. training scripts). Not compile-safe; + # use :meth:`build_attention_mask` upstream when wrapping `forward` in `torch.compile`. + attention_mask = _build_far_block_mask_from_far_cfg( + far_cfg, has_clean=clean_hidden_states is not None, device=hidden_states.device + ) rotary_emb = self.rope(far_cfg=far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device) diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py index 0eb60b525a0f..c3e1dbf3a459 100644 --- a/src/diffusers/pipelines/anyflow/pipeline_anyflow.py +++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py @@ -80,11 +80,11 @@ def prompt_clean(text): class AnyFlowPipeline(DiffusionPipeline, WanLoraLoaderMixin): r""" Bidirectional text-to-video generation pipeline for AnyFlow flow-map-distilled checkpoints, introduced in - [AnyFlow](https://huggingface.co/papers/2605.13724) by Yuchao Gu, Guian Fang et al. + [AnyFlow](https://huggingface.co/papers/2605.13724). AnyFlow learns arbitrary-interval transitions :math:`z_t \to z_r` rather than the fixed :math:`z_t \to z_0` mapping of consistency models, so a single distilled checkpoint can be evaluated at 1, 2, 4, 8, 16... NFE without - retraining. This pipeline operates over the full video tensor in one bidirectional pass; for frame-level + retraining. This pipeline operates over the full video tensor in one bidirectional pass; for chunk-wise autoregressive (causal) generation use ``AnyFlowFARPipeline``. Sampling is plain Euler in mean-velocity form (``z_r = z_t - (t - r) * u``) with no re-noising. The released NVIDIA diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py index e73c44b2fde3..e33f8b7c3873 100644 --- a/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py +++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py @@ -92,11 +92,10 @@ def prompt_clean(text): class AnyFlowFARPipeline(DiffusionPipeline, WanLoraLoaderMixin): r""" Causal (FAR-based) text-to-video / image-to-video / video-to-video pipeline for AnyFlow checkpoints, introduced in - [AnyFlow](https://huggingface.co/papers/2605.13724) by Yuchao Gu, Guian Fang et al. + [AnyFlow](https://huggingface.co/papers/2605.13724). - The pipeline drives a frame-level autoregressive sampling loop over chunks: each chunk is denoised with flow-map - steps while attending only to past chunks via block-sparse causal attention, and intermediate KV cache is reused - across chunks. + The pipeline drives a chunk-wise autoregressive sampling loop: each chunk is denoised with flow-map steps while + attending only to past chunks via block-sparse causal attention, and intermediate KV cache is reused across chunks. The task mode (T2V / I2V / V2V) is selected by which conditioning argument is passed to ``__call__``: @@ -106,9 +105,9 @@ class AnyFlowFARPipeline(DiffusionPipeline, WanLoraLoaderMixin): - ``video_latents=`` — already-encoded latents in the FAR layout (skips the VAE encode step). - The FAR backbone is the causal Wan2.1 variant introduced by FAR (Gu et al., 2025; arXiv:2503.19325). Inference is - plain Euler in mean-velocity form per chunk with no re-noising. Joint T2V / I2V / V2V is supported by a single - distilled model. + The FAR backbone is the causal Wan2.1 variant introduced by [FAR](https://huggingface.co/papers/2503.19325). + Inference is plain Euler in mean-velocity form per chunk with no re-noising. Joint T2V / I2V / V2V is supported by + a single distilled model. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). @@ -129,11 +128,6 @@ class AnyFlowFARPipeline(DiffusionPipeline, WanLoraLoaderMixin): model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - # Default chunk partition for the released NVIDIA AnyFlow-FAR checkpoints (81 frames at the diffusers - # VAE temporal stride of 4 → 21 latent frames split into 1 + 3*6 + 2 = [1, 3, 3, 3, 3, 3, 3, 2]). Override - # via the ``chunk_partition`` argument to ``__call__`` for other frame counts. - default_chunk_partition: List[int] = [1, 3, 3, 3, 3, 3, 3, 2] - def __init__( self, tokenizer: AutoTokenizer, @@ -423,12 +417,21 @@ def encode_kv_cache( r_timestep = torch.tensor([0], device=latents.device).expand(latent_model_input.shape[0]).unsqueeze(-1) r_timestep = r_timestep.repeat((1, latent_model_input.shape[1])) + attention_mask = self.transformer.build_attention_mask( + chunk_partition=chunk_partition, + height=latent_model_input.shape[-2], + width=latent_model_input.shape[-1], + device=latent_model_input.device, + mode="cache", + ) + _, kv_cache = self.transformer( hidden_states=latent_model_input, chunk_partition=chunk_partition, timestep=timestep, r_timestep=r_timestep, encoder_hidden_states=prompt_embeds, + attention_mask=attention_mask, attention_kwargs=self.attention_kwargs, return_dict=False, # kv-cache related @@ -538,9 +541,9 @@ def __call__( use_kv_cache (`bool`, defaults to `True`): Reuse the FAR attention KV cache across causal chunks. Disable only for debugging. chunk_partition (`List[int]`, *optional*): - Per-chunk frame counts. Defaults to `default_chunk_partition` (matched to the released 81-frame - checkpoints). When you change `num_frames`, supply a `chunk_partition` that sums to `(num_frames - 1) - // vae_scale_factor_temporal + 1`. + Per-chunk frame counts. Defaults to `self.transformer.config.chunk_partition` (matched to the released + 81-frame checkpoints). When you change `num_frames`, supply a `chunk_partition` that sums to + `(num_frames - 1) // vae_scale_factor_temporal + 1`. Examples: @@ -629,7 +632,7 @@ def __call__( video_latents = self.encode_video(video, height=height, width=width) if chunk_partition is None: - chunk_partition = list(self.default_chunk_partition) + chunk_partition = list(self.transformer.config.chunk_partition) if init_latents.shape[1] != sum(chunk_partition): raise ValueError( f"chunk_partition={chunk_partition} sums to {sum(chunk_partition)}, but the input latent " @@ -713,14 +716,14 @@ def __call__( this_chunk_partition = chunk_partition[: chunk_idx + 1] self.scheduler.set_timesteps(num_inference_steps, device=device, sigmas=sigmas, timesteps=timesteps) - timesteps = self.scheduler.timesteps + scheduler_timesteps = self.scheduler.timesteps inner_progress_bar_config = { **outer_progress_bar_config, "position": 1, "leave": False, "desc": f"Chunk {chunk_idx} Inference Steps", } - for i, t in enumerate(tqdm(timesteps, **inner_progress_bar_config)): + for i, t in enumerate(tqdm(scheduler_timesteps, **inner_progress_bar_config)): r = self.scheduler.sigmas[i + 1] * self.scheduler.config.num_train_timesteps if t == r: continue diff --git a/tests/models/transformers/test_models_transformer_anyflow_far.py b/tests/models/transformers/test_models_transformer_anyflow_far.py index d3631e361c09..6ebbe6596f51 100644 --- a/tests/models/transformers/test_models_transformer_anyflow_far.py +++ b/tests/models/transformers/test_models_transformer_anyflow_far.py @@ -21,6 +21,7 @@ from diffusers.models.transformers.transformer_anyflow_far import ( AnyFlowCausalAttnProcessor, AnyFlowFARTransformerOutput, + _build_anyflow_far_causal_block_mask, ) from diffusers.utils.torch_utils import randn_tensor @@ -30,6 +31,7 @@ BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, + TorchCompileTesterMixin, TrainingTesterMixin, ) @@ -88,6 +90,9 @@ def get_dummy_inputs(self) -> dict[str, "torch.Tensor"]: text_seq_len = 12 text_dim = 16 + # No `attention_mask` here: the model accepts `attention_mask=None` and builds the mask + # internally — exercising that fallback in every non-compile test is the point. The + # compile test class overrides this method to inject a pre-built BlockMask. return { "hidden_states": randn_tensor( (batch_size, num_frames, num_channels, height, width), @@ -149,11 +154,36 @@ class TestAnyFlowFARTransformer3DAttention(AnyFlowFARTransformer3DTesterConfig, """Attention processor tests for AnyFlow FAR Transformer 3D.""" -# Torch-compile mixin intentionally skipped: FAR's `_build_causal_mask` uses -# `flex_attention.create_block_mask(_compile=False)`, which conflicts with the tracer -# assumptions made by the standard TorchCompileTesterMixin. The bidi transformer test file -# covers compile behavior; the FAR causal path is bit-exact-validated end-to-end on H200 -# through the pipeline replay rather than per-module compile. +class TestAnyFlowFARTransformer3DCompile(AnyFlowFARTransformer3DTesterConfig, TorchCompileTesterMixin): + """torch.compile tests for AnyFlow FAR Transformer 3D. + + Pre-builds the BlockMask via the standalone helper and injects it as ``attention_mask`` so the + transformer forward never calls ``flex_attention.create_block_mask(_compile=False)`` inside the + compiled scope. + """ + + def get_dummy_inputs(self) -> dict[str, "torch.Tensor"]: + inputs = super().get_dummy_inputs() + init_dict = self.get_init_dict() + inputs["attention_mask"] = _build_anyflow_far_causal_block_mask( + chunk_partition=inputs["chunk_partition"], + height=inputs["hidden_states"].shape[-2], + width=inputs["hidden_states"].shape[-1], + patch_size=init_dict["patch_size"], + compressed_patch_size=init_dict["compressed_patch_size"], + full_chunk_limit=init_dict["full_chunk_limit"], + mode="train", + has_clean_context=False, + device=torch_device, + ) + return inputs + + @pytest.mark.skip(reason="torch.export does not accept BlockMask as a pytree input.") + def test_compile_works_with_aot(self, tmp_path): + # BlockMask is a custom NamedTuple containing tensors plus a Python callable `mask_mod`, + # which `torch.export` cannot lift into a pytree. `torch.compile(fullgraph=True)` and + # `compile_repeated_blocks` both work; only AOT export is blocked. + super().test_compile_works_with_aot(tmp_path) class AnyFlowCausalAttnProcessorTest(unittest.TestCase): diff --git a/tests/pipelines/anyflow/test_anyflow_far.py b/tests/pipelines/anyflow/test_anyflow_far.py index 8086afef6d65..de244d563ec6 100644 --- a/tests/pipelines/anyflow/test_anyflow_far.py +++ b/tests/pipelines/anyflow/test_anyflow_far.py @@ -90,6 +90,7 @@ def get_dummy_components(self): rope_max_seq_len=32, gate_value=0.25, deltatime_type="r", + chunk_partition=(1, 1, 1), ) components = { @@ -106,8 +107,8 @@ def get_dummy_inputs(self, device, seed=0): generator = torch.manual_seed(seed) else: generator = torch.Generator(device=device).manual_seed(seed) - # num_frames=9 -> 3 latent frames (VAE temporal stride 4); use a matching - # chunk_partition so the FAR pipeline's pre-flight assertion passes. + # num_frames=9 -> 3 latent frames (VAE temporal stride 4); the transformer config above + # has chunk_partition=(1, 1, 1) (sum 3) baked in, so __call__ picks it up automatically. inputs = { "prompt": "dance monkey", "negative_prompt": "negative", @@ -119,7 +120,6 @@ def get_dummy_inputs(self, device, seed=0): "num_frames": 9, "max_sequence_length": 16, "output_type": "pt", - "chunk_partition": [1, 1, 1], } return inputs