From 3372b05dc9f0ea6308fe73f7aa1c6233a00fd9db Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 26 May 2026 22:52:30 +0200 Subject: [PATCH] perf: cache lexicographic chunk coords in sharding codec The subchunk_write_order feature (#3826) regressed sharded write performance: _encode_partial_single rebuilt the full per-shard chunk coordinate grid on every write via `np.array(list(_subchunk_order_iter(..., "lexicographic")))`, and `to_dict_vectorized` rebuilt a tuple key per row with `tuple(coords.ravel())`. For a single-chunk write into a shard with tens of thousands of chunks this roughly doubled write time (~22ms -> ~40ms on test_sharded_morton_write_single_chunk, matching the -44% CodSpeed regression). Add cached `_lexicographic_order` (array) and `_lexicographic_order_keys` (tuples) helpers in indexing.py, mirroring `_morton_order`/`_morton_order_keys`, and pass the cached keys into `to_dict_vectorized` instead of deriving them row-by-row. This restores write throughput to the pre-#3826 baseline while preserving identical chunk ordering (verified equal to np.ndindex across shapes including 0-d and empty). Co-Authored-By: Claude Opus 4.7 (1M context) --- changes/4001.misc.md | 4 ++++ src/zarr/codecs/sharding.py | 24 ++++++++++++++++++------ src/zarr/core/indexing.py | 27 +++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 6 deletions(-) create mode 100644 changes/4001.misc.md diff --git a/changes/4001.misc.md b/changes/4001.misc.md new file mode 100644 index 0000000000..e90f16a9e8 --- /dev/null +++ b/changes/4001.misc.md @@ -0,0 +1,4 @@ +Restore sharding write performance for shards with many chunks. The +`subchunk_write_order` feature inadvertently rebuilt the per-shard chunk +coordinate grid (up to tens of thousands of tuples) on every partial write; +these coordinates are now cached, restoring throughput to its previous level. diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 33c8602ecb..535c682053 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -46,8 +46,11 @@ BasicIndexer, ChunkProjection, SelectorTuple, + _lexicographic_order, + _lexicographic_order_keys, c_order_iter, get_indexer, + lexicographic_order_iter, morton_order_iter, ) from zarr.core.metadata.v3 import ( @@ -266,6 +269,7 @@ def __iter__(self) -> Iterator[tuple[int, ...]]: def to_dict_vectorized( self, chunk_coords_array: npt.NDArray[np.integer[Any]], + chunk_coords_keys: tuple[tuple[int, ...], ...], ) -> dict[tuple[int, ...], Buffer | None]: """Build a dict of chunk coordinates to buffers using vectorized lookup. @@ -273,6 +277,11 @@ def to_dict_vectorized( ---------- chunk_coords_array : ndarray of shape (n_chunks, n_dims) Array of chunk coordinates for vectorized index lookup. + chunk_coords_keys : tuple of coordinate tuples + The same coordinates as `chunk_coords_array`, in the same order, as + plain tuples for use as dict keys. Passed in (rather than derived + row-by-row from the array) so the cached value can be reused instead + of rebuilding 35k tuples on every write. Returns ------- @@ -281,11 +290,11 @@ def to_dict_vectorized( starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array) result: dict[tuple[int, ...], Buffer | None] = {} - for i, coords in enumerate(chunk_coords_array): + for i, coords in enumerate(chunk_coords_keys): if valid[i]: - result[tuple(coords.ravel())] = self.buf[int(starts[i]) : int(ends[i])] + result[coords] = self.buf[int(starts[i]) : int(ends[i])] else: - result[tuple(coords.ravel())] = None + result[coords] = None return result @@ -533,7 +542,7 @@ def _subchunk_order_iter( case "morton": subchunk_iter = morton_order_iter(chunks_per_shard) case "lexicographic": - subchunk_iter = np.ndindex(chunks_per_shard) + subchunk_iter = lexicographic_order_iter(chunks_per_shard) case "colexicographic": subchunk_iter = (c[::-1] for c in np.ndindex(chunks_per_shard[::-1])) case "unordered": @@ -612,9 +621,12 @@ async def _encode_partial_single( chunks_per_shard=chunks_per_shard, ) shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) - # Use vectorized lookup for better performance + # Use vectorized lookup for better performance. The lexicographic + # coordinate array and keys are cached, so neither is rebuilt on + # every write. shard_dict = shard_reader.to_dict_vectorized( - np.array(list(self._subchunk_order_iter(chunks_per_shard, "lexicographic"))) + _lexicographic_order(chunks_per_shard), + _lexicographic_order_keys(chunks_per_shard), ) await self.codec_pipeline.write( diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index cb81164209..ab658a4924 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -1584,6 +1584,33 @@ def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]] return iter(_morton_order_keys(tuple(chunk_shape))) +@lru_cache(maxsize=16) +def _lexicographic_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]: + # Lexicographic (C-order) coordinates, computed vectorized and cached so that + # the sharding codec's per-shard chunk grid is not rebuilt on every call. + # Equivalent to `np.array(list(np.ndindex(chunk_shape)))` but without the + # Python-level iteration over every coordinate. + n_dims = len(chunk_shape) + if n_dims == 0: + # A 0-d shard holds a single chunk addressed by the empty coordinate, so + # the coordinate array has one row and zero columns. np.indices(()) cannot + # express this, so build it directly. Matches list(np.ndindex(())) == [()]. + order = np.empty((1, 0), dtype=np.intp) + else: + order = np.indices(chunk_shape, dtype=np.intp).reshape(n_dims, -1).T + order.flags.writeable = False + return order + + +@lru_cache(maxsize=16) +def _lexicographic_order_keys(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: + return tuple(tuple(int(x) for x in row) for row in _lexicographic_order(chunk_shape)) + + +def lexicographic_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: + return iter(_lexicographic_order_keys(tuple(chunk_shape))) + + def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]: return itertools.product(*(range(x) for x in chunks_per_shard))