diff --git a/src/google/adk/memory/in_memory_memory_service.py b/src/google/adk/memory/in_memory_memory_service.py index 129ba5bcb..8b6f36138 100644 --- a/src/google/adk/memory/in_memory_memory_service.py +++ b/src/google/adk/memory/in_memory_memory_service.py @@ -14,6 +14,7 @@ from __future__ import annotations import re +import threading from typing import TYPE_CHECKING from typing_extensions import override @@ -42,38 +43,43 @@ class InMemoryMemoryService(BaseMemoryService): Uses keyword matching instead of semantic search. - It is not suitable for multi-threaded production environments. Use it for - testing and development only. + This class is thread-safe, however, it should be used for testing and + development only. """ def __init__(self): + self._lock = threading.Lock() + self._session_events: dict[str, dict[str, list[Event]]] = {} - """Keys are app_name/user_id, session_id. Values are session event lists.""" + """Keys are "{app_name}/{user_id}". Values are dicts of session_id to + session event lists. + """ @override async def add_session_to_memory(self, session: Session): user_key = _user_key(session.app_name, session.user_id) - self._session_events[user_key] = self._session_events.get( - _user_key(session.app_name, session.user_id), {} - ) - self._session_events[user_key][session.id] = [ - event - for event in session.events - if event.content and event.content.parts - ] + + with self._lock: + self._session_events[user_key] = self._session_events.get(user_key, {}) + self._session_events[user_key][session.id] = [ + event + for event in session.events + if event.content and event.content.parts + ] @override async def search_memory( self, *, app_name: str, user_id: str, query: str ) -> SearchMemoryResponse: user_key = _user_key(app_name, user_id) - if user_key not in self._session_events: - return SearchMemoryResponse() + + with self._lock: + session_event_lists = self._session_events.get(user_key, {}) words_in_query = set(query.lower().split()) response = SearchMemoryResponse() - for session_events in self._session_events[user_key].values(): + for session_events in session_event_lists.values(): for event in session_events: if not event.content or not event.content.parts: continue