diff --git a/src/bedrock_agentcore/memory/integrations/strands/config.py b/src/bedrock_agentcore/memory/integrations/strands/config.py index 3fc8901b..a42e824e 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/config.py +++ b/src/bedrock_agentcore/memory/integrations/strands/config.py @@ -68,6 +68,22 @@ class AgentCoreMemoryConfig(BaseModel): persistence_mode: Controls what gets persisted to AgentCore Memory. FULL (default): persist everything. NONE: disable all persistence while keeping local state management and memory injection working. + async_mode: When True, the session manager registers async hook callbacks that + offload the per-turn boto3 calls (append_message, sync_agent, + retrieve_customer_context, and buffer flushes) to a thread via + asyncio.to_thread, keeping the asyncio event loop unblocked. Intended for + async agent runtimes (e.g. Agent.stream_async() in a WebSocket server). + Default is False (existing synchronous behavior, unchanged). + + Requires async invocation (stream_async / invoke_async). Sync agent() calls + will raise RuntimeError from Strands' hook registry because it refuses to + dispatch coroutine callbacks through the sync path. + + Note: this does NOT cover agent initialization. Strands disallows async + callbacks for AgentInitializedEvent, so the read_session / read_agent / + list_messages calls that run during Agent(...) construction still block + the calling thread. If that matters, construct the Agent off-loop + (e.g. `await asyncio.to_thread(Agent, ...)`). """ memory_id: str = Field(min_length=1) @@ -81,6 +97,7 @@ class AgentCoreMemoryConfig(BaseModel): default_metadata: Optional[Dict[str, Any]] = None metadata_provider: Optional[Callable[[], Dict[str, Any]]] = None persistence_mode: PersistenceMode = Field(default=PersistenceMode.FULL) + async_mode: bool = Field(default=False) @field_validator("default_metadata", mode="before") @classmethod diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 6bfaca9d..c4fe8540 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -1,5 +1,6 @@ """AgentCore Memory-based session manager for Bedrock AgentCore Memory integration.""" +import asyncio import json import logging import threading @@ -10,7 +11,18 @@ import boto3 from botocore.config import Config as BotocoreConfig +from strands.experimental.hooks.events import ( + BidiAfterInvocationEvent, + BidiAgentInitializedEvent, + BidiMessageAddedEvent, +) +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, +) from strands.hooks import AfterInvocationEvent, MessageAddedEvent +from strands.hooks.events import AgentInitializedEvent from strands.hooks.registry import HookRegistry from strands.session.repository_session_manager import RepositorySessionManager from strands.session.session_repository import SessionRepository @@ -906,16 +918,85 @@ def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig): def register_hooks(self, registry: HookRegistry, **kwargs) -> None: """Register additional hooks. + In sync mode (the default), delegates to the base class and adds the + retrieve_customer_context + batching callbacks synchronously, preserving + existing behavior exactly. + + In async mode, registers async callbacks that wrap every per-turn + boto3-backed operation (append_message, sync_agent, buffer flushes, + customer-context retrieval) with asyncio.to_thread, so the asyncio + event loop stays free while boto3 is blocking on the network. + + Note: AgentInitializedEvent cannot be async per Strands' HookRegistry, + so agent restoration (read_session / read_agent / list_messages) still + blocks the calling thread in async mode — see AgentCoreMemoryConfig + docstring for mitigations. + Args: registry (HookRegistry): The hook registry to register callbacks with. **kwargs: Additional keyword arguments. """ - RepositorySessionManager.register_hooks(self, registry, **kwargs) - registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) + if not self.config.async_mode: + RepositorySessionManager.register_hooks(self, registry, **kwargs) + registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) + + # Only register AfterInvocationEvent hook when batching is enabled + if self.config.batch_size > 1: + registry.add_callback(AfterInvocationEvent, lambda event: self._flush_messages()) + return + + # Async mode: register async callbacks that offload the existing sync + # methods to a worker thread via asyncio.to_thread. AgentInitializedEvent + # and BidiAgentInitializedEvent must stay sync (Strands disallows async + # callbacks for AgentInitializedEvent — see strands/hooks/registry.py:227). + logger.warning( + "AgentCoreMemorySessionManager async_mode=True: the agent must be invoked " + "via the async path (e.g. agent.stream_async(...) or agent.invoke_async(...)). " + "Sync invocation will raise RuntimeError from Strands' hook registry." + ) + + def _offload(method, *event_args): + """Build an async callback that offloads `method(*[a(event) for a in event_args])` to a thread. + + Each entry in `event_args` is a callable that extracts an argument from the event; + pass none for a zero-arg method. + """ + + async def _callback(event): + await asyncio.to_thread(method, *(extract(event) for extract in event_args)) + + return _callback + + registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) + + async def _on_message_added_persist(event: MessageAddedEvent) -> None: + await asyncio.to_thread(self.append_message, event.message, event.agent) + await asyncio.to_thread(self.sync_agent, event.agent) + + registry.add_callback(MessageAddedEvent, _on_message_added_persist) + registry.add_callback(AfterInvocationEvent, _offload(self.sync_agent, lambda e: e.agent)) + registry.add_callback(MessageAddedEvent, _offload(self.retrieve_customer_context, lambda e: e)) - # Only register AfterInvocationEvent hook when batching is enabled if self.config.batch_size > 1: - registry.add_callback(AfterInvocationEvent, lambda event: self._flush_messages()) + registry.add_callback(AfterInvocationEvent, _offload(self._flush_messages)) + + # Register multi-agent callbacks so async-mode parity matches sync-mode + registry.add_callback(MultiAgentInitializedEvent, _offload(self.initialize_multi_agent, lambda e: e.source)) + registry.add_callback(AfterNodeCallEvent, _offload(self.sync_multi_agent, lambda e: e.source)) + registry.add_callback(AfterMultiAgentInvocationEvent, _offload(self.sync_multi_agent, lambda e: e.source)) + + # Register BidiAgent callbacks so async-mode parity matches sync-mode. + # BidiAgentInitializedEvent dispatches through invoke_callbacks (sync), + # so its callback must stay sync; the other two dispatch through + # invoke_callbacks_async, so async wrappers are safe. + registry.add_callback(BidiAgentInitializedEvent, lambda event: self.initialize_bidi_agent(event.agent)) + + async def _on_bidi_message_added(event: BidiMessageAddedEvent) -> None: + await asyncio.to_thread(self.append_bidi_message, event.message, event.agent) + await asyncio.to_thread(self.sync_bidi_agent, event.agent) + + registry.add_callback(BidiMessageAddedEvent, _on_bidi_message_added) + registry.add_callback(BidiAfterInvocationEvent, _offload(self.sync_bidi_agent, lambda e: e.agent)) @override def initialize(self, agent: "Agent", **kwargs: Any) -> None: @@ -1071,6 +1152,7 @@ def _flush_agent_states_only(self) -> list[dict[str, Any]]: with self._agent_state_lock: agent_states_to_send = list(self._agent_state_buffer) + self._agent_state_buffer.clear() if not agent_states_to_send: return [] @@ -1101,11 +1183,10 @@ def _flush_agent_states_only(self) -> list[dict[str, Any]]: results.append(event) logger.debug("Flushed %d agent states for agent %s: %s", len(payloads), agent_id, event.get("eventId")) - # Clear agent state buffer only after ALL events succeed - with self._agent_state_lock: - self._agent_state_buffer.clear() - except Exception as e: + # Restore agent states to buffer so they aren't lost + with self._agent_state_lock: + self._agent_state_buffer.extend(agent_states_to_send) logger.error("Failed to flush agent states to AgentCore Memory: %s", e) raise SessionException(f"Failed to flush agent states: {e}") from e diff --git a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py index 2daca61b..8e26a0a2 100644 --- a/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py +++ b/tests/bedrock_agentcore/memory/integrations/strands/test_agentcore_memory_session_manager.py @@ -1,5 +1,7 @@ """Tests for AgentCoreMemorySessionManager.""" +import asyncio +import inspect import logging import time from datetime import datetime, timezone @@ -9,6 +11,16 @@ from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError from strands.agent.agent import Agent +from strands.experimental.hooks.events import ( + BidiAfterInvocationEvent, + BidiAgentInitializedEvent, + BidiMessageAddedEvent, +) +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, +) from strands.hooks import AfterInvocationEvent, MessageAddedEvent from strands.hooks.registry import HookRegistry from strands.types.exceptions import SessionException @@ -3580,3 +3592,196 @@ def test_retrieve_customer_context_works(self, mock_memory_client): mock_memory_client.retrieve_memories.assert_called_once() assert "" in mock_agent.messages[0]["content"][0]["text"] + + +class TestAsyncMode: + """Tests for async_mode: callbacks must not block the event loop.""" + + def test_async_mode_defaults_to_false(self, agentcore_config): + assert agentcore_config.async_mode is False + + def test_sync_mode_registers_sync_callbacks(self, mock_memory_client): + """async_mode=False: all MessageAddedEvent/AfterInvocationEvent callbacks are sync.""" + config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", batch_size=5, async_mode=False) + manager = _create_session_manager(config, mock_memory_client) + registry = HookRegistry() + manager.register_hooks(registry) + + for event_type in (MessageAddedEvent, AfterInvocationEvent): + for cb in registry.get_callbacks_for( + event_type(agent=Mock(), message={"role": "user", "content": [{"text": "x"}]}) + if event_type is MessageAddedEvent + else event_type(agent=Mock()) + ): + assert not inspect.iscoroutinefunction(cb), f"Sync mode leaked an async callback for {event_type}" + + def test_async_mode_registers_async_callbacks(self, mock_memory_client): + """async_mode=True: MessageAddedEvent and AfterInvocationEvent callbacks are coroutine functions.""" + config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", batch_size=5, async_mode=True) + manager = _create_session_manager(config, mock_memory_client) + registry = HookRegistry() + manager.register_hooks(registry) + + msg_callbacks = registry.get_callbacks_for( + MessageAddedEvent(agent=Mock(), message={"role": "user", "content": [{"text": "x"}]}) + ) + assert msg_callbacks, "No MessageAddedEvent callbacks registered in async mode" + assert all(inspect.iscoroutinefunction(cb) for cb in msg_callbacks) + + after_callbacks = registry.get_callbacks_for(AfterInvocationEvent(agent=Mock())) + assert after_callbacks, "No AfterInvocationEvent callbacks registered in async mode" + assert all(inspect.iscoroutinefunction(cb) for cb in after_callbacks) + + async def test_async_mode_does_not_block_event_loop(self, mock_memory_client): + """The async hooks run boto3 on a worker thread, so the event loop can make progress concurrently.""" + config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", async_mode=True) + manager = _create_session_manager(config, mock_memory_client) + + # Simulate each sync session-manager method blocking on boto3. + def slow_append_message(message, agent, **kwargs): + time.sleep(0.2) + + def slow_sync_agent(agent, **kwargs): + time.sleep(0.2) + + manager.append_message = slow_append_message + manager.sync_agent = slow_sync_agent + + registry = HookRegistry() + manager.register_hooks(registry) + + persist_callbacks = [ + cb + for cb in registry.get_callbacks_for( + MessageAddedEvent(agent=Mock(), message={"role": "user", "content": [{"text": "x"}]}) + ) + if asyncio.iscoroutinefunction(cb) + ] + assert persist_callbacks + + event = MessageAddedEvent(agent=Mock(), message={"role": "user", "content": [{"text": "hello"}]}) + + # Ticker proves the event loop made progress while the hook awaited to_thread. + ticks = 0 + + async def ticker(): + nonlocal ticks + while True: + await asyncio.sleep(0.01) + ticks += 1 + + ticker_task = asyncio.create_task(ticker()) + try: + # Run the persist callback (append_message + sync_agent); both sleep 0.2s on a worker thread. + await persist_callbacks[0](event) + finally: + ticker_task.cancel() + + assert ticks > 5, f"Event loop was blocked; only {ticks} ticks recorded" + + async def test_async_mode_batching_registers_flush_callback(self, mock_memory_client): + """async_mode=True with batch_size>1: AfterInvocationEvent gets both sync_agent and flush callbacks.""" + config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", batch_size=5, async_mode=True) + manager = _create_session_manager(config, mock_memory_client) + registry = HookRegistry() + manager.register_hooks(registry) + + after_callbacks = list(registry.get_callbacks_for(AfterInvocationEvent(agent=Mock()))) + assert len(after_callbacks) == 2 + assert all(asyncio.iscoroutinefunction(cb) for cb in after_callbacks) + + def test_async_mode_registers_multi_agent_callbacks(self, mock_memory_client): + """async_mode=True: multi-agent events get async callbacks (parity with sync mode).""" + config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", async_mode=True) + manager = _create_session_manager(config, mock_memory_client) + registry = HookRegistry() + manager.register_hooks(registry) + + for event_type in (MultiAgentInitializedEvent, AfterNodeCallEvent, AfterMultiAgentInvocationEvent): + callbacks = registry._registered_callbacks.get(event_type, []) + assert callbacks, f"No callbacks registered for {event_type.__name__}" + assert all(asyncio.iscoroutinefunction(cb) for cb in callbacks) + + def test_async_mode_logs_sync_invocation_warning(self, mock_memory_client, caplog): + """async_mode=True emits a WARNING at register_hooks time pointing users to stream_async/invoke_async.""" + config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", async_mode=True) + manager = _create_session_manager(config, mock_memory_client) + registry = HookRegistry() + + with caplog.at_level(logging.WARNING, logger="bedrock_agentcore.memory.integrations.strands.session_manager"): + manager.register_hooks(registry) + + assert any("async_mode=True" in rec.message and "stream_async" in rec.message for rec in caplog.records) + + def test_async_mode_registers_bidi_agent_callbacks(self, mock_memory_client): + """async_mode=True: BidiAgent events get callbacks; init stays sync, others are async.""" + config = AgentCoreMemoryConfig(memory_id="m", session_id="s", actor_id="a", async_mode=True) + manager = _create_session_manager(config, mock_memory_client) + registry = HookRegistry() + manager.register_hooks(registry) + + # BidiAgentInitializedEvent dispatches via the sync hook path, so its callback must NOT be a coroutine. + init_callbacks = registry._registered_callbacks.get(BidiAgentInitializedEvent, []) + assert init_callbacks, "No callbacks registered for BidiAgentInitializedEvent" + assert not any(asyncio.iscoroutinefunction(cb) for cb in init_callbacks) + + # BidiMessageAddedEvent and BidiAfterInvocationEvent dispatch via invoke_callbacks_async, + # so their callbacks should be async to keep the event loop unblocked. + for event_type in (BidiMessageAddedEvent, BidiAfterInvocationEvent): + callbacks = registry._registered_callbacks.get(event_type, []) + assert callbacks, f"No callbacks registered for {event_type.__name__}" + assert all(asyncio.iscoroutinefunction(cb) for cb in callbacks) + + +class TestFlushAgentStatesRaceCondition: + """Tests for the copy-and-clear-under-one-lock fix in _flush_agent_states_only.""" + + def test_flush_agent_states_does_not_drop_concurrent_appends(self, batching_session_manager, mock_memory_client): + """States appended during the network I/O window must survive the flush.""" + states_appended_during_flush = [] + + # Pre-populate the buffer with one state to force a flush. + initial_agent = SessionAgent( + agent_id="agent-1", + state={"description": "initial"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", initial_agent) + assert batching_session_manager.pending_agent_state_count() == 1 + + # Simulate a concurrent create_agent during the boto3 call. The mock fires + # while the buffer is being flushed — i.e. between the copy and any clear — + # so a second append must NOT be lost. + def create_event_and_append_concurrently(**kwargs): + new_agent = SessionAgent( + agent_id="agent-2", + state={"description": "appended-mid-flush"}, + conversation_manager_state={}, + ) + batching_session_manager.create_agent("test-session-456", new_agent) + states_appended_during_flush.append(new_agent) + return {"eventId": "event_during_flush"} + + mock_memory_client.gmdp_client.create_event.side_effect = create_event_and_append_concurrently + + batching_session_manager._flush_agent_states_only() + + # The state appended during the flush must still be in the buffer afterwards. + assert states_appended_during_flush, "Test setup error: concurrent append did not run" + assert batching_session_manager.pending_agent_state_count() == 1, ( + "State appended during flush was dropped — copy/clear is not atomic" + ) + + def test_flush_agent_states_failure_restores_buffer(self, batching_session_manager, mock_memory_client): + """A failed flush must restore the originally-buffered states (no data loss).""" + mock_memory_client.gmdp_client.create_event.side_effect = Exception("API Error") + + agent = SessionAgent(agent_id="agent-1", state={"description": "v1"}, conversation_manager_state={}) + batching_session_manager.create_agent("test-session-456", agent) + assert batching_session_manager.pending_agent_state_count() == 1 + + with pytest.raises(SessionException): + batching_session_manager._flush_agent_states_only() + + # State must be back in the buffer for retry. + assert batching_session_manager.pending_agent_state_count() == 1