Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 112 additions & 65 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from xarray import DataArray, Dataset
from xarray.core.groupby import GroupBy
from xarray.core.resample import Resample
from xarray.core.utils import Frozen

try:
from xarray.core.rolling import ( # type:ignore[import-not-found,no-redef,unused-ignore]
Expand Down Expand Up @@ -495,7 +496,10 @@ def _get_bounds(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
return list(results)


def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[Hashable]]:
@functools.lru_cache(maxsize=256)
def _parse_grid_mapping_attribute(
grid_mapping_attr: str,
) -> Mapping[str, list[Hashable]]:
"""
Parse a grid_mapping attribute that may contain multiple grid mappings.

Expand All @@ -507,19 +511,20 @@ def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[Hash
- Multiple: "spatial_ref: crs_4326: latitude longitude crs_27700: x27700 y27700"
-> {"spatial_ref": [], "crs_4326": ["latitude", "longitude"], "crs_27700": ["x27700", "y27700"]}

Returns a dictionary mapping grid mapping variable names to their associated coordinate variables.
Returns a read-only mapping from grid mapping variable name to its associated
coordinate variables. The result is memoized, so callers must not mutate it.
"""
# Check if there are colons indicating multiple mappings
if ":" not in grid_mapping_attr:
return {grid_mapping_attr.strip(): []}
return Frozen({grid_mapping_attr.strip(): []})

# Use regex to parse the format
# First, find all grid mapping variables (words before colons)
grid_pattern = r"(?:^|\s)([a-zA-Z_][a-zA-Z0-9_]*)(?=\s*:)"
grid_mappings = re.findall(grid_pattern, grid_mapping_attr)

if not grid_mappings:
return {grid_mapping_attr.strip(): []}
return Frozen({grid_mapping_attr.strip(): []})

result: dict[str, list[Hashable]] = {}

Expand Down Expand Up @@ -548,13 +553,13 @@ def _parse_grid_mapping_attribute(grid_mapping_attr: str) -> dict[str, list[Hash
else:
result[gm] = []

return result
return Frozen(result)


def _create_grid_mapping(
var_name: str,
ds: Dataset,
grid_mapping_dict: dict[str, list[Hashable]],
grid_mapping_dict: Mapping[str, list[Hashable]],
) -> GridMapping:
"""
Create a GridMapping dataclass instance from a grid mapping variable.
Expand Down Expand Up @@ -1007,7 +1012,11 @@ def _getattr(
newmap.update(dict.fromkeys(inverted[key], value))
newmap.update({key: attribute[key] for key in unused_keys})

skip: dict[str, list[Literal["coords", "measures"]] | None] = {
skip: dict[
str,
list[Literal["coords", "measures", "grid_mapping_names", "geometries"]]
| None,
] = {
"data_vars": ["coords"],
"coords": None,
}
Expand Down Expand Up @@ -1048,22 +1057,25 @@ def wrapper(*args, **kwargs):
def _getitem(
accessor: CFAccessor,
key: Hashable,
skip: list[Literal["coords", "measures"]] | None = None,
skip: list[Literal["coords", "measures", "grid_mapping_names", "geometries"]]
| None = None,
) -> DataArray: ...


@overload
def _getitem(
accessor: CFAccessor,
key: Iterable[Hashable],
skip: list[Literal["coords", "measures"]] | None = None,
skip: list[Literal["coords", "measures", "grid_mapping_names", "geometries"]]
| None = None,
) -> Dataset: ...


def _getitem(
accessor: CFAccessor,
key: Hashable | Iterable[Hashable],
skip: list[Literal["coords", "measures"]] | None = None,
skip: list[Literal["coords", "measures", "grid_mapping_names", "geometries"]]
| None = None,
):
"""
Index into obj using key. Attaches CF associated variables.
Expand All @@ -1077,9 +1089,14 @@ def _getitem(
"""

obj = accessor._obj
all_bounds = obj.cf.bounds if isinstance(obj, Dataset) else {}
kind = str(type(obj).__name__)
scalar_key = isinstance(key, Hashable)
# obj.cf.bounds is expensive; only compute it when scalar lookup on a
# Dataset actually needs to drop bounds variables.
if not isinstance(obj, DataArray) and scalar_key:
all_bounds = obj.cf.bounds
else:
all_bounds = {}

key_iter: Iterable[Hashable]
if isinstance(key, Hashable): # using scalar_key breaks mypy type narrowing
Expand Down Expand Up @@ -1127,60 +1144,90 @@ def check_results(names, key):

custom_criteria = ChainMap(*OPTIONS["custom_criteria"])

varnames: list[Hashable] = []
coords: list[Hashable] = []
successful = dict.fromkeys(key_iter, False)
for k in key_iter:
if "coords" not in skip and k in _AXIS_NAMES + _COORD_NAMES:
names = _get_all(obj, k)
names = drop_bounds(names)
check_results(names, k)
successful[k] = bool(names)
coords.extend(names)
elif "measures" not in skip and k in measures:
measure = _get_all(obj, k)
check_results(measure, k)
successful[k] = bool(measure)
if measure:
varnames.extend(measure)
elif "grid_mapping_names" not in skip and k in grid_mapping_names:
grid_mapping = _get_all(obj, k)
check_results(grid_mapping, k)
successful[k] = bool(grid_mapping)
if grid_mapping:
varnames.extend(grid_mapping)
elif "geometries" not in skip and (k == "geometry" or k in _GEOMETRY_TYPES):
geometries = _get_all(obj, k)
if geometries and k in _GEOMETRY_TYPES:
new = itertools.chain(
_parse_related_geometry_vars(
ChainMap(obj[g].attrs, obj[g].encoding)
# Fast path: when every key is just a plain variable in the Dataset (no CF
# special meaning), skip the per-key classification loop below. Test the
# cheap predicates first and only build the reserved-name set / consult
# accessor.standard_names (a full attrs scan) if those could pass — the
# common ds.cf["X"] / ds.cf["longitude"] paths must not pay that cost.
fast_path = (
isinstance(obj, Dataset)
and not skip
and all(k in obj._variables for k in key_iter)
)
if fast_path:
reserved: set[Hashable] = set(_AXIS_NAMES).union(
_COORD_NAMES,
_GEOMETRY_TYPES,
("geometry",),
measures,
grid_mapping_names,
custom_criteria,
cf_role_criteria,
)
standard_names = accessor.standard_names
fast_path = all(k not in reserved and k not in standard_names for k in key_iter)

varnames: list[Hashable]
coords: list[Hashable]
if fast_path:
varnames = list(key_iter)
coords = []
successful = dict.fromkeys(key_iter, True)
else:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simply indented

varnames = []
coords = []
successful = dict.fromkeys(key_iter, False)
for k in key_iter:
if "coords" not in skip and k in _AXIS_NAMES + _COORD_NAMES:
names = _get_all(obj, k)
names = drop_bounds(names)
check_results(names, k)
successful[k] = bool(names)
coords.extend(names)
elif "measures" not in skip and k in measures:
measure = _get_all(obj, k)
check_results(measure, k)
successful[k] = bool(measure)
if measure:
varnames.extend(measure)
elif "grid_mapping_names" not in skip and k in grid_mapping_names:
grid_mapping = _get_all(obj, k)
check_results(grid_mapping, k)
successful[k] = bool(grid_mapping)
if grid_mapping:
varnames.extend(grid_mapping)
elif "geometries" not in skip and (k == "geometry" or k in _GEOMETRY_TYPES):
geometries = _get_all(obj, k)
if geometries and k in _GEOMETRY_TYPES:
new = itertools.chain(
_parse_related_geometry_vars(
ChainMap(obj[g].attrs, obj[g].encoding)
)
for g in geometries
)
for g in geometries
)
geometries.extend(*new)
if len(geometries) > 1 and scalar_key:
raise ValueError(
f"CF geometries must be represented by an Xarray Dataset. To request a Dataset in return please pass `[{k!r}]` instead."
)
successful[k] = bool(geometries)
if geometries:
varnames.extend(geometries)
elif k in custom_criteria or k in cf_role_criteria:
names = _get_all(obj, k)
check_results(names, k)
successful[k] = bool(names)
varnames.extend(names)
else:
stdnames = set(_get_with_standard_name(obj, k))
objcoords = set(obj.coords)
stdnames = drop_bounds(stdnames)
if "coords" in skip:
stdnames -= objcoords
check_results(stdnames, k)
successful[k] = bool(stdnames)
varnames.extend(stdnames - objcoords)
coords.extend(stdnames & objcoords)
geometries.extend(*new)
if len(geometries) > 1 and scalar_key:
raise ValueError(
f"CF geometries must be represented by an Xarray Dataset. To request a Dataset in return please pass `[{k!r}]` instead."
)
successful[k] = bool(geometries)
if geometries:
varnames.extend(geometries)
elif k in custom_criteria or k in cf_role_criteria:
names = _get_all(obj, k)
check_results(names, k)
successful[k] = bool(names)
varnames.extend(names)
else:
stdnames = set(_get_with_standard_name(obj, k))
objcoords = set(obj.coords)
stdnames = drop_bounds(stdnames)
if "coords" in skip:
stdnames -= objcoords
check_results(stdnames, k)
successful[k] = bool(stdnames)
varnames.extend(stdnames - objcoords)
coords.extend(stdnames & objcoords)

# these are not special names but could be variable names in underlying object
# we allow this so that we can return variables with appropriate CF auxiliary variables
Expand Down Expand Up @@ -2042,7 +2089,7 @@ def cell_measures(self) -> dict[str, list[Hashable]]:
]
as_dataset = self._maybe_to_dataset().reset_coords()

keys = {}
keys: dict[str, str] = {}
for attr in set(all_attrs):
try:
keys.update(parse_cell_methods_attr(attr))
Expand Down
13 changes: 8 additions & 5 deletions cf_xarray/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import functools
import inspect
import os
import warnings
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from typing import Any
from xml.etree import ElementTree

import numpy as np
from xarray import DataArray
from xarray.core.utils import Frozen

try:
import cftime
Expand Down Expand Up @@ -64,7 +66,8 @@ def _is_datetime_like(da: DataArray) -> bool:
return False


def parse_cell_methods_attr(attr: str) -> dict[str, str]:
@functools.lru_cache(maxsize=256)
def parse_cell_methods_attr(attr: str) -> Mapping[str, str]:
"""
Parse cell_methods attributes (format is 'measure: name').

Expand All @@ -75,14 +78,14 @@ def parse_cell_methods_attr(attr: str) -> dict[str, str]:

Returns
-------
Dictionary mapping measure to name
Read-only mapping from measure to name.
"""
strings = [s for scolons in attr.split(":") for s in scolons.split()]
if len(strings) % 2 != 0:
raise ValueError(f"attrs['cell_measures'] = {attr!r} is malformed.")

return dict(
zip(strings[slice(0, None, 2)], strings[slice(1, None, 2)], strict=False)
return Frozen(
dict(zip(strings[slice(0, None, 2)], strings[slice(1, None, 2)], strict=False))
)


Expand Down
Loading