diff --git a/strands-py/src/strands/memory/__init__.py b/strands-py/src/strands/memory/__init__.py new file mode 100644 index 0000000000..90ba73c871 --- /dev/null +++ b/strands-py/src/strands/memory/__init__.py @@ -0,0 +1,59 @@ +"""Memory module for Strands Agents. + +This package gives agents cross-session recall and persistence through a +``MemoryManager`` plugin that manages pluggable memory stores, exposes search/add +tools, and runs automatic background extraction. +""" + +from ..types.exceptions import AggregateMemoryError +from .extraction.model_extractor import ModelExtractor +from .extraction.triggers import IntervalTrigger, InvocationTrigger +from .extraction.types import ( + ExtractionConfig, + ExtractionResult, + ExtractionTrigger, + ExtractionTriggerContext, + Extractor, + ExtractorContext, + MemoryContentBlockType, + MemoryMessageFilter, +) +from .memory_manager import MemoryManager +from .types import ( + AddMessagesContext, + MemoryAddOptions, + MemoryAddToolConfig, + MemoryEntry, + MemoryManagerConfig, + MemorySearchOptions, + MemoryStore, + MemoryStoreConfig, + MemoryToolConfig, + SearchOptions, +) + +__all__ = [ + "AddMessagesContext", + "AggregateMemoryError", + "ExtractionConfig", + "ExtractionResult", + "ExtractionTrigger", + "ExtractionTriggerContext", + "Extractor", + "ExtractorContext", + "IntervalTrigger", + "InvocationTrigger", + "MemoryAddOptions", + "MemoryAddToolConfig", + "MemoryContentBlockType", + "MemoryEntry", + "MemoryManager", + "MemoryManagerConfig", + "MemoryMessageFilter", + "MemorySearchOptions", + "MemoryStore", + "MemoryStoreConfig", + "MemoryToolConfig", + "ModelExtractor", + "SearchOptions", +] diff --git a/strands-py/src/strands/memory/extraction/__init__.py b/strands-py/src/strands/memory/extraction/__init__.py new file mode 100644 index 0000000000..bfae911262 --- /dev/null +++ b/strands-py/src/strands/memory/extraction/__init__.py @@ -0,0 +1,5 @@ +"""Extraction primitives for the memory module: coordinator, triggers, and extractor. + +The public surface is exported from ``strands.memory``; this subpackage groups the +implementation modules. +""" diff --git a/strands-py/src/strands/memory/extraction/coordinator.py b/strands-py/src/strands/memory/extraction/coordinator.py new file mode 100644 index 0000000000..ed534d2b77 --- /dev/null +++ b/strands-py/src/strands/memory/extraction/coordinator.py @@ -0,0 +1,264 @@ +"""Background coordinator that saves conversation messages to memory stores. + +The :class:`ExtractionCoordinator` buffers every message the agent produces and, +when a store's trigger fires, saves that store's unsaved messages in the +background. It keeps a per-store high-water mark so each message is delivered to +a store at most once, serializes a single store's saves through a per-store task +chain, and backs off stores that fail repeatedly. +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass + +from ...models.model import Model +from ...types.content import ContentBlock, Message +from ...types.exceptions import AggregateMemoryError +from ..types import MemoryStore +from .types import DEFAULT_MEMORY_MESSAGE_FILTER, Extractor, ExtractorContext, MemoryMessageFilter + +logger = logging.getLogger(__name__) + +# Number of consecutive save failures after which a store backs off. +SAVE_FAILURES_BEFORE_BACKOFF = 10 + +# While backed off, a store retries only once every this many save attempts. +BACKOFF_PROBE_INTERVAL = 3 + + +@dataclass +class _Buffered: + """A buffered message and its monotonically increasing sequence number.""" + + seq: int + message: Message + + +class ExtractionCoordinator: + """Saves conversation messages to memory stores in the background. + + Buffers every recorded message and, per store, tracks a high-water mark of + the last ``seq`` saved so each message is delivered at most once. A single + store's saves are serialized through a per-store task chain; different stores + save independently. Failures are logged and swallowed, with per-store backoff + for repeatedly failing stores. + """ + + def __init__(self, stores: list[MemoryStore], default_model: Model) -> None: + """Initialize the coordinator. + + Args: + stores: The extraction-configured stores this coordinator manages. + default_model: The agent's model, passed to extractors that do not + configure their own. + """ + self._stores = list(stores) + self._default_model = default_model + # Messages waiting to be saved, oldest first. + self._pending: list[_Buffered] = [] + # The ``seq`` to assign the next buffered message. + self._next_seq = 0 + # Per store: ``seq`` of the last message it has saved (-1 means none). + self._marks: dict[int, int] = {id(store): -1 for store in stores} + # Per store: the currently-running save task, so the next save waits its turn. + self._chains: dict[int, asyncio.Task] = {} + # Per store: consecutive save failures, reset to 0 on success. + self._consecutive_failures: dict[int, int] = {} + # Per store: save-request count while backed off, to let every Nth through as a probe. + self._backoff_counters: dict[int, int] = {} + # Fire-and-forget background tasks, retained so they aren't GC'd mid-flight. + self._background: set[asyncio.Task] = set() + + def record(self, message: Message) -> None: + """Add a message to the buffer.""" + self._pending.append(_Buffered(self._next_seq, message)) + self._next_seq += 1 + + def schedule(self, store: MemoryStore) -> None: + """Save this store's unsaved messages in the background, non-blocking. + + Dispatches the save and returns immediately. A no-op when the store is + backed off and this request is not a probe. + """ + task = self.process(store) + if task is None: + return + self._background.add(task) + + def _done(completed: asyncio.Task) -> None: + self._background.discard(completed) + if completed.cancelled(): + return + error = completed.exception() + if error is not None: + logger.warning("store=<%s>, reason=<%s> | background memory save failed", store.name, error) + + task.add_done_callback(_done) + + def process(self, store: MemoryStore) -> asyncio.Task | None: + """Queue a save for this store behind its previous save. + + Returns the task running the save, or ``None`` when the store is backed + off and this request is not a probe. + """ + if not self._should_attempt(store): + return None + return self._enqueue(store) + + def _enqueue(self, store: MemoryStore) -> asyncio.Task: + """Queue this store's save behind its previous one and return the task.""" + previous = self._chains.get(id(store)) + task = asyncio.create_task(self._run_chain(store, previous)) + self._chains[id(store)] = task + return task + + async def _run_chain(self, store: MemoryStore, previous: asyncio.Task | None) -> None: + """Run this store's save after its previous one completes.""" + if previous is not None: + await previous + await self._extract(store) + + def _should_attempt(self, store: MemoryStore) -> bool: + """Return whether to attempt a save now. + + A healthy store always attempts. A backed-off store attempts only once + every :data:`BACKOFF_PROBE_INTERVAL` requests (a probe) and skips the + rest. + """ + if self._consecutive_failures.get(id(store), 0) < SAVE_FAILURES_BEFORE_BACKOFF: + return True + count = self._backoff_counters.get(id(store), 0) + 1 + self._backoff_counters[id(store)] = count + return count % BACKOFF_PROBE_INTERVAL == 0 + + async def flush(self) -> None: + """Save every store's remaining buffered messages and wait for completion. + + Bypasses backoff and also waits out saves that start while waiting. + Never raises. + """ + for store in self._stores: + self._enqueue(store) + while True: + snapshot = list(self._chains.values()) + await asyncio.gather(*snapshot, return_exceptions=True) + current = list(self._chains.values()) + # Done once nothing new started while we waited. + if len(current) == len(snapshot) and all( + current_task is snapshot_task for current_task, snapshot_task in zip(current, snapshot, strict=True) + ): + return + + async def _extract(self, store: MemoryStore) -> None: + """Save the store's messages newer than its high-water mark. + + On failure the mark is rolled back so the batch retries next time. + """ + mark = self._marks.get(id(store), -1) + fresh = [buffered for buffered in self._pending if buffered.seq > mark] + if not fresh: + return + + extraction = store.extraction + if extraction is None: + return + + # Mark saved before saving so a queued save won't pick these up again; + # rolled back below on failure. + self._marks[id(store)] = fresh[-1].seq + + message_filter = extraction.filter or DEFAULT_MEMORY_MESSAGE_FILTER + filtered = self._filter_messages([buffered.message for buffered in fresh], message_filter) + + try: + if filtered: + await self._write(store, filtered, extraction.extractor) + # Successful write clears the failure streak and ends backoff. A + # fully filtered (empty) turn never touched the backend, so it + # leaves backoff state untouched. + self._consecutive_failures[id(store)] = 0 + self._backoff_counters.pop(id(store), None) + except Exception as error: # noqa: BLE001 - saving must never break the agent loop. + self._on_save_failed(store, mark, error) + finally: + self._trim() + + async def _write(self, store: MemoryStore, messages: list[Message], extractor: Extractor | None) -> None: + """Save the messages to the store, one of two ways. + + - With an extractor: run it, then write each fact via ``add`` + concurrently. If any write fails the whole batch is re-raised and + retried later, so stores should expect duplicate writes. + - Without an extractor: hand the raw messages to ``add_messages``. + + Raises: + AggregateMemoryError: If any concurrent ``add`` write fails. + """ + if extractor is not None: + entries = await extractor.extract(messages, ExtractorContext(default_model=self._default_model)) + results = await asyncio.gather( + *(store.add(entry.content, entry.metadata) for entry in entries), + return_exceptions=True, + ) + failures = [result for result in results if isinstance(result, BaseException)] + if failures: + raise AggregateMemoryError( + f"failed to write {len(failures)} of {len(entries)} extracted entries", + failures, + ) + return + + await store.add_messages(messages) + + def _filter_messages(self, messages: list[Message], message_filter: MemoryMessageFilter) -> list[Message]: + """Remove excluded content blocks, dropping any message left empty. + + Builds new message dicts rather than mutating the inputs. + """ + exclude = set(message_filter.exclude) + result: list[Message] = [] + for message in messages: + content = [block for block in message["content"] if self._block_kind(block) not in exclude] + if content: + new_message: Message = {"role": message["role"], "content": content} + if message.get("metadata") is not None: + new_message["metadata"] = message["metadata"] + result.append(new_message) + return result + + def _block_kind(self, block: ContentBlock) -> str: + """Return the content block's kind (its single key), or ``""`` if empty.""" + return next(iter(block.keys()), "") + + def _on_save_failed(self, store: MemoryStore, mark_before_save: int, error: BaseException) -> None: + """Handle a failed save. + + Rolls the mark back so the messages retry next time. After + :data:`SAVE_FAILURES_BEFORE_BACKOFF` consecutive failures the store + enters backoff and logs an error; before that it logs a warning. + """ + failures = self._consecutive_failures.get(id(store), 0) + 1 + self._consecutive_failures[id(store)] = failures + self._marks[id(store)] = mark_before_save + reason = str(error) + + if failures >= SAVE_FAILURES_BEFORE_BACKOFF: + logger.error( + "store=<%s>, failures=<%s>, reason=<%s> | memory store save failing repeatedly", + store.name, + failures, + reason, + ) + else: + logger.warning("store=<%s>, reason=<%s> | memory extraction failed", store.name, reason) + + def _trim(self) -> None: + """Drop buffered messages every store has already saved. + + A store stuck failing keeps its messages buffered, so the buffer grows + until it recovers; this is bounded by the (non-persisted) session. + """ + min_mark = min(self._marks.values()) + self._pending = [buffered for buffered in self._pending if buffered.seq > min_mark] diff --git a/strands-py/src/strands/memory/extraction/model_extractor.py b/strands-py/src/strands/memory/extraction/model_extractor.py new file mode 100644 index 0000000000..fdfb1e76d1 --- /dev/null +++ b/strands-py/src/strands/memory/extraction/model_extractor.py @@ -0,0 +1,139 @@ +"""Model-backed :class:`Extractor` that distills messages into discrete facts. + +A :class:`ModelExtractor` calls a language model with a fact-extraction system +prompt and parses the response into :class:`ExtractionResult` entries. Backends +that extract server-side should omit the extractor entirely. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from ...models.model import Model +from ...types.content import Message +from .types import ExtractionResult, ExtractorContext + +logger = logging.getLogger(__name__) + +# Default instruction guiding the model to emit discrete, durable facts as a JSON array. +DEFAULT_SYSTEM_PROMPT = ( + "You extract durable facts worth remembering across future conversations from a transcript.\n" + "\n" + 'Return ONLY a JSON array of objects, each: {"content": string}. Each object is one discrete, ' + "self-contained fact (a preference, decision, or stable detail about the user or task). Do not " + "include transient chit-chat, questions, or anything already obvious. If there is nothing worth " + "remembering, return []." +) + + +class ModelExtractor: + """An :class:`Extractor` that calls a language model to distill messages into discrete facts. + + Use for self-managed stores that hold plain text and want automatic + distillation. + + Example: + ```python + ExtractionConfig( + trigger=[InvocationTrigger()], + extractor=ModelExtractor(model=cheap_model, system_prompt="Extract user preferences."), + ) + ``` + """ + + def __init__(self, model: Model | None = None, system_prompt: str | None = None) -> None: + """Initialize the extractor. + + Args: + model: Model used to extract facts. Defaults to the agent's own model + (via :attr:`ExtractorContext.default_model`); set a cheaper one to + cut cost. + system_prompt: System prompt steering what counts as a fact. Defaults + to a general fact-extraction prompt. + """ + self._model = model + self._system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT + + async def extract(self, messages: list[Message], context: ExtractorContext | None = None) -> list[ExtractionResult]: + """Extract entries from a batch of messages. + + Raises: + ValueError: If no model is configured and no default is available. + RuntimeError: If the model returns no response. + """ + model = self._model or (context.default_model if context else None) + if model is None: + raise ValueError("ModelExtractor: no model configured and no default model available") + if not messages: + return [] + + # Present the transcript as a single user turn so the system prompt governs extraction. + transcript = "\n".join(_render_message(message) for message in messages) + prompt: Message = { + "role": "user", + "content": [{"text": f"Extract facts from the following transcript:\n\n{transcript}"}], + } + + # Lazy import to avoid a circular import with ``event_loop.streaming``. + from ...event_loop.streaming import stream_messages + + final_message: Message | None = None + async for event in stream_messages(model, self._system_prompt, [prompt], tool_specs=[]): + # The terminal ``ModelStopReason`` event carries ``{"stop": (stop_reason, message, ...)}``. + stop = event.get("stop") + if stop is not None: + final_message = stop[1] + + if final_message is None: + raise RuntimeError("ModelExtractor: model returned no response") + + text = "".join(block.get("text", "") for block in final_message["content"]).strip() + + return _parse_entries(text, type(model).__name__) + + +def _render_message(message: Message) -> str: + """Render one message as ``role: text``, joining its non-empty text blocks.""" + text = "\n".join(part for block in message["content"] if (part := block.get("text", "")) and len(part) > 0) + return f"{message['role']}: {text}" + + +def _extract_json_array(text: str) -> str | None: + """Extract the substring from the first ``[`` to the last ``]``, or None if absent.""" + start = text.find("[") + end = text.rfind("]") + if start == -1 or end == -1 or end < start: + return None + return text[start : end + 1] + + +def _parse_entries(text: str, model_name: str) -> list[ExtractionResult]: + """Parse the model's response into entries. + + Tolerates the array being wrapped in prose or a code fence. Malformed output + yields no entries (logged) rather than throwing. + """ + json_text = _extract_json_array(text) + if json_text is None: + logger.warning("model=<%s> | ModelExtractor: no JSON array in model output, skipping", model_name) + return [] + + try: + parsed: Any = json.loads(json_text) + except ValueError as err: + logger.warning("model=<%s>, error=<%s> | ModelExtractor: failed to parse output", model_name, str(err)) + return [] + + if not isinstance(parsed, list): + return [] + + entries: list[ExtractionResult] = [] + for item in parsed: + if isinstance(item, dict) and isinstance(item.get("content"), str): + content = item["content"].strip() + if len(content) > 0: + metadata = item.get("metadata") + entries.append(ExtractionResult(content=content, metadata=metadata if metadata is not None else None)) + return entries diff --git a/strands-py/src/strands/memory/extraction/triggers.py b/strands-py/src/strands/memory/extraction/triggers.py new file mode 100644 index 0000000000..5e480576a9 --- /dev/null +++ b/strands-py/src/strands/memory/extraction/triggers.py @@ -0,0 +1,94 @@ +"""Built-in extraction triggers that control *when* a store's extraction runs. + +* :class:`InvocationTrigger` -- fire after every agent invocation. +* :class:`IntervalTrigger` -- fire once every ``turns`` invocations. + +See :class:`ExtractionTrigger` for the self-attaching trigger contract. +""" + +from __future__ import annotations + +from ...hooks.events import AfterInvocationEvent +from ...hooks.registry import HookOrder +from .types import ExtractionTrigger, ExtractionTriggerContext + + +class InvocationTrigger(ExtractionTrigger): + """Runs extraction after every agent invocation. + + The highest-fidelity option, and the most expensive when an + :class:`~strands.memory.extraction.types.Extractor` is configured (a model + call per turn). + + Example: + ```python + ExtractionConfig(trigger=[InvocationTrigger()]) + ``` + """ + + name = "invocation" + + def attach(self, context: ExtractionTriggerContext) -> None: + """Register an after-invocation callback that fires extraction. + + Runs after the SDK's own after-invocation hooks so extraction sees the + settled turn. The save runs in a background task, so the hook never + blocks. + """ + context.agent.add_hook( + lambda event: context.fire(), + AfterInvocationEvent, + order=HookOrder.SDK_LAST, + ) + + +class IntervalTrigger(ExtractionTrigger): + """Runs extraction every ``turns`` agent invocations. + + A controllable middle ground: the high-water mark still picks up the skipped + turns when the trigger fires. + + Example: + ```python + ExtractionConfig(trigger=[IntervalTrigger(turns=5)]) + ``` + + Attributes: + name: Stable identifier for this trigger kind (``interval``). + """ + + name = "interval" + + def __init__(self, turns: int) -> None: + """Initialize the trigger with a firing cadence. + + Args: + turns: Run extraction once every this many invocations. Must be a + positive integer. + + Raises: + ValueError: If ``turns`` is not a positive integer (``bool`` is + rejected even though it subclasses ``int``). + """ + # Reject bool explicitly (bool is a subclass of int) and any value < 1. + if not isinstance(turns, int) or isinstance(turns, bool) or turns < 1: + raise ValueError(f"IntervalTrigger: turns must be a positive integer, got {turns}") + self._turns = turns + + def attach(self, context: ExtractionTriggerContext) -> None: + """Register an after-invocation callback that fires every ``turns`` turns. + + Each ``attach`` creates a fresh closure counter, so one trigger instance + shared across stores keeps an independent count per attachment. + """ + # Per-attach counter so stores sharing one instance fire independently. + count = 0 + + def _callback(event: AfterInvocationEvent) -> None: + nonlocal count + count += 1 + # `fire` is fire-and-forget; it dispatches extraction in the background. + if count % self._turns == 0: + context.fire() + + context.agent.add_hook(_callback, AfterInvocationEvent, order=HookOrder.SDK_LAST) diff --git a/strands-py/src/strands/memory/extraction/types.py b/strands-py/src/strands/memory/extraction/types.py new file mode 100644 index 0000000000..dc763c0fea --- /dev/null +++ b/strands-py/src/strands/memory/extraction/types.py @@ -0,0 +1,144 @@ +"""Primitive types for the memory extraction subsystem.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, Protocol + +from ...models.model import Model +from ...types.content import Message + +if TYPE_CHECKING: + # Lazy import to avoid a circular import; only used in annotations. + from ...agent.agent import Agent + +# Metadata mapping for an extracted entry (scores, ids, timestamps, etc.). +Metadata = dict[str, Any] + +# Content-block kinds a ``MemoryMessageFilter`` can exclude. Mirrors the keys of +# ``strands.types.content.ContentBlock`` (e.g. ``{"text": ...}`` -> ``"text"``). +MemoryContentBlockType = Literal[ + "text", + "toolUse", + "toolResult", + "image", + "document", + "reasoningContent", + "video", + "guardContent", + "citationsContent", + "cachePoint", +] + + +@dataclass +class ExtractionResult: + """A discrete entry produced by an :class:`Extractor`, ready to write via ``add``.""" + + content: str + metadata: Metadata | None = None + + +@dataclass +class ExtractorContext: + """Context passed to :meth:`Extractor.extract`. + + Attributes: + default_model: The agent's model, supplied so an extractor can default to + it. + """ + + default_model: Model | None = None + + +class Extractor(Protocol): + """Transforms conversation messages into discrete, searchable entries. + + Optional on a store's :class:`ExtractionConfig`: when absent, the manager + passes messages straight to the store's ``add_messages`` (the no-extractor + passthrough), which is the right path for backends that extract server-side. + """ + + async def extract(self, messages: list[Message], context: ExtractorContext | None = None) -> list[ExtractionResult]: + """Extract entries from a batch of messages.""" + ... + + +@dataclass +class MemoryMessageFilter: + """Filters content blocks out of messages before extraction. + + Blocks whose kind is in :attr:`exclude` are stripped; a message left with no + content is dropped. Defaults to excluding tool traffic (``toolUse`` / + ``toolResult``). + """ + + exclude: list[MemoryContentBlockType] + + +# Default filter: drop tool-call traffic, keep everything else. +DEFAULT_MEMORY_MESSAGE_FILTER = MemoryMessageFilter(exclude=["toolUse", "toolResult"]) + + +@dataclass +class ExtractionTriggerContext: + """Context handed to :meth:`ExtractionTrigger.attach`. + + Attributes: + agent: The agent the trigger attaches its hooks to. + fire: Save this store's unsaved messages now. Runs in the background and + returns immediately. To await completion, see ``MemoryManager.flush``. + """ + + agent: Agent + fire: Callable[[], None] + + +class ExtractionTrigger(ABC): + """Controls when a store's :class:`ExtractionConfig` runs. + + A trigger is a self-attaching value object: :meth:`attach` wires the agent + hooks it needs and calls :attr:`ExtractionTriggerContext.fire` when extraction + should happen. Subclass for custom triggering logic. A trigger that never + fires never extracts; for a guaranteed final write, use + ``MemoryManager.flush``. + + Attributes: + name: Stable identifier for this trigger kind, used in logging. + """ + + name: str + + @abstractmethod + def attach(self, context: ExtractionTriggerContext) -> None: + """Wire this trigger into the agent lifecycle. + + Called once per store during ``MemoryManager`` initialization. Register + hooks on ``context.agent`` and call ``context.fire()`` when extraction + should run. + """ + ... + + +@dataclass +class ExtractionConfig: + """Per-store automatic-extraction configuration. + + Attributes: + trigger: When to run extraction. A single trigger or a non-empty list; + multiple triggers compose (extraction runs whenever any fires). + extractor: How to turn messages into entries. When set, the store must + implement ``add``. When omitted, the manager hands the filtered + messages straight to the store's ``add_messages`` (for backends that + extract server-side). + filter: Content blocks to strip before extraction. Defaults to + :data:`DEFAULT_MEMORY_MESSAGE_FILTER` (excludes ``toolUse`` / + ``toolResult``). Pass ``MemoryMessageFilter(exclude=[])`` to keep tool + blocks. + """ + + trigger: ExtractionTrigger | list[ExtractionTrigger] + extractor: Extractor | None = None + filter: MemoryMessageFilter | None = None diff --git a/strands-py/src/strands/memory/memory_manager.py b/strands-py/src/strands/memory/memory_manager.py new file mode 100644 index 0000000000..1e4804772b --- /dev/null +++ b/strands-py/src/strands/memory/memory_manager.py @@ -0,0 +1,571 @@ +"""Cross-session memory retrieval and storage for agents.""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from ..hooks.events import AfterInvocationEvent, MessageAddedEvent +from ..hooks.registry import HookOrder +from ..plugins.plugin import Plugin +from ..tools.decorator import tool +from ..types.exceptions import AggregateMemoryError +from ..types.tools import AgentTool +from .extraction.coordinator import ExtractionCoordinator +from .extraction.types import ExtractionTrigger, ExtractionTriggerContext +from .types import ( + MemoryAddOptions, + MemoryAddToolConfig, + MemoryEntry, + MemorySearchOptions, + MemoryStore, + MemoryToolConfig, + _has_method, + _has_write_sink, +) + +if TYPE_CHECKING: + from ..agent.agent import Agent + +logger = logging.getLogger(__name__) + +SEARCH_TOOL_DESCRIPTION = ( + "Search long-term memory for facts, preferences, or context from previous conversations. Use when you need " + "background about the user or topic that may have been discussed before." +) + +ADD_TOOL_DESCRIPTION = ( + "Add facts, preferences, or decisions to long-term memory so they are remembered across conversations. Use when " + "the user shares something worth recalling later." +) + +# Default maximum results per store when neither caller nor store specifies one. +DEFAULT_MAX_SEARCH_RESULTS = 3 + + +def _normalize_triggers(trigger: ExtractionTrigger | list[ExtractionTrigger]) -> list[ExtractionTrigger]: + """Normalize a store's ``trigger`` field (a single trigger or a list) to a list.""" + return list(trigger) if isinstance(trigger, list) else [trigger] + + +def _flatten_reasons(reasons: list[BaseException]) -> list[BaseException]: + """Flatten nested aggregate errors so the leaves are concrete reasons.""" + flattened: list[BaseException] = [] + for reason in reasons: + if isinstance(reason, AggregateMemoryError): + flattened.extend(_flatten_reasons(reason.errors)) + else: + flattened.append(reason) + return flattened + + +class MemoryManager(Plugin): + """Provides cross-session memory retrieval and storage for agents. + + When using the synchronous ``Agent(...)`` entry point, set + ``flush_on_invocation_end=True`` so extraction writes persist across its + per-invocation event loop. + + Example: + ```python + from strands import Agent + from strands.memory import MemoryManager + + memory_manager = MemoryManager(stores=[my_store], flush_on_invocation_end=True) + agent = Agent(model=model, plugins=[memory_manager]) + agent("Remember I prefer dark mode") + + results = await memory_manager.search("user preferences") + ``` + """ + + name = "strands:memory-manager" + + def __init__( + self, + stores: list[MemoryStore], + search_tool_config: MemoryToolConfig | bool = True, + add_tool_config: MemoryAddToolConfig | bool = False, + flush_on_invocation_end: bool = False, + ) -> None: + """Initialize the memory manager. + + Args: + stores: One or more memory stores to manage. + search_tool_config: Search tool configuration. ``True`` (default) + registers a ``search_memory`` tool with default name/description; + a :class:`MemoryToolConfig` customizes it; ``False`` disables it. + add_tool_config: Add tool configuration. ``False`` (default) disables + the add tool; ``True`` lets it write to all writable stores; a + :class:`MemoryAddToolConfig` restricts/customizes it. + flush_on_invocation_end: When True, await pending extraction writes at + the end of each agent invocation. Enable when driving the agent + through the synchronous ``Agent(...)`` entry point, whose + per-invocation event loop would otherwise cancel in-flight saves. + Defaults to False (fire-and-forget). + + Raises: + ValueError: If ``stores`` is empty, a store name is duplicated, a + writable store has no write sink, an extraction config is + misconfigured, or the add tool is enabled/scoped against stores + that cannot accept discrete ``add`` writes. + """ + if len(stores) == 0: + raise ValueError("MemoryManager: at least one store is required") + + seen_names: set[str] = set() + for store in stores: + if store.name in seen_names: + raise ValueError(f"MemoryManager: duplicate store name '{store.name}'") + seen_names.add(store.name) + + if store.writable and not _has_write_sink(store): + raise ValueError( + f"MemoryManager: store '{store.name}' is writable but has no add or add_messages method" + ) + + if store.extraction is not None: + if not store.writable: + raise ValueError(f"MemoryManager: store '{store.name}' has extraction config but is not writable") + if len(_normalize_triggers(store.extraction.trigger)) == 0: + raise ValueError(f"MemoryManager: store '{store.name}' has extraction config but no triggers") + # Each extraction shape needs its matching write sink. + if store.extraction.extractor is not None: + if not _has_method(store, "add"): + raise ValueError( + f"MemoryManager: store '{store.name}' has an extractor but no add method " + "(extracted entries are written via add)" + ) + elif not _has_method(store, "add_messages"): + raise ValueError( + f"MemoryManager: store '{store.name}' has extraction config without an extractor " + "but no add_messages method" + ) + + super().__init__() + + self._stores = list(stores) + self._search_stores = list(stores) + # `add`-targeting paths (tool / programmatic) need an `add` method specifically. + self._add_stores = [store for store in stores if store.writable and _has_method(store, "add")] + self._extraction_stores = [store for store in stores if store.writable and store.extraction is not None] + + self._search_tool_config: MemoryToolConfig | bool + if search_tool_config is False: + self._search_tool_config = False + elif isinstance(search_tool_config, MemoryToolConfig): + self._search_tool_config = search_tool_config + else: + self._search_tool_config = MemoryToolConfig() + + self._add_tool_config: MemoryAddToolConfig | bool + self._add_tool_stores: list[MemoryStore] + if add_tool_config is None or add_tool_config is False: + self._add_tool_config = False + self._add_tool_stores = [] + else: + # The `add_memory` tool writes via `add`, so needs an `add`-capable store. + if len(self._add_stores) == 0: + raise ValueError("MemoryManager: add_tool_config is enabled but no writable stores implement add") + resolved_config = ( + add_tool_config if isinstance(add_tool_config, MemoryAddToolConfig) else MemoryAddToolConfig() + ) + self._add_tool_config = resolved_config + self._add_tool_stores = self._resolve_add_tool_stores(resolved_config) + + # Fire-and-forget background tasks, retained so they aren't GC'd mid-flight. + self._background_tasks: set[asyncio.Task] = set() + + # Extraction coordinator, created in ``init_agent`` when configured. + self._coordinator: ExtractionCoordinator | None = None + + self._flush_on_invocation_end = flush_on_invocation_end + + # Build tools now; surfaced via the ``tools`` property. + self._memory_tools: list[AgentTool] = self._build_tools() + + def _resolve_add_tool_stores(self, tool_config: MemoryAddToolConfig) -> list[MemoryStore]: + """Resolve the writable stores the ``add_memory`` tool may write to. + + Each entry (a store name or instance) must resolve by name to a + configured, ``add``-capable writable store. Omitted means all such stores. + + Raises: + ValueError: If a referenced store is not configured, not writable, or + has no ``add`` method. + """ + if tool_config.stores is None: + return self._add_stores + + names = [store if isinstance(store, str) else store.name for store in tool_config.stores] + + resolved: list[MemoryStore] = [] + seen: set[str] = set() + for name in names: + if name in seen: + continue + seen.add(name) + found = next((store for store in self._stores if store.name == name), None) + if found is None: + raise ValueError(f"MemoryManager: add_tool_config store '{name}' not found") + if not found.writable: + raise ValueError(f"MemoryManager: add_tool_config store '{name}' is not writable") + if not _has_method(found, "add"): + raise ValueError(f"MemoryManager: add_tool_config store '{name}' has no add method (only add_messages)") + resolved.append(found) + return resolved + + def _build_tools(self) -> list[AgentTool]: + """Build the tools this plugin registers. + + Includes the manager's ``search_memory`` / ``add_memory`` tools plus any + tools the stores expose via + :meth:`~strands.memory.types.MemoryStore.get_tools`, in store order. + """ + tools: list[AgentTool] = [] + + if isinstance(self._search_tool_config, MemoryToolConfig): + tools.append(self._create_search_tool(self._search_tool_config)) + + if isinstance(self._add_tool_config, MemoryAddToolConfig): + tools.append(self._create_add_tool(self._add_tool_config, self._add_tool_stores)) + + for store in self._stores: + if _has_method(store, "get_tools"): + tools.extend(store.get_tools()) + + return tools + + @property + def tools(self) -> list[AgentTool]: # type: ignore[override] + """Tools registered by this plugin: search/add plus any store-provided tools. + + Widens the base :class:`~strands.plugins.plugin.Plugin` annotation because + a store's ``get_tools`` may contribute any + :class:`~strands.types.tools.AgentTool`. + """ + return list(self._memory_tools) + + async def search(self, query: str, options: MemorySearchOptions | None = None) -> list[MemoryEntry]: + """Search stores for entries matching the query. + + Unscoped: searches all configured stores when ``options.stores`` is + omitted. Results are attributed to their store via ``store_name`` and + concatenated in target order. + + Raises: + ValueError: If a named store is not found (raised before querying). + """ + requested_stores = options.stores if options is not None else None + caller_max = options.max_search_results if options is not None else None + + logger.debug( + "query=<%s>, max_search_results=<%s>, stores=<%s> | searching stores", + query, + caller_max, + requested_stores, + ) + + if requested_stores is not None: + target_stores: list[MemoryStore] = [] + seen: set[str] = set() + for name in requested_stores: + if name in seen: + continue + seen.add(name) + found = next((store for store in self._stores if store.name == name), None) + if found is None: + raise ValueError(f"MemoryManager: store '{name}' not found") + target_stores.append(found) + else: + target_stores = self._stores + + settled = await asyncio.gather( + *( + store.search( + query, + MemorySearchOptions( + max_search_results=( + caller_max + if caller_max is not None + else store.max_search_results + if store.max_search_results is not None + else DEFAULT_MAX_SEARCH_RESULTS + ) + ), + ) + for store in target_stores + ), + return_exceptions=True, + ) + + results: list[MemoryEntry] = [] + for store, outcome in zip(target_stores, settled, strict=True): + if isinstance(outcome, BaseException): + logger.warning("store=<%s>, reason=<%s> | store search failed", store.name, outcome) + continue + for entry in outcome: + results.append(MemoryEntry(content=entry.content, store_name=store.name, metadata=entry.metadata)) + + logger.debug("results=<%s> | search complete", len(results)) + return results + + async def add(self, content: str, options: MemoryAddOptions | None = None) -> None: + """Add content to writable stores. + + Unscoped: targets all configured writable stores. Target stores are + validated first, then writes are awaited concurrently; per-store failures + are logged and surfaced as an + :class:`~strands.types.exceptions.AggregateMemoryError`. + + Raises: + ValueError: If a named store is not found or is read-only, or if no + writable store matched. + AggregateMemoryError: If any targeted store write fails. + """ + requested_stores = options.stores if options is not None else None + metadata = options.metadata if options is not None else None + + if requested_stores is not None: + writable_stores: list[MemoryStore] = [] + seen: set[str] = set() + for name in requested_stores: + if name in seen: + continue + seen.add(name) + found = next((store for store in self._stores if store.name == name), None) + if found is None: + raise ValueError(f"MemoryManager: store '{name}' not found") + if not found.writable: + raise ValueError(f"MemoryManager: store '{name}' is read-only") + writable_stores.append(found) + else: + writable_stores = self._add_stores + + if len(writable_stores) == 0: + raise ValueError("MemoryManager: no writable store matched") + + settled = await asyncio.gather( + *(store.add(content, metadata) for store in writable_stores), + return_exceptions=True, + ) + + failed_names: list[str] = [] + reasons: list[BaseException] = [] + for store, outcome in zip(writable_stores, settled, strict=True): + if isinstance(outcome, BaseException): + logger.warning("store=<%s>, reason=<%s> | store write failed", store.name, outcome) + failed_names.append(store.name) + reasons.append(outcome) + + if failed_names: + raise AggregateMemoryError( + f"MemoryManager: store writes failed: {', '.join(failed_names)}", + reasons, + ) + + def _resolve_tool_targets(self, scoped_names: list[str], requested: list[str] | None) -> list[str]: + """Resolve the store names a tool callback should target. + + Omitting ``requested`` targets all scoped stores; in-scope names are kept + and out-of-scope names are dropped with a warning. + + Raises: + ValueError: If every requested name is out of scope. + """ + if requested is None or len(requested) == 0: + return scoped_names + + scoped_set = set(scoped_names) + in_scope = [name for name in requested if name in scoped_set] + out_of_scope = [name for name in requested if name not in scoped_set] + + if len(in_scope) == 0: + raise ValueError( + f"MemoryManager: requested=<{', '.join(requested)}> | none of the requested memory stores " + f"are available; available stores: {', '.join(scoped_names)}" + ) + + if out_of_scope: + logger.warning( + "requested=<%s> | ignoring memory stores outside this tool's scope", + ", ".join(out_of_scope), + ) + + return in_scope + + def _create_search_tool(self, config: MemoryToolConfig) -> AgentTool: + """Build the ``search_memory`` tool.""" + description = config.description if config.description is not None else SEARCH_TOOL_DESCRIPTION + store_descriptions = [ + f"- {store.name}: {store.description}" for store in self._search_stores if store.description + ] + if store_descriptions: + description += "\n\nAvailable memory stores:\n" + "\n".join(store_descriptions) + description += ( + "\n\nYou can target one or more memory stores by name if you know which domains are relevant, " + "or omit the stores parameter to search all." + ) + + scoped_names = [store.name for store in self._search_stores] + + async def search_memory( + query: str, + max_search_results: int | None = None, + stores: list[str] | None = None, + ) -> list[dict[str, Any]]: + """Search long-term memory. + + Args: + query: What to search for. + max_search_results: Maximum number of results per store. + stores: Filter to specific stores by name. Omit to search all + available stores. + + Returns: + Matching memory entries, each attributed to its store. + """ + targets = self._resolve_tool_targets(scoped_names, stores) + results = await self.search( + query, + MemorySearchOptions(max_search_results=max_search_results, stores=targets), + ) + payload: list[dict[str, Any]] = [] + for entry in results: + item: dict[str, Any] = {"content": entry.content} + if entry.store_name: + item["store_name"] = entry.store_name + if entry.metadata: + item["metadata"] = entry.metadata + payload.append(item) + return payload + + return tool( + name=config.name if config.name is not None else "search_memory", + description=description, + )(search_memory) + + def _create_add_tool(self, config: MemoryAddToolConfig, stores: list[MemoryStore]) -> AgentTool: + """Build the ``add_memory`` tool.""" + description = config.description if config.description is not None else ADD_TOOL_DESCRIPTION + store_descriptions = [f"- {store.name}: {store.description}" for store in stores if store.description] + if store_descriptions: + description += "\n\nAvailable writable stores:\n" + "\n".join(store_descriptions) + description += ( + "\n\nYou can target a specific store by name to route facts to the right place, " + "or omit to add to all available writable stores." + ) + + scoped_names = [store.name for store in stores] + wait_for_writes = config.wait_for_writes + + async def add_memory(entries: list[str], stores: list[str] | None = None) -> dict[str, int]: + """Add data to long-term memory. + + Args: + entries: Data to add to long-term memory. + stores: Target specific stores by name. Omit to add to all + writable stores. + + Returns: + A summary of the write (``{"stored": n}`` or ``{"accepted": n}``). + """ + # @tool validation does not enforce ``minItems``, so guard here. + if not entries: + raise ValueError("MemoryManager: add_memory requires at least one entry") + + targets = self._resolve_tool_targets(scoped_names, stores) + + if not wait_for_writes: + # Fire-and-forget: dispatch without awaiting. ``add`` logs per-store failures. + for content in entries: + self._schedule_background(self._add_swallow(content, targets)) + return {"accepted": len(entries)} + + # Await mode: surface failures with concrete (flattened) reasons. + settled = await asyncio.gather( + *(self.add(content, MemoryAddOptions(stores=targets)) for content in entries), + return_exceptions=True, + ) + failures = [outcome for outcome in settled if isinstance(outcome, BaseException)] + if failures: + flattened = _flatten_reasons(failures) + joined = "; ".join(str(reason) for reason in flattened) + raise AggregateMemoryError( + f"MemoryManager: failed to add {len(failures)} of {len(entries)} entries: {joined}", + flattened, + ) + + return {"stored": len(entries)} + + return tool( + name=config.name if config.name is not None else "add_memory", + description=description, + )(add_memory) + + async def _add_swallow(self, content: str, targets: list[str]) -> None: + """Run a programmatic ``add`` and swallow any failure (the add tool's fire-and-forget mode).""" + try: + await self.add(content, MemoryAddOptions(stores=targets)) + except Exception: # noqa: BLE001 - failures are logged in ``add``; swallow here. + pass + + def _schedule_background(self, coroutine: Any) -> None: + """Schedule a coroutine as a tracked background task.""" + task = asyncio.ensure_future(coroutine) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + def init_agent(self, agent: Agent) -> None: + """Initialize the plugin with the agent. + + Wires up automatic extraction for any store configured with an + ``ExtractionConfig``. A no-op when no store uses extraction. + """ + if len(self._extraction_stores) == 0: + return + + coordinator = ExtractionCoordinator(self._extraction_stores, agent.model) + self._coordinator = coordinator + + # Buffer every message so extraction has its own copy to save from. + agent.add_hook(lambda event: coordinator.record(event.message), MessageAddedEvent) + + for store in self._extraction_stores: + assert store.extraction is not None # noqa: S101 - extraction stores always configure this. + for trigger in _normalize_triggers(store.extraction.trigger): + trigger.attach(ExtractionTriggerContext(agent=agent, fire=self._make_fire(coordinator, store))) + + if self._flush_on_invocation_end: + agent.add_hook(self._flush_after_invocation, AfterInvocationEvent, order=HookOrder.SDK_LAST) + else: + logger.warning( + "flush_on_invocation_end= | background extraction is lost if the event loop closes " + "before it finishes (e.g. the synchronous Agent(...) entry point); safe to ignore if you " + "await MemoryManager.flush() at a shutdown boundary or enable flush_on_invocation_end." + ) + + async def _flush_after_invocation(self, event: AfterInvocationEvent) -> None: + """Await pending extraction writes at the end of an agent invocation.""" + await self.flush() + + @staticmethod + def _make_fire(coordinator: ExtractionCoordinator, store: MemoryStore) -> Callable[[], None]: + """Build a zero-arg ``fire`` callback bound to a specific store.""" + + def fire() -> None: + coordinator.schedule(store) + + return fire + + async def flush(self) -> None: + """Save every store's remaining messages and wait for all saves to finish. + + A no-op when no store has extraction configured. Drains automatic + extraction only; ``add_memory`` fire-and-forget writes are not awaited + here. + """ + if self._coordinator is not None: + await self._coordinator.flush() diff --git a/strands-py/src/strands/memory/types.py b/strands-py/src/strands/memory/types.py new file mode 100644 index 0000000000..d8399186a4 --- /dev/null +++ b/strands-py/src/strands/memory/types.py @@ -0,0 +1,205 @@ +"""Core types for the Strands memory module.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Protocol + +from ..types.content import Message +from ..types.tools import AgentTool + +if TYPE_CHECKING: + # Lazy import to avoid a circular import with the extraction subpackage; + # ``ExtractionConfig`` is only referenced in annotations. + from .extraction.types import ExtractionConfig + +# JSON-compatible metadata mapping (scores, ids, timestamps, etc.). +Metadata = dict[str, Any] + + +@dataclass +class MemoryEntry: + """A single memory entry retrieved from or stored to a memory store. + + Attributes: + store_name: Name of the store this entry came from, set by + ``MemoryManager.search``. Stores need not set this themselves. + """ + + content: str + store_name: str | None = None + metadata: Metadata | None = None + + +@dataclass +class SearchOptions: + """Options passed to :meth:`MemoryStore.search`. + + Store implementations may extend this with backend-specific fields; note that + ``MemoryManager.search`` forwards only these base fields across its stores. + """ + + max_search_results: int | None = None + + +@dataclass +class AddMessagesContext: + """Context the manager supplies to :meth:`MemoryStore.add_messages`. + + Intentionally empty for now so fields can be added later without a breaking + signature change. + """ + + +@dataclass +class MemorySearchOptions(SearchOptions): + """Options for ``MemoryManager.search``. + + Attributes: + stores: Filter to specific stores by name. Omit to search all. A + programmatic search with an empty list searches no stores, whereas + the ``search_memory`` tool treats an empty list as "search all + in-scope stores". + """ + + stores: list[str] | None = None + + +@dataclass +class MemoryAddOptions: + """Options for ``MemoryManager.add``. + + Attributes: + stores: Filter to specific writable stores by name. Omit to write to all. + A programmatic add with an empty list matches no store (raises), + whereas the ``add_memory`` tool treats an empty list as "write to all + in-scope stores". + """ + + metadata: Metadata | None = None + stores: list[str] | None = None + + +@dataclass +class MemoryToolConfig: + """Configuration for customizing a memory tool's name or description.""" + + name: str | None = None + description: str | None = None + + +@dataclass +class MemoryAddToolConfig(MemoryToolConfig): + """Configuration for the ``add_memory`` tool. + + Attributes: + stores: The writable stores the tool may write to, as store names or + :class:`MemoryStore` instances. Omit to allow all writable stores. + wait_for_writes: When ``True`` (default), wait for writes and return + ``{"stored": ...}`` (or surface a failure to the model). When + ``False``, fire-and-forget: return ``{"accepted": ...}`` once writes + are dispatched; per-store failures are logged. + """ + + stores: list[str | MemoryStore] | None = None + wait_for_writes: bool = True + + +@dataclass +class MemoryManagerConfig: + """Configuration for the ``MemoryManager``, mirroring the constructor kwargs. + + Attributes: + stores: One or more memory stores to manage. + search_tool_config: Search tool configuration. Defaults to ``True``. + add_tool_config: Add tool configuration. Defaults to ``False`` (opt-in); + ``True`` allows all writable stores, or pass a + :class:`MemoryAddToolConfig` to restrict it. + flush_on_invocation_end: When ``True``, await pending extraction writes at + the end of each agent invocation. Defaults to ``False``. + """ + + stores: list[MemoryStore] + search_tool_config: MemoryToolConfig | bool = True + add_tool_config: MemoryAddToolConfig | bool = False + flush_on_invocation_end: bool = False + + +class MemoryStoreConfig(Protocol): + """Declarative identity and behavior fields shared by every memory store. + + Attributes: + name: Unique identifier for this store, used to target it in tools. + description: Human-readable description; included in tool descriptions. + max_search_results: Default maximum results per search, used when a caller + does not pass a per-call value. + writable: Whether this store accepts writes. A writable store requires at + least one write sink (:meth:`MemoryStore.add` or + :meth:`MemoryStore.add_messages`). + extraction: Automatic-extraction configuration. Requires the store to be + writable. + """ + + name: str + description: str | None + max_search_results: int | None + writable: bool + extraction: ExtractionConfig | None + + +class MemoryStore(MemoryStoreConfig, Protocol): + """Runtime contract for a memory store backend. + + Extends :class:`MemoryStoreConfig` with runtime methods. Every store is + searchable; ``writable`` declares whether it also accepts writes. A store + author implements the config fields plus :meth:`search`, and optionally + :meth:`add`, :meth:`add_messages`, and :meth:`get_tools`. + """ + + async def search(self, query: str, options: SearchOptions | None = None) -> list[MemoryEntry]: + """Search the store for entries matching the query, ordered by relevance.""" + ... + + # --- Optional methods: detect presence via ``_has_method`` / ``_has_write_sink``. + + async def add(self, content: str, metadata: Metadata | None = None) -> Any: + """Add a single piece of content to the store. + + Extraction writes are at-least-once, so implementations used with + extraction should tolerate duplicate writes. The resolved value is + store-specific and not consumed by the manager. + """ + ... + + async def add_messages(self, messages: list[Message], context: AddMessagesContext | None = None) -> Any: + """Ingest a batch of conversation messages, preserving role structure. + + The sink for extraction without a client-side extractor: the manager + hands the filtered batch straight here. The resolved value is + store-specific. + """ + ... + + def get_tools(self) -> list[AgentTool]: + """Return store-specific tools to register alongside the manager's tools.""" + ... + + +def _has_method(store: object, name: str) -> bool: + """Return whether ``store`` actually implements the named method. + + Inspects the store's type so a class that merely inherits the + :class:`MemoryStore` Protocol's stub counts as "not implemented". + """ + method = getattr(type(store), name, None) + if method is None: + return False + # A subclass can inherit the Protocol's stub; treat that as "not implemented". + if method is getattr(MemoryStore, name, None): + return False + return callable(method) + + +def _has_write_sink(store: MemoryStore) -> bool: + """Return whether ``store`` provides at least one write sink (``add`` or ``add_messages``).""" + return _has_method(store, "add") or _has_method(store, "add_messages") diff --git a/strands-py/src/strands/types/exceptions.py b/strands-py/src/strands/types/exceptions.py index beeb6c7aa6..7d621191d7 100644 --- a/strands-py/src/strands/types/exceptions.py +++ b/strands-py/src/strands/types/exceptions.py @@ -127,3 +127,22 @@ class CheckpointException(Exception): """Exception raised when checkpoint operations fail (e.g., incompatible schema version).""" pass + + +class AggregateMemoryError(Exception): + """Raised when one or more memory store operations fail. + + Attributes: + errors: The underlying exceptions that caused this aggregate failure. + """ + + def __init__(self, message: str, errors: list[BaseException]) -> None: + """Initialize the aggregate error. + + Args: + message: A human-readable description of the aggregate failure, + typically naming the stores that failed. + errors: The underlying exceptions that caused this failure. + """ + super().__init__(message) + self.errors = errors diff --git a/strands-py/tests/strands/memory/__init__.py b/strands-py/tests/strands/memory/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/strands-py/tests/strands/memory/test_extraction.py b/strands-py/tests/strands/memory/test_extraction.py new file mode 100644 index 0000000000..d3ee73a23f --- /dev/null +++ b/strands-py/tests/strands/memory/test_extraction.py @@ -0,0 +1,593 @@ +"""Tests for ``ExtractionCoordinator``. + +Ported from the coordinator-focused parts of +``strands-ts/src/memory/extraction/__tests__/extraction.test.ts``. Each TS +``it(...)`` that exercises coordinator behavior maps to a ``test_...`` case here. +Where the TS suite drove the coordinator through ``MemoryManager`` + agent hooks, +these tests drive the :class:`ExtractionCoordinator` directly (``record`` / +``process`` / ``flush`` / ``schedule``) so attempt counts are deterministic and +no agent/manager wiring (Task 9/10) is required. + +These tests are written test-first: ``ExtractionCoordinator`` lands in Task 8, so +they are expected to fail with an ``ImportError`` until that implementation +exists. The backoff constants are imported from the module under test and used in +assertions rather than hard-coded. +""" + +from __future__ import annotations + +import asyncio +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from strands.memory.extraction.coordinator import ( + BACKOFF_PROBE_INTERVAL, + SAVE_FAILURES_BEFORE_BACKOFF, + ExtractionCoordinator, +) +from strands.memory.extraction.types import ( + ExtractionConfig, + ExtractionResult, + ExtractionTrigger, + MemoryMessageFilter, +) +from strands.types.content import Message + +# A stub model. The coordinator only passes ``default_model`` through to an +# extractor's context, so any sentinel object suffices. +_DEFAULT_MODEL: Any = SimpleNamespace(id="stub-model") + + +# --------------------------------------------------------------------------- # +# Test fakes / helpers +# --------------------------------------------------------------------------- # + + +class _NoopTrigger(ExtractionTrigger): + """A trigger that never wires anything. + + The coordinator never calls ``attach`` (triggers are attached by the manager + in Task 10), so a no-op satisfies ``ExtractionConfig.trigger`` without pulling + in agent wiring. + """ + + name = "noop" + + def attach(self, context: Any) -> None: # pragma: no cover - never called + pass + + +def _trigger() -> _NoopTrigger: + return _NoopTrigger() + + +def _make_store( + name: str, + extraction: ExtractionConfig, + sink: str = "both", +) -> Any: + """Build a writable fake store with the requested write sink(s). + + ``sink`` chooses which write method(s) the store exposes, mirroring the TS + ``createExtractionStore`` helper: + + - ``"add"`` -> only ``add`` (extractor route) + - ``"add_messages"`` -> only ``add_messages`` (passthrough route) + - ``"both"`` -> both methods present + + ``search`` / ``add`` / ``add_messages`` are ``AsyncMock``s so tests can assert + call counts/args and inject failures or gating. + """ + store = SimpleNamespace() + store.name = name + store.description = None + store.max_search_results = None + store.writable = True + store.extraction = extraction + store.search = AsyncMock(return_value=[]) + store.add = AsyncMock(return_value=None) + store.add_messages = AsyncMock(return_value=None) + + if sink == "add": + del store.add_messages + elif sink == "add_messages": + del store.add + return store + + +def _make_extractor(entries: list[ExtractionResult]) -> Any: + """Build a fake ``Extractor`` whose ``extract`` is an ``AsyncMock``.""" + extractor = SimpleNamespace() + extractor.extract = AsyncMock(return_value=list(entries)) + return extractor + + +def _user_msg(text: str) -> Message: + return {"role": "user", "content": [{"text": text}]} + + +def _assistant_msg(text: str) -> Message: + return {"role": "assistant", "content": [{"text": text}]} + + +def _tool_use_msg() -> Message: + return {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "t", "input": {}}}]} + + +def _added_metadata(call: Any) -> Any: + """Pull the metadata argument from a recorded ``store.add`` call. + + Tolerates ``add(content, metadata)`` (positional) and + ``add(content, metadata=...)`` (keyword). + """ + if len(call.args) > 1: + return call.args[1] + return call.kwargs.get("metadata") + + +def _extractor_context(call: Any) -> Any: + """Pull the ``ExtractorContext`` argument from a recorded ``extract`` call.""" + if len(call.args) > 1: + return call.args[1] + return call.kwargs.get("context") + + +def _saved_texts(mock: AsyncMock) -> list[str]: + """Flatten every text block delivered across all calls to an add* mock.""" + texts: list[str] = [] + for call in mock.call_args_list: + batch = call.args[0] + for message in batch: + for block in message["content"]: + if "text" in block: + texts.append(block["text"]) + return texts + + +async def _drive(coordinator: ExtractionCoordinator, store: Any) -> None: + """Request a save and await it. + + ``process`` returns the queued save task, or ``None`` when the store is + backed off and this request is not a probe; awaiting the task drives the save + to completion for deterministic assertions. + """ + task = coordinator.process(store) + if task is not None: + await task + + +# --------------------------------------------------------------------------- # +# No-extractor passthrough +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_no_extractor_passthrough_hands_raw_batch_to_add_messages(): + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("I prefer dark mode")) + coordinator.record(_assistant_msg("Noted")) + await _drive(coordinator, store) + + store.add_messages.assert_called_once() + batch = store.add_messages.call_args.args[0] + assert len(batch) == 2 + assert batch[0]["role"] == "user" + assert batch[1]["role"] == "assistant" + + +# --------------------------------------------------------------------------- # +# Extractor route +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_extractor_route_calls_extractor_and_writes_each_entry_via_add(): + extractor = _make_extractor( + [ExtractionResult(content="fact one"), ExtractionResult(content="fact two", metadata={"k": "v"})] + ) + store = _make_store("s", ExtractionConfig(trigger=_trigger(), extractor=extractor), sink="both") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("something happened")) + await _drive(coordinator, store) + + extractor.extract.assert_called_once() + assert store.add.call_count == 2 + assert store.add.call_args_list[0].args[0] == "fact one" + assert store.add.call_args_list[1].args[0] == "fact two" + assert _added_metadata(store.add.call_args_list[1]) == {"k": "v"} + # The extractor route never uses the batch sink. + store.add_messages.assert_not_called() + + +@pytest.mark.asyncio +async def test_extractor_route_passes_default_model_in_context(): + extractor = _make_extractor([]) + store = _make_store("s", ExtractionConfig(trigger=_trigger(), extractor=extractor), sink="both") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("hi")) + await _drive(coordinator, store) + + extractor.extract.assert_called_once() + context = _extractor_context(extractor.extract.call_args) + assert context is not None + assert context.default_model is _DEFAULT_MODEL + + +@pytest.mark.asyncio +async def test_extractor_route_writes_entries_concurrently(): + # Both add() calls should be in flight before either resolves: the first add + # blocks until the second has started, which is only possible if the writes + # run concurrently (a serial await loop would deadlock). + second_started = asyncio.Event() + first_invoked_during_second = False + call_index = {"n": 0} + + async def add_impl(content: str, metadata: Any = None) -> None: + nonlocal first_invoked_during_second + index = call_index["n"] + call_index["n"] += 1 + if index == 0: + await second_started.wait() + first_invoked_during_second = second_started.is_set() + else: + second_started.set() + + extractor = _make_extractor([ExtractionResult(content="a"), ExtractionResult(content="b")]) + store = _make_store("s", ExtractionConfig(trigger=_trigger(), extractor=extractor), sink="add") + store.add.side_effect = add_impl + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("x")) + await _drive(coordinator, store) + + assert store.add.call_count == 2 + assert first_invoked_during_second is True + + +@pytest.mark.asyncio +async def test_extractor_route_rolls_back_and_retries_batch_on_entry_failure(): + extractor = _make_extractor([ExtractionResult(content="a"), ExtractionResult(content="b")]) + store = _make_store("s", ExtractionConfig(trigger=_trigger(), extractor=extractor), sink="add") + # First batch: second entry write fails -> whole batch rolled back. + store.add.side_effect = [None, RuntimeError("write failed"), None, None] + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("x")) + await _drive(coordinator, store) # fails, mark rolled back + await _drive(coordinator, store) # retries the same batch + + # 2 writes on the first attempt + 2 on the retry. + assert store.add.call_count == 4 + assert extractor.extract.call_count == 2 + + +# --------------------------------------------------------------------------- # +# Message filter +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_filter_drops_tool_blocks_by_default_and_empties(): + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("keep me")) + coordinator.record(_tool_use_msg()) # tool-only message -> emptied -> dropped + await _drive(coordinator, store) + + store.add_messages.assert_called_once() + batch = store.add_messages.call_args.args[0] + assert len(batch) == 1 + assert batch[0]["role"] == "user" + + +@pytest.mark.asyncio +async def test_filter_honors_a_custom_filter(): + store = _make_store( + "s", + ExtractionConfig(trigger=_trigger(), filter=MemoryMessageFilter(exclude=["text"])), + sink="add_messages", + ) + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("this is text and should be excluded")) + await _drive(coordinator, store) + + # The only message was text, excluded -> emptied -> nothing to write. + store.add_messages.assert_not_called() + + +# --------------------------------------------------------------------------- # +# High-water-mark dedup +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_hwm_processes_only_messages_added_since_the_last_save(): + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("turn one")) + await _drive(coordinator, store) + coordinator.record(_user_msg("turn two")) + await _drive(coordinator, store) + + assert store.add_messages.call_count == 2 + assert len(store.add_messages.call_args_list[0].args[0]) == 1 + second = store.add_messages.call_args_list[1].args[0] + assert len(second) == 1 + assert second[0]["content"][0]["text"] == "turn two" + + +@pytest.mark.asyncio +async def test_hwm_does_nothing_when_no_new_messages_since_the_mark(): + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("only turn")) + await _drive(coordinator, store) + await _drive(coordinator, store) # no new messages + + assert store.add_messages.call_count == 1 + + +@pytest.mark.asyncio +async def test_hwm_retries_the_same_messages_on_the_next_save_if_a_write_fails(): + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + store.add_messages.side_effect = [RuntimeError("backend down"), None] + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("important")) + await _drive(coordinator, store) # fails, mark rolled back + await _drive(coordinator, store) # retries + + assert store.add_messages.call_count == 2 + assert len(store.add_messages.call_args_list[1].args[0]) == 1 + + +# --------------------------------------------------------------------------- # +# Backing off and recovering from a failing store +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_backs_off_to_periodic_probes_after_threshold_failures(): + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + store.add_messages.side_effect = RuntimeError("backend down") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + # Each call buffers a message and requests a save; every save fails. Run + # enough backed-off requests for exactly two probe intervals. + probes = 2 + requests = SAVE_FAILURES_BEFORE_BACKOFF + BACKOFF_PROBE_INTERVAL * probes + for index in range(requests): + coordinator.record(_user_msg(f"m{index}")) + await _drive(coordinator, store) + + # Attempts every request until backoff, then only every probe interval. + assert store.add_messages.call_count == SAVE_FAILURES_BEFORE_BACKOFF + probes + + +@pytest.mark.asyncio +async def test_recovers_and_saves_the_buffered_backlog_when_the_store_comes_back(): + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + store.add_messages.side_effect = RuntimeError("down") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + # Drive the store into backoff. + for index in range(SAVE_FAILURES_BEFORE_BACKOFF): + coordinator.record(_user_msg(f"down{index}")) + await _drive(coordinator, store) + + # Store recovers; run probe-interval requests so a probe lands and succeeds. + store.add_messages.reset_mock() + store.add_messages.side_effect = None + store.add_messages.return_value = None + for index in range(BACKOFF_PROBE_INTERVAL): + coordinator.record(_user_msg(f"up{index}")) + await _drive(coordinator, store) + + # The recovering probe saved, and its batch includes the outage backlog. + assert store.add_messages.called + texts = _saved_texts(store.add_messages) + assert "down0" in texts + assert "up0" in texts + + +@pytest.mark.asyncio +async def test_a_healthy_store_keeps_saving_every_request_while_a_sibling_is_backed_off(): + bad = _make_store("bad", ExtractionConfig(trigger=_trigger()), sink="add_messages") + bad.add_messages.side_effect = RuntimeError("down") + good = _make_store("good", ExtractionConfig(trigger=_trigger()), sink="add_messages") + coordinator = ExtractionCoordinator([bad, good], _DEFAULT_MODEL) + + probes = 2 + requests = SAVE_FAILURES_BEFORE_BACKOFF + BACKOFF_PROBE_INTERVAL * probes + for index in range(requests): + coordinator.record(_user_msg(f"m{index}")) + await _drive(coordinator, bad) + await _drive(coordinator, good) + + # Good store saves every request; bad store stops at backoff + its probes. + assert good.add_messages.call_count == requests + assert bad.add_messages.call_count == SAVE_FAILURES_BEFORE_BACKOFF + probes + + +@pytest.mark.asyncio +async def test_flush_resolves_even_when_a_store_is_failing(): + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + store.add_messages.side_effect = RuntimeError("down") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("x")) + await _drive(coordinator, store) # fails (swallowed) + + assert await coordinator.flush() is None + + +@pytest.mark.asyncio +async def test_flush_bypasses_backoff_and_writes_the_backlog_of_a_recovered_store(): + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + store.add_messages.side_effect = RuntimeError("down") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + # Drive the store into backoff. + for index in range(SAVE_FAILURES_BEFORE_BACKOFF): + coordinator.record(_user_msg(f"down{index}")) + await _drive(coordinator, store) + + # Store recovers and a final message arrives, but no probe has landed yet. + store.add_messages.reset_mock() + store.add_messages.side_effect = None + store.add_messages.return_value = None + coordinator.record(_user_msg("final")) + + # A single flush must write the backlog despite backoff (not be probe-gated). + await coordinator.flush() + + store.add_messages.assert_called_once() + texts = _saved_texts(store.add_messages) + assert "final" in texts + assert "down0" in texts + + +@pytest.mark.asyncio +async def test_a_fully_filtered_empty_turn_does_not_reset_the_failure_streak(): + # A no-extractor store; the default filter drops tool blocks. A turn of only + # tool blocks contributes no extractable content, so it must not be mistaken + # for a recovery that clears the prior failures. We prove the streak survives + # by showing backoff still engages and the next request is probe-gated. + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + store.add_messages.side_effect = RuntimeError("down") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + # One short of backoff. + for index in range(SAVE_FAILURES_BEFORE_BACKOFF - 1): + coordinator.record(_user_msg(f"m{index}")) + await _drive(coordinator, store) + + # An all-tool-blocks turn: its own content filters away. + coordinator.record(_tool_use_msg()) + await _drive(coordinator, store) + + # The next real failure tips into backoff (it would not if the streak reset). + coordinator.record(_user_msg("nth")) + await _drive(coordinator, store) + + # Backed off: the next request is probe-gated, so the backend isn't called. + store.add_messages.reset_mock() + coordinator.record(_user_msg("after")) + await _drive(coordinator, store) + store.add_messages.assert_not_called() + + +# --------------------------------------------------------------------------- # +# Flush semantics +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_flush_force_extracts_a_buffered_tail_whose_trigger_never_fired(): + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + # Buffer messages but never call process (the trigger never fired). + coordinator.record(_user_msg("a")) + coordinator.record(_user_msg("b")) + store.add_messages.assert_not_called() + + await coordinator.flush() + + store.add_messages.assert_called_once() + assert len(store.add_messages.call_args.args[0]) == 2 + + +@pytest.mark.asyncio +async def test_flush_does_not_re_extract_messages_already_processed(): + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("a")) + await _drive(coordinator, store) # already extracted + assert store.add_messages.call_count == 1 + + await coordinator.flush() # nothing fresh -> no-op + assert store.add_messages.call_count == 1 + + +@pytest.mark.asyncio +async def test_flush_is_a_no_op_when_nothing_is_buffered(): + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + await coordinator.flush() + + store.add_messages.assert_not_called() + + +@pytest.mark.asyncio +async def test_flush_awaits_an_in_flight_write(): + release = asyncio.Event() + completed = {"v": False} + + async def add_messages_impl(messages: Any, context: Any = None) -> None: + await release.wait() + completed["v"] = True + + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + store.add_messages.side_effect = add_messages_impl + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("hello")) + coordinator.schedule(store) # non-blocking background save + + flushed = asyncio.ensure_future(coordinator.flush()) + await asyncio.sleep(0) # let flush start waiting on the in-flight write + assert completed["v"] is False + + release.set() + await flushed + assert completed["v"] is True + + +# --------------------------------------------------------------------------- # +# Background, non-blocking execution +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_background_save_does_not_block_scheduling_and_flush_awaits_it(): + release = asyncio.Event() + started = asyncio.Event() + completed = {"v": False} + + async def add_messages_impl(messages: Any, context: Any = None) -> None: + started.set() + await release.wait() + completed["v"] = True + + store = _make_store("s", ExtractionConfig(trigger=_trigger()), sink="add_messages") + store.add_messages.side_effect = add_messages_impl + coordinator = ExtractionCoordinator([store], _DEFAULT_MODEL) + + coordinator.record(_user_msg("hello")) + coordinator.schedule(store) # returns immediately, write runs in background + + # The background write begins but hangs; scheduling did not block on it. + await started.wait() + assert completed["v"] is False + + # flush must await the in-flight write to completion once it is released. + flushed = asyncio.ensure_future(coordinator.flush()) + await asyncio.sleep(0) + assert completed["v"] is False + release.set() + await flushed + assert completed["v"] is True diff --git a/strands-py/tests/strands/memory/test_memory_manager.py b/strands-py/tests/strands/memory/test_memory_manager.py new file mode 100644 index 0000000000..b6e5e2a657 --- /dev/null +++ b/strands-py/tests/strands/memory/test_memory_manager.py @@ -0,0 +1,1147 @@ +"""Tests for ``MemoryManager``. + +Ported from ``strands-ts/src/memory/__tests__/memory-manager.test.ts`` and the +manager-level / ``initAgent`` parts of +``strands-ts/src/memory/extraction/__tests__/extraction.test.ts``. Each TS +``it(...)`` maps to a ``test_...`` case here. + +These tests are written test-first: ``MemoryManager`` lands in Task 10, so they +are expected to fail with an ``ImportError`` until that implementation exists. + +Driving the built tools +----------------------- +The manager builds ``search_memory`` / ``add_memory`` via the ``tool()`` factory +wrapping async closures. ``DecoratedFunctionTool.__call__`` delegates straight to +the wrapped function, so a test invokes a tool by calling it with kwargs and +awaiting the returned coroutine (e.g. ``await search_tool(query="q")``). This +exercises the full closure (scope resolution + ``MemoryManager.search`` / +``.add``) without the agent runtime. + +Fake stores +----------- +``_store`` builds each fake store as an instance of a freshly-created class whose +``search`` / ``add`` / ``add_messages`` / ``get_tools`` live on the *class* (as +mocks). The manager detects optional methods with ``_has_method``, which inspects +``type(store)`` (not the instance) -- so optional methods must be class +attributes, and are only defined when the store is meant to expose them. This +mirrors the TS ``createMockStore`` (writable -> ``add``; ``tools=`` -> +``get_tools``) and ``createExtractionStore`` helpers. +""" + +from __future__ import annotations + +import asyncio +import inspect +import logging +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from strands.hooks.events import AfterInvocationEvent, MessageAddedEvent +from strands.hooks.registry import HookOrder +from strands.memory import AggregateMemoryError +from strands.memory.extraction.triggers import IntervalTrigger, InvocationTrigger +from strands.memory.extraction.types import ExtractionConfig, ExtractionResult +from strands.memory.memory_manager import DEFAULT_MAX_SEARCH_RESULTS, MemoryManager +from strands.memory.types import ( + MemoryAddOptions, + MemoryAddToolConfig, + MemoryEntry, + MemorySearchOptions, + MemoryToolConfig, +) +from strands.tools.decorator import tool + +# --------------------------------------------------------------------------- # +# Test fakes / helpers +# --------------------------------------------------------------------------- # + + +def _store( + name: str, + *, + entries: list[MemoryEntry] | None = None, + writable: bool = False, + description: str | None = None, + max_search_results: int | None = None, + tools: list[Any] | None = None, + extraction: ExtractionConfig | None = None, + sinks: set[str] | None = None, + search_error: BaseException | None = None, + add_error: BaseException | None = None, + add_messages_error: BaseException | None = None, +) -> Any: + """Build a fake ``MemoryStore``. + + Optional write/tool methods are placed on a freshly-created class so the + manager's ``_has_method`` (which inspects ``type(store)``) detects only the + capabilities the store is meant to expose. + + Args: + name: Store name. + entries: Entries the store's ``search`` returns (manager attributes them). + writable: Whether the store accepts writes. + description: Store description (surfaces in tool descriptions). + max_search_results: Per-store default search limit. + tools: When given, the store exposes ``get_tools`` returning these. + extraction: Extraction config for the store. + sinks: Which write sinks to expose (subset of ``{"add", "add_messages"}``). + Defaults to ``{"add"}`` when ``writable`` else no sinks. + search_error: When set, ``search`` raises this. + add_error: When set, ``add`` raises this. + add_messages_error: When set, ``add_messages`` raises this. + """ + methods: dict[str, Any] = {} + + if search_error is not None: + methods["search"] = AsyncMock(side_effect=search_error) + else: + methods["search"] = AsyncMock(return_value=list(entries or [])) + + if sinks is None: + sinks = {"add"} if writable else set() + + if "add" in sinks: + methods["add"] = AsyncMock(side_effect=add_error) if add_error is not None else AsyncMock(return_value=None) + if "add_messages" in sinks: + methods["add_messages"] = ( + AsyncMock(side_effect=add_messages_error) + if add_messages_error is not None + else AsyncMock(return_value=None) + ) + if tools is not None: + methods["get_tools"] = MagicMock(return_value=list(tools)) + + store_cls = type(f"_FakeStore_{name}", (), dict(methods)) + store = store_cls() + store.name = name + store.description = description + store.max_search_results = max_search_results + store.writable = writable + store.extraction = extraction + return store + + +def _make_extractor(entries: list[ExtractionResult]) -> Any: + """Build a fake ``Extractor`` whose ``extract`` is an ``AsyncMock``.""" + extractor = SimpleNamespace() + extractor.extract = AsyncMock(return_value=list(entries)) + return extractor + + +def _named_tool(name: str) -> Any: + """Build a named function tool (mirrors the TS ``createNamedTool``).""" + + @tool(name=name, description=f"test tool {name}") + def _t() -> str: + return "ok" + + return _t + + +def _tool_named(mm: MemoryManager, name: str) -> Any: + """Return the manager-registered tool with the given ``tool_name``.""" + for built in mm.tools: + if built.tool_name == name: + return built + registered_names = [registered_tool.tool_name for registered_tool in mm.tools] + raise AssertionError(f"tool {name!r} not registered; have {registered_names}") + + +def _tool_names(mm: MemoryManager) -> list[str]: + return [registered_tool.tool_name for registered_tool in mm.tools] + + +def _added_metadata(call: Any) -> Any: + """Pull the metadata argument from a recorded ``store.add`` call.""" + if len(call.args) > 1: + return call.args[1] + return call.kwargs.get("metadata") + + +def _added_content(call: Any) -> Any: + """Pull the content argument from a recorded ``store.add`` call.""" + return call.args[0] if call.args else call.kwargs.get("content") + + +def _forwarded_max(mock: AsyncMock) -> Any: + """Return the ``max_search_results`` forwarded to a store's ``search``.""" + call = mock.call_args + options = call.args[1] if len(call.args) > 1 else call.kwargs.get("options") + if options is None: + return None + if isinstance(options, dict): + return options.get("max_search_results") + return getattr(options, "max_search_results", None) + + +def _extractor_context(call: Any) -> Any: + """Pull the ``ExtractorContext`` argument from a recorded ``extract`` call.""" + if len(call.args) > 1: + return call.args[1] + return call.kwargs.get("context") + + +def _user_msg(text: str) -> dict: + return {"role": "user", "content": [{"text": text}]} + + +def _assistant_msg(text: str) -> dict: + return {"role": "assistant", "content": [{"text": text}]} + + +class _FakeAgent: + """Minimal agent stand-in for ``init_agent`` wiring. + + The manager only uses ``agent.add_hook(callback, event_type, *, order=...)`` + and ``agent.model``. Recorded hooks are kept as ``(callback, event_type, + order)`` triples so tests can fire the matching events manually. + """ + + def __init__(self, model: Any = None) -> None: + self.model = model + self.hooks: list[tuple[Any, Any, float]] = [] + + def add_hook(self, callback: Any, event_type: Any = None, *, order: float = HookOrder.DEFAULT) -> None: + self.hooks.append((callback, event_type, order)) + + +async def _invoke_all(agent: _FakeAgent, event: Any) -> None: + """Fire every recorded hook registered for ``event``'s type.""" + for callback, event_type, _order in list(agent.hooks): + if event_type is type(event): + result = callback(event) + if inspect.isawaitable(result): + await result + + +async def _add_messages(agent: _FakeAgent, *messages: dict) -> None: + """Drive ``MessageAddedEvent`` for each message into the coordinator buffer.""" + for message in messages: + await _invoke_all(agent, MessageAddedEvent(agent=agent, message=message)) + + +async def _fire_invocation(agent: _FakeAgent, mm: MemoryManager) -> None: + """Fire ``AfterInvocationEvent`` (drives triggers), then flush the manager.""" + await _invoke_all(agent, AfterInvocationEvent(agent=agent)) + await mm.flush() + + +# --------------------------------------------------------------------------- # +# Constructor / store validation (Requirement 2) +# --------------------------------------------------------------------------- # + + +def test_constructor_raises_when_stores_empty(): + with pytest.raises(Exception, match="at least one store is required"): + MemoryManager(stores=[]) + + +def test_constructor_creates_instance_with_valid_config_and_name(): + mm = MemoryManager(stores=[_store("test")]) + assert mm.name == "strands:memory-manager" + + +def test_constructor_raises_on_duplicate_store_name(): + with pytest.raises(Exception, match="duplicate store name"): + MemoryManager(stores=[_store("dup"), _store("dup")]) + + +def test_constructor_raises_when_writable_without_write_sink(): + broken = _store("broken", writable=True, sinks=set()) + with pytest.raises(Exception, match="no add or add_messages"): + MemoryManager(stores=[broken]) + + +def test_constructor_raises_when_add_tool_enabled_but_no_store_implements_add(): + with pytest.raises(Exception, match="no writable stores implement add"): + MemoryManager(stores=[_store("a")], add_tool_config=True) + + +def test_constructor_allows_add_tool_config_true_with_single_writable_store(): + mm = MemoryManager(stores=[_store("a", writable=True)], add_tool_config=True) + assert "add_memory" in _tool_names(mm) + + +def test_constructor_allows_add_tool_config_true_with_multiple_writable_stores(): + mm = MemoryManager( + stores=[_store("a", writable=True), _store("b", writable=True)], + add_tool_config=True, + ) + assert "add_memory" in _tool_names(mm) + + +def test_constructor_raises_when_add_tool_config_stores_names_nonexistent(): + with pytest.raises(Exception, match="not found"): + MemoryManager( + stores=[_store("a", writable=True)], + add_tool_config=MemoryAddToolConfig(stores=["nonexistent"]), + ) + + +def test_constructor_raises_when_add_tool_config_stores_names_non_writable(): + with pytest.raises(Exception, match="not writable"): + MemoryManager( + stores=[_store("a", writable=True), _store("readonly")], + add_tool_config=MemoryAddToolConfig(stores=["readonly"]), + ) + + +def test_constructor_raises_when_add_tool_config_stores_names_writable_store_without_add(): + # R2.12: a referenced store is writable but exposes only ``add_messages`` (no + # ``add``), so the add tool cannot write discrete entries to it. The + # ``add``-capable peer keeps the construction otherwise valid up to this check. + add_only = _store("add-only", writable=True, sinks={"add"}) + messages_only = _store("messages-only", writable=True, sinks={"add_messages"}) + with pytest.raises(Exception, match="has no add method"): + MemoryManager( + stores=[add_only, messages_only], + add_tool_config=MemoryAddToolConfig(stores=["messages-only"]), + ) + + +@pytest.mark.asyncio +async def test_constructor_accepts_memory_store_instances_in_add_tool_config_stores(): + personal = _store("personal", writable=True) + team = _store("team", writable=True) + # Pass the store instance instead of its name; resolves by name to scope to it. + mm = MemoryManager(stores=[personal, team], add_tool_config=MemoryAddToolConfig(stores=[personal])) + + await _tool_named(mm, "add_memory")(entries=["fact"]) + + personal.add.assert_called_once() + assert _added_content(personal.add.call_args) == "fact" + team.add.assert_not_called() + + +def test_constructor_raises_when_add_tool_config_stores_instance_not_configured(): + configured = _store("configured", writable=True) + stray = _store("stray", writable=True) + with pytest.raises(Exception, match="not found"): + MemoryManager(stores=[configured], add_tool_config=MemoryAddToolConfig(stores=[stray])) + + +# --- extraction-related construction validation (Requirement 2.5-2.8) ------- # + + +def test_constructor_raises_when_extraction_store_not_writable(): + store = _store("s", writable=False, sinks=set(), extraction=ExtractionConfig(trigger=InvocationTrigger())) + with pytest.raises(Exception, match="not writable"): + MemoryManager(stores=[store]) + + +def test_constructor_raises_when_extraction_config_has_no_triggers(): + store = _store("s", writable=True, sinks={"add_messages"}, extraction=ExtractionConfig(trigger=[])) + with pytest.raises(Exception, match="no triggers"): + MemoryManager(stores=[store]) + + +def test_constructor_allows_store_writable_only_via_add_messages(): + store = _store("s", writable=True, sinks={"add_messages"}, extraction=ExtractionConfig(trigger=InvocationTrigger())) + # Should not raise. + MemoryManager(stores=[store]) + + +def test_constructor_rejects_add_tool_config_targeting_add_messages_only_store(): + store = _store("s", writable=True, sinks={"add_messages"}, extraction=ExtractionConfig(trigger=InvocationTrigger())) + with pytest.raises(Exception, match="no writable stores implement add"): + MemoryManager(stores=[store], add_tool_config=True) + + +def test_constructor_raises_when_extraction_has_extractor_but_no_add(): + store = _store( + "s", + writable=True, + sinks={"add_messages"}, + extraction=ExtractionConfig(trigger=InvocationTrigger(), extractor=_make_extractor([])), + ) + with pytest.raises(Exception, match="extractor but no add"): + MemoryManager(stores=[store]) + + +def test_constructor_raises_when_extraction_no_extractor_but_no_add_messages(): + store = _store("s", writable=True, sinks={"add"}, extraction=ExtractionConfig(trigger=InvocationTrigger())) + with pytest.raises(Exception, match="without an extractor but no add_messages"): + MemoryManager(stores=[store]) + + +# --------------------------------------------------------------------------- # +# get_tools composition and ordering (Requirement 5) +# --------------------------------------------------------------------------- # + + +def test_get_tools_registers_search_tool_by_default(): + mm = MemoryManager(stores=[_store("test")]) + assert _tool_names(mm) == ["search_memory"] + + +def test_get_tools_registers_add_tool_when_enabled(): + mm = MemoryManager(stores=[_store("test", writable=True)], add_tool_config=True) + assert _tool_names(mm) == ["search_memory", "add_memory"] + + +def test_get_tools_does_not_register_add_tool_by_default(): + mm = MemoryManager(stores=[_store("test", writable=True)]) + assert _tool_names(mm) == ["search_memory"] + + +def test_get_tools_empty_when_search_and_add_disabled_and_no_store_tools(): + mm = MemoryManager( + stores=[_store("test", writable=True)], + search_tool_config=False, + add_tool_config=False, + ) + assert mm.tools == [] + + +def test_get_tools_uses_custom_tool_names(): + mm = MemoryManager( + stores=[_store("test", writable=True)], + search_tool_config=MemoryToolConfig(name="recall"), + add_tool_config=MemoryAddToolConfig(name="remember"), + ) + assert _tool_names(mm) == ["recall", "remember"] + + +def test_get_tools_includes_store_descriptions_in_search_description(): + store = _store("personal", description="User preferences") + mm = MemoryManager(stores=[store]) + search = _tool_named(mm, "search_memory") + description = search.tool_spec["description"] + assert "personal: User preferences" in description + assert "target one or more memory stores by name" in description + + +def test_get_tools_includes_store_descriptions_in_add_description(): + store = _store("notes", writable=True, description="Personal notes") + mm = MemoryManager(stores=[store], add_tool_config=True) + add = _tool_named(mm, "add_memory") + description = add.tool_spec["description"] + assert "notes: Personal notes" in description + assert "target a specific store by name" in description + + +def test_get_tools_aggregates_tools_provided_by_a_store(): + store = _store("kb", tools=[_named_tool("kb_query")]) + mm = MemoryManager(stores=[store]) + assert _tool_names(mm) == ["search_memory", "kb_query"] + + +def test_get_tools_aggregates_store_tools_across_multiple_stores_with_manager_tools(): + store_a = _store("a", writable=True, tools=[_named_tool("a_tool")]) + store_b = _store("b", tools=[_named_tool("b_tool")]) + mm = MemoryManager(stores=[store_a, store_b], add_tool_config=True) + assert _tool_names(mm) == ["search_memory", "add_memory", "a_tool", "b_tool"] + + +def test_get_tools_includes_store_tools_even_when_manager_registers_none(): + store = _store("kb", tools=[_named_tool("kb_query")]) + mm = MemoryManager(stores=[store], search_tool_config=False) + assert _tool_names(mm) == ["kb_query"] + + +# --------------------------------------------------------------------------- # +# Programmatic search (Requirement 3) +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_search_queries_all_stores_and_concatenates_results(): + store1 = _store("a", entries=[MemoryEntry(content="fact one")]) + store2 = _store("b", entries=[MemoryEntry(content="fact two")]) + mm = MemoryManager(stores=[store1, store2]) + + results = await mm.search("query") + + assert results == [ + MemoryEntry(content="fact one", store_name="a"), + MemoryEntry(content="fact two", store_name="b"), + ] + + +@pytest.mark.asyncio +async def test_search_resolves_store_max_search_results_when_caller_omits(): + store = _store("a", max_search_results=5) + mm = MemoryManager(stores=[store]) + + await mm.search("query") + + assert _forwarded_max(store.search) == 5 + + +@pytest.mark.asyncio +async def test_search_forwards_explicit_max_search_results_override(): + store = _store("a", max_search_results=5) + mm = MemoryManager(stores=[store]) + + await mm.search("query", MemorySearchOptions(max_search_results=2)) + + assert _forwarded_max(store.search) == 2 + + +@pytest.mark.asyncio +async def test_search_falls_back_to_default_when_neither_caller_nor_store_specifies(): + store = _store("a") + mm = MemoryManager(stores=[store]) + + await mm.search("query") + + assert _forwarded_max(store.search) == DEFAULT_MAX_SEARCH_RESULTS + + +@pytest.mark.asyncio +async def test_search_filters_to_named_stores(): + store1 = _store("personal", entries=[MemoryEntry(content="personal fact")]) + store2 = _store("team", entries=[MemoryEntry(content="team fact")]) + mm = MemoryManager(stores=[store1, store2]) + + results = await mm.search("query", MemorySearchOptions(stores=["personal"])) + + assert results == [MemoryEntry(content="personal fact", store_name="personal")] + store2.search.assert_not_called() + + +@pytest.mark.asyncio +async def test_search_skips_failing_stores_and_returns_the_rest(): + store1 = _store("failing", search_error=RuntimeError("network error")) + store2 = _store("ok", entries=[MemoryEntry(content="fact")]) + mm = MemoryManager(stores=[store1, store2]) + + results = await mm.search("query") + + assert results == [MemoryEntry(content="fact", store_name="ok")] + + +@pytest.mark.asyncio +async def test_search_searches_all_stores_when_filter_omitted(): + store1 = _store("a", entries=[MemoryEntry(content="fact one")]) + store2 = _store("b", entries=[MemoryEntry(content="fact two")]) + mm = MemoryManager(stores=[store1, store2]) + + results = await mm.search("query") + + assert results == [ + MemoryEntry(content="fact one", store_name="a"), + MemoryEntry(content="fact two", store_name="b"), + ] + + +@pytest.mark.asyncio +async def test_search_searches_no_stores_when_filter_is_empty_list(): + store1 = _store("a", entries=[MemoryEntry(content="fact one")]) + store2 = _store("b", entries=[MemoryEntry(content="fact two")]) + mm = MemoryManager(stores=[store1, store2]) + + results = await mm.search("query", MemorySearchOptions(stores=[])) + + assert results == [] + store1.search.assert_not_called() + store2.search.assert_not_called() + + +@pytest.mark.asyncio +async def test_search_raises_not_found_before_querying_when_named_store_missing(): + store = _store("personal", entries=[MemoryEntry(content="fact")]) + mm = MemoryManager(stores=[store]) + + with pytest.raises(Exception, match="not found"): + await mm.search("query", MemorySearchOptions(stores=["nonexistent"])) + store.search.assert_not_called() + + +# --------------------------------------------------------------------------- # +# Programmatic add (Requirement 4) +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_add_writes_to_all_writable_stores(): + store1 = _store("a", writable=True) + store2 = _store("b", writable=True) + mm = MemoryManager(stores=[store1, store2]) + + await mm.add("user likes coffee") + + assert _added_content(store1.add.call_args) == "user likes coffee" + assert _added_metadata(store1.add.call_args) is None + assert _added_content(store2.add.call_args) == "user likes coffee" + assert _added_metadata(store2.add.call_args) is None + + +@pytest.mark.asyncio +async def test_add_passes_metadata_to_stores(): + store = _store("a", writable=True) + mm = MemoryManager(stores=[store]) + + await mm.add("fact", MemoryAddOptions(metadata={"source": "user"})) + + assert _added_content(store.add.call_args) == "fact" + assert _added_metadata(store.add.call_args) == {"source": "user"} + + +@pytest.mark.asyncio +async def test_add_filters_to_named_stores(): + store1 = _store("personal", writable=True) + store2 = _store("team", writable=True) + mm = MemoryManager(stores=[store1, store2]) + + await mm.add("my preference", MemoryAddOptions(stores=["personal"])) + + assert _added_content(store1.add.call_args) == "my preference" + store2.add.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_dedupes_duplicate_store_names(): + store = _store("personal", writable=True) + mm = MemoryManager(stores=[store]) + + await mm.add("fact", MemoryAddOptions(stores=["personal", "personal"])) + + store.add.assert_called_once() + + +@pytest.mark.asyncio +async def test_add_raises_when_no_writable_stores_match(): + mm = MemoryManager(stores=[_store("a")]) + with pytest.raises(Exception, match="no writable store matched"): + await mm.add("fact") + + +@pytest.mark.asyncio +async def test_add_raises_not_found_when_named_store_missing(): + mm = MemoryManager(stores=[_store("a", writable=True)]) + with pytest.raises(Exception, match="not found"): + await mm.add("fact", MemoryAddOptions(stores=["nonexistent"])) + + +@pytest.mark.asyncio +async def test_add_raises_read_only_when_named_store_not_writable(): + mm = MemoryManager(stores=[_store("readonly")]) + with pytest.raises(Exception, match="read-only"): + await mm.add("fact", MemoryAddOptions(stores=["readonly"])) + + +@pytest.mark.asyncio +async def test_add_awaits_writes_and_raises_aggregate_naming_failed_store(): + failing = _store("failing", writable=True, add_error=RuntimeError("write error")) + ok = _store("ok", writable=True) + mm = MemoryManager(stores=[failing, ok]) + + with pytest.raises(AggregateMemoryError, match="failing") as exc_info: + await mm.add("fact") + + # The remaining store still received its write (partial failure completes the rest). + assert _added_content(ok.add.call_args) == "fact" + assert len(exc_info.value.errors) == 1 + + +# --------------------------------------------------------------------------- # +# Search tool scoping and attribution (Requirement 6) +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_search_tool_queries_all_stores_when_stores_omitted(): + personal = _store("personal", entries=[MemoryEntry(content="personal fact")]) + team = _store("team", entries=[MemoryEntry(content="team fact")]) + mm = MemoryManager(stores=[personal, team]) + + await _tool_named(mm, "search_memory")(query="q") + + personal.search.assert_called() + team.search.assert_called() + + +@pytest.mark.asyncio +async def test_search_tool_treats_empty_stores_as_omitted(): + personal = _store("personal", entries=[MemoryEntry(content="personal fact")]) + team = _store("team", entries=[MemoryEntry(content="team fact")]) + mm = MemoryManager(stores=[personal, team]) + + await _tool_named(mm, "search_memory")(query="q", stores=[]) + + personal.search.assert_called() + team.search.assert_called() + + +@pytest.mark.asyncio +async def test_search_tool_targets_only_the_requested_in_scope_store(): + personal = _store("personal", entries=[MemoryEntry(content="personal fact")]) + team = _store("team", entries=[MemoryEntry(content="team fact")]) + mm = MemoryManager(stores=[personal, team]) + + await _tool_named(mm, "search_memory")(query="q", stores=["personal"]) + + personal.search.assert_called() + team.search.assert_not_called() + + +@pytest.mark.asyncio +async def test_search_tool_attributes_each_result_to_its_store(): + personal = _store("personal", entries=[MemoryEntry(content="personal fact")]) + team = _store("team", entries=[MemoryEntry(content="team fact")]) + mm = MemoryManager(stores=[personal, team]) + + result = await _tool_named(mm, "search_memory")(query="q") + + assert result == [ + {"content": "personal fact", "store_name": "personal"}, + {"content": "team fact", "store_name": "team"}, + ] + + +@pytest.mark.asyncio +async def test_search_tool_keeps_valid_names_and_warns_on_out_of_scope(caplog): + personal = _store("personal", entries=[MemoryEntry(content="personal fact")]) + team = _store("team", entries=[MemoryEntry(content="team fact")]) + mm = MemoryManager(stores=[personal, team]) + + with caplog.at_level(logging.WARNING): + await _tool_named(mm, "search_memory")(query="q", stores=["personal", "nonexistent"]) + + personal.search.assert_called() + team.search.assert_not_called() + assert "nonexistent" in caplog.text + + +@pytest.mark.asyncio +async def test_search_tool_raises_when_every_requested_store_is_out_of_scope(): + personal = _store("personal", entries=[MemoryEntry(content="personal fact")]) + mm = MemoryManager(stores=[personal]) + + with pytest.raises(Exception, match="none of the requested memory stores are available"): + await _tool_named(mm, "search_memory")(query="q", stores=["nonexistent"]) + personal.search.assert_not_called() + + +# --------------------------------------------------------------------------- # +# Add tool scoping and write modes (Requirement 7) +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_add_tool_writes_to_all_writable_stores_when_stores_omitted(): + personal = _store("personal", writable=True) + team = _store("team", writable=True) + mm = MemoryManager(stores=[personal, team], add_tool_config=True) + + await _tool_named(mm, "add_memory")(entries=["fact"]) + + assert _added_content(personal.add.call_args) == "fact" + assert _added_content(team.add.call_args) == "fact" + + +@pytest.mark.asyncio +async def test_add_tool_treats_empty_stores_as_omitted(): + personal = _store("personal", writable=True) + team = _store("team", writable=True) + mm = MemoryManager(stores=[personal, team], add_tool_config=True) + + await _tool_named(mm, "add_memory")(entries=["fact"], stores=[]) + + personal.add.assert_called() + team.add.assert_called() + + +@pytest.mark.asyncio +async def test_add_tool_is_scoped_to_allowlist_excluding_other_writable_stores(): + personal = _store("personal", writable=True) + team = _store("team", writable=True) + mm = MemoryManager(stores=[personal, team], add_tool_config=MemoryAddToolConfig(stores=["personal"])) + + # Omitting stores writes to the configured allowlist only -- not every writable store. + await _tool_named(mm, "add_memory")(entries=["fact"]) + + assert _added_content(personal.add.call_args) == "fact" + team.add.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_tool_rejects_a_writable_store_excluded_from_allowlist(): + personal = _store("personal", writable=True) + extraction_only = _store("extraction-only", writable=True) + mm = MemoryManager(stores=[personal, extraction_only], add_tool_config=MemoryAddToolConfig(stores=["personal"])) + + with pytest.raises(Exception, match="none of the requested memory stores are available"): + await _tool_named(mm, "add_memory")(entries=["fact"], stores=["extraction-only"]) + extraction_only.add.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_tool_excludes_read_only_stores_from_its_scope(): + personal = _store("personal", writable=True) + readonly = _store("readonly") + mm = MemoryManager(stores=[personal, readonly], add_tool_config=True) + + with pytest.raises(Exception, match="none of the requested memory stores are available"): + await _tool_named(mm, "add_memory")(entries=["fact"], stores=["readonly"]) + personal.add.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_tool_keeps_valid_names_and_warns_on_out_of_scope(caplog): + personal = _store("personal", writable=True) + team = _store("team", writable=True) + mm = MemoryManager(stores=[personal, team], add_tool_config=True) + + with caplog.at_level(logging.WARNING): + await _tool_named(mm, "add_memory")(entries=["fact"], stores=["personal", "nonexistent"]) + + assert _added_content(personal.add.call_args) == "fact" + team.add.assert_not_called() + assert "nonexistent" in caplog.text + + +@pytest.mark.asyncio +async def test_add_tool_raises_when_every_requested_store_is_out_of_scope(): + personal = _store("personal", writable=True) + mm = MemoryManager(stores=[personal], add_tool_config=True) + + with pytest.raises(Exception, match="none of the requested memory stores are available"): + await _tool_named(mm, "add_memory")(entries=["fact"], stores=["nonexistent"]) + personal.add.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_tool_rejects_an_empty_entries_list(): + # R7.2: an empty ``entries`` list is rejected without writing to any store. + # NOTE: in TS this is enforced by the tool's Zod ``min(1)`` schema. The Python + # @tool validation model is derived from the closure signature and does not + # enforce the advertised JSON-schema ``minItems``, so calling the closure + # directly relies on a manager-side guard. This test pins that guard. + personal = _store("personal", writable=True) + mm = MemoryManager(stores=[personal], add_tool_config=True) + + with pytest.raises(ValueError, match="requires at least one entry"): + await _tool_named(mm, "add_memory")(entries=[]) + personal.add.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_tool_returns_stored_count_by_default(): + store = _store("notes", writable=True) + mm = MemoryManager(stores=[store], add_tool_config=True) + + result = await _tool_named(mm, "add_memory")(entries=["a", "b"]) + + assert result == {"stored": 2} + + +@pytest.mark.asyncio +async def test_add_tool_raises_flattened_aggregate_with_concrete_reasons_on_failure(): + failing = _store("failing", writable=True, add_error=RuntimeError("write error")) + mm = MemoryManager(stores=[failing], add_tool_config=True) + + with pytest.raises(AggregateMemoryError) as exc_info: + await _tool_named(mm, "add_memory")(entries=["a", "b"]) + + agg = exc_info.value + assert "failed to add 2 of 2 entries" in str(agg) + assert "write error" in str(agg) + # Leaves are the underlying store errors, not the per-entry aggregate errors. + assert len(agg.errors) == 2 + assert all(not isinstance(error, AggregateMemoryError) for error in agg.errors) + + +@pytest.mark.asyncio +async def test_add_tool_wait_for_writes_false_returns_accepted_count(): + store = _store("notes", writable=True) + mm = MemoryManager(stores=[store], add_tool_config=MemoryAddToolConfig(wait_for_writes=False)) + + result = await _tool_named(mm, "add_memory")(entries=["a", "b"]) + + assert result == {"accepted": 2} + await asyncio.sleep(0.05) # let fire-and-forget writes drain + + +@pytest.mark.asyncio +async def test_add_tool_wait_for_writes_false_returns_accepted_even_when_a_write_fails(): + failing = _store("failing", writable=True, add_error=RuntimeError("write error")) + mm = MemoryManager(stores=[failing], add_tool_config=MemoryAddToolConfig(wait_for_writes=False)) + + result = await _tool_named(mm, "add_memory")(entries=["a", "b"]) + + assert result == {"accepted": 2} + await asyncio.sleep(0.05) # let the (swallowed) failing writes drain + + +# --------------------------------------------------------------------------- # +# init_agent extraction wiring (Requirements 8, 9.1) +# --------------------------------------------------------------------------- # + + +def test_init_agent_does_not_throw_without_extraction(): + mm = MemoryManager(stores=[_store("test")]) + mm.init_agent(_FakeAgent()) # should not raise + + +def test_init_agent_registers_no_hooks_when_no_store_has_extraction(): + store = _store("s", writable=True, sinks={"add"}) + mm = MemoryManager(stores=[store]) + agent = _FakeAgent() + + mm.init_agent(agent) + + assert agent.hooks == [] + + +@pytest.mark.asyncio +async def test_init_agent_no_extractor_passthrough_hands_raw_batch_to_add_messages(): + store = _store("s", writable=True, sinks={"add_messages"}, extraction=ExtractionConfig(trigger=InvocationTrigger())) + mm = MemoryManager(stores=[store]) + agent = _FakeAgent() + mm.init_agent(agent) + + await _add_messages(agent, _user_msg("I prefer dark mode"), _assistant_msg("Noted")) + await _fire_invocation(agent, mm) + + store.add_messages.assert_called_once() + batch = store.add_messages.call_args.args[0] + assert len(batch) == 2 + assert batch[0]["role"] == "user" + assert batch[1]["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_init_agent_extractor_route_writes_each_entry_via_add(): + extractor = _make_extractor([ExtractionResult(content="fact one"), ExtractionResult(content="fact two")]) + # The store exposes both sinks so the ``add_messages.assert_not_called()`` + # check below is a real assertion against a mock (not an AttributeError on a + # missing attribute); the extractor route must still write via ``add`` only. + store = _store( + "s", + writable=True, + sinks={"add", "add_messages"}, + extraction=ExtractionConfig(trigger=InvocationTrigger(), extractor=extractor), + ) + mm = MemoryManager(stores=[store]) + agent = _FakeAgent() + mm.init_agent(agent) + + await _add_messages(agent, _user_msg("something happened")) + await _fire_invocation(agent, mm) + + extractor.extract.assert_called_once() + assert store.add.call_count == 2 + store.add_messages.assert_not_called() + + +@pytest.mark.asyncio +async def test_init_agent_passes_agent_model_as_default_model_to_extractor(): + extractor = _make_extractor([]) + store = _store( + "s", + writable=True, + sinks={"add"}, + extraction=ExtractionConfig(trigger=InvocationTrigger(), extractor=extractor), + ) + mm = MemoryManager(stores=[store]) + fake_model = SimpleNamespace(id="model") + agent = _FakeAgent(model=fake_model) + mm.init_agent(agent) + + await _add_messages(agent, _user_msg("hi")) + await _fire_invocation(agent, mm) + + extractor.extract.assert_called_once() + context = _extractor_context(extractor.extract.call_args) + assert context is not None + assert context.default_model is fake_model + + +@pytest.mark.asyncio +async def test_init_agent_interval_trigger_fires_every_n_invocations(): + store = _store( + "s", writable=True, sinks={"add_messages"}, extraction=ExtractionConfig(trigger=IntervalTrigger(turns=2)) + ) + mm = MemoryManager(stores=[store]) + agent = _FakeAgent() + mm.init_agent(agent) + + # Fire the raw hook (not the flushing helper) so we observe interval gating. + await _add_messages(agent, _user_msg("a")) + await _invoke_all(agent, AfterInvocationEvent(agent=agent)) # count 1, no fire + store.add_messages.assert_not_called() + + await _add_messages(agent, _user_msg("b")) + await _fire_invocation(agent, mm) # count 2, fire (+ flush drains it) + store.add_messages.assert_called_once() + assert len(store.add_messages.call_args.args[0]) == 2 + + +@pytest.mark.asyncio +async def test_init_agent_accepts_a_single_trigger_not_wrapped_in_a_list(): + store = _store("s", writable=True, sinks={"add_messages"}, extraction=ExtractionConfig(trigger=InvocationTrigger())) + mm = MemoryManager(stores=[store]) + agent = _FakeAgent() + mm.init_agent(agent) + + await _add_messages(agent, _user_msg("hi")) + await _fire_invocation(agent, mm) + + store.add_messages.assert_called_once() + + +@pytest.mark.asyncio +async def test_init_agent_composes_multiple_triggers_fires_on_any(): + store = _store( + "s", + writable=True, + sinks={"add_messages"}, + extraction=ExtractionConfig(trigger=[IntervalTrigger(turns=2), InvocationTrigger()]), + ) + mm = MemoryManager(stores=[store]) + agent = _FakeAgent() + mm.init_agent(agent) + + await _add_messages(agent, _user_msg("a")) + await _fire_invocation(agent, mm) + + # The invocation trigger fired on turn 1 even though the interval would not have. + store.add_messages.assert_called_once() + + +@pytest.mark.asyncio +async def test_flush_is_a_no_op_when_extraction_is_not_configured(): + store = _store("s", writable=True, sinks={"add"}) + mm = MemoryManager(stores=[store]) + mm.init_agent(_FakeAgent()) + + assert await mm.flush() is None + + +@pytest.mark.asyncio +async def test_flush_force_extracts_a_buffered_tail_whose_trigger_never_fired(): + store = _store( + "s", writable=True, sinks={"add_messages"}, extraction=ExtractionConfig(trigger=IntervalTrigger(turns=5)) + ) + mm = MemoryManager(stores=[store]) + agent = _FakeAgent() + mm.init_agent(agent) + + await _add_messages(agent, _user_msg("a")) + await _invoke_all(agent, AfterInvocationEvent(agent=agent)) # count 1, no fire + await _add_messages(agent, _user_msg("b")) + await _invoke_all(agent, AfterInvocationEvent(agent=agent)) # count 2, no fire + store.add_messages.assert_not_called() + + await mm.flush() + + store.add_messages.assert_called_once() + assert len(store.add_messages.call_args.args[0]) == 2 + + +@pytest.mark.asyncio +async def test_flush_does_not_re_extract_messages_already_processed(): + store = _store("s", writable=True, sinks={"add_messages"}, extraction=ExtractionConfig(trigger=InvocationTrigger())) + mm = MemoryManager(stores=[store]) + agent = _FakeAgent() + mm.init_agent(agent) + + await _add_messages(agent, _user_msg("a")) + await _fire_invocation(agent, mm) # already extracted + flushed + store.add_messages.assert_called_once() + + await mm.flush() # nothing fresh -> no-op + store.add_messages.assert_called_once() + + +@pytest.mark.asyncio +async def test_init_agent_background_save_does_not_block_hook_and_flush_awaits_it(): + release = asyncio.Event() + completed = {"v": False} + + async def add_messages_impl(messages: Any, context: Any = None) -> None: + await release.wait() + completed["v"] = True + + store = _store("s", writable=True, sinks={"add_messages"}, extraction=ExtractionConfig(trigger=InvocationTrigger())) + store.add_messages.side_effect = add_messages_impl + mm = MemoryManager(stores=[store]) + agent = _FakeAgent() + mm.init_agent(agent) + + await _add_messages(agent, _user_msg("hello")) + # Fire the hook directly: it must return while the store write hangs. + await _invoke_all(agent, AfterInvocationEvent(agent=agent)) + + flushed = asyncio.ensure_future(mm.flush()) + await asyncio.sleep(0) + assert completed["v"] is False + + release.set() + await flushed + assert completed["v"] is True + + +@pytest.mark.asyncio +async def test_flush_on_invocation_end_registers_hook_that_awaits_flush(): + extraction_store = _store( + "s", writable=True, sinks={"add_messages"}, extraction=ExtractionConfig(trigger=InvocationTrigger()) + ) + memory_manager = MemoryManager(stores=[extraction_store], flush_on_invocation_end=True) + agent = _FakeAgent() + memory_manager.init_agent(agent) + + flush_hooks = [ + callback + for callback, event_type, _order in agent.hooks + if event_type is AfterInvocationEvent and callback == memory_manager._flush_after_invocation + ] + assert len(flush_hooks) == 1 + + await _add_messages(agent, _user_msg("I prefer dark mode"), _assistant_msg("Noted")) + # Fire the invocation event only; the registered flush hook awaits flush, so + # the store write persists WITHOUT an explicit ``memory_manager.flush()`` call. + await _invoke_all(agent, AfterInvocationEvent(agent=agent)) + + extraction_store.add_messages.assert_called_once() + + +@pytest.mark.asyncio +async def test_flush_on_invocation_end_disabled_by_default_registers_no_flush_hook(): + extraction_store = _store( + "s", writable=True, sinks={"add_messages"}, extraction=ExtractionConfig(trigger=InvocationTrigger()) + ) + memory_manager = MemoryManager(stores=[extraction_store]) + agent = _FakeAgent() + memory_manager.init_agent(agent) + + # Only the recorder (MessageAddedEvent) + trigger (AfterInvocationEvent) hooks + # are registered; the flush method is not among them. + registered_callbacks = [callback for callback, _event_type, _order in agent.hooks] + assert memory_manager._flush_after_invocation not in registered_callbacks + + await _add_messages(agent, _user_msg("I prefer dark mode"), _assistant_msg("Noted")) + # The event alone schedules a background save but does not await it (no flush + # hook), so the store has not been written synchronously by the event. + await _invoke_all(agent, AfterInvocationEvent(agent=agent)) + + extraction_store.add_messages.assert_not_called() + + +def test_init_agent_warns_when_extraction_configured_without_flush_on_invocation_end(caplog): + extraction_store = _store( + "s", writable=True, sinks={"add_messages"}, extraction=ExtractionConfig(trigger=InvocationTrigger()) + ) + memory_manager = MemoryManager(stores=[extraction_store]) + + with caplog.at_level(logging.WARNING): + memory_manager.init_agent(_FakeAgent()) + + assert "flush_on_invocation_end" in caplog.text + + +def test_init_agent_does_not_warn_when_flush_on_invocation_end_enabled(caplog): + extraction_store = _store( + "s", writable=True, sinks={"add_messages"}, extraction=ExtractionConfig(trigger=InvocationTrigger()) + ) + memory_manager = MemoryManager(stores=[extraction_store], flush_on_invocation_end=True) + + with caplog.at_level(logging.WARNING): + memory_manager.init_agent(_FakeAgent()) + + assert "flush_on_invocation_end" not in caplog.text diff --git a/strands-py/tests/strands/memory/test_model_extractor.py b/strands-py/tests/strands/memory/test_model_extractor.py new file mode 100644 index 0000000000..86148f9bc0 --- /dev/null +++ b/strands-py/tests/strands/memory/test_model_extractor.py @@ -0,0 +1,137 @@ +"""Tests for ``ModelExtractor`` parsing and behavior. + +Ported from ``strands-ts/src/memory/extraction/__tests__/model-extractor.test.ts``. +Each TS ``it(...)`` maps to a ``test_...`` case here. + +The extractor is driven through its public ``extract`` API: a fake ``Model`` +(``MockedModelProvider``) yields stream events that +``strands.event_loop.streaming.stream_messages`` aggregates into a chosen +assistant message, so we exercise the real aggregate-and-parse path. + +These tests are written test-first: ``ModelExtractor`` lands in Task 6, so they +are expected to fail with an ``ImportError`` until that implementation exists. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from strands.memory.extraction.model_extractor import ModelExtractor +from strands.memory.extraction.types import ExtractionResult, ExtractorContext +from strands.types.content import Message +from tests.fixtures.mocked_model_provider import MockedModelProvider + + +def _user_turn(text: str) -> Message: + """Build a single user turn, mirroring the TS ``userTurn`` helper.""" + return {"role": "user", "content": [{"text": text}]} + + +def _assistant_text(text: str) -> Message: + """Build an assistant response carrying ``text`` for the fake model to emit.""" + return {"role": "assistant", "content": [{"text": text}]} + + +class _RecordingModel(MockedModelProvider): + """A fake model that counts how many times ``stream`` is invoked. + + Used to assert the extractor short-circuits an empty batch without ever + calling the model. + """ + + def __init__(self, agent_responses: Any) -> None: + super().__init__(agent_responses) + self.stream_calls = 0 + + async def stream(self, *args: Any, **kwargs: Any): # type: ignore[override] + self.stream_calls += 1 + async for event in super().stream(*args, **kwargs): + yield event + + +@pytest.mark.asyncio +async def test_parses_a_json_array_of_entries_from_the_model_response(): + model = MockedModelProvider( + [_assistant_text('[{"content": "User prefers dark mode"}, {"content": "Lives in Berlin"}]')] + ) + extractor = ModelExtractor(model=model) + + entries = await extractor.extract([_user_turn("I like dark mode and live in Berlin")]) + + assert entries == [ + ExtractionResult(content="User prefers dark mode"), + ExtractionResult(content="Lives in Berlin"), + ] + + +@pytest.mark.asyncio +async def test_extracts_a_json_array_even_when_wrapped_in_prose_or_a_code_fence(): + model = MockedModelProvider( + [_assistant_text('Here are the facts:\n```json\n[{"content": "fact"}]\n```\nHope that helps.')] + ) + extractor = ModelExtractor(model=model) + + entries = await extractor.extract([_user_turn("something")]) + + assert entries == [ExtractionResult(content="fact")] + + +@pytest.mark.asyncio +async def test_preserves_entry_metadata(): + model = MockedModelProvider([_assistant_text('[{"content": "fact", "metadata": {"topic": "pref"}}]')]) + extractor = ModelExtractor(model=model) + + entries = await extractor.extract([_user_turn("x")]) + + assert entries == [ExtractionResult(content="fact", metadata={"topic": "pref"})] + + +@pytest.mark.asyncio +async def test_returns_no_entries_on_malformed_json_without_throwing(): + model = MockedModelProvider([_assistant_text("not json at all")]) + extractor = ModelExtractor(model=model) + + entries = await extractor.extract([_user_turn("x")]) + + assert entries == [] + + +@pytest.mark.asyncio +async def test_drops_entries_without_a_string_content_and_empty_strings(): + model = MockedModelProvider([_assistant_text('[{"content": "keep"}, {"content": ""}, {"foo": "bar"}, "loose"]')]) + extractor = ModelExtractor(model=model) + + entries = await extractor.extract([_user_turn("x")]) + + assert entries == [ExtractionResult(content="keep")] + + +@pytest.mark.asyncio +async def test_returns_empty_for_an_empty_message_batch_without_calling_the_model(): + model = _RecordingModel([]) + extractor = ModelExtractor(model=model) + + entries = await extractor.extract([]) + + assert entries == [] + assert model.stream_calls == 0 + + +@pytest.mark.asyncio +async def test_falls_back_to_the_default_model_from_context_when_none_configured(): + model = MockedModelProvider([_assistant_text('[{"content": "fact"}]')]) + extractor = ModelExtractor() + + entries = await extractor.extract([_user_turn("x")], ExtractorContext(default_model=model)) + + assert entries == [ExtractionResult(content="fact")] + + +@pytest.mark.asyncio +async def test_raises_when_no_model_is_configured_and_no_default_is_provided(): + extractor = ModelExtractor() + + with pytest.raises(Exception, match="no model configured"): + await extractor.extract([_user_turn("x")])