Skip to content

CP Tests batching using subprocess worker pool#2993

Open
sudhakarsingh27 wants to merge 16 commits into
NVIDIA:mainfrom
sudhakarsingh27:sudhakars/cp_batching_pool
Open

CP Tests batching using subprocess worker pool#2993
sudhakarsingh27 wants to merge 16 commits into
NVIDIA:mainfrom
sudhakarsingh27:sudhakars/cp_batching_pool

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator

@sudhakarsingh27 sudhakarsingh27 commented May 14, 2026

Design

Problem

8×H100, test_essential=True, 38 runnable CP attention configs:

We need that overhead amortised, without changing how tests are written or how skips report. #2965 takes one approach (dry-run + per-batch torchrun). This PR proposes a simpler alternative.

Approach

One long-lived torchrun per world_size, fed by JSON-over-stdio. No dry-run, no batch chunking, no two-pass dispatch.

A session-scoped fixture spawns at most two workers (one for world_size=2, one for world_size=4, both lazily). Each test body calls pool.submit(kwargs) which writes one JSON request line to rank-0's stdin and reads one JSON response line from rank-0's stdout. NCCL initialises once per pool; every subsequent case reuses the same process group.

pytest (session)
    ├─ fixture cp_pool spawns torchrun --standalone (world_size=2)
    │    ├─ rank 0 ── stdin/stdout ──┐
    │    ├─ rank 1                   │
    │                                │
    ├─ test_cp_with_flash_attention[case0] ── pool.submit(kwargs) ──┘
    ├─ test_cp_with_flash_attention[case1] ── pool.submit(kwargs)
    ├─ ...
    └─ fixture teardown sends {"op":"shutdown"} to each pool

How the pool works

Worker (run_attention_with_cp_pool.py) is launched once per world_size and loops:

def main():
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    torch.cuda.set_device(rank % torch.cuda.device_count())
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    os.environ["NVTE_CP_POOL_PG"] = "1"  # tell run_dpa_with_cp to reuse our PG

    while True:
        req = _recv_request(rank)              # rank 0 reads stdin, broadcasts
        if req.get("op") == "shutdown":
            break
        ok, msg = _run_one(req, rank)          # _reset_between_cases + run_dpa_with_cp
        gathered = [None] * world_size
        dist.gather_object((ok, msg), gathered if rank == 0 else None, dst=0)
        if rank == 0:
            _send_response(rank, {"ok": all_ok, "error": first_failing_traceback})
    dist.destroy_process_group()

run_dpa_with_cp checks NVTE_CP_POOL_PG and skips its own dist.init_process_group call; on teardown it destroys only the per-case CP sub-groups, leaving the main PG intact for the next case.

Pytest side (PoolWorker) is a thin wrapper around subprocess.Popen of python -m torch.distributed.run --standalone. submit() writes one request line, selects on stdout with a configurable timeout, reads one sentinel-prefixed JSON response line. On timeout or pipe error it terminates the pool; the next submit() lazily respawns it.

A daemon thread drains the worker's stderr into a bounded ring buffer (200 lines / ~40 KB) and echoes each line live to sys.stderr. On a crash/timeout, the last 4 KB of that buffer is attached to the AssertionError raised on the failing test — so CI JUnit XML carries the actual cause (NCCL error, Python traceback, OOM) inline rather than just "pool worker died". Equivalent in spirit to PR #2965's run_distributed() stderr capture.

Bug found while validating: stream race in AttnFuncWithCPAndKVAllGather.forward

A latent TE bug surfaced once the persistent pool ran the full test_essential=False matrix. In the all-gather CP forward, max_logit_per_step[i] is written inside with torch.cuda.stream(flash_attn_streams[i]): — and for i=1 that's cp_stream, not the default stream. Later, at loop iteration i=2, the code reads max_logit_per_step[1] via torch.maximum(...) on the default stream without a wait_stream in between. The post-loop current_stream().wait_stream(cp_stream) is too late — the race has already fired.

The race is latent: outcome depends on stream scheduling. In a fresh-process subprocess (one-torchrun-per-test path) streams are cleanly initialised and timing happens to put the write before the read. In a long-running persistent-worker process — exactly what the pool exposes — prior workloads leave stream state in a different shape and the read can fire before the write completes. The result was max_logit stale on 3 of 12 head entries (~0.3 abs diff), surfacing in 5 specific configs (all cp_comm_type=all_gather, cp_1_0/cp_1_1, bshd or fp16).

Fix is one line in context_parallel.py: current_stream().wait_stream(flash_attn_streams[i-1]) before the torch.maximum read. No-op when streams are identical (i=1), only fires when reading from cp_stream (i=2). Independently useful for anyone running CP attention in a long-lived process, not just the pool.

Performance

Measured back-to-back on the same 8×H100 box.

test_essential=True (38 runnable configs: 34 × 2-GPU + 4 × 4-GPU)

Run Approach torchrun spawns Wall Pass / Fail Speedup
Unbatched (baseline) one torchrun per config 38 ~554 s n/a (PR #2965) 1.0×
#2965 B=16 dry-run + per-batch torchrun 4 263 s 38 / 0 2.1×
This PR (pool) persistent NCCL pool per world_size 2 248 s 38 / 0 2.23×

test_essential=False (full matrix, 50 976 collected, 348 runnable)

Run Approach Wall Pass / Fail Notes
#2965 (B=16) dry-run + per-batch torchrun 28 m 10 s (1 690 s) 328 / 0 clean
This PR (pool) — before the all-gather race fix persistent pool 26 m 25 s (1 585 s) 323 / 5 5 latent-TE-bug failures; same configs pass on #2965 because per-batch fresh-process re-init hides the race
This PR (pool) — current persistent pool + race fix 26 m 23 s (1 583 s) 348 / 0 clean

Pool is ~7 % faster on the full matrix. Note the runnable count rises from 328 to 348 once the race fix lands — the 20 extra cases were silently dropped earlier by the same numerical-corruption path.

Knobs

Env var Effect
NVTE_CP_POOL_TIMEOUT_SEC=N Max seconds the fixture waits for one case's response. Default 90. Slowest observed case on H100 (test_essential=True) is ~15 s; 90 s gives ~6× headroom. Override for slower machines or heavier matrices.
NVTE_CP_POOL_TIMING=1 Emit [POOL-TIMING] case_idx=N world_size=W wall_s=X.XXX ok=B on rank-0 stderr per case. Off by default. Used to recalibrate the timeout against a new matrix.
NVTE_CP_POOL_PG=1 Set by the pool worker itself before calling run_dpa_with_cp; tells that function not to call dist.init_process_group and to leave the main PG alone on teardown. Not for end-user use.

There is intentionally no batch-size knob — there's no concept of a batch to size.

Adding a pooled test

  1. Write the test as usual: @pytest.mark.parametrize stack + inline pytest.skip(...) checks.
  2. Add cp_pool to the function signature.
  3. At the end, do pool = cp_pool(num_gpus); _submit(pool, **kwargs) where kwargs becomes run_dpa_with_cp(**kwargs).

That's it. No two-pass logic, no fixture stubs.

Failure semantics

Outcome What pytest sees
Inline pytest.skip(...) fires Standard SKIP — body returns before pool.submit, no pool work.
@pytest.mark.skip(if) marker fires Standard SKIP via pytest's normal path.
Config ran, assertion failed on any rank FAIL with the first failing rank's full traceback (via dist.gather_object). Other ranks' tracebacks visible in subprocess stderr.
Pool worker timed out FAIL on the current test: "pool worker (world_size=N) timed out after 90s; ..." plus the last 4 KB of the worker's stderr. Pool killed; subsequent tests respawn a fresh pool.
Pool worker died mid-request FAIL: "pool worker died mid-request" plus stderr tail; next case respawns.
Pool worker died before request FAIL: "pool worker died before request could be sent" plus stderr tail; next case respawns.

Cross-rank failure detail is strictly better than all_reduce(ok, op=MIN) (which only tells you some rank failed): gather_object brings back each rank's (ok, traceback) tuple so the reported error is the actual non-zero-rank stack trace, not "see subprocess stderr."

What happens when a pool worker stalls

Three terminal states for any submit(): response arrives → normal handling; no response after POOL_SUBMIT_TIMEOUT_SEC → timeout path; process died → mid-request-death path. Any stall (application hang, NCCL deadlock, GPU wedge, even a stdout-pipe-full self-deadlock) eventually resolves to one of the latter two. The pool is SIGTERM'd (5 s grace) then SIGKILL'd, the current test FAILs with the stderr tail attached, and the next case lazily respawns a fresh pool of the same world_size (~6–9 s NCCL re-init). The other pool (different world_size) is unaffected. Blast radius: one failed test + one respawn.

Mitigations for shared-process state

All cases share one Python process and one NCCL world per world_size, so anything that needs a clean per-test starting point is reset before each case (in _run_one, not in a finally block, so the first case is also clean):

  • torch.manual_seed(1234) + torch.cuda.manual_seed(1234) — RNG reseeded so input tensors are reproducible per case.
  • FP8GlobalStateManager.reset() — drops FP8 amax history etc. that would otherwise leak across cases.
  • torch.cuda.empty_cache().
  • copy.deepcopy(model_configs_*[model]) inside run_dpa_with_cp — the THD branch rewrites attn_mask_type in place; without deepcopy the change leaks into the module-level dict.

run_dpa_with_cp itself re-sets NVTE_FUSED_ATTN/NVTE_FLASH_ATTN unconditionally at the top of every call and pops the FP8-related transient env vars, so no explicit env-key reset is needed in the pool worker.

The single-shot run_dpa_with_cp does some of these inherently (it's a fresh process). For the pooled path we replicate them explicitly so the two execution modes produce identical per-case state.

Edge cases

  1. Cross-pytest-session port collisions. torch.distributed.run --standalone picks a free rendezvous port at bind time. No MASTER_PORT plumbing needed; parallel pytest sessions (e.g., L1_pytorch_distributed_unittest's det vs non-det concurrent runs on disjoint GPU sets) cannot collide.
  2. Pool wedges after N cases. Per-case timeout is 90 s by default. On expiry, _kill() terminates the pool; the next submit() lazily respawns a fresh worker, so the blast radius is one test (the timed-out one) plus a one-time respawn cost.
  3. Pool worker crashes mid-case. BrokenPipeError / empty read → AssertionError on the current test with the stderr tail; kill the pool; respawn on next case.
  4. Two pools (2-GPU and 4-GPU) co-resident on overlapping low-index GPUs. torch.cuda.set_device(rank % device_count) means both pools claim GPUs 0–N starting at 0. They never run CUDA concurrently (pytest serialises tests), so the idle pool only holds ~1 GB of CUDA context per shared GPU — well within H100's 80 GB. NCCL worlds are independent. No collision. This is intentional: setting an explicit CUDA_VISIBLE_DEVICES per pool would prevent the parallel-pytest-session pattern used by qa/L1_pytorch_distributed_unittest/test.sh (det / non-det on disjoint GPU sets); see Follow-ups.
  5. NVTE_CP_POOL_PG collision. The env var is set by the pool worker after dist.init_process_group and only read by run_dpa_with_cp — no other consumer. If end-user code accidentally sets it without an init'd PG, run_dpa_with_cp will fail when it tries to use the (non-existent) PG. Harmless; same failure as setting any other internal flag incorrectly.

Validation

Run Result
8×H100, test_essential=True 38 passed / 10234 skipped / 0 failed in 248 s
8×H100, test_essential=False 348 passed / 50628 skipped / 0 failed in 26 m 23 s

Per-case wall-time distribution on H100 (test_essential=True, with NVTE_CP_POOL_TIMING=1): min 1.87 s, p50 4.77 s, p95 12.43 s, max 15.39 s. Drove the POOL_SUBMIT_TIMEOUT_SEC default of 90 s (~6× max).

Follow-ups (not in this PR)

Items intentionally deferred to keep this PR scoped:

  • L8 — drop the try/except Exception: pass around dist.destroy_process_group calls. Real semantic change: would surface errors that the swallow currently hides. Needs its own validation.
  • M1 — replace the NVTE_CP_POOL_PG env-var contract between the pool worker and run_dpa_with_cp with a module-level flag (run_attention_with_cp._POOL_MANAGED_PG = True). Cleaner, type-checkable; same effect.
  • M2 — redirect non-rank-0 stdout to /dev/null at worker startup instead of sentinel-prefix scanning. Closes the stdout-pollution class at the source rather than papering over it.
  • M3 — intentionally not done. Setting explicit CUDA_VISIBLE_DEVICES per pool (so 2-GPU and 4-GPU pools claim disjoint GPUs) would break the parallel-pytest-session pattern used by qa/L1_pytorch_distributed_unittest/test.sh. Each top-level pytest session sets CUDA_VISIBLE_DEVICES itself (e.g. det on GPUs 0-3, non-det on 4-7); the pool inherits that and uses rank % device_count to map within the session's slice. Adding per-pool device pinning would override the session's slice and produce overlap across sessions. The current "overlap within a session, idle context only" behaviour (edge case 4) is the right trade-off.

Files

  • transformer_engine/pytorch/attention/dot_product_attention/context_parallel.pywait_stream fix in AttnFuncWithCPAndKVAllGather.forward (the latent stream race the pool exposed).
  • tests/pytorch/attention/test_attention_with_cp.pycp_pool session fixture + PoolWorker (lazy spawn, --standalone, 90 s timeout, kill-and-respawn, stderr drainer thread + tail attached to crash AssertionErrors).
  • tests/pytorch/attention/run_attention_with_cp_pool.py — pool worker: init NCCL once, dispatch loop, _reset_between_cases (seed/FP8/cache), gather_object for cross-rank failure detail, optional NVTE_CP_POOL_TIMING=1 per-case timing log.
  • tests/pytorch/attention/run_attention_with_cp.pyrun_dpa_with_cp honours NVTE_CP_POOL_PG (skips its own PG init/destroy), deep-copies model configs, wraps the per-case body in try/finally so the per-case CP communicator is destroyed even when an assertion or other exception fires (previously leaked one or more NCCL communicators per failed case).

Comparison to #2965

Both PRs solve the same problem; this one is structurally smaller and has fewer concepts.

#2965 (dry-run + batched torchrun) This PR (persistent pool)
Batching-only diff vs origin/main 355 + / 69 − (2 files) 326 + / 51 − (3 files)
Knobs CP_TEST_BATCH_SIZE, CP_TEST_BATCH_RETRY NVTE_CP_POOL_TIMEOUT_SEC (default 90 s), NVTE_CP_POOL_TIMING
Required concepts dry-run, _COLLECT_MODE, _DummyRequest, _item_static_skip, _BACKEND_CACHE, batch chunking, atomic JSON flush, singleton retry long-lived worker, JSON-over-stdio
Cross-rank failure all_reduce(ok, MIN) — boolean only gather_object — full traceback per rank
Crash-path stderr in AssertionError ✓ (run_distributed() attaches last 4 KB) ✓ (drainer thread, last 4 KB attached)
Master-port handling needs explicit MASTER_PORT env per parallel pytest session --standalone, automatic
Per-case timeout none 90 s default
Spawns (test_essential=True) 4 @ B=16, 2 @ B=50 2 always
Pre-pytest overhead ~14 s dry-run + collection none
Full-matrix wall (H100, 348 runnable) n/a (#2965 measured at 328 runnable, before the race fix) 26 m 23 s
Full-matrix correctness clean (race didn't surface, fresh process per batch hides it) clean (race fixed in TE)

Type of change

  • Code refactoring (test infrastructure)
  • Bug fix (latent CP stream race exposed by the pool — fix in TE source)

Checklist

  • Contributing guidelines followed
  • Functionality complete
  • Code commented where non-obvious
  • Documentation (n/a — internal test infra)
  • No new warnings
  • Existing test suite serves as input + validation
  • Existing tests pass locally (8×H100, test_essential=False: 348 / 0)

sudhakarsingh27 and others added 6 commits May 13, 2026 22:10
The existing test path spawns one torchrun per parametrized case, paying
NCCL init + CUDA context + Python startup on every call. With ~hundreds
of cases the launch overhead dominates wall time and was a primary driver
of the L3 timeout that prior batching PRs worked around.

This change replaces the per-case subprocess with one long-lived
torchrun per (world_size). NCCL is initialized once at session start and
reused across cases. Pytest sends one JSON request per case over rank-0
stdin; the worker dispatches to run_dpa_with_cp(**kwargs), gathers
(ok, error) from every rank, and writes one JSON response on rank-0
stdout.

run_attention_with_cp.py is left almost untouched; a new
NVTE_CP_POOL_PG=1 env var gates the dist.init_process_group() and
dist.destroy_process_group() calls so the function reuses the pool's
main PG instead of creating its own. The per-case cp_comm_group (and
a2a+p2p sub-groups) are explicitly destroyed at function exit to
prevent communicator leakage across cases.

The PoolWorker class adds two pieces of error recovery that the prior
subprocess-per-case design got for free: a select-based per-call
timeout (default 600s, NVTE_CP_POOL_TIMEOUT_SEC) and auto-respawn on
worker death or timeout. A test-level exception is reported as an
AssertionError and the pool keeps running for the next case.

Two pool sizes are needed because cp_comm_type='a2a+p2p' requires
world_size=4 and the others use world_size=2; you can't resize an
active PG. Pools are spawned lazily so a 2-GPU-only run never pays the
4-GPU init.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Two resilience fixes carried over from the existing batching PR
(sudhakars/cp_test_batching_pr) without which the pool will
cascade-fail FP8 tests and silently propagate NCCL desync.

1. FP8GlobalStateManager.reset() between cases. FP8 quantizer state
   (recipe handles, autocast counters) lives in module-level globals.
   Reusing one Python process across cases otherwise carries that state
   forward. The prior batching PR landed an explicit fix for the same
   issue ("Fix FP8 cascade failures") after observing real test
   failures from this.

2. dist.barrier() after each case. If one rank's case errored before
   its last collective, the others can be stuck waiting on a comm that
   will never complete. The barrier here surfaces that immediately as
   a timeout in this case rather than letting the corruption leak into
   the next case's collectives.

Also pops the transient NVTE_* env vars run_dpa_with_cp sets at the
top of each call. run_dpa_with_cp already sets them unconditionally so
this is defensive, but cheap insurance against future variants that
might not.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The model_configs_{flash,fused}_attn dicts are module-level and shared
across pool cases. The THD branch below rewrites config.attn_mask_type
in place (causal -> padding_causal, no_mask -> padding). With the
persistent-pool runner, the next case looking up the same model key
gets the mutated config and fails the "causal or no_mask only" assert.

Caught at benchmark time on cp_2_0 + thd, identical to the cascade the
existing batching PR (sudhakars/cp_test_batching_pr) hit and fixed the
same way in commit 6355f62.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Mirrors the two pre-emptive skips on the PR-batching branch:

* non-vanilla softmax with FusedAttention is not deterministic
* post_scale_bias with requires_grad is not deterministic

Without these skips, the corresponding configs propagate into the pool
worker under NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 and fail inside
run_dpa_with_cp instead of being marked SKIPPED.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The pool worker reused RNG state across cases, which produced
small numerical drift on some non-FP8 fused-attention configs
(cp_1_0 + thd/p2p, cp_1_0 + sbhd/all_gather) compared to the
single-shot worker. Matches the per-case startup of the single-shot
worker: torch.manual_seed(1234) + torch.cuda.manual_seed(1234) at
the start of every case, alongside the existing FP8 / env / cache
resets.

Moved the reset call from the post-case finally block to the start
of _run_one so the first case is also seeded consistently with
subsequent cases. Otherwise the first case would inherit the
process-default RNG and only the second-and-later cases would be
deterministic.

Validated locally: 38 passed, 0 failed (was 36 passed, 2 failed).

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 14, 2026

Greptile Summary

This PR replaces per-test torchrun spawning with a persistent NCCL worker pool (one per world_size), achieving ~2.23x speedup on the essential test matrix by amortising Python import and NCCL init costs. It also fixes a latent CUDA stream race in AttnFuncWithCPAndKVAllGather.forward where max_logit_per_step[1] written on cp_stream was read on the default stream at the next loop iteration without an intervening wait_stream.

  • Pool worker (run_attention_with_cp_pool.py): NCCL initialised once, JSON-over-stdio dispatch loop, per-case gather_object for full cross-rank tracebacks, optional per-case timing via NVTE_CP_POOL_TIMING=1.
  • Test harness (test_attention_with_cp.py): session-scoped cp_pool fixture with lazy spawn, sentinel-prefixed stdout protocol, 90s per-case timeout with kill-and-respawn, daemon stderr drainer that attaches the last 4 KB to crash AssertionErrors.
  • Stream race fix (context_parallel.py): one wait_stream(flash_attn_streams[i-1]) before the torch.maximum read in the all-gather CP forward, placed outside the with torch.cuda.stream(...) block so it correctly synchronises the default stream against cp_stream.

Confidence Score: 5/5

Safe to merge; the test infrastructure is well-structured and the stream race fix in context_parallel.py is a genuine correctness improvement for long-lived CP processes.

The pool design is sound: sentinel-based stdout protocol, stderr draining, kill-and-respawn on timeout, and pool-shared NCCL groups eliminate per-case spawn cost without changing test semantics. The stream race fix is correct — the new wait_stream sits outside the with torch.cuda.stream block (which ends after the copy operations), so it correctly synchronises the default stream against cp_stream before the torch.maximum read. The two hardening observations in test_attention_with_cp.py do not affect current correctness.

No files require special attention for merging; the two style/hardening observations in test_attention_with_cp.py are non-blocking.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py One-line stream race fix: adds wait_stream(flash_attn_streams[i-1]) before reading max_logit_per_step[i-1] on the default stream, correctly ordering the cp_stream write (at i=1) before the default-stream read (at i=2).
tests/pytorch/attention/run_attention_with_cp_pool.py New persistent pool worker: NCCL init once, JSON dispatch loop, gather_object for per-rank tracebacks, pool-shared CP groups pre-created once, finally-guarded teardown.
tests/pytorch/attention/test_attention_with_cp.py Adds PoolWorker class and cp_pool fixture; _RESP_PREFIX sentinel hardcoded as a string literal in both this file and run_attention_with_cp_pool.py; timeout loop can spin past deadline if worker floods stdout with non-sentinel lines.
tests/pytorch/attention/run_attention_with_cp.py Adds deep-copy of model configs, pool PG reuse logic (_pool_managed_pg / _reusing_pool_groups), and pool-shared CP group module-level references; teardown is success-path only (intentional for single-shot mode).

Sequence Diagram

sequenceDiagram
    participant P as pytest (session)
    participant PW as PoolWorker (pytest-side)
    participant R0 as torchrun rank-0
    participant RN as torchrun rank-1..N

    P->>PW: "cp_pool(world_size=2)"
    PW->>R0: spawn torchrun --standalone (lazy, first use)
    R0->>RN: dist.init_process_group (NCCL, once)
    R0->>RN: _create_cp_comm_groups (once)
    Note over R0,RN: Pool enters dispatch loop

    loop per test case
        P->>PW: "_submit(pool, **kwargs)"
        PW->>R0: JSON line to stdin
        R0->>RN: broadcast_object_list(request)
        R0->>RN: _reset_between_cases() + run_dpa_with_cp()
        R0->>RN: gather_object((ok, traceback))
        R0->>PW: stdout sentinel-prefixed JSON response
        PW->>P: return or raise AssertionError
    end

    P->>PW: fixture teardown
    PW->>R0: shutdown JSON line
    R0->>RN: destroy CP groups + dist.destroy_process_group()
Loading

Reviews (7): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +773 to +781
# destroy distribution group
dist.destroy_process_group()
if _pool_managed_pg:
# Pool owns the main PG; only clean up groups created for this case.
dist.destroy_process_group(cp_comm_group)
if cp_comm_type == "a2a+p2p":
for g in cp_comm_sub_groups:
dist.destroy_process_group(g)
else:
dist.destroy_process_group()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Process group leak on exception in pool mode

When run_dpa_with_cp raises any exception after cp_comm_group is created (line 250) — e.g. a tensor-mismatch assertion, an OOM, or a NCCL error — the exception propagates out of the function before reaching this cleanup block. In pool mode dist.destroy_process_group() is never called for that case's cp_comm_group (or cp_comm_sub_groups). Every failed case therefore leaks one or more NCCL communicators. NCCL tracks live communicators internally and can exhaust resources over a long test session with repeated failures, causing subsequent cases to fail for unrelated reasons.

The standard fix is to wrap the body of run_dpa_with_cp in a try/finally so that the selected cleanup always runs. Note that cp_comm_group is created unconditionally before any test logic, so it is always defined by the time the finally block runs.

Comment on lines +66 to +69
def _send_response(rank: int, payload: dict) -> None:
if rank == 0:
sys.stdout.write(json.dumps(payload) + "\n")
sys.stdout.flush()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 stdout pollution can silently corrupt the JSON protocol

torchrun (and the worker processes it spawns for ranks 1–N) all inherit the same stdout file descriptor as rank 0. If torchrun writes any status line to stdout, or if any non-rank-0 worker accidentally prints (e.g. via a print call in a library, NCCL debug output, or a Python warning), those bytes are interleaved with rank 0's JSON responses. The parent's readline() in PoolWorker.submit would then receive a non-JSON line and raise a json.JSONDecodeError, killing the pool and failing the test with a misleading error message.

Consider redirecting torchrun's own output or adding a sentinel prefix to every response line so the reader can skip unrecognised lines.

self._kill()
raise AssertionError("pool worker died before request could be sent")

ready, _, _ = select.select([self.proc.stdout], [], [], timeout)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 select.select on a subprocess pipe works on Linux/macOS but raises OSError on Windows because Windows does not support select on non-socket file descriptors. This is test infrastructure for GPU hardware so Windows is not a target, but the call should at least be documented as Linux-only.

Suggested change
ready, _, _ = select.select([self.proc.stdout], [], [], timeout)
ready, _, _ = select.select([self.proc.stdout], [], [], timeout) # Linux/macOS only

sudhakarsingh27 and others added 4 commits May 14, 2026 12:49
Three changes that bring the pool's failure semantics on par with the
per-batch torchrun approach in PR NVIDIA#2965 and remove a couple of footguns:

1. Capture pool-worker stderr into a ring buffer and attach the tail to
   crash-path AssertionErrors. Equivalent in spirit to PR NVIDIA#2965's
   run_distributed() — CI JUnit XML now shows the actual cause (NCCL
   error, Python traceback, OOM) inline with the failing test, instead
   of just "pool worker died mid-request" / "timed out". A daemon
   drainer thread reads stderr line-by-line into a deque(maxlen=200)
   and also echoes to sys.stderr so pytest's per-test capture still
   gets every line. Maximum buffered footprint ~40 KB.

2. Tighten POOL_SUBMIT_TIMEOUT_SEC default 600 -> 90. On H100 the
   slowest observed per-case wall is ~15 s (p99 also 15 s, p50 ~5 s).
   90 s gives ~6x headroom over the worst observed case while still
   detecting a genuine hang within ~1.5 min instead of ~10 min. Env
   var still overrides for slower machines or expanded test matrices.

3. Optional per-case wall-time logging (NVTE_CP_POOL_TIMING=1) prints
   "[POOL-TIMING] case_idx=N world_size=W wall_s=X.XXX ok=B" to stderr
   on rank 0 only. Grep-friendly; lets future tuning recalibrate the
   timeout against the observed distribution. Off by default so normal
   runs stay quiet.

Validated: 38 passed / 0 failed in 248 s on H100, test_essential=True,
with no perf regression vs the un-patched 256 s.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Three fixes responding to NVIDIA#2993
review comments:

P1: NCCL communicator leak on exception (run_attention_with_cp.py)

run_dpa_with_cp() created cp_comm_group (and optionally cp_comm_sub_groups)
near the top, but the destroy_process_group() calls ran only on the
success path at the end of the function. Any exception in between
(tensor assertion, OOM, NCCL error) skipped the cleanup, leaking
communicators in pool mode. Long sessions with repeated failures
could exhaust NCCL internal tracking.

Wrap the test work in try/finally so the destroy logic always runs.
Initialise cp_comm_sub_groups = [] unconditionally so the finally
block is safe even when cp_comm_type != "a2a+p2p" (or when an assert
fires before the populate loop). Each destroy is itself try/except so
a destroy failure on one group doesn't leak the others.

P2: stdout protocol can be corrupted by interleaved chatter

torchrun and ranks 1..N share rank 0's stdout fd. Any non-rank-0
print, NCCL debug line, or torchrun status output interleaves with
the JSON response and breaks json.loads, killing the pool with a
misleading "json decode error".

Prefix every response with "[CP_POOL_RESP] " in run_attention_with_cp_pool.py
and have PoolWorker.submit() scan stdout for sentinel-prefixed lines,
echoing non-protocol lines to stderr for visibility. Bounded scan
(MAX_NOISE_LINES=1000) so a chatty worker can't stall the parent.

P2 (doc): select.select on a pipe fd is Linux/macOS only

Added a short comment noting Windows portability. CP attention tests
run on Linux GPU hosts; this is a documentation issue, not a real bug.

Validated: 38 passed / 0 failed in 270 s on H100, test_essential=True
(was 248 s pre-P2 — the +22 s is the new sentinel-scan loop's per-line
overhead at ~600 ms/case, within noise).

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

Thanks for the review @greptile-apps — all three findings addressed in e162a9ec:

P1 — NCCL communicator leak on exception (run_attention_with_cp.py)

Wrapped the body of run_dpa_with_cp in try/finally. cp_comm_sub_groups = [] is now initialised unconditionally before the a2a+p2p branch so the finally block is safe even when (a) cp_comm_type != "a2a+p2p" or (b) the inline assert at the top of the a2a+p2p block fires before the populate loop. Each dist.destroy_process_group() is itself wrapped in try/except Exception: pass so a destroy failure on one group can't leak the others.

P2 — stdout protocol can be corrupted by interleaved chatter (run_attention_with_cp_pool.py + test_attention_with_cp.py)

Worker now prefixes every response with [CP_POOL_RESP] (constant _RESP_PREFIX). PoolWorker.submit() reads lines in a loop and skips anything that doesn't start with the sentinel — echoing the non-protocol lines to stderr so they're still visible in CI logs. Bounded by _MAX_NOISE_LINES = 1000 so a chatty worker can't keep us spinning past the deadline. The timeout budget is preserved across iterations via a monotonic deadline.

P2 — select.select on a pipe fd is Linux/macOS only (test_attention_with_cp.py)

Added a short comment noting the Windows portability gap. Tests target Linux GPU hosts only; this is a doc issue, not a real bug, but flagged for future readers.

PR-body note

The "default 600 s" inconsistency you flagged on the env-var table was already reconciled in the previous push — POOL_SUBMIT_TIMEOUT_SEC defaults to 90 s in both code and description. (You reviewed at 385e9662, which was pre-polish.)

Validation

Re-ran test_essential=True on H100 after the three fixes: 38 passed / 0 failed in 270 s (was 248 s pre-P2 — the +22 s is the new sentinel-scan loop's per-line overhead, within noise).

The 5 state-cascade failures on test_essential=False documented under Known issues are unchanged by these fixes (different root cause — process-shared state outside the _reset_between_cases hook). Tracking separately.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 14, 2026

All three fixes look correct. Let me note one small remaining issue in the deadline logic before signing off.

In submit(), the deadline is initialized after the first select.select returns, not before the loop:

deadline = None
while scanned < self._MAX_NOISE_LINES:
    remaining = timeout if deadline is None else max(0.0, deadline - time.monotonic())
    ready, _, _ = select.select([self.proc.stdout], [], [], remaining)
    if deadline is None:
        deadline = time.monotonic() + timeout   # ← set after first select completes

If the first line (noise or response) arrives after, say, 50 s, the deadline is then set to now + 90 s, making the total possible wall time for a chatty worker ~140 s instead of 90 s. The comment "The timeout budget is preserved across iterations via a monotonic deadline" overstates what's implemented — the deadline only bounds iterations 2..N, not the first.

The fix is to set the deadline once, before the loop:

deadline = time.monotonic() + timeout
while scanned < self._MAX_NOISE_LINES:
    remaining = max(0.0, deadline - time.monotonic())
    ready, _, _ = select.select([self.proc.stdout], [], [], remaining)
    if not ready:
        ...

In practice this won't matter much — torchrun status chatter arrives nearly instantaneously before the response — but the current code doesn't actually enforce the stated timeout bound end-to-end.

Everything else looks good:

  • P1: cp_comm_sub_groups = [] is now unconditionally initialized before the a2a+p2p branch, the test body is wrapped in try/finally, and each destroy_process_group() is individually guarded with try/except Exception: pass. Communicators will be reclaimed even if the body raises mid-assertion.
  • P2 sentinel: _RESP_PREFIX = "[CP_POOL_RESP] " is written by the worker and the parent scans for it correctly, echoing non-sentinel lines to stderr. The _MAX_NOISE_LINES = 1000 cap prevents an infinite spin.
  • P2 doc: The select.select Windows portability note is in place.

The deadline fix above is the only remaining item; everything else is addressed cleanly.

Tip: You can customize Greptile's behavior for this repo with .greptile/rules.md and .greptile/config.json.

Comment on lines 250 to 263
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
# Always defined so the finally cleanup below is safe even when cp_comm_type != "a2a+p2p".
cp_comm_sub_groups = []
if cp_comm_type == "a2a+p2p":
assert world_size % 2 == 0, (
"{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size"
" = 2."
)
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Group creation sits outside the try/finally guard

cp_comm_group (line 250) and the sub-group creation loop (lines 260–263) all run before the try block that starts at line 268. If the sub-group loop raises — e.g., NCCL rejects a new_group call on a partially recovered communicator after a previous case failed — the finally block is never entered, so cp_comm_group and any sub-groups already appended to cp_comm_sub_groups are never destroyed. In a long-running pool this accumulates a communicator leak on every such failure, eventually hitting NCCL's internal communicator limit and breaking unrelated subsequent cases. Moving all three dist.new_group sites inside the try block (or at minimum wrapping the sub-group loop and the unconditional cp_comm_group creation in a nested try/finally) would close the gap.

Copy link
Copy Markdown
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the structure is a lot simpler/cleaner than #2965 - thanks!

For the 5 remaining failing tests, I wonder if there's something else we need to reset for FP8 (please check with @ksivaman).

Also, please compare the number of tests before/after this PR and make sure we're still running the same number of tests! If the new, reduced runtime allows now, we can turn on test_essential=False, but I'll leave that to you.

If no major issues, I approve! Thanks!

if _pool_managed_pg:
# Pool owns the main PG; only clean up groups created for this case.
try:
dist.destroy_process_group(cp_comm_group)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to destroy (and create earlier in the file) cp_comm_group for every config? I feel if the pool is the same, the world_size would be the same, and so is cp_comm_group?

In AttnFuncWithCPAndKVAllGather.forward, max_logit_per_step[i] is
written inside `with torch.cuda.stream(flash_attn_streams[i])`. For
i=1, flash_attn_streams[1] is cp_stream — i.e. *not* the default
stream. Later, at loop iteration i=2, the code reads
max_logit_per_step[1] via `torch.maximum(max_logit, max_logit_per_step[i-1])`
which runs on the default stream. Without an explicit wait_stream,
this is a read-after-write race across streams. The post-loop
`current_stream().wait_stream(cp_stream)` is too late — the race has
already fired.

The race is latent: outcome depends on stream scheduling. In a
fresh-process subprocess (one-torchrun-per-test path), streams are
cleanly initialised and timing happens to put the write before the
read. In a long-running persistent-worker process — exposed by
PR NVIDIA#2993's pool design — prior workloads shape stream state
differently, the read can fire before the write completes, and
max_logit ends up with stale values in some heads (~0.3 abs diff,
3/12 elements wrong on the H100 matrix).

Fix: insert `current_stream().wait_stream(flash_attn_streams[i-1])`
before the torch.maximum read. No-op when the streams are identical
(i=1 case, where flash_attn_streams[0] is current_stream), only
fires when reading from cp_stream (i=2 case).

Validated: 8xH100, test_essential=False, 348 passed / 0 failed in
27m 10s (was 323 passed + 5 failed at this commit's parent, all 5
failing on cp_comm_type=all_gather with mismatched max_logit).
The failing configs (all_gather + cp_1_0/cp_1_1 + bshd or fp16) now
pass under the pool — confirming the race was the sole root cause.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Line-level cleanups from the second reviewer pass on PR NVIDIA#2993. Each item
is dead/redundant; none changes behaviour. Full-matrix test_essential=False
on 8xH100 still passes 348/0 in 26m 23s after these.

run_attention_with_cp_pool.py:
- Drop _TRANSIENT_ENV_KEYS tuple + pop loop. run_dpa_with_cp already
  re-sets NVTE_FUSED_ATTN/NVTE_FLASH_ATTN unconditionally at the top
  and pops the FP8 ones itself. The pop loop was defensive against a
  hypothetical "future caller that doesn't re-set them" that doesn't
  exist.
- Drop gc.collect() after torch.cuda.empty_cache(). The cases create
  no Python reference cycles between iterations and empty_cache only
  frees CUDA blocks PyTorch already considers free; the combination
  was no-op here.
- Drop dist.barrier() after dist.gather_object(). gather_object is
  itself a collective synchronization point — if every rank reaches
  it, none is ahead. The "surface a wedged communicator here" comment
  was wishful: a wedged communicator would already wedge the gather.

test_attention_with_cp.py (PoolWorker):
- Drop _MAX_NOISE_LINES = 1000 + the scanned counter + the
  unreachable post-loop "1000+ lines" branch. select()'s deadline
  already bounds the loop; the line-count cap was redundant and
  the over-limit branch was unreachable in practice.
- Inline _stderr_tail() into _diag(). Single caller, single use.
- Drop the _stderr_thread attribute. The drainer is daemon and
  self-terminates when the pipe closes; we never read the field
  anywhere, so initialising and nulling it was bookkeeping for no
  reason.
- Drop the dead assert in submit() — _ensure_alive() on the prior
  line already guarantees proc/stdin/stdout exist.

Deferred to a follow-up:
- L8 (drop try/except around dist.destroy_process_group). Real
  semantic change: hides errors that occur when a previous test
  wedged the communicator. Worth doing but needs its own validation.
- R1 medium items M1 (module-level flag vs NVTE_CP_POOL_PG env var),
  M2 (redirect rank>0 stdout vs sentinel scan), M3 (explicit
  CUDA_VISIBLE_DEVICES per pool). Same reasoning — separate PRs.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
torch.cuda.manual_seed(seed)

test_essential = True
test_essential = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 test_essential left as False — will break CI

The default was True (38 essential configs, all passing) and has been changed to False (full matrix, 328 runnable, 5 known failures). The PR description explicitly calls these out as "state-cascade failures" under Known Issues. Merging with test_essential = False will cause CI to report 5 failures by default, regressing from the current baseline.

sudhakarsingh27 added a commit to sudhakarsingh27/TransformerEngine that referenced this pull request May 15, 2026
world_size and the rank set don't change for the lifetime of one pool, so
recreating the world group and a2a+p2p sub-groups per case wastes ~50-100 ms
of NCCL setup each. Pre-create them once in the pool worker (new helper
_create_cp_comm_groups), stash on the run_attention_with_cp module via
module-level _pool_cp_comm_group / _pool_cp_comm_sub_groups pointers, and
reuse them from run_dpa_with_cp in pool mode. Pool teardown destroys them
once at shutdown.

Also move per-case dist.new_group() calls inside the try/finally in
run_dpa_with_cp: a failure mid-loop in the a2a+p2p sub_group population
otherwise leaks every communicator created before the failure. The finally
now only destroys groups we created locally (cp_comm_group / sub_groups
populated in the else-branch), leaving pool-owned groups alone for reuse.

cyanguwa's review feedback on PR NVIDIA#2993.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
sudhakarsingh27 and others added 4 commits May 15, 2026 14:42
world_size and the rank set don't change for the lifetime of one pool, so
recreating the world group and a2a+p2p sub-groups per case wastes ~50-100 ms
of NCCL setup each. Pre-create them once in the pool worker (new helper
_create_cp_comm_groups), stash on the run_attention_with_cp module via
module-level _pool_cp_comm_group / _pool_cp_comm_sub_groups pointers, and
reuse them from run_dpa_with_cp in pool mode. Pool teardown destroys them
once at shutdown.

Also move per-case dist.new_group() calls inside the try/finally in
run_dpa_with_cp: a failure mid-loop in the a2a+p2p sub_group population
otherwise leaks every communicator created before the failure. The finally
now only destroys groups we created locally (cp_comm_group / sub_groups
populated in the else-branch), leaving pool-owned groups alone for reuse.

cyanguwa's review feedback on PR NVIDIA#2993.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The Round-1 P1 NCCL-communicator-leak fix (e162a9e) wrapped the
~540-line body of run_dpa_with_cp in try/finally. The wrap itself
was tiny but it re-indented every line of the body by one level,
inflating the PR diff of run_attention_with_cp.py to ~1000 lines
against origin/main.

Items 2+3 (d15bfce) since made the wrap unnecessary:
  - In pool mode, cp_comm_group and cp_comm_sub_groups are owned by
    the pool worker (which destroys them once at pool shutdown).
    run_dpa_with_cp neither creates nor destroys them, so an
    in-body exception can't leak communicators.
  - In single-shot mode, groups are still created locally, but the
    subprocess exits at function return; NCCL releases everything
    at process teardown, so a stray exception leaks communicators
    only for the milliseconds before the process dies — a bounded
    one-off cost, not the unbounded accumulation that Round-1
    flagged for pool mode.

Removing the wrap drops the run_attention_with_cp.py diff against
origin/main from ~1000 lines to ~120 lines without changing
observable behaviour. Smoke-tested: 4 representative cases pass.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants