Skip to content

Commit d0b95b9

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Make InMemoryMemoryService thread-safe
Even though InMemoryMemoryService is intended only for testing and local development, we eliminate a potential source of bugs during prototyping by providing a thread-safe InMemoryMemoryService. PiperOrigin-RevId: 780854478
1 parent 584c8c6 commit d0b95b9

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

src/google/adk/memory/in_memory_memory_service.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
import re
17+
import threading
1718
from typing import TYPE_CHECKING
1819

1920
from typing_extensions import override
@@ -42,38 +43,43 @@ class InMemoryMemoryService(BaseMemoryService):
4243
4344
Uses keyword matching instead of semantic search.
4445
45-
It is not suitable for multi-threaded production environments. Use it for
46-
testing and development only.
46+
This class is thread-safe, however, it should be used for testing and
47+
development only.
4748
"""
4849

4950
def __init__(self):
51+
self._lock = threading.Lock()
52+
5053
self._session_events: dict[str, dict[str, list[Event]]] = {}
51-
"""Keys are app_name/user_id, session_id. Values are session event lists."""
54+
"""Keys are "{app_name}/{user_id}". Values are dicts of session_id to
55+
session event lists.
56+
"""
5257

5358
@override
5459
async def add_session_to_memory(self, session: Session):
5560
user_key = _user_key(session.app_name, session.user_id)
56-
self._session_events[user_key] = self._session_events.get(
57-
_user_key(session.app_name, session.user_id), {}
58-
)
59-
self._session_events[user_key][session.id] = [
60-
event
61-
for event in session.events
62-
if event.content and event.content.parts
63-
]
61+
62+
with self._lock:
63+
self._session_events[user_key] = self._session_events.get(user_key, {})
64+
self._session_events[user_key][session.id] = [
65+
event
66+
for event in session.events
67+
if event.content and event.content.parts
68+
]
6469

6570
@override
6671
async def search_memory(
6772
self, *, app_name: str, user_id: str, query: str
6873
) -> SearchMemoryResponse:
6974
user_key = _user_key(app_name, user_id)
70-
if user_key not in self._session_events:
71-
return SearchMemoryResponse()
75+
76+
with self._lock:
77+
session_event_lists = self._session_events.get(user_key, {})
7278

7379
words_in_query = set(query.lower().split())
7480
response = SearchMemoryResponse()
7581

76-
for session_events in self._session_events[user_key].values():
82+
for session_events in session_event_lists.values():
7783
for event in session_events:
7884
if not event.content or not event.content.parts:
7985
continue

0 commit comments

Comments
 (0)