diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index fbba9a8c..b1084fb1 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -27,6 +27,7 @@ from xarray import DataArray, Dataset from xarray.core.groupby import GroupBy from xarray.core.resample import Resample +from xarray.core.utils import Frozen try: from xarray.core.rolling import ( # type:ignore[import-not-found,no-redef,unused-ignore] @@ -495,7 +496,10 @@ def _get_bounds(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]: return list(results) -def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[Hashable]]: +@functools.lru_cache(maxsize=256) +def _parse_grid_mapping_attribute( + grid_mapping_attr: str, +) -> Mapping[str, list[Hashable]]: """ Parse a grid_mapping attribute that may contain multiple grid mappings. @@ -507,11 +511,12 @@ def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[Hash - Multiple: "spatial_ref: crs_4326: latitude longitude crs_27700: x27700 y27700" -> {"spatial_ref": [], "crs_4326": ["latitude", "longitude"], "crs_27700": ["x27700", "y27700"]} - Returns a dictionary mapping grid mapping variable names to their associated coordinate variables. + Returns a read-only mapping from grid mapping variable name to its associated + coordinate variables. The result is memoized, so callers must not mutate it. """ # Check if there are colons indicating multiple mappings if ":" not in grid_mapping_attr: - return {grid_mapping_attr.strip(): []} + return Frozen({grid_mapping_attr.strip(): []}) # Use regex to parse the format # First, find all grid mapping variables (words before colons) @@ -519,7 +524,7 @@ def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[Hash grid_mappings = re.findall(grid_pattern, grid_mapping_attr) if not grid_mappings: - return {grid_mapping_attr.strip(): []} + return Frozen({grid_mapping_attr.strip(): []}) result: dict[str, list[Hashable]] = {} @@ -548,13 +553,13 @@ def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[Hash else: result[gm] = [] - return result + return Frozen(result) def _create_grid_mapping( var_name: str, ds: Dataset, - grid_mapping_dict: dict[str, list[Hashable]], + grid_mapping_dict: Mapping[str, list[Hashable]], ) -> GridMapping: """ Create a GridMapping dataclass instance from a grid mapping variable. @@ -1007,7 +1012,11 @@ def _getattr( newmap.update(dict.fromkeys(inverted[key], value)) newmap.update({key: attribute[key] for key in unused_keys}) - skip: dict[str, list[Literal["coords", "measures"]] | None] = { + skip: dict[ + str, + list[Literal["coords", "measures", "grid_mapping_names", "geometries"]] + | None, + ] = { "data_vars": ["coords"], "coords": None, } @@ -1048,7 +1057,8 @@ def wrapper(*args, **kwargs): def _getitem( accessor: CFAccessor, key: Hashable, - skip: list[Literal["coords", "measures"]] | None = None, + skip: list[Literal["coords", "measures", "grid_mapping_names", "geometries"]] + | None = None, ) -> DataArray: ... @@ -1056,14 +1066,16 @@ def _getitem( def _getitem( accessor: CFAccessor, key: Iterable[Hashable], - skip: list[Literal["coords", "measures"]] | None = None, + skip: list[Literal["coords", "measures", "grid_mapping_names", "geometries"]] + | None = None, ) -> Dataset: ... def _getitem( accessor: CFAccessor, key: Hashable | Iterable[Hashable], - skip: list[Literal["coords", "measures"]] | None = None, + skip: list[Literal["coords", "measures", "grid_mapping_names", "geometries"]] + | None = None, ): """ Index into obj using key. Attaches CF associated variables. @@ -1077,9 +1089,14 @@ def _getitem( """ obj = accessor._obj - all_bounds = obj.cf.bounds if isinstance(obj, Dataset) else {} kind = str(type(obj).__name__) scalar_key = isinstance(key, Hashable) + # obj.cf.bounds is expensive; only compute it when scalar lookup on a + # Dataset actually needs to drop bounds variables. + if not isinstance(obj, DataArray) and scalar_key: + all_bounds = obj.cf.bounds + else: + all_bounds = {} key_iter: Iterable[Hashable] if isinstance(key, Hashable): # using scalar_key breaks mypy type narrowing @@ -1127,60 +1144,90 @@ def check_results(names, key): custom_criteria = ChainMap(*OPTIONS["custom_criteria"]) - varnames: list[Hashable] = [] - coords: list[Hashable] = [] - successful = dict.fromkeys(key_iter, False) - for k in key_iter: - if "coords" not in skip and k in _AXIS_NAMES + _COORD_NAMES: - names = _get_all(obj, k) - names = drop_bounds(names) - check_results(names, k) - successful[k] = bool(names) - coords.extend(names) - elif "measures" not in skip and k in measures: - measure = _get_all(obj, k) - check_results(measure, k) - successful[k] = bool(measure) - if measure: - varnames.extend(measure) - elif "grid_mapping_names" not in skip and k in grid_mapping_names: - grid_mapping = _get_all(obj, k) - check_results(grid_mapping, k) - successful[k] = bool(grid_mapping) - if grid_mapping: - varnames.extend(grid_mapping) - elif "geometries" not in skip and (k == "geometry" or k in _GEOMETRY_TYPES): - geometries = _get_all(obj, k) - if geometries and k in _GEOMETRY_TYPES: - new = itertools.chain( - _parse_related_geometry_vars( - ChainMap(obj[g].attrs, obj[g].encoding) + # Fast path: when every key is just a plain variable in the Dataset (no CF + # special meaning), skip the per-key classification loop below. Test the + # cheap predicates first and only build the reserved-name set / consult + # accessor.standard_names (a full attrs scan) if those could pass — the + # common ds.cf["X"] / ds.cf["longitude"] paths must not pay that cost. + fast_path = ( + isinstance(obj, Dataset) + and not skip + and all(k in obj._variables for k in key_iter) + ) + if fast_path: + reserved: set[Hashable] = set(_AXIS_NAMES).union( + _COORD_NAMES, + _GEOMETRY_TYPES, + ("geometry",), + measures, + grid_mapping_names, + custom_criteria, + cf_role_criteria, + ) + standard_names = accessor.standard_names + fast_path = all(k not in reserved and k not in standard_names for k in key_iter) + + varnames: list[Hashable] + coords: list[Hashable] + if fast_path: + varnames = list(key_iter) + coords = [] + successful = dict.fromkeys(key_iter, True) + else: + varnames = [] + coords = [] + successful = dict.fromkeys(key_iter, False) + for k in key_iter: + if "coords" not in skip and k in _AXIS_NAMES + _COORD_NAMES: + names = _get_all(obj, k) + names = drop_bounds(names) + check_results(names, k) + successful[k] = bool(names) + coords.extend(names) + elif "measures" not in skip and k in measures: + measure = _get_all(obj, k) + check_results(measure, k) + successful[k] = bool(measure) + if measure: + varnames.extend(measure) + elif "grid_mapping_names" not in skip and k in grid_mapping_names: + grid_mapping = _get_all(obj, k) + check_results(grid_mapping, k) + successful[k] = bool(grid_mapping) + if grid_mapping: + varnames.extend(grid_mapping) + elif "geometries" not in skip and (k == "geometry" or k in _GEOMETRY_TYPES): + geometries = _get_all(obj, k) + if geometries and k in _GEOMETRY_TYPES: + new = itertools.chain( + _parse_related_geometry_vars( + ChainMap(obj[g].attrs, obj[g].encoding) + ) + for g in geometries ) - for g in geometries - ) - geometries.extend(*new) - if len(geometries) > 1 and scalar_key: - raise ValueError( - f"CF geometries must be represented by an Xarray Dataset. To request a Dataset in return please pass `[{k!r}]` instead." - ) - successful[k] = bool(geometries) - if geometries: - varnames.extend(geometries) - elif k in custom_criteria or k in cf_role_criteria: - names = _get_all(obj, k) - check_results(names, k) - successful[k] = bool(names) - varnames.extend(names) - else: - stdnames = set(_get_with_standard_name(obj, k)) - objcoords = set(obj.coords) - stdnames = drop_bounds(stdnames) - if "coords" in skip: - stdnames -= objcoords - check_results(stdnames, k) - successful[k] = bool(stdnames) - varnames.extend(stdnames - objcoords) - coords.extend(stdnames & objcoords) + geometries.extend(*new) + if len(geometries) > 1 and scalar_key: + raise ValueError( + f"CF geometries must be represented by an Xarray Dataset. To request a Dataset in return please pass `[{k!r}]` instead." + ) + successful[k] = bool(geometries) + if geometries: + varnames.extend(geometries) + elif k in custom_criteria or k in cf_role_criteria: + names = _get_all(obj, k) + check_results(names, k) + successful[k] = bool(names) + varnames.extend(names) + else: + stdnames = set(_get_with_standard_name(obj, k)) + objcoords = set(obj.coords) + stdnames = drop_bounds(stdnames) + if "coords" in skip: + stdnames -= objcoords + check_results(stdnames, k) + successful[k] = bool(stdnames) + varnames.extend(stdnames - objcoords) + coords.extend(stdnames & objcoords) # these are not special names but could be variable names in underlying object # we allow this so that we can return variables with appropriate CF auxiliary variables @@ -2042,7 +2089,7 @@ def cell_measures(self) -> dict[str, list[Hashable]]: ] as_dataset = self._maybe_to_dataset().reset_coords() - keys = {} + keys: dict[str, str] = {} for attr in set(all_attrs): try: keys.update(parse_cell_methods_attr(attr)) diff --git a/cf_xarray/utils.py b/cf_xarray/utils.py index bdc2605e..a9e6c1ba 100644 --- a/cf_xarray/utils.py +++ b/cf_xarray/utils.py @@ -1,13 +1,15 @@ +import functools import inspect import os import warnings from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from typing import Any from xml.etree import ElementTree import numpy as np from xarray import DataArray +from xarray.core.utils import Frozen try: import cftime @@ -64,7 +66,8 @@ def _is_datetime_like(da: DataArray) -> bool: return False -def parse_cell_methods_attr(attr: str) -> dict[str, str]: +@functools.lru_cache(maxsize=256) +def parse_cell_methods_attr(attr: str) -> Mapping[str, str]: """ Parse cell_methods attributes (format is 'measure: name'). @@ -75,14 +78,14 @@ def parse_cell_methods_attr(attr: str) -> dict[str, str]: Returns ------- - Dictionary mapping measure to name + Read-only mapping from measure to name. """ strings = [s for scolons in attr.split(":") for s in scolons.split()] if len(strings) % 2 != 0: raise ValueError(f"attrs['cell_measures'] = {attr!r} is malformed.") - return dict( - zip(strings[slice(0, None, 2)], strings[slice(1, None, 2)], strict=False) + return Frozen( + dict(zip(strings[slice(0, None, 2)], strings[slice(1, None, 2)], strict=False)) )