diff --git a/src/a2a/helpers/proto_helpers.py b/src/a2a/helpers/proto_helpers.py index 79e1f739d..6cc6350b6 100644 --- a/src/a2a/helpers/proto_helpers.py +++ b/src/a2a/helpers/proto_helpers.py @@ -3,6 +3,10 @@ import uuid from collections.abc import Sequence +from typing import Any + +from google.protobuf import struct_pb2 +from google.protobuf.json_format import ParseDict from a2a.types.a2a_pb2 import ( Artifact, @@ -23,9 +27,9 @@ def new_message( parts: list[Part], - role: Role = Role.ROLE_AGENT, context_id: str | None = None, task_id: str | None = None, + role: Role = Role.ROLE_AGENT, ) -> Message: """Creates a new message containing a list of Parts.""" return Message( @@ -39,16 +43,17 @@ def new_message( def new_text_message( text: str, + media_type: str | None = None, context_id: str | None = None, task_id: str | None = None, role: Role = Role.ROLE_AGENT, ) -> Message: """Creates a new message containing a single text Part.""" return new_message( - parts=[Part(text=text)], - role=role, - task_id=task_id, + parts=[new_text_part(text, media_type=media_type)], context_id=context_id, + task_id=task_id, + role=role, ) @@ -57,6 +62,91 @@ def get_message_text(message: Message, delimiter: str = '\n') -> str: return delimiter.join(get_text_parts(message.parts)) +def new_data_message( + data: Any, + media_type: str | None = None, + context_id: str | None = None, + task_id: str | None = None, + role: Role = Role.ROLE_AGENT, +) -> Message: + """Creates a new message containing a single data Part. + + Args: + data: JSON-serializable data to embed (dict, list, str, etc.). + media_type: Optional MIME type of the part content (e.g., "text/plain", "application/json", "image/png"). + context_id: Optional context ID. + task_id: Optional task ID. + role: The role of the message sender (default: ROLE_AGENT). + + Returns: + A Message with a single data Part. + """ + return new_message( + parts=[new_data_part(data, media_type=media_type)], + context_id=context_id, + task_id=task_id, + role=role, + ) + + +def new_raw_message( # noqa: PLR0913 + raw: bytes, + media_type: str | None = None, + filename: str | None = None, + context_id: str | None = None, + task_id: str | None = None, + role: Role = Role.ROLE_AGENT, +) -> Message: + """Creates a new message containing a single raw bytes Part. + + Args: + raw: The raw bytes content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + context_id: Optional context ID. + task_id: Optional task ID. + role: The role of the message sender (default: ROLE_AGENT). + + Returns: + A Message with a single raw Part. + """ + return new_message( + parts=[new_raw_part(raw, media_type=media_type, filename=filename)], + context_id=context_id, + task_id=task_id, + role=role, + ) + + +def new_url_message( # noqa: PLR0913 + url: str, + media_type: str | None = None, + filename: str | None = None, + context_id: str | None = None, + task_id: str | None = None, + role: Role = Role.ROLE_AGENT, +) -> Message: + """Creates a new message containing a single URL Part. + + Args: + url: The URL pointing to the file content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + context_id: Optional context ID. + task_id: Optional task ID. + role: The role of the message sender (default: ROLE_AGENT). + + Returns: + A Message with a single URL Part. + """ + return new_message( + parts=[new_url_part(url, media_type=media_type, filename=filename)], + context_id=context_id, + task_id=task_id, + role=role, + ) + + # --- Artifact Helpers --- @@ -78,12 +168,98 @@ def new_artifact( def new_text_artifact( name: str, text: str, + media_type: str | None = None, description: str | None = None, artifact_id: str | None = None, ) -> Artifact: """Creates a new Artifact object containing only a single text Part.""" return new_artifact( - [Part(text=text)], + [new_text_part(text, media_type=media_type)], + name, + description, + artifact_id=artifact_id, + ) + + +def new_data_artifact( + name: str, + data: Any, + media_type: str | None = None, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object containing only a single data Part. + + Args: + name: The name of the artifact. + data: JSON-serializable data to embed (dict, list, str, etc.). + media_type: Optional MIME type of the part content (e.g., "text/plain", "application/json", "image/png"). + description: Optional description. + artifact_id: Optional artifact ID (auto-generated if not provided). + + Returns: + An Artifact with a single data Part. + """ + return new_artifact( + [new_data_part(data, media_type=media_type)], + name, + description, + artifact_id=artifact_id, + ) + + +def new_raw_artifact( # noqa: PLR0913 + name: str, + raw: bytes, + media_type: str | None = None, + filename: str | None = None, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object containing only a single raw bytes Part. + + Args: + name: The name of the artifact. + raw: The raw bytes content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + description: Optional description. + artifact_id: Optional artifact ID (auto-generated if not provided). + + Returns: + An Artifact with a single raw Part. + """ + return new_artifact( + [new_raw_part(raw, media_type=media_type, filename=filename)], + name, + description, + artifact_id=artifact_id, + ) + + +def new_url_artifact( # noqa: PLR0913 + name: str, + url: str, + media_type: str | None = None, + filename: str | None = None, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object containing only a single URL Part. + + Args: + name: The name of the artifact. + url: The URL pointing to the file content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + description: Optional description. + artifact_id: Optional artifact ID (auto-generated if not provided). + + Returns: + An Artifact with a single URL Part. + """ + return new_artifact( + [new_url_part(url, media_type=media_type, filename=filename)], name, description, artifact_id=artifact_id, @@ -141,6 +317,85 @@ def new_task( # --- Part Helpers --- +def new_text_part( + text: str, + media_type: str | None = None, +) -> Part: + """Creates a Part with text content. + + Args: + text: The text content. + media_type: Optional MIME type (e.g. 'text/plain', 'text/markdown'). + + Returns: + A Part with the text field set. + """ + return Part(text=text, media_type=media_type or '') + + +def new_data_part( + data: Any, + media_type: str | None = None, +) -> Part: + """Creates a Part with structured data (google.protobuf.Value). + + Args: + data: JSON-serializable data to embed (dict, list, str, etc.). + media_type: Optional MIME type of the part content (e.g., "text/plain", "application/json", "image/png"). + + Returns: + A Part with the data field set. + """ + return Part( + data=ParseDict(data, struct_pb2.Value()), + media_type=media_type or '', + ) + + +def new_raw_part( + raw: bytes, + media_type: str | None = None, + filename: str | None = None, +) -> Part: + """Creates a Part with raw bytes content. + + Args: + raw: The raw bytes content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + + Returns: + A Part with the raw field set. + """ + return Part( + raw=raw, + media_type=media_type or '', + filename=filename or '', + ) + + +def new_url_part( + url: str, + media_type: str | None = None, + filename: str | None = None, +) -> Part: + """Creates a Part with a URL pointing to file content. + + Args: + url: The URL to the file content. + media_type: Optional MIME type (e.g. 'image/png'). + filename: Optional filename. + + Returns: + A Part with the url field set. + """ + return Part( + url=url, + media_type=media_type or '', + filename=filename or '', + ) + + def get_text_parts(parts: Sequence[Part]) -> list[str]: """Extracts text content from all text Parts.""" return [part.text for part in parts if part.HasField('text')] diff --git a/tests/helpers/test_proto_helpers.py b/tests/helpers/test_proto_helpers.py index a4f6498ab..8fb68dbc2 100644 --- a/tests/helpers/test_proto_helpers.py +++ b/tests/helpers/test_proto_helpers.py @@ -1,37 +1,49 @@ """Tests for proto helpers.""" import pytest + from a2a.helpers.proto_helpers import ( - new_message, - new_text_message, + get_artifact_text, get_message_text, + get_stream_response_text, + get_text_parts, new_artifact, - new_text_artifact, - get_artifact_text, - new_task_from_user_message, + new_data_artifact, + new_data_message, + new_data_part, + new_message, + new_raw_artifact, + new_raw_message, + new_raw_part, new_task, - get_text_parts, - new_text_status_update_event, + new_task_from_user_message, + new_text_artifact, new_text_artifact_update_event, - get_stream_response_text, + new_text_message, + new_text_part, + new_text_status_update_event, + new_url_artifact, + new_url_message, + new_url_part, ) from a2a.types.a2a_pb2 import ( + Artifact, + Message, Part, Role, - Message, - Artifact, + StreamResponse, Task, TaskState, - StreamResponse, ) + # --- Message Helpers Tests --- def test_new_message() -> None: parts = [Part(text='hello')] msg = new_message( - parts=parts, role=Role.ROLE_USER, context_id='ctx1', task_id='task1' + parts, context_id='ctx1', task_id='task1', role=Role.ROLE_USER ) assert msg.role == Role.ROLE_USER assert msg.parts == parts @@ -42,11 +54,74 @@ def test_new_message() -> None: def test_new_text_message() -> None: msg = new_text_message( - text='hello', context_id='ctx1', task_id='task1', role=Role.ROLE_USER + 'hello', + media_type='text/plain', + context_id='ctx1', + task_id='task1', + role=Role.ROLE_USER, ) assert msg.role == Role.ROLE_USER assert len(msg.parts) == 1 assert msg.parts[0].text == 'hello' + assert msg.parts[0].media_type == 'text/plain' + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_new_data_message() -> None: + msg = new_data_message( + data={'key': 'value'}, + media_type='application/json', + context_id='ctx1', + task_id='task1', + role=Role.ROLE_USER, + ) + assert msg.role == Role.ROLE_USER + assert len(msg.parts) == 1 + assert msg.parts[0].HasField('data') + assert msg.parts[0].data.struct_value.fields['key'].string_value == 'value' + assert msg.parts[0].media_type == 'application/json' + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_new_raw_message() -> None: + msg = new_raw_message( + b'\x89PNG', + media_type='image/png', + filename='img.png', + context_id='ctx1', + task_id='task1', + role=Role.ROLE_USER, + ) + assert msg.role == Role.ROLE_USER + assert len(msg.parts) == 1 + assert msg.parts[0].HasField('raw') + assert msg.parts[0].raw == b'\x89PNG' + assert msg.parts[0].media_type == 'image/png' + assert msg.parts[0].filename == 'img.png' + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_new_url_message() -> None: + msg = new_url_message( + 'https://example.com/file.pdf', + media_type='application/pdf', + filename='file.pdf', + context_id='ctx1', + task_id='task1', + role=Role.ROLE_USER, + ) + assert msg.role == Role.ROLE_USER + assert len(msg.parts) == 1 + assert msg.parts[0].HasField('url') + assert msg.parts[0].url == 'https://example.com/file.pdf' + assert msg.parts[0].media_type == 'application/pdf' + assert msg.parts[0].filename == 'file.pdf' assert msg.context_id == 'ctx1' assert msg.task_id == 'task1' assert msg.message_id != '' @@ -90,6 +165,74 @@ def test_new_text_artifact_with_id() -> None: assert art.artifact_id == 'art1' +def test_new_data_artifact() -> None: + art = new_data_artifact( + name='result', data={'score': 1.0}, description='desc' + ) + assert art.name == 'result' + assert art.description == 'desc' + assert len(art.parts) == 1 + assert art.parts[0].HasField('data') + assert art.parts[0].data.struct_value.fields['score'].number_value == 1.0 + assert art.artifact_id != '' + + +def test_new_data_artifact_with_id() -> None: + art = new_data_artifact(name='result', data={'x': 'y'}, artifact_id='art1') + assert art.artifact_id == 'art1' + assert art.parts[0].data.struct_value.fields['x'].string_value == 'y' + + +def test_new_raw_artifact() -> None: + art = new_raw_artifact( + name='screenshot', + raw=b'\x89PNG', + media_type='image/png', + filename='screen.png', + description='desc', + artifact_id='art1', + ) + assert art.name == 'screenshot' + assert art.description == 'desc' + assert art.artifact_id == 'art1' + assert len(art.parts) == 1 + assert art.parts[0].HasField('raw') + assert art.parts[0].raw == b'\x89PNG' + assert art.parts[0].media_type == 'image/png' + assert art.parts[0].filename == 'screen.png' + + +def test_new_raw_artifact_minimal() -> None: + art = new_raw_artifact(name='file', raw=b'data') + assert art.parts[0].raw == b'data' + assert art.artifact_id != '' + + +def test_new_url_artifact() -> None: + art = new_url_artifact( + name='report', + url='https://example.com/report.pdf', + media_type='application/pdf', + filename='report.pdf', + description='desc', + artifact_id='art1', + ) + assert art.name == 'report' + assert art.description == 'desc' + assert art.artifact_id == 'art1' + assert len(art.parts) == 1 + assert art.parts[0].HasField('url') + assert art.parts[0].url == 'https://example.com/report.pdf' + assert art.parts[0].media_type == 'application/pdf' + assert art.parts[0].filename == 'report.pdf' + + +def test_new_url_artifact_minimal() -> None: + art = new_url_artifact(name='img', url='https://example.com/img.png') + assert art.parts[0].url == 'https://example.com/img.png' + assert art.artifact_id != '' + + def test_get_artifact_text() -> None: art = Artifact(parts=[Part(text='hello'), Part(text='world')]) assert get_artifact_text(art) == 'hello\nworld' @@ -149,6 +292,78 @@ def test_get_text_parts() -> None: assert get_text_parts(parts) == ['hello', 'world'] +def test_new_text_part() -> None: + part = new_text_part('hello') + assert part.HasField('text') + assert part.text == 'hello' + assert part.media_type == '' + + +def test_new_text_part_with_media_type() -> None: + part = new_text_part('# Hello', media_type='text/markdown') + assert part.HasField('text') + assert part.text == '# Hello' + assert part.media_type == 'text/markdown' + + +def test_new_data_part_from_dict() -> None: + part = new_data_part({'key': 'value', 'count': 42}) + assert part.HasField('data') + assert part.data.struct_value.fields['key'].string_value == 'value' + assert part.data.struct_value.fields['count'].number_value == 42 + assert part.media_type == '' + + +def test_new_data_part_with_media_type() -> None: + part = new_data_part({'key': 'value'}, media_type='application/json') + assert part.HasField('data') + assert part.media_type == 'application/json' + + +def test_new_data_part_from_list() -> None: + part = new_data_part([1, 2, 3]) + assert part.HasField('data') + assert part.data.list_value.values[0].number_value == 1 + assert part.data.list_value.values[1].number_value == 2 + assert part.data.list_value.values[2].number_value == 3 + + +def test_new_raw_part() -> None: + part = new_raw_part(b'\x89PNG', media_type='image/png', filename='img.png') + assert part.HasField('raw') + assert part.raw == b'\x89PNG' + assert part.media_type == 'image/png' + assert part.filename == 'img.png' + + +def test_new_raw_part_minimal() -> None: + part = new_raw_part(b'data') + assert part.HasField('raw') + assert part.raw == b'data' + assert part.media_type == '' + assert part.filename == '' + + +def test_new_url_part() -> None: + part = new_url_part( + 'https://example.com/file.pdf', + media_type='application/pdf', + filename='file.pdf', + ) + assert part.HasField('url') + assert part.url == 'https://example.com/file.pdf' + assert part.media_type == 'application/pdf' + assert part.filename == 'file.pdf' + + +def test_new_url_part_minimal() -> None: + part = new_url_part('https://example.com/img.png') + assert part.HasField('url') + assert part.url == 'https://example.com/img.png' + assert part.media_type == '' + assert part.filename == '' + + # --- Event & Stream Helpers Tests ---