From dc92437814579c7c5391ac1302a559fe4ff115e5 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 31 Mar 2026 09:48:55 -0700 Subject: [PATCH 01/21] Reorganize graph test files for clarity Rename test files to reflect what they actually test: - test_basic -> test_graph_builder (stream capture tests) - test_conditional -> test_graph_builder_conditional - test_advanced -> test_graph_update (moved child_graph and stream_lifetime tests into test_graph_builder) - test_capture_alloc -> test_graph_memory_resource - test_explicit* -> test_graphdef* Made-with: Cursor --- cuda_core/tests/graph/test_graph_update.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cuda_core/tests/graph/test_graph_update.py b/cuda_core/tests/graph/test_graph_update.py index 48a88d5ac5..42b64db016 100644 --- a/cuda_core/tests/graph/test_graph_update.py +++ b/cuda_core/tests/graph/test_graph_update.py @@ -145,11 +145,10 @@ def build_graph(condition_value): assert arr[1] == 3 assert arr[2] == 3 - # Close the memory resource now because the garbage collected might - # de-allocate it during the next graph builder process b.close() + # ============================================================================= # Error cases # ============================================================================= From 281ed8283ec1d48ba5d395a5dea338ab53be5d25 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 31 Mar 2026 11:24:50 -0700 Subject: [PATCH 02/21] Enhance Graph.update() and add whole-graph update tests - Extend Graph.update() to accept both GraphBuilder and GraphDef sources - Surface CUgraphExecUpdateResultInfo details on update failure instead of a generic CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE message - Release the GIL during cuGraphExecUpdate via nogil block - Add parametrized happy-path test covering both GraphBuilder and GraphDef - Add error-case tests: unfinished builder, topology mismatch, wrong type Made-with: Cursor --- cuda_core/tests/graph/test_graph_update.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cuda_core/tests/graph/test_graph_update.py b/cuda_core/tests/graph/test_graph_update.py index 42b64db016..48a88d5ac5 100644 --- a/cuda_core/tests/graph/test_graph_update.py +++ b/cuda_core/tests/graph/test_graph_update.py @@ -145,10 +145,11 @@ def build_graph(condition_value): assert arr[1] == 3 assert arr[2] == 3 + # Close the memory resource now because the garbage collected might + # de-allocate it during the next graph builder process b.close() - # ============================================================================= # Error cases # ============================================================================= From 7854b7681da61384b3dc0664a88022a681d48fb1 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 31 Mar 2026 15:37:26 -0700 Subject: [PATCH 03/21] Add AdjacencySet proxy for pred/succ and GraphNode.remove() Replace cached tuple-based pred/succ with mutable AdjacencySet backed by direct CUDA driver calls. Add GraphNode.remove() wrapping cuGraphDestroyNode. Made-with: Cursor --- .../core/_graph/_graph_def/_adjacency_set.pyx | 202 ++++++++++++++++++ .../core/_graph/_graph_def/_graph_node.pxd | 2 - .../core/_graph/_graph_def/_graph_node.pyx | 116 ++-------- cuda_core/tests/graph/test_graphdef.py | 4 +- .../tests/graph/test_graphdef_lifetime.py | 2 +- 5 files changed, 218 insertions(+), 108 deletions(-) create mode 100644 cuda_core/cuda/core/_graph/_graph_def/_adjacency_set.pyx diff --git a/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set.pyx b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set.pyx new file mode 100644 index 0000000000..124a7a5610 --- /dev/null +++ b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set.pyx @@ -0,0 +1,202 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Mutable-set proxy for graph node predecessors and successors.""" + +from libc.stddef cimport size_t +from libcpp.vector cimport vector +from cuda.bindings cimport cydriver +from cuda.core._graph._graph_def._graph_node cimport GraphNode +from cuda.core._resource_handles cimport ( + GraphHandle, + GraphNodeHandle, + as_cu, + graph_node_get_graph, +) +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN +from collections.abc import MutableSet + + +# ---- Python MutableSet wrapper ---------------------------------------------- + +class AdjacencySet(MutableSet): + """Mutable set-like view of a node's predecessors or successors.""" + + __slots__ = ("_core",) + + def __init__(self, node, bint is_fwd): + self._core = _AdjacencySetCore(node, is_fwd) + + # --- abstract methods required by MutableSet --- + + def __contains__(self, x): + if not isinstance(x, GraphNode): + return False + return x in (<_AdjacencySetCore>self._core).query() + + def __iter__(self): + return iter((<_AdjacencySetCore>self._core).query()) + + def __len__(self): + return (<_AdjacencySetCore>self._core).count() + + def add(self, value): + if not isinstance(value, GraphNode): + raise TypeError( + f"expected GraphNode, got {type(value).__name__}") + (<_AdjacencySetCore>self._core).add_edge(value) + + def discard(self, value): + if not isinstance(value, GraphNode): + return + if value not in self: + return + (<_AdjacencySetCore>self._core).remove_edge(value) + + # --- override for bulk efficiency --- + + def update(self, *others): + """Add edges to multiple nodes at once.""" + nodes = [] + for other in others: + if isinstance(other, GraphNode): + nodes.append(other) + else: + nodes.extend(other) + if not nodes: + return + for n in nodes: + if not isinstance(n, GraphNode): + raise TypeError( + f"expected GraphNode, got {type(n).__name__}") + (<_AdjacencySetCore>self._core).add_edges(nodes) + + def __repr__(self): + return "{" + ", ".join(repr(n) for n in self) + "}" + + +# ---- cdef core holding function pointer ------------------------------------ + +# Signature shared by _get_preds and _get_succs. +ctypedef cydriver.CUresult (*_adj_fn_t)( + cydriver.CUgraphNode, cydriver.CUgraphNode*, size_t*) noexcept nogil + + +cdef class _AdjacencySetCore: + """Cythonized core implementing AdjacencySet""" + cdef: + GraphNodeHandle _h_node + GraphHandle _h_graph + _adj_fn_t _query_fn + bint _is_fwd + + def __init__(self, GraphNode node, bint is_fwd): + self._h_node = node._h_node + self._h_graph = graph_node_get_graph(node._h_node) + self._is_fwd = is_fwd + self._query_fn = _get_succs if is_fwd else _get_preds + + cdef inline void _resolve_edge( + self, GraphNode other, + cydriver.CUgraphNode* c_from, + cydriver.CUgraphNode* c_to) noexcept: + if self._is_fwd: + c_from[0] = as_cu(self._h_node) + c_to[0] = as_cu(other._h_node) + else: + c_from[0] = as_cu(other._h_node) + c_to[0] = as_cu(self._h_node) + + cdef list query(self): + cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) + if c_node == NULL: + return [] + cdef size_t count = 0 + with nogil: + HANDLE_RETURN(self._query_fn(c_node, NULL, &count)) + if count == 0: + return [] + cdef vector[cydriver.CUgraphNode] nodes_vec + nodes_vec.resize(count) + with nogil: + HANDLE_RETURN(self._query_fn( + c_node, nodes_vec.data(), &count)) + return [GraphNode._create(self._h_graph, nodes_vec[i]) + for i in range(count)] + + cdef Py_ssize_t count(self): + cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) + if c_node == NULL: + return 0 + cdef size_t n = 0 + with nogil: + HANDLE_RETURN(self._query_fn(c_node, NULL, &n)) + return n + + cdef void add_edge(self, GraphNode other): + cdef cydriver.CUgraphNode c_from, c_to + self._resolve_edge(other, &c_from, &c_to) + with nogil: + HANDLE_RETURN(_add_edge(as_cu(self._h_graph), &c_from, &c_to, 1)) + + cdef void remove_edge(self, GraphNode other): + cdef cydriver.CUgraphNode c_from, c_to + self._resolve_edge(other, &c_from, &c_to) + with nogil: + HANDLE_RETURN(_remove_edge(as_cu(self._h_graph), &c_from, &c_to, 1)) + + cdef void add_edges(self, list nodes): + cdef size_t n = len(nodes) + cdef vector[cydriver.CUgraphNode] from_vec + cdef vector[cydriver.CUgraphNode] to_vec + from_vec.resize(n) + to_vec.resize(n) + cdef size_t i + for i in range(n): + self._resolve_edge(nodes[i], &from_vec[i], &to_vec[i]) + with nogil: + HANDLE_RETURN(_add_edge( + as_cu(self._h_graph), from_vec.data(), to_vec.data(), n)) + + +# ---- driver wrappers: absorb CUDA version differences ---- + +cdef cydriver.CUresult _get_preds( + cydriver.CUgraphNode node, cydriver.CUgraphNode* out, + size_t* count) noexcept nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + return cydriver.cuGraphNodeGetDependencies(node, out, NULL, count) + ELSE: + return cydriver.cuGraphNodeGetDependencies(node, out, count) + + +cdef cydriver.CUresult _get_succs( + cydriver.CUgraphNode node, cydriver.CUgraphNode* out, + size_t* count) noexcept nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + return cydriver.cuGraphNodeGetDependentNodes(node, out, NULL, count) + ELSE: + return cydriver.cuGraphNodeGetDependentNodes(node, out, count) + + +cdef cydriver.CUresult _add_edge( + cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr, + cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + return cydriver.cuGraphAddDependencies( + graph, from_arr, to_arr, NULL, count) + ELSE: + return cydriver.cuGraphAddDependencies( + graph, from_arr, to_arr, count) + + +cdef cydriver.CUresult _remove_edge( + cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr, + cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + return cydriver.cuGraphRemoveDependencies( + graph, from_arr, to_arr, NULL, count) + ELSE: + return cydriver.cuGraphRemoveDependencies( + graph, from_arr, to_arr, count) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd index 7a9f82f33f..0a87b70ad6 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pxd @@ -9,8 +9,6 @@ from cuda.core._resource_handles cimport GraphHandle, GraphNodeHandle cdef class GraphNode: cdef: GraphNodeHandle _h_node - tuple _pred_cache - tuple _succ_cache object __weakref__ @staticmethod diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx index 17c2c072f7..f3ec73130b 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx @@ -57,6 +57,7 @@ from cuda.core._graph._utils cimport ( ) from cuda.core import Device +from cuda.core._graph._graph_def._adjacency_set import AdjacencySet from cuda.core._utils.cuda_utils import driver, handle_return @@ -122,33 +123,21 @@ cdef class GraphNode: """ return as_py(self._h_node) - @property - def pred(self) -> tuple: - """Return the predecessor nodes (dependencies) of this node. - - Results are cached since a node's dependencies are immutable - once created. - - Returns - ------- - tuple of GraphNode - The nodes that this node depends on. - """ - return GN_pred(self) + def remove(self): + """Remove this node and all its edges from the parent graph.""" + cdef cydriver.CUgraphNode node = as_cu(self._h_node) + with nogil: + HANDLE_RETURN(cydriver.cuGraphDestroyNode(node)) @property - def succ(self) -> tuple: - """Return the successor nodes (dependents) of this node. - - Results are cached and automatically invalidated when new - dependent nodes are added via builder methods. + def pred(self): + """A mutable set-like view of this node's predecessors.""" + return AdjacencySet(self, False) - Returns - ------- - tuple of GraphNode - The nodes that depend on this node. - """ - return GN_succ(self) + @property + def succ(self): + """A mutable set-like view of this node's successors.""" + return AdjacencySet(self, True) def launch(self, config: LaunchConfig, kernel: Kernel, *args) -> KernelNode: """Add a kernel launch node depending on this node. @@ -504,7 +493,6 @@ cdef inline ConditionalNode _make_conditional_node( n._cond_type = cond_type n._branches = branches - pred._succ_cache = None return n cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node): @@ -546,72 +534,6 @@ cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node): return n -cdef inline tuple GN_pred(GraphNode self): - if self._pred_cache is not None: - return self._pred_cache - - cdef cydriver.CUgraphNode node = as_cu(self._h_node) - if node == NULL: - self._pred_cache = () - return self._pred_cache - - cdef size_t num_deps = 0 - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, NULL, NULL, &num_deps)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, NULL, &num_deps)) - - if num_deps == 0: - self._pred_cache = () - return self._pred_cache - - cdef vector[cydriver.CUgraphNode] deps - deps.resize(num_deps) - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, deps.data(), NULL, &num_deps)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, deps.data(), &num_deps)) - - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - self._pred_cache = tuple(GraphNode._create(h_graph, deps[i]) for i in range(num_deps)) - return self._pred_cache - - -cdef inline tuple GN_succ(GraphNode self): - if self._succ_cache is not None: - return self._succ_cache - - cdef cydriver.CUgraphNode node = as_cu(self._h_node) - if node == NULL: - self._succ_cache = () - return self._succ_cache - - cdef size_t num_deps = 0 - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, NULL, NULL, &num_deps)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, NULL, &num_deps)) - - if num_deps == 0: - self._succ_cache = () - return self._succ_cache - - cdef vector[cydriver.CUgraphNode] deps - deps.resize(num_deps) - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, deps.data(), NULL, &num_deps)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, deps.data(), &num_deps)) - - cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) - self._succ_cache = tuple(GraphNode._create(h_graph, deps[i]) for i in range(num_deps)) - return self._succ_cache - - cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker, ParamHolder ker_args): cdef cydriver.CUDA_KERNEL_NODE_PARAMS node_params cdef cydriver.CUgraphNode new_node = NULL @@ -644,7 +566,6 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker, _attach_user_object(as_cu(h_graph), new KernelHandle(ker._h_kernel), _destroy_kernel_handle_copy) - self._succ_cache = None return KernelNode._create_with_params( create_graph_node_handle(new_node, h_graph), conf.grid, conf.block, conf.shmem_size, @@ -674,9 +595,6 @@ cdef inline EmptyNode GN_join(GraphNode self, tuple nodes): HANDLE_RETURN(cydriver.cuGraphAddEmptyNode( &new_node, as_cu(h_graph), deps_ptr, num_deps)) - self._succ_cache = None - for other in nodes: - (other)._succ_cache = None return EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph)) @@ -753,7 +671,6 @@ cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options): HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode( &new_node, as_cu(h_graph), deps, num_deps, &alloc_params)) - self._succ_cache = None return AllocNode._create_with_params( create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size, device_id, memory_type, tuple(peer_ids)) @@ -774,7 +691,6 @@ cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr): HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode( &new_node, as_cu(h_graph), deps, num_deps, c_dptr)) - self._succ_cache = None return FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr) @@ -810,7 +726,6 @@ cdef inline MemsetNode GN_memset( &new_node, as_cu(h_graph), deps, num_deps, &memset_params, ctx)) - self._succ_cache = None return MemsetNode._create_with_params( create_graph_node_handle(new_node, h_graph), c_dst, val, elem_size, width, height, pitch) @@ -872,7 +787,6 @@ cdef inline MemcpyNode GN_memcpy( HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode( &new_node, as_cu(h_graph), deps, num_deps, ¶ms, ctx)) - self._succ_cache = None return MemcpyNode._create_with_params( create_graph_node_handle(new_node, h_graph), c_dst, c_src, size, c_dst_type, c_src_type) @@ -900,7 +814,6 @@ cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def): cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph) - self._succ_cache = None return ChildGraphNode._create_with_params( create_graph_node_handle(new_node, h_graph), h_embedded) @@ -923,7 +836,6 @@ cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev): _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), _destroy_event_handle_copy) - self._succ_cache = None return EventRecordNode._create_with_params( create_graph_node_handle(new_node, h_graph), ev._h_event) @@ -946,7 +858,6 @@ cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev): _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), _destroy_event_handle_copy) - self._succ_cache = None return EventWaitNode._create_with_params( create_graph_node_handle(new_node, h_graph), ev._h_event) @@ -974,7 +885,6 @@ cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_ &new_node, as_cu(h_graph), deps, num_deps, &node_params)) cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None - self._succ_cache = None return HostCallbackNode._create_with_params( create_graph_node_handle(new_node, h_graph), callable_obj, node_params.fn, node_params.userData) diff --git a/cuda_core/tests/graph/test_graphdef.py b/cuda_core/tests/graph/test_graphdef.py index 3412d71847..be6da9515a 100644 --- a/cuda_core/tests/graph/test_graphdef.py +++ b/cuda_core/tests/graph/test_graphdef.py @@ -712,8 +712,8 @@ def test_graphdef_entry_is_virtual(sample_graphdef): """Internal entry node is virtual (no pred/succ, type is None).""" entry = sample_graphdef._entry assert isinstance(entry, GraphNode) - assert entry.pred == () - assert entry.succ == () + assert entry.pred == set() + assert entry.succ == set() assert entry.type is None diff --git a/cuda_core/tests/graph/test_graphdef_lifetime.py b/cuda_core/tests/graph/test_graphdef_lifetime.py index 133f2c7ca1..50bc009629 100644 --- a/cuda_core/tests/graph/test_graphdef_lifetime.py +++ b/cuda_core/tests/graph/test_graphdef_lifetime.py @@ -477,7 +477,7 @@ def test_kernel_node_reconstruction_preserves_validity(init_cuda): # Reconstruct the kernel node through DAG traversal # successor.pred -> GraphNode._create -> KernelNode._create_from_driver # -> create_kernel_handle_ref -> handle recovery - reconstructed = successor.pred[0] + reconstructed = next(iter(successor.pred)) assert isinstance(reconstructed, KernelNode) assert reconstructed.kernel.attributes.max_threads_per_block() > 0 From 5fbd288f5c9e567c72eb6ddd045ece81f6be7d88 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Wed, 1 Apr 2026 21:50:34 -0700 Subject: [PATCH 04/21] Add edge mutation support and MutableSet interface for GraphNode adjacencies Enable adding/removing edges between graph nodes via AdjacencySet (a MutableSet proxy on GraphNode.pred/succ), node removal via discard(), and property setters for bulk edge replacement. Includes comprehensive mutation and interface tests. Closes part of #1330 (step 2: edge mutation on GraphDef). Made-with: Cursor --- .../core/_graph/_graph_def/_adjacency_set.pyx | 6 + .../core/_graph/_graph_def/_graph_def.pyx | 8 +- .../core/_graph/_graph_def/_graph_node.pyx | 16 +- .../tests/graph/test_graphdef_lifetime.py | 4 +- .../tests/graph/test_graphdef_mutation.py | 323 ++++++++++++++++++ .../helpers/collection_interface_testers.py | 142 ++++++++ cuda_core/tests/helpers/graph_kernels.py | 22 +- 7 files changed, 512 insertions(+), 9 deletions(-) create mode 100644 cuda_core/tests/graph/test_graphdef_mutation.py create mode 100644 cuda_core/tests/helpers/collection_interface_testers.py diff --git a/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set.pyx b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set.pyx index 124a7a5610..059e19cd9f 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set.pyx @@ -28,6 +28,10 @@ class AdjacencySet(MutableSet): def __init__(self, node, bint is_fwd): self._core = _AdjacencySetCore(node, is_fwd) + @classmethod + def _from_iterable(cls, it): + return set(it) + # --- abstract methods required by MutableSet --- def __contains__(self, x): @@ -45,6 +49,8 @@ class AdjacencySet(MutableSet): if not isinstance(value, GraphNode): raise TypeError( f"expected GraphNode, got {type(value).__name__}") + if value in self: + return (<_AdjacencySetCore>self._core).add_edge(value) def discard(self, value): diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx index d45c72ba2a..83f8517923 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx @@ -314,7 +314,7 @@ cdef class GraphDef: with nogil: HANDLE_RETURN(cydriver.cuGraphDebugDotPrint(as_cu(self._h_graph), c_path, flags)) - def nodes(self) -> tuple: + def nodes(self) -> set: """Return all nodes in the graph. Returns @@ -335,9 +335,9 @@ cdef class GraphDef: with nogil: HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes)) - return tuple(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes)) + return set(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes)) - def edges(self) -> tuple: + def edges(self) -> set: """Return all edges in the graph as (from_node, to_node) pairs. Returns @@ -369,7 +369,7 @@ cdef class GraphDef: HANDLE_RETURN(cydriver.cuGraphGetEdges( as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges)) - return tuple( + return set( (GraphNode._create(self._h_graph, from_nodes[i]), GraphNode._create(self._h_graph, to_nodes[i])) for i in range(num_edges) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx index f3ec73130b..da1d5dcac0 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx @@ -123,8 +123,8 @@ cdef class GraphNode: """ return as_py(self._h_node) - def remove(self): - """Remove this node and all its edges from the parent graph.""" + def discard(self): + """Discard this node and remove all its edges from the parent graph.""" cdef cydriver.CUgraphNode node = as_cu(self._h_node) with nogil: HANDLE_RETURN(cydriver.cuGraphDestroyNode(node)) @@ -134,11 +134,23 @@ cdef class GraphNode: """A mutable set-like view of this node's predecessors.""" return AdjacencySet(self, False) + @pred.setter + def pred(self, value): + p = AdjacencySet(self, False) + p.clear() + p.update(value) + @property def succ(self): """A mutable set-like view of this node's successors.""" return AdjacencySet(self, True) + @succ.setter + def succ(self, value): + s = AdjacencySet(self, True) + s.clear() + s.update(value) + def launch(self, config: LaunchConfig, kernel: Kernel, *args) -> KernelNode: """Add a kernel launch node depending on this node. diff --git a/cuda_core/tests/graph/test_graphdef_lifetime.py b/cuda_core/tests/graph/test_graphdef_lifetime.py index 50bc009629..3b254d423f 100644 --- a/cuda_core/tests/graph/test_graphdef_lifetime.py +++ b/cuda_core/tests/graph/test_graphdef_lifetime.py @@ -68,7 +68,7 @@ def test_branches_survive_parent_deletion(init_cuda, builder, expected_count): gc.collect() for branch in branches: - assert branch.nodes() == () + assert branch.nodes() == set() @pytest.mark.parametrize("builder, expected_count", _COND_BUILDERS) @@ -108,7 +108,7 @@ def test_reconstructed_body_survives_parent_deletion(init_cuda): del g, condition, all_nodes, cond_nodes, branches gc.collect() - assert body.nodes() == () + assert body.nodes() == set() # ============================================================================= diff --git a/cuda_core/tests/graph/test_graphdef_mutation.py b/cuda_core/tests/graph/test_graphdef_mutation.py new file mode 100644 index 0000000000..2abea16e3c --- /dev/null +++ b/cuda_core/tests/graph/test_graphdef_mutation.py @@ -0,0 +1,323 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +"""Tests for mutating a graph definition (edge changes, node removal).""" + +import pytest +from helpers.collection_interface_testers import assert_mutable_set_interface +from helpers.graph_kernels import compile_parallel_kernels + +from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource +from cuda.core._graph._graph_def import GraphDef, KernelNode, MemsetNode + +try: + import numpy as np + + _has_numpy_2_1 = tuple(int(i) for i in np.__version__.split(".")[:2]) >= (2, 1) +except ImportError: + np = None + _has_numpy_2_1 = False + +_need_numpy_2_1 = pytest.mark.skipif(not _has_numpy_2_1, reason="need numpy 2.1+") + + +class YRig: + """Test rigging for graph mutation tests. Constructs a Y-shaped graph with + two parallel arms joined by a combine node. Modifying the sequence of + operations along either arm changes the output. + + Topology:: + + a0 -- a1 -- a2 + \ + j -- r + / + b0 -- b1 + + Each a/b node applies ``affine(ptr, m, b)`` to its arm's int accumulator. + Node r computes result ``combine(R, A, B) = (A << 16) | (B & 0xFFFF)``, + encoding both arms' results into a single int. j is a joining (empty) node + preceeding r. + """ + + def __init__(self): + self.A_OPS = [(2, 1), (3, 2), (5, 3)] + self.B_OPS = [(2, 7), (3, 1)] + + mod = compile_parallel_kernels() + self.affine = mod.get_kernel("affine") + self.combine = mod.get_kernel("combine") + self.config = LaunchConfig(grid=1, block=1) + + self._mr = LegacyPinnedMemoryResource() + self._buf = self._mr.allocate(3 * 4) + self._arr = np.from_dlpack(self._buf).view(np.int32) + + self.ptr_a = self._arr[0:].ctypes.data + self.ptr_b = self._arr[1:].ctypes.data + self.ptr_r = self._arr[2:].ctypes.data + + self.graph_def = GraphDef() + self.stream = None + + # Arm A + self.a = [] + prev = self.graph_def + for m, b in self.A_OPS: + prev = prev.launch(self.config, self.affine, self.ptr_a, m, b) + self.a.append(prev) + + # Arm B + self.b = [] + prev = self.graph_def + for m, b in self.B_OPS: + prev = prev.launch(self.config, self.affine, self.ptr_b, m, b) + self.b.append(prev) + + # Join and combine + self.j = self.graph_def.join(self.a[-1], self.b[-1]) + self.r = self.j.launch(self.config, self.combine, self.ptr_r, self.ptr_a, self.ptr_b) + + def run(self): + if self.stream is None: + self.stream = Device().create_stream() + graph = self.graph_def.instantiate() + self.reset() + graph.launch(self.stream) + self.stream.sync() + + def reset(self): + self._arr[:] = 0 + + @property + def A_out(self): + return int(self._arr[0]) + + @property + def B_out(self): + return int(self._arr[1]) + + @property + def R_out(self): + return int(self._arr[2]) + + @property + def output(self): + return self.A_out, self.B_out, self.R_out + + @property + def expected_output(self): + """Expected (A, B, R) after one run from zero.""" + + def apply_affine(val, ops): + for m, b in ops: + val = val * m + b + return val + + a = apply_affine(0, self.A_OPS) + b = apply_affine(0, self.B_OPS) + r = (a << 16) | (b & 0xFFFF) + return (a, b, r) + + @property + def edges(self): + return self.graph_def.edges() + + @property + def initial_edges(self): + return ( + set(zip(self.a, self.a[1:])) + | set(zip(self.b, self.b[1:])) + | {(self.a[-1], self.j), (self.b[-1], self.j), (self.j, self.r)} + ) + + @property + def nodes(self): + return self.graph_def.nodes() + + @property + def initial_nodes(self): + return set(self.a + self.b + [self.j, self.r]) + + def close(self): + self._buf.close() + + +@_need_numpy_2_1 +class TestMutateYRig: + """Tests that mutate the Y-shaped graph built by YRig.""" + + def test_baseline(self, init_cuda): + """Unmodified graph produces the expected results.""" + rig = YRig() + rig.run() + assert rig.output == rig.expected_output + assert rig.edges == rig.initial_edges + assert rig.nodes == rig.initial_nodes + rig.close() + + def test_discard_a1(self, init_cuda): + """Discard a1 (creates a race). Arm b yields the expected value, and the + final result is correctly ordered after b.""" + rig = YRig() + rig.a[1].discard() + rig.run() + _, b_exp, _ = rig.expected_output + assert rig.B_out == b_exp + assert (rig.R_out & 0xFFFF) == b_exp + a0, a1, a2 = rig.a + assert rig.edges == rig.initial_edges - {(a0, a1), (a1, a2)} + assert rig.nodes == rig.initial_nodes - {a1} + rig.close() + + def test_discard_a2(self, init_cuda): + """Discard a2, connect a1--r""" + rig = YRig() + rig.a[2].discard() + rig.a[1].succ.add(rig.r) + rig.A_OPS.pop() + rig.run() + assert rig.output == rig.expected_output + a0, a1, a2, j, r = rig.a + [rig.j, rig.r] + assert rig.edges == (rig.initial_edges - {(a1, a2), (a2, j)}) | {(a1, r)} + assert rig.nodes == rig.initial_nodes - {a2} + rig.close() + + def test_discard_joint(self, init_cuda): + """Remove the joining node and instead add edges directly to r.""" + rig = YRig() + _, _, a2, _, b1, j, r = rig.a + rig.b + [rig.j, rig.r] + j.discard() + r.pred = {a2, b1} + rig.run() + assert rig.output == rig.expected_output + assert rig.edges == (rig.initial_edges - {(a2, j), (b1, j), (j, r)}) | {(a2, r), (b1, r)} + assert rig.nodes == rig.initial_nodes - {j} + rig.close() + + def test_insert_b(self, init_cuda): + """Insert a node into arm b.""" + rig = YRig() + coeffs = 5, 3 + b_new = rig.graph_def.launch(rig.config, rig.affine, rig.ptr_b, *coeffs) + b0, b1 = rig.b + b0.succ.discard(b1) + b0.succ.add(b_new) + b_new.succ.add(b1) + rig.B_OPS.insert(1, coeffs) + rig.run() + assert rig.output == rig.expected_output + assert rig.edges == (rig.initial_edges - {(b0, b1)}) | {(b0, b_new), (b_new, b1)} + assert rig.nodes == rig.initial_nodes | {b_new} + rig.close() + + +def test_adjacency_set_interface(init_cuda): + """Exercise every MutableSet method on AdjacencySet.""" + g = GraphDef() + hub = g.join() + items = [g.join() for _ in range(5)] + assert_mutable_set_interface(hub.succ, items) + + +def test_adjacency_set_pred_direction(init_cuda): + """Verify that pred works symmetrically with succ.""" + g = GraphDef() + target = g.join() + x, y, z = (g.join() for _ in range(3)) + + pred = target.pred + assert pred == set() + + pred.add(x) + pred.add(y) + assert pred == {x, y} + + # Verify the edge is visible from the other direction + assert target in x.succ + assert target in y.succ + assert target not in z.succ + + pred.discard(x) + assert pred == {y} + assert target not in x.succ + + +def test_adjacency_set_property_setter(init_cuda): + """Verify that assigning to node.pred or node.succ replaces all edges.""" + g = GraphDef() + hub = g.join() + a, b, c = (g.join() for _ in range(3)) + + hub.succ = {a, b} + assert hub.succ == {a, b} + + hub.succ = {c} + assert hub.succ == {c} + assert a not in hub.succ + + hub.succ = set() + assert hub.succ == set() + + hub.pred = {a, b} + assert hub.pred == {a, b} + + hub.pred = set() + assert hub.pred == set() + + hub.pred = set() + assert hub.pred == set() + + +@_need_numpy_2_1 +def test_convert_linear_to_fan_in(init_cuda): + """Chain four computations sequentially, then rewire so all pairs run in + parallel feeding into a reduce node. + + Initial topology (sequential):: + + memset0 -- launch0 -- memset1 -- launch1 -- memset2 -- launch2 -- memset3 -- launch3 + + After rewiring (parallel):: + + memset0 -- launch0 --\ + memset1 -- launch1 ---+-- reduce + memset2 -- launch2 --/ + memset3 -- launch3 -/ + """ + mod = compile_parallel_kernels() + affine = mod.get_kernel("affine") + reduce_kern = mod.get_kernel("reduce") + config = LaunchConfig(grid=1, block=1) + + mr = LegacyPinnedMemoryResource() + buf = mr.allocate(5 * 4) + arr = np.from_dlpack(buf).view(np.int32) + arr[:] = 0 + + values = np.array([10, 20, 30, 40], dtype=np.int32) + ptrs = [arr[i:].ctypes.data for i in range(5)] + + # Create the initial graph. + g = GraphDef() + prev = g + for i, val in enumerate(values): + prev = prev.memset(ptrs[i], val, 1).launch(config, affine, ptrs[i], 2, 1) + reduce_node = g.launch(config, reduce_kern, ptrs[4], ptrs[0], 4) + + # Rewire: + # - drop preds from memsets + # - connect results to reduction + for node in g.nodes(): + if isinstance(node, MemsetNode) and len(node.pred): + node.pred = set() + elif isinstance(node, KernelNode) and node != reduce_node: + node.succ = {reduce_node} + + stream = Device().create_stream() + graph = g.instantiate() + graph.launch(stream) + stream.sync() + assert arr[4] == sum(2 * values + 1) + + buf.close() diff --git a/cuda_core/tests/helpers/collection_interface_testers.py b/cuda_core/tests/helpers/collection_interface_testers.py new file mode 100644 index 0000000000..d9b5ee2cd0 --- /dev/null +++ b/cuda_core/tests/helpers/collection_interface_testers.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +"""Reusable helpers to verify collections.abc protocol conformance.""" + +from collections.abc import MutableSet, Set + +import pytest + + +def assert_mutable_set_interface(subject, items): + """Exercise every MutableSet method on *subject* against a reference set. + + Parameters + ---------- + subject : MutableSet + An **empty** mutable-set-like object to test. + items : sequence + At least five distinct, hashable objects valid for insertion into + *subject*. + """ + assert len(items) >= 5 + a, b, c, d, e = items[:5] + ref = set() + + # -- ABC conformance -- + assert isinstance(subject, Set) + assert isinstance(subject, MutableSet) + + # -- empty state -- + assert len(subject) == 0 + assert subject == ref + assert subject == set() + assert list(subject) == [] + + # -- add -- + subject.add(a) + ref.add(a) + assert subject == ref + assert a in subject + assert b not in subject + assert len(subject) == 1 + + subject.add(b) + subject.add(c) + ref.update({b, c}) + assert subject == ref + assert len(subject) == 3 + + # add duplicate is a no-op + subject.add(a) + assert subject == ref + + # -- discard -- + subject.discard(b) + ref.discard(b) + assert subject == ref + + # discard non-member is a no-op + subject.discard(d) + assert subject == ref + + # -- remove -- + subject.add(b) + ref.add(b) + subject.remove(b) + ref.remove(b) + assert subject == ref + + with pytest.raises(KeyError): + subject.remove(d) + + # -- comparison with plain set -- + assert subject == {a, c} + assert subject != {a, b} + + # -- isdisjoint -- + assert subject.isdisjoint({d, e}) + assert not subject.isdisjoint({a, d}) + + # -- subset / superset -- + assert subject <= {a, c} + assert subject <= {a, b, c} + assert not (subject <= {a}) + assert subject < {a, b, c} + assert not (subject < {a, c}) + assert {a, c} >= subject + assert {a, b, c} > subject + + # -- binary operators -- + assert subject & {a, d} == {a} + assert subject | {d} == {a, c, d} + assert subject - {c} == {a} + assert subject ^ {c, d} == {a, d} + + # -- in-place union (|=) -- + subject |= {d, e} + ref |= {d, e} + assert subject == ref + + # -- in-place intersection (&=) -- + subject &= {a, d, e} + ref &= {a, d, e} + assert subject == ref + + # -- in-place difference (-=) -- + subject -= {e} + ref -= {e} + assert subject == ref + + # -- in-place symmetric difference (^=) -- + subject ^= {a, b} + ref ^= {a, b} + assert subject == ref + + # -- pop -- + popped = subject.pop() + ref.discard(popped) + assert popped not in subject + assert subject == ref + + # -- clear -- + subject.clear() + ref.clear() + assert subject == ref + assert len(subject) == 0 + + with pytest.raises(KeyError): + subject.pop() + + # -- bulk add via |= -- + subject |= {a, b, c} + ref.update({a, b, c}) + assert subject == ref + + # -- __iter__ -- + assert set(subject) == ref + + # -- __repr__ -- + r = repr(subject) + assert isinstance(r, str) + assert len(r) > 0 diff --git a/cuda_core/tests/helpers/graph_kernels.py b/cuda_core/tests/helpers/graph_kernels.py index c38f0bafde..657d7509b2 100644 --- a/cuda_core/tests/helpers/graph_kernels.py +++ b/cuda_core/tests/helpers/graph_kernels.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE """Shared kernel compilation helpers for graph tests.""" @@ -79,3 +79,23 @@ def compile_conditional_kernels(cond_type): nvrtcVersion = handle_return(nvrtc.nvrtcVersion()) pytest.skip(f"NVRTC version {nvrtcVersion} does not support conditionals") return mod + + +def compile_parallel_kernels(): + """Compile kernels for parallel graph tests. + + Returns a module with: + - affine: computes *a = *a * m + b + - combine: computes *s = (*a << 16) | (*b & 0xFFFF) + - reduce: computes a sum. + """ + code = """ + __global__ void affine(int *a, int m, int b) { *a = *a * m + b; } + __global__ void combine(int *s, int *a, int *b) { *s = (*a << 16) | (*b & 0xFFFF); } + __global__ void reduce(int *out, int *in, size_t n) { for(size_t i=0; i Date: Wed, 1 Apr 2026 21:58:41 -0700 Subject: [PATCH 05/21] Use requires_module mark for numpy version checks in mutation tests Replace inline skipif version check with requires_module(np, "2.1") from the shared test helpers, consistent with other test files. Made-with: Cursor --- cuda_core/tests/graph/test_graphdef_mutation.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/cuda_core/tests/graph/test_graphdef_mutation.py b/cuda_core/tests/graph/test_graphdef_mutation.py index 2abea16e3c..3756f2b712 100644 --- a/cuda_core/tests/graph/test_graphdef_mutation.py +++ b/cuda_core/tests/graph/test_graphdef_mutation.py @@ -3,23 +3,14 @@ """Tests for mutating a graph definition (edge changes, node removal).""" -import pytest +import numpy as np from helpers.collection_interface_testers import assert_mutable_set_interface from helpers.graph_kernels import compile_parallel_kernels +from helpers.marks import requires_module from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource from cuda.core._graph._graph_def import GraphDef, KernelNode, MemsetNode -try: - import numpy as np - - _has_numpy_2_1 = tuple(int(i) for i in np.__version__.split(".")[:2]) >= (2, 1) -except ImportError: - np = None - _has_numpy_2_1 = False - -_need_numpy_2_1 = pytest.mark.skipif(not _has_numpy_2_1, reason="need numpy 2.1+") - class YRig: """Test rigging for graph mutation tests. Constructs a Y-shaped graph with @@ -143,7 +134,7 @@ def close(self): self._buf.close() -@_need_numpy_2_1 +@requires_module(np, "2.1") class TestMutateYRig: """Tests that mutate the Y-shaped graph built by YRig.""" @@ -269,7 +260,7 @@ def test_adjacency_set_property_setter(init_cuda): assert hub.pred == set() -@_need_numpy_2_1 +@requires_module(np, "2.1") def test_convert_linear_to_fan_in(init_cuda): """Chain four computations sequentially, then rewire so all pairs run in parallel feeding into a reduce node. From b27dd93ea2ae84476d8fc997752f9ae929dd4f33 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Thu, 2 Apr 2026 07:36:18 -0700 Subject: [PATCH 06/21] Fix empty-graph return type: return set() instead of () for nodes/edges Made-with: Cursor --- cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx index 83f8517923..03673844d5 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx @@ -319,7 +319,7 @@ cdef class GraphDef: Returns ------- - tuple of GraphNode + set of GraphNode All nodes in the graph. """ cdef size_t num_nodes = 0 @@ -328,7 +328,7 @@ cdef class GraphDef: HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), NULL, &num_nodes)) if num_nodes == 0: - return () + return set() cdef vector[cydriver.CUgraphNode] nodes_vec nodes_vec.resize(num_nodes) @@ -342,7 +342,7 @@ cdef class GraphDef: Returns ------- - tuple of tuple + set of tuple Each element is a (from_node, to_node) pair representing a dependency edge in the graph. """ @@ -355,7 +355,7 @@ cdef class GraphDef: HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, &num_edges)) if num_edges == 0: - return () + return set() cdef vector[cydriver.CUgraphNode] from_nodes cdef vector[cydriver.CUgraphNode] to_nodes From 8554d306e0dfeed282f41b689aec2445693ea469 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Thu, 2 Apr 2026 09:22:10 -0700 Subject: [PATCH 07/21] Rename AdjacencySet to AdjacencySetProxy, add bulk ops and safety guards Rename class and file to AdjacencySetProxy to clarify write-through semantics. Add bulk-efficient clear(), __isub__(), __ior__() overrides and remove_edges() on the Cython core. Guard GraphNode.discard() against double-destroy via membership check. Filter duplicates in update(). Add error-path tests for wrong types, cross-graph edges, and self-edges. Made-with: Cursor --- ...cency_set.pyx => _adjacency_set_proxy.pyx} | 68 ++++++++++---- .../core/_graph/_graph_def/_graph_node.pyx | 17 ++-- .../tests/graph/test_graphdef_mutation.py | 88 +++++++++++++++++-- 3 files changed, 144 insertions(+), 29 deletions(-) rename cuda_core/cuda/core/_graph/_graph_def/{_adjacency_set.pyx => _adjacency_set_proxy.pyx} (73%) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set.pyx b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx similarity index 73% rename from cuda_core/cuda/core/_graph/_graph_def/_adjacency_set.pyx rename to cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx index 059e19cd9f..5c5dae1ddd 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx @@ -20,14 +20,16 @@ from collections.abc import MutableSet # ---- Python MutableSet wrapper ---------------------------------------------- -class AdjacencySet(MutableSet): - """Mutable set-like view of a node's predecessors or successors.""" +class AdjacencySetProxy(MutableSet): + """Mutable set proxy for a node's predecessors or successors. Mutations + write through to the underlying CUDA graph.""" __slots__ = ("_core",) def __init__(self, node, bint is_fwd): self._core = _AdjacencySetCore(node, is_fwd) + # Used by operators such as &|^ to create non-proxy views when needed. @classmethod def _from_iterable(cls, it): return set(it) @@ -62,6 +64,22 @@ class AdjacencySet(MutableSet): # --- override for bulk efficiency --- + def clear(self): + """Remove all edges in a single driver call.""" + members = (<_AdjacencySetCore>self._core).query() + if members: + (<_AdjacencySetCore>self._core).remove_edges(members) + + def __isub__(self, it): + """Remove edges to all nodes in *it* in a single driver call.""" + if it is self: + self.clear() + else: + to_remove = [v for v in it if isinstance(v, GraphNode) and v in self] + if to_remove: + (<_AdjacencySetCore>self._core).remove_edges(to_remove) + return self + def update(self, *others): """Add edges to multiple nodes at once.""" nodes = [] @@ -76,21 +94,28 @@ class AdjacencySet(MutableSet): if not isinstance(n, GraphNode): raise TypeError( f"expected GraphNode, got {type(n).__name__}") - (<_AdjacencySetCore>self._core).add_edges(nodes) + new = [n for n in nodes if n not in self] + if new: + (<_AdjacencySetCore>self._core).add_edges(new) + + def __ior__(self, it): + """Add edges to all nodes in *it* in a single driver call.""" + self.update(it) + return self def __repr__(self): return "{" + ", ".join(repr(n) for n in self) + "}" -# ---- cdef core holding function pointer ------------------------------------ +# ---- cdef core holding a function pointer ------------------------------------ -# Signature shared by _get_preds and _get_succs. +# Signature shared by driver_get_preds and driver_get_succs. ctypedef cydriver.CUresult (*_adj_fn_t)( cydriver.CUgraphNode, cydriver.CUgraphNode*, size_t*) noexcept nogil cdef class _AdjacencySetCore: - """Cythonized core implementing AdjacencySet""" + """Cythonized core implementing AdjacencySetProxy""" cdef: GraphNodeHandle _h_node GraphHandle _h_graph @@ -101,7 +126,7 @@ cdef class _AdjacencySetCore: self._h_node = node._h_node self._h_graph = graph_node_get_graph(node._h_node) self._is_fwd = is_fwd - self._query_fn = _get_succs if is_fwd else _get_preds + self._query_fn = driver_get_succs if is_fwd else driver_get_preds cdef inline void _resolve_edge( self, GraphNode other, @@ -144,15 +169,28 @@ cdef class _AdjacencySetCore: cdef cydriver.CUgraphNode c_from, c_to self._resolve_edge(other, &c_from, &c_to) with nogil: - HANDLE_RETURN(_add_edge(as_cu(self._h_graph), &c_from, &c_to, 1)) + HANDLE_RETURN(driver_add_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) + + cdef void add_edges(self, list nodes): + cdef size_t n = len(nodes) + cdef vector[cydriver.CUgraphNode] from_vec + cdef vector[cydriver.CUgraphNode] to_vec + from_vec.resize(n) + to_vec.resize(n) + cdef size_t i + for i in range(n): + self._resolve_edge(nodes[i], &from_vec[i], &to_vec[i]) + with nogil: + HANDLE_RETURN(driver_add_edges( + as_cu(self._h_graph), from_vec.data(), to_vec.data(), n)) cdef void remove_edge(self, GraphNode other): cdef cydriver.CUgraphNode c_from, c_to self._resolve_edge(other, &c_from, &c_to) with nogil: - HANDLE_RETURN(_remove_edge(as_cu(self._h_graph), &c_from, &c_to, 1)) + HANDLE_RETURN(driver_remove_edges(as_cu(self._h_graph), &c_from, &c_to, 1)) - cdef void add_edges(self, list nodes): + cdef void remove_edges(self, list nodes): cdef size_t n = len(nodes) cdef vector[cydriver.CUgraphNode] from_vec cdef vector[cydriver.CUgraphNode] to_vec @@ -162,13 +200,13 @@ cdef class _AdjacencySetCore: for i in range(n): self._resolve_edge(nodes[i], &from_vec[i], &to_vec[i]) with nogil: - HANDLE_RETURN(_add_edge( + HANDLE_RETURN(driver_remove_edges( as_cu(self._h_graph), from_vec.data(), to_vec.data(), n)) # ---- driver wrappers: absorb CUDA version differences ---- -cdef cydriver.CUresult _get_preds( +cdef inline cydriver.CUresult driver_get_preds( cydriver.CUgraphNode node, cydriver.CUgraphNode* out, size_t* count) noexcept nogil: IF CUDA_CORE_BUILD_MAJOR >= 13: @@ -177,7 +215,7 @@ cdef cydriver.CUresult _get_preds( return cydriver.cuGraphNodeGetDependencies(node, out, count) -cdef cydriver.CUresult _get_succs( +cdef inline cydriver.CUresult driver_get_succs( cydriver.CUgraphNode node, cydriver.CUgraphNode* out, size_t* count) noexcept nogil: IF CUDA_CORE_BUILD_MAJOR >= 13: @@ -186,7 +224,7 @@ cdef cydriver.CUresult _get_succs( return cydriver.cuGraphNodeGetDependentNodes(node, out, count) -cdef cydriver.CUresult _add_edge( +cdef inline cydriver.CUresult driver_add_edges( cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr, cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil: IF CUDA_CORE_BUILD_MAJOR >= 13: @@ -197,7 +235,7 @@ cdef cydriver.CUresult _add_edge( graph, from_arr, to_arr, count) -cdef cydriver.CUresult _remove_edge( +cdef inline cydriver.CUresult driver_remove_edges( cydriver.CUgraph graph, cydriver.CUgraphNode* from_arr, cydriver.CUgraphNode* to_arr, size_t count) noexcept nogil: IF CUDA_CORE_BUILD_MAJOR >= 13: diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx index da1d5dcac0..e17f1b30e7 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx @@ -57,7 +57,7 @@ from cuda.core._graph._utils cimport ( ) from cuda.core import Device -from cuda.core._graph._graph_def._adjacency_set import AdjacencySet +from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy from cuda.core._utils.cuda_utils import driver, handle_return @@ -124,7 +124,12 @@ cdef class GraphNode: return as_py(self._h_node) def discard(self): - """Discard this node and remove all its edges from the parent graph.""" + """Discard this node and remove all its edges from the parent graph. + + Safe to call on an already-discarded node (no-op). + """ + if self not in self.graph.nodes(): + return cdef cydriver.CUgraphNode node = as_cu(self._h_node) with nogil: HANDLE_RETURN(cydriver.cuGraphDestroyNode(node)) @@ -132,22 +137,22 @@ cdef class GraphNode: @property def pred(self): """A mutable set-like view of this node's predecessors.""" - return AdjacencySet(self, False) + return AdjacencySetProxy(self, False) @pred.setter def pred(self, value): - p = AdjacencySet(self, False) + p = AdjacencySetProxy(self, False) p.clear() p.update(value) @property def succ(self): """A mutable set-like view of this node's successors.""" - return AdjacencySet(self, True) + return AdjacencySetProxy(self, True) @succ.setter def succ(self, value): - s = AdjacencySet(self, True) + s = AdjacencySetProxy(self, True) s.clear() s.update(value) diff --git a/cuda_core/tests/graph/test_graphdef_mutation.py b/cuda_core/tests/graph/test_graphdef_mutation.py index 3756f2b712..b42556b2b8 100644 --- a/cuda_core/tests/graph/test_graphdef_mutation.py +++ b/cuda_core/tests/graph/test_graphdef_mutation.py @@ -4,12 +4,15 @@ """Tests for mutating a graph definition (edge changes, node removal).""" import numpy as np +import pytest from helpers.collection_interface_testers import assert_mutable_set_interface from helpers.graph_kernels import compile_parallel_kernels from helpers.marks import requires_module +from cuda.bindings import driver from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource from cuda.core._graph._graph_def import GraphDef, KernelNode, MemsetNode +from cuda.core._utils.cuda_utils import CUDAError class YRig: @@ -29,6 +32,9 @@ class YRig: Node r computes result ``combine(R, A, B) = (A << 16) | (B & 0xFFFF)``, encoding both arms' results into a single int. j is a joining (empty) node preceeding r. + + The affine operation a * m + b is noncommutative, so we can be sure the + graph has exactly the topology we expect by checking the final value. """ def __init__(self): @@ -148,8 +154,8 @@ def test_baseline(self, init_cuda): rig.close() def test_discard_a1(self, init_cuda): - """Discard a1 (creates a race). Arm b yields the expected value, and the - final result is correctly ordered after b.""" + """Discard a1 (creates a race on arm a). Arm b yields the expected + value, and the final step is correctly ordered after b completes.""" rig = YRig() rig.a[1].discard() rig.run() @@ -204,7 +210,7 @@ def test_insert_b(self, init_cuda): def test_adjacency_set_interface(init_cuda): - """Exercise every MutableSet method on AdjacencySet.""" + """Exercise every MutableSet method on AdjacencySetProxy.""" g = GraphDef() hub = g.join() items = [g.join() for _ in range(5)] @@ -260,6 +266,67 @@ def test_adjacency_set_property_setter(init_cuda): assert hub.pred == set() +def test_discarded_node(init_cuda): + """Test uses of discarded nodes.""" + mr = LegacyPinnedMemoryResource() + buf = mr.allocate(4) + arr = np.from_dlpack(buf).view(np.int32) + arr[:] = 0 + ptr = arr[0:].ctypes.data + + g = GraphDef() + a = g.memset(ptr, 0, 4) + b = a.memset(ptr, 42, 4) + + assert b in g.nodes() + assert (a, b) in g.edges() + + b.discard() + + # b is removed from the graph but still usable + assert b not in g.nodes() + assert (a, b) not in g.edges() + assert isinstance(b, MemsetNode) + assert b.type == driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_KERNEL + assert b.pred == set() + assert b.succ == set() + assert b.handle != 0 + assert b.dptr == ptr + assert b.value == 42 + assert b.width == 4 + + # Repeated discard succeeds quietly. + b.discard() + + +def test_add_wrong_type(init_cuda): + """Adding a non-GraphNode raises TypeError.""" + g = GraphDef() + node = g.join() + with pytest.raises(TypeError, match="expected GraphNode"): + node.succ.add("not a node") + with pytest.raises(TypeError, match="expected GraphNode"): + node.succ.add(42) + + +def test_cross_graph_edge(init_cuda): + """Adding an edge to a node from a different graph raises CUDAError.""" + g1 = GraphDef() + g2 = GraphDef() + a = g1.join() + b = g2.join() + with pytest.raises(CUDAError): + a.succ.add(b) + + +def test_self_edge(init_cuda): + """Adding a self-edge raises CUDAError.""" + g = GraphDef() + node = g.join() + with pytest.raises(CUDAError): + node.succ.add(node) + + @requires_module(np, "2.1") def test_convert_linear_to_fan_in(init_cuda): """Chain four computations sequentially, then rewire so all pairs run in @@ -293,17 +360,22 @@ def test_convert_linear_to_fan_in(init_cuda): g = GraphDef() prev = g for i, val in enumerate(values): - prev = prev.memset(ptrs[i], val, 1).launch(config, affine, ptrs[i], 2, 1) + prev = prev.memset(ptrs[i], val, 1) + prev = prev.launch(config, affine, ptrs[i], 2, 1) reduce_node = g.launch(config, reduce_kern, ptrs[4], ptrs[0], 4) # Rewire: # - drop preds from memsets - # - connect results to reduction + # - connect kernel launches to the reduction + assert len(g.edges()) == 7 + for node in g.nodes(): - if isinstance(node, MemsetNode) and len(node.pred): - node.pred = set() + if isinstance(node, MemsetNode): + node.pred.clear() elif isinstance(node, KernelNode) and node != reduce_node: - node.succ = {reduce_node} + node.succ.add(reduce_node) + + assert len(g.edges()) == 8 stream = Device().create_stream() graph = g.instantiate() From 9813c20b1d736af0c0c4ae59575e1288377b9567 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Thu, 2 Apr 2026 15:12:24 -0700 Subject: [PATCH 08/21] Add destroy() method with handle invalidation, remove GRAPH_NODE_SENTINEL Replace discard() with destroy() which calls cuGraphDestroyNode and then zeroes the CUgraphNode resource in the handle box via invalidate_graph_node_handle. This prevents stale memory access on destroyed nodes. Properties (type, pred, succ, handle) degrade gracefully to None/empty for destroyed nodes. Remove the GRAPH_NODE_SENTINEL (0x1) approach in favor of using NULL for both sentinels and destroyed nodes, which is simpler and avoids the risk of passing 0x1 to driver APIs that treat it as a valid pointer. Made-with: Cursor --- cuda_core/cuda/core/_cpp/resource_handles.cpp | 8 +++- cuda_core/cuda/core/_cpp/resource_handles.hpp | 3 ++ .../core/_graph/_graph_def/_graph_node.pyx | 22 +++++++-- cuda_core/cuda/core/_resource_handles.pxd | 1 + cuda_core/cuda/core/_resource_handles.pyx | 2 + .../tests/graph/test_graphdef_mutation.py | 48 +++++++++++-------- 6 files changed, 58 insertions(+), 26 deletions(-) diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 0e3d2d7857..2355d64717 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -957,7 +957,7 @@ GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent) namespace { struct GraphNodeBox { - CUgraphNode resource; + mutable CUgraphNode resource; GraphHandle h_graph; }; } // namespace @@ -978,6 +978,12 @@ GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept { return h ? get_box(h)->h_graph : GraphHandle{}; } +void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept { + if (h) { + get_box(h)->resource = nullptr; + } +} + // ============================================================================ // Graphics Resource Handles // ============================================================================ diff --git a/cuda_core/cuda/core/_cpp/resource_handles.hpp b/cuda_core/cuda/core/_cpp/resource_handles.hpp index 92d3cd4669..064f1406f6 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.hpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.hpp @@ -415,6 +415,9 @@ GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_ // Extract the owning graph handle from a node handle. GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept; +// Zero the CUgraphNode resource inside the handle, marking it invalid. +void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept; + // ============================================================================ // Graphics resource handle functions // ============================================================================ diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx index e17f1b30e7..4048c9ee06 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx @@ -48,6 +48,7 @@ from cuda.core._resource_handles cimport ( create_graph_handle_ref, create_graph_node_handle, graph_node_get_graph, + invalidate_graph_node_handle, ) from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value @@ -123,16 +124,27 @@ cdef class GraphNode: """ return as_py(self._h_node) - def discard(self): - """Discard this node and remove all its edges from the parent graph. + @property + def is_valid(self): + """Whether this node is valid (not destroyed). - Safe to call on an already-discarded node (no-op). + Returns ``False`` after :meth:`destroy` has been called. + """ + return as_intptr(self._h_node) != 0 + + def destroy(self): + """Destroy this node and remove all its edges from the parent graph. + + After this call, :attr:`is_valid` returns ``False`` and the node + cannot be re-added to any graph. Safe to call on an + already-destroyed node (no-op). """ - if self not in self.graph.nodes(): - return cdef cydriver.CUgraphNode node = as_cu(self._h_node) + if node == NULL: + return with nogil: HANDLE_RETURN(cydriver.cuGraphDestroyNode(node)) + invalidate_graph_node_handle(self._h_node) @property def pred(self): diff --git a/cuda_core/cuda/core/_resource_handles.pxd b/cuda_core/cuda/core/_resource_handles.pxd index 419106f04a..f847e60223 100644 --- a/cuda_core/cuda/core/_resource_handles.pxd +++ b/cuda_core/cuda/core/_resource_handles.pxd @@ -186,6 +186,7 @@ cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHand # Graph node handles cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil +cdef void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept nogil # Graphics resource handles cdef GraphicsResourceHandle create_graphics_resource_handle( diff --git a/cuda_core/cuda/core/_resource_handles.pyx b/cuda_core/cuda/core/_resource_handles.pyx index 39b425b9ed..001f9b4a0c 100644 --- a/cuda_core/cuda/core/_resource_handles.pyx +++ b/cuda_core/cuda/core/_resource_handles.pyx @@ -159,6 +159,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil GraphHandle graph_node_get_graph "cuda_core::graph_node_get_graph" ( const GraphNodeHandle& h) noexcept nogil + void invalidate_graph_node_handle "cuda_core::invalidate_graph_node_handle" ( + const GraphNodeHandle& h) noexcept nogil # Graphics resource handles GraphicsResourceHandle create_graphics_resource_handle "cuda_core::create_graphics_resource_handle" ( diff --git a/cuda_core/tests/graph/test_graphdef_mutation.py b/cuda_core/tests/graph/test_graphdef_mutation.py index b42556b2b8..dcfd4aab89 100644 --- a/cuda_core/tests/graph/test_graphdef_mutation.py +++ b/cuda_core/tests/graph/test_graphdef_mutation.py @@ -9,7 +9,6 @@ from helpers.graph_kernels import compile_parallel_kernels from helpers.marks import requires_module -from cuda.bindings import driver from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource from cuda.core._graph._graph_def import GraphDef, KernelNode, MemsetNode from cuda.core._utils.cuda_utils import CUDAError @@ -153,11 +152,11 @@ def test_baseline(self, init_cuda): assert rig.nodes == rig.initial_nodes rig.close() - def test_discard_a1(self, init_cuda): - """Discard a1 (creates a race on arm a). Arm b yields the expected + def test_destroy_a1(self, init_cuda): + """Destroy a1 (creates a race on arm a). Arm b yields the expected value, and the final step is correctly ordered after b completes.""" rig = YRig() - rig.a[1].discard() + rig.a[1].destroy() rig.run() _, b_exp, _ = rig.expected_output assert rig.B_out == b_exp @@ -167,10 +166,10 @@ def test_discard_a1(self, init_cuda): assert rig.nodes == rig.initial_nodes - {a1} rig.close() - def test_discard_a2(self, init_cuda): - """Discard a2, connect a1--r""" + def test_destroy_a2(self, init_cuda): + """Destroy a2, connect a1--r""" rig = YRig() - rig.a[2].discard() + rig.a[2].destroy() rig.a[1].succ.add(rig.r) rig.A_OPS.pop() rig.run() @@ -180,11 +179,11 @@ def test_discard_a2(self, init_cuda): assert rig.nodes == rig.initial_nodes - {a2} rig.close() - def test_discard_joint(self, init_cuda): + def test_destroy_joint(self, init_cuda): """Remove the joining node and instead add edges directly to r.""" rig = YRig() _, _, a2, _, b1, j, r = rig.a + rig.b + [rig.j, rig.r] - j.discard() + j.destroy() r.pred = {a2, b1} rig.run() assert rig.output == rig.expected_output @@ -266,8 +265,8 @@ def test_adjacency_set_property_setter(init_cuda): assert hub.pred == set() -def test_discarded_node(init_cuda): - """Test uses of discarded nodes.""" +def test_destroyed_node(init_cuda): + """Test that destroy() invalidates a node.""" mr = LegacyPinnedMemoryResource() buf = mr.allocate(4) arr = np.from_dlpack(buf).view(np.int32) @@ -278,25 +277,34 @@ def test_discarded_node(init_cuda): a = g.memset(ptr, 0, 4) b = a.memset(ptr, 42, 4) + assert a.is_valid + assert b.is_valid assert b in g.nodes() assert (a, b) in g.edges() - b.discard() + b.destroy() - # b is removed from the graph but still usable + assert not b.is_valid assert b not in g.nodes() assert (a, b) not in g.edges() + + # Python object is invalid but using it does not crash. assert isinstance(b, MemsetNode) - assert b.type == driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_KERNEL + assert b.type is None assert b.pred == set() assert b.succ == set() - assert b.handle != 0 - assert b.dptr == ptr - assert b.value == 42 - assert b.width == 4 + assert b.handle is None + assert b.dptr == ptr # tolerable + assert b.value == 42 # tolerable + assert b.width == 4 # tolerable + + # Adding an edge to a destroyed node fails. + with pytest.raises(CUDAError): + a.succ.add(b) - # Repeated discard succeeds quietly. - b.discard() + # Repeated destroy succeeds quietly. + b.destroy() + assert not b.is_valid def test_add_wrong_type(init_cuda): From 6411881b430d3a2d718fe5c699efa56ef6eb7c5f Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Thu, 2 Apr 2026 12:55:24 -0700 Subject: [PATCH 09/21] Add GraphNode identity cache for stable Python object round-trips Nodes retrieved via GraphDef.nodes(), edges(), or pred/succ traversal now return the same Python object that was originally created, enabling identity checks with `is`. A C++ HandleRegistry deduplicates CUgraphNode handles, and a Cython WeakValueDictionary caches the Python wrapper objects. Made-with: Cursor --- cuda_core/cuda/core/_cpp/resource_handles.cpp | 9 ++- .../core/_graph/_graph_def/_graph_node.pyx | 65 ++++++++++++------- cuda_core/tests/graph/test_graphdef.py | 2 + .../tests/graph/test_graphdef_mutation.py | 2 +- 4 files changed, 54 insertions(+), 24 deletions(-) diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 2355d64717..670ddadb1d 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -969,9 +969,16 @@ static const GraphNodeBox* get_box(const GraphNodeHandle& h) { ); } +static HandleRegistry graph_node_registry; + GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) { + if (auto h = graph_node_registry.lookup(node)) { + return h; + } auto box = std::make_shared(GraphNodeBox{node, h_graph}); - return GraphNodeHandle(box, &box->resource); + GraphNodeHandle h(box, &box->resource); + graph_node_registry.register_handle(node, h); + return h; } GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept { diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx index 4048c9ee06..3f4a9d1446 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx @@ -57,10 +57,19 @@ from cuda.core._graph._utils cimport ( _attach_user_object, ) +import weakref + from cuda.core import Device from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy from cuda.core._utils.cuda_utils import driver, handle_return +_node_cache = weakref.WeakValueDictionary() + + +cdef inline GraphNode _cached(GraphNode n): + _node_cache[n._h_node.get()] = n + return n + cdef class GraphNode: """Base class for all graph nodes. @@ -522,18 +531,30 @@ cdef inline ConditionalNode _make_conditional_node( n._cond_type = cond_type n._branches = branches - return n + return _cached(n) + cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node): + cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph) + + # Sentinel: virtual node to represent the graph entry point. if node == NULL: n = GraphNode.__new__(GraphNode) - (n)._h_node = create_graph_node_handle(node, h_graph) + (n)._h_node = h_node return n - cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph) + # Return a cached object or create and cache a new one. + cached = _node_cache.get(h_node.get()) + if cached is not None: + return cached + else: + return _cached(GN_create_impl(h_node)) + + +cdef inline GraphNode GN_create_impl(GraphNodeHandle h_node): cdef cydriver.CUgraphNodeType node_type with nogil: - HANDLE_RETURN(cydriver.cuGraphNodeGetType(node, &node_type)) + HANDLE_RETURN(cydriver.cuGraphNodeGetType(as_cu(h_node), &node_type)) if node_type == cydriver.CU_GRAPH_NODE_TYPE_EMPTY: return EmptyNode._create_impl(h_node) @@ -595,10 +616,10 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker, _attach_user_object(as_cu(h_graph), new KernelHandle(ker._h_kernel), _destroy_kernel_handle_copy) - return KernelNode._create_with_params( + return _cached(KernelNode._create_with_params( create_graph_node_handle(new_node, h_graph), conf.grid, conf.block, conf.shmem_size, - ker._h_kernel) + ker._h_kernel)) cdef inline EmptyNode GN_join(GraphNode self, tuple nodes): @@ -624,7 +645,7 @@ cdef inline EmptyNode GN_join(GraphNode self, tuple nodes): HANDLE_RETURN(cydriver.cuGraphAddEmptyNode( &new_node, as_cu(h_graph), deps_ptr, num_deps)) - return EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph)) + return _cached(EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph))) cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options): @@ -700,9 +721,9 @@ cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options): HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode( &new_node, as_cu(h_graph), deps, num_deps, &alloc_params)) - return AllocNode._create_with_params( + return _cached(AllocNode._create_with_params( create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size, - device_id, memory_type, tuple(peer_ids)) + device_id, memory_type, tuple(peer_ids))) cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr): @@ -720,7 +741,7 @@ cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr): HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode( &new_node, as_cu(h_graph), deps, num_deps, c_dptr)) - return FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr) + return _cached(FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr)) cdef inline MemsetNode GN_memset( @@ -755,9 +776,9 @@ cdef inline MemsetNode GN_memset( &new_node, as_cu(h_graph), deps, num_deps, &memset_params, ctx)) - return MemsetNode._create_with_params( + return _cached(MemsetNode._create_with_params( create_graph_node_handle(new_node, h_graph), c_dst, - val, elem_size, width, height, pitch) + val, elem_size, width, height, pitch)) cdef inline MemcpyNode GN_memcpy( @@ -816,9 +837,9 @@ cdef inline MemcpyNode GN_memcpy( HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode( &new_node, as_cu(h_graph), deps, num_deps, ¶ms, ctx)) - return MemcpyNode._create_with_params( + return _cached(MemcpyNode._create_with_params( create_graph_node_handle(new_node, h_graph), c_dst, c_src, size, - c_dst_type, c_src_type) + c_dst_type, c_src_type)) cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def): @@ -843,8 +864,8 @@ cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def): cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph) - return ChildGraphNode._create_with_params( - create_graph_node_handle(new_node, h_graph), h_embedded) + return _cached(ChildGraphNode._create_with_params( + create_graph_node_handle(new_node, h_graph), h_embedded)) cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev): @@ -865,8 +886,8 @@ cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev): _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), _destroy_event_handle_copy) - return EventRecordNode._create_with_params( - create_graph_node_handle(new_node, h_graph), ev._h_event) + return _cached(EventRecordNode._create_with_params( + create_graph_node_handle(new_node, h_graph), ev._h_event)) cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev): @@ -887,8 +908,8 @@ cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev): _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), _destroy_event_handle_copy) - return EventWaitNode._create_with_params( - create_graph_node_handle(new_node, h_graph), ev._h_event) + return _cached(EventWaitNode._create_with_params( + create_graph_node_handle(new_node, h_graph), ev._h_event)) cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_data): @@ -914,6 +935,6 @@ cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_ &new_node, as_cu(h_graph), deps, num_deps, &node_params)) cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None - return HostCallbackNode._create_with_params( + return _cached(HostCallbackNode._create_with_params( create_graph_node_handle(new_node, h_graph), callable_obj, - node_params.fn, node_params.userData) + node_params.fn, node_params.userData)) diff --git a/cuda_core/tests/graph/test_graphdef.py b/cuda_core/tests/graph/test_graphdef.py index be6da9515a..7f3872eb8d 100644 --- a/cuda_core/tests/graph/test_graphdef.py +++ b/cuda_core/tests/graph/test_graphdef.py @@ -661,6 +661,7 @@ def test_node_type_preserved_by_nodes(node_spec): matched = [n for n in all_nodes if n == node] assert len(matched) == 1 assert isinstance(matched[0], spec.roundtrip_class) + assert matched[0] is node def test_node_type_preserved_by_pred_succ(node_spec): @@ -670,6 +671,7 @@ def test_node_type_preserved_by_pred_succ(node_spec): matched = [s for s in predecessor.succ if s == node] assert len(matched) == 1 assert isinstance(matched[0], spec.roundtrip_class) + assert matched[0] is node def test_node_attrs(node_spec): diff --git a/cuda_core/tests/graph/test_graphdef_mutation.py b/cuda_core/tests/graph/test_graphdef_mutation.py index dcfd4aab89..ac0d8f5e61 100644 --- a/cuda_core/tests/graph/test_graphdef_mutation.py +++ b/cuda_core/tests/graph/test_graphdef_mutation.py @@ -380,7 +380,7 @@ def test_convert_linear_to_fan_in(init_cuda): for node in g.nodes(): if isinstance(node, MemsetNode): node.pred.clear() - elif isinstance(node, KernelNode) and node != reduce_node: + elif isinstance(node, KernelNode) and node is not reduce_node: node.succ.add(reduce_node) assert len(g.edges()) == 8 From 7a3dbb4915d6eda0f705c4f53f32534a673c877d Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Thu, 2 Apr 2026 15:17:12 -0700 Subject: [PATCH 10/21] Purge node cache on destroy to prevent stale identity lookups Made-with: Cursor --- cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx | 1 + 1 file changed, 1 insertion(+) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx index 3f4a9d1446..4bec7422fd 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx @@ -153,6 +153,7 @@ cdef class GraphNode: return with nogil: HANDLE_RETURN(cydriver.cuGraphDestroyNode(node)) + _node_cache.pop(self._h_node.get(), None) invalidate_graph_node_handle(self._h_node) @property From 91b3b4e7dae8ae0580f9653b0aae53fde3607cfb Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Thu, 2 Apr 2026 16:34:14 -0700 Subject: [PATCH 11/21] Skip NULL nodes in graph_node_registry to fix sentinel identity collision Sentinel (entry) nodes use NULL as their CUgraphNode, so caching them under a NULL key caused all sentinels across different graphs to share the same handle. This made nodes built from the wrong graph's entry point, causing CUDA_ERROR_INVALID_VALUE for conditional nodes and hash collisions in equality tests. Made-with: Cursor --- cuda_core/cuda/core/_cpp/resource_handles.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 670ddadb1d..dabc9b8253 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -972,12 +972,16 @@ static const GraphNodeBox* get_box(const GraphNodeHandle& h) { static HandleRegistry graph_node_registry; GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) { - if (auto h = graph_node_registry.lookup(node)) { - return h; + if (node) { + if (auto h = graph_node_registry.lookup(node)) { + return h; + } } auto box = std::make_shared(GraphNodeBox{node, h_graph}); GraphNodeHandle h(box, &box->resource); - graph_node_registry.register_handle(node, h); + if (node) { + graph_node_registry.register_handle(node, h); + } return h; } From 1b7743dd8e508ba335dca61c66beaa1623a7a7f2 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 3 Apr 2026 10:19:31 -0700 Subject: [PATCH 12/21] Unregister destroyed nodes from C++ graph_node_registry When a node is destroyed, the driver may reuse its CUgraphNode pointer for a new node. Without unregistering the old entry, the registry returns a stale handle pointing to the wrong node type and graph. Made-with: Cursor --- cuda_core/cuda/core/_cpp/resource_handles.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index dabc9b8253..dd8a3d04cb 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -991,6 +991,10 @@ GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept { void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept { if (h) { + CUgraphNode node = get_box(h)->resource; + if (node) { + graph_node_registry.unregister_handle(node); + } get_box(h)->resource = nullptr; } } From 84f0b30b01e9f06cc6482ef57822ff39dcdfcafb Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 3 Apr 2026 10:29:20 -0700 Subject: [PATCH 13/21] Add dedicated test for node identity preservation through round-trips Made-with: Cursor --- cuda_core/tests/graph/test_graphdef.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/cuda_core/tests/graph/test_graphdef.py b/cuda_core/tests/graph/test_graphdef.py index 7f3872eb8d..562f720ca8 100644 --- a/cuda_core/tests/graph/test_graphdef.py +++ b/cuda_core/tests/graph/test_graphdef.py @@ -699,6 +699,31 @@ def test_node_attrs_preserved_by_nodes(node_spec): assert getattr(retrieved, attr) == getattr(node, attr), f"{spec.name}.{attr} not preserved by nodes()" +def test_identity_preservation(init_cuda): + """Round-trips through nodes(), edges(), and pred/succ return extant + objects rather than duplicates.""" + g = GraphDef() + a = g.join() + b = a.join() + + # nodes() + assert any(x is a for x in g.nodes()) + assert any(x is b for x in g.nodes()) + + # succ/pred + a.succ = {b} + (b2,) = a.succ + assert b2 is b + + (a2,) = b.pred + assert a2 is a + + # edges() + ((a2, b2),) = g.edges() + assert a2 is a + assert b2 is b + + # ============================================================================= # GraphDef basics # ============================================================================= From 6b36e4719c83bea9978cf898eb34e27b219ab622 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 3 Apr 2026 13:14:53 -0700 Subject: [PATCH 14/21] Add handle= to all GraphNode subclass __repr__ for debugging Every subclass repr now starts with handle=0x... (the CUgraphNode pointer) followed by type-specific identity/parameter data. Dynamic queries (pred counts, subnode counts) are removed in favor of deterministic, cheap fields. This makes set comparison failures in test output readable when debugging graph mutation tests. Made-with: Cursor --- .../core/_graph/_graph_def/_subclasses.pyx | 55 ++++++++++--------- cuda_core/tests/test_object_protocols.py | 28 +++++----- 2 files changed, 43 insertions(+), 40 deletions(-) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_subclasses.pyx b/cuda_core/cuda/core/_graph/_graph_def/_subclasses.pyx index 2c78b3b0ac..e1ab3a460a 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_subclasses.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_subclasses.pyx @@ -58,8 +58,7 @@ cdef class EmptyNode(GraphNode): return n def __repr__(self) -> str: - cdef Py_ssize_t n = len(self.pred) - return f"" + return f"" cdef class KernelNode(GraphNode): @@ -108,7 +107,8 @@ cdef class KernelNode(GraphNode): h_kernel) def __repr__(self) -> str: - return (f"") + return (f"") @property def grid(self) -> tuple: @@ -207,7 +207,8 @@ cdef class AllocNode(GraphNode): params.poolProps.location.id, memory_type, tuple(peer_ids)) def __repr__(self) -> str: - return f"" + return (f"") @property def dptr(self) -> int: @@ -273,7 +274,7 @@ cdef class FreeNode(GraphNode): return FreeNode._create_with_params(h_node, dptr) def __repr__(self) -> str: - return f"" + return f"" @property def dptr(self) -> int: @@ -328,8 +329,8 @@ cdef class MemsetNode(GraphNode): params.elementSize, params.width, params.height, params.pitch) def __repr__(self) -> str: - return (f"") + return (f"") @property def dptr(self) -> int: @@ -416,8 +417,8 @@ cdef class MemcpyNode(GraphNode): def __repr__(self) -> str: cdef str dt = "H" if self._dst_type == cydriver.CU_MEMORYTYPE_HOST else "D" cdef str st = "H" if self._src_type == cydriver.CU_MEMORYTYPE_HOST else "D" - return (f"") + return (f"") @property def dst(self) -> int: @@ -465,12 +466,8 @@ cdef class ChildGraphNode(GraphNode): return ChildGraphNode._create_with_params(h_node, h_child) def __repr__(self) -> str: - cdef cydriver.CUgraph g = as_cu(self._h_child_graph) - cdef size_t num_nodes = 0 - with nogil: - HANDLE_RETURN(cydriver.cuGraphGetNodes(g, NULL, &num_nodes)) - cdef Py_ssize_t n = num_nodes - return f"" + return (f"") @property def child_graph(self) -> "GraphDef": @@ -507,7 +504,8 @@ cdef class EventRecordNode(GraphNode): return EventRecordNode._create_with_params(h_node, h_event) def __repr__(self) -> str: - return f"" + return (f"") @property def event(self) -> Event: @@ -544,7 +542,8 @@ cdef class EventWaitNode(GraphNode): return EventWaitNode._create_with_params(h_node, h_event) def __repr__(self) -> str: - return f"" + return (f"") @property def event(self) -> Event: @@ -591,8 +590,10 @@ cdef class HostCallbackNode(GraphNode): def __repr__(self) -> str: if self._callable is not None: name = getattr(self._callable, '__name__', '?') - return f"" - return f"self._fn:x}>" + return (f"") + return (f"self._fn:x}>") @property def callback_fn(self): @@ -672,7 +673,7 @@ cdef class ConditionalNode(GraphNode): return n def __repr__(self) -> str: - return "" + return f"" @property def condition(self) -> Condition | None: @@ -709,7 +710,8 @@ cdef class IfNode(ConditionalNode): """An if-conditional node (1 branch, executes when condition is non-zero).""" def __repr__(self) -> str: - return f"self._condition._c_handle:x}>" + return (f"self._condition._c_handle:x}>") @property def then(self) -> "GraphDef": @@ -721,7 +723,8 @@ cdef class IfElseNode(ConditionalNode): """An if-else conditional node (2 branches).""" def __repr__(self) -> str: - return f"self._condition._c_handle:x}>" + return (f"self._condition._c_handle:x}>") @property def then(self) -> "GraphDef": @@ -738,7 +741,8 @@ cdef class WhileNode(ConditionalNode): """A while-loop conditional node (1 branch, repeats while condition is non-zero).""" def __repr__(self) -> str: - return f"self._condition._c_handle:x}>" + return (f"self._condition._c_handle:x}>") @property def body(self) -> "GraphDef": @@ -750,6 +754,5 @@ cdef class SwitchNode(ConditionalNode): """A switch conditional node (N branches, selected by condition value).""" def __repr__(self) -> str: - cdef Py_ssize_t n = len(self._branches) - return (f"self._condition._c_handle:x}" - f" with {n} {'branch' if n == 1 else 'branches'}>") + return (f"self._condition._c_handle:x}>") diff --git a/cuda_core/tests/test_object_protocols.py b/cuda_core/tests/test_object_protocols.py index ef4f1337d1..3a523a8943 100644 --- a/cuda_core/tests/test_object_protocols.py +++ b/cuda_core/tests/test_object_protocols.py @@ -685,20 +685,20 @@ def sample_switch_node_alt(sample_graphdef): ("sample_graphdef", r""), ("sample_condition", r""), ("sample_root_node", r""), - ("sample_empty_node", r""), - ("sample_alloc_node", r""), - ("sample_kernel_node", r""), - ("sample_free_node", r""), - ("sample_memset_node", r""), - ("sample_memcpy_node", r""), - ("sample_child_graph_node", r""), - ("sample_event_record_node", r""), - ("sample_event_wait_node", r""), - ("sample_host_callback_node", r""), - ("sample_if_node", r""), - ("sample_if_else_node", r""), - ("sample_while_node", r""), - ("sample_switch_node", r""), + ("sample_empty_node", r""), + ("sample_alloc_node", r""), + ("sample_kernel_node", r""), + ("sample_free_node", r""), + ("sample_memset_node", r""), + ("sample_memcpy_node", r""), + ("sample_child_graph_node", r""), + ("sample_event_record_node", r""), + ("sample_event_wait_node", r""), + ("sample_host_callback_node", r""), + ("sample_if_node", r""), + ("sample_if_else_node", r""), + ("sample_while_node", r""), + ("sample_switch_node", r""), ] From 729af4983f0870986e3422d56c9ab585f33b408a Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 3 Apr 2026 13:24:58 -0700 Subject: [PATCH 15/21] Rename _node_cache/_cached to _node_registry/_registered Aligns Python-side terminology with the C++ graph_node_registry. Made-with: Cursor --- .../core/_graph/_graph_def/_graph_node.pyx | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx index d471eeb40b..7ddc074460 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx @@ -63,11 +63,11 @@ from cuda.core import Device from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy from cuda.core._utils.cuda_utils import driver, handle_return -_node_cache = weakref.WeakValueDictionary() +_node_registry = weakref.WeakValueDictionary() -cdef inline GraphNode _cached(GraphNode n): - _node_cache[n._h_node.get()] = n +cdef inline GraphNode _registered(GraphNode n): + _node_registry[n._h_node.get()] = n return n @@ -153,7 +153,7 @@ cdef class GraphNode: return with nogil: HANDLE_RETURN(cydriver.cuGraphDestroyNode(node)) - _node_cache.pop(self._h_node.get(), None) + _node_registry.pop(self._h_node.get(), None) invalidate_graph_node_handle(self._h_node) @property @@ -532,7 +532,7 @@ cdef inline ConditionalNode _make_conditional_node( n._cond_type = cond_type n._branches = branches - return _cached(n) + return _registered(n) cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node): cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph) @@ -543,12 +543,12 @@ cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node): (n)._h_node = h_node return n - # Return a cached object or create and cache a new one. - cached = _node_cache.get(h_node.get()) - if cached is not None: - return cached + # Return a registered object or create and register a new one. + registered = _node_registry.get(h_node.get()) + if registered is not None: + return registered else: - return _cached(GN_create_impl(h_node)) + return _registered(GN_create_impl(h_node)) cdef inline GraphNode GN_create_impl(GraphNodeHandle h_node): @@ -616,7 +616,7 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker, _attach_user_object(as_cu(h_graph), new KernelHandle(ker._h_kernel), _destroy_kernel_handle_copy) - return _cached(KernelNode._create_with_params( + return _registered(KernelNode._create_with_params( create_graph_node_handle(new_node, h_graph), conf.grid, conf.block, conf.shmem_size, ker._h_kernel)) @@ -645,7 +645,7 @@ cdef inline EmptyNode GN_join(GraphNode self, tuple nodes): HANDLE_RETURN(cydriver.cuGraphAddEmptyNode( &new_node, as_cu(h_graph), deps_ptr, num_deps)) - return _cached(EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph))) + return _registered(EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph))) cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options): @@ -721,7 +721,7 @@ cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options): HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode( &new_node, as_cu(h_graph), deps, num_deps, &alloc_params)) - return _cached(AllocNode._create_with_params( + return _registered(AllocNode._create_with_params( create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size, device_id, memory_type, tuple(peer_ids))) @@ -741,7 +741,7 @@ cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr): HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode( &new_node, as_cu(h_graph), deps, num_deps, c_dptr)) - return _cached(FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr)) + return _registered(FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr)) cdef inline MemsetNode GN_memset( @@ -776,7 +776,7 @@ cdef inline MemsetNode GN_memset( &new_node, as_cu(h_graph), deps, num_deps, &memset_params, ctx)) - return _cached(MemsetNode._create_with_params( + return _registered(MemsetNode._create_with_params( create_graph_node_handle(new_node, h_graph), c_dst, val, elem_size, width, height, pitch)) @@ -837,7 +837,7 @@ cdef inline MemcpyNode GN_memcpy( HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode( &new_node, as_cu(h_graph), deps, num_deps, ¶ms, ctx)) - return _cached(MemcpyNode._create_with_params( + return _registered(MemcpyNode._create_with_params( create_graph_node_handle(new_node, h_graph), c_dst, c_src, size, c_dst_type, c_src_type)) @@ -864,7 +864,7 @@ cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def): cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph) - return _cached(ChildGraphNode._create_with_params( + return _registered(ChildGraphNode._create_with_params( create_graph_node_handle(new_node, h_graph), h_embedded)) @@ -886,7 +886,7 @@ cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev): _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), _destroy_event_handle_copy) - return _cached(EventRecordNode._create_with_params( + return _registered(EventRecordNode._create_with_params( create_graph_node_handle(new_node, h_graph), ev._h_event)) @@ -908,7 +908,7 @@ cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev): _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), _destroy_event_handle_copy) - return _cached(EventWaitNode._create_with_params( + return _registered(EventWaitNode._create_with_params( create_graph_node_handle(new_node, h_graph), ev._h_event)) @@ -935,6 +935,6 @@ cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_ &new_node, as_cu(h_graph), deps, num_deps, &node_params)) cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None - return _cached(HostCallbackNode._create_with_params( + return _registered(HostCallbackNode._create_with_params( create_graph_node_handle(new_node, h_graph), callable_obj, node_params.fn, node_params.userData)) From 42131b609f50190fce0862c57fa67a6ef9d290ea Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 3 Apr 2026 13:45:05 -0700 Subject: [PATCH 16/21] Fix unregister_handle and rename invalidate_graph_node_handle unregister_handle: remove the expired() guard that prevented erasure when the shared_ptr was still alive. This caused stale registry entries after destroy(), leading to CUDA_ERROR_INVALID_VALUE when the driver reused CUgraphNode pointer values. Rename invalidate_graph_node_handle -> invalidate_graph_node for consistency with the rest of the graph node API. Made-with: Cursor --- cuda_core/cuda/core/_cpp/resource_handles.cpp | 11 +++-------- cuda_core/cuda/core/_cpp/resource_handles.hpp | 2 +- cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx | 4 ++-- cuda_core/cuda/core/_resource_handles.pxd | 2 +- cuda_core/cuda/core/_resource_handles.pyx | 2 +- 5 files changed, 8 insertions(+), 13 deletions(-) diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index dd8a3d04cb..904b84c657 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -174,13 +174,8 @@ class HandleRegistry { } void unregister_handle(const Key& key) noexcept { - try { - std::lock_guard lock(mutex_); - auto it = map_.find(key); - if (it != map_.end() && it->second.expired()) { - map_.erase(it); - } - } catch (...) {} + std::lock_guard lock(mutex_); + map_.erase(key); } Handle lookup(const Key& key) { @@ -989,7 +984,7 @@ GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept { return h ? get_box(h)->h_graph : GraphHandle{}; } -void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept { +void invalidate_graph_node(const GraphNodeHandle& h) noexcept { if (h) { CUgraphNode node = get_box(h)->resource; if (node) { diff --git a/cuda_core/cuda/core/_cpp/resource_handles.hpp b/cuda_core/cuda/core/_cpp/resource_handles.hpp index 064f1406f6..d63fb86997 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.hpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.hpp @@ -416,7 +416,7 @@ GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_ GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept; // Zero the CUgraphNode resource inside the handle, marking it invalid. -void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept; +void invalidate_graph_node(const GraphNodeHandle& h) noexcept; // ============================================================================ // Graphics resource handle functions diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx index 7ddc074460..1474d10430 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx @@ -48,7 +48,7 @@ from cuda.core._resource_handles cimport ( create_graph_handle_ref, create_graph_node_handle, graph_node_get_graph, - invalidate_graph_node_handle, + invalidate_graph_node, ) from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value @@ -154,7 +154,7 @@ cdef class GraphNode: with nogil: HANDLE_RETURN(cydriver.cuGraphDestroyNode(node)) _node_registry.pop(self._h_node.get(), None) - invalidate_graph_node_handle(self._h_node) + invalidate_graph_node(self._h_node) @property def pred(self): diff --git a/cuda_core/cuda/core/_resource_handles.pxd b/cuda_core/cuda/core/_resource_handles.pxd index f847e60223..9e7307e821 100644 --- a/cuda_core/cuda/core/_resource_handles.pxd +++ b/cuda_core/cuda/core/_resource_handles.pxd @@ -186,7 +186,7 @@ cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHand # Graph node handles cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil -cdef void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept nogil +cdef void invalidate_graph_node(const GraphNodeHandle& h) noexcept nogil # Graphics resource handles cdef GraphicsResourceHandle create_graphics_resource_handle( diff --git a/cuda_core/cuda/core/_resource_handles.pyx b/cuda_core/cuda/core/_resource_handles.pyx index 001f9b4a0c..2090f5026d 100644 --- a/cuda_core/cuda/core/_resource_handles.pyx +++ b/cuda_core/cuda/core/_resource_handles.pyx @@ -159,7 +159,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil GraphHandle graph_node_get_graph "cuda_core::graph_node_get_graph" ( const GraphNodeHandle& h) noexcept nogil - void invalidate_graph_node_handle "cuda_core::invalidate_graph_node_handle" ( + void invalidate_graph_node "cuda_core::invalidate_graph_node" ( const GraphNodeHandle& h) noexcept nogil # Graphics resource handles From 15d0036ea3ee8c5833ba0e1dadcf3164e6f1557a Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 3 Apr 2026 14:22:09 -0700 Subject: [PATCH 17/21] Add cheap containment test and early type check for AdjacencySetProxy Add _AdjacencySetCore.contains() that checks membership by comparing raw CUgraphNode handles at the C level, avoiding Python object construction. Uses a 16-element stack buffer for a single driver call in the common case. Move the type check in update() inline next to the extend loop so invalid input is rejected immediately. Made-with: Cursor --- .../_graph_def/_adjacency_set_proxy.pyx | 37 ++++++++++++++++--- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx index 5c5dae1ddd..82a1e00828 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx @@ -39,7 +39,7 @@ class AdjacencySetProxy(MutableSet): def __contains__(self, x): if not isinstance(x, GraphNode): return False - return x in (<_AdjacencySetCore>self._core).query() + return (<_AdjacencySetCore>self._core).contains(x) def __iter__(self): return iter((<_AdjacencySetCore>self._core).query()) @@ -87,13 +87,13 @@ class AdjacencySetProxy(MutableSet): if isinstance(other, GraphNode): nodes.append(other) else: - nodes.extend(other) + for n in other: + if not isinstance(n, GraphNode): + raise TypeError( + f"expected GraphNode, got {type(n).__name__}") + nodes.append(n) if not nodes: return - for n in nodes: - if not isinstance(n, GraphNode): - raise TypeError( - f"expected GraphNode, got {type(n).__name__}") new = [n for n in nodes if n not in self] if new: (<_AdjacencySetCore>self._core).add_edges(new) @@ -156,6 +156,31 @@ cdef class _AdjacencySetCore: return [GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(count)] + cdef bint contains(self, GraphNode other): + cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode target = as_cu(other._h_node) + if c_node == NULL or target == NULL: + return False + cdef cydriver.CUgraphNode buf[16] + cdef size_t count = 16 + cdef size_t i + with nogil: + HANDLE_RETURN(self._query_fn(c_node, buf, &count)) + if count <= 16: + for i in range(count): + if buf[i] == target: + return True + else: + cdef vector[cydriver.CUgraphNode] nodes_vec + nodes_vec.resize(count) + with nogil: + HANDLE_RETURN(self._query_fn(c_node, nodes_vec.data(), &count)) + assert count == nodes_vec.size() + for i in range(count): + if nodes_vec[i] == target: + return True + return False + cdef Py_ssize_t count(self): cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) if c_node == NULL: From 347693f3bc3cd0a5116c50c7e560e4a3302735b1 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 3 Apr 2026 14:55:36 -0700 Subject: [PATCH 18/21] Add GraphDef.empty(), stack-buffer query optimization, and registry test - Add GraphDef.empty() for creating entry-point empty nodes; replace all no-arg join() calls on GraphDef with empty() in tests. - Optimize _AdjacencySetCore.query() to use a 16-element stack buffer, matching the contains() optimization. - Add test_registry_cleanup exercising destroy(), graph deletion, and weak-reference cleanup of the node registry. Made-with: Cursor --- .../_graph_def/_adjacency_set_proxy.pyx | 33 ++++++++------ .../core/_graph/_graph_def/_graph_def.pyx | 10 +++++ cuda_core/tests/graph/test_graphdef.py | 45 ++++++++++++++++++- cuda_core/tests/graph/test_graphdef_errors.py | 6 +-- .../tests/graph/test_graphdef_mutation.py | 20 ++++----- 5 files changed, 87 insertions(+), 27 deletions(-) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx index 82a1e00828..3f5f419fb6 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx @@ -143,11 +143,14 @@ cdef class _AdjacencySetCore: cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) if c_node == NULL: return [] - cdef size_t count = 0 + cdef cydriver.CUgraphNode buf[16] + cdef size_t count = 16 + cdef size_t i with nogil: - HANDLE_RETURN(self._query_fn(c_node, NULL, &count)) - if count == 0: - return [] + HANDLE_RETURN(self._query_fn(c_node, buf, &count)) + if count <= 16: + return [GraphNode._create(self._h_graph, buf[i]) + for i in range(count)] cdef vector[cydriver.CUgraphNode] nodes_vec nodes_vec.resize(count) with nogil: @@ -166,19 +169,23 @@ cdef class _AdjacencySetCore: cdef size_t i with nogil: HANDLE_RETURN(self._query_fn(c_node, buf, &count)) + + # Fast path for small sets. if count <= 16: for i in range(count): if buf[i] == target: return True - else: - cdef vector[cydriver.CUgraphNode] nodes_vec - nodes_vec.resize(count) - with nogil: - HANDLE_RETURN(self._query_fn(c_node, nodes_vec.data(), &count)) - assert count == nodes_vec.size() - for i in range(count): - if nodes_vec[i] == target: - return True + return False + + # Fallback for large sets. + cdef vector[cydriver.CUgraphNode] nodes_vec + nodes_vec.resize(count) + with nogil: + HANDLE_RETURN(self._query_fn(c_node, nodes_vec.data(), &count)) + assert count == nodes_vec.size() + for i in range(count): + if nodes_vec[i] == target: + return True return False cdef Py_ssize_t count(self): diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx index 03673844d5..6c8e01069b 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx @@ -159,6 +159,16 @@ cdef class GraphDef: """ return self._entry.launch(config, kernel, *args) + def empty(self) -> "EmptyNode": + """Add an entry-point empty node (no dependencies). + + Returns + ------- + EmptyNode + A new EmptyNode with no dependencies. + """ + return self._entry.join() + def join(self, *nodes) -> "EmptyNode": """Create an empty node that depends on all given nodes. diff --git a/cuda_core/tests/graph/test_graphdef.py b/cuda_core/tests/graph/test_graphdef.py index 562f720ca8..fe760f8994 100644 --- a/cuda_core/tests/graph/test_graphdef.py +++ b/cuda_core/tests/graph/test_graphdef.py @@ -703,7 +703,7 @@ def test_identity_preservation(init_cuda): """Round-trips through nodes(), edges(), and pred/succ return extant objects rather than duplicates.""" g = GraphDef() - a = g.join() + a = g.empty() b = a.join() # nodes() @@ -724,6 +724,49 @@ def test_identity_preservation(init_cuda): assert b2 is b +def test_registry_cleanup(init_cuda): + """Node registry entries are removed on destroy() and graph teardown.""" + import gc + from cuda.core._graph._graph_def._graph_node import _node_registry + + def registered(node): + return any(v is node for v in _node_registry.values()) + + gc.collect() + assert len(_node_registry) == 0 + + g = GraphDef() + a = g.empty() + b = g.empty() + c = g.empty() + + assert len(_node_registry) == 3 + assert registered(a) + assert registered(b) + assert registered(c) + + a.destroy() + assert len(_node_registry) == 2 + assert not registered(a) + assert registered(b) + assert registered(c) + + del g + gc.collect() + assert len(_node_registry) == 2 + assert registered(b) + assert registered(c) + + b.destroy() + assert len(_node_registry) == 1 + assert not registered(b) + assert registered(c) + + del c + gc.collect() + assert len(_node_registry) == 0 + + # ============================================================================= # GraphDef basics # ============================================================================= diff --git a/cuda_core/tests/graph/test_graphdef_errors.py b/cuda_core/tests/graph/test_graphdef_errors.py index 9c6a870562..596f83bffe 100644 --- a/cuda_core/tests/graph/test_graphdef_errors.py +++ b/cuda_core/tests/graph/test_graphdef_errors.py @@ -101,10 +101,10 @@ def test_condition_from_different_graph(init_cuda): # ============================================================================= -def test_join_no_extra_nodes(init_cuda): - """join() from entry with no extra nodes creates a single empty node.""" +def test_empty_node(init_cuda): + """empty() creates a single entry-point empty node.""" g = GraphDef() - joined = g.join() + joined = g.empty() assert isinstance(joined, EmptyNode) assert len(g.nodes()) == 1 diff --git a/cuda_core/tests/graph/test_graphdef_mutation.py b/cuda_core/tests/graph/test_graphdef_mutation.py index ac0d8f5e61..c0fc1bb242 100644 --- a/cuda_core/tests/graph/test_graphdef_mutation.py +++ b/cuda_core/tests/graph/test_graphdef_mutation.py @@ -211,16 +211,16 @@ def test_insert_b(self, init_cuda): def test_adjacency_set_interface(init_cuda): """Exercise every MutableSet method on AdjacencySetProxy.""" g = GraphDef() - hub = g.join() - items = [g.join() for _ in range(5)] + hub = g.empty() + items = [g.empty() for _ in range(5)] assert_mutable_set_interface(hub.succ, items) def test_adjacency_set_pred_direction(init_cuda): """Verify that pred works symmetrically with succ.""" g = GraphDef() - target = g.join() - x, y, z = (g.join() for _ in range(3)) + target = g.empty() + x, y, z = (g.empty() for _ in range(3)) pred = target.pred assert pred == set() @@ -242,8 +242,8 @@ def test_adjacency_set_pred_direction(init_cuda): def test_adjacency_set_property_setter(init_cuda): """Verify that assigning to node.pred or node.succ replaces all edges.""" g = GraphDef() - hub = g.join() - a, b, c = (g.join() for _ in range(3)) + hub = g.empty() + a, b, c = (g.empty() for _ in range(3)) hub.succ = {a, b} assert hub.succ == {a, b} @@ -310,7 +310,7 @@ def test_destroyed_node(init_cuda): def test_add_wrong_type(init_cuda): """Adding a non-GraphNode raises TypeError.""" g = GraphDef() - node = g.join() + node = g.empty() with pytest.raises(TypeError, match="expected GraphNode"): node.succ.add("not a node") with pytest.raises(TypeError, match="expected GraphNode"): @@ -321,8 +321,8 @@ def test_cross_graph_edge(init_cuda): """Adding an edge to a node from a different graph raises CUDAError.""" g1 = GraphDef() g2 = GraphDef() - a = g1.join() - b = g2.join() + a = g1.empty() + b = g2.empty() with pytest.raises(CUDAError): a.succ.add(b) @@ -330,7 +330,7 @@ def test_cross_graph_edge(init_cuda): def test_self_edge(init_cuda): """Adding a self-edge raises CUDAError.""" g = GraphDef() - node = g.join() + node = g.empty() with pytest.raises(CUDAError): node.succ.add(node) From 641a08925c3b6fea5c36e9aa7920c0eac9a12397 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 3 Apr 2026 15:23:53 -0700 Subject: [PATCH 19/21] Document the two-level handle and object registry design Add REGISTRY_DESIGN.md explaining how the C++ HandleRegistry (Level 1) and Cython _node_registry (Level 2) work together to preserve Python object identity through driver round-trips. Add cross-references at each registry instantiation site. Made-with: Cursor --- cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md | 51 +++++++++++++++++++ cuda_core/cuda/core/_cpp/resource_handles.cpp | 3 ++ .../core/_graph/_graph_def/_graph_node.pyx | 1 + 3 files changed, 55 insertions(+) create mode 100644 cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md diff --git a/cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md b/cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md new file mode 100644 index 0000000000..cbfc609686 --- /dev/null +++ b/cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md @@ -0,0 +1,51 @@ +# Handle and Object Registries + +When Python-managed objects round-trip through the CUDA driver (e.g., +querying a graph's nodes and getting back raw `CUgraphNode` pointers), +we need to recover the original Python object rather than creating a +duplicate. + +This document describes the approach used to achieve this. The pattern +is driven mainly by needs arising in the context of CUDA graphs, but +it is general and can be extended to other object types as needs arise. + +This solves the same problem as pybind11's `registered_instances` map +and is sometimes called the Identity Map pattern. Two registries work +together to map a raw driver handle all the way back to the original +Python object. Both use weak references so they +do not prevent cleanup. Entries are removed either explicitly (via +`destroy()` or a Box destructor) or implicitly when the weak reference +expires. + +## Level 1: Driver Handle -> Resource Handle (C++) + +`HandleRegistry` in `resource_handles.cpp` maps a raw CUDA handle +(e.g., `CUevent`, `CUkernel`, `CUgraphNode`) to the `weak_ptr` that +owns it. When a `_ref` constructor receives a raw handle, it +checks the registry first. If found, it returns the existing +`shared_ptr`, preserving the Box and its metadata (e.g., `EventBox` +carries timing/IPC flags, `KernelBox` carries the library dependency). + +Without this level, a round-tripped handle would produce a new Box +with default metadata, losing information that was set at creation. + +Instances: `event_registry`, `kernel_registry`, `graph_node_registry`. + +## Level 2: Resource Handle -> Python Object (Cython) + +`_node_registry` in `_graph_node.pyx` is a `WeakValueDictionary` +mapping a resource address (`shared_ptr::get()`) to a Python +`GraphNode` object. When `GraphNode._create` receives a handle from +Level 1, it checks this registry. If found, it returns the existing +Python object. + +Without this level, each driver round-trip would produce a distinct +Python object for the same logical node, resulting in surprising +behavior: + +```python +a = g.empty() +a.succ = {b} +b2, = a.succ # queries driver, gets back CUgraphNode for b +assert b2 is b # fails without Level 2 registry +``` diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 904b84c657..a21cd8a8aa 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -388,6 +388,7 @@ ContextHandle get_event_context(const EventHandle& h) noexcept { return h ? get_box(h)->h_context : ContextHandle{}; } +// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle) static HandleRegistry event_registry; EventHandle create_event_handle(const ContextHandle& h_ctx, unsigned int flags, @@ -894,6 +895,7 @@ static const KernelBox* get_box(const KernelHandle& h) { ); } +// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle) static HandleRegistry kernel_registry; KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* name) { @@ -964,6 +966,7 @@ static const GraphNodeBox* get_box(const GraphNodeHandle& h) { ); } +// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle) static HandleRegistry graph_node_registry; GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) { diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx index 1474d10430..195acbe765 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx @@ -63,6 +63,7 @@ from cuda.core import Device from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy from cuda.core._utils.cuda_utils import driver, handle_return +# See _cpp/REGISTRY_DESIGN.md (Level 2: Resource Handle -> Python Object) _node_registry = weakref.WeakValueDictionary() From 83706879486672635be805ad9c7662d2c97d54f3 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 3 Apr 2026 15:28:18 -0700 Subject: [PATCH 20/21] Fix import formatting in test_registry_cleanup Made-with: Cursor --- cuda_core/tests/graph/test_graphdef.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cuda_core/tests/graph/test_graphdef.py b/cuda_core/tests/graph/test_graphdef.py index fe760f8994..03d82c737c 100644 --- a/cuda_core/tests/graph/test_graphdef.py +++ b/cuda_core/tests/graph/test_graphdef.py @@ -727,6 +727,7 @@ def test_identity_preservation(init_cuda): def test_registry_cleanup(init_cuda): """Node registry entries are removed on destroy() and graph teardown.""" import gc + from cuda.core._graph._graph_def._graph_node import _node_registry def registered(node): From f779f30b8ce88b8239c49e01099ae9b05a4592e3 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 3 Apr 2026 15:50:09 -0700 Subject: [PATCH 21/21] Optimize GraphDef.nodes() and edges() to try a single driver call Pre-allocate vectors to 128 entries and pass them on the first call. Only fall back to a second call if the graph exceeds 128 nodes/edges. Made-with: Cursor --- .../core/_graph/_graph_def/_graph_def.pyx | 45 +++++++++++-------- cuda_core/tests/graph/test_graphdef.py | 2 +- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx index 6c8e01069b..8776b7d49f 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx @@ -332,18 +332,20 @@ cdef class GraphDef: set of GraphNode All nodes in the graph. """ - cdef size_t num_nodes = 0 + cdef vector[cydriver.CUgraphNode] nodes_vec + nodes_vec.resize(128) + cdef size_t num_nodes = 128 with nogil: - HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), NULL, &num_nodes)) + HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes)) if num_nodes == 0: return set() - cdef vector[cydriver.CUgraphNode] nodes_vec - nodes_vec.resize(num_nodes) - with nogil: - HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes)) + if num_nodes > 128: + nodes_vec.resize(num_nodes) + with nogil: + HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes)) return set(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes)) @@ -356,21 +358,12 @@ cdef class GraphDef: Each element is a (from_node, to_node) pair representing a dependency edge in the graph. """ - cdef size_t num_edges = 0 - - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, NULL, &num_edges)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, &num_edges)) - - if num_edges == 0: - return set() - cdef vector[cydriver.CUgraphNode] from_nodes cdef vector[cydriver.CUgraphNode] to_nodes - from_nodes.resize(num_edges) - to_nodes.resize(num_edges) + from_nodes.resize(128) + to_nodes.resize(128) + cdef size_t num_edges = 128 + with nogil: IF CUDA_CORE_BUILD_MAJOR >= 13: HANDLE_RETURN(cydriver.cuGraphGetEdges( @@ -379,6 +372,20 @@ cdef class GraphDef: HANDLE_RETURN(cydriver.cuGraphGetEdges( as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges)) + if num_edges == 0: + return set() + + if num_edges > 128: + from_nodes.resize(num_edges) + to_nodes.resize(num_edges) + with nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + HANDLE_RETURN(cydriver.cuGraphGetEdges( + as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), NULL, &num_edges)) + ELSE: + HANDLE_RETURN(cydriver.cuGraphGetEdges( + as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges)) + return set( (GraphNode._create(self._h_graph, from_nodes[i]), GraphNode._create(self._h_graph, to_nodes[i])) diff --git a/cuda_core/tests/graph/test_graphdef.py b/cuda_core/tests/graph/test_graphdef.py index 03d82c737c..c4e34fc02c 100644 --- a/cuda_core/tests/graph/test_graphdef.py +++ b/cuda_core/tests/graph/test_graphdef.py @@ -704,7 +704,7 @@ def test_identity_preservation(init_cuda): objects rather than duplicates.""" g = GraphDef() a = g.empty() - b = a.join() + b = g.empty() # nodes() assert any(x is a for x in g.nodes())