diff --git a/cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md b/cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md new file mode 100644 index 0000000000..cbfc609686 --- /dev/null +++ b/cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md @@ -0,0 +1,51 @@ +# Handle and Object Registries + +When Python-managed objects round-trip through the CUDA driver (e.g., +querying a graph's nodes and getting back raw `CUgraphNode` pointers), +we need to recover the original Python object rather than creating a +duplicate. + +This document describes the approach used to achieve this. The pattern +is driven mainly by needs arising in the context of CUDA graphs, but +it is general and can be extended to other object types as needs arise. + +This solves the same problem as pybind11's `registered_instances` map +and is sometimes called the Identity Map pattern. Two registries work +together to map a raw driver handle all the way back to the original +Python object. Both use weak references so they +do not prevent cleanup. Entries are removed either explicitly (via +`destroy()` or a Box destructor) or implicitly when the weak reference +expires. + +## Level 1: Driver Handle -> Resource Handle (C++) + +`HandleRegistry` in `resource_handles.cpp` maps a raw CUDA handle +(e.g., `CUevent`, `CUkernel`, `CUgraphNode`) to the `weak_ptr` that +owns it. When a `_ref` constructor receives a raw handle, it +checks the registry first. If found, it returns the existing +`shared_ptr`, preserving the Box and its metadata (e.g., `EventBox` +carries timing/IPC flags, `KernelBox` carries the library dependency). + +Without this level, a round-tripped handle would produce a new Box +with default metadata, losing information that was set at creation. + +Instances: `event_registry`, `kernel_registry`, `graph_node_registry`. + +## Level 2: Resource Handle -> Python Object (Cython) + +`_node_registry` in `_graph_node.pyx` is a `WeakValueDictionary` +mapping a resource address (`shared_ptr::get()`) to a Python +`GraphNode` object. When `GraphNode._create` receives a handle from +Level 1, it checks this registry. If found, it returns the existing +Python object. + +Without this level, each driver round-trip would produce a distinct +Python object for the same logical node, resulting in surprising +behavior: + +```python +a = g.empty() +a.succ = {b} +b2, = a.succ # queries driver, gets back CUgraphNode for b +assert b2 is b # fails without Level 2 registry +``` diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 904b84c657..a21cd8a8aa 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -388,6 +388,7 @@ ContextHandle get_event_context(const EventHandle& h) noexcept { return h ? get_box(h)->h_context : ContextHandle{}; } +// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle) static HandleRegistry event_registry; EventHandle create_event_handle(const ContextHandle& h_ctx, unsigned int flags, @@ -894,6 +895,7 @@ static const KernelBox* get_box(const KernelHandle& h) { ); } +// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle) static HandleRegistry kernel_registry; KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* name) { @@ -964,6 +966,7 @@ static const GraphNodeBox* get_box(const GraphNodeHandle& h) { ); } +// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle) static HandleRegistry graph_node_registry; GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) { diff --git a/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx index 5c5dae1ddd..3f5f419fb6 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx @@ -39,7 +39,7 @@ class AdjacencySetProxy(MutableSet): def __contains__(self, x): if not isinstance(x, GraphNode): return False - return x in (<_AdjacencySetCore>self._core).query() + return (<_AdjacencySetCore>self._core).contains(x) def __iter__(self): return iter((<_AdjacencySetCore>self._core).query()) @@ -87,13 +87,13 @@ class AdjacencySetProxy(MutableSet): if isinstance(other, GraphNode): nodes.append(other) else: - nodes.extend(other) + for n in other: + if not isinstance(n, GraphNode): + raise TypeError( + f"expected GraphNode, got {type(n).__name__}") + nodes.append(n) if not nodes: return - for n in nodes: - if not isinstance(n, GraphNode): - raise TypeError( - f"expected GraphNode, got {type(n).__name__}") new = [n for n in nodes if n not in self] if new: (<_AdjacencySetCore>self._core).add_edges(new) @@ -143,11 +143,14 @@ cdef class _AdjacencySetCore: cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) if c_node == NULL: return [] - cdef size_t count = 0 + cdef cydriver.CUgraphNode buf[16] + cdef size_t count = 16 + cdef size_t i with nogil: - HANDLE_RETURN(self._query_fn(c_node, NULL, &count)) - if count == 0: - return [] + HANDLE_RETURN(self._query_fn(c_node, buf, &count)) + if count <= 16: + return [GraphNode._create(self._h_graph, buf[i]) + for i in range(count)] cdef vector[cydriver.CUgraphNode] nodes_vec nodes_vec.resize(count) with nogil: @@ -156,6 +159,35 @@ cdef class _AdjacencySetCore: return [GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(count)] + cdef bint contains(self, GraphNode other): + cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode target = as_cu(other._h_node) + if c_node == NULL or target == NULL: + return False + cdef cydriver.CUgraphNode buf[16] + cdef size_t count = 16 + cdef size_t i + with nogil: + HANDLE_RETURN(self._query_fn(c_node, buf, &count)) + + # Fast path for small sets. + if count <= 16: + for i in range(count): + if buf[i] == target: + return True + return False + + # Fallback for large sets. + cdef vector[cydriver.CUgraphNode] nodes_vec + nodes_vec.resize(count) + with nogil: + HANDLE_RETURN(self._query_fn(c_node, nodes_vec.data(), &count)) + assert count == nodes_vec.size() + for i in range(count): + if nodes_vec[i] == target: + return True + return False + cdef Py_ssize_t count(self): cdef cydriver.CUgraphNode c_node = as_cu(self._h_node) if c_node == NULL: diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx index 03673844d5..8776b7d49f 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx @@ -159,6 +159,16 @@ cdef class GraphDef: """ return self._entry.launch(config, kernel, *args) + def empty(self) -> "EmptyNode": + """Add an entry-point empty node (no dependencies). + + Returns + ------- + EmptyNode + A new EmptyNode with no dependencies. + """ + return self._entry.join() + def join(self, *nodes) -> "EmptyNode": """Create an empty node that depends on all given nodes. @@ -322,18 +332,20 @@ cdef class GraphDef: set of GraphNode All nodes in the graph. """ - cdef size_t num_nodes = 0 + cdef vector[cydriver.CUgraphNode] nodes_vec + nodes_vec.resize(128) + cdef size_t num_nodes = 128 with nogil: - HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), NULL, &num_nodes)) + HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes)) if num_nodes == 0: return set() - cdef vector[cydriver.CUgraphNode] nodes_vec - nodes_vec.resize(num_nodes) - with nogil: - HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes)) + if num_nodes > 128: + nodes_vec.resize(num_nodes) + with nogil: + HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes)) return set(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes)) @@ -346,21 +358,12 @@ cdef class GraphDef: Each element is a (from_node, to_node) pair representing a dependency edge in the graph. """ - cdef size_t num_edges = 0 - - with nogil: - IF CUDA_CORE_BUILD_MAJOR >= 13: - HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, NULL, &num_edges)) - ELSE: - HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, &num_edges)) - - if num_edges == 0: - return set() - cdef vector[cydriver.CUgraphNode] from_nodes cdef vector[cydriver.CUgraphNode] to_nodes - from_nodes.resize(num_edges) - to_nodes.resize(num_edges) + from_nodes.resize(128) + to_nodes.resize(128) + cdef size_t num_edges = 128 + with nogil: IF CUDA_CORE_BUILD_MAJOR >= 13: HANDLE_RETURN(cydriver.cuGraphGetEdges( @@ -369,6 +372,20 @@ cdef class GraphDef: HANDLE_RETURN(cydriver.cuGraphGetEdges( as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges)) + if num_edges == 0: + return set() + + if num_edges > 128: + from_nodes.resize(num_edges) + to_nodes.resize(num_edges) + with nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + HANDLE_RETURN(cydriver.cuGraphGetEdges( + as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), NULL, &num_edges)) + ELSE: + HANDLE_RETURN(cydriver.cuGraphGetEdges( + as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges)) + return set( (GraphNode._create(self._h_graph, from_nodes[i]), GraphNode._create(self._h_graph, to_nodes[i])) diff --git a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx index 1474d10430..195acbe765 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx @@ -63,6 +63,7 @@ from cuda.core import Device from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy from cuda.core._utils.cuda_utils import driver, handle_return +# See _cpp/REGISTRY_DESIGN.md (Level 2: Resource Handle -> Python Object) _node_registry = weakref.WeakValueDictionary() diff --git a/cuda_core/cuda/core/_graph/_graph_def/_subclasses.pyx b/cuda_core/cuda/core/_graph/_graph_def/_subclasses.pyx index 2c78b3b0ac..e1ab3a460a 100644 --- a/cuda_core/cuda/core/_graph/_graph_def/_subclasses.pyx +++ b/cuda_core/cuda/core/_graph/_graph_def/_subclasses.pyx @@ -58,8 +58,7 @@ cdef class EmptyNode(GraphNode): return n def __repr__(self) -> str: - cdef Py_ssize_t n = len(self.pred) - return f"" + return f"" cdef class KernelNode(GraphNode): @@ -108,7 +107,8 @@ cdef class KernelNode(GraphNode): h_kernel) def __repr__(self) -> str: - return (f"") + return (f"") @property def grid(self) -> tuple: @@ -207,7 +207,8 @@ cdef class AllocNode(GraphNode): params.poolProps.location.id, memory_type, tuple(peer_ids)) def __repr__(self) -> str: - return f"" + return (f"") @property def dptr(self) -> int: @@ -273,7 +274,7 @@ cdef class FreeNode(GraphNode): return FreeNode._create_with_params(h_node, dptr) def __repr__(self) -> str: - return f"" + return f"" @property def dptr(self) -> int: @@ -328,8 +329,8 @@ cdef class MemsetNode(GraphNode): params.elementSize, params.width, params.height, params.pitch) def __repr__(self) -> str: - return (f"") + return (f"") @property def dptr(self) -> int: @@ -416,8 +417,8 @@ cdef class MemcpyNode(GraphNode): def __repr__(self) -> str: cdef str dt = "H" if self._dst_type == cydriver.CU_MEMORYTYPE_HOST else "D" cdef str st = "H" if self._src_type == cydriver.CU_MEMORYTYPE_HOST else "D" - return (f"") + return (f"") @property def dst(self) -> int: @@ -465,12 +466,8 @@ cdef class ChildGraphNode(GraphNode): return ChildGraphNode._create_with_params(h_node, h_child) def __repr__(self) -> str: - cdef cydriver.CUgraph g = as_cu(self._h_child_graph) - cdef size_t num_nodes = 0 - with nogil: - HANDLE_RETURN(cydriver.cuGraphGetNodes(g, NULL, &num_nodes)) - cdef Py_ssize_t n = num_nodes - return f"" + return (f"") @property def child_graph(self) -> "GraphDef": @@ -507,7 +504,8 @@ cdef class EventRecordNode(GraphNode): return EventRecordNode._create_with_params(h_node, h_event) def __repr__(self) -> str: - return f"" + return (f"") @property def event(self) -> Event: @@ -544,7 +542,8 @@ cdef class EventWaitNode(GraphNode): return EventWaitNode._create_with_params(h_node, h_event) def __repr__(self) -> str: - return f"" + return (f"") @property def event(self) -> Event: @@ -591,8 +590,10 @@ cdef class HostCallbackNode(GraphNode): def __repr__(self) -> str: if self._callable is not None: name = getattr(self._callable, '__name__', '?') - return f"" - return f"self._fn:x}>" + return (f"") + return (f"self._fn:x}>") @property def callback_fn(self): @@ -672,7 +673,7 @@ cdef class ConditionalNode(GraphNode): return n def __repr__(self) -> str: - return "" + return f"" @property def condition(self) -> Condition | None: @@ -709,7 +710,8 @@ cdef class IfNode(ConditionalNode): """An if-conditional node (1 branch, executes when condition is non-zero).""" def __repr__(self) -> str: - return f"self._condition._c_handle:x}>" + return (f"self._condition._c_handle:x}>") @property def then(self) -> "GraphDef": @@ -721,7 +723,8 @@ cdef class IfElseNode(ConditionalNode): """An if-else conditional node (2 branches).""" def __repr__(self) -> str: - return f"self._condition._c_handle:x}>" + return (f"self._condition._c_handle:x}>") @property def then(self) -> "GraphDef": @@ -738,7 +741,8 @@ cdef class WhileNode(ConditionalNode): """A while-loop conditional node (1 branch, repeats while condition is non-zero).""" def __repr__(self) -> str: - return f"self._condition._c_handle:x}>" + return (f"self._condition._c_handle:x}>") @property def body(self) -> "GraphDef": @@ -750,6 +754,5 @@ cdef class SwitchNode(ConditionalNode): """A switch conditional node (N branches, selected by condition value).""" def __repr__(self) -> str: - cdef Py_ssize_t n = len(self._branches) - return (f"self._condition._c_handle:x}" - f" with {n} {'branch' if n == 1 else 'branches'}>") + return (f"self._condition._c_handle:x}>") diff --git a/cuda_core/tests/graph/test_graphdef.py b/cuda_core/tests/graph/test_graphdef.py index 562f720ca8..c4e34fc02c 100644 --- a/cuda_core/tests/graph/test_graphdef.py +++ b/cuda_core/tests/graph/test_graphdef.py @@ -703,8 +703,8 @@ def test_identity_preservation(init_cuda): """Round-trips through nodes(), edges(), and pred/succ return extant objects rather than duplicates.""" g = GraphDef() - a = g.join() - b = a.join() + a = g.empty() + b = g.empty() # nodes() assert any(x is a for x in g.nodes()) @@ -724,6 +724,50 @@ def test_identity_preservation(init_cuda): assert b2 is b +def test_registry_cleanup(init_cuda): + """Node registry entries are removed on destroy() and graph teardown.""" + import gc + + from cuda.core._graph._graph_def._graph_node import _node_registry + + def registered(node): + return any(v is node for v in _node_registry.values()) + + gc.collect() + assert len(_node_registry) == 0 + + g = GraphDef() + a = g.empty() + b = g.empty() + c = g.empty() + + assert len(_node_registry) == 3 + assert registered(a) + assert registered(b) + assert registered(c) + + a.destroy() + assert len(_node_registry) == 2 + assert not registered(a) + assert registered(b) + assert registered(c) + + del g + gc.collect() + assert len(_node_registry) == 2 + assert registered(b) + assert registered(c) + + b.destroy() + assert len(_node_registry) == 1 + assert not registered(b) + assert registered(c) + + del c + gc.collect() + assert len(_node_registry) == 0 + + # ============================================================================= # GraphDef basics # ============================================================================= diff --git a/cuda_core/tests/graph/test_graphdef_errors.py b/cuda_core/tests/graph/test_graphdef_errors.py index 9c6a870562..596f83bffe 100644 --- a/cuda_core/tests/graph/test_graphdef_errors.py +++ b/cuda_core/tests/graph/test_graphdef_errors.py @@ -101,10 +101,10 @@ def test_condition_from_different_graph(init_cuda): # ============================================================================= -def test_join_no_extra_nodes(init_cuda): - """join() from entry with no extra nodes creates a single empty node.""" +def test_empty_node(init_cuda): + """empty() creates a single entry-point empty node.""" g = GraphDef() - joined = g.join() + joined = g.empty() assert isinstance(joined, EmptyNode) assert len(g.nodes()) == 1 diff --git a/cuda_core/tests/graph/test_graphdef_mutation.py b/cuda_core/tests/graph/test_graphdef_mutation.py index ac0d8f5e61..c0fc1bb242 100644 --- a/cuda_core/tests/graph/test_graphdef_mutation.py +++ b/cuda_core/tests/graph/test_graphdef_mutation.py @@ -211,16 +211,16 @@ def test_insert_b(self, init_cuda): def test_adjacency_set_interface(init_cuda): """Exercise every MutableSet method on AdjacencySetProxy.""" g = GraphDef() - hub = g.join() - items = [g.join() for _ in range(5)] + hub = g.empty() + items = [g.empty() for _ in range(5)] assert_mutable_set_interface(hub.succ, items) def test_adjacency_set_pred_direction(init_cuda): """Verify that pred works symmetrically with succ.""" g = GraphDef() - target = g.join() - x, y, z = (g.join() for _ in range(3)) + target = g.empty() + x, y, z = (g.empty() for _ in range(3)) pred = target.pred assert pred == set() @@ -242,8 +242,8 @@ def test_adjacency_set_pred_direction(init_cuda): def test_adjacency_set_property_setter(init_cuda): """Verify that assigning to node.pred or node.succ replaces all edges.""" g = GraphDef() - hub = g.join() - a, b, c = (g.join() for _ in range(3)) + hub = g.empty() + a, b, c = (g.empty() for _ in range(3)) hub.succ = {a, b} assert hub.succ == {a, b} @@ -310,7 +310,7 @@ def test_destroyed_node(init_cuda): def test_add_wrong_type(init_cuda): """Adding a non-GraphNode raises TypeError.""" g = GraphDef() - node = g.join() + node = g.empty() with pytest.raises(TypeError, match="expected GraphNode"): node.succ.add("not a node") with pytest.raises(TypeError, match="expected GraphNode"): @@ -321,8 +321,8 @@ def test_cross_graph_edge(init_cuda): """Adding an edge to a node from a different graph raises CUDAError.""" g1 = GraphDef() g2 = GraphDef() - a = g1.join() - b = g2.join() + a = g1.empty() + b = g2.empty() with pytest.raises(CUDAError): a.succ.add(b) @@ -330,7 +330,7 @@ def test_cross_graph_edge(init_cuda): def test_self_edge(init_cuda): """Adding a self-edge raises CUDAError.""" g = GraphDef() - node = g.join() + node = g.empty() with pytest.raises(CUDAError): node.succ.add(node) diff --git a/cuda_core/tests/test_object_protocols.py b/cuda_core/tests/test_object_protocols.py index ef4f1337d1..3a523a8943 100644 --- a/cuda_core/tests/test_object_protocols.py +++ b/cuda_core/tests/test_object_protocols.py @@ -685,20 +685,20 @@ def sample_switch_node_alt(sample_graphdef): ("sample_graphdef", r""), ("sample_condition", r""), ("sample_root_node", r""), - ("sample_empty_node", r""), - ("sample_alloc_node", r""), - ("sample_kernel_node", r""), - ("sample_free_node", r""), - ("sample_memset_node", r""), - ("sample_memcpy_node", r""), - ("sample_child_graph_node", r""), - ("sample_event_record_node", r""), - ("sample_event_wait_node", r""), - ("sample_host_callback_node", r""), - ("sample_if_node", r""), - ("sample_if_else_node", r""), - ("sample_while_node", r""), - ("sample_switch_node", r""), + ("sample_empty_node", r""), + ("sample_alloc_node", r""), + ("sample_kernel_node", r""), + ("sample_free_node", r""), + ("sample_memset_node", r""), + ("sample_memcpy_node", r""), + ("sample_child_graph_node", r""), + ("sample_event_record_node", r""), + ("sample_event_wait_node", r""), + ("sample_host_callback_node", r""), + ("sample_if_node", r""), + ("sample_if_else_node", r""), + ("sample_while_node", r""), + ("sample_switch_node", r""), ]