[#14225][perf] AutoDeploy MTP + ADP enablement and MoE all-to-all optimization#15063
[#14225][perf] AutoDeploy MTP + ADP enablement and MoE all-to-all optimization#15063MrGeva wants to merge 1 commit into
Conversation
|
/bot run |
📝 WalkthroughWalkthroughThis PR removes a DP-aware token-info slot from ChangesBatch-info slot removal and MoE dispatch refactoring
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)
405-434:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winSync the remaining
batch_infodocs with the new 14-slot contract.This docstring is updated, but
SequenceInfostill saysbatch_info_hostincludes “DP-aware token info”, andnest_sequences()still documentsbatch_infoas a 3-element shape. That leaves this file with multiple incompatible specs for the same tensor.Based on learnings, ensure that
batch_info's format matches its documented spec.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py` around lines 405 - 434, Update the remaining docs and any related constants to match the 14-slot batch_info contract: change SequenceInfo's description of batch_info_host to list the 14 elements (or refer to the shared doc) instead of "DP-aware token info", and update nest_sequences() docstring to describe batch_info as a 14-element vector (not a 3-element shape); also verify any references to _NUM_ELEMENTS, batch_info, or batch_info_host in SequenceInfo, nest_sequences(), and adjacent helpers reflect the 14-slot semantics (including slots 0–13 names like num_prefill, max_context_length, max_draft_len, use_replay) so all docs and constants are consistent.Source: Learnings
🧹 Nitpick comments (1)
tests/integration/test_lists/test-db/l0_dgx_b200.yml (1)
397-397: Coverage sufficiency looks good for this PR scope.Adding
perf/test_perf_sanity.py::test_e2e[aggr_upload-super_mtp_ad_nvfp4_blackwell-super_mtp_ad_nvfp4_ws4_1k1k]intests/integration/test_lists/test-db/l0_dgx_b200.ymlgives targeted post-merge coverage for the new NVFP4 Super MTP AutoDeploy profile.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/integration/test_lists/test-db/l0_dgx_b200.yml` at line 397, Add the new targeted test entry for the NVFP4 Super MTP AutoDeploy profile to the l0_dgx_b200 test list by inserting the exact test identifier perf/test_perf_sanity.py::test_e2e[aggr_upload-super_mtp_ad_nvfp4_blackwell-super_mtp_ad_nvfp4_ws4_1k1k1k] (use the identifier from the review content) into tests/integration/test_lists/test-db/l0_dgx_b200.yml under the appropriate test list block, preserving YAML list syntax and indentation, avoid duplicates, and ensure the test string is quoted or escaped if needed to prevent YAML parsing issues; verify the entry appears exactly as the reviewer requested and run a quick YAML lint to confirm validity.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/auto_deploy/model_registry/configs/super_v3_mtp_low_latency.yaml`:
- Around line 68-69: Update the explanatory comment that currently states
"Triton SSM + causal conv are required for MTP as currently they are the only
backends that support speculative mamba state caching." to reflect that
FlashInfer SSM is also supported: mention that Triton SSM and FlashInfer SSM
(`flashinfer_ssm`) — together with causal conv — are supported backends for MTP
speculative mamba state caching where applicable, and ensure consistency with
the paired config `super_v3_mtp.yaml`.
- Around line 4-10: Update the header diff notes to match the actual config keys
and values: change the "16 vs 128" phrasing to reflect that super_v3_mtp.yaml in
this PR uses max_batch_size: 64 (so say "16 vs 64" or just "max_batch_size
lowered to 16 from 64"), and replace references to cuda_graph_batch_sizes with
the correct nested key cuda_graph_config.batch_sizes; ensure the explanatory
bullets reference the actual keys max_batch_size and
cuda_graph_config.batch_sizes and the actual removed batch sizes (drop 24, 32,
64, 128) so the header is accurate and not misleading.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py`:
- Around line 62-66: The branch that returns a capture-time token budget uses
cuda_graph_state.in_warm_up() and torch.cuda.is_current_stream_capturing() but
does not respect the global bypass flag, so steps forced to eager via
BypassCapturedGraphs() still get local_num_tokens; update the conditional that
checks capture/warm-up (the if using torch.cuda.is_current_stream_capturing() or
cuda_graph_state.in_warm_up()) to also require that BypassCapturedGraphs() is
false (i.e., only apply the capture-time budget when not bypassed), keeping the
existing budget checks (budget > 0 and budget * ep_size * 4 <= max_num_tokens)
and otherwise return max_num_tokens as before.
---
Outside diff comments:
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py`:
- Around line 405-434: Update the remaining docs and any related constants to
match the 14-slot batch_info contract: change SequenceInfo's description of
batch_info_host to list the 14 elements (or refer to the shared doc) instead of
"DP-aware token info", and update nest_sequences() docstring to describe
batch_info as a 14-element vector (not a 3-element shape); also verify any
references to _NUM_ELEMENTS, batch_info, or batch_info_host in SequenceInfo,
nest_sequences(), and adjacent helpers reflect the 14-slot semantics (including
slots 0–13 names like num_prefill, max_context_length, max_draft_len,
use_replay) so all docs and constants are consistent.
---
Nitpick comments:
In `@tests/integration/test_lists/test-db/l0_dgx_b200.yml`:
- Line 397: Add the new targeted test entry for the NVFP4 Super MTP AutoDeploy
profile to the l0_dgx_b200 test list by inserting the exact test identifier
perf/test_perf_sanity.py::test_e2e[aggr_upload-super_mtp_ad_nvfp4_blackwell-super_mtp_ad_nvfp4_ws4_1k1k1k]
(use the identifier from the review content) into
tests/integration/test_lists/test-db/l0_dgx_b200.yml under the appropriate test
list block, preserving YAML list syntax and indentation, avoid duplicates, and
ensure the test string is quoted or escaped if needed to prevent YAML parsing
issues; verify the entry appears exactly as the reviewer requested and run a
quick YAML lint to confirm validity.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 0fb5e177-877d-4d70-8e49-f7d13437e036
📒 Files selected for processing (17)
examples/auto_deploy/model_registry/configs/super_v3_mtp.yamlexamples/auto_deploy/model_registry/configs/super_v3_mtp_low_latency.yamltensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.pytensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.pytensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.pytensorrt_llm/_torch/auto_deploy/llm_args.pytensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.pytensorrt_llm/_torch/auto_deploy/models/eagle.pytensorrt_llm/_torch/auto_deploy/shim/ad_executor.pytensorrt_llm/_torch/auto_deploy/transform/library/sharding.pytensorrt_llm/_torch/auto_deploy/transform/library/sharding_ir.pytensorrt_llm/_torch/auto_deploy/utils/cuda_graph.pytests/integration/defs/accuracy/test_llm_api_autodeploy.pytests/integration/test_lists/test-db/l0_dgx_b200.ymltests/scripts/perf-sanity/aggregated/super_mtp_ad_nvfp4_blackwell.yamltests/unittest/auto_deploy/multigpu/compile/test_bypass_captured_graphs.pytests/unittest/auto_deploy/singlegpu/models/test_eagle.py
💤 Files with no reviewable changes (1)
- tensorrt_llm/_torch/auto_deploy/transform/library/sharding_ir.py
| # Diff from ``super_v3_mtp.yaml``: | ||
| # - ``max_batch_size`` lowered (16 vs 128). Less mamba/KV cache pressure | ||
| # and only small captured graphs are exercised. | ||
| # - ``cuda_graph_batch_sizes`` trimmed to the latency-relevant range | ||
| # (drop 24, 32, 64, 128). Every captured graph costs warmup time and | ||
| # GPU memory; if you'll never serve c>16, capturing larger sizes is | ||
| # wasted. |
There was a problem hiding this comment.
Header diff notes are stale and currently misleading.
Line [5] says max_batch_size is "16 vs 128", but super_v3_mtp.yaml in this PR uses max_batch_size: 64.
Line [7] references cuda_graph_batch_sizes, while this file uses cuda_graph_config.batch_sizes. Please sync the header notes with the actual config keys/values to avoid operator confusion.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/auto_deploy/model_registry/configs/super_v3_mtp_low_latency.yaml`
around lines 4 - 10, Update the header diff notes to match the actual config
keys and values: change the "16 vs 128" phrasing to reflect that
super_v3_mtp.yaml in this PR uses max_batch_size: 64 (so say "16 vs 64" or just
"max_batch_size lowered to 16 from 64"), and replace references to
cuda_graph_batch_sizes with the correct nested key
cuda_graph_config.batch_sizes; ensure the explanatory bullets reference the
actual keys max_batch_size and cuda_graph_config.batch_sizes and the actual
removed batch sizes (drop 24, 32, 64, 128) so the header is accurate and not
misleading.
|
PR_Github #52586 [ run ] triggered by Bot. Commit: |
02fbb3d to
9563db8
Compare
|
PR_Github #52586 [ run ] completed with state
|
9563db8 to
c0acbd7
Compare
|
/bot run |
|
PR_Github #52591 [ run ] triggered by Bot. Commit: |
…mization Rebases the SuperV3-MTP attention-DP optimization onto current upstream/main (which now carries gk's MoE all-to-all stateful cache NVIDIA#13718/NVIDIA#13723 and gagam's SSM-replay PR). All net changes from the optimization branch are preserved. MoE all-to-all per-rank token budget (runtime_max_tokens_per_rank): - Replace the per-iteration cross-rank-max read (an int(batch_info_host[14] .item()) on a pinned-host tensor, fed by a per-forward tp_allgather) with a sync-free shape-based budget via _hybrid_runtime_max_tokens_per_rank: under cuda-graph capture/warm-up the budget is x.shape[0] (uniform across DP ranks because maybe_pad_for_cuda_graph pads every rank to a common cg_batch_size and MTP tokens-per-seq is uniform), gated so the tight budget is only taken while the MoE-GEMM row count stays in the fast small-M tactic region; in eager (prefill or bypass) it falls back to the static max_num_tokens every rank computes identically. No per-layer host read. - Drop the now-dead batch_info_host plumbing for the DP-max slot: the slot-14 (max_dp_num_tokens) storage and update/get accessors in BatchInfo (_NUM_ELEMENTS 15->14), the pre-forward tp_allgather + update in the AD shim, and the batch_info_host injection into the MoE op in both the dict-based (sharding.py) and IR-based (sharding_ir.py) sharding paths, plus the op signatures in trtllm_moe.py / torch_moe.py. Scheduling / sharding: - Enable the attention-DP request balancer (PyExecutor._balance_adp_requests) in the AD executor so prefill is co-scheduled across ranks, avoiding a single prefill straggler stalling the others at the MoE all-to-all collective. - Keep the draft-EP revert under attention-DP in sharding (replicate the draft model's MoE rather than EP-sharding it) to avoid the shared-workspace corruption that hangs the all-to-all at concurrency. Mixed-mode cuda-graph bypass uses upstream's process-wide BypassCapturedGraphs() context manager (cuda_graph_state.in_bypass()) instead of the per-instance flag, keeping all ranks consistent when one is in prefill. Tests: - Add NVFP4 SuperV3-MTP attn-DP perf-sanity config + post-merge enrollment. - test_mtp / test_accuracy NVFP4 coverage resolved to upstream's parametrization. Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
c0acbd7 to
25a8ad1
Compare
|
PR_Github #52591 [ run ] completed with state
|
Description
Optimizes the SuperV3 MTP + attention-DP path, building on the recently merged
MoE all-to-all stateful cache (#13718 / #13723) and the SSM-replay PR.
MoE all-to-all per-rank token budget (
runtime_max_tokens_per_rank)int(batch_info_host[14].item())on a pinned-host tensor, fed by a per-forward
tp_allgather) with a sync-free,shape-based budget via
_hybrid_runtime_max_tokens_per_rank:x.shape[0]— uniform across DPranks because
maybe_pad_for_cuda_graphpads every rank to a commoncg_batch_sizeand MTP tokens-per-seq is uniform — gated so the tight budget isonly taken while the MoE-GEMM row count stays in the fast small-M tactic region.
to the static
max_num_tokens, which every rank computes identically.batch_info_hostDP-max plumbing: the slot-14 (max_dp_num_tokens)storage + accessors in
BatchInfo(_NUM_ELEMENTS15→14), the pre-forwardtp_allgather+update in the AutoDeploy shim, the injection into the MoE op in boththe dict-based (
sharding.py) and IR-based (sharding_ir.py) sharding paths, and thecorresponding op signatures in
trtllm_moe.py/torch_moe.py.Scheduling / sharding
PyExecutor._balance_adp_requests) in theAD executor so prefill is co-scheduled across ranks, avoiding a single prefill
straggler stalling the others at the MoE all-to-all collective.
MoE rather than EP-sharding it) to avoid the shared-workspace corruption that hangs the
all-to-all at concurrency.
Mixed-mode cuda-graph bypass uses the process-wide
BypassCapturedGraphs()contextmanager (
cuda_graph_state.in_bypass()), keeping all ranks consistent when one is in prefill.Perf Impact
see: #14225 (comment)
Test Coverage
tests/integration/defs/accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[*]and
::test_accuracy[*-4-attn_dp_on-trtllm](bf16/fp8/nvfp4) — exercise the MTP +attention-DP MoE all-to-all path.
tests/unittest/auto_deploy/multigpu/compile/test_bypass_captured_graphs.py— guards themixed-mode captured-graph bypass.
perf/test_perf_sanity.py::test_e2e[aggr_upload-super_mtp_ad_nvfp4_blackwell-super_mtp_ad_nvfp4_ws4_1k1k].with no NaN / hang / all-to-all timeout; GSM8K 47/50 = 94.0%.
PR Checklist
Summary by CodeRabbit
New Features
Configuration Updates
Performance Improvements