diff --git a/docs/source/en/api/models/anyflow_far_transformer3d.md b/docs/source/en/api/models/anyflow_far_transformer3d.md
index 3a9909b4887a..7f818c44ef20 100644
--- a/docs/source/en/api/models/anyflow_far_transformer3d.md
+++ b/docs/source/en/api/models/anyflow_far_transformer3d.md
@@ -13,19 +13,22 @@ specific language governing permissions and limitations under the License.
# AnyFlowFARTransformer3DModel
The causal (FAR) 3D Transformer used by [`AnyFlowFARPipeline`](../pipelines/anyflow#anyflowfarpipeline) —
-the FAR variant of [AnyFlow](https://huggingface.co/papers/2605.13724) (Yuchao Gu, Guian Fang et al., NUS
-ShowLab × NVIDIA). It extends the v0.35.1 Wan2.1 backbone with three additions:
+the FAR variant of [AnyFlow](https://huggingface.co/papers/2605.13724). See the
+[`AnyFlowFARPipeline`](../pipelines/anyflow) page for paper, authors, and released checkpoints. It extends
+the v0.35.1 Wan2.1 backbone with three additions:
-1. **FAR causal block-mask** via `torch.nn.attention.flex_attention`, supporting frame-level autoregressive
- generation as introduced in [FAR (Gu et al., 2025)](https://arxiv.org/abs/2503.19325).
+1. **FAR causal block-mask** via `torch.nn.attention.flex_attention`, supporting chunk-wise autoregressive
+ generation as introduced in [FAR](https://huggingface.co/papers/2503.19325).
2. **Compressed-frame patch embedding** (`far_patch_embedding`) for context (already-generated) frames,
warm-started from the full-resolution `patch_embedding` at construction time via trilinear interpolation.
3. **Dual-timestep flow-map embedding** (same as
[`AnyFlowTransformer3DModel`](anyflow_transformer3d)) — every forward call conditions on both the source
timestep ``t`` and the target timestep ``r``.
-The chunk schedule (`chunk_partition`) is **not** baked into the model config. It is a per-call argument to
-`forward`, so the same checkpoint handles different `num_frames` configurations without retraining.
+The default chunk schedule (`chunk_partition`) is stored in the model config; the released NVIDIA AnyFlow-FAR
+checkpoints use `[1, 3, 3, 3, 3, 3, 3, 2]` for the canonical 81-frame setting. `forward` accepts a per-call
+`chunk_partition` override, so the same checkpoint also handles other `num_frames` configurations without
+retraining.
```python
from diffusers import AnyFlowFARTransformer3DModel
diff --git a/docs/source/en/api/models/anyflow_transformer3d.md b/docs/source/en/api/models/anyflow_transformer3d.md
index 95888080c0ce..d37f7fba62fb 100644
--- a/docs/source/en/api/models/anyflow_transformer3d.md
+++ b/docs/source/en/api/models/anyflow_transformer3d.md
@@ -16,10 +16,11 @@ The bidirectional 3D Transformer used by [`AnyFlowPipeline`](../pipelines/anyflo
v0.35.1 Wan2.1 backbone with one structural change: the timestep embedder is replaced by
``AnyFlowDualTimestepTextImageEmbedding``, so every forward call conditions on both the source timestep
``t`` and the target timestep ``r``. This is the embedding required to learn the flow map
-:math:`\Phi_{r\leftarrow t}` introduced in
-[AnyFlow](https://huggingface.co/papers/2605.13724) (Yuchao Gu, Guian Fang et al., NUS ShowLab × NVIDIA).
+$\Phi_{r\leftarrow t}$ introduced in
+[AnyFlow](https://huggingface.co/papers/2605.13724). See the [`AnyFlowPipeline`](../pipelines/anyflow) page
+for paper, authors, and released checkpoints.
-For frame-level autoregressive (FAR causal) generation, use
+For chunk-wise autoregressive (FAR causal) generation, use
[`AnyFlowFARTransformer3DModel`](anyflow_far_transformer3d) instead.
```python
diff --git a/docs/source/en/api/pipelines/anyflow.md b/docs/source/en/api/pipelines/anyflow.md
index 9358b8d454fc..9e496a61113f 100644
--- a/docs/source/en/api/pipelines/anyflow.md
+++ b/docs/source/en/api/pipelines/anyflow.md
@@ -20,68 +20,28 @@ specific language governing permissions and limitations under the License.
# AnyFlow
-[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/2605.13724) by Yuchao Gu, Guian Fang and collaborators at [NUS ShowLab](https://sites.google.com/view/showlab) in collaboration with NVIDIA.
+[AnyFlow: Any-Step Video Diffusion Model with On-Policy Flow Map Distillation](https://huggingface.co/papers/2605.13724) from NVIDIA, National University of Singapore, and Massachusetts Institute of Technology, by Yuchao Gu, Guian Fang, Yuxin Jiang, Weijia Mao, Song Han, Han Cai, Mike Zheng Shou.
+
+> **TL;DR:** AnyFlow is the first any-step video diffusion framework built on flow maps, which enables a single model (bidirectional or causal) to adapt to arbitrary inference budgets.
*Few-step video generation has been significantly advanced by consistency models. However, their performance often degrades in any-step video diffusion models due to the fixed-point formulation. To address this limitation, we present AnyFlow, the first any-step video diffusion distillation framework built on flow maps. Instead of learning only the mapping z_t → z_0, AnyFlow learns transitions z_t → z_r over arbitrary time intervals, enabling a single model to adapt to different inference budgets. We design an improved forward flow map training recipe that fine-tunes pretrained video diffusion models into flow map models, and introduce Flow Map Backward Simulation to enable on-policy distillation for flow map models. Extensive experiments across both bidirectional and causal architectures, at scales ranging from 1.3B to 14B, on text-to-video and image-to-video tasks demonstrate that AnyFlow outperforms consistency-based baselines while preserving high fidelity and flexible sampling under varying step budgets.*
-The original training code is at [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow). The project page is at [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow).
+The AnyFlow pipelines were contributed by the AnyFlow Team. The original code is available on [GitHub](https://github.com/NVlabs/AnyFlow), the project page is at [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow), and pretrained models can be found in the [nvidia/anyflow](https://huggingface.co/collections/nvidia/anyflow) collection on Hugging Face.
-The following AnyFlow checkpoints are supported:
+Available Models:
| Checkpoint | Backbone | Description |
-|------------|----------|-------------|
-| [`nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers) | Wan2.1 1.3B | Bidirectional T2V, lightweight |
-| [`nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers) | Wan2.1 14B | Bidirectional T2V, full quality |
+|---|---|---|
+| [`nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers) | Wan2.1 1.3B | Bidirectional T2V |
+| [`nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers) | Wan2.1 14B | Bidirectional T2V |
| [`nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers) | FAR + Wan2.1 1.3B | Causal T2V / I2V / V2V |
| [`nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers`](https://huggingface.co/nvidia/AnyFlow-FAR-Wan2.1-14B-Diffusers) | FAR + Wan2.1 14B | Causal T2V / I2V / V2V |
-All four are grouped under the [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection.
-
> [!TIP]
-> Choose `AnyFlowPipeline` for traditional bidirectional text-to-video generation. Choose `AnyFlowFARPipeline` for streaming I2V, video continuation (V2V), or any setup that benefits from frame-by-frame autoregressive sampling.
-
-> [!TIP]
-> AnyFlow supports any-step sampling: a single distilled checkpoint can be evaluated at 1, 2, 4, 8, 16... NFE without retraining. Quality scales monotonically with steps in our benchmarks.
-
-### Optimizing Memory and Inference Speed
-
-
-
-
-```py
-import torch
-from diffusers import AnyFlowPipeline
-from diffusers.hooks import apply_group_offloading
-
-pipe = AnyFlowPipeline.from_pretrained(
- "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
-)
-apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level")
-pipe.vae.enable_slicing()
-pipe.vae.enable_tiling()
-```
-
-
-
-
-```py
-import torch
-from diffusers import AnyFlowPipeline
-
-pipe = AnyFlowPipeline.from_pretrained(
- "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
-).to("cuda")
-pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
-```
-
-
-
+> `AnyFlowPipeline` is designed for bidirectional diffusion models in text-to-video (T2V) generation. `AnyFlowFARPipeline` is a chunk-wise causal diffusion model that supports text-to-video (T2V) generation, image-to-video (I2V) generation, and video continuation (V2V).
### Generation with AnyFlow (Bidirectional T2V)
-
-
-
```py
import torch
from diffusers import AnyFlowPipeline
@@ -91,14 +51,16 @@ pipe = AnyFlowPipeline.from_pretrained(
"nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16
).to("cuda")
-prompt = "A red panda eating bamboo in a forest, cinematic lighting"
-video = pipe(prompt, num_inference_steps=4, num_frames=33).frames[0]
-export_to_video(video, "out.mp4", fps=16)
+prompt = (
+ "An astronaut runs smoothly and appears almost weightless on the lunar surface, "
+ "as seen from a low-angle shot that highlights the vast, desolate background of the moon. "
+ "The moon's craters and rocky terrain are clearly visible, creating a stark contrast against "
+ "the running astronaut who moves with graceful, fluid motions."
+)
+video = pipe(prompt, num_inference_steps=4, num_frames=81).frames[0]
+export_to_video(video, "anyflow_t2v.mp4", fps=16)
```
-
-
-
### Generation with AnyFlow (FAR Causal)
The causal pipeline selects between T2V / I2V / V2V via the ``video`` (or ``video_latents``) argument:
@@ -108,10 +70,10 @@ clip for V2V continuation. If you already have pre-encoded latents in the model
``video_latents=`` to skip VAE encoding. ``video`` and ``video_latents`` are mutually exclusive.
> [!IMPORTANT]
-> `AnyFlowFARPipeline.default_chunk_partition = [1, 3, 3, 3, 3, 3, 3, 2]` (sum 21) is matched to the
-> released checkpoints' canonical 81 raw frames (21 latent frames at the VAE temporal stride of 4). When
-> you change `num_frames`, you must also pass a matching `chunk_partition` summing to
-> `(num_frames - 1) // 4 + 1`, otherwise the pipeline raises an `AssertionError`.
+> The released checkpoints bake `chunk_partition=[1, 3, 3, 3, 3, 3, 3, 2]` (sum 21) into the transformer
+> config, matched to the canonical 81 raw frames (21 latent frames at the VAE temporal stride of 4). When
+> you change `num_frames`, pass a matching `chunk_partition` summing to `(num_frames - 1) // 4 + 1`,
+> otherwise the pipeline raises a `ValueError`.
@@ -125,12 +87,12 @@ pipe = AnyFlowFARPipeline.from_pretrained(
"nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
).to("cuda")
-video = pipe(
- prompt="A cat surfing a wave, sunset",
- num_inference_steps=4,
- num_frames=81,
-).frames[0]
-export_to_video(video, "out.mp4", fps=16)
+prompt = (
+ "An astronaut runs smoothly and appears almost weightless on the lunar surface, "
+ "as seen from a low-angle shot that highlights the vast, desolate background of the moon."
+)
+video = pipe(prompt, num_inference_steps=4, num_frames=81).frames[0]
+export_to_video(video, "anyflow_far_t2v.mp4", fps=16)
```
@@ -146,18 +108,25 @@ pipe = AnyFlowFARPipeline.from_pretrained(
"nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
).to("cuda")
-# Wrap the conditioning image as a one-frame video tensor: (1, 1, 3, H, W) in [0, 1].
-first_frame = load_image("path/to/first_frame.png").resize((832, 480))
+# Example conditioning image from the AnyFlow repo.
+first_frame = load_image(
+ "https://raw.githubusercontent.com/NVlabs/AnyFlow/main/assets/evaluation/example/images/1.jpg"
+).resize((832, 480))
arr = np.asarray(first_frame).astype("float32") / 255.0 # (480, 832, 3)
-context_tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).unsqueeze(1).to("cuda")
+context_tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).unsqueeze(1).to("cuda") # (1, 1, 3, 480, 832)
+prompt = (
+ "A towering, battle-scarred humanoid robot, reminiscent of a Transformer with powerful, segmented armor "
+ "and glowing red optics, walking through the skeletal remains of a city ruin. Twisted metal and shattered "
+ "concrete crunch under its heavy steps, as the robot scans the desolate, dust-choked skyline under an dark sky."
+)
video = pipe(
- prompt="a cat walks across a sunlit lawn",
+ prompt=prompt,
video=context_tensor,
num_inference_steps=4,
num_frames=81,
).frames[0]
-export_to_video(video, "out.mp4", fps=16)
+export_to_video(video, "anyflow_far_i2v.mp4", fps=16)
```
@@ -173,21 +142,26 @@ pipe = AnyFlowFARPipeline.from_pretrained(
"nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", torch_dtype=torch.bfloat16
).to("cuda")
-# Context clip — 9 raw frames map to 3 latent frames (9 = 4·2 + 1, 3 = 2 + 1).
-context_frames = load_video("path/to/context.mp4")[:9]
+# Example conditioning clip from the AnyFlow repo — take the first 9 frames (3 latent frames at VAE temporal stride 4).
+context_frames = load_video(
+ "https://raw.githubusercontent.com/NVlabs/AnyFlow/main/assets/evaluation/example/videos/2.mp4"
+)[:9]
arr = np.stack([np.asarray(f.resize((832, 480))) for f in context_frames]).astype("float32") / 255.0
-# np.stack gives (T, H, W, C) = (9, 480, 832, 3) → permute to (T, C, H, W) then add batch.
context_tensor = torch.from_numpy(arr).permute(0, 3, 1, 2).unsqueeze(0).to("cuda") # (1, 9, 3, 480, 832)
+prompt = (
+ "A focused trail runner's powerful strides through a dense, sun-dappled forest. "
+ "The camera tracks alongside, highlighting muscular exertion, sweat, and determined facial expression."
+)
video = pipe(
- prompt="continue the story",
+ prompt=prompt,
video=context_tensor,
num_inference_steps=4,
num_frames=81,
# Override chunk_partition so the first chunk covers exactly the 3 latent context frames.
chunk_partition=[3, 3, 3, 3, 3, 3, 3],
).frames[0]
-export_to_video(video, "out.mp4", fps=16)
+export_to_video(video, "anyflow_far_v2v.mp4", fps=16)
```
diff --git a/docs/source/zh/using-diffusers/anyflow.md b/docs/source/zh/using-diffusers/anyflow.md
index 575cdb1c1cb8..e9c925a85256 100644
--- a/docs/source/zh/using-diffusers/anyflow.md
+++ b/docs/source/zh/using-diffusers/anyflow.md
@@ -22,7 +22,7 @@ NFE 增加反而经常掉点。
采样步之间的 re-noising;on-policy 蒸馏阶段额外用 **DMD 反向散度监督** + **Flow-Map backward simulation**
(3 段 shortcut)补上 consistency 蒸馏遗留的 exposure-bias 缺口。
-AnyFlow 由 Yuchao Gu、Guian Fang 等人在 [NUS ShowLab](https://sites.google.com/view/showlab) 与 NVIDIA 合作完成。原始训练代码在 [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow),项目主页是 [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow)。4 个发布 checkpoint 归在 [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection 里。
+AnyFlow 由 NVIDIA、新加坡国立大学(NUS)和 MIT 合作完成,作者为 Yuchao Gu、Guian Fang、Yuxin Jiang、Weijia Mao、Song Han、Han Cai、Mike Zheng Shou。原始训练代码在 [`NVlabs/AnyFlow`](https://github.com/NVlabs/AnyFlow),项目主页是 [nvlabs.github.io/AnyFlow](https://nvlabs.github.io/AnyFlow),4 个发布 checkpoint 归在 [`nvidia/anyflow`](https://huggingface.co/collections/nvidia/anyflow) Hugging Face collection 里。
本文档梳理实战要点:怎么选 pipeline、怎么用 any-step 采样、怎么把 AnyFlow 嵌进 T2V / I2V / V2V 工作流。
@@ -100,7 +100,7 @@ prompt = "森林里一只小熊猫在啃竹子,电影感光照"
for nfe in [1, 2, 4, 8, 16, 32]:
# 每轮重建 generator —— 这样跨步数对比时唯一变量是 NFE。
generator = torch.Generator("cuda").manual_seed(0)
- video = pipe(prompt, num_inference_steps=nfe, num_frames=33, generator=generator).frames[0]
+ video = pipe(prompt, num_inference_steps=nfe, num_frames=81, generator=generator).frames[0]
export_to_video(video, f"out_nfe{nfe}.mp4", fps=16)
```
@@ -125,11 +125,11 @@ Causal pipeline 用同一个蒸馏模型支持三种任务模式,**通过 `vid
Context tensor 的帧数必须满足 `T = 4n + 1`,跟 VAE 时间步长对齐。
> [!IMPORTANT]
-> FAR pipeline 是分块 (chunk) rollout,`num_frames` 必须配合 chunk 调度。默认
-> `chunk_partition=[1, 3, 3, 3, 3, 3, 3, 2]`(求和 21)对应发布 checkpoint 的标准 `num_frames=81`
-> (21 = (81 − 1) // 4 + 1)。改 `num_frames` 时**必须**显式传匹配的 `chunk_partition`,使其求和等于
-> `(num_frames - 1) // 4 + 1`,否则 pipeline 会抛 `AssertionError`。比如 `num_frames=33` 对应 9 个 latent
-> 帧,可用 `chunk_partition=[1, 4, 4]`。
+> FAR pipeline 是分块 (chunk) rollout,`num_frames` 必须配合 chunk 调度。发布的 checkpoint 在
+> transformer config 里写入 `chunk_partition=[1, 3, 3, 3, 3, 3, 3, 2]`(求和 21),对应标准
+> `num_frames=81`(21 = (81 − 1) // 4 + 1)。改 `num_frames` 时**必须**显式传匹配的 `chunk_partition`,
+> 使其求和等于 `(num_frames - 1) // 4 + 1`,否则 pipeline 会抛 `ValueError`。比如 `num_frames=33` 对应
+> 9 个 latent 帧,可用 `chunk_partition=[1, 4, 4]`。
```py
import numpy as np
@@ -183,33 +183,6 @@ export_to_video(video, "v2v.mp4", fps=16)
如果你已经有 VAE 编码过的 latent,可以直接传 `video_latents=` 跳过 `vae_encode` 步骤
(和 `video` 互斥)。
-## 显存与推理速度
-
-14B 的 AnyFlow 模型用 group offload + VAE slicing 单卡 40 GB 能跑:
-
-```py
-import torch
-from diffusers import AnyFlowPipeline
-from diffusers.hooks import apply_group_offloading
-
-pipe = AnyFlowPipeline.from_pretrained(
- "nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers", torch_dtype=torch.bfloat16
-)
-apply_group_offloading(pipe.transformer, onload_device="cuda", offload_type="leaf_level")
-pipe.vae.enable_slicing()
-pipe.vae.enable_tiling()
-```
-
-延迟方面,`torch.compile` 对 transformer(最重的模块)效果很好:
-
-```py
-pipe = pipe.to("cuda")
-pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
-```
-
-编译开销跑几步就摊销掉;配合 AnyFlow 的低 NFE(4-8 步),`torch.compile` 在 14B 上相比 eager
-模式有明显加速。
-
## LoRA 微调
两个 pipeline 都复用 [`WanLoraLoaderMixin`](../api/loaders/lora),因此为对应 Wan2.1 backbone 训练的
diff --git a/scripts/convert_anyflow_to_diffusers.py b/scripts/convert_anyflow_to_diffusers.py
index 60574ca23a1e..229d286c4701 100644
--- a/scripts/convert_anyflow_to_diffusers.py
+++ b/scripts/convert_anyflow_to_diffusers.py
@@ -57,13 +57,21 @@
"AnyFlow-FAR-Wan2.1-1.3B-Diffusers": {
"base_model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"transformer_cls": AnyFlowFARTransformer3DModel,
- "transformer_kwargs": {"full_chunk_limit": 3, "compressed_patch_size": [1, 4, 4]},
+ "transformer_kwargs": {
+ "full_chunk_limit": 3,
+ "compressed_patch_size": [1, 4, 4],
+ "chunk_partition": [1, 3, 3, 3, 3, 3, 3, 2],
+ },
"pipeline_cls": AnyFlowFARPipeline,
},
"AnyFlow-FAR-Wan2.1-14B-Diffusers": {
"base_model": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
"transformer_cls": AnyFlowFARTransformer3DModel,
- "transformer_kwargs": {"full_chunk_limit": 3, "compressed_patch_size": [1, 4, 4]},
+ "transformer_kwargs": {
+ "full_chunk_limit": 3,
+ "compressed_patch_size": [1, 4, 4],
+ "chunk_partition": [1, 3, 3, 3, 3, 3, 3, 2],
+ },
"pipeline_cls": AnyFlowFARPipeline,
},
"AnyFlow-Wan2.1-T2V-1.3B-Diffusers": {
diff --git a/src/diffusers/models/transformers/transformer_anyflow.py b/src/diffusers/models/transformers/transformer_anyflow.py
index 2ac554419e5e..873e0b095b33 100644
--- a/src/diffusers/models/transformers/transformer_anyflow.py
+++ b/src/diffusers/models/transformers/transformer_anyflow.py
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-# This file derives from the FAR architecture (Gu et al., 2025, arXiv:2503.19325) and adds the
-# AnyFlow dual-timestep flow-map embedding (AnyFlowDualTimestepTextImageEmbedding) introduced by
-# Yuchao Gu, Guian Fang et al. (arXiv:2605.13724). The base 3D DiT structure is adapted from the
+# This file derives from the FAR architecture (arXiv:2503.19325) and adds the
+# AnyFlow dual-timestep flow-map embedding (AnyFlowDualTimestepTextImageEmbedding) introduced in
+# AnyFlow (arXiv:2605.13724). The base 3D DiT structure is adapted from the
# v0.35.1 Wan2.1 transformer (transformer_wan.py); upstream Wan has since been refactored, so
# this file is intentionally self-contained rather than annotated with `# Copied from`.
@@ -334,8 +334,11 @@ def __init__(
self._freqs_cache: Optional[Tuple[Any, torch.Tensor]] = None
def _build_freqs(self, device: torch.device) -> torch.Tensor:
+ # Skip the cache read/write inside torch.compile: mutating ``self._freqs_cache`` between calls
+ # becomes a Dynamo guard and forces recompilation on the second invocation.
+ is_compiling = torch.compiler.is_compiling()
cache_key = (device.type, str(device))
- if self._freqs_cache is not None and self._freqs_cache[0] == cache_key:
+ if not is_compiling and self._freqs_cache is not None and self._freqs_cache[0] == cache_key:
return self._freqs_cache[1]
is_mps = device.type == "mps"
@@ -357,7 +360,8 @@ def _build_freqs(self, device: torch.device) -> torch.Tensor:
)
freqs_list.append(f.to(device))
freqs = torch.cat(freqs_list, dim=1)
- self._freqs_cache = (cache_key, freqs)
+ if not is_compiling:
+ self._freqs_cache = (cache_key, freqs)
return freqs
def _forward_full_frame(self, num_frames, height, width, device) -> torch.Tensor:
@@ -510,10 +514,9 @@ class AnyFlowTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
The architecture is the v0.35.1 Wan2.1 3D DiT backbone with one structural change: the timestep embedder is
replaced by ``AnyFlowDualTimestepTextImageEmbedding`` so that every forward call conditions on both the source
timestep ``t`` and the target timestep ``r``. This is the embedding required to learn the flow map
- :math:`\Phi_{r\leftarrow t}` introduced in [AnyFlow](https://huggingface.co/papers/2605.13724) by Yuchao Gu, Guian
- Fang et al.
+ :math:`\Phi_{r\leftarrow t}` introduced in [AnyFlow](https://huggingface.co/papers/2605.13724).
- For frame-level autoregressive (FAR causal) generation, use ``AnyFlowFARTransformer3DModel`` instead; that variant
+ For chunk-wise autoregressive (FAR causal) generation, use ``AnyFlowFARTransformer3DModel`` instead; that variant
adds the FAR causal block-mask and a compressed-frame patch embedding on top of the same backbone.
Args:
diff --git a/src/diffusers/models/transformers/transformer_anyflow_far.py b/src/diffusers/models/transformers/transformer_anyflow_far.py
index a40e2fafcb61..4b418f129636 100644
--- a/src/diffusers/models/transformers/transformer_anyflow_far.py
+++ b/src/diffusers/models/transformers/transformer_anyflow_far.py
@@ -14,9 +14,9 @@
#
# This file is the FAR causal sibling of `transformer_anyflow.py`. Shared submodules are duplicated
# via `# Copied from` so `make fix-copies` keeps both files in sync; this keeps each transformer
-# variant readable in isolation. The FAR architecture comes from Gu et al., 2025
+# variant readable in isolation. The FAR architecture comes from FAR
# (arXiv:2503.19325); the dual-timestep flow-map embedding is AnyFlow's contribution
-# (Yuchao Gu, Guian Fang et al., arXiv:2605.13724).
+# (arXiv:2605.13724).
import math
from dataclasses import dataclass
@@ -25,7 +25,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from torch.nn.attention.flex_attention import create_block_mask
+from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
@@ -77,7 +77,7 @@ class AnyFlowCausalAttnProcessor:
autoregressive read (cache-read step).
Requires the ``flex`` attention backend — the ``BlockMask`` produced by
- :class:`AnyFlowFARTransformer3DModel._build_causal_mask` is consumed only by the flex backend. A clear
+ :meth:`AnyFlowFARTransformer3DModel.build_attention_mask` is consumed only by the flex backend. A clear
:class:`ValueError` is raised if a non-flex backend is configured via ``_attention_backend``.
"""
@@ -643,8 +643,11 @@ def __init__(
# Copied from diffusers.models.transformers.transformer_anyflow.AnyFlowRotaryPosEmbed._build_freqs
def _build_freqs(self, device: torch.device) -> torch.Tensor:
+ # Skip the cache read/write inside torch.compile: mutating ``self._freqs_cache`` between calls
+ # becomes a Dynamo guard and forces recompilation on the second invocation.
+ is_compiling = torch.compiler.is_compiling()
cache_key = (device.type, str(device))
- if self._freqs_cache is not None and self._freqs_cache[0] == cache_key:
+ if not is_compiling and self._freqs_cache is not None and self._freqs_cache[0] == cache_key:
return self._freqs_cache[1]
is_mps = device.type == "mps"
@@ -666,7 +669,8 @@ def _build_freqs(self, device: torch.device) -> torch.Tensor:
)
freqs_list.append(f.to(device))
freqs = torch.cat(freqs_list, dim=1)
- self._freqs_cache = (cache_key, freqs)
+ if not is_compiling:
+ self._freqs_cache = (cache_key, freqs)
return freqs
def avg_pool_complex(self, freq: torch.Tensor, kernel_size: int, stride: int):
@@ -774,14 +778,191 @@ def forward(self, far_cfg, device, clean_hidden_states=None):
return {"query": freqs, "key": freqs}
+def _build_anyflow_far_causal_block_mask(
+ chunk_partition: List[int],
+ height: int,
+ width: int,
+ patch_size: Tuple[int, int, int],
+ compressed_patch_size: Tuple[int, int, int],
+ full_chunk_limit: int,
+ *,
+ mode: str = "train",
+ has_clean_context: bool = False,
+ device: Optional[torch.device] = None,
+) -> BlockMask:
+ r"""Build the causal :class:`~torch.nn.attention.flex_attention.BlockMask` for the FAR transformer.
+
+ Provided as a standalone function so callers can construct the mask *outside* the transformer's compiled region,
+ which is required to wrap the forward in ``torch.compile(fullgraph=True)`` (``flex_attention.create_block_mask``
+ itself uses ``_compile=False`` internally and breaks the graph when invoked inside the compiled scope).
+
+ Two modes are exposed, mirroring the FAR forward paths that actually consume a mask. The autoregressive
+ ``_forward_inference`` path attends through the KV cache and does not use a full BlockMask, so it has no
+ corresponding mode here.
+
+ Args:
+ chunk_partition: per-chunk frame counts; must sum to the number of latent frames.
+ height, width: latent spatial dimensions.
+ patch_size, compressed_patch_size, full_chunk_limit: must match the transformer config.
+ mode: ``"train"`` (strict ``>`` comparison against ``full_chunk_limit``, matches
+ :meth:`AnyFlowFARTransformer3DModel._forward_train`) or ``"cache"`` (``>=`` comparison via the
+ ``full_chunk_limit - 1`` offset used by :meth:`AnyFlowFARTransformer3DModel._forward_cache`).
+ has_clean_context: ``True`` when ``clean_hidden_states`` is being threaded through the
+ transformer (training V2V/I2V).
+ device: device for the resulting BlockMask. Defaults to CPU.
+ """
+ if mode not in {"train", "cache"}:
+ raise ValueError(f"Unknown mode {mode!r}; expected 'train' or 'cache'.")
+ full_token_per_frame = (height // patch_size[1]) * (width // patch_size[2])
+ compressed_token_per_frame = (height // compressed_patch_size[1]) * (width // compressed_patch_size[2])
+
+ # `cache` uses `full_chunk_limit - 1` (an effective `>= full_chunk_limit` comparison); `train` uses a strict `>`.
+ total_chunks = len(chunk_partition)
+ threshold = full_chunk_limit - 1 if mode == "cache" else full_chunk_limit
+ if total_chunks > threshold:
+ num_full_chunk = threshold
+ num_compressed_chunk = total_chunks - threshold
+ else:
+ num_full_chunk, num_compressed_chunk = total_chunks, 0
+
+ far_cfg = {
+ "num_full_chunk": num_full_chunk,
+ "num_compressed_chunk": num_compressed_chunk,
+ "num_full_frames": sum(chunk_partition[num_compressed_chunk:]),
+ "num_compressed_frames": sum(chunk_partition[:num_compressed_chunk]),
+ "full_token_per_frame": full_token_per_frame,
+ "compressed_token_per_frame": compressed_token_per_frame,
+ "chunk_partition": chunk_partition,
+ }
+ return _build_far_block_mask_from_far_cfg(far_cfg, has_clean=has_clean_context, device=device)
+
+
+def _build_far_block_mask_from_far_cfg(far_cfg, has_clean, device):
+ """Internal: build a BlockMask given an already-computed ``far_cfg`` dict.
+
+ Factored out of :class:`AnyFlowFARTransformer3DModel` so it can be shared between
+ :func:`_build_anyflow_far_causal_block_mask` (the user-facing entry point) and the in-forward fallback path used
+ when no pre-built ``attention_mask`` is passed.
+ """
+ chunk_partition = far_cfg["chunk_partition"]
+
+ noise_seq_len = clean_seq_len = far_cfg["num_full_frames"] * far_cfg["full_token_per_frame"]
+ context_seq_len = far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"]
+
+ noise_start = context_seq_len
+ noise_end = noise_start + noise_seq_len
+
+ clean_start = context_seq_len + noise_seq_len
+ clean_end = clean_start + clean_seq_len
+
+ if has_clean:
+ real_seq_len = context_seq_len + noise_seq_len + clean_seq_len
+ else:
+ real_seq_len = context_seq_len + noise_seq_len
+
+ padded_seq_len = int(math.ceil(real_seq_len / 128.0) * 128.0)
+
+ context_chunk_partition, noise_chunk_partition = (
+ chunk_partition[: far_cfg["num_compressed_chunk"]],
+ chunk_partition[far_cfg["num_compressed_chunk"] :],
+ )
+
+ if len(context_chunk_partition) != 0:
+ context_frame_idx = torch.cat(
+ [
+ torch.ones(chunk_len * far_cfg["compressed_token_per_frame"], device=device) * chunk_idx
+ for chunk_idx, chunk_len in enumerate(context_chunk_partition)
+ ]
+ )
+ else:
+ context_frame_idx = None
+
+ if has_clean:
+ noise_frame_idx = clean_frame_idx = torch.cat(
+ [
+ torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device)
+ * (chunk_idx + len(context_chunk_partition))
+ for chunk_idx, chunk_len in enumerate(noise_chunk_partition)
+ ]
+ )
+ pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device)
+
+ if len(context_chunk_partition) != 0:
+ frame_idx = torch.cat([context_frame_idx, noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0)
+ else:
+ frame_idx = torch.cat([noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0)
+
+ def mask_mod(b, h, q_idx, kv_idx):
+ # 1) is padding
+ is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len)
+
+ # 2) chunk causal
+ base = frame_idx[q_idx] >= frame_idx[kv_idx]
+
+ # 3) interval mask
+ q_is_noise = (q_idx >= noise_start) & (q_idx < noise_end)
+ q_is_clean = (q_idx >= clean_start) & (q_idx < clean_end)
+
+ k_is_noise = (kv_idx >= noise_start) & (kv_idx < noise_end)
+ k_is_clean = (kv_idx >= clean_start) & (kv_idx < clean_end)
+
+ # 4) clean -> noise: disallowed
+ is_clean_to_noise = q_is_clean & k_is_noise
+
+ # 5) noise -> noise: only same frame
+ same_frame_idx = frame_idx[q_idx] == frame_idx[kv_idx]
+
+ noise_to_noise = q_is_noise & k_is_noise
+ noise_to_clean = q_is_noise & k_is_clean
+
+ noise_to_noise_allow = noise_to_noise & same_frame_idx
+ noise_to_noise_mask = (~noise_to_noise) | noise_to_noise_allow
+
+ noise_to_clean_same = noise_to_clean & same_frame_idx
+ noise_to_clean_disallow = noise_to_clean_same
+
+ allowed = base & ~is_padding & ~is_clean_to_noise & noise_to_noise_mask & ~noise_to_clean_disallow
+ return allowed
+
+ else:
+ noise_frame_idx = torch.cat(
+ [
+ torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device)
+ * (chunk_idx + len(context_chunk_partition))
+ for chunk_idx, chunk_len in enumerate(noise_chunk_partition)
+ ]
+ )
+ pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device)
+
+ if len(context_chunk_partition) != 0:
+ frame_idx = torch.cat([context_frame_idx, noise_frame_idx, pad_frame_idx], dim=0)
+ else:
+ frame_idx = torch.cat([noise_frame_idx, pad_frame_idx], dim=0)
+
+ def mask_mod(b, h, q_idx, kv_idx):
+ is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len)
+ base = frame_idx[q_idx] >= frame_idx[kv_idx]
+ return base & ~is_padding
+
+ return create_block_mask(
+ mask_mod,
+ B=None,
+ H=None,
+ Q_LEN=padded_seq_len,
+ KV_LEN=padded_seq_len,
+ device=device,
+ _compile=False,
+ )
+
+
class AnyFlowFARTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
- Causal (FAR) 3D Transformer for AnyFlow flow-map sampling with frame-level autoregressive generation.
+ Causal (FAR) 3D Transformer for AnyFlow flow-map sampling with chunk-wise autoregressive generation.
Extends the v0.35.1 Wan2.1 backbone with:
- * **FAR causal block-mask** via :func:`torch.nn.attention.flex_attention`, supporting frame-level autoregressive
- generation (FAR; [Gu et al., 2025](https://arxiv.org/abs/2503.19325)).
+ * **FAR causal block-mask** via :func:`torch.nn.attention.flex_attention`, supporting chunk-wise autoregressive
+ generation ([FAR](https://huggingface.co/papers/2503.19325)).
* **Compressed-frame patch embedding** ``far_patch_embedding`` for context (already-generated) frames, initialized
from ``patch_embedding`` via trilinear interpolation so a freshly constructed model is already at a reasonable
starting point even before LoRA fine-tuning.
@@ -826,11 +1007,11 @@ class AnyFlowFARTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
Mixing gate between source-timestep and delta-timestep embeddings.
deltatime_type (`str`, defaults to `'r'`):
Either ``"r"`` (delta is the target timestep) or ``"t-r"`` (delta is the absolute interval).
-
- .. note::
- ``chunk_partition`` is **not** a model config field — it is a per-call argument passed to :meth:`forward`.
- Different inference setups (varying ``num_frames`` or full-vs-compressed schedules) therefore do not require
- separate checkpoints.
+ chunk_partition (`Tuple[int, ...]`, defaults to `(1, 3, 3, 3, 3, 3, 3, 2)`):
+ Default per-chunk frame counts used by the pipeline. The released NVIDIA AnyFlow-FAR checkpoints target
+ ``num_frames=81`` (21 latent frames at VAE temporal stride 4) split as ``1 + 3*6 + 2``. A different
+ ``num_frames`` requires a matching ``chunk_partition`` override passed to
+ :meth:`AnyFlowFARPipeline.__call__` (and likewise to :meth:`forward`).
"""
_supports_gradient_checkpointing = True
@@ -859,6 +1040,7 @@ def __init__(
rope_max_seq_len: int = 1024,
gate_value: float = 0.25,
deltatime_type: str = "r",
+ chunk_partition: Tuple[int, ...] = (1, 3, 3, 3, 3, 3, 3, 2),
) -> None:
super().__init__()
@@ -923,6 +1105,7 @@ def forward(
clean_timestep: Optional[torch.Tensor] = None,
kv_cache: Optional[List[Dict[str, torch.Tensor]]] = None,
kv_cache_flag: Optional[Dict[str, Any]] = None,
+ attention_mask: Optional[BlockMask] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Transformer2DModelOutput, AnyFlowFARTransformerOutput, Tuple]:
@@ -955,6 +1138,12 @@ def forward(
Per-block KV cache for autoregressive inference. `None` selects the training path.
kv_cache_flag (`Dict[str, Any]`, *optional*):
KV-cache metadata (e.g. ``is_cache_step`` flag and token counts).
+ attention_mask (`BlockMask`, *optional*):
+ Pre-built causal mask, typically constructed via :meth:`build_attention_mask`. Consumed by the train
+ and KV-cache prefill paths; the autoregressive inference path attends through the KV cache and does not
+ use a full mask. When ``None``, the train / cache paths build the mask internally; that fallback is not
+ compile-safe (the underlying ``flex_attention.create_block_mask`` breaks the graph under
+ ``fullgraph=True``), so pass a pre-built mask whenever wrapping ``forward`` in ``torch.compile``.
attention_kwargs (`dict`, *optional*):
Forwarded to the attention processors.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -978,12 +1167,14 @@ def forward(
return self._forward_cache(
clean_hidden_states=clean_hidden_states,
clean_timestep=clean_timestep,
+ attention_mask=attention_mask,
**common,
)
return self._forward_inference(**common)
return self._forward_train(
clean_hidden_states=clean_hidden_states,
clean_timestep=clean_timestep,
+ attention_mask=attention_mask,
**common,
)
@@ -1029,142 +1220,50 @@ def _forward_far_patchify_inference(self, hidden_states):
hidden_states = self.patch_embedding(hidden_states).flatten(start_dim=2, end_dim=4).transpose(1, 2)
return hidden_states
- def _build_causal_mask(self, far_cfg, clean_hidden_states, device, dtype):
- chunk_partition = far_cfg["chunk_partition"]
-
- noise_seq_len = clean_seq_len = far_cfg["num_full_frames"] * far_cfg["full_token_per_frame"]
- context_seq_len = far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"]
-
- noise_start = context_seq_len
- noise_end = noise_start + noise_seq_len
-
- clean_start = context_seq_len + noise_seq_len
- clean_end = clean_start + clean_seq_len
-
- if clean_hidden_states is not None:
- real_seq_len = context_seq_len + noise_seq_len + clean_seq_len
- else:
- real_seq_len = context_seq_len + noise_seq_len
-
- padded_seq_len = int(math.ceil(real_seq_len / 128.0) * 128.0)
-
- if clean_hidden_states is not None:
- context_chunk_partition, noise_chunk_partition = (
- chunk_partition[: far_cfg["num_compressed_chunk"]],
- chunk_partition[far_cfg["num_compressed_chunk"] :],
- ) # noqa: E501
-
- if len(context_chunk_partition) != 0:
- context_frame_idx = torch.cat(
- [
- torch.ones(chunk_len * far_cfg["compressed_token_per_frame"], device=device) * chunk_idx
- for chunk_idx, chunk_len in enumerate(context_chunk_partition)
- ]
- ) # noqa: E501
- else:
- context_frame_idx = None
- noise_frame_idx = clean_frame_idx = torch.cat(
- [
- torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device)
- * (chunk_idx + len(context_chunk_partition))
- for chunk_idx, chunk_len in enumerate(noise_chunk_partition)
- ]
- ) # noqa: E501
- pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device)
-
- if len(context_chunk_partition) != 0:
- frame_idx = torch.cat([context_frame_idx, noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0)
- else:
- frame_idx = torch.cat([noise_frame_idx, clean_frame_idx, pad_frame_idx], dim=0)
-
- def mask_mod(b, h, q_idx, kv_idx):
- # q_idx, kv_idx: LongTensor, range: [0, padded_seq_len)
-
- # 1) whether is padding
- is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len)
-
- # 2) chunk causal
- base = frame_idx[q_idx] >= frame_idx[kv_idx]
-
- # 3) interval mask
- q_is_noise = (q_idx >= noise_start) & (q_idx < noise_end)
- q_is_clean = (q_idx >= clean_start) & (q_idx < clean_end)
-
- k_is_noise = (kv_idx >= noise_start) & (kv_idx < noise_end)
- k_is_clean = (kv_idx >= clean_start) & (kv_idx < clean_end)
-
- # 4) clean -> noise: disallowed
- is_clean_to_noise = q_is_clean & k_is_noise
-
- # 5) noise -> noise: only same frame
- same_frame_idx = frame_idx[q_idx] == frame_idx[kv_idx]
-
- noise_to_noise = q_is_noise & k_is_noise
- noise_to_clean = q_is_noise & k_is_clean
-
- noise_to_noise_allow = noise_to_noise & same_frame_idx
- noise_to_noise_mask = (~noise_to_noise) | noise_to_noise_allow
-
- noise_to_clean_same = noise_to_clean & same_frame_idx
- noise_to_clean_disallow = noise_to_clean_same
-
- # attention mask is chunk casual
- allowed = base & ~is_padding & ~is_clean_to_noise & noise_to_noise_mask & ~noise_to_clean_disallow
- return allowed
-
- return create_block_mask(
- mask_mod,
- B=None,
- H=None,
- Q_LEN=padded_seq_len,
- KV_LEN=padded_seq_len,
- device=device,
- _compile=False,
- )
- else:
- context_chunk_partition, noise_chunk_partition = (
- chunk_partition[: far_cfg["num_compressed_chunk"]],
- chunk_partition[far_cfg["num_compressed_chunk"] :],
- ) # noqa: E501
-
- if len(context_chunk_partition) != 0:
- context_frame_idx = torch.cat(
- [
- torch.ones(chunk_len * far_cfg["compressed_token_per_frame"], device=device) * chunk_idx
- for chunk_idx, chunk_len in enumerate(context_chunk_partition)
- ]
- ) # noqa: E501
- else:
- context_frame_idx = None
-
- noise_frame_idx = torch.cat(
- [
- torch.ones(chunk_len * far_cfg["full_token_per_frame"], device=device)
- * (chunk_idx + len(context_chunk_partition))
- for chunk_idx, chunk_len in enumerate(noise_chunk_partition)
- ]
- ) # noqa: E501
- pad_frame_idx = torch.zeros(padded_seq_len - real_seq_len, device=device)
+ def build_attention_mask(
+ self,
+ *,
+ chunk_partition: List[int],
+ height: int,
+ width: int,
+ has_clean_context: bool = False,
+ device: Optional[torch.device] = None,
+ mode: str = "train",
+ ) -> BlockMask:
+ r"""Pre-build the causal :class:`~torch.nn.attention.flex_attention.BlockMask` outside ``forward``.
+
+ Pass the result via :meth:`forward`'s ``attention_mask`` kwarg to make the whole transformer compatible with
+ ``torch.compile(fullgraph=True)``. Without a pre-built mask, ``forward`` falls back to constructing it
+ internally — that path uses ``flex_attention.create_block_mask(_compile=False)`` and breaks the compile graph.
- if len(context_chunk_partition) != 0:
- frame_idx = torch.cat([context_frame_idx, noise_frame_idx, pad_frame_idx], dim=0)
- else:
- frame_idx = torch.cat([noise_frame_idx, pad_frame_idx], dim=0)
-
- def mask_mod(b, h, q_idx, kv_idx):
- is_padding = (q_idx >= real_seq_len) | (kv_idx >= real_seq_len)
- base = frame_idx[q_idx] >= frame_idx[kv_idx]
- return base & ~is_padding
-
- return create_block_mask(
- mask_mod,
- B=None,
- H=None,
- Q_LEN=padded_seq_len,
- KV_LEN=padded_seq_len,
- device=device,
- _compile=False,
- )
+ Args:
+ chunk_partition: per-chunk frame counts (must sum to the number of latent frames).
+ height, width: latent spatial dimensions.
+ has_clean_context: ``True`` when ``clean_hidden_states`` will be threaded through :meth:`forward`
+ (training V2V/I2V); only this presence flag affects the mask layout.
+ device: device for the resulting :class:`BlockMask`. The mask is not auto-moved by
+ ``device_map="auto"``; build it on the same device the transformer's inputs will live on.
+ mode: ``"train"`` (matches :meth:`_forward_train`) or ``"cache"`` (matches :meth:`_forward_cache`).
+ The autoregressive ``_forward_inference`` path attends through the KV cache and has no mode here.
+
+ Returns:
+ :class:`~torch.nn.attention.flex_attention.BlockMask`: causal mask spanning the FAR layout, padded to a
+ multiple of 128 along the sequence dimension (the BlockMask block-size requirement).
+
+ Raises:
+ ValueError: if ``mode`` is neither ``"train"`` nor ``"cache"``.
+ """
+ return _build_anyflow_far_causal_block_mask(
+ chunk_partition=chunk_partition,
+ height=height,
+ width=width,
+ patch_size=self.config.patch_size,
+ compressed_patch_size=self.config.compressed_patch_size,
+ full_chunk_limit=self.config.full_chunk_limit,
+ mode=mode,
+ has_clean_context=has_clean_context,
+ device=device,
+ )
def _forward_inference(
self,
@@ -1291,6 +1390,7 @@ def _forward_cache(
r_timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ attention_mask: Optional[BlockMask] = None,
return_dict: bool = True,
clean_hidden_states=None,
clean_timestep=None,
@@ -1337,9 +1437,10 @@ def _forward_cache(
far_cfg["num_compressed_frames"] * far_cfg["compressed_token_per_frame"]
)
- attention_mask = self._build_causal_mask(
- far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype
- )
+ if attention_mask is None:
+ attention_mask = _build_far_block_mask_from_far_cfg(
+ far_cfg, has_clean=clean_hidden_states is not None, device=hidden_states.device
+ )
rotary_emb = self.rope(far_cfg=far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device)
hidden_states = self._forward_far_patchify(
@@ -1396,6 +1497,7 @@ def _forward_train(
r_timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ attention_mask: Optional[BlockMask] = None,
return_dict: bool = True,
clean_hidden_states=None,
clean_timestep=None,
@@ -1436,9 +1538,12 @@ def _forward_train(
"chunk_partition": chunk_partition,
}
- attention_mask = self._build_causal_mask(
- far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device, dtype=hidden_states.dtype
- )
+ if attention_mask is None:
+ # Fallback for callers that don't pre-build (e.g. training scripts). Not compile-safe;
+ # use :meth:`build_attention_mask` upstream when wrapping `forward` in `torch.compile`.
+ attention_mask = _build_far_block_mask_from_far_cfg(
+ far_cfg, has_clean=clean_hidden_states is not None, device=hidden_states.device
+ )
rotary_emb = self.rope(far_cfg=far_cfg, clean_hidden_states=clean_hidden_states, device=hidden_states.device)
diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py
index 0eb60b525a0f..c3e1dbf3a459 100644
--- a/src/diffusers/pipelines/anyflow/pipeline_anyflow.py
+++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow.py
@@ -80,11 +80,11 @@ def prompt_clean(text):
class AnyFlowPipeline(DiffusionPipeline, WanLoraLoaderMixin):
r"""
Bidirectional text-to-video generation pipeline for AnyFlow flow-map-distilled checkpoints, introduced in
- [AnyFlow](https://huggingface.co/papers/2605.13724) by Yuchao Gu, Guian Fang et al.
+ [AnyFlow](https://huggingface.co/papers/2605.13724).
AnyFlow learns arbitrary-interval transitions :math:`z_t \to z_r` rather than the fixed :math:`z_t \to z_0` mapping
of consistency models, so a single distilled checkpoint can be evaluated at 1, 2, 4, 8, 16... NFE without
- retraining. This pipeline operates over the full video tensor in one bidirectional pass; for frame-level
+ retraining. This pipeline operates over the full video tensor in one bidirectional pass; for chunk-wise
autoregressive (causal) generation use ``AnyFlowFARPipeline``.
Sampling is plain Euler in mean-velocity form (``z_r = z_t - (t - r) * u``) with no re-noising. The released NVIDIA
diff --git a/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py b/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py
index e73c44b2fde3..e33f8b7c3873 100644
--- a/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py
+++ b/src/diffusers/pipelines/anyflow/pipeline_anyflow_far.py
@@ -92,11 +92,10 @@ def prompt_clean(text):
class AnyFlowFARPipeline(DiffusionPipeline, WanLoraLoaderMixin):
r"""
Causal (FAR-based) text-to-video / image-to-video / video-to-video pipeline for AnyFlow checkpoints, introduced in
- [AnyFlow](https://huggingface.co/papers/2605.13724) by Yuchao Gu, Guian Fang et al.
+ [AnyFlow](https://huggingface.co/papers/2605.13724).
- The pipeline drives a frame-level autoregressive sampling loop over chunks: each chunk is denoised with flow-map
- steps while attending only to past chunks via block-sparse causal attention, and intermediate KV cache is reused
- across chunks.
+ The pipeline drives a chunk-wise autoregressive sampling loop: each chunk is denoised with flow-map steps while
+ attending only to past chunks via block-sparse causal attention, and intermediate KV cache is reused across chunks.
The task mode (T2V / I2V / V2V) is selected by which conditioning argument is passed to ``__call__``:
@@ -106,9 +105,9 @@ class AnyFlowFARPipeline(DiffusionPipeline, WanLoraLoaderMixin):
- ``video_latents=`` — already-encoded latents in the
FAR layout (skips the VAE encode step).
- The FAR backbone is the causal Wan2.1 variant introduced by FAR (Gu et al., 2025; arXiv:2503.19325). Inference is
- plain Euler in mean-velocity form per chunk with no re-noising. Joint T2V / I2V / V2V is supported by a single
- distilled model.
+ The FAR backbone is the causal Wan2.1 variant introduced by [FAR](https://huggingface.co/papers/2503.19325).
+ Inference is plain Euler in mean-velocity form per chunk with no re-noising. Joint T2V / I2V / V2V is supported by
+ a single distilled model.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -129,11 +128,6 @@ class AnyFlowFARPipeline(DiffusionPipeline, WanLoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
- # Default chunk partition for the released NVIDIA AnyFlow-FAR checkpoints (81 frames at the diffusers
- # VAE temporal stride of 4 → 21 latent frames split into 1 + 3*6 + 2 = [1, 3, 3, 3, 3, 3, 3, 2]). Override
- # via the ``chunk_partition`` argument to ``__call__`` for other frame counts.
- default_chunk_partition: List[int] = [1, 3, 3, 3, 3, 3, 3, 2]
-
def __init__(
self,
tokenizer: AutoTokenizer,
@@ -423,12 +417,21 @@ def encode_kv_cache(
r_timestep = torch.tensor([0], device=latents.device).expand(latent_model_input.shape[0]).unsqueeze(-1)
r_timestep = r_timestep.repeat((1, latent_model_input.shape[1]))
+ attention_mask = self.transformer.build_attention_mask(
+ chunk_partition=chunk_partition,
+ height=latent_model_input.shape[-2],
+ width=latent_model_input.shape[-1],
+ device=latent_model_input.device,
+ mode="cache",
+ )
+
_, kv_cache = self.transformer(
hidden_states=latent_model_input,
chunk_partition=chunk_partition,
timestep=timestep,
r_timestep=r_timestep,
encoder_hidden_states=prompt_embeds,
+ attention_mask=attention_mask,
attention_kwargs=self.attention_kwargs,
return_dict=False,
# kv-cache related
@@ -538,9 +541,9 @@ def __call__(
use_kv_cache (`bool`, defaults to `True`):
Reuse the FAR attention KV cache across causal chunks. Disable only for debugging.
chunk_partition (`List[int]`, *optional*):
- Per-chunk frame counts. Defaults to `default_chunk_partition` (matched to the released 81-frame
- checkpoints). When you change `num_frames`, supply a `chunk_partition` that sums to `(num_frames - 1)
- // vae_scale_factor_temporal + 1`.
+ Per-chunk frame counts. Defaults to `self.transformer.config.chunk_partition` (matched to the released
+ 81-frame checkpoints). When you change `num_frames`, supply a `chunk_partition` that sums to
+ `(num_frames - 1) // vae_scale_factor_temporal + 1`.
Examples:
@@ -629,7 +632,7 @@ def __call__(
video_latents = self.encode_video(video, height=height, width=width)
if chunk_partition is None:
- chunk_partition = list(self.default_chunk_partition)
+ chunk_partition = list(self.transformer.config.chunk_partition)
if init_latents.shape[1] != sum(chunk_partition):
raise ValueError(
f"chunk_partition={chunk_partition} sums to {sum(chunk_partition)}, but the input latent "
@@ -713,14 +716,14 @@ def __call__(
this_chunk_partition = chunk_partition[: chunk_idx + 1]
self.scheduler.set_timesteps(num_inference_steps, device=device, sigmas=sigmas, timesteps=timesteps)
- timesteps = self.scheduler.timesteps
+ scheduler_timesteps = self.scheduler.timesteps
inner_progress_bar_config = {
**outer_progress_bar_config,
"position": 1,
"leave": False,
"desc": f"Chunk {chunk_idx} Inference Steps",
}
- for i, t in enumerate(tqdm(timesteps, **inner_progress_bar_config)):
+ for i, t in enumerate(tqdm(scheduler_timesteps, **inner_progress_bar_config)):
r = self.scheduler.sigmas[i + 1] * self.scheduler.config.num_train_timesteps
if t == r:
continue
diff --git a/tests/models/transformers/test_models_transformer_anyflow_far.py b/tests/models/transformers/test_models_transformer_anyflow_far.py
index d3631e361c09..6ebbe6596f51 100644
--- a/tests/models/transformers/test_models_transformer_anyflow_far.py
+++ b/tests/models/transformers/test_models_transformer_anyflow_far.py
@@ -21,6 +21,7 @@
from diffusers.models.transformers.transformer_anyflow_far import (
AnyFlowCausalAttnProcessor,
AnyFlowFARTransformerOutput,
+ _build_anyflow_far_causal_block_mask,
)
from diffusers.utils.torch_utils import randn_tensor
@@ -30,6 +31,7 @@
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
+ TorchCompileTesterMixin,
TrainingTesterMixin,
)
@@ -88,6 +90,9 @@ def get_dummy_inputs(self) -> dict[str, "torch.Tensor"]:
text_seq_len = 12
text_dim = 16
+ # No `attention_mask` here: the model accepts `attention_mask=None` and builds the mask
+ # internally — exercising that fallback in every non-compile test is the point. The
+ # compile test class overrides this method to inject a pre-built BlockMask.
return {
"hidden_states": randn_tensor(
(batch_size, num_frames, num_channels, height, width),
@@ -149,11 +154,36 @@ class TestAnyFlowFARTransformer3DAttention(AnyFlowFARTransformer3DTesterConfig,
"""Attention processor tests for AnyFlow FAR Transformer 3D."""
-# Torch-compile mixin intentionally skipped: FAR's `_build_causal_mask` uses
-# `flex_attention.create_block_mask(_compile=False)`, which conflicts with the tracer
-# assumptions made by the standard TorchCompileTesterMixin. The bidi transformer test file
-# covers compile behavior; the FAR causal path is bit-exact-validated end-to-end on H200
-# through the pipeline replay rather than per-module compile.
+class TestAnyFlowFARTransformer3DCompile(AnyFlowFARTransformer3DTesterConfig, TorchCompileTesterMixin):
+ """torch.compile tests for AnyFlow FAR Transformer 3D.
+
+ Pre-builds the BlockMask via the standalone helper and injects it as ``attention_mask`` so the
+ transformer forward never calls ``flex_attention.create_block_mask(_compile=False)`` inside the
+ compiled scope.
+ """
+
+ def get_dummy_inputs(self) -> dict[str, "torch.Tensor"]:
+ inputs = super().get_dummy_inputs()
+ init_dict = self.get_init_dict()
+ inputs["attention_mask"] = _build_anyflow_far_causal_block_mask(
+ chunk_partition=inputs["chunk_partition"],
+ height=inputs["hidden_states"].shape[-2],
+ width=inputs["hidden_states"].shape[-1],
+ patch_size=init_dict["patch_size"],
+ compressed_patch_size=init_dict["compressed_patch_size"],
+ full_chunk_limit=init_dict["full_chunk_limit"],
+ mode="train",
+ has_clean_context=False,
+ device=torch_device,
+ )
+ return inputs
+
+ @pytest.mark.skip(reason="torch.export does not accept BlockMask as a pytree input.")
+ def test_compile_works_with_aot(self, tmp_path):
+ # BlockMask is a custom NamedTuple containing tensors plus a Python callable `mask_mod`,
+ # which `torch.export` cannot lift into a pytree. `torch.compile(fullgraph=True)` and
+ # `compile_repeated_blocks` both work; only AOT export is blocked.
+ super().test_compile_works_with_aot(tmp_path)
class AnyFlowCausalAttnProcessorTest(unittest.TestCase):
diff --git a/tests/pipelines/anyflow/test_anyflow_far.py b/tests/pipelines/anyflow/test_anyflow_far.py
index 8086afef6d65..de244d563ec6 100644
--- a/tests/pipelines/anyflow/test_anyflow_far.py
+++ b/tests/pipelines/anyflow/test_anyflow_far.py
@@ -90,6 +90,7 @@ def get_dummy_components(self):
rope_max_seq_len=32,
gate_value=0.25,
deltatime_type="r",
+ chunk_partition=(1, 1, 1),
)
components = {
@@ -106,8 +107,8 @@ def get_dummy_inputs(self, device, seed=0):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
- # num_frames=9 -> 3 latent frames (VAE temporal stride 4); use a matching
- # chunk_partition so the FAR pipeline's pre-flight assertion passes.
+ # num_frames=9 -> 3 latent frames (VAE temporal stride 4); the transformer config above
+ # has chunk_partition=(1, 1, 1) (sum 3) baked in, so __call__ picks it up automatically.
inputs = {
"prompt": "dance monkey",
"negative_prompt": "negative",
@@ -119,7 +120,6 @@ def get_dummy_inputs(self, device, seed=0):
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
- "chunk_partition": [1, 1, 1],
}
return inputs