Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ dependencies = [
"taskiq_dependencies>=1.3.1,<2",
"anyio>=4",
"packaging>=19",
"izulu==0.50.0",
"aiohttp>=3",
]

Expand Down
82 changes: 82 additions & 0 deletions taskiq/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Minimal exception templating used by taskiq exceptions."""

import sys
from string import Formatter

if sys.version_info >= (3, 11):
from typing import dataclass_transform
else:
from typing_extensions import dataclass_transform


@dataclass_transform(
eq_default=False,
order_default=False,
kw_only_default=True,
frozen_default=False,
)
class Error(Exception):
"""Base templated exception compatible with taskiq needs."""

__template__ = "Exception occurred"

@classmethod
def _collect_annotations(cls) -> dict[str, object]:
"""Collect all annotated fields from the class hierarchy."""
annotations: dict[str, object] = {}
for class_ in reversed(cls.__mro__):
annotations.update(getattr(class_, "__annotations__", {}))
return annotations

@classmethod
def _format_fields(cls, names: set[str]) -> str:
"""Format field names in a deterministic error message."""
return ", ".join(f"'{name}'" for name in sorted(names))

@classmethod
def _template_fields(cls, template: str) -> set[str]:
"""Extract plain field names used in a format template."""
fields: set[str] = set()
for _, field_name, _, _ in Formatter().parse(template):
if not field_name:
continue
field = field_name.split(".", maxsplit=1)[0].split("[", maxsplit=1)[0]
fields.add(field)
return fields

def __init__(self, **kwargs: object) -> None:
annotations = self._collect_annotations()
undeclared = set(kwargs) - set(annotations)
if undeclared:
raise TypeError(f"Undeclared arguments: {self._format_fields(undeclared)}")

missing = {
field
for field in annotations
if field not in kwargs and not hasattr(type(self), field)
}
if missing:
raise TypeError(f"Missing arguments: {self._format_fields(missing)}")

for key, value in kwargs.items():
setattr(self, key, value)

template = getattr(type(self), "__template__", self.__template__)
missing_annotations = self._template_fields(template) - set(annotations)
if missing_annotations:
raise ValueError(
f"Fields must be annotated: {self._format_fields(missing_annotations)}",
)

payload = {field: getattr(self, field) for field in annotations}
super().__init__(template.format(**payload))

def __repr__(self) -> str:
"""Represent exception with all declared fields."""
annotations = self._collect_annotations()
module = type(self).__module__
qualname = type(self).__qualname__
if not annotations:
return f"{module}.{qualname}()"
args = ", ".join(f"{field}={getattr(self, field)!r}" for field in annotations)
return f"{module}.{qualname}({args})"
4 changes: 2 additions & 2 deletions taskiq/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any

from izulu import root
from taskiq.error import Error


class TaskiqError(root.Error):
class TaskiqError(Error):
"""Base exception for all errors."""

__template__ = "Exception occurred"
Expand Down
88 changes: 88 additions & 0 deletions tests/test_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest

from taskiq.error import Error
from taskiq.exceptions import SecurityError, TaskiqResultTimeoutError


class SimpleError(Error):
__template__ = "simple"


class ValueTemplateError(Error):
__template__ = "value={value}"
value: int


class DefaultValueTemplateError(Error):
__template__ = "value={value}"
value: int = 10


class BaseError(Error):
__template__ = "base={base}, child={child}"
base: int = 1


class ChildError(BaseError):
child: str


class MissingAnnotationError(Error):
__template__ = "value={value}"


class IndexedTemplateError(Error):
__template__ = "{payload[key]}"
payload: dict[str, str]


def test_simple_error_message_and_repr() -> None:
error = SimpleError()
assert str(error) == "simple"
assert error.args == ("simple",)
assert repr(error).endswith(".SimpleError()")


def test_template_with_required_value() -> None:
error = ValueTemplateError(value=3)
assert str(error) == "value=3"
assert repr(error).endswith(".ValueTemplateError(value=3)")


def test_missing_argument_raises_type_error() -> None:
with pytest.raises(TypeError, match="Missing arguments: 'value'"):
ValueTemplateError() # type: ignore[call-arg]


def test_undeclared_argument_raises_type_error() -> None:
with pytest.raises(TypeError, match="Undeclared arguments: 'extra'"):
ValueTemplateError(value=1, extra=2) # type: ignore[call-arg]


def test_default_value_is_used_without_kwargs() -> None:
error = DefaultValueTemplateError()
assert str(error) == "value=10"
assert repr(error).endswith(".DefaultValueTemplateError(value=10)")


def test_annotations_are_collected_from_inheritance() -> None:
error = ChildError(child="ok")
assert str(error) == "base=1, child=ok"
assert repr(error).endswith(".ChildError(base=1, child='ok')")


def test_template_fields_must_be_annotated() -> None:
with pytest.raises(ValueError, match="Fields must be annotated: 'value'"):
MissingAnnotationError()


def test_indexed_template_field_does_not_require_extra_annotation() -> None:
error = IndexedTemplateError(payload={"key": "value"})
assert str(error) == "value"


def test_taskiq_exceptions_use_error_base_correctly() -> None:
timeout_error = TaskiqResultTimeoutError(timeout=1.5)
security_error = SecurityError(description="boom")
assert str(timeout_error) == "Waiting for task results has timed out, timeout=1.5"
assert str(security_error) == "Security exception occurred: boom"
66 changes: 66 additions & 0 deletions tests/test_exceptions_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import re

import pytest

from taskiq import InMemoryBroker
from taskiq.brokers.shared_broker import AsyncSharedBroker
from taskiq.exceptions import (
SharedBrokerListenError,
SharedBrokerSendTaskError,
TaskBrokerMismatchError,
UnknownTaskError,
)
from taskiq.message import BrokerMessage


def _broker_message(task_name: str) -> BrokerMessage:
return BrokerMessage(
task_id="task-id",
task_name=task_name,
message=b"{}",
labels={},
)


async def test_inmemory_broker_raises_unknown_task_error() -> None:
broker = InMemoryBroker()

with pytest.raises(
UnknownTaskError,
match=re.escape(
"Cannot send unknown task to the queue, task name - missing.task",
),
):
await broker.kick(_broker_message("missing.task"))


async def test_shared_broker_raises_send_task_error() -> None:
broker = AsyncSharedBroker()

with pytest.raises(
SharedBrokerSendTaskError,
match="You cannot use kiq directly on shared task",
):
await broker.kick(_broker_message("any.task"))


async def test_shared_broker_raises_listen_error() -> None:
broker = AsyncSharedBroker()

with pytest.raises(SharedBrokerListenError, match="Shared broker cannot listen"):
await broker.listen()


def test_registering_task_in_another_broker_raises_mismatch_error() -> None:
first_broker = InMemoryBroker()
second_broker = InMemoryBroker()

@first_broker.task(task_name="test.task")
async def test_task() -> None:
return None

with pytest.raises(
TaskBrokerMismatchError,
match="Task already has a different broker",
):
second_broker._register_task(test_task.task_name, test_task)
11 changes: 0 additions & 11 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading