From aef2f30de680e86572e2877ee799d1f8ca09b4e9 Mon Sep 17 00:00:00 2001 From: Krzysztof Dziedzic Date: Tue, 21 Apr 2026 13:08:56 +0000 Subject: [PATCH] test: test push notifications in itk --- .github/workflows/itk.yaml | 2 +- itk/README.md | 2 +- itk/main.py | 127 +++++++++++++++++++++++-------------- itk/run_itk.sh | 28 ++++++-- 4 files changed, 105 insertions(+), 54 deletions(-) diff --git a/.github/workflows/itk.yaml b/.github/workflows/itk.yaml index feb9325e3..33d7585d6 100644 --- a/.github/workflows/itk.yaml +++ b/.github/workflows/itk.yaml @@ -31,4 +31,4 @@ jobs: run: bash run_itk.sh working-directory: itk env: - A2A_SAMPLES_REVISION: itk-v.016-alpha + A2A_SAMPLES_REVISION: itk-v.02-alpha diff --git a/itk/README.md b/itk/README.md index 9a82d0469..3044b37af 100644 --- a/itk/README.md +++ b/itk/README.md @@ -36,7 +36,7 @@ You must set the `A2A_SAMPLES_REVISION` environment variable to specify which re Example: ``` -export A2A_SAMPLES_REVISION=itk-v.015-alpha +export A2A_SAMPLES_REVISION=itk-v.02-alpha ``` ### 2. Execute Tests diff --git a/itk/main.py b/itk/main.py index 6792c540a..cc761d081 100644 --- a/itk/main.py +++ b/itk/main.py @@ -17,13 +17,21 @@ from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes -from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.events import EventQueue +from a2a.server.routes import ( + create_agent_card_routes, + create_jsonrpc_routes, + create_rest_routes, +) from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler -from a2a.server.tasks import TaskUpdater +from a2a.server.tasks import ( + TaskUpdater, + BasePushNotificationSender, + InMemoryPushNotificationConfigStore, +) from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.server.context import ServerCallContext from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( AgentCapabilities, @@ -35,11 +43,12 @@ Task, TaskState, TaskStatus, + TaskPushNotificationConfig, ) from a2a.utils import TransportProtocol - -log_level = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper() +log_level_str = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper() +log_level = getattr(logging, log_level_str, logging.INFO) logging.basicConfig(level=log_level) logger = logging.getLogger(__name__) @@ -106,7 +115,9 @@ def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message: ) -async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]: +async def handle_call_agent( + call: instruction_pb2.CallAgent, +) -> list[str]: """Handles the CallAgent instruction by invoking another agent.""" logger.info('Calling agent %s via %s', call.agent_card_uri, call.transport) @@ -131,36 +142,47 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]: selected_transport == TransportProtocol.GRPC ) + if call.HasField('push_notification'): + url = call.push_notification.url + if not url: + raise ValueError('URL not specified in push_notification behavior') + if not url.startswith(('http://', 'https://')): + url = f'http://{url}' + config.push_notification_config = TaskPushNotificationConfig( + url=f'{url}/notifications', + token='itk-token', # noqa: S106 + ) + try: - client = await create_client(call.agent_card_uri, client_config=config) + client = await create_client( + call.agent_card_uri, + client_config=config, + ) # Wrap nested instruction - async with client: - nested_msg = wrap_instruction_to_request(call.instruction) - request = SendMessageRequest(message=nested_msg) - - results: list[str] = [] - async for event in client.send_message(request): - # Event is StreamResponse - logger.info('Event: %s', event) - stream_resp = event - - message = None - if stream_resp.HasField('message'): - message = stream_resp.message - elif stream_resp.HasField( - 'task' - ) and stream_resp.task.status.HasField('message'): - message = stream_resp.task.status.message - elif stream_resp.HasField( - 'status_update' - ) and stream_resp.status_update.status.HasField('message'): - message = stream_resp.status_update.status.message - - if message: - results.extend( - part.text for part in message.parts if part.text - ) + nested_msg = wrap_instruction_to_request(call.instruction) + request = SendMessageRequest(message=nested_msg) + + results = [] + async for event in client.send_message(request): + # Event is streaming response and task + logger.info('Event: %s', event) + stream_resp = event + + message = None + if stream_resp.HasField('message'): + message = stream_resp.message + elif stream_resp.HasField( + 'task' + ) and stream_resp.task.status.HasField('message'): + message = stream_resp.task.status.message + elif stream_resp.HasField( + 'status_update' + ) and stream_resp.status_update.status.HasField('message'): + message = stream_resp.status_update.status.message + + if message: + results.extend(part.text for part in message.parts if part.text) except Exception as e: logger.exception('Failed to call outbound agent') @@ -171,7 +193,9 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]: return results -async def handle_instruction(inst: instruction_pb2.Instruction) -> list[str]: +async def handle_instruction( + inst: instruction_pb2.Instruction, +) -> list[str]: """Recursively handles instructions.""" if inst.HasField('call_agent'): return await handle_call_agent(inst.call_agent) @@ -303,9 +327,7 @@ async def main_async(http_port: int, grpc_port: int) -> None: description='Python agent using SDK 1.0.', version='1.0.0', capabilities=AgentCapabilities( - streaming=True, - push_notifications=True, - extended_agent_card=True, + streaming=True, push_notifications=True, extended_agent_card=True ), default_input_modes=['text/plain'], default_output_modes=['text/plain'], @@ -313,23 +335,32 @@ async def main_async(http_port: int, grpc_port: int) -> None: ) task_store = InMemoryTaskStore() + push_config_store = InMemoryPushNotificationConfigStore() + push_sender = BasePushNotificationSender( + httpx_client=httpx.AsyncClient(), + config_store=push_config_store, + context=ServerCallContext(), + ) + handler = DefaultRequestHandler( agent_executor=V10AgentExecutor(), - task_store=task_store, agent_card=agent_card, + task_store=task_store, queue_manager=InMemoryQueueManager(), + push_config_store=push_config_store, + push_sender=push_sender, ) handler_extended = DefaultRequestHandler( agent_executor=V10AgentExecutor(), - task_store=task_store, agent_card=agent_card, + task_store=task_store, queue_manager=InMemoryQueueManager(), + push_config_store=push_config_store, + push_sender=push_sender, extended_agent_card=agent_card, ) - app = FastAPI() - agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/.well-known/agent-card.json' ) @@ -338,15 +369,16 @@ async def main_async(http_port: int, grpc_port: int) -> None: rpc_url='/', enable_v0_3_compat=True, ) - app.mount( - '/jsonrpc', - FastAPI(routes=jsonrpc_routes + agent_card_routes), - ) - rest_routes = create_rest_routes( request_handler=handler, enable_v0_3_compat=True, ) + + app = FastAPI() + app.mount( + '/jsonrpc', + FastAPI(routes=jsonrpc_routes + agent_card_routes), + ) app.mount('/rest', FastAPI(routes=rest_routes + agent_card_routes)) server = grpc.aio.server() @@ -365,9 +397,8 @@ async def main_async(http_port: int, grpc_port: int) -> None: grpc_port, ) - uvicorn_log_level = os.environ.get('ITK_LOG_LEVEL', 'INFO').lower() config = uvicorn.Config( - app, host='127.0.0.1', port=http_port, log_level=uvicorn_log_level + app, host='127.0.0.1', port=http_port, log_level=log_level_str.lower() ) uvicorn_server = uvicorn.Server(config) diff --git a/itk/run_itk.sh b/itk/run_itk.sh index 2d9371c14..21736f171 100755 --- a/itk/run_itk.sh +++ b/itk/run_itk.sh @@ -119,14 +119,16 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \ "sdks": ["current", "python_v10", "python_v03", "go_v10", "go_v03"], "traversal": "euler", "edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"], - "protocols": ["jsonrpc", "grpc"] + "protocols": ["jsonrpc", "grpc"], + "behavior": "send_message" }, { "name": "Star Topology (No Go v03) - HTTP_JSON", "sdks": ["current", "python_v10", "python_v03", "go_v10"], "traversal": "euler", "edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"], - "protocols": ["http_json"] + "protocols": ["http_json"], + "behavior": "send_message" }, { "name": "Star Topology (Full) - JSONRPC & GRPC (Streaming)", @@ -134,7 +136,8 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \ "traversal": "euler", "edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"], "protocols": ["jsonrpc", "grpc"], - "streaming": true + "streaming": true, + "behavior": "send_message" }, { "name": "Star Topology (No Go v03) - HTTP_JSON (Streaming)", @@ -142,7 +145,24 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \ "traversal": "euler", "edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"], "protocols": ["http_json"], - "streaming": true + "streaming": true, + "behavior": "send_message" + }, + { + "name": "Push Notification Test - JSONRPC & GRPC", + "sdks": ["current", "python_v10", "python_v03", "go_v03"], + "traversal": "euler", + "edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"], + "protocols": ["jsonrpc", "grpc"], + "behavior": "push_notification" + }, + { + "name": "Push Notification Test - HTTP_JSON", + "sdks": ["current", "python_v10", "python_v03"], + "traversal": "euler", + "edges": ["0->1", "0->2", "1->0", "2->0"], + "protocols": ["http_json"], + "behavior": "push_notification" } ] }')