CP Tests batching using subprocess worker pool#2993
Conversation
The existing test path spawns one torchrun per parametrized case, paying NCCL init + CUDA context + Python startup on every call. With ~hundreds of cases the launch overhead dominates wall time and was a primary driver of the L3 timeout that prior batching PRs worked around. This change replaces the per-case subprocess with one long-lived torchrun per (world_size). NCCL is initialized once at session start and reused across cases. Pytest sends one JSON request per case over rank-0 stdin; the worker dispatches to run_dpa_with_cp(**kwargs), gathers (ok, error) from every rank, and writes one JSON response on rank-0 stdout. run_attention_with_cp.py is left almost untouched; a new NVTE_CP_POOL_PG=1 env var gates the dist.init_process_group() and dist.destroy_process_group() calls so the function reuses the pool's main PG instead of creating its own. The per-case cp_comm_group (and a2a+p2p sub-groups) are explicitly destroyed at function exit to prevent communicator leakage across cases. The PoolWorker class adds two pieces of error recovery that the prior subprocess-per-case design got for free: a select-based per-call timeout (default 600s, NVTE_CP_POOL_TIMEOUT_SEC) and auto-respawn on worker death or timeout. A test-level exception is reported as an AssertionError and the pool keeps running for the next case. Two pool sizes are needed because cp_comm_type='a2a+p2p' requires world_size=4 and the others use world_size=2; you can't resize an active PG. Pools are spawned lazily so a 2-GPU-only run never pays the 4-GPU init. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Two resilience fixes carried over from the existing batching PR
(sudhakars/cp_test_batching_pr) without which the pool will
cascade-fail FP8 tests and silently propagate NCCL desync.
1. FP8GlobalStateManager.reset() between cases. FP8 quantizer state
(recipe handles, autocast counters) lives in module-level globals.
Reusing one Python process across cases otherwise carries that state
forward. The prior batching PR landed an explicit fix for the same
issue ("Fix FP8 cascade failures") after observing real test
failures from this.
2. dist.barrier() after each case. If one rank's case errored before
its last collective, the others can be stuck waiting on a comm that
will never complete. The barrier here surfaces that immediately as
a timeout in this case rather than letting the corruption leak into
the next case's collectives.
Also pops the transient NVTE_* env vars run_dpa_with_cp sets at the
top of each call. run_dpa_with_cp already sets them unconditionally so
this is defensive, but cheap insurance against future variants that
might not.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The model_configs_{flash,fused}_attn dicts are module-level and shared
across pool cases. The THD branch below rewrites config.attn_mask_type
in place (causal -> padding_causal, no_mask -> padding). With the
persistent-pool runner, the next case looking up the same model key
gets the mutated config and fails the "causal or no_mask only" assert.
Caught at benchmark time on cp_2_0 + thd, identical to the cascade the
existing batching PR (sudhakars/cp_test_batching_pr) hit and fixed the
same way in commit 6355f62.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Mirrors the two pre-emptive skips on the PR-batching branch: * non-vanilla softmax with FusedAttention is not deterministic * post_scale_bias with requires_grad is not deterministic Without these skips, the corresponding configs propagate into the pool worker under NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 and fail inside run_dpa_with_cp instead of being marked SKIPPED. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The pool worker reused RNG state across cases, which produced small numerical drift on some non-FP8 fused-attention configs (cp_1_0 + thd/p2p, cp_1_0 + sbhd/all_gather) compared to the single-shot worker. Matches the per-case startup of the single-shot worker: torch.manual_seed(1234) + torch.cuda.manual_seed(1234) at the start of every case, alongside the existing FP8 / env / cache resets. Moved the reset call from the post-case finally block to the start of _run_one so the first case is also seeded consistently with subsequent cases. Otherwise the first case would inherit the process-default RNG and only the second-and-later cases would be deterministic. Validated locally: 38 passed, 0 failed (was 36 passed, 2 failed). Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR replaces per-test
Confidence Score: 5/5Safe to merge; the test infrastructure is well-structured and the stream race fix in context_parallel.py is a genuine correctness improvement for long-lived CP processes. The pool design is sound: sentinel-based stdout protocol, stderr draining, kill-and-respawn on timeout, and pool-shared NCCL groups eliminate per-case spawn cost without changing test semantics. The stream race fix is correct — the new wait_stream sits outside the with torch.cuda.stream block (which ends after the copy operations), so it correctly synchronises the default stream against cp_stream before the torch.maximum read. The two hardening observations in test_attention_with_cp.py do not affect current correctness. No files require special attention for merging; the two style/hardening observations in test_attention_with_cp.py are non-blocking. Important Files Changed
Sequence DiagramsequenceDiagram
participant P as pytest (session)
participant PW as PoolWorker (pytest-side)
participant R0 as torchrun rank-0
participant RN as torchrun rank-1..N
P->>PW: "cp_pool(world_size=2)"
PW->>R0: spawn torchrun --standalone (lazy, first use)
R0->>RN: dist.init_process_group (NCCL, once)
R0->>RN: _create_cp_comm_groups (once)
Note over R0,RN: Pool enters dispatch loop
loop per test case
P->>PW: "_submit(pool, **kwargs)"
PW->>R0: JSON line to stdin
R0->>RN: broadcast_object_list(request)
R0->>RN: _reset_between_cases() + run_dpa_with_cp()
R0->>RN: gather_object((ok, traceback))
R0->>PW: stdout sentinel-prefixed JSON response
PW->>P: return or raise AssertionError
end
P->>PW: fixture teardown
PW->>R0: shutdown JSON line
R0->>RN: destroy CP groups + dist.destroy_process_group()
Reviews (7): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| # destroy distribution group | ||
| dist.destroy_process_group() | ||
| if _pool_managed_pg: | ||
| # Pool owns the main PG; only clean up groups created for this case. | ||
| dist.destroy_process_group(cp_comm_group) | ||
| if cp_comm_type == "a2a+p2p": | ||
| for g in cp_comm_sub_groups: | ||
| dist.destroy_process_group(g) | ||
| else: | ||
| dist.destroy_process_group() |
There was a problem hiding this comment.
Process group leak on exception in pool mode
When run_dpa_with_cp raises any exception after cp_comm_group is created (line 250) — e.g. a tensor-mismatch assertion, an OOM, or a NCCL error — the exception propagates out of the function before reaching this cleanup block. In pool mode dist.destroy_process_group() is never called for that case's cp_comm_group (or cp_comm_sub_groups). Every failed case therefore leaks one or more NCCL communicators. NCCL tracks live communicators internally and can exhaust resources over a long test session with repeated failures, causing subsequent cases to fail for unrelated reasons.
The standard fix is to wrap the body of run_dpa_with_cp in a try/finally so that the selected cleanup always runs. Note that cp_comm_group is created unconditionally before any test logic, so it is always defined by the time the finally block runs.
| def _send_response(rank: int, payload: dict) -> None: | ||
| if rank == 0: | ||
| sys.stdout.write(json.dumps(payload) + "\n") | ||
| sys.stdout.flush() |
There was a problem hiding this comment.
stdout pollution can silently corrupt the JSON protocol
torchrun (and the worker processes it spawns for ranks 1–N) all inherit the same stdout file descriptor as rank 0. If torchrun writes any status line to stdout, or if any non-rank-0 worker accidentally prints (e.g. via a print call in a library, NCCL debug output, or a Python warning), those bytes are interleaved with rank 0's JSON responses. The parent's readline() in PoolWorker.submit would then receive a non-JSON line and raise a json.JSONDecodeError, killing the pool and failing the test with a misleading error message.
Consider redirecting torchrun's own output or adding a sentinel prefix to every response line so the reader can skip unrecognised lines.
| self._kill() | ||
| raise AssertionError("pool worker died before request could be sent") | ||
|
|
||
| ready, _, _ = select.select([self.proc.stdout], [], [], timeout) |
There was a problem hiding this comment.
select.select on a subprocess pipe works on Linux/macOS but raises OSError on Windows because Windows does not support select on non-socket file descriptors. This is test infrastructure for GPU hardware so Windows is not a target, but the call should at least be documented as Linux-only.
| ready, _, _ = select.select([self.proc.stdout], [], [], timeout) | |
| ready, _, _ = select.select([self.proc.stdout], [], [], timeout) # Linux/macOS only |
Three changes that bring the pool's failure semantics on par with the per-batch torchrun approach in PR NVIDIA#2965 and remove a couple of footguns: 1. Capture pool-worker stderr into a ring buffer and attach the tail to crash-path AssertionErrors. Equivalent in spirit to PR NVIDIA#2965's run_distributed() — CI JUnit XML now shows the actual cause (NCCL error, Python traceback, OOM) inline with the failing test, instead of just "pool worker died mid-request" / "timed out". A daemon drainer thread reads stderr line-by-line into a deque(maxlen=200) and also echoes to sys.stderr so pytest's per-test capture still gets every line. Maximum buffered footprint ~40 KB. 2. Tighten POOL_SUBMIT_TIMEOUT_SEC default 600 -> 90. On H100 the slowest observed per-case wall is ~15 s (p99 also 15 s, p50 ~5 s). 90 s gives ~6x headroom over the worst observed case while still detecting a genuine hang within ~1.5 min instead of ~10 min. Env var still overrides for slower machines or expanded test matrices. 3. Optional per-case wall-time logging (NVTE_CP_POOL_TIMING=1) prints "[POOL-TIMING] case_idx=N world_size=W wall_s=X.XXX ok=B" to stderr on rank 0 only. Grep-friendly; lets future tuning recalibrate the timeout against the observed distribution. Off by default so normal runs stay quiet. Validated: 38 passed / 0 failed in 248 s on H100, test_essential=True, with no perf regression vs the un-patched 256 s. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Three fixes responding to NVIDIA#2993 review comments: P1: NCCL communicator leak on exception (run_attention_with_cp.py) run_dpa_with_cp() created cp_comm_group (and optionally cp_comm_sub_groups) near the top, but the destroy_process_group() calls ran only on the success path at the end of the function. Any exception in between (tensor assertion, OOM, NCCL error) skipped the cleanup, leaking communicators in pool mode. Long sessions with repeated failures could exhaust NCCL internal tracking. Wrap the test work in try/finally so the destroy logic always runs. Initialise cp_comm_sub_groups = [] unconditionally so the finally block is safe even when cp_comm_type != "a2a+p2p" (or when an assert fires before the populate loop). Each destroy is itself try/except so a destroy failure on one group doesn't leak the others. P2: stdout protocol can be corrupted by interleaved chatter torchrun and ranks 1..N share rank 0's stdout fd. Any non-rank-0 print, NCCL debug line, or torchrun status output interleaves with the JSON response and breaks json.loads, killing the pool with a misleading "json decode error". Prefix every response with "[CP_POOL_RESP] " in run_attention_with_cp_pool.py and have PoolWorker.submit() scan stdout for sentinel-prefixed lines, echoing non-protocol lines to stderr for visibility. Bounded scan (MAX_NOISE_LINES=1000) so a chatty worker can't stall the parent. P2 (doc): select.select on a pipe fd is Linux/macOS only Added a short comment noting Windows portability. CP attention tests run on Linux GPU hosts; this is a documentation issue, not a real bug. Validated: 38 passed / 0 failed in 270 s on H100, test_essential=True (was 248 s pre-P2 — the +22 s is the new sentinel-scan loop's per-line overhead at ~600 ms/case, within noise). Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
|
Thanks for the review @greptile-apps — all three findings addressed in P1 — NCCL communicator leak on exception ( Wrapped the body of P2 — stdout protocol can be corrupted by interleaved chatter ( Worker now prefixes every response with P2 — Added a short comment noting the Windows portability gap. Tests target Linux GPU hosts only; this is a doc issue, not a real bug, but flagged for future readers. PR-body noteThe "default 600 s" inconsistency you flagged on the env-var table was already reconciled in the previous push — ValidationRe-ran The 5 state-cascade failures on |
|
All three fixes look correct. Let me note one small remaining issue in the deadline logic before signing off. In deadline = None
while scanned < self._MAX_NOISE_LINES:
remaining = timeout if deadline is None else max(0.0, deadline - time.monotonic())
ready, _, _ = select.select([self.proc.stdout], [], [], remaining)
if deadline is None:
deadline = time.monotonic() + timeout # ← set after first select completesIf the first line (noise or response) arrives after, say, 50 s, the deadline is then set to The fix is to set the deadline once, before the loop: deadline = time.monotonic() + timeout
while scanned < self._MAX_NOISE_LINES:
remaining = max(0.0, deadline - time.monotonic())
ready, _, _ = select.select([self.proc.stdout], [], [], remaining)
if not ready:
...In practice this won't matter much — torchrun status chatter arrives nearly instantaneously before the response — but the current code doesn't actually enforce the stated Everything else looks good:
The deadline fix above is the only remaining item; everything else is addressed cleanly. Tip: You can customize Greptile's behavior for this repo with |
| cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") | ||
| # Always defined so the finally cleanup below is safe even when cp_comm_type != "a2a+p2p". | ||
| cp_comm_sub_groups = [] | ||
| if cp_comm_type == "a2a+p2p": | ||
| assert world_size % 2 == 0, ( | ||
| "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size" | ||
| " = 2." | ||
| ) | ||
| cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] | ||
| cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] | ||
| cp_comm_sub_groups = [] | ||
| for sub_ranks in cp_comm_sub_ranks: | ||
| sub_group = dist.new_group(sub_ranks, backend="nccl") | ||
| if rank in sub_ranks: | ||
| cp_comm_sub_groups.append(sub_group) |
There was a problem hiding this comment.
Group creation sits outside the
try/finally guard
cp_comm_group (line 250) and the sub-group creation loop (lines 260–263) all run before the try block that starts at line 268. If the sub-group loop raises — e.g., NCCL rejects a new_group call on a partially recovered communicator after a previous case failed — the finally block is never entered, so cp_comm_group and any sub-groups already appended to cp_comm_sub_groups are never destroyed. In a long-running pool this accumulates a communicator leak on every such failure, eventually hitting NCCL's internal communicator limit and breaking unrelated subsequent cases. Moving all three dist.new_group sites inside the try block (or at minimum wrapping the sub-group loop and the unconditional cp_comm_group creation in a nested try/finally) would close the gap.
cyanguwa
left a comment
There was a problem hiding this comment.
I think the structure is a lot simpler/cleaner than #2965 - thanks!
For the 5 remaining failing tests, I wonder if there's something else we need to reset for FP8 (please check with @ksivaman).
Also, please compare the number of tests before/after this PR and make sure we're still running the same number of tests! If the new, reduced runtime allows now, we can turn on test_essential=False, but I'll leave that to you.
If no major issues, I approve! Thanks!
| if _pool_managed_pg: | ||
| # Pool owns the main PG; only clean up groups created for this case. | ||
| try: | ||
| dist.destroy_process_group(cp_comm_group) |
There was a problem hiding this comment.
Do we want to destroy (and create earlier in the file) cp_comm_group for every config? I feel if the pool is the same, the world_size would be the same, and so is cp_comm_group?
In AttnFuncWithCPAndKVAllGather.forward, max_logit_per_step[i] is written inside `with torch.cuda.stream(flash_attn_streams[i])`. For i=1, flash_attn_streams[1] is cp_stream — i.e. *not* the default stream. Later, at loop iteration i=2, the code reads max_logit_per_step[1] via `torch.maximum(max_logit, max_logit_per_step[i-1])` which runs on the default stream. Without an explicit wait_stream, this is a read-after-write race across streams. The post-loop `current_stream().wait_stream(cp_stream)` is too late — the race has already fired. The race is latent: outcome depends on stream scheduling. In a fresh-process subprocess (one-torchrun-per-test path), streams are cleanly initialised and timing happens to put the write before the read. In a long-running persistent-worker process — exposed by PR NVIDIA#2993's pool design — prior workloads shape stream state differently, the read can fire before the write completes, and max_logit ends up with stale values in some heads (~0.3 abs diff, 3/12 elements wrong on the H100 matrix). Fix: insert `current_stream().wait_stream(flash_attn_streams[i-1])` before the torch.maximum read. No-op when the streams are identical (i=1 case, where flash_attn_streams[0] is current_stream), only fires when reading from cp_stream (i=2 case). Validated: 8xH100, test_essential=False, 348 passed / 0 failed in 27m 10s (was 323 passed + 5 failed at this commit's parent, all 5 failing on cp_comm_type=all_gather with mismatched max_logit). The failing configs (all_gather + cp_1_0/cp_1_1 + bshd or fp16) now pass under the pool — confirming the race was the sole root cause. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Line-level cleanups from the second reviewer pass on PR NVIDIA#2993. Each item is dead/redundant; none changes behaviour. Full-matrix test_essential=False on 8xH100 still passes 348/0 in 26m 23s after these. run_attention_with_cp_pool.py: - Drop _TRANSIENT_ENV_KEYS tuple + pop loop. run_dpa_with_cp already re-sets NVTE_FUSED_ATTN/NVTE_FLASH_ATTN unconditionally at the top and pops the FP8 ones itself. The pop loop was defensive against a hypothetical "future caller that doesn't re-set them" that doesn't exist. - Drop gc.collect() after torch.cuda.empty_cache(). The cases create no Python reference cycles between iterations and empty_cache only frees CUDA blocks PyTorch already considers free; the combination was no-op here. - Drop dist.barrier() after dist.gather_object(). gather_object is itself a collective synchronization point — if every rank reaches it, none is ahead. The "surface a wedged communicator here" comment was wishful: a wedged communicator would already wedge the gather. test_attention_with_cp.py (PoolWorker): - Drop _MAX_NOISE_LINES = 1000 + the scanned counter + the unreachable post-loop "1000+ lines" branch. select()'s deadline already bounds the loop; the line-count cap was redundant and the over-limit branch was unreachable in practice. - Inline _stderr_tail() into _diag(). Single caller, single use. - Drop the _stderr_thread attribute. The drainer is daemon and self-terminates when the pipe closes; we never read the field anywhere, so initialising and nulling it was bookkeeping for no reason. - Drop the dead assert in submit() — _ensure_alive() on the prior line already guarantees proc/stdin/stdout exist. Deferred to a follow-up: - L8 (drop try/except around dist.destroy_process_group). Real semantic change: hides errors that occur when a previous test wedged the communicator. Worth doing but needs its own validation. - R1 medium items M1 (module-level flag vs NVTE_CP_POOL_PG env var), M2 (redirect rank>0 stdout vs sentinel scan), M3 (explicit CUDA_VISIBLE_DEVICES per pool). Same reasoning — separate PRs. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
| torch.cuda.manual_seed(seed) | ||
|
|
||
| test_essential = True | ||
| test_essential = False |
There was a problem hiding this comment.
test_essential left as False — will break CI
The default was True (38 essential configs, all passing) and has been changed to False (full matrix, 328 runnable, 5 known failures). The PR description explicitly calls these out as "state-cascade failures" under Known Issues. Merging with test_essential = False will cause CI to report 5 failures by default, regressing from the current baseline.
world_size and the rank set don't change for the lifetime of one pool, so recreating the world group and a2a+p2p sub-groups per case wastes ~50-100 ms of NCCL setup each. Pre-create them once in the pool worker (new helper _create_cp_comm_groups), stash on the run_attention_with_cp module via module-level _pool_cp_comm_group / _pool_cp_comm_sub_groups pointers, and reuse them from run_dpa_with_cp in pool mode. Pool teardown destroys them once at shutdown. Also move per-case dist.new_group() calls inside the try/finally in run_dpa_with_cp: a failure mid-loop in the a2a+p2p sub_group population otherwise leaks every communicator created before the failure. The finally now only destroys groups we created locally (cp_comm_group / sub_groups populated in the else-branch), leaving pool-owned groups alone for reuse. cyanguwa's review feedback on PR NVIDIA#2993. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
world_size and the rank set don't change for the lifetime of one pool, so recreating the world group and a2a+p2p sub-groups per case wastes ~50-100 ms of NCCL setup each. Pre-create them once in the pool worker (new helper _create_cp_comm_groups), stash on the run_attention_with_cp module via module-level _pool_cp_comm_group / _pool_cp_comm_sub_groups pointers, and reuse them from run_dpa_with_cp in pool mode. Pool teardown destroys them once at shutdown. Also move per-case dist.new_group() calls inside the try/finally in run_dpa_with_cp: a failure mid-loop in the a2a+p2p sub_group population otherwise leaks every communicator created before the failure. The finally now only destroys groups we created locally (cp_comm_group / sub_groups populated in the else-branch), leaving pool-owned groups alone for reuse. cyanguwa's review feedback on PR NVIDIA#2993. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…kars/cp_batching_pool
The Round-1 P1 NCCL-communicator-leak fix (e162a9e) wrapped the ~540-line body of run_dpa_with_cp in try/finally. The wrap itself was tiny but it re-indented every line of the body by one level, inflating the PR diff of run_attention_with_cp.py to ~1000 lines against origin/main. Items 2+3 (d15bfce) since made the wrap unnecessary: - In pool mode, cp_comm_group and cp_comm_sub_groups are owned by the pool worker (which destroys them once at pool shutdown). run_dpa_with_cp neither creates nor destroys them, so an in-body exception can't leak communicators. - In single-shot mode, groups are still created locally, but the subprocess exits at function return; NCCL releases everything at process teardown, so a stray exception leaks communicators only for the milliseconds before the process dies — a bounded one-off cost, not the unbounded accumulation that Round-1 flagged for pool mode. Removing the wrap drops the run_attention_with_cp.py diff against origin/main from ~1000 lines to ~120 lines without changing observable behaviour. Smoke-tested: 4 representative cases pass. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Design
Problem
8×H100,
test_essential=True, 38 runnable CP attention configs:torchrun(baseline measured in [PyTorch] Batch CP attention tests in single torchrun to amortize NCC… #2965).We need that overhead amortised, without changing how tests are written or how skips report. #2965 takes one approach (dry-run + per-batch torchrun). This PR proposes a simpler alternative.
Approach
One long-lived
torchrunperworld_size, fed by JSON-over-stdio. No dry-run, no batch chunking, no two-pass dispatch.A session-scoped fixture spawns at most two workers (one for
world_size=2, one forworld_size=4, both lazily). Each test body callspool.submit(kwargs)which writes one JSON request line to rank-0's stdin and reads one JSON response line from rank-0's stdout. NCCL initialises once per pool; every subsequent case reuses the same process group.How the pool works
Worker (
run_attention_with_cp_pool.py) is launched once perworld_sizeand loops:run_dpa_with_cpchecksNVTE_CP_POOL_PGand skips its owndist.init_process_groupcall; on teardown it destroys only the per-case CP sub-groups, leaving the main PG intact for the next case.Pytest side (
PoolWorker) is a thin wrapper aroundsubprocess.Popenofpython -m torch.distributed.run --standalone.submit()writes one request line,selects on stdout with a configurable timeout, reads one sentinel-prefixed JSON response line. On timeout or pipe error it terminates the pool; the nextsubmit()lazily respawns it.A daemon thread drains the worker's stderr into a bounded ring buffer (200 lines / ~40 KB) and echoes each line live to
sys.stderr. On a crash/timeout, the last 4 KB of that buffer is attached to theAssertionErrorraised on the failing test — so CI JUnit XML carries the actual cause (NCCL error, Python traceback, OOM) inline rather than just"pool worker died". Equivalent in spirit to PR #2965'srun_distributed()stderr capture.Bug found while validating: stream race in
AttnFuncWithCPAndKVAllGather.forwardA latent TE bug surfaced once the persistent pool ran the full
test_essential=Falsematrix. In the all-gather CP forward,max_logit_per_step[i]is written insidewith torch.cuda.stream(flash_attn_streams[i]):— and fori=1that'scp_stream, not the default stream. Later, at loop iterationi=2, the code readsmax_logit_per_step[1]viatorch.maximum(...)on the default stream without await_streamin between. The post-loopcurrent_stream().wait_stream(cp_stream)is too late — the race has already fired.The race is latent: outcome depends on stream scheduling. In a fresh-process subprocess (one-torchrun-per-test path) streams are cleanly initialised and timing happens to put the write before the read. In a long-running persistent-worker process — exactly what the pool exposes — prior workloads leave stream state in a different shape and the read can fire before the write completes. The result was
max_logitstale on 3 of 12 head entries (~0.3 abs diff), surfacing in 5 specific configs (allcp_comm_type=all_gather,cp_1_0/cp_1_1,bshdorfp16).Fix is one line in
context_parallel.py:current_stream().wait_stream(flash_attn_streams[i-1])before thetorch.maximumread. No-op when streams are identical (i=1), only fires when reading fromcp_stream(i=2). Independently useful for anyone running CP attention in a long-lived process, not just the pool.Performance
Measured back-to-back on the same 8×H100 box.
test_essential=True(38 runnable configs: 34 × 2-GPU + 4 × 4-GPU)torchrunspawnsB=16world_sizetest_essential=False(full matrix, 50 976 collected, 348 runnable)Pool is ~7 % faster on the full matrix. Note the runnable count rises from 328 to 348 once the race fix lands — the 20 extra cases were silently dropped earlier by the same numerical-corruption path.
Knobs
NVTE_CP_POOL_TIMEOUT_SEC=Ntest_essential=True) is ~15 s; 90 s gives ~6× headroom. Override for slower machines or heavier matrices.NVTE_CP_POOL_TIMING=1[POOL-TIMING] case_idx=N world_size=W wall_s=X.XXX ok=Bon rank-0 stderr per case. Off by default. Used to recalibrate the timeout against a new matrix.NVTE_CP_POOL_PG=1run_dpa_with_cp; tells that function not to calldist.init_process_groupand to leave the main PG alone on teardown. Not for end-user use.There is intentionally no batch-size knob — there's no concept of a batch to size.
Adding a pooled test
@pytest.mark.parametrizestack + inlinepytest.skip(...)checks.cp_poolto the function signature.pool = cp_pool(num_gpus); _submit(pool, **kwargs)wherekwargsbecomesrun_dpa_with_cp(**kwargs).That's it. No two-pass logic, no fixture stubs.
Failure semantics
pytest.skip(...)firespool.submit, no pool work.@pytest.mark.skip(if)marker firesdist.gather_object). Other ranks' tracebacks visible in subprocess stderr."pool worker (world_size=N) timed out after 90s; ..."plus the last 4 KB of the worker's stderr. Pool killed; subsequent tests respawn a fresh pool."pool worker died mid-request"plus stderr tail; next case respawns."pool worker died before request could be sent"plus stderr tail; next case respawns.Cross-rank failure detail is strictly better than
all_reduce(ok, op=MIN)(which only tells you some rank failed):gather_objectbrings back each rank's(ok, traceback)tuple so the reported error is the actual non-zero-rank stack trace, not "see subprocess stderr."What happens when a pool worker stalls
Three terminal states for any
submit(): response arrives → normal handling; no response afterPOOL_SUBMIT_TIMEOUT_SEC→ timeout path; process died → mid-request-death path. Any stall (application hang, NCCL deadlock, GPU wedge, even a stdout-pipe-full self-deadlock) eventually resolves to one of the latter two. The pool isSIGTERM'd (5 s grace) thenSIGKILL'd, the current test FAILs with the stderr tail attached, and the next case lazily respawns a fresh pool of the sameworld_size(~6–9 s NCCL re-init). The other pool (differentworld_size) is unaffected. Blast radius: one failed test + one respawn.Mitigations for shared-process state
All cases share one Python process and one NCCL world per
world_size, so anything that needs a clean per-test starting point is reset before each case (in_run_one, not in afinallyblock, so the first case is also clean):torch.manual_seed(1234) + torch.cuda.manual_seed(1234)— RNG reseeded so input tensors are reproducible per case.FP8GlobalStateManager.reset()— drops FP8 amax history etc. that would otherwise leak across cases.torch.cuda.empty_cache().copy.deepcopy(model_configs_*[model])insiderun_dpa_with_cp— the THD branch rewritesattn_mask_typein place; without deepcopy the change leaks into the module-level dict.run_dpa_with_cpitself re-setsNVTE_FUSED_ATTN/NVTE_FLASH_ATTNunconditionally at the top of every call and pops the FP8-related transient env vars, so no explicit env-key reset is needed in the pool worker.The single-shot
run_dpa_with_cpdoes some of these inherently (it's a fresh process). For the pooled path we replicate them explicitly so the two execution modes produce identical per-case state.Edge cases
torch.distributed.run --standalonepicks a free rendezvous port at bind time. NoMASTER_PORTplumbing needed; parallel pytest sessions (e.g.,L1_pytorch_distributed_unittest's det vs non-det concurrent runs on disjoint GPU sets) cannot collide._kill()terminates the pool; the nextsubmit()lazily respawns a fresh worker, so the blast radius is one test (the timed-out one) plus a one-time respawn cost.BrokenPipeError/ empty read → AssertionError on the current test with the stderr tail; kill the pool; respawn on next case.torch.cuda.set_device(rank % device_count)means both pools claim GPUs 0–N starting at 0. They never run CUDA concurrently (pytest serialises tests), so the idle pool only holds ~1 GB of CUDA context per shared GPU — well within H100's 80 GB. NCCL worlds are independent. No collision. This is intentional: setting an explicitCUDA_VISIBLE_DEVICESper pool would prevent the parallel-pytest-session pattern used byqa/L1_pytorch_distributed_unittest/test.sh(det / non-det on disjoint GPU sets); see Follow-ups.NVTE_CP_POOL_PGcollision. The env var is set by the pool worker afterdist.init_process_groupand only read byrun_dpa_with_cp— no other consumer. If end-user code accidentally sets it without an init'd PG,run_dpa_with_cpwill fail when it tries to use the (non-existent) PG. Harmless; same failure as setting any other internal flag incorrectly.Validation
test_essential=Truetest_essential=FalsePer-case wall-time distribution on H100 (
test_essential=True, withNVTE_CP_POOL_TIMING=1): min 1.87 s, p50 4.77 s, p95 12.43 s, max 15.39 s. Drove thePOOL_SUBMIT_TIMEOUT_SECdefault of 90 s (~6× max).Follow-ups (not in this PR)
Items intentionally deferred to keep this PR scoped:
try/except Exception: passarounddist.destroy_process_groupcalls. Real semantic change: would surface errors that the swallow currently hides. Needs its own validation.NVTE_CP_POOL_PGenv-var contract between the pool worker andrun_dpa_with_cpwith a module-level flag (run_attention_with_cp._POOL_MANAGED_PG = True). Cleaner, type-checkable; same effect./dev/nullat worker startup instead of sentinel-prefix scanning. Closes the stdout-pollution class at the source rather than papering over it.CUDA_VISIBLE_DEVICESper pool (so 2-GPU and 4-GPU pools claim disjoint GPUs) would break the parallel-pytest-session pattern used byqa/L1_pytorch_distributed_unittest/test.sh. Each top-level pytest session setsCUDA_VISIBLE_DEVICESitself (e.g. det on GPUs 0-3, non-det on 4-7); the pool inherits that and usesrank % device_countto map within the session's slice. Adding per-pool device pinning would override the session's slice and produce overlap across sessions. The current "overlap within a session, idle context only" behaviour (edge case 4) is the right trade-off.Files
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py—wait_streamfix inAttnFuncWithCPAndKVAllGather.forward(the latent stream race the pool exposed).tests/pytorch/attention/test_attention_with_cp.py—cp_poolsession fixture +PoolWorker(lazy spawn,--standalone, 90 s timeout, kill-and-respawn, stderr drainer thread + tail attached to crash AssertionErrors).tests/pytorch/attention/run_attention_with_cp_pool.py— pool worker: init NCCL once, dispatch loop,_reset_between_cases(seed/FP8/cache),gather_objectfor cross-rank failure detail, optionalNVTE_CP_POOL_TIMING=1per-case timing log.tests/pytorch/attention/run_attention_with_cp.py—run_dpa_with_cphonoursNVTE_CP_POOL_PG(skips its own PG init/destroy), deep-copies model configs, wraps the per-case body in try/finally so the per-case CP communicator is destroyed even when an assertion or other exception fires (previously leaked one or more NCCL communicators per failed case).Comparison to #2965
Both PRs solve the same problem; this one is structurally smaller and has fewer concepts.
origin/mainCP_TEST_BATCH_SIZE,CP_TEST_BATCH_RETRYNVTE_CP_POOL_TIMEOUT_SEC(default 90 s),NVTE_CP_POOL_TIMING_COLLECT_MODE,_DummyRequest,_item_static_skip,_BACKEND_CACHE, batch chunking, atomic JSON flush, singleton retryall_reduce(ok, MIN)— boolean onlygather_object— full traceback per rankrun_distributed()attaches last 4 KB)MASTER_PORTenv per parallel pytest session--standalone, automatictest_essential=True)Type of change
Checklist
test_essential=False: 348 / 0)