-
Notifications
You must be signed in to change notification settings - Fork 112
feat: add async support to MemorySessionManager #478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| """AgentCore Memory-based session manager for Bedrock AgentCore Memory integration.""" | ||
|
|
||
| import asyncio | ||
| import json | ||
| import logging | ||
| import threading | ||
|
|
@@ -10,7 +11,13 @@ | |
|
|
||
| import boto3 | ||
| from botocore.config import Config as BotocoreConfig | ||
| from strands.experimental.hooks.multiagent.events import ( | ||
| AfterMultiAgentInvocationEvent, | ||
| AfterNodeCallEvent, | ||
| MultiAgentInitializedEvent, | ||
| ) | ||
| from strands.hooks import AfterInvocationEvent, MessageAddedEvent | ||
| from strands.hooks.events import AgentInitializedEvent | ||
| from strands.hooks.registry import HookRegistry | ||
| from strands.session.repository_session_manager import RepositorySessionManager | ||
| from strands.session.session_repository import SessionRepository | ||
|
|
@@ -906,16 +913,79 @@ def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig): | |
| def register_hooks(self, registry: HookRegistry, **kwargs) -> None: | ||
| """Register additional hooks. | ||
|
|
||
| In sync mode (the default), delegates to the base class and adds the | ||
| retrieve_customer_context + batching callbacks synchronously, preserving | ||
| existing behavior exactly. | ||
|
|
||
| In async mode, registers async callbacks that wrap every per-turn | ||
| boto3-backed operation (append_message, sync_agent, buffer flushes, | ||
| customer-context retrieval) with asyncio.to_thread, so the asyncio | ||
| event loop stays free while boto3 is blocking on the network. | ||
|
|
||
| Note: AgentInitializedEvent cannot be async per Strands' HookRegistry, | ||
| so agent restoration (read_session / read_agent / list_messages) still | ||
| blocks the calling thread in async mode — see AgentCoreMemoryConfig | ||
| docstring for mitigations. | ||
|
|
||
| Args: | ||
| registry (HookRegistry): The hook registry to register callbacks with. | ||
| **kwargs: Additional keyword arguments. | ||
| """ | ||
| RepositorySessionManager.register_hooks(self, registry, **kwargs) | ||
| registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) | ||
| if not self.config.async_mode: | ||
| RepositorySessionManager.register_hooks(self, registry, **kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: We're manually registering hooks in the async path instead of going through RepositorySessionManager.register_hooks(), so if strands adds new hooks upstream, we won't pick them up here. |
||
| registry.add_callback(MessageAddedEvent, lambda event: self.retrieve_customer_context(event)) | ||
|
|
||
| # Only register AfterInvocationEvent hook when batching is enabled | ||
| if self.config.batch_size > 1: | ||
| registry.add_callback(AfterInvocationEvent, lambda event: self._flush_messages()) | ||
| return | ||
|
|
||
| # Async mode: register async callbacks that offload the existing sync | ||
| # methods to a worker thread via asyncio.to_thread. AgentInitializedEvent | ||
| # must stay sync (Strands disallows async callbacks on this event; see | ||
| # strands/hooks/registry.py:174). | ||
| logger.warning( | ||
| "AgentCoreMemorySessionManager async_mode=True: the agent must be invoked " | ||
| "via the async path (e.g. agent.stream_async(...) or agent.invoke_async(...)). " | ||
| "Sync invocation will raise RuntimeError from Strands' hook registry." | ||
| ) | ||
|
|
||
| registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note (for my own understanding): because this path doesn't call In other words, we pick this synchronous hook from that implementation and leave out the rest to overwrite them with our own async hooks. |
||
|
|
||
| async def _on_message_added_persist(event: MessageAddedEvent) -> None: | ||
| await asyncio.to_thread(self.append_message, event.message, event.agent) | ||
| await asyncio.to_thread(self.sync_agent, event.agent) | ||
|
|
||
| async def _on_message_added_retrieve(event: MessageAddedEvent) -> None: | ||
| await asyncio.to_thread(self.retrieve_customer_context, event) | ||
|
|
||
| async def _on_after_invocation_sync(event: AfterInvocationEvent) -> None: | ||
| await asyncio.to_thread(self.sync_agent, event.agent) | ||
|
|
||
| registry.add_callback(MessageAddedEvent, _on_message_added_persist) | ||
| registry.add_callback(AfterInvocationEvent, _on_after_invocation_sync) | ||
| registry.add_callback(MessageAddedEvent, _on_message_added_retrieve) | ||
|
|
||
| # Only register AfterInvocationEvent hook when batching is enabled | ||
| if self.config.batch_size > 1: | ||
| registry.add_callback(AfterInvocationEvent, lambda event: self._flush_messages()) | ||
|
|
||
| async def _on_after_invocation_flush(event: AfterInvocationEvent) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: could we add a small helper to reduce boilerplate here? |
||
| await asyncio.to_thread(self._flush_messages) | ||
|
|
||
| registry.add_callback(AfterInvocationEvent, _on_after_invocation_flush) | ||
|
|
||
| # Register multi-agent callbacks so async-mode parity matches sync-mode | ||
| async def _on_multi_agent_initialized(event: MultiAgentInitializedEvent) -> None: | ||
| await asyncio.to_thread(self.initialize_multi_agent, event.source) | ||
|
|
||
| async def _on_after_node_call(event: AfterNodeCallEvent) -> None: | ||
| await asyncio.to_thread(self.sync_multi_agent, event.source) | ||
|
|
||
| async def _on_after_multi_agent_invocation(event: AfterMultiAgentInvocationEvent) -> None: | ||
| await asyncio.to_thread(self.sync_multi_agent, event.source) | ||
|
|
||
| registry.add_callback(MultiAgentInitializedEvent, _on_multi_agent_initialized) | ||
| registry.add_callback(AfterNodeCallEvent, _on_after_node_call) | ||
| registry.add_callback(AfterMultiAgentInvocationEvent, _on_after_multi_agent_invocation) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Do we support BidiAgent? The sync path picks up BidiAgent hooks through the parent, but async mode doesn't register them. Should we add them here or explicitly document that BidiAgent + async_mode is unsupported? |
||
|
|
||
| @override | ||
| def initialize(self, agent: "Agent", **kwargs: Any) -> None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I prefer if self.config.async_mode: — easier to read when the primary condition isn't negated.