Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
dc92437
Reorganize graph test files for clarity
Andy-Jost Mar 31, 2026
281ed82
Enhance Graph.update() and add whole-graph update tests
Andy-Jost Mar 31, 2026
7854b76
Add AdjacencySet proxy for pred/succ and GraphNode.remove()
Andy-Jost Mar 31, 2026
5fbd288
Add edge mutation support and MutableSet interface for GraphNode adja…
Andy-Jost Apr 2, 2026
aa84e26
Use requires_module mark for numpy version checks in mutation tests
Andy-Jost Apr 2, 2026
b27dd93
Fix empty-graph return type: return set() instead of () for nodes/edges
Andy-Jost Apr 2, 2026
8554d30
Rename AdjacencySet to AdjacencySetProxy, add bulk ops and safety guards
Andy-Jost Apr 2, 2026
9813c20
Add destroy() method with handle invalidation, remove GRAPH_NODE_SENT…
Andy-Jost Apr 2, 2026
6411881
Add GraphNode identity cache for stable Python object round-trips
Andy-Jost Apr 2, 2026
7a3dbb4
Purge node cache on destroy to prevent stale identity lookups
Andy-Jost Apr 2, 2026
91b3b4e
Skip NULL nodes in graph_node_registry to fix sentinel identity colli…
Andy-Jost Apr 2, 2026
1b7743d
Unregister destroyed nodes from C++ graph_node_registry
Andy-Jost Apr 3, 2026
84f0b30
Add dedicated test for node identity preservation through round-trips
Andy-Jost Apr 3, 2026
64d6c2d
Merge branch 'main' into graph-node-identity
Andy-Jost Apr 3, 2026
6b36e47
Add handle= to all GraphNode subclass __repr__ for debugging
Andy-Jost Apr 3, 2026
a40be9a
Merge branch 'main' into graph-node-identity
Andy-Jost Apr 3, 2026
729af49
Rename _node_cache/_cached to _node_registry/_registered
Andy-Jost Apr 3, 2026
42131b6
Fix unregister_handle and rename invalidate_graph_node_handle
Andy-Jost Apr 3, 2026
9766e54
Merge branch 'graph-node-identity' into graph-node-repr
Andy-Jost Apr 3, 2026
15d0036
Add cheap containment test and early type check for AdjacencySetProxy
Andy-Jost Apr 3, 2026
347693f
Add GraphDef.empty(), stack-buffer query optimization, and registry test
Andy-Jost Apr 3, 2026
641a089
Document the two-level handle and object registry design
Andy-Jost Apr 3, 2026
8370687
Fix import formatting in test_registry_cleanup
Andy-Jost Apr 3, 2026
36527da
Merge origin/main into graph-node-repr
Andy-Jost Apr 3, 2026
f779f30
Optimize GraphDef.nodes() and edges() to try a single driver call
Andy-Jost Apr 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Handle and Object Registries

When Python-managed objects round-trip through the CUDA driver (e.g.,
querying a graph's nodes and getting back raw `CUgraphNode` pointers),
we need to recover the original Python object rather than creating a
duplicate.

This document describes the approach used to achieve this. The pattern
is driven mainly by needs arising in the context of CUDA graphs, but
it is general and can be extended to other object types as needs arise.

This solves the same problem as pybind11's `registered_instances` map
and is sometimes called the Identity Map pattern. Two registries work
together to map a raw driver handle all the way back to the original
Python object. Both use weak references so they
do not prevent cleanup. Entries are removed either explicitly (via
`destroy()` or a Box destructor) or implicitly when the weak reference
expires.

## Level 1: Driver Handle -> Resource Handle (C++)

`HandleRegistry` in `resource_handles.cpp` maps a raw CUDA handle
(e.g., `CUevent`, `CUkernel`, `CUgraphNode`) to the `weak_ptr` that
owns it. When a `_ref` constructor receives a raw handle, it
checks the registry first. If found, it returns the existing
`shared_ptr`, preserving the Box and its metadata (e.g., `EventBox`
carries timing/IPC flags, `KernelBox` carries the library dependency).

Without this level, a round-tripped handle would produce a new Box
with default metadata, losing information that was set at creation.

Instances: `event_registry`, `kernel_registry`, `graph_node_registry`.

## Level 2: Resource Handle -> Python Object (Cython)

`_node_registry` in `_graph_node.pyx` is a `WeakValueDictionary`
mapping a resource address (`shared_ptr::get()`) to a Python
`GraphNode` object. When `GraphNode._create` receives a handle from
Level 1, it checks this registry. If found, it returns the existing
Python object.

Without this level, each driver round-trip would produce a distinct
Python object for the same logical node, resulting in surprising
behavior:

```python
a = g.empty()
a.succ = {b}
b2, = a.succ # queries driver, gets back CUgraphNode for b
assert b2 is b # fails without Level 2 registry
```
3 changes: 3 additions & 0 deletions cuda_core/cuda/core/_cpp/resource_handles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ ContextHandle get_event_context(const EventHandle& h) noexcept {
return h ? get_box(h)->h_context : ContextHandle{};
}

// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle)
static HandleRegistry<CUevent, EventHandle> event_registry;

EventHandle create_event_handle(const ContextHandle& h_ctx, unsigned int flags,
Expand Down Expand Up @@ -894,6 +895,7 @@ static const KernelBox* get_box(const KernelHandle& h) {
);
}

// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle)
static HandleRegistry<CUkernel, KernelHandle> kernel_registry;

KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* name) {
Expand Down Expand Up @@ -964,6 +966,7 @@ static const GraphNodeBox* get_box(const GraphNodeHandle& h) {
);
}

// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle)
static HandleRegistry<CUgraphNode, GraphNodeHandle> graph_node_registry;

GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) {
Expand Down
52 changes: 42 additions & 10 deletions cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class AdjacencySetProxy(MutableSet):
def __contains__(self, x):
if not isinstance(x, GraphNode):
return False
return x in (<_AdjacencySetCore>self._core).query()
return (<_AdjacencySetCore>self._core).contains(<GraphNode>x)

def __iter__(self):
return iter((<_AdjacencySetCore>self._core).query())
Expand Down Expand Up @@ -87,13 +87,13 @@ class AdjacencySetProxy(MutableSet):
if isinstance(other, GraphNode):
nodes.append(other)
else:
nodes.extend(other)
for n in other:
if not isinstance(n, GraphNode):
raise TypeError(
f"expected GraphNode, got {type(n).__name__}")
nodes.append(n)
if not nodes:
return
for n in nodes:
if not isinstance(n, GraphNode):
raise TypeError(
f"expected GraphNode, got {type(n).__name__}")
new = [n for n in nodes if n not in self]
if new:
(<_AdjacencySetCore>self._core).add_edges(new)
Expand Down Expand Up @@ -143,11 +143,14 @@ cdef class _AdjacencySetCore:
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
if c_node == NULL:
return []
cdef size_t count = 0
cdef cydriver.CUgraphNode buf[16]
cdef size_t count = 16
cdef size_t i
with nogil:
HANDLE_RETURN(self._query_fn(c_node, NULL, &count))
if count == 0:
return []
HANDLE_RETURN(self._query_fn(c_node, buf, &count))
if count <= 16:
return [GraphNode._create(self._h_graph, buf[i])
for i in range(count)]
cdef vector[cydriver.CUgraphNode] nodes_vec
nodes_vec.resize(count)
with nogil:
Expand All @@ -156,6 +159,35 @@ cdef class _AdjacencySetCore:
return [GraphNode._create(self._h_graph, nodes_vec[i])
for i in range(count)]

cdef bint contains(self, GraphNode other):
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
cdef cydriver.CUgraphNode target = as_cu(other._h_node)
if c_node == NULL or target == NULL:
return False
cdef cydriver.CUgraphNode buf[16]
cdef size_t count = 16
cdef size_t i
with nogil:
HANDLE_RETURN(self._query_fn(c_node, buf, &count))

# Fast path for small sets.
if count <= 16:
for i in range(count):
if buf[i] == target:
return True
return False

# Fallback for large sets.
cdef vector[cydriver.CUgraphNode] nodes_vec
nodes_vec.resize(count)
with nogil:
HANDLE_RETURN(self._query_fn(c_node, nodes_vec.data(), &count))
assert count == nodes_vec.size()
for i in range(count):
if nodes_vec[i] == target:
return True
return False

cdef Py_ssize_t count(self):
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
if c_node == NULL:
Expand Down
55 changes: 36 additions & 19 deletions cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ cdef class GraphDef:
"""
return self._entry.launch(config, kernel, *args)

def empty(self) -> "EmptyNode":
"""Add an entry-point empty node (no dependencies).

Returns
-------
EmptyNode
A new EmptyNode with no dependencies.
"""
return self._entry.join()

def join(self, *nodes) -> "EmptyNode":
"""Create an empty node that depends on all given nodes.

Expand Down Expand Up @@ -322,18 +332,20 @@ cdef class GraphDef:
set of GraphNode
All nodes in the graph.
"""
cdef size_t num_nodes = 0
cdef vector[cydriver.CUgraphNode] nodes_vec
nodes_vec.resize(128)
cdef size_t num_nodes = 128

with nogil:
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), NULL, &num_nodes))
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))

if num_nodes == 0:
return set()

cdef vector[cydriver.CUgraphNode] nodes_vec
nodes_vec.resize(num_nodes)
with nogil:
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))
if num_nodes > 128:
nodes_vec.resize(num_nodes)
with nogil:
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))

return set(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes))

Expand All @@ -346,21 +358,12 @@ cdef class GraphDef:
Each element is a (from_node, to_node) pair representing
a dependency edge in the graph.
"""
cdef size_t num_edges = 0

with nogil:
IF CUDA_CORE_BUILD_MAJOR >= 13:
HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, NULL, &num_edges))
ELSE:
HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, &num_edges))

if num_edges == 0:
return set()

cdef vector[cydriver.CUgraphNode] from_nodes
cdef vector[cydriver.CUgraphNode] to_nodes
from_nodes.resize(num_edges)
to_nodes.resize(num_edges)
from_nodes.resize(128)
to_nodes.resize(128)
cdef size_t num_edges = 128

with nogil:
IF CUDA_CORE_BUILD_MAJOR >= 13:
HANDLE_RETURN(cydriver.cuGraphGetEdges(
Expand All @@ -369,6 +372,20 @@ cdef class GraphDef:
HANDLE_RETURN(cydriver.cuGraphGetEdges(
as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges))

if num_edges == 0:
return set()

if num_edges > 128:
from_nodes.resize(num_edges)
to_nodes.resize(num_edges)
with nogil:
IF CUDA_CORE_BUILD_MAJOR >= 13:
HANDLE_RETURN(cydriver.cuGraphGetEdges(
as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), NULL, &num_edges))
ELSE:
HANDLE_RETURN(cydriver.cuGraphGetEdges(
as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges))

return set(
(GraphNode._create(self._h_graph, from_nodes[i]),
GraphNode._create(self._h_graph, to_nodes[i]))
Expand Down
1 change: 1 addition & 0 deletions cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ from cuda.core import Device
from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy
from cuda.core._utils.cuda_utils import driver, handle_return

# See _cpp/REGISTRY_DESIGN.md (Level 2: Resource Handle -> Python Object)
_node_registry = weakref.WeakValueDictionary()


Expand Down
Loading
Loading