diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 0902b667..89eb3f91 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -13,7 +13,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: 3.11 - uses: actions/cache@v4 with: key: ${{ github.ref }} diff --git a/README.md b/README.md index 3d9e39cc..bda81dab 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ See the [docs](https://cesnet.github.io/dp3/howto/get-started/) for more details ### Installing for application development -Pre-requisites: Python 3.9 or higher, `pip` (with `virtualenv` installed), `git`, `Docker` and `Docker Compose`. +Pre-requisites: Python 3.11 or higher, `pip` (with `virtualenv` installed), `git`, `Docker` and `Docker Compose`. Create a virtualenv and install the DP³ platform using: @@ -117,7 +117,7 @@ You are now ready to start developing your application! ## Installing for platform development -Pre-requisites: Python 3.9 or higher, `pip` (with `virtualenv` installed), `git`, `Docker` and `Docker Compose`. +Pre-requisites: Python 3.11 or higher, `pip` (with `virtualenv` installed), `git`, `Docker` and `Docker Compose`. Pull the repository and install using: diff --git a/docker/python/Dockerfile b/docker/python/Dockerfile index aea62099..426ae6d6 100644 --- a/docker/python/Dockerfile +++ b/docker/python/Dockerfile @@ -1,7 +1,7 @@ # syntax=docker/dockerfile:1 # Base interpreter with installed requirements -FROM python:3.9-slim AS base +FROM python:3.11-slim AS base RUN apt-get update; apt-get install -y \ gcc \ git diff --git a/docs/howto/develop-dp3.md b/docs/howto/develop-dp3.md index 070cbffd..a29b5f31 100644 --- a/docs/howto/develop-dp3.md +++ b/docs/howto/develop-dp3.md @@ -16,7 +16,7 @@ You will end up with: For platform development, you need: -- Python 3.9 or higher +- Python 3.11 or higher - `pip` - `git` - Docker diff --git a/docs/howto/get-started.md b/docs/howto/get-started.md index 53c89818..ce6e6e19 100644 --- a/docs/howto/get-started.md +++ b/docs/howto/get-started.md @@ -15,7 +15,7 @@ You will end up with: For local application development, you need: -- Python 3.9 or higher +- Python 3.11 or higher - `pip` - `git` - Docker diff --git a/dp3/api/internal/config.py b/dp3/api/internal/config.py index d7ef4626..965fbda5 100644 --- a/dp3/api/internal/config.py +++ b/dp3/api/internal/config.py @@ -10,12 +10,15 @@ from dp3.api.internal.dp_logger import DPLogger from dp3.common.config import ModelSpec, read_config_dir +from dp3.common.utils import suppress_dependency_loggers from dp3.database.database import EntityDatabase from dp3.history_management.telemetry import TelemetryReader from dp3.task_processing.task_queue import TaskQueueWriter DATAPOINTS_INGESTION_URL_PATH = "/datapoints" +suppress_dependency_loggers() + class ConfigEnv(BaseModel): """Configuration environment variables container""" @@ -44,7 +47,7 @@ def validate(cls, v): try: # Validate and parse environmental variables - conf_env = ConfigEnv.parse_obj(os.environ) + conf_env = ConfigEnv.model_validate(os.environ) except ValidationError as e: config_error = any("CONF_DIR" in x["loc"] and len(x["loc"]) > 1 for x in e.errors()) env_error = any(len(x["loc"]) == 1 for x in e.errors()) diff --git a/dp3/api/internal/entity_response_models.py b/dp3/api/internal/entity_response_models.py index 89adba14..c483ae87 100644 --- a/dp3/api/internal/entity_response_models.py +++ b/dp3/api/internal/entity_response_models.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any from pydantic import BaseModel, Field, NonNegativeInt, PlainSerializer @@ -25,11 +25,11 @@ class EntityState(BaseModel): JsonVal = Annotated[Any, PlainSerializer(to_json_friendly, when_used="json")] LinkVal = dict[str, JsonVal] -PlainVal = Union[LinkVal, JsonVal] +PlainVal = LinkVal | JsonVal MultiVal = list[PlainVal] HistoryVal = list[dict[str, PlainVal]] -Dp3Val = Union[HistoryVal, MultiVal, PlainVal] +Dp3Val = HistoryVal | MultiVal | PlainVal EntityEidMasterRecord = dict[str, Dp3Val] @@ -45,7 +45,7 @@ class EntityEidList(BaseModel): Data does not include history of observations attributes and timeseries. """ - time_created: Optional[datetime] = None + time_created: datetime | None = None count: int data: EntityEidSnapshots diff --git a/dp3/api/internal/models.py b/dp3/api/internal/models.py index 65918730..79f0dda8 100644 --- a/dp3/api/internal/models.py +++ b/dp3/api/internal/models.py @@ -1,4 +1,6 @@ -from typing import Annotated, Any, Literal, Optional, Union +from functools import reduce +from operator import or_ +from typing import Annotated, Any, Literal from pydantic import BaseModel, Field, TypeAdapter, create_model, model_validator @@ -26,10 +28,10 @@ class DataPoint(BaseModel): id: Any attr: str v: Any - t1: Optional[AwareDatetime] = None - t2: Optional[T2Datetime] = Field(None, validate_default=True) + t1: AwareDatetime | None = None + t2: T2Datetime | None = Field(None, validate_default=True) c: Annotated[float, Field(ge=0.0, le=1.0)] = 1.0 - src: Optional[str] = None + src: str | None = None @model_validator(mode="after") def validate_against_attribute(self): @@ -43,14 +45,14 @@ def validate_against_attribute(self): class EntityId(BaseModel): - """Dummy model for entity id + """Common interface for validated entity identifiers. Attributes: type: Entity type id: Entity ID """ - type: Literal["entity_type"] + type: str id: Any @@ -60,11 +62,14 @@ class EntityId(BaseModel): entity_id_models.append( create_model( f"EntityId{{{entity_type}}}", - __base__=BaseModel, + __base__=EntityId, type=(Literal[entity_type], Field(..., alias="etype")), id=(dtype, Field(..., alias="eid")), ) ) -EntityId = Annotated[Union[tuple(entity_id_models)], Field(discriminator="type")] # noqa: F811 -EntityIdAdapter = TypeAdapter(EntityId) +if not entity_id_models: + raise RuntimeError("At least one entity type must be configured to run the API.") + +EntityIdType = Annotated[reduce(or_, entity_id_models), Field(discriminator="type")] +EntityIdAdapter = TypeAdapter(EntityIdType) diff --git a/dp3/api/routers/entity.py b/dp3/api/routers/entity.py index b952a4bc..209072c6 100644 --- a/dp3/api/routers/entity.py +++ b/dp3/api/routers/entity.py @@ -1,5 +1,5 @@ -from datetime import datetime -from typing import Annotated, Any, Optional +from datetime import UTC, datetime +from typing import Annotated, Any, cast from fastapi import APIRouter, Depends, HTTPException, Request from pydantic import Json, NonNegativeInt, ValidationError @@ -22,7 +22,7 @@ from dp3.common.attrspec import AttrType from dp3.common.datapoint import to_json_friendly from dp3.common.task import DataPointTask, task_context -from dp3.common.types import UTC, AwareDatetime +from dp3.common.types import AwareDatetime from dp3.database.database import DatabaseError @@ -33,10 +33,10 @@ async def check_etype(etype: str): return etype -async def parse_eid(etype: str, eid: str): +async def parse_eid(etype: str, eid: str) -> EntityId: """Middleware to parse EID""" try: - return EntityIdAdapter.validate_python({"etype": etype, "eid": eid}) + return cast(EntityId, EntityIdAdapter.validate_python({"etype": etype, "eid": eid})) except ValidationError as e: raise RequestValidationError(["path", "eid"], e.errors()[0]["msg"]) from e @@ -44,12 +44,12 @@ async def parse_eid(etype: str, eid: str): ParsedEid = Annotated[EntityId, Depends(parse_eid)] -def _parse_optional_eid(etype: str, eid: Optional[str]) -> Any: +def _parse_optional_eid(etype: str, eid: str | None) -> Any: """Parse optional entity id query parameter for entity-scoped endpoints.""" if eid is None: return None try: - return EntityIdAdapter.validate_python({"etype": etype, "eid": eid}).id + return cast(EntityId, EntityIdAdapter.validate_python({"etype": etype, "eid": eid})).id except ValidationError as e: raise RequestValidationError(["query", "eid"], e.errors()[0]["msg"]) from e @@ -72,7 +72,7 @@ def _raw_datapoint_to_response(raw_datapoint: dict[str, Any]) -> dict[str, Any]: def get_eid_master_record_handler( - e: EntityId, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None + e: EntityId, date_from: AwareDatetime | None = None, date_to: AwareDatetime | None = None ): """Handler for getting master record of EID""" # TODO: This is probably not the most efficient way. Maybe gather only @@ -97,8 +97,8 @@ def get_eid_master_record_handler( def get_eid_snapshots_handler( e: EntityId, - date_from: Optional[AwareDatetime] = None, - date_to: Optional[AwareDatetime] = None, + date_from: AwareDatetime | None = None, + date_to: AwareDatetime | None = None, skip: int = 0, limit: int = 0, ) -> list[dict[str, Any]]: @@ -271,9 +271,9 @@ async def count_entity_type_eids( ) async def get_entity_type_raw_datapoints( etype: str, - eid: Optional[str] = None, - attr: Optional[str] = None, - src: Optional[str] = None, + eid: str | None = None, + attr: str | None = None, + src: str | None = None, skip: NonNegativeInt = 0, limit: NonNegativeInt = 20, ) -> EntityRawDataPage: @@ -305,7 +305,7 @@ async def get_entity_type_raw_datapoints( @router.get("/{etype}/{eid}") async def get_eid_data( - e: ParsedEid, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None + e: ParsedEid, date_from: AwareDatetime | None = None, date_to: AwareDatetime | None = None ) -> EntityEidData: """Get data of the entity identified by `etype` and `eid`. @@ -325,7 +325,7 @@ async def get_eid_data( @router.get("/{etype}/{eid}/master") async def get_eid_master_record( - e: ParsedEid, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None + e: ParsedEid, date_from: AwareDatetime | None = None, date_to: AwareDatetime | None = None ) -> EntityEidMasterRecord: """Get the master record of the entity identified by `etype` and `eid`.""" return get_eid_master_record_handler(e, date_from, date_to) @@ -334,8 +334,8 @@ async def get_eid_master_record( @router.get("/{etype}/{eid}/snapshots") async def get_eid_snapshots( e: ParsedEid, - date_from: Optional[AwareDatetime] = None, - date_to: Optional[AwareDatetime] = None, + date_from: AwareDatetime | None = None, + date_to: AwareDatetime | None = None, skip: NonNegativeInt = 0, limit: NonNegativeInt = 0, ) -> EntityEidSnapshots: @@ -351,8 +351,8 @@ async def get_eid_snapshots( async def get_eid_attr_value( e: ParsedEid, attr: str, - date_from: Optional[AwareDatetime] = None, - date_to: Optional[AwareDatetime] = None, + date_from: AwareDatetime | None = None, + date_to: AwareDatetime | None = None, ) -> EntityEidAttrValueOrHistory: """Get attribute value @@ -373,9 +373,7 @@ async def get_eid_attr_value( @router.post("/{etype}/{eid}/set/{attr}") -async def set_eid_attr_value( - etype: str, eid: str, attr: str, body: EntityEidAttrValue, request: Request -) -> SuccessResponse: +async def set_eid_attr_value(etype: str, eid: str, attr: str, request: Request) -> SuccessResponse: """Set current value of attribute Internally just creates datapoint for specified attribute and value. @@ -386,6 +384,11 @@ async def set_eid_attr_value( if attr not in MODEL_SPEC.attribs(etype): raise RequestValidationError(["path", "attr"], f"Attribute '{attr}' doesn't exist") + try: + body = EntityEidAttrValue.model_validate(await request.json()) + except ValueError as e: + raise RequestValidationError(["body"], str(e)) from e + # Construct datapoint try: dp = DataPoint( @@ -396,7 +399,7 @@ async def set_eid_attr_value( t1=datetime.now(UTC), src=f"{request.client.host} via API", ) - dp3_dp = api_to_dp3_datapoint(dp.dict()) + dp3_dp = api_to_dp3_datapoint(dp.model_dump()) except ValidationError as e: raise RequestValidationError(["body", "value"], e.errors()[0]["msg"]) from e diff --git a/dp3/api/routers/telemetry.py b/dp3/api/routers/telemetry.py index 60acd77f..0c2fe9d7 100644 --- a/dp3/api/routers/telemetry.py +++ b/dp3/api/routers/telemetry.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Literal, Optional +from typing import Literal import requests from fastapi import APIRouter, HTTPException @@ -41,9 +41,9 @@ async def get_snapshot_summary() -> dict: @router.get("/metadata") async def get_metadata( - module: Optional[str] = None, - date_from: Optional[AwareDatetime] = None, - date_to: Optional[AwareDatetime] = None, + module: str | None = None, + date_from: AwareDatetime | None = None, + date_to: AwareDatetime | None = None, skip: NonNegativeInt = 0, limit: NonNegativeInt = 0, sort: Literal["newest", "oldest"] = "newest", diff --git a/dp3/bin/check.py b/dp3/bin/check.py index 8836de6b..b852db1a 100755 --- a/dp3/bin/check.py +++ b/dp3/bin/check.py @@ -218,7 +218,7 @@ def main(args): unique_sources = [] source_paths_and_errors = [] - for path, source, err in zip(paths, sources, errors): + for path, source, err in zip(paths, sources, errors, strict=False): if source in unique_sources: i = unique_sources.index(source) source_paths_and_errors[i].add((path, err)) @@ -226,7 +226,7 @@ def main(args): unique_sources.append(source) source_paths_and_errors.append({(path, err)}) - for source, paths_and_errors in zip(unique_sources, source_paths_and_errors): + for source, paths_and_errors in zip(unique_sources, source_paths_and_errors, strict=False): for path, err in paths_and_errors: print(" -> ".join(path)) print(" ", err) diff --git a/dp3/bin/schema_update.py b/dp3/bin/schema_update.py index 7a45238f..d078a2df 100644 --- a/dp3/bin/schema_update.py +++ b/dp3/bin/schema_update.py @@ -7,6 +7,7 @@ import logging from dp3.common.config import ModelSpec, read_config_dir +from dp3.common.utils import suppress_dependency_loggers from dp3.database.database import EntityDatabase @@ -41,6 +42,7 @@ def main(args): LOGFORMAT = "%(asctime)-15s,%(name)s,[%(levelname)s] %(message)s" LOGDATEFORMAT = "%Y-%m-%dT%H:%M:%S" logging.basicConfig(level=logging.DEBUG, format=LOGFORMAT, datefmt=LOGDATEFORMAT) + suppress_dependency_loggers() log = logging.getLogger("SchemaUpdate") # Connect to database diff --git a/dp3/bin/shcmd/common.py b/dp3/bin/shcmd/common.py index 860951fd..266bde98 100644 --- a/dp3/bin/shcmd/common.py +++ b/dp3/bin/shcmd/common.py @@ -5,7 +5,7 @@ import os import sys from functools import lru_cache -from typing import Any, Optional +from typing import Any from urllib.parse import urljoin import requests @@ -28,9 +28,9 @@ class DP3APIClient: def __init__( self, config_dir: str, - base_url: Optional[str] = None, + base_url: str | None = None, timeout: float = 5.0, - model_spec: Optional[ModelSpec] = None, + model_spec: ModelSpec | None = None, ): self.config_dir = os.path.abspath(config_dir) self.model_spec = model_spec @@ -42,7 +42,7 @@ def __init__( def _normalize_base_url(base_url: str) -> str: return base_url.rstrip("/") + "/" - def _resolve_base_url(self, base_url: Optional[str]) -> str: + def _resolve_base_url(self, base_url: str | None) -> str: if base_url is not None: normalized = self._normalize_base_url(base_url) self._check_health(normalized) @@ -84,7 +84,7 @@ def request( method: str, path: str, *, - params: Optional[dict[str, Any]] = None, + params: dict[str, Any] | None = None, json_body: Any = None, stream: bool = False, ) -> requests.Response: @@ -115,7 +115,7 @@ def read_json_value(raw_value: str) -> Any: raise APIError(f"Invalid JSON value: {e}") from e -def read_json_input(path: Optional[str]) -> Any: +def read_json_input(path: str | None) -> Any: """Decode JSON from a file path or standard input.""" if path in (None, "-"): content = sys.stdin.read() @@ -207,7 +207,7 @@ def stream_json_pages( return 0 -def resolve_config_dir(config_dir: Optional[str]) -> str: +def resolve_config_dir(config_dir: str | None) -> str: """Resolve the configuration directory for the shell-oriented CLI.""" if config_dir is not None: return os.path.abspath(config_dir) @@ -217,7 +217,7 @@ def resolve_config_dir(config_dir: Optional[str]) -> str: @lru_cache(maxsize=32) -def load_completion_model_spec(config_dir: str) -> Optional[ModelSpec]: +def load_completion_model_spec(config_dir: str) -> ModelSpec | None: """Load the model specification used by shell completion.""" try: config = read_config_dir(config_dir, recursive=True) @@ -228,8 +228,8 @@ def load_completion_model_spec(config_dir: str) -> Optional[ModelSpec]: @lru_cache(maxsize=32) def load_completion_entity_catalog( - config_dir: str, base_url: Optional[str], timeout: float -) -> Optional[dict[str, Any]]: + config_dir: str, base_url: str | None, timeout: float +) -> dict[str, Any] | None: """Load entity metadata from the API when config-based completion is unavailable.""" try: client = DP3APIClient(config_dir, base_url, timeout) @@ -241,7 +241,7 @@ def load_completion_entity_catalog( def get_completion_context( parsed_args, -) -> tuple[Optional[ModelSpec], Optional[dict[str, Any]]]: +) -> tuple[ModelSpec | None, dict[str, Any] | None]: """Return completion metadata derived from config and API sources.""" config_dir = resolve_config_dir(getattr(parsed_args, "config", None)) model_spec = load_completion_model_spec(config_dir) @@ -257,8 +257,8 @@ def get_completion_context( def _entity_type_description( etype: str, - model_spec: Optional[ModelSpec], - entity_catalog: Optional[dict[str, Any]], + model_spec: ModelSpec | None, + entity_catalog: dict[str, Any] | None, ) -> str: if model_spec is not None and etype in model_spec.entities: entity_spec = model_spec.entity(etype) diff --git a/dp3/bin/shcmd/entity/__init__.py b/dp3/bin/shcmd/entity/__init__.py index c8874972..46de9e9a 100644 --- a/dp3/bin/shcmd/entity/__init__.py +++ b/dp3/bin/shcmd/entity/__init__.py @@ -2,7 +2,6 @@ """Entity commands for the shell-oriented DP3 CLI.""" import argparse -from typing import Optional from . import etype from .common import complete_entity_rest, complete_entity_selector @@ -25,7 +24,7 @@ def _build_overview_parser() -> argparse.ArgumentParser: return parser -def parse_entity_command(args) -> tuple[Optional[argparse.Namespace], Optional[int]]: +def parse_entity_command(args) -> tuple[argparse.Namespace | None, int | None]: """Parse the path-like entity command grammar.""" overview_parser = _build_overview_parser() if args.selector is None: diff --git a/dp3/bin/shcmd/entity/common.py b/dp3/bin/shcmd/entity/common.py index acf620f2..0b3eb9a0 100644 --- a/dp3/bin/shcmd/entity/common.py +++ b/dp3/bin/shcmd/entity/common.py @@ -3,7 +3,7 @@ import argparse import json -from typing import Any, Optional +from typing import Any from argcomplete.finders import CompletionFinder @@ -154,7 +154,7 @@ def build_has_attr_filter(client, etype: str, attr: str) -> dict[str, Any]: return query -def build_generic_filter_param(client, args) -> Optional[str]: +def build_generic_filter_param(client, args) -> str | None: """Build the generic-filter query parameter for entity type queries.""" query = None if getattr(args, "filter_json", None) is not None: @@ -240,7 +240,7 @@ def _match_descriptions(values: dict[str, str], prefix: str) -> dict[str, str]: return {value: description for value, description in values.items() if value.startswith(prefix)} -def _entity_types(model_spec, entity_catalog: Optional[dict[str, Any]] = None) -> list[str]: +def _entity_types(model_spec, entity_catalog: dict[str, Any] | None = None) -> list[str]: if model_spec is not None: return sorted(model_spec.entities) if entity_catalog is not None: @@ -249,7 +249,7 @@ def _entity_types(model_spec, entity_catalog: Optional[dict[str, Any]] = None) - def _entity_attrs( - model_spec, etype: str, entity_catalog: Optional[dict[str, Any]] = None + model_spec, etype: str, entity_catalog: dict[str, Any] | None = None ) -> list[str]: if model_spec is not None and etype in model_spec.entities: return sorted(model_spec.attribs(etype)) @@ -259,7 +259,7 @@ def _entity_attrs( def _entity_attr_descriptions( - model_spec, etype: str, entity_catalog: Optional[dict[str, Any]] = None + model_spec, etype: str, entity_catalog: dict[str, Any] | None = None ) -> dict[str, str]: attrs = _entity_attrs(model_spec, etype, entity_catalog) descriptions = {attr: f"Attribute on entity type '{etype}'." for attr in attrs} diff --git a/dp3/common/attrspec.py b/dp3/common/attrspec.py index 79536837..86d26efb 100644 --- a/dp3/common/attrspec.py +++ b/dp3/common/attrspec.py @@ -1,6 +1,6 @@ from datetime import timedelta from enum import Flag -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal from pydantic import ( BaseModel, @@ -68,11 +68,11 @@ def from_str(cls, type_str: str): class ObservationsHistoryParams(BaseModel): """History parameters field of observations attribute""" - max_age: Optional[ParsedTimedelta] = None - max_items: Optional[PositiveInt] = None - expire_time: Optional[ParsedTimedelta] = None - pre_validity: Optional[ParsedTimedelta] = timedelta() - post_validity: Optional[ParsedTimedelta] = timedelta() + max_age: ParsedTimedelta | None = None + max_items: PositiveInt | None = None + expire_time: ParsedTimedelta | None = None + pre_validity: ParsedTimedelta | None = timedelta() + post_validity: ParsedTimedelta | None = timedelta() aggregate: bool = True @@ -85,8 +85,8 @@ def expire_time_inf_transform(cls, v): class TimeseriesTSParams(BaseModel): """Timeseries parameters field of timeseries attribute""" - max_age: Optional[ParsedTimedelta] = None - time_step: Optional[ParsedTimedelta] = None + max_age: ParsedTimedelta | None = None + time_step: ParsedTimedelta | None = None class TimeseriesSeries(BaseModel): @@ -119,7 +119,7 @@ class AttrSpecGeneric(SpecModel, use_enum_values=True): id: str = Field(pattern=ID_REGEX) name: str description: str = "" - ttl: Optional[ParsedTimedelta] = timedelta() + ttl: ParsedTimedelta | None = timedelta() _dp_model = PrivateAttr() @@ -298,7 +298,7 @@ def add_default_series(cls, v, info: FieldValidationInfo): - [AttrSpecObservations][dp3.common.attrspec.AttrSpecObservations] - [AttrSpecTimeseries][dp3.common.attrspec.AttrSpecTimeseries] """ -AttrSpecType = Union[AttrSpecTimeseries, AttrSpecObservations, AttrSpecPlain] +AttrSpecType = AttrSpecTimeseries | AttrSpecObservations | AttrSpecPlain def AttrSpec(id: str, spec: dict[str, Any]) -> AttrSpecType: diff --git a/dp3/common/callback_registrar.py b/dp3/common/callback_registrar.py index 8ed7010b..a74a6bfd 100644 --- a/dp3/common/callback_registrar.py +++ b/dp3/common/callback_registrar.py @@ -1,7 +1,8 @@ import logging +from collections.abc import Callable from functools import partial, wraps from logging import Logger -from typing import Any, Callable, Union +from typing import Any from pydantic import BaseModel @@ -68,7 +69,7 @@ def on_entity_creation_in_snapshots( def on_attr_change_in_snapshots( model_spec: ModelSpec, run_flag: SharedFlag, - original_hook: Callable[[AnyEidT, DataPointTask], Union[list[DataPointTask], None]], + original_hook: Callable[[AnyEidT, DataPointTask], list[DataPointTask] | None], etype: str, record: dict, ) -> list[DataPointTask]: @@ -139,16 +140,16 @@ def scheduler_register( self, func: Callable, *, - func_args: Union[list, tuple] = None, + func_args: list | tuple = None, func_kwargs: dict = None, - year: Union[int, str] = None, - month: Union[int, str] = None, - day: Union[int, str] = None, - week: Union[int, str] = None, - day_of_week: Union[int, str] = None, - hour: Union[int, str] = None, - minute: Union[int, str] = None, - second: Union[int, str] = None, + year: int | str = None, + month: int | str = None, + day: int | str = None, + week: int | str = None, + day_of_week: int | str = None, + hour: int | str = None, + minute: int | str = None, + second: int | str = None, timezone: str = "UTC", misfire_grace_time: int = 1, ) -> int: @@ -271,7 +272,7 @@ def register_entity_hook(self, hook_type: str, hook: Callable, entity: str): def register_on_new_attr_hook( self, - hook: Callable[[AnyEidT, DataPointType], Union[None, list[DataPointTask]]], + hook: Callable[[AnyEidT, DataPointType], None | list[DataPointTask]], entity: str, attr: str, refresh: SharedFlag = None, @@ -353,7 +354,7 @@ def register_timeseries_hook( def register_correlation_hook( self, - hook: Callable[[str, dict], Union[None, list[DataPointTask]]], + hook: Callable[[str, dict], None | list[DataPointTask]], entity_type: str, depends_on: list[list[str]], may_change: list[list[str]], @@ -385,7 +386,7 @@ def register_correlation_hook( def register_correlation_hook_with_master_record( self, - hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]], + hook: Callable[[str, dict, dict], None | list[DataPointTask]], entity_type: str, depends_on: list[list[str]], may_change: list[list[str]], diff --git a/dp3/common/config.py b/dp3/common/config.py index 2b538740..1e2c86f9 100644 --- a/dp3/common/config.py +++ b/dp3/common/config.py @@ -6,13 +6,12 @@ from collections.abc import Iterator from contextlib import contextmanager from contextvars import ContextVar -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any, TypeVar import yaml from pydantic import ( BaseModel, ConfigDict, - Extra, Field, NonNegativeInt, PositiveInt, @@ -32,6 +31,8 @@ from dp3.common.datatype import AnyEidT from dp3.common.entityspec import EntitySpec +_T = TypeVar("_T") + class NoDefault: pass @@ -50,7 +51,7 @@ def __repr__(self): def copy(self): return HierarchicalDict(dict.copy(self)) - def get(self, key, default=NoDefault): + def get(self, key: str, default: type[NoDefault] | _T = NoDefault) -> Any | _T: """ Key may be a path (in dot notation) into a hierarchy of dicts. For example `dictionary.get('abc.x.y')` @@ -157,7 +158,7 @@ def read_config_dir(dir_path: str, recursive: bool = False) -> HierarchicalDict: TimeInt = Annotated[int, Field(ge=0, le=59)] -class CronExpression(BaseModel, extra=Extra.forbid): +class CronExpression(BaseModel): """ Cron expression used for scheduling. Also support standard cron expressions, such as @@ -177,16 +178,18 @@ class CronExpression(BaseModel, extra=Extra.forbid): timezone: Timezone for time specification (default is UTC). """ - second: Optional[Union[TimeInt, CronStr]] = None - minute: Optional[Union[TimeInt, CronStr]] = None - hour: Optional[Union[TimeInt, CronStr]] = None + model_config = ConfigDict(extra="forbid") + + second: TimeInt | CronStr | None = None + minute: TimeInt | CronStr | None = None + hour: TimeInt | CronStr | None = None - day: Optional[Union[Annotated[int, Field(ge=1, le=31)], CronStr]] = None - day_of_week: Optional[Union[Annotated[int, Field(ge=0, le=6)], CronStr]] = None + day: Annotated[int, Field(ge=1, le=31)] | CronStr | None = None + day_of_week: Annotated[int, Field(ge=0, le=6)] | CronStr | None = None - week: Optional[int] = Field(default=None, ge=1, le=53) - month: Optional[int] = Field(default=None, ge=1, le=12) - year: Optional[str] = Field(default=None, pattern=r"^\d{4}$") + week: int | None = Field(default=None, ge=1, le=53) + month: int | None = Field(default=None, ge=1, le=12) + year: str | None = Field(default=None, pattern=r"^\d{4}$") timezone: str = "UTC" diff --git a/dp3/common/control.py b/dp3/common/control.py index c3ccf3f7..83b9d9f8 100644 --- a/dp3/common/control.py +++ b/dp3/common/control.py @@ -3,8 +3,8 @@ """ import logging +from collections.abc import Callable from enum import Enum -from typing import Callable from pydantic import BaseModel diff --git a/dp3/common/datapoint.py b/dp3/common/datapoint.py index ee8850c0..d6c4a945 100644 --- a/dp3/common/datapoint.py +++ b/dp3/common/datapoint.py @@ -1,5 +1,5 @@ from ipaddress import IPv4Address, IPv6Address -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer @@ -8,7 +8,7 @@ def to_json_friendly(v): - if isinstance(v, (IPv4Address, IPv6Address, MACAddress)): + if isinstance(v, IPv4Address | IPv6Address | MACAddress): return str(v) return v @@ -33,7 +33,7 @@ class DataPointBase(BaseModel, use_enum_values=True): etype: str eid: Annotated[Any, PlainSerializer(to_json_friendly, when_used="json")] = None attr: str - src: Optional[str] = None + src: str | None = None v: Annotated[Any, PlainSerializer(to_json_friendly, when_used="json")] = None c: Any = None t1: Any = None @@ -147,10 +147,10 @@ def dp_ts_root_validator_irregular_intervals(self): # Check time_first[i] <= time_last[i] assert all( - t[0] <= t[1] for t in zip(self.v.time_first, self.v.time_last) + t[0] <= t[1] for t in zip(self.v.time_first, self.v.time_last, strict=False) ), "'time_first[i] <= time_last[i]' isn't true for all 'i'" return self -DataPointType = Union[DataPointPlainBase, DataPointObservationsBase, DataPointTimeseriesBase] +DataPointType = DataPointPlainBase | DataPointObservationsBase | DataPointTimeseriesBase diff --git a/dp3/common/datatype.py b/dp3/common/datatype.py index fe8b80ce..65930e34 100644 --- a/dp3/common/datatype.py +++ b/dp3/common/datatype.py @@ -2,7 +2,7 @@ from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv6Address -from typing import Any, Optional, Union +from typing import Any from pydantic import ( BaseModel, @@ -38,12 +38,12 @@ "mac": MACAddress, "time": datetime, "special": Any, - "json": Union[Json[Any], dict, list], + "json": Json[Any] | dict | list, } eid_data_types = ["string", "int", "ipv4", "ipv6", "mac"] -AnyEidT = Union[str, int, IPv4Address, IPv6Address, MACAddress] +AnyEidT = str | int | IPv4Address | IPv6Address | MACAddress """Type alias for any of possible entity ID data types. Note that the type is determined based on the loaded entity configuration @@ -190,7 +190,7 @@ def _determine_value_validator(self): # Set (type, default value) for the key if k_optional: k = k[:-1] # Remove question mark from key - dict_spec[k] = (Optional[primitive_data_types[v]], None) + dict_spec[k] = (primitive_data_types[v] | None, None) else: dict_spec[k] = (primitive_data_types[v], ...) @@ -228,7 +228,7 @@ def determine_value_validator(self): return self._determine_value_validator() @property - def data_type(self) -> Union[type, BaseModel]: + def data_type(self) -> type | BaseModel: """Type for incoming value validation""" return self._data_type @@ -263,7 +263,7 @@ def mirror_link(self) -> bool: return self._mirror_link @property - def mirror_as(self) -> Union[str, None]: + def mirror_as(self) -> str | None: """If `mirror_link`, what is the name of the mirrored attribute""" return self._mirror_as diff --git a/dp3/common/entityspec.py b/dp3/common/entityspec.py index f0e0fba7..5992b075 100644 --- a/dp3/common/entityspec.py +++ b/dp3/common/entityspec.py @@ -1,4 +1,4 @@ -from typing import Literal, Union +from typing import Literal from pydantic import BaseModel, Field, PrivateAttr, model_validator @@ -62,7 +62,7 @@ class EntitySpec(SpecModel): id_data_type: EidDataType = EidDataType("string") name: str snapshot: bool - lifetime: Union[ImmortalLifetime, TimeToLiveLifetime, WeakLifetime] = Field( + lifetime: ImmortalLifetime | TimeToLiveLifetime | WeakLifetime = Field( default_factory=lambda: ImmortalLifetime(type="immortal"), discriminator="type" ) diff --git a/dp3/common/mac_address.py b/dp3/common/mac_address.py index 28100d48..936cfad6 100644 --- a/dp3/common/mac_address.py +++ b/dp3/common/mac_address.py @@ -1,4 +1,6 @@ -from typing import Any, Union +from __future__ import annotations + +from typing import Any from pydantic import GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema @@ -10,7 +12,7 @@ class MACAddress: Can be initialized from colon or comma separated string, or from raw bytes. """ - def __init__(self, mac: Union[bytes, str, "MACAddress"]): + def __init__(self, mac: bytes | str | MACAddress): if isinstance(mac, self.__class__): mac = mac.mac # type: ignore if not isinstance(mac, bytes) or len(mac) != 6: @@ -19,11 +21,11 @@ def __init__(self, mac: Union[bytes, str, "MACAddress"]): self.mac: bytes = mac @classmethod - def _validate(cls, value: Any) -> "MACAddress": + def _validate(cls, value: Any) -> MACAddress: return cls(value) @classmethod - def _serialize(cls, value: "MACAddress", info: Any) -> Any: + def _serialize(cls, value: MACAddress, info: Any) -> Any: if info.mode == "json": return str(value) return value @@ -33,7 +35,7 @@ def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler ) -> CoreSchema: base_schema = core_schema.no_info_after_validator_function( - cls._validate, handler(Union[str, bytes]) + cls._validate, handler(str | bytes) ) python_schema = core_schema.union_schema( @@ -52,7 +54,7 @@ def __get_pydantic_core_schema__( ) @staticmethod - def _parse_mac(mac: Union[bytes, str]) -> bytes: + def _parse_mac(mac: bytes | str) -> bytes: if isinstance(mac, str): mac = mac.encode() if not isinstance(mac, bytes): diff --git a/dp3/common/scheduler.py b/dp3/common/scheduler.py index 75b03ac8..031d2bfc 100644 --- a/dp3/common/scheduler.py +++ b/dp3/common/scheduler.py @@ -6,7 +6,7 @@ """ import logging -from typing import Callable, Union +from collections.abc import Callable from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.cron import CronTrigger @@ -39,16 +39,16 @@ def stop(self) -> None: def register( self, func: Callable, - func_args: Union[list, tuple] = None, + func_args: list | tuple = None, func_kwargs: dict = None, - year: Union[int, str] = None, - month: Union[int, str] = None, - day: Union[int, str] = None, - week: Union[int, str] = None, - day_of_week: Union[int, str] = None, - hour: Union[int, str] = None, - minute: Union[int, str] = None, - second: Union[int, str] = None, + year: int | str = None, + month: int | str = None, + day: int | str = None, + week: int | str = None, + day_of_week: int | str = None, + hour: int | str = None, + minute: int | str = None, + second: int | str = None, timezone: str = "UTC", misfire_grace_time: int = 1, ) -> int: diff --git a/dp3/common/task.py b/dp3/common/task.py index b59d85a7..3c319666 100644 --- a/dp3/common/task.py +++ b/dp3/common/task.py @@ -1,12 +1,12 @@ import hashlib from abc import ABC, abstractmethod -from collections.abc import Iterator +from collections.abc import Callable, Iterator from contextlib import contextmanager from contextvars import ContextVar from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv6Address -from typing import Annotated, Any, Callable, Optional, Union +from typing import Annotated, Any from pydantic import ( AfterValidator, @@ -150,7 +150,7 @@ class DataPointTask(Task): eid: Annotated[Any, PlainSerializer(to_json_friendly, when_used="json")] data_points: list[ValidatedDataPoint] = [] tags: list[Any] = [] - ttl_tokens: Optional[dict[str, datetime]] = None + ttl_tokens: dict[str, datetime] | None = None delete: bool = False def __init__(__pydantic_self__, **data: Any) -> None: @@ -222,13 +222,11 @@ def get_discriminator_value(entity_tuple: tuple[str, Any]) -> str: EntityTuple = Annotated[ - Union[ - Annotated[tuple[str, str], Tag("string")], - Annotated[tuple[str, int], Tag("int")], - Annotated[tuple[str, IPv4Address], Tag("ipv4")], - Annotated[tuple[str, IPv6Address], Tag("ipv6")], - Annotated[tuple[str, MACAddress], Tag("mac")], - ], + Annotated[tuple[str, str], Tag("string")] + | Annotated[tuple[str, int], Tag("int")] + | Annotated[tuple[str, IPv4Address], Tag("ipv4")] + | Annotated[tuple[str, IPv6Address], Tag("ipv6")] + | Annotated[tuple[str, MACAddress], Tag("mac")], Discriminator(get_discriminator_value), ] @@ -259,8 +257,8 @@ def as_message(self) -> str: return self.model_dump_json() @staticmethod - def get_validator(model_spec: ModelSpec) -> Callable[[Union[str, bytes]], "Snapshot"]: - def json_validator(serialized: Union[str, bytes]) -> Snapshot: + def get_validator(model_spec: ModelSpec) -> Callable[[str | bytes], "Snapshot"]: + def json_validator(serialized: str | bytes) -> Snapshot: with entity_type_context(model_spec): return Snapshot.model_validate_json(serialized) diff --git a/dp3/common/types.py b/dp3/common/types.py index 922963da..ab5f81fe 100644 --- a/dp3/common/types.py +++ b/dp3/common/types.py @@ -1,7 +1,7 @@ -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from ipaddress import IPv4Address, IPv6Address from json import JSONEncoder -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any from event_count_logger import DummyEventGroup, EventGroup from pydantic import AfterValidator, BeforeValidator @@ -9,8 +9,6 @@ from dp3.common.utils import parse_time_duration, time_duration_pattern -UTC = timezone.utc - def parse_timedelta_or_passthrough(v): """ @@ -24,7 +22,7 @@ def parse_timedelta_or_passthrough(v): ParsedTimedelta = Annotated[timedelta, BeforeValidator(parse_timedelta_or_passthrough)] -def ensure_timezone_aware(v: Optional[datetime]): +def ensure_timezone_aware(v: datetime | None): """Ensure datetime is timezone-aware by defaulting to UTC.""" if v is None: return v @@ -55,7 +53,7 @@ def t2_after_t1(v, info: FieldValidationInfo): AfterValidator(t2_after_t1), ] -EventGroupType = Union[EventGroup, DummyEventGroup] +EventGroupType = EventGroup | DummyEventGroup class DP3Encoder(JSONEncoder): @@ -64,6 +62,6 @@ class DP3Encoder(JSONEncoder): def default(self, o: Any) -> Any: if isinstance(o, datetime): return o.strftime("%Y-%m-%dT%H:%M:%S.%fZ")[:-4] - if isinstance(o, (IPv4Address, IPv6Address)): + if isinstance(o, IPv4Address | IPv6Address): return str(o) return super().default(o) diff --git a/dp3/common/utils.py b/dp3/common/utils.py index 742bb258..a70c2f7c 100644 --- a/dp3/common/utils.py +++ b/dp3/common/utils.py @@ -2,12 +2,12 @@ auxiliary/utility functions and classes """ +import logging import re from collections.abc import Iterable, Iterator from datetime import datetime, timedelta from functools import partial from itertools import islice -from typing import Union # *** IP conversion functions *** ipv4_re = re.compile(r"^([0-9]{1,3})\.([0-9]{1,3})\.([0-9]{1,3})\.([0-9]{1,3})$") @@ -74,7 +74,7 @@ def parse_rfc_time(time_str): time_duration_pattern = re.compile(r"^\s*(\d+)([smhd])?$") -def parse_time_duration(duration_string: Union[str, int, timedelta]) -> timedelta: +def parse_time_duration(duration_string: str | int | timedelta) -> timedelta: """ Parse duration in format (or just "0"). @@ -84,7 +84,7 @@ def parse_time_duration(duration_string: Union[str, int, timedelta]) -> timedelt if isinstance(duration_string, timedelta): return duration_string # if number is passed, consider it number of seconds - if isinstance(duration_string, (int, float)): + if isinstance(duration_string, int | float): return timedelta(seconds=duration_string) d = 0 @@ -186,3 +186,17 @@ def get_func_name(func_or_method): except AttributeError: return fname return wrapper.format(name=f"{module}.{fname}", args=args) + + +DEPENDENCY_LOGGERS = ( + "requests", + "urllib3", + "amqpstorm", + "pymongo", +) + + +def suppress_dependency_loggers(level: int = logging.WARNING) -> None: + """Keep noisy third-party loggers quiet when DP3 enables verbose logging.""" + for logger_name in DEPENDENCY_LOGGERS: + logging.getLogger(logger_name).setLevel(level) diff --git a/dp3/core/collector.py b/dp3/core/collector.py index f6ea9785..31801f7f 100644 --- a/dp3/core/collector.py +++ b/dp3/core/collector.py @@ -4,7 +4,7 @@ import logging from collections import defaultdict -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from functools import partial from pydantic import BaseModel @@ -15,7 +15,6 @@ from dp3.common.datapoint import DataPointBase, DataPointObservationsBase, DataPointTimeseriesBase from dp3.common.datatype import AnyEidT from dp3.common.task import DataPointTask, parse_eids_from_cache -from dp3.common.types import UTC from dp3.database.database import EntityDatabase DB_SEND_CHUNK = 1000 diff --git a/dp3/core/link_manager.py b/dp3/core/link_manager.py index 94aedc09..a0f60393 100644 --- a/dp3/core/link_manager.py +++ b/dp3/core/link_manager.py @@ -3,7 +3,7 @@ """ import logging -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from functools import partial from pymongo import DeleteMany @@ -14,7 +14,6 @@ from dp3.common.datapoint import DataPointBase, DataPointObservationsBase from dp3.common.datatype import AnyEidT from dp3.common.task import parse_eids_from_cache -from dp3.common.types import UTC from dp3.database.database import EntityDatabase diff --git a/dp3/core/updater.py b/dp3/core/updater.py index 8832a77f..fb0ed21f 100644 --- a/dp3/core/updater.py +++ b/dp3/core/updater.py @@ -2,10 +2,10 @@ import logging from collections import defaultdict -from collections.abc import Iterator -from datetime import datetime, timedelta +from collections.abc import Callable, Iterator +from datetime import UTC, datetime, timedelta from functools import partial -from typing import Callable, Literal +from typing import Literal from pydantic import BaseModel, validate_call from pymongo.cursor import Cursor @@ -14,7 +14,7 @@ from dp3.common.config import CronExpression, PlatformConfig from dp3.common.scheduler import Scheduler from dp3.common.task import DataPointTask, task_context -from dp3.common.types import UTC, EventGroupType, ParsedTimedelta +from dp3.common.types import EventGroupType, ParsedTimedelta from dp3.database.database import EntityDatabase from dp3.task_processing.task_queue import TaskQueueWriter diff --git a/dp3/database/config.py b/dp3/database/config.py index 3137736e..573cce01 100644 --- a/dp3/database/config.py +++ b/dp3/database/config.py @@ -1,5 +1,5 @@ import urllib -from typing import Literal, Union +from typing import Literal from pydantic import BaseModel, Field, field_validator @@ -38,7 +38,7 @@ class MongoConfig(BaseModel, extra="forbid"): db_name: str = "dp3" username: str = "dp3" password: str = "dp3" - connection: Union[MongoStandaloneConfig, MongoReplicaConfig] = Field(..., discriminator="mode") + connection: MongoStandaloneConfig | MongoReplicaConfig = Field(..., discriminator="mode") storage: StorageConfig = StorageConfig() @field_validator("username", "password") diff --git a/dp3/database/database.py b/dp3/database/database.py index e676bed6..4a2bfede 100644 --- a/dp3/database/database.py +++ b/dp3/database/database.py @@ -3,9 +3,8 @@ import threading import time from collections import defaultdict -from collections.abc import Generator, Iterator -from datetime import datetime -from typing import Callable, Optional +from collections.abc import Callable, Generator, Iterator +from datetime import UTC, datetime import pymongo from event_count_logger import DummyEventGroup @@ -24,7 +23,7 @@ from dp3.common.datatype import AnyEidT from dp3.common.scheduler import Scheduler from dp3.common.task import HASH -from dp3.common.types import UTC, EventGroupType +from dp3.common.types import EventGroupType from dp3.database.config import MongoConfig, MongoReplicaConfig, MongoStandaloneConfig from dp3.database.encodings import get_codec_options from dp3.database.exceptions import DatabaseError @@ -62,7 +61,7 @@ def __init__( model_spec: ModelSpec, num_processes: int, process_index: int = 0, - elog: Optional[EventGroupType] = None, + elog: EventGroupType | None = None, ) -> None: self.log = logging.getLogger("EntityDatabase") self.elog = elog or DummyEventGroup() @@ -559,7 +558,7 @@ def update_master_records(self, etype: str, eids: list[AnyEidT], records: list[d res = master_col.bulk_write( [ ReplaceOne({"_id": eid}, record, upsert=True) - for eid, record in zip(eids, records) + for eid, record in zip(eids, records, strict=False) ], ordered=False, ) @@ -776,7 +775,7 @@ def delete_many_link_dps( try: updates = [] for etype, affected_eid_list, attr_name, eid_to_list in zip( - etypes, affected_eids, attr_names, eids_to + etypes, affected_eids, attr_names, eids_to, strict=False ): master_col = self._master_col(etype) attr_type = self._db_schema_config.attr(etype, attr_name).t @@ -836,8 +835,8 @@ def get_value_or_history( etype: str, attr_name: str, eid: AnyEidT, - t1: Optional[datetime] = None, - t2: Optional[datetime] = None, + t1: datetime | None = None, + t2: datetime | None = None, ) -> dict: """Gets current value and/or history of attribute for given `eid`. @@ -878,12 +877,12 @@ def estimate_count_eids(self, etype: str) -> int: master_col = self._master_col(etype) return master_col.estimated_document_count({}) - def _get_metadata_id(self, module: str, time: datetime, worker_id: Optional[int] = None) -> str: + def _get_metadata_id(self, module: str, time: datetime, worker_id: int | None = None) -> str: """Generates unique metadata id based on `module`, `time` and the worker index.""" worker_id = self._process_index if worker_id is None else worker_id return f"{module}{time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}w{worker_id}" - def save_metadata(self, time: datetime, metadata: dict, worker_id: Optional[int] = None): + def save_metadata(self, time: datetime, metadata: dict, worker_id: int | None = None): """Saves metadata dict under the caller module and passed timestamp.""" module = get_caller_id() metadata["_id"] = self._get_metadata_id(module, time, worker_id) @@ -897,7 +896,7 @@ def save_metadata(self, time: datetime, metadata: dict, worker_id: Optional[int] raise DatabaseError(f"Insert of metadata failed: {e}\n{metadata}") from e def update_metadata( - self, time: datetime, metadata: dict, increase: dict = None, worker_id: Optional[int] = None + self, time: datetime, metadata: dict, increase: dict = None, worker_id: int | None = None ): """Updates existing metadata of caller module and passed timestamp.""" module = get_caller_id() @@ -1120,7 +1119,7 @@ def move_raw_to_archive(self, etype: str): except Exception as e: raise DatabaseError(f"Move of raw collection failed: {e}") from e - def get_archive_summary(self, etype: str, before: datetime) -> Optional[dict]: + def get_archive_summary(self, etype: str, before: datetime) -> dict | None: collection_summaries = [] for archive_col_name in self._archive_col_names(etype): result_cursor = self._get_archive_summary(archive_col_name, before=before) @@ -1187,7 +1186,7 @@ def drop_empty_archives(self, etype: str) -> int: raise DatabaseError(f"Drop of empty archive failed: {e}") from e return dropped_count - def get_module_cache(self, override_called_id: Optional[str] = None): + def get_module_cache(self, override_called_id: str | None = None): """Return a persistent cache collection for given module name. Module name is determined automatically, but you can override it. diff --git a/dp3/database/magic.py b/dp3/database/magic.py index 4e2a856e..e4f7a17e 100644 --- a/dp3/database/magic.py +++ b/dp3/database/magic.py @@ -58,9 +58,9 @@ """ import re -from datetime import datetime, timezone +from datetime import UTC, datetime from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network -from typing import Any, Union +from typing import Any from bson import Binary @@ -87,9 +87,9 @@ def _binary_id_filter(value: Any, _) -> dict[str, Binary]: magic_type, value = match.groups() value = magic_string_replacements[magic_type](value, True) - if isinstance(value, (IPv4Address, IPv6Address, MACAddress)): + if isinstance(value, IPv4Address | IPv6Address | MACAddress): return _binary_snapshot_bucket_range(value.packed) - if isinstance(value, (IPv4Network, IPv6Network)): + if isinstance(value, IPv4Network | IPv6Network): return { "$gte": _pack_binary_snapshot_bucket_id(value[0].packed, 0), "$lte": _pack_binary_snapshot_bucket_id(value[-1].packed, -1), @@ -100,18 +100,14 @@ def _binary_id_filter(value: Any, _) -> dict[str, Binary]: raise ValueError(f"Unsupported value type {type(value)}: {value}") -def _parse_ipv4_network( - value: str, in_id_filter: bool -) -> Union[IPv4Network, dict[str, IPv4Address]]: +def _parse_ipv4_network(value: str, in_id_filter: bool) -> IPv4Network | dict[str, IPv4Address]: ip = IPv4Network(value) if in_id_filter: return ip return {"$gte": ip[0], "$lte": ip[-1]} -def _parse_ipv6_network( - value: str, in_id_filter: bool -) -> Union[IPv6Network, dict[str, IPv6Address]]: +def _parse_ipv6_network(value: str, in_id_filter: bool) -> IPv6Network | dict[str, IPv6Address]: ip = IPv6Network(value) if in_id_filter: return ip @@ -122,12 +118,12 @@ def _parse_mac_address(value: str, _) -> MACAddress: return MACAddress(value) -def _parse_date_ts(value: Union[int, float], _) -> datetime: - return datetime.fromtimestamp(float(value), timezone.utc) +def _parse_date_ts(value: int | float, _) -> datetime: + return datetime.fromtimestamp(float(value), UTC) def _parse_date_string(value: str, _) -> datetime: - return datetime.strptime(value, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + return datetime.strptime(value, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=UTC) magic_string_replacements = { @@ -153,7 +149,7 @@ def search_and_replace(query: dict[str, Any]) -> dict[str, Any]: """ if isinstance(query, dict): for key, value in query.items(): - if isinstance(value, (dict, list)): + if isinstance(value, dict | list): search_and_replace(value) elif isinstance(value, str): match = magic_regex.match(value) diff --git a/dp3/database/schema_cleaner.py b/dp3/database/schema_cleaner.py index 09ada81b..ae30e73a 100644 --- a/dp3/database/schema_cleaner.py +++ b/dp3/database/schema_cleaner.py @@ -1,9 +1,9 @@ import logging import time from collections import defaultdict -from datetime import datetime +from collections.abc import Callable +from datetime import UTC, datetime from logging import Logger -from typing import Callable import pymongo from pymongo import DeleteOne, InsertOne @@ -12,7 +12,6 @@ from dp3.common.attrspec import ID_REGEX, AttrSpecType, AttrType from dp3.common.config import HierarchicalDict, ModelSpec -from dp3.common.types import UTC from dp3.common.utils import batched # number of seconds to wait for the i-th attempt to reconnect after error diff --git a/dp3/database/snapshots.py b/dp3/database/snapshots.py index 1e8a5407..79b1e536 100644 --- a/dp3/database/snapshots.py +++ b/dp3/database/snapshots.py @@ -6,7 +6,7 @@ from collections.abc import Iterable from datetime import datetime, timedelta from ipaddress import IPv4Address, IPv6Address -from typing import Any, Optional, Union +from typing import Any import pymongo from bson import Binary @@ -125,11 +125,11 @@ def _binary_bucket_range(self, eid: bytes) -> dict: } @abc.abstractmethod - def _bucket_id(self, eid: AnyEidT, ctime: datetime) -> Union[str, Binary]: + def _bucket_id(self, eid: AnyEidT, ctime: datetime) -> str | Binary: """Returns `_id` for snapshot bucket document.""" @abc.abstractmethod - def _filter_from_bid(self, b_id: Union[bytes, str]) -> dict: + def _filter_from_bid(self, b_id: bytes | str) -> dict: """Returns filter for snapshots with same eid as given bucket document _id. Args: b_id: the _id of the snapshot bucket, type depends on etype's data type @@ -154,8 +154,8 @@ def get_latest_one(self, eid: AnyEidT) -> dict: def find_latest( self, - fulltext_filters: Optional[dict[str, str]] = None, - generic_filter: Optional[dict[str, Any]] = None, + fulltext_filters: dict[str, str] | None = None, + generic_filter: dict[str, Any] | None = None, ) -> Cursor: """Find latest snapshots of given `etype`. @@ -193,8 +193,8 @@ def find_latest( def count_latest( self, - fulltext_filters: Optional[dict[str, str]] = None, - generic_filter: Optional[dict[str, Any]] = None, + fulltext_filters: dict[str, str] | None = None, + generic_filter: dict[str, Any] | None = None, ) -> int: """Count latest snapshots of given `etype`. @@ -251,11 +251,11 @@ def _prepare_latest_query( def get_by_eid( self, eid: AnyEidT, - t1: Optional[datetime] = None, - t2: Optional[datetime] = None, + t1: datetime | None = None, + t2: datetime | None = None, skip: int = 0, limit: int = 0, - ) -> Union[Cursor, CommandCursor]: + ) -> Cursor | CommandCursor: """Get all (or filtered) snapshots of given `eid`. This method is useful for displaying `eid`'s history on web. @@ -360,8 +360,8 @@ def get_distinct_val_count(self, attr: str) -> dict[Any, int]: def _get_oversized( self, eid: AnyEidT, - t1: Optional[datetime] = None, - t2: Optional[datetime] = None, + t1: datetime | None = None, + t2: datetime | None = None, skip: int = 0, limit: int = 0, ) -> Cursor: @@ -789,8 +789,8 @@ def get_latest_one(self, entity_type: str, eid: AnyEidT) -> dict: def find_latest( self, entity_type: str, - fulltext_filters: Optional[dict[str, str]] = None, - generic_filter: Optional[dict[str, Any]] = None, + fulltext_filters: dict[str, str] | None = None, + generic_filter: dict[str, Any] | None = None, ) -> Cursor: """Find latest snapshots of given `etype`. @@ -825,8 +825,8 @@ def find_latest( def count_latest( self, entity_type: str, - fulltext_filters: Optional[dict[str, str]] = None, - generic_filter: Optional[dict[str, Any]] = None, + fulltext_filters: dict[str, str] | None = None, + generic_filter: dict[str, Any] | None = None, ) -> int: """Count latest snapshots of given `etype`. @@ -844,11 +844,11 @@ def get_by_eid( self, entity_type: str, eid: AnyEidT, - t1: Optional[datetime] = None, - t2: Optional[datetime] = None, + t1: datetime | None = None, + t2: datetime | None = None, skip: int = 0, limit: int = 0, - ) -> Union[Cursor, CommandCursor]: + ) -> Cursor | CommandCursor: """Get all (or filtered) snapshots of given `eid`. This method is useful for displaying `eid`'s history on web. diff --git a/dp3/history_management/history_manager.py b/dp3/history_management/history_manager.py index 15d78926..97f1f9e9 100644 --- a/dp3/history_management/history_manager.py +++ b/dp3/history_management/history_manager.py @@ -2,11 +2,10 @@ import json import logging import os -from datetime import datetime +from datetime import UTC, datetime from pathlib import Path -from typing import Optional -from pydantic import BaseModel, Extra +from pydantic import BaseModel, ConfigDict from dp3.common.attrspec import ( AttrSpecObservations, @@ -16,7 +15,7 @@ ) from dp3.common.callback_registrar import CallbackRegistrar from dp3.common.config import CronExpression, PlatformConfig -from dp3.common.types import UTC, DP3Encoder, ParsedTimedelta +from dp3.common.types import DP3Encoder, ParsedTimedelta from dp3.common.utils import entity_expired from dp3.database.database import DatabaseError, EntityDatabase @@ -46,10 +45,10 @@ class DPArchivationConfig(BaseModel): schedule: CronExpression older_than: ParsedTimedelta - archive_dir: Optional[str] = None + archive_dir: str | None = None -class HistoryManagerConfig(BaseModel, extra=Extra.forbid): +class HistoryManagerConfig(BaseModel): """Configuration for history manager. Attributes: @@ -60,6 +59,8 @@ class HistoryManagerConfig(BaseModel, extra=Extra.forbid): datapoint_archivation: Configuration for datapoint archivation. """ + model_config = ConfigDict(extra="forbid") + aggregation_schedule: CronExpression datapoint_cleaning_schedule: CronExpression mark_datapoints_schedule: CronExpression @@ -237,7 +238,7 @@ def _reformat_dp(dp): def _get_raw_dps_summary( self, before: datetime - ) -> tuple[Optional[datetime], Optional[datetime], int]: + ) -> tuple[datetime | None, datetime | None, int]: date_ranges = [] for etype in self.model_spec.entities: summary = self.db.get_archive_summary(etype, before=before) diff --git a/dp3/history_management/telemetry.py b/dp3/history_management/telemetry.py index f4d0f17e..20c84420 100644 --- a/dp3/history_management/telemetry.py +++ b/dp3/history_management/telemetry.py @@ -1,7 +1,7 @@ import logging import threading import time -from datetime import datetime +from datetime import UTC, datetime import requests from pymongo import ASCENDING, UpdateOne @@ -10,7 +10,6 @@ from dp3.common.config import PlatformConfig from dp3.common.datapoint import DataPointObservationsBase, DataPointTimeseriesBase from dp3.common.task import DataPointTask -from dp3.common.types import UTC from dp3.database.database import EntityDatabase @@ -43,7 +42,7 @@ def note_latest_src_timestamp(self, task: DataPointTask): """Note the latest timestamp of each source in the task""" latest_timestamps = {} for dp in task.data_points: - has_timestamp = isinstance(dp, (DataPointObservationsBase, DataPointTimeseriesBase)) + has_timestamp = isinstance(dp, DataPointObservationsBase | DataPointTimeseriesBase) if dp.src is None or not has_timestamp: continue latest_timestamp = dp.t2 or dp.t1 diff --git a/dp3/scripts/add_hashes.py b/dp3/scripts/add_hashes.py index 87571bac..7104e8f4 100755 --- a/dp3/scripts/add_hashes.py +++ b/dp3/scripts/add_hashes.py @@ -29,7 +29,7 @@ model_spec = ModelSpec(config.get("db_entities")) # Connect to database -connection_conf = MongoConfig.parse_obj(config.get("database", {})) +connection_conf = MongoConfig.model_validate(config.get("database", {})) client = EntityDatabase.connect(connection_conf) client.admin.command("ping") diff --git a/dp3/scripts/datapoint_log_converter.py b/dp3/scripts/datapoint_log_converter.py index b5086f4b..a9dd1ec4 100755 --- a/dp3/scripts/datapoint_log_converter.py +++ b/dp3/scripts/datapoint_log_converter.py @@ -7,7 +7,8 @@ import logging import os import re -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import pandas as pd from dateutil.parser import parse as parsetime diff --git a/dp3/scripts/dummy_sender.py b/dp3/scripts/dummy_sender.py index b6d8cb2f..d9e9c856 100755 --- a/dp3/scripts/dummy_sender.py +++ b/dp3/scripts/dummy_sender.py @@ -6,7 +6,7 @@ import os import time from argparse import ArgumentParser -from datetime import datetime, timezone +from datetime import UTC, datetime from itertools import islice from queue import Queue from threading import Event, Thread @@ -14,8 +14,6 @@ import pandas as pd import requests -UTC = timezone.utc - def get_valid_path(parser, arg): if not os.path.exists(arg): diff --git a/dp3/snapshots/snapshooter.py b/dp3/snapshots/snapshooter.py index 717524ee..826cc2a0 100644 --- a/dp3/snapshots/snapshooter.py +++ b/dp3/snapshots/snapshooter.py @@ -19,8 +19,9 @@ import logging from collections import defaultdict -from datetime import datetime -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from datetime import UTC, datetime +from typing import Any import pymongo.errors from event_count_logger import DummyEventGroup @@ -42,7 +43,7 @@ parse_eids_from_cache, task_context, ) -from dp3.common.types import UTC, EventGroupType +from dp3.common.types import EventGroupType from dp3.common.utils import get_func_name from dp3.database.database import EntityDatabase from dp3.snapshots.snapshot_hooks import ( @@ -69,7 +70,7 @@ def __init__( task_queue_writer: TaskQueueWriter, platform_config: PlatformConfig, scheduler: Scheduler, - elog: Optional[EventGroupType] = None, + elog: EventGroupType | None = None, ) -> None: self.log = logging.getLogger("SnapShooter") @@ -185,7 +186,7 @@ def register_timeseries_hook( def register_correlation_hook( self, - hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]], + hook: Callable[[str, dict, dict], None | list[DataPointTask]], entity_type: str, depends_on: list[list[str]], may_change: list[list[str]], @@ -543,7 +544,7 @@ def make_snapshot(self, task: Snapshot): self.db.update_metadata(task.time, metadata={"linked_finished": True}, worker_id=0) @staticmethod - def _remove_record_from_value(spec: AttrSpecType, value: Union[dict, list[dict]]): + def _remove_record_from_value(spec: AttrSpecType, value: dict | list[dict]): if spec.is_iterable: for link_dict in value: if "record" in link_dict: @@ -574,7 +575,7 @@ def extend_master_record(etype, master_record, new_tasks: list[DataPointTask]): for datapoint in task.data_points: if datapoint.etype != etype: continue - dp_dict = datapoint.dict(include={"v", "t1", "t2", "c"}) + dp_dict = datapoint.model_dump(include={"v", "t1", "t2", "c"}) if datapoint.attr in master_record: master_record[datapoint.attr].append() else: @@ -647,7 +648,7 @@ def get_linked_entity_ids(self, entity_type: str, current_values: dict) -> set[t @staticmethod def _get_link_entity_ids( - spec: AttrSpecType, link_value: Union[list[dict], dict] + spec: AttrSpecType, link_value: list[dict] | dict ) -> set[tuple[str, str]]: if spec.is_iterable: return {(spec.relation_to, v["eid"]) for v in link_value} @@ -664,7 +665,7 @@ def link_loaded_entities(self, loaded_entities: dict): entity[attr] = [] val_conf = entity[f"{attr}#c"] pruned_conf = [] - for v, conf in zip(val, val_conf): + for v, conf in zip(val, val_conf, strict=False): if self._keep_link(loaded_entities, attr_spec, v): self._link_record(loaded_entities, attr_spec, v) entity[attr].append(v) @@ -682,7 +683,7 @@ def link_loaded_entities(self, loaded_entities: dict): del entity[key] def _keep_link( - self, loaded_entities: dict, attr_spec: AttrSpecType, val: Union[dict, list[dict]] + self, loaded_entities: dict, attr_spec: AttrSpecType, val: dict | list[dict] ) -> bool: if self.config.keep_empty: return True @@ -693,7 +694,7 @@ def _keep_link( return loaded_entities.get((attr_spec.relation_to, val["eid"])) is not None @staticmethod - def _link_record(loaded_entities: dict, attr_spec: AttrSpecType, val: Union[dict, list[dict]]): + def _link_record(loaded_entities: dict, attr_spec: AttrSpecType, val: dict | list[dict]): if attr_spec.is_iterable: for link_dict in val: link_dict["record"] = loaded_entities.get( diff --git a/dp3/snapshots/snapshot_hooks.py b/dp3/snapshots/snapshot_hooks.py index dcf437f5..2ffe3c1e 100644 --- a/dp3/snapshots/snapshot_hooks.py +++ b/dp3/snapshots/snapshot_hooks.py @@ -4,9 +4,8 @@ import logging from collections import defaultdict -from collections.abc import Hashable +from collections.abc import Callable, Hashable from dataclasses import dataclass, field -from typing import Callable, Union from dp3.common.attrspec import AttrType from dp3.common.config import ModelSpec @@ -84,7 +83,7 @@ def __init__(self, log: logging.Logger, model_spec: ModelSpec, elog: EventGroupT def register( self, - hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]], + hook: Callable[[str, dict, dict], None | list[DataPointTask]], entity_type: str, depends_on: list[list[str]], may_change: list[list[str]], diff --git a/dp3/task_processing/task_executor.py b/dp3/task_processing/task_executor.py index 92199a67..ff8a487d 100644 --- a/dp3/task_processing/task_executor.py +++ b/dp3/task_processing/task_executor.py @@ -1,5 +1,5 @@ import logging -from typing import Callable +from collections.abc import Callable from event_count_logger import DummyEventGroup diff --git a/dp3/task_processing/task_hooks.py b/dp3/task_processing/task_hooks.py index 4438d78e..9e015f37 100644 --- a/dp3/task_processing/task_hooks.py +++ b/dp3/task_processing/task_hooks.py @@ -1,5 +1,5 @@ import logging -from typing import Callable +from collections.abc import Callable from dp3.common.attrspec import AttrType from dp3.common.config import ModelSpec diff --git a/dp3/task_processing/task_queue.py b/dp3/task_processing/task_queue.py index 6ddd1718..f6f2f17a 100644 --- a/dp3/task_processing/task_queue.py +++ b/dp3/task_processing/task_queue.py @@ -36,7 +36,8 @@ import logging import threading import time -from typing import Callable, Union +from collections.abc import Callable +from typing import Literal import amqpstorm @@ -94,7 +95,7 @@ class RobustAMQPConnection: host, port, virtual_host, username, password """ - def __init__(self, rabbit_config: dict = None) -> None: + def __init__(self, rabbit_config: dict | None = None) -> None: rabbit_config = {} if rabbit_config is None else rabbit_config self.log = logging.getLogger("RobustAMQPConnection") self.conn_params = { @@ -104,8 +105,8 @@ def __init__(self, rabbit_config: dict = None) -> None: "username": rabbit_config.get("username", "guest"), "password": rabbit_config.get("password", "guest"), } - self.connection: amqpstorm.Connection = None - self.channel: amqpstorm.Channel = None + self.connection: amqpstorm.Connection | None = None + self.channel: amqpstorm.Channel | None = None self._connection_id = 0 def __del__(self): @@ -133,9 +134,10 @@ def connect(self) -> None: # This was a repeated attempt, print success message with ERROR level self.log.error("... it's OK now, we're successfully connected!") - self.channel = self.connection.channel() - self.channel.confirm_deliveries() - self.channel.basic.qos(PREFETCH_COUNT) + channel = self.connection.channel() + channel.confirm_deliveries() + channel.basic.qos(PREFETCH_COUNT) + self.channel = channel break except amqpstorm.AMQPError as e: sleep_time = RECONNECT_DELAYS[min(attempts, len(RECONNECT_DELAYS)) - 1] @@ -152,7 +154,7 @@ def disconnect(self) -> None: self.connection = None self.channel = None - def check_queue_existence(self, queue_name: str) -> bool: + def check_queue_existence(self, queue_name: str | None) -> bool: if queue_name is None: return True assert self.channel is not None, "not connected" @@ -191,10 +193,10 @@ def __init__( self, app_name: str, workers: int = 1, - rabbit_config: dict = None, - exchange: str = None, - priority_exchange: str = None, - parent_logger: logging.Logger = None, + rabbit_config: dict | None = None, + exchange: str | None = None, + priority_exchange: str | None = None, + parent_logger: logging.Logger | None = None, ) -> None: rabbit_config = {} if rabbit_config is None else rabbit_config assert isinstance(workers, int) and workers >= 1, "count of workers must be positive number" @@ -360,10 +362,10 @@ def __init__( parse_task: Callable[[str], Task], app_name: str, worker_index: int = 0, - rabbit_config: dict = None, - queue: str = None, - priority_queue: Union[str, bool] = None, - parent_logger: logging.Logger = None, + rabbit_config: dict | None = None, + queue: str | None = None, + priority_queue: str | Literal[False] | None = None, + parent_logger: logging.Logger | None = None, ) -> None: rabbit_config = {} if rabbit_config is None else rabbit_config assert callable(callback), "callback must be callable object" @@ -391,14 +393,14 @@ def __init__( priority_queue = DEFAULT_PRIORITY_QUEUE.format(app_name, worker_index) elif priority_queue is False: priority_queue = None - self.queue_name = queue - self.priority_queue_name = priority_queue + self.queue_name: str = queue + self.priority_queue_name: str | None = priority_queue self.worker_index = worker_index self.running = False - self._consuming_thread = None - self._processing_thread = None + self._consuming_thread: threading.Thread | None = None + self._processing_thread: threading.Thread | None = None # Receive messages into 2 temporary queues # (max length should be equal to prefetch_count set in RabbitMQReader) @@ -490,11 +492,12 @@ def ack(self, msg_tag: tuple[int, int]) -> bool: Returns: Whether the message was acknowledged successfully and can be processed further. """ - conn_id, msg_tag = msg_tag + assert self.channel is not None, "not connected" + conn_id, _delivery_tag = msg_tag if conn_id != self._connection_id: return False try: - self.channel.basic.ack(delivery_tag=msg_tag) + self.channel.basic.ack(delivery_tag=_delivery_tag) except amqpstorm.AMQPChannelError as why: self.log.error("Channel error while acknowledging message: %s", why) self.reconnect() @@ -503,6 +506,7 @@ def ack(self, msg_tag: tuple[int, int]) -> bool: def _consuming_thread_func(self): # Register consumers and start consuming loop, reconnect on error + assert self.channel is not None, "not connected" while self.running: try: # Register consumers on both queues @@ -587,8 +591,8 @@ def watchdog(self): Register to be called periodically by scheduler. """ - proc = self._processing_thread.is_alive() - cons = self._consuming_thread.is_alive() + proc = self._processing_thread is not None and self._processing_thread.is_alive() + cons = self._consuming_thread is not None and self._consuming_thread.is_alive() if not proc or not cons: self.log.error( @@ -599,7 +603,8 @@ def watchdog(self): self._stop_consuming_thread() self._stop_processing_thread() - self.channel.close() + if self.channel is not None: + self.channel.close() self.channel = None self.cache.clear() self.cache_pri.clear() @@ -609,16 +614,17 @@ def watchdog(self): def _stop_consuming_thread(self) -> None: if self._consuming_thread: - if self._consuming_thread.is_alive: + if self._consuming_thread.is_alive(): # if not connected, no problem with contextlib.suppress(amqpstorm.AMQPError): - self.channel.stop_consuming() + if self.channel is not None: + self.channel.stop_consuming() self._consuming_thread.join() self._consuming_thread = None def _stop_processing_thread(self) -> None: if self._processing_thread: - if self._processing_thread.is_alive: + if self._processing_thread.is_alive(): self.running = False # tell processing thread to stop self.cache_full.set() # break potential wait() for data self._processing_thread.join() diff --git a/dp3/template/app/docker/python/Dockerfile b/dp3/template/app/docker/python/Dockerfile index f01df3ba..81c6589f 100644 --- a/dp3/template/app/docker/python/Dockerfile +++ b/dp3/template/app/docker/python/Dockerfile @@ -1,7 +1,7 @@ # syntax=docker/dockerfile:1 # Base interpreter with installed requirements -FROM python:3.9-slim as base +FROM python:3.11-slim as base RUN apt-get update; apt-get install -y git # Install requirements diff --git a/dp3/testing/case.py b/dp3/testing/case.py index bd6aec81..23fff28c 100644 --- a/dp3/testing/case.py +++ b/dp3/testing/case.py @@ -2,16 +2,15 @@ import copy import unittest -from collections.abc import Iterable, Mapping, Sequence -from datetime import datetime -from typing import Any, Callable, Generic, Optional, TypeVar, Union +from collections.abc import Callable, Iterable, Mapping, Sequence +from datetime import UTC, datetime +from typing import Any, Generic, TypeVar from dp3.common.attrspec import AttrType from dp3.common.base_module import BaseModule from dp3.common.config import HierarchicalDict, ModelSpec, PlatformConfig from dp3.common.datapoint import DataPointBase from dp3.common.task import DataPointTask, task_context -from dp3.common.types import UTC from dp3.common.utils import get_func_name from dp3.testing.assertions import ModuleAssertions from dp3.testing.config import ( @@ -34,11 +33,11 @@ class DP3ModuleTestCase(ModuleAssertions, unittest.TestCase, Generic[ModuleT]): ``config_dir`` explicitly when they need a fixed fixture config. """ - config_dir: Optional[str] = None + config_dir: str | None = None config_env_var: str = CONFIG_DIR_ENV module_class: type[ModuleT] - module_name: Optional[str] = None - module_config: Optional[dict] = None + module_name: str | None = None + module_config: dict | None = None app_name: str = "test" process_index: int = 0 num_processes: int = 1 @@ -78,7 +77,7 @@ def get_module_config(self) -> dict: return copy.deepcopy(self.module_config) return copy.deepcopy(get_module_config(self.config, self.get_module_name())) - def get_module_name(self) -> Optional[str]: + def get_module_name(self) -> str | None: if self.module_name is not None: return self.module_name return self.module_class.__module__.split(".")[-1] @@ -99,9 +98,9 @@ def make_task( self, etype: str, eid: Any, - data_points: Optional[list[Union[dict, DataPointBase]]] = None, - tags: Optional[list] = None, - ttl_tokens: Optional[dict] = None, + data_points: list[dict | DataPointBase] | None = None, + tags: list | None = None, + ttl_tokens: dict | None = None, delete: bool = False, ) -> DataPointTask: with task_context(self.model_spec): @@ -142,8 +141,8 @@ def make_observation_datapoint( attr: str, v: Any, src: str = "test", - t1: Optional[datetime] = None, - t2: Optional[datetime] = None, + t1: datetime | None = None, + t2: datetime | None = None, c: float = 1.0, **fields, ) -> DataPointBase: @@ -160,8 +159,8 @@ def make_timeseries_datapoint( attr: str, v: Mapping[str, Sequence[Any]], src: str = "test", - t1: Optional[datetime] = None, - t2: Optional[datetime] = None, + t1: datetime | None = None, + t2: datetime | None = None, **fields, ) -> DataPointBase: """Create a validated timeseries datapoint. @@ -190,7 +189,7 @@ def make_timeseries_datapoint( return self._make_datapoint(etype, eid, attr, values, src=src, **data) @staticmethod - def _infer_timeseries_t1(attr_spec, values: Mapping[str, Sequence[Any]]) -> Optional[datetime]: + def _infer_timeseries_t1(attr_spec, values: Mapping[str, Sequence[Any]]) -> datetime | None: if attr_spec.timeseries_type == "irregular" and values.get("time"): return values["time"][0] if attr_spec.timeseries_type == "irregular_intervals" and values.get("time_first"): @@ -208,13 +207,13 @@ def run_task_hooks(self, hook_type: str, task: DataPointTask) -> None: self.registrar.run_task_hooks(hook_type, task) def run_allow_entity_creation( - self, entity: str, eid: Any, task: Optional[DataPointTask] = None + self, entity: str, eid: Any, task: DataPointTask | None = None ) -> bool: task = task or self._make_synthetic_task(entity, eid) return self.registrar.run_allow_entity_creation(entity, eid, task) def run_on_entity_creation( - self, entity: str, eid: Any, task: Optional[DataPointTask] = None + self, entity: str, eid: Any, task: DataPointTask | None = None ) -> list[DataPointTask]: task = task or self._make_synthetic_task(entity, eid) return self.registrar.run_on_entity_creation(entity, eid, task) @@ -230,24 +229,24 @@ def run_correlation_hooks( self, entity_type: str, record: dict, - master_record: Optional[dict] = None, + master_record: dict | None = None, ) -> list[DataPointTask]: return self.registrar.run_correlation_hooks(entity_type, record, master_record) def run_periodic_update( - self, entity_type: str, eid: Any, master_record: dict, hook_id: Optional[str] = None + self, entity_type: str, eid: Any, master_record: dict, hook_id: str | None = None ) -> list[DataPointTask]: return self.registrar.run_periodic_update(entity_type, eid, master_record, hook_id) def run_periodic_eid_update( - self, entity_type: str, eid: Any, hook_id: Optional[str] = None + self, entity_type: str, eid: Any, hook_id: str | None = None ) -> list[DataPointTask]: return self.registrar.run_periodic_eid_update(entity_type, eid, hook_id) - def run_scheduler_job(self, job: Union[int, str, Callable, HookRegistration]): + def run_scheduler_job(self, job: int | str | Callable | HookRegistration): return self.registrar.run_scheduler_job(job) - def registered(self, kind: Optional[str] = None, **fields) -> list[HookRegistration]: + def registered(self, kind: str | None = None, **fields) -> list[HookRegistration]: """Return registrations matching ``kind`` and the supplied registration fields.""" return [ registration @@ -305,7 +304,7 @@ def assert_scheduler_registered(self, **fields) -> HookRegistration: assertSchedulerRegistered = assert_scheduler_registered def _registration_matches( - self, registration: HookRegistration, kind: Optional[str], fields: dict[str, Any] + self, registration: HookRegistration, kind: str | None, fields: dict[str, Any] ) -> bool: if kind is not None and registration.kind != kind: return False diff --git a/dp3/testing/config.py b/dp3/testing/config.py index 5ed9a9d3..db9bef0d 100644 --- a/dp3/testing/config.py +++ b/dp3/testing/config.py @@ -1,14 +1,13 @@ """Configuration helpers for DP3 module tests.""" import os -from typing import Optional from dp3.common.config import HierarchicalDict, ModelSpec, PlatformConfig, read_config_dir CONFIG_DIR_ENV = "DP3_CONFIG_DIR" -def resolve_config_dir(config_dir: Optional[str] = None, env_var: str = CONFIG_DIR_ENV) -> str: +def resolve_config_dir(config_dir: str | None = None, env_var: str = CONFIG_DIR_ENV) -> str: """Return an absolute DP3 config directory path. Explicit ``config_dir`` values take precedence. If no explicit path is supplied, the path is @@ -23,9 +22,7 @@ def resolve_config_dir(config_dir: Optional[str] = None, env_var: str = CONFIG_D return os.path.abspath(resolved) -def load_config( - config_dir: Optional[str] = None, env_var: str = CONFIG_DIR_ENV -) -> HierarchicalDict: +def load_config(config_dir: str | None = None, env_var: str = CONFIG_DIR_ENV) -> HierarchicalDict: """Load a DP3 config directory for module tests.""" return read_config_dir(resolve_config_dir(config_dir, env_var), recursive=True) @@ -38,7 +35,7 @@ def build_model_spec(config: HierarchicalDict) -> ModelSpec: def build_platform_config( config: HierarchicalDict, model_spec: ModelSpec, - config_dir: Optional[str] = None, + config_dir: str | None = None, *, app_name: str = "test", process_index: int = 0, @@ -57,7 +54,7 @@ def build_platform_config( ) -def get_module_config(config: HierarchicalDict, module_name: Optional[str]) -> dict: +def get_module_config(config: HierarchicalDict, module_name: str | None) -> dict: """Return module-specific config from loaded app config.""" if module_name is None: return {} diff --git a/dp3/testing/registrar.py b/dp3/testing/registrar.py index fa50f2d1..3d9fa902 100644 --- a/dp3/testing/registrar.py +++ b/dp3/testing/registrar.py @@ -4,10 +4,11 @@ import logging import warnings from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass, field from datetime import timedelta from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any from apscheduler.triggers.cron import CronTrigger from event_count_logger import DummyEventGroup @@ -40,12 +41,12 @@ class HookRegistration: kind: str hook: Callable - entity: Optional[str] = None - attr: Optional[str] = None - hook_type: Optional[str] = None - hook_id: Optional[str] = None - entity_type: Optional[str] = None - attr_type: Optional[str] = None + entity: str | None = None + attr: str | None = None + hook_type: str | None = None + hook_id: str | None = None + entity_type: str | None = None + attr_type: str | None = None depends_on: list[list[str]] = field(default_factory=list) may_change: list[list[str]] = field(default_factory=list) refresh: Any = None @@ -60,7 +61,7 @@ class TestCallbackRegistrar: def __init__( self, model_spec: ModelSpec, - log: Optional[logging.Logger] = None, + log: logging.Logger | None = None, update_batch_period: Any = None, ): self.model_spec = model_spec @@ -92,16 +93,16 @@ def scheduler_register( # noqa: PLR0913 self, func: Callable, *, - func_args: Union[list, tuple] = None, + func_args: list | tuple = None, func_kwargs: dict = None, - year: Union[int, str] = None, - month: Union[int, str] = None, - day: Union[int, str] = None, - week: Union[int, str] = None, - day_of_week: Union[int, str] = None, - hour: Union[int, str] = None, - minute: Union[int, str] = None, - second: Union[int, str] = None, + year: int | str = None, + month: int | str = None, + day: int | str = None, + week: int | str = None, + day_of_week: int | str = None, + hour: int | str = None, + minute: int | str = None, + second: int | str = None, timezone: str = "UTC", misfire_grace_time: int = 1, ) -> int: @@ -368,7 +369,7 @@ def run_correlation_hooks( self, entity_type: str, record: dict, - master_record: Optional[dict] = None, + master_record: dict | None = None, ) -> list[DataPointTask]: eid = self._assert_record_eid(record) return self.run_correlation_hooks_for_entities( @@ -376,7 +377,7 @@ def run_correlation_hooks( ) def run_correlation_hooks_for_entities( - self, entities: dict[tuple[str, Any], dict], master_records: Optional[dict] = None + self, entities: dict[tuple[str, Any], dict], master_records: dict | None = None ) -> list[DataPointTask]: master_records = master_records or {} for entity_type, _ in entities: @@ -415,7 +416,7 @@ def run_periodic_update( entity_type: str, eid: Any, master_record: dict, - hook_id: Optional[str] = None, + hook_id: str | None = None, ) -> list[DataPointTask]: hooks = self._matching_update_hooks(self._periodic_record_hooks, entity_type, hook_id) tasks: list[DataPointTask] = [] @@ -426,7 +427,7 @@ def run_periodic_update( return tasks def run_periodic_eid_update( - self, entity_type: str, eid: Any, hook_id: Optional[str] = None + self, entity_type: str, eid: Any, hook_id: str | None = None ) -> list[DataPointTask]: hooks = self._matching_update_hooks(self._periodic_eid_hooks, entity_type, hook_id) tasks: list[DataPointTask] = [] @@ -436,9 +437,7 @@ def run_periodic_eid_update( tasks.extend(hook_tasks) return tasks - def get_scheduler_job( - self, job: Union[int, str, Callable, HookRegistration] - ) -> HookRegistration: + def get_scheduler_job(self, job: int | str | Callable | HookRegistration) -> HookRegistration: """Return a registered scheduler job by id, callable, or callable name.""" if isinstance(job, int): for reg in self._scheduler_jobs: @@ -457,7 +456,7 @@ def get_scheduler_job( raise ValueError(f"Multiple scheduler jobs match {job!r}.") return matches[0] - def run_scheduler_job(self, job: Union[int, str, Callable, HookRegistration]): + def run_scheduler_job(self, job: int | str | Callable | HookRegistration): reg = self.get_scheduler_job(job) return reg.hook(*reg.extra["func_args"], **reg.extra["func_kwargs"]) @@ -536,7 +535,7 @@ def _validate_periodic_hook( thread_id: UpdateThreadId = (period_seconds, entity_type, eid_only) return get_update_thread_hooks(update_thread_hooks, hook_id, thread_id) - def _update_batch_period_seconds(self) -> Optional[float]: + def _update_batch_period_seconds(self) -> float | None: if self.update_batch_period is None: return None return self._period_seconds(self.update_batch_period) @@ -556,7 +555,7 @@ def _run_no_arg_hooks(self, hooks: list[HookRegistration]) -> list[DataPointTask return tasks @staticmethod - def _matching_update_hooks(hooks: dict, entity_type: str, hook_id: Optional[str]): + def _matching_update_hooks(hooks: dict, entity_type: str, hook_id: str | None): matches = [] for (_, etype, _), thread_hooks in hooks.items(): if etype != entity_type: @@ -569,7 +568,7 @@ def _matching_update_hooks(hooks: dict, entity_type: str, hook_id: Optional[str] return matches -def _callable_matches(func: Callable, expected: Union[str, Callable]) -> bool: +def _callable_matches(func: Callable, expected: str | Callable) -> bool: if callable(expected): return func == expected func_name = get_func_name(func) diff --git a/dp3/worker.py b/dp3/worker.py index 722abbb2..179d1c1c 100755 --- a/dp3/worker.py +++ b/dp3/worker.py @@ -18,6 +18,7 @@ from dp3.common.callback_registrar import CallbackRegistrar, reload_module_config from dp3.common.config import PlatformConfig from dp3.common.control import Control, ControlAction, refresh_on_entity_creation +from dp3.common.utils import suppress_dependency_loggers from dp3.core.collector import GarbageCollector from dp3.core.link_manager import LinkManager from dp3.core.updater import Updater @@ -122,10 +123,7 @@ def main(app_name: str, config_dir: str, process_index: int, verbose: bool) -> N ) log = logging.getLogger() - # Disable INFO and DEBUG messages from some libraries - logging.getLogger("requests").setLevel(logging.WARNING) - logging.getLogger("urllib3").setLevel(logging.WARNING) - logging.getLogger("amqpstorm").setLevel(logging.WARNING) + suppress_dependency_loggers() ############################################## # Load configuration diff --git a/pyproject.toml b/pyproject.toml index 0cd5d32c..f7a78460 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", "Intended Audience :: Developers", ] -requires-python = ">=3.9" +requires-python = ">=3.11" dynamic = ["version", "dependencies", "optional-dependencies"] [tool.setuptools_scm] @@ -64,13 +64,13 @@ scripts = { file = ["requirements.scripts.txt"] } ### Black Formatting ################################################################### [tool.black] -target-version = ["py39"] +target-version = ["py311"] line-length = 100 extend-exclude = "/(install|docker)/" ### Ruff Code Linting ################################################################## [tool.ruff] -target-version = "py39" +target-version = "py311" extend-exclude = ["install", "docker"] line-length = 100 show-fixes = true diff --git a/requirements.scripts.txt b/requirements.scripts.txt index 74983ca4..3f4736c5 100644 --- a/requirements.scripts.txt +++ b/requirements.scripts.txt @@ -1,2 +1,2 @@ numpy>=1.23.0 -pandas~=1.4.3 +pandas~=2.2 diff --git a/requirements.txt b/requirements.txt index 2809ff2d..faa473d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ -AMQPStorm~=2.7.2 -apscheduler~=3.10.0 -argcomplete~=3.6.0 +AMQPStorm~=2.7 +apscheduler~=3.10 +argcomplete~=3.6 event-count-logger>=1.1 fastapi>=0.109.1 pydantic>=2.4.0 -pymongo~=4.6.3 +pymongo~=4.6 python-dateutil~=2.8 pyyaml~=6.0 -requests~=2.32.0 +requests~=2.32 uvicorn>=0.22.0 diff --git a/tests/test_api/common.py b/tests/test_api/common.py index 9f766fb0..df1b9302 100644 --- a/tests/test_api/common.py +++ b/tests/test_api/common.py @@ -3,7 +3,8 @@ import sys import time import unittest -from typing import Callable, TypeVar +from collections.abc import Callable +from typing import TypeVar import requests from pydantic import BaseModel diff --git a/tests/test_api/test_01_datapoints.py b/tests/test_api/test_01_datapoints.py index c7e73c36..4e92ddc5 100644 --- a/tests/test_api/test_01_datapoints.py +++ b/tests/test_api/test_01_datapoints.py @@ -1,13 +1,11 @@ import json import sys -from datetime import datetime +from datetime import UTC, datetime from typing import Any import common from common import ACCEPTED_ERROR_CODES -from dp3.common.types import UTC - class PushDatapoints(common.APITest): def test_invalid_payload(self): diff --git a/tests/test_api/test_get_entity_eid_data.py b/tests/test_api/test_get_entity_eid_data.py index b7e1ac9e..818273b8 100644 --- a/tests/test_api/test_get_entity_eid_data.py +++ b/tests/test_api/test_get_entity_eid_data.py @@ -1,11 +1,10 @@ import sys -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta import common from pydantic import RootModel from dp3.api.internal.entity_response_models import EntityEidData, EntityEidMasterRecord -from dp3.common.types import UTC DATAPOINT_COUNT = 6 diff --git a/tests/test_api/test_raw.py b/tests/test_api/test_raw.py index 1b9ca919..2c9a5170 100644 --- a/tests/test_api/test_raw.py +++ b/tests/test_api/test_raw.py @@ -1,11 +1,11 @@ import datetime import json import sys +from datetime import UTC import common from dp3.api.internal.entity_response_models import EntityRawDataPage -from dp3.common.types import UTC class RawDatapointsIntegration(common.APITest): diff --git a/tests/test_api/test_snapshots.py b/tests/test_api/test_snapshots.py index d8fd82e4..f31a38b3 100644 --- a/tests/test_api/test_snapshots.py +++ b/tests/test_api/test_snapshots.py @@ -1,12 +1,12 @@ import datetime import json import sys +from datetime import UTC from time import sleep import common from dp3.api.internal.entity_response_models import EntityEidData, EntityEidSnapshots -from dp3.common.types import UTC class SnapshotIntegration(common.APITest): diff --git a/tests/test_api/test_telemetry.py b/tests/test_api/test_telemetry.py index 171079b6..71e1c480 100644 --- a/tests/test_api/test_telemetry.py +++ b/tests/test_api/test_telemetry.py @@ -1,12 +1,11 @@ import datetime import json import sys +from datetime import UTC from time import sleep import common -from dp3.common.types import UTC - class TelemetryEndpoints(common.APITest): @classmethod diff --git a/tests/test_common/test_magic.py b/tests/test_common/test_magic.py index 70f658d8..6bf12925 100644 --- a/tests/test_common/test_magic.py +++ b/tests/test_common/test_magic.py @@ -1,7 +1,7 @@ """Test the search & replace functionality for snapshot generic filter endpoint""" import unittest -from datetime import datetime, timezone +from datetime import UTC, datetime from ipaddress import IPv4Address, IPv6Address from bson import Binary @@ -34,17 +34,17 @@ def test_replace_int(self): def test_replace_date(self): query = {"date_attr": "$$Date{2021-01-01T00:00:00Z}"} - expected = {"date_attr": datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)} + expected = {"date_attr": datetime(2021, 1, 1, 0, 0, 0, tzinfo=UTC)} self.assertEqual(search_and_replace(query), expected) def test_replace_date_ts(self): query = {"date_attr": "$$DateTs{1609459200}"} - expected = {"date_attr": datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)} + expected = {"date_attr": datetime(2021, 1, 1, 0, 0, 0, tzinfo=UTC)} self.assertEqual(search_and_replace(query), expected) # Test with float value query = {"date_attr": "$$DateTs{1609459200.5}"} - expected = {"date_attr": datetime(2021, 1, 1, 0, 0, 0, 500000, tzinfo=timezone.utc)} + expected = {"date_attr": datetime(2021, 1, 1, 0, 0, 0, 500000, tzinfo=UTC)} self.assertEqual(search_and_replace(query), expected) def test_replace_ipv4_prefix(self): diff --git a/tests/test_common/test_module_testing.py b/tests/test_common/test_module_testing.py index f3b89548..197b5f14 100644 --- a/tests/test_common/test_module_testing.py +++ b/tests/test_common/test_module_testing.py @@ -1,13 +1,12 @@ import os import warnings -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from pydantic import ValidationError from dp3.common.base_module import BaseModule from dp3.common.state import SharedFlag from dp3.common.task import DataPointTask, task_context -from dp3.common.types import UTC from dp3.testing import DP3ModuleTestCase diff --git a/tests/test_common/test_snapshots.py b/tests/test_common/test_snapshots.py index ae39eed8..89590f27 100644 --- a/tests/test_common/test_snapshots.py +++ b/tests/test_common/test_snapshots.py @@ -3,14 +3,14 @@ import logging import os import unittest +from collections.abc import Callable +from datetime import UTC from functools import partial, update_wrapper -from typing import Callable, Optional from event_count_logger import DummyEventGroup from dp3.common.config import ModelSpec, PlatformConfig, read_config_dir from dp3.common.task import Task -from dp3.common.types import UTC from dp3.snapshots.snapshooter import SnapShooter from dp3.snapshots.snapshot_hooks import SnapshotCorrelationHookContainer @@ -115,7 +115,7 @@ def register_on_entity_delete( self, f_one: Callable[[str, str], None], f_many: Callable[[str, list[str]], None] ): ... - def get_module_cache(self, override_called_id: Optional[str] = None): + def get_module_cache(self, override_called_id: str | None = None): return self.module_cache def save_snapshot(self, etype: str, snapshot: dict, time: datetime): diff --git a/tests/test_common/test_types.py b/tests/test_common/test_types.py index 1c0e0fb0..3235293a 100644 --- a/tests/test_common/test_types.py +++ b/tests/test_common/test_types.py @@ -1,5 +1,5 @@ import unittest -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta, timezone from pydantic import BaseModel, Field @@ -18,7 +18,7 @@ class _T2Model(BaseModel): class TestAwareDatetime(unittest.TestCase): def test_naive_datetime_defaults_to_utc(self): model = _AwareModel(dt="2024-01-01T10:00:00") - self.assertEqual(model.dt.tzinfo, timezone.utc) + self.assertEqual(model.dt.tzinfo, UTC) def test_existing_timezone_is_preserved(self): cest_timezone = timezone(timedelta(hours=2), "CEST") @@ -29,6 +29,6 @@ def test_existing_timezone_is_preserved(self): def test_t2_datetime_inherits_timezone_when_missing(self): model = _T2Model(t1="2024-01-01T00:00:00") self.assertIsNotNone(model.t2) - self.assertEqual(model.t1.tzinfo, timezone.utc) - self.assertEqual(model.t2.tzinfo, timezone.utc) + self.assertEqual(model.t1.tzinfo, UTC) + self.assertEqual(model.t2.tzinfo, UTC) self.assertEqual(model.t2, model.t1) diff --git a/tests/test_example/dps_gen.py b/tests/test_example/dps_gen.py index 9001a162..83885434 100644 --- a/tests/test_example/dps_gen.py +++ b/tests/test_example/dps_gen.py @@ -2,12 +2,12 @@ import json import random -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta class TimeContainer: def __init__(self): - self.time = datetime.now(timezone.utc) - timedelta(days=4) + self.time = datetime.now(UTC) - timedelta(days=4) def add_minutes(self, minutes: int): self.time += timedelta(minutes=minutes) diff --git a/tests/test_example/dps_gen_realtime.py b/tests/test_example/dps_gen_realtime.py index 9f701e87..1d5979c3 100644 --- a/tests/test_example/dps_gen_realtime.py +++ b/tests/test_example/dps_gen_realtime.py @@ -2,14 +2,12 @@ import random from argparse import ArgumentParser -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from sys import stderr from time import sleep import requests -UTC = timezone.utc - def random_initial_location(): latitude = random.uniform(39.0, 41.0)