diff --git a/src/a2a/helpers/__init__.py b/src/a2a/helpers/__init__.py index c42429d43..a4a0401e7 100644 --- a/src/a2a/helpers/__init__.py +++ b/src/a2a/helpers/__init__.py @@ -3,32 +3,64 @@ from a2a.helpers.agent_card import display_agent_card from a2a.helpers.proto_helpers import ( get_artifact_text, + get_data_parts, get_message_text, + get_raw_parts, get_stream_response_text, get_text_parts, + get_url_parts, new_artifact, + new_data_artifact, + new_data_artifact_update_event, + new_data_message, + new_data_part, new_message, + new_raw_artifact, + new_raw_artifact_update_event, + new_raw_message, + new_raw_part, new_task, new_task_from_user_message, new_text_artifact, new_text_artifact_update_event, new_text_message, + new_text_part, new_text_status_update_event, + new_url_artifact, + new_url_artifact_update_event, + new_url_message, + new_url_part, ) __all__ = [ 'display_agent_card', 'get_artifact_text', + 'get_data_parts', 'get_message_text', + 'get_raw_parts', 'get_stream_response_text', 'get_text_parts', + 'get_url_parts', 'new_artifact', + 'new_data_artifact', + 'new_data_artifact_update_event', + 'new_data_message', + 'new_data_part', 'new_message', + 'new_raw_artifact', + 'new_raw_artifact_update_event', + 'new_raw_message', + 'new_raw_part', 'new_task', 'new_task_from_user_message', 'new_text_artifact', 'new_text_artifact_update_event', 'new_text_message', + 'new_text_part', 'new_text_status_update_event', + 'new_url_artifact', + 'new_url_artifact_update_event', + 'new_url_message', + 'new_url_part', ] diff --git a/src/a2a/helpers/proto_helpers.py b/src/a2a/helpers/proto_helpers.py index 6cc6350b6..a32413cf0 100644 --- a/src/a2a/helpers/proto_helpers.py +++ b/src/a2a/helpers/proto_helpers.py @@ -401,6 +401,45 @@ def get_text_parts(parts: Sequence[Part]) -> list[str]: return [part.text for part in parts if part.HasField('text')] +def get_data_parts(parts: Sequence[Part]) -> list[Any]: + """Extracts structured data from all data Parts. + + Each returned element is the Python object obtained from the + ``google.protobuf.Value`` stored in the Part. + + Args: + parts: A sequence of ``Part`` objects. + + Returns: + A list of deserialized data values from any data Parts found. + """ + return [part.data for part in parts if part.HasField('data')] + + +def get_raw_parts(parts: Sequence[Part]) -> list[bytes]: + """Extracts raw bytes content from all raw Parts. + + Args: + parts: A sequence of ``Part`` objects. + + Returns: + A list of ``bytes`` from any raw Parts found. + """ + return [part.raw for part in parts if part.HasField('raw')] + + +def get_url_parts(parts: Sequence[Part]) -> list[str]: + """Extracts URL strings from all URL Parts. + + Args: + parts: A sequence of ``Part`` objects. + + Returns: + A list of URL strings from any URL Parts found. + """ + return [part.url for part in parts if part.HasField('url')] + + # --- Event & Stream Helpers --- @@ -447,6 +486,129 @@ def new_text_artifact_update_event( # noqa: PLR0913 ) +def new_data_artifact_update_event( # noqa: PLR0913 + task_id: str, + context_id: str, + name: str, + data: Any, + media_type: str | None = None, + append: bool = False, + last_chunk: bool = False, + artifact_id: str | None = None, +) -> TaskArtifactUpdateEvent: + """Creates a TaskArtifactUpdateEvent with a single data artifact. + + Args: + task_id: The task ID. + context_id: The context ID. + name: The name of the artifact. + data: JSON-serializable data to embed (dict, list, str, etc.). + media_type: Optional MIME type of the part content. + append: Whether to append to the existing artifact. + last_chunk: Whether this is the last chunk. + artifact_id: Optional artifact ID (auto-generated if not provided). + + Returns: + A TaskArtifactUpdateEvent with a single data artifact. + """ + return TaskArtifactUpdateEvent( + task_id=task_id, + context_id=context_id, + artifact=new_data_artifact( + name=name, + data=data, + media_type=media_type, + artifact_id=artifact_id, + ), + append=append, + last_chunk=last_chunk, + ) + + +def new_raw_artifact_update_event( # noqa: PLR0913 + task_id: str, + context_id: str, + name: str, + raw: bytes, + media_type: str | None = None, + filename: str | None = None, + append: bool = False, + last_chunk: bool = False, + artifact_id: str | None = None, +) -> TaskArtifactUpdateEvent: + """Creates a TaskArtifactUpdateEvent with a single raw bytes artifact. + + Args: + task_id: The task ID. + context_id: The context ID. + name: The name of the artifact. + raw: The raw bytes content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + append: Whether to append to the existing artifact. + last_chunk: Whether this is the last chunk. + artifact_id: Optional artifact ID (auto-generated if not provided). + + Returns: + A TaskArtifactUpdateEvent with a single raw artifact. + """ + return TaskArtifactUpdateEvent( + task_id=task_id, + context_id=context_id, + artifact=new_raw_artifact( + name=name, + raw=raw, + media_type=media_type, + filename=filename, + artifact_id=artifact_id, + ), + append=append, + last_chunk=last_chunk, + ) + + +def new_url_artifact_update_event( # noqa: PLR0913 + task_id: str, + context_id: str, + name: str, + url: str, + media_type: str | None = None, + filename: str | None = None, + append: bool = False, + last_chunk: bool = False, + artifact_id: str | None = None, +) -> TaskArtifactUpdateEvent: + """Creates a TaskArtifactUpdateEvent with a single URL artifact. + + Args: + task_id: The task ID. + context_id: The context ID. + name: The name of the artifact. + url: The URL pointing to the file content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + append: Whether to append to the existing artifact. + last_chunk: Whether this is the last chunk. + artifact_id: Optional artifact ID (auto-generated if not provided). + + Returns: + A TaskArtifactUpdateEvent with a single URL artifact. + """ + return TaskArtifactUpdateEvent( + task_id=task_id, + context_id=context_id, + artifact=new_url_artifact( + name=name, + url=url, + media_type=media_type, + filename=filename, + artifact_id=artifact_id, + ), + append=append, + last_chunk=last_chunk, + ) + + def get_stream_response_text( response: StreamResponse, delimiter: str = '\n' ) -> str: diff --git a/tests/helpers/test_proto_helpers.py b/tests/helpers/test_proto_helpers.py index 8fb68dbc2..f05e6bbc3 100644 --- a/tests/helpers/test_proto_helpers.py +++ b/tests/helpers/test_proto_helpers.py @@ -4,15 +4,20 @@ from a2a.helpers.proto_helpers import ( get_artifact_text, + get_data_parts, get_message_text, + get_raw_parts, get_stream_response_text, get_text_parts, + get_url_parts, new_artifact, new_data_artifact, + new_data_artifact_update_event, new_data_message, new_data_part, new_message, new_raw_artifact, + new_raw_artifact_update_event, new_raw_message, new_raw_part, new_task, @@ -23,6 +28,7 @@ new_text_part, new_text_status_update_event, new_url_artifact, + new_url_artifact_update_event, new_url_message, new_url_part, ) @@ -443,3 +449,165 @@ def test_get_stream_response_text_artifact_update() -> None: def test_get_stream_response_text_empty() -> None: resp = StreamResponse() assert get_stream_response_text(resp) == '' + + +# --- Part Extractor Tests --- + + +def test_get_data_parts() -> None: + parts = [ + new_data_part({'key': 'value'}), + Part(text='hello'), + new_data_part([1, 2]), + ] + result = get_data_parts(parts) + assert len(result) == 2 + assert result[0].struct_value.fields['key'].string_value == 'value' + assert result[1].list_value.values[0].number_value == 1 + + +def test_get_data_parts_empty() -> None: + parts = [Part(text='hello'), Part(url='http://example.com')] + assert get_data_parts(parts) == [] + + +def test_get_raw_parts() -> None: + parts = [ + Part(raw=b'\x89PNG'), + Part(text='hello'), + Part(raw=b'\xff\xd8'), + ] + result = get_raw_parts(parts) + assert result == [b'\x89PNG', b'\xff\xd8'] + + +def test_get_raw_parts_empty() -> None: + parts = [Part(text='hello')] + assert get_raw_parts(parts) == [] + + +def test_get_url_parts() -> None: + parts = [ + Part(url='https://example.com/a.png'), + Part(text='hello'), + Part(url='https://example.com/b.pdf'), + ] + result = get_url_parts(parts) + assert result == [ + 'https://example.com/a.png', + 'https://example.com/b.pdf', + ] + + +def test_get_url_parts_empty() -> None: + parts = [Part(text='hello')] + assert get_url_parts(parts) == [] + + +# --- Non-text Artifact Update Event Tests --- + + +def test_new_data_artifact_update_event() -> None: + event = new_data_artifact_update_event( + task_id='task1', + context_id='ctx1', + name='result', + data={'score': 0.95}, + media_type='application/json', + append=True, + last_chunk=True, + artifact_id='art1', + ) + assert event.task_id == 'task1' + assert event.context_id == 'ctx1' + assert event.artifact.name == 'result' + assert event.artifact.artifact_id == 'art1' + assert event.artifact.parts[0].HasField('data') + assert ( + event.artifact.parts[0].data.struct_value.fields['score'].number_value + == 0.95 + ) + assert event.artifact.parts[0].media_type == 'application/json' + assert event.append is True + assert event.last_chunk is True + + +def test_new_data_artifact_update_event_minimal() -> None: + event = new_data_artifact_update_event( + task_id='task1', + context_id='ctx1', + name='result', + data=[1, 2, 3], + ) + assert event.artifact.parts[0].HasField('data') + assert event.append is False + assert event.last_chunk is False + assert event.artifact.artifact_id != '' + + +def test_new_raw_artifact_update_event() -> None: + event = new_raw_artifact_update_event( + task_id='task1', + context_id='ctx1', + name='screenshot', + raw=b'\x89PNG', + media_type='image/png', + filename='screen.png', + append=False, + last_chunk=True, + artifact_id='art1', + ) + assert event.task_id == 'task1' + assert event.context_id == 'ctx1' + assert event.artifact.name == 'screenshot' + assert event.artifact.artifact_id == 'art1' + assert event.artifact.parts[0].HasField('raw') + assert event.artifact.parts[0].raw == b'\x89PNG' + assert event.artifact.parts[0].media_type == 'image/png' + assert event.artifact.parts[0].filename == 'screen.png' + assert event.last_chunk is True + + +def test_new_raw_artifact_update_event_minimal() -> None: + event = new_raw_artifact_update_event( + task_id='task1', + context_id='ctx1', + name='file', + raw=b'data', + ) + assert event.artifact.parts[0].raw == b'data' + assert event.artifact.artifact_id != '' + + +def test_new_url_artifact_update_event() -> None: + event = new_url_artifact_update_event( + task_id='task1', + context_id='ctx1', + name='report', + url='https://example.com/report.pdf', + media_type='application/pdf', + filename='report.pdf', + append=True, + last_chunk=False, + artifact_id='art1', + ) + assert event.task_id == 'task1' + assert event.context_id == 'ctx1' + assert event.artifact.name == 'report' + assert event.artifact.artifact_id == 'art1' + assert event.artifact.parts[0].HasField('url') + assert event.artifact.parts[0].url == 'https://example.com/report.pdf' + assert event.artifact.parts[0].media_type == 'application/pdf' + assert event.artifact.parts[0].filename == 'report.pdf' + assert event.append is True + + +def test_new_url_artifact_update_event_minimal() -> None: + event = new_url_artifact_update_event( + task_id='task1', + context_id='ctx1', + name='img', + url='https://example.com/img.png', + ) + assert event.artifact.parts[0].url == 'https://example.com/img.png' + assert event.artifact.artifact_id != ''