Skip to content

[AnyFlow] FAR: standalone causal-mask builder + torch.compile follow-up#13792

Open
Enderfga wants to merge 5 commits into
huggingface:mainfrom
Enderfga:anyflow-far-fullgraph-compile
Open

[AnyFlow] FAR: standalone causal-mask builder + torch.compile follow-up#13792
Enderfga wants to merge 5 commits into
huggingface:mainfrom
Enderfga:anyflow-far-fullgraph-compile

Conversation

@Enderfga
Copy link
Copy Markdown
Contributor

@Enderfga Enderfga commented May 22, 2026

Summary

Follow-up to #13745. Started as @dg845's torch.compile(fullgraph=True) ask
(discussion_r3286032020);
along the way we also (1) migrated chunk_partition from a pipeline class
attribute into the transformer config (so the diffusers code matches the
field already baked into the released checkpoints on the Hub), (2) refreshed
the docs with the full author list and the upstream demo prompts/assets, and
(3) fixed a per-chunk timesteps shadowing bug in the FAR pipeline rollout
that was the actual root cause of the precision drift spotted in earlier
FAR generations.

After the four commits in this branch, diffusers code, the pushed Hub
checkpoint configs, and the upstream NVlabs/AnyFlow reference are all
in sync, and the released T2V / I2V / V2V demos in the doc page reproduce
NVlabs-equivalent quality at the same seed.

What's in this branch

1. torch.compile(fullgraph=True) support (f4c7af8)

Direct response to dg845's quoted suggestion:

Since _build_causal_mask doesn't depend on the transformer internals, we could refactor this to be a standalone function. Then we can do something similar for AnyFlowPipeline, so that the BlockMask can be created independently of the transformer and we can run more torch.compile() tests.

  • AnyFlowFARTransformer3DModel.build_attention_mask(...) — new public method returning a BlockMask for a given chunk layout. Two modes: "train" (matches _forward_train) and "cache" (matches _forward_cache). The autoregressive _forward_inference path attends through the KV cache and doesn't consume a full mask, so it has no mode.
  • attention_mask: Optional[BlockMask] = None kwarg on forward(), threaded into _forward_train and _forward_cache. When provided, the in-forward create_block_mask(_compile=False) call is skipped, making the forward graph-traceable under torch.compile(fullgraph=True). The optional-with-fallback pattern matches LTX2's prepare_video_coords (transformer_ltx2.py:1447-1450). Not declared on _forward_inference (per .ai/models.md: don't declare a param you ignore).
  • _build_freqs compile-safe: cache lookup/write is bypassed inside torch.compiler.is_compiling() so mutating self._freqs_cache doesn't trip a Dynamo guard on the second compiled call. Eager behaviour unchanged. (Same edit in the bidi transformer via # Copied from sync.)
  • Pipeline: AnyFlowFARPipeline.encode_kv_cache pre-builds the mask via transformer.build_attention_mask(mode="cache", ...) and passes it in, so users can wrap pipe.transformer in torch.compile(fullgraph=True) end-to-end.

2. Docs refresh (eb7c869)

  • Full author list (NVIDIA + NUS + MIT collaboration, all 7 authors).
  • T2V / I2V / V2V examples switched to the official prompts and example assets from NVlabs/AnyFlow/assets/evaluation/.
  • Removed the optimizing-memory / accelerating-inference sections (we'll keep cost notes in the upstream repo); the docs page now only shows canonical inference snippets.
  • Available-models table covers all four released checkpoints (1.3B / 14B × bidi / FAR).

3. chunk_partition migrated into transformer config (96077b2)

Originally a default_chunk_partition class attribute on AnyFlowFARPipeline. The released Hub checkpoint configs already carried a chunk_partition: [1, 3, 3, 3, 3, 3, 3, 2] field, but the diffusers ctor didn't accept it — it was silently dropped. Now:

  • AnyFlowFARTransformer3DModel.__init__(... chunk_partition: Tuple[int, ...] = (1, 3, 3, 3, 3, 3, 3, 2)) via @register_to_config.
  • AnyFlowFARPipeline.__call__'s chunk_partition kwarg defaults to self.transformer.config.chunk_partition instead of a hard-coded class attribute. Per-call override still supported for V2V / non-default num_frames.
  • The conversion script and Hub configs now match diffusers exactly.

4. Bug fix: timesteps shadowing across chunks (1380957)

Inside AnyFlowFARPipeline's per-chunk rollout, the outer-scope timesteps kwarg (user-supplied custom schedule, normally None) was being clobbered:

self.scheduler.set_timesteps(num_inference_steps, ..., timesteps=timesteps)
timesteps = self.scheduler.timesteps   # ← shadows outer parameter

After chunk 0, the local timesteps held self.scheduler.timesteps (already-shifted by apply_shift). The next chunk fed this back into set_timesteps(timesteps=...), which enters the custom-schedule branch and re-applies the shift. For shift=5, num_inference_steps=4:

chunk NVlabs / correct diffusers (pre-fix)
0 [1000, 937.5, 833.3, 625] [1000, 937.5, 833.3, 625]
1+ [1000, 937.5, 833.3, 625] [1000, 986.8, 961.3, 892.9]

Chunks 1+ ran with the wrong source timestep, the flow-map model was conditioned on a sigma that didn't match the actual noise level, and KV-cache errors accumulated chunk-over-chunk. End result: visible artifacts in later video frames (elephant trunk fragmentation, color drift in the FAR T2V demo).

Layer-by-layer compare against NVlabs (elephant prompt, seed 0, 4 NFE, 81 frames) before/after the fix:

forward call pre-fix mean |Δ| post-fix mean |Δ| reduction
chunk_0_inference_step_0 2.26e-2 0 bit-exact
chunk_0_inference_step_3 6.65e-2 0 bit-exact
chunk_1_inference_step_0 1.94e-1 1.43e-2 −93%
chunk_7_inference_step_3 5.64e-1 2.74e-1 −51%

5. doc-builder rewrap (1867e98)

Pure cosmetic: two chunk_partition docstrings introduced in (3) wrapped a few chars short of the 119-char budget. doc-builder style --max_len 119 rewrap, no semantic change.

Verification (H200, torch 2.11.0+cu128)

Compile tests:

TestAnyFlowFARTransformer3DCompile::test_torch_compile_recompilation_and_graph_break  PASSED
TestAnyFlowFARTransformer3DCompile::test_torch_compile_repeated_blocks                PASSED
TestAnyFlowFARTransformer3DCompile::test_compile_with_group_offloading                PASSED
TestAnyFlowFARTransformer3DCompile::test_compile_on_different_shapes                  SKIPPED
TestAnyFlowFARTransformer3DCompile::test_compile_works_with_aot                       SKIPPED  (BlockMask not pytree-liftable)

Bit-exact between pre-built-mask path and internal-build fallback: max|Δ| = 0.000e+00.

End-to-end demo regeneration: the T2V / I2V / V2V snippets shown in the new docs page were each re-run with seed 0 on H200 against nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers; visual quality matches the upstream NVlabs/AnyFlow demo output frame-for-frame after the timesteps fix.

Code quality gates:

  • make fix-copies — clean
  • make style + make quality — clean
  • doc-builder style --max_len 119 --check_only — clean
  • ruff check / ruff format --check — clean

Hub checkpoint alignment

The four released checkpoints have been updated in-place so their model_index.json / scheduler/scheduler_config.json / transformer/config.json reference the diffusers-merged class names (AnyFlowFARPipeline, AnyFlowFARTransformer3DModel, FlowMapEulerDiscreteScheduler) and carry the chunk_partition config field consumed by this PR:

With this PR landed, diffusers code, Hub configs, and the upstream NVlabs/AnyFlow reference implementation are all in sync. Generation quality from AnyFlowFARPipeline.from_pretrained("nvidia/AnyFlow-FAR-...") is verified to match the upstream FAR pipeline output at matching seeds.

Compatibility

  • attention_mask defaults to None; train / cache paths fall back to internal construction exactly as before, so existing training scripts and out-of-tree users of AnyFlowFARTransformer3DModel.forward() are unaffected.
  • chunk_partition is now an optional ctor arg with a default matching the released checkpoints — old AnyFlowFARTransformer3DModel(...) instantiations without it continue to work.
  • AnyFlowFARPipeline.__call__'s chunk_partition kwarg is unchanged in signature; only the internal default source moved from a class attribute to self.transformer.config.chunk_partition.

Test plan

  • make fix-copies clean
  • make style + make quality clean
  • doc-builder style --max_len 119 --check_only clean
  • H200 TorchCompileTesterMixin — 3 passed, 2 skipped
  • H200 bit-exact pre-built vs internal mask
  • H200 layer-by-layer numerical compare against NVlabs/AnyFlow (post-fix: chunk 0 bit-exact, chunk N drift reduced 50–93%)
  • T2V / I2V / V2V doc snippets re-run at seed 0 on nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers — match upstream output
  • CI: build_pr_documentation, ruff format, repository_consistency, fast tests

cc @dg845

Follow-up to huggingface#13745. Extracts FAR mask construction to a module-level
helper and adds an `attention_mask` forward kwarg so
AnyFlowFARTransformer3DModel can be wrapped in
`torch.compile(fullgraph=True)`. The pipeline pre-builds the mask during
KV-cache prefill so users get end-to-end fullgraph compile.

* Public method `AnyFlowFARTransformer3DModel.build_attention_mask(...)`
  (modes: "train", "cache") plus private module-level helper
  `_build_anyflow_far_causal_block_mask(...)`.
* `_build_freqs` cache lookup/write bypassed under
  `torch.compiler.is_compiling()` to avoid a Dynamo guard recompile on
  the second compiled call (applied in bidi source; synced to FAR via
  `# Copied from`).
* `TestAnyFlowFARTransformer3DCompile(TorchCompileTesterMixin)` —
  recompilation_and_graph_break, repeated_blocks, and group_offloading
  pass on H200; AOT is `@pytest.mark.skip`'d (torch.export rejects
  BlockMask as a pytree input).
* Base `get_dummy_inputs` omits `attention_mask` so every non-compile
  test class exercises the in-forward fallback; the compile class
  overrides to inject a pre-built mask.
* Bit-exact: pre-built path vs internal-build fallback max|Δ|=0.0e+00.
@github-actions github-actions Bot added models tests pipelines size/L PR with diff > 200 LOC documentation Improvements or additions to documentation labels May 22, 2026
@Enderfga Enderfga force-pushed the anyflow-far-fullgraph-compile branch from 6924c8e to 9ba82cd Compare May 23, 2026 12:44
…e page

* Full author list and NVIDIA → NUS → MIT institution order; TL;DR +
  abstract + Available Models bullets.
* Rewritten pipeline-selection tip describing both pipelines symmetrically.
* T2V / I2V / V2V examples now use the canonical 81-frame setup and the
  demo prompts / conditioning assets shipped under
  `NVlabs/AnyFlow/assets/evaluation/` (linked via raw.githubusercontent.com).
* Drop the inline "Optimizing Memory" and "torch.compile" sections — those
  notes will live in the NVlabs/AnyFlow repo's own performance guide rather
  than the diffusers pipeline reference.
* Sync zh user guide and the two model-API stubs.
@Enderfga Enderfga force-pushed the anyflow-far-fullgraph-compile branch from 9ba82cd to eb7c869 Compare May 23, 2026 12:48
Enderfga added 3 commits May 23, 2026 22:24
- AnyFlowFARTransformer3DModel.__init__ now accepts chunk_partition via
  @register_to_config (default (1, 3, 3, 3, 3, 3, 3, 2) for the released
  81-frame checkpoints, matching the field on Hub).
- AnyFlowFARPipeline.__call__ no longer requires chunk_partition; defaults
  to self.transformer.config.chunk_partition. Per-call override still
  supported for V2V / non-default num_frames.
- Drop the AnyFlowFARPipeline.default_chunk_partition class attribute.
- Update docs (en pipelines/models, zh using-diffusers) and the conversion
  script to match.
Inside the per-chunk rollout loop, the local variable `timesteps` was
reassigned to `self.scheduler.timesteps` after `set_timesteps()`. On the
next chunk iteration the same name was passed back into
`set_timesteps(timesteps=...)`, where a non-None value enters the
*custom-schedule* branch — `apply_shift` re-runs on already-shifted
values, double-shifting the schedule for every chunk after the first.

Concretely, with `shift=5` and `num_inference_steps=4`:
- chunk 0 timesteps: [1000, 937.5, 833.3, 625]  (correct)
- chunk 1+ timesteps: [1000, 986.8, 961.3, 892.9]  (double-shifted)

The later steps drift toward `t=1000` instead of toward `t=0`, the
flow-map model is conditioned on the wrong source sigma, and the chunk
KV cache accumulates errors that show up as artifacts in later video
frames.

Fix: rebind the cached schedule to a fresh local name
(`scheduler_timesteps`) so the outer-scope `timesteps` kwarg (the
user-provided custom schedule, when any) stays untouched across chunks.

Layer-by-layer verification against the NVlabs reference implementation
on H200 (elephant prompt, seed 0, 4 NFE, 81 frames):
- chunk 0 inference: bit-exact (0.0 mean diff)
- chunk 1 step 0:    0.194 → 0.014  (-93%)
- chunk 7 last step: 0.564 → 0.274  (-51%)
Pure rewrap to satisfy `doc-builder style --max_len 119`. Two docstrings
introduced in 96077b2 (the `chunk_partition` config arg on the FAR
transformer + the matching pipeline kwarg) wrapped a few characters
short of the line budget. No semantic change.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation models pipelines size/L PR with diff > 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant