Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/itk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ jobs:
run: bash run_itk.sh
working-directory: itk
env:
A2A_SAMPLES_REVISION: itk-v.016-alpha
A2A_SAMPLES_REVISION: itk-v.02-alpha
2 changes: 1 addition & 1 deletion itk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ You must set the `A2A_SAMPLES_REVISION` environment variable to specify which re

Example:
```
export A2A_SAMPLES_REVISION=itk-v.015-alpha
export A2A_SAMPLES_REVISION=itk-v.02-alpha
```

### 2. Execute Tests
Expand Down
127 changes: 79 additions & 48 deletions itk/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,21 @@
from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
from a2a.server.routes.rest_routes import create_rest_routes
from a2a.server.events import EventQueue
from a2a.server.routes import (
create_agent_card_routes,
create_jsonrpc_routes,
create_rest_routes,
)
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
from a2a.server.tasks import TaskUpdater
from a2a.server.tasks import (
TaskUpdater,
BasePushNotificationSender,
InMemoryPushNotificationConfigStore,
)
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.server.context import ServerCallContext
from a2a.types import a2a_pb2_grpc
from a2a.types.a2a_pb2 import (
AgentCapabilities,
Expand All @@ -35,11 +43,12 @@
Task,
TaskState,
TaskStatus,
TaskPushNotificationConfig,
)
from a2a.utils import TransportProtocol


log_level = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper()
log_level_str = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper()
log_level = getattr(logging, log_level_str, logging.INFO)
logging.basicConfig(level=log_level)
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,7 +115,9 @@ def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message:
)


async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
async def handle_call_agent(
call: instruction_pb2.CallAgent,
) -> list[str]:
"""Handles the CallAgent instruction by invoking another agent."""
logger.info('Calling agent %s via %s', call.agent_card_uri, call.transport)

Expand All @@ -131,36 +142,47 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
selected_transport == TransportProtocol.GRPC
)

if call.HasField('push_notification'):
url = call.push_notification.url
if not url:
raise ValueError('URL not specified in push_notification behavior')
if not url.startswith(('http://', 'https://')):
url = f'http://{url}'
config.push_notification_config = TaskPushNotificationConfig(
url=f'{url}/notifications',
token='itk-token', # noqa: S106
)

try:
client = await create_client(call.agent_card_uri, client_config=config)
client = await create_client(
call.agent_card_uri,
client_config=config,
)

# Wrap nested instruction
async with client:
nested_msg = wrap_instruction_to_request(call.instruction)
request = SendMessageRequest(message=nested_msg)

results: list[str] = []
async for event in client.send_message(request):
# Event is StreamResponse
logger.info('Event: %s', event)
stream_resp = event

message = None
if stream_resp.HasField('message'):
message = stream_resp.message
elif stream_resp.HasField(
'task'
) and stream_resp.task.status.HasField('message'):
message = stream_resp.task.status.message
elif stream_resp.HasField(
'status_update'
) and stream_resp.status_update.status.HasField('message'):
message = stream_resp.status_update.status.message

if message:
results.extend(
part.text for part in message.parts if part.text
)
nested_msg = wrap_instruction_to_request(call.instruction)
request = SendMessageRequest(message=nested_msg)

results = []
async for event in client.send_message(request):
# Event is streaming response and task
logger.info('Event: %s', event)
stream_resp = event

message = None
if stream_resp.HasField('message'):
message = stream_resp.message
elif stream_resp.HasField(
'task'
) and stream_resp.task.status.HasField('message'):
message = stream_resp.task.status.message
elif stream_resp.HasField(
'status_update'
) and stream_resp.status_update.status.HasField('message'):
message = stream_resp.status_update.status.message

if message:
results.extend(part.text for part in message.parts if part.text)

except Exception as e:
logger.exception('Failed to call outbound agent')
Expand All @@ -171,7 +193,9 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
return results


async def handle_instruction(inst: instruction_pb2.Instruction) -> list[str]:
async def handle_instruction(
inst: instruction_pb2.Instruction,
) -> list[str]:
"""Recursively handles instructions."""
if inst.HasField('call_agent'):
return await handle_call_agent(inst.call_agent)
Expand Down Expand Up @@ -303,33 +327,40 @@ async def main_async(http_port: int, grpc_port: int) -> None:
description='Python agent using SDK 1.0.',
version='1.0.0',
capabilities=AgentCapabilities(
streaming=True,
push_notifications=True,
extended_agent_card=True,
streaming=True, push_notifications=True, extended_agent_card=True
),
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
supported_interfaces=interfaces,
)

task_store = InMemoryTaskStore()
push_config_store = InMemoryPushNotificationConfigStore()
push_sender = BasePushNotificationSender(
httpx_client=httpx.AsyncClient(),
config_store=push_config_store,
context=ServerCallContext(),
)

handler = DefaultRequestHandler(
agent_executor=V10AgentExecutor(),
task_store=task_store,
agent_card=agent_card,
task_store=task_store,
queue_manager=InMemoryQueueManager(),
push_config_store=push_config_store,
push_sender=push_sender,
)

handler_extended = DefaultRequestHandler(
agent_executor=V10AgentExecutor(),
task_store=task_store,
agent_card=agent_card,
task_store=task_store,
queue_manager=InMemoryQueueManager(),
push_config_store=push_config_store,
push_sender=push_sender,
extended_agent_card=agent_card,
)

app = FastAPI()

agent_card_routes = create_agent_card_routes(
agent_card=agent_card, card_url='/.well-known/agent-card.json'
)
Expand All @@ -338,15 +369,16 @@ async def main_async(http_port: int, grpc_port: int) -> None:
rpc_url='/',
enable_v0_3_compat=True,
)
app.mount(
'/jsonrpc',
FastAPI(routes=jsonrpc_routes + agent_card_routes),
)

rest_routes = create_rest_routes(
request_handler=handler,
enable_v0_3_compat=True,
)

app = FastAPI()
app.mount(
'/jsonrpc',
FastAPI(routes=jsonrpc_routes + agent_card_routes),
)
app.mount('/rest', FastAPI(routes=rest_routes + agent_card_routes))

server = grpc.aio.server()
Expand All @@ -365,9 +397,8 @@ async def main_async(http_port: int, grpc_port: int) -> None:
grpc_port,
)

uvicorn_log_level = os.environ.get('ITK_LOG_LEVEL', 'INFO').lower()
config = uvicorn.Config(
app, host='127.0.0.1', port=http_port, log_level=uvicorn_log_level
app, host='127.0.0.1', port=http_port, log_level=log_level_str.lower()
)
uvicorn_server = uvicorn.Server(config)

Expand Down
28 changes: 24 additions & 4 deletions itk/run_itk.sh
Original file line number Diff line number Diff line change
Expand Up @@ -119,30 +119,50 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \
"sdks": ["current", "python_v10", "python_v03", "go_v10", "go_v03"],
"traversal": "euler",
"edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"],
"protocols": ["jsonrpc", "grpc"]
"protocols": ["jsonrpc", "grpc"],
"behavior": "send_message"
},
{
"name": "Star Topology (No Go v03) - HTTP_JSON",
"sdks": ["current", "python_v10", "python_v03", "go_v10"],
"traversal": "euler",
"edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"],
"protocols": ["http_json"]
"protocols": ["http_json"],
"behavior": "send_message"
},
{
"name": "Star Topology (Full) - JSONRPC & GRPC (Streaming)",
"sdks": ["current", "python_v10", "python_v03", "go_v10", "go_v03"],
"traversal": "euler",
"edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"],
"protocols": ["jsonrpc", "grpc"],
"streaming": true
"streaming": true,
"behavior": "send_message"
},
{
"name": "Star Topology (No Go v03) - HTTP_JSON (Streaming)",
"sdks": ["current", "python_v10", "python_v03", "go_v10"],
"traversal": "euler",
"edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"],
"protocols": ["http_json"],
"streaming": true
"streaming": true,
"behavior": "send_message"
},
{
"name": "Push Notification Test - JSONRPC & GRPC",
"sdks": ["current", "python_v10", "python_v03", "go_v03"],
"traversal": "euler",
"edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"],
"protocols": ["jsonrpc", "grpc"],
"behavior": "push_notification"
},
{
"name": "Push Notification Test - HTTP_JSON",
"sdks": ["current", "python_v10", "python_v03"],
"traversal": "euler",
"edges": ["0->1", "0->2", "1->0", "2->0"],
"protocols": ["http_json"],
"behavior": "push_notification"
}
]
}')
Expand Down
Loading