Skip to content

Add examples on how to profile a pipeline#13356

Open
sayakpaul wants to merge 29 commits intomainfrom
profiling-workflow
Open

Add examples on how to profile a pipeline#13356
sayakpaul wants to merge 29 commits intomainfrom
profiling-workflow

Conversation

@sayakpaul
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul commented Mar 28, 2026

What does this PR do?

TL;DR: Adds a guide on how to profile a pipeline and fix issues like CPU overhead, CPU<->GPU syncs, etc.

Motivation

Since we provide first-class torch.compile support, it's important that our pipelines are set up for optimal success with it. This includes spotting any obvious issues that plague the torch.compile performance -- CPU overhead, CPU<->GPU syncs, graphbreaks, kernel launch delays, etc.

The best way to spot these bugs is to profile a pipeline, as it gives a granular measurement of where the GPU is spending time and if it is doing so in an expected manner. We can then uncover any unexpected issues and eventually fix them.

Workflow

The README.md added in the PR has all the descriptions, but in summary:

  • take a popular pipeline like Flux/Flux2/QwenImage/Wan/LTX2
  • run the profile with 2 inference steps
  • load the trace on Perfetto
  • spot the potential suspects
  • piggy that back to Claude along with the trace
    • ask it to attempt a fix
    • review the fix
    • compare the results

With this Workflow, I was able to fix some issues in the Flux2 Klein pipeline and the Wan pipeline. All changes look quite harmless to me.

Plan

Not only is it helpful to profile pipelines to get a ceiling on performance, but the community could also help us improve our pipelines should this workflow prove to be useful.

Note to reviewers

Please review the changes in src/diffusers/*. And you can skip straight to the "Afterwards" section in the README.md document.

The tutorial is currently available here. Some inline comments.

@sayakpaul sayakpaul requested review from DN6, dg845 and stevhliu March 28, 2026 04:41
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super educational, i enjoyed reading this a lot!

  • maybe rename "Approach" to something like "How the tooling works" because it describes how it works rather than what the user should do
  • it seems like "Afterwards" may be more effective as a blog post as it tells a story about issues 1 and 2 in the "What to look for" section
  • could be useful to add a link to this doc from our torch.compile docs


To inspect this: zoom into a single denoising step, select a CUDA kernel on the GPU row, and look at the corresponding CPU-side launch slice directly above it. The horizontal offset between them is the launch latency. In a healthy trace, CPU launch slices should be well ahead of GPU execution (the CPU is "feeding" the GPU faster than it can consume).

### Quick checklist per pipeline
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very helpful!

@sayakpaul sayakpaul requested a review from stevhliu March 31, 2026 04:24
- See if `scheduler_step` is disproportionately expensive relative to `transformer_forward` (it should be negligible)
- Spot unexpected CPU work between annotated regions

**4. Eager vs compile comparison**
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can mention cuda graphs here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be included under "Smaller CPU gaps"? If so, it's being mentioned a bit later since I wanted to keep the scope specific.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Copy link
Copy Markdown
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Love to see this


# We set the index here to remove DtoH sync, helpful especially during compilation.
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
self.scheduler.set_begin_index(0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, still feel like we need to broaden the scope of that fix within diffusers at some point :) I'll be out on sabbatical for the next month but I can help when I get back

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be awesome. Created #13375 to track. Thanks for offering to help.

Comment on lines +316 to +318
* Use of CUDA Graphs can also help mitigate CPU overhead related issues. When
using "reduce-overhead" and "max-autotune" in `torch.compile` triggers the
use of CUDA Graphs.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm glad you mentioned this here - I wonder if it's worth clarifying in the respective sections of this doc that CUDAGraph usage is the reason why we expect gap removal from using torch.compile

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The results and the graph presented in this doc were obtained with the "default" compilation mode (along with regional compilation):

COMPILE_ARGS="--compile_regional --compile_fullgraph --compile_mode default"

So not sure? There are also being shipped in this PR that helped mitigate the stalling issues 👀

@sayakpaul sayakpaul added the performance Anything related to performance improvements, profiling and benchmarking label Apr 1, 2026

Education materials to strategically profile pipelines to potentially improve their
runtime with `torch.compile`. To set these pipelines up for success with `torch.compile`,
we often have to get rid of DtoH syncs, CPU overheads, kernel launch delays, and
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
we often have to get rid of DtoH syncs, CPU overheads, kernel launch delays, and
we often have to get rid of device-to-host (DtoH) syncs, CPU overheads, kernel launch delays, and

I'm not sure if this terminology will be familiar to the target audience, I had to look it up the first time to find out what it meant.


## Context

We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in popular diffusers pipelines — especially issues that become non-trivial under `torch.compile`. The approach is inspired by [flux-fast's run_benchmark.py](https://github.com/huggingface/flux-fast/blob/0a1dcc91658f0df14cd7fce862a5c8842784c6da/run_benchmark.py#L66-L85) which uses `torch.profiler` with method-level annotations, and motivated by issues like [diffusers#11696](https://github.com/huggingface/diffusers/pull/11696) (DtoH sync from scheduler `.item()` call).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in popular diffusers pipelines — especially issues that become non-trivial under `torch.compile`. The approach is inspired by [flux-fast's run_benchmark.py](https://github.com/huggingface/flux-fast/blob/0a1dcc91658f0df14cd7fce862a5c8842784c6da/run_benchmark.py#L66-L85) which uses `torch.profiler` with method-level annotations, and motivated by issues like [diffusers#11696](https://github.com/huggingface/diffusers/pull/11696) (DtoH sync from scheduler `.item()` call).
We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in popular diffusers pipelines — especially issues that become non-trivial when using [`torch.compile`](https://docs.pytorch.org/docs/stable/generated/torch.compile.html). The approach is inspired by [flux-fast's run_benchmark.py](https://github.com/huggingface/flux-fast/blob/0a1dcc91658f0df14cd7fce862a5c8842784c6da/run_benchmark.py#L66-L85) which uses [`torch.profiler`](https://docs.pytorch.org/docs/stable/profiler.html) with method-level annotations, and motivated by issues like [diffusers#11696](https://github.com/huggingface/diffusers/pull/11696) (DtoH sync from scheduler `.item()` call).

I think adding links to the torch.compile and torch.profiler docs could be useful for following along, especially if readers aren't familiar with them.


## How the Tooling Works

Follow the flux-fast pattern: **annotate key pipeline methods** with `torch.profiler.record_function` wrappers, then run the pipeline under `torch.profiler.profile` and export a Chrome trace.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Follow the flux-fast pattern: **annotate key pipeline methods** with `torch.profiler.record_function` wrappers, then run the pipeline under `torch.profiler.profile` and export a Chrome trace.
Follow the flux-fast pattern: **annotate key pipeline methods** with `torch.profiler.record_function` wrappers, then run the pipeline under `torch.profiler.profile` and export a Chrome JSON trace.


## Verification

1. Run: `python profiling/profiling_pipelines.py --pipeline flux --mode eager --num_steps 2`
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
1. Run: `python profiling/profiling_pipelines.py --pipeline flux --mode eager --num_steps 2`
1. Run: `python examples/profiling/profiling_pipelines.py --pipeline flux --mode eager --num_steps 2`

I think running the script from the diffusers root directory might be more common?


Open the exported `.json` trace at [ui.perfetto.dev](https://ui.perfetto.dev/). The trace has two main rows: **CPU** (top) and **CUDA** (bottom). In Perfetto, the CPU row is typically labeled with the process/thread name (e.g., `python (PID)` or `MainThread`) and appears at the top. The CUDA row is labeled `GPU 0` (or similar) and appears below the CPU rows.

**Navigation:** Use `W` to zoom in, `S` to zoom out, and `A`/`D` to pan left/right. You can also scroll to zoom and click-drag to pan. Use `Shift+scroll` to scroll vertically through rows.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also found , and ., which select the previous/next track event, useful for e.g. finding the next transformer_forward event, or finding the next GPU kernel on the GPU kernel track.

Open both traces side by side (two Perfetto tabs). Key differences to look for:
- **Fewer, wider CUDA kernels** in compile mode (fused ops) vs many small kernels in eager
- **Smaller CPU gaps** between kernels in compile mode (less Python dispatch overhead)
- **CUDA kernel count per step**: to compare, zoom into a single `transformer_forward` span on the CUDA row and count the distinct kernel slices within it. In eager mode you'll typically see many narrow slices (one per op); in compile mode these fuse into fewer, wider slices. A quick way to estimate: select a time range covering one denoising step on the CUDA row — Perfetto shows the number of slices in the selection summary at the bottom. If compile mode shows a similar kernel count to eager, fusion isn't happening effectively (likely due to graph breaks).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I tried comparing an eager and compile trace side-by-side, I found that it was difficult to find corresponding events because if --mode compile is used, there appear to be no transformer_forward events in the trace, but rather one large ## Call CompiledFxGraph... event. (I think if regional compilation is used via --compile_regional, the transformer_forward event does appear again, with the ## CompiledFxGraph... events under the transformer_forward events.)

- There may be implicit syncs forcing serialization
- `torch.compile` should help here by batching launches — compare eager vs compile to confirm

To inspect this: zoom into a single denoising step, select a CUDA kernel on the GPU row, and look at the corresponding CPU-side launch slice directly above it. The horizontal offset between them is the launch latency. In a healthy trace, CPU launch slices should be well ahead of GPU execution (the CPU is "feeding" the GPU faster than it can consume).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
To inspect this: zoom into a single denoising step, select a CUDA kernel on the GPU row, and look at the corresponding CPU-side launch slice directly above it. The horizontal offset between them is the launch latency. In a healthy trace, CPU launch slices should be well ahead of GPU execution (the CPU is "feeding" the GPU faster than it can consume).
To inspect this: zoom into a single denoising step, select a CUDA kernel on the GPU row, and look at the corresponding CPU-side launch slice directly above it (there should be an arrow pointing from the CPU launch slice to the GPU kernel slice). The horizontal offset between them is the launch latency. In a healthy trace, CPU launch slices should be well ahead of GPU execution (the CPU is "feeding" the GPU faster than it can consume).

Not sure about the exact wording, but I think mentioning this is helpful, especially if there is a big temporal gap between the CPU cudaLaunchKernel and corresponding GPU kernel execution.

May also be worth mentioning that when the GPU kernel is selected, the corresponding cudaLaunchKernel event should be in the "Preceding Flows" section; I believe the "Delay" column then gives the exact launch latency.

### Spotting gaps between launches

Then a reasonable next step is to spot frequent gaps between kernel executions. In the compiled
case, we don't spot any on the surface. But if we zone in, some become apparent.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
case, we don't spot any on the surface. But if we zone in, some become apparent.
case, we don't spot any on the surface. But if we zoom in, some become apparent.

nit: typo

</table>

So, we provided the profile trace file (with compilation) to Claude, asked it to find the instances of
"cudaStreamSynchronize" and "cudaDeviceSynchronize", and to come up with some potential fixes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"cudaStreamSynchronize" and "cudaDeviceSynchronize", and to come up with some potential fixes.
`cudaStreamSynchronize` and `cudaDeviceSynchronize`, and to come up with some potential fixes.

nit: formatting

</tr>
</table>

### Spotting gaps between launches
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this section somewhat unsatisfying because it glosses over the reasoning behind each step in the profiling process. I think it would be more instructive if the chain of reasoning that led us (and Claude) to

  1. Find the tqdm and _unpack_latents_with_ids fixes
  2. Find out that these fixes weren't sufficient, and why they weren't sufficient
  3. Discover that cache_context was the bottleneck

was discussed at greater length in this section.

```

The changes looked reasonable based on our past experience. So, we asked Claude to apply these changes to [`pipeline_flux2_klein.py`](../../src/diffusers/pipelines/flux2/pipeline_flux2_klein.py). We then profiled
the updated pipeline. It still didn't eliminate the gaps as expected so, we fed that back to Claude and
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to have a demonstration using the profiling outputs that the _unpack_latents_with_ids fix indeed eliminates a DtoH sync because the current wording

It still didn't eliminate the gaps as expected

makes it unclear if the change is effective.

|------------------------|------------------------------|-----------------------------|
| `_set_context` total | 21.6ms (8 calls) | 0.0ms (8 calls) |
| `cache_context` total | 21.7ms | 0.1ms |
| CPU gaps | 5,523us / 8,007us / 5,508us | 158us / 2,777us / 136us |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding another row with the total wall clock time (or another "overall performance" metric) before and after would be useful here because it's not obvious to me that reducing the CPU gaps here necessarily leads to better performance overall.


![GPU idle](https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Wan/Screenshot%202026-03-27%20at%205.56.39%E2%80%AFPM.png)

The UniPC scheduler (used in Wan) creates small constant tensors via `torch.tensor([0.5], dtype=x.dtype, device=device)` during `step()`. This triggers a "cudaMemcpyAsync + cudaStreamSynchronize" to copy
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The UniPC scheduler (used in Wan) creates small constant tensors via `torch.tensor([0.5], dtype=x.dtype, device=device)` during `step()`. This triggers a "cudaMemcpyAsync + cudaStreamSynchronize" to copy
The UniPC scheduler (used in Wan) creates small constant tensors via `torch.tensor([0.5], dtype=x.dtype, device=device)` during `step()`. This triggers a `cudaMemcpyAsync` + `cudaStreamSynchronize` to copy

nit: formatting

Comment on lines +316 to +318
* Use of CUDA Graphs can also help mitigate CPU overhead related issues. When
using "reduce-overhead" and "max-autotune" in `torch.compile` triggers the
use of CUDA Graphs.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Use of CUDA Graphs can also help mitigate CPU overhead related issues. When
using "reduce-overhead" and "max-autotune" in `torch.compile` triggers the
use of CUDA Graphs.
* Use of CUDA Graphs can also help mitigate CPU overhead related issues. CUDA Graphs can be enabled by setting the `torch.compile` mode to `"reduce-overhead"` or `"max-autotune"`.

nit: I think the wording here is awkward

return latents

@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since _unpack_latents_with_ids was originally # Copied from Flux2Pipeline, should the changes here also be propagated to that pipeline?

Comment on lines +906 to +907
rks.append(torch.ones((), device=device))
rks = torch.stack(rks)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this change also be explained in the docs (examples/profiling/README.md)?

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left some questions/suggestions :).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Anything related to performance improvements, profiling and benchmarking

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants