|
14 | 14 | from __future__ import annotations
|
15 | 15 |
|
16 | 16 | import re
|
| 17 | +import threading |
17 | 18 | from typing import TYPE_CHECKING
|
18 | 19 |
|
19 | 20 | from typing_extensions import override
|
@@ -42,38 +43,43 @@ class InMemoryMemoryService(BaseMemoryService):
|
42 | 43 |
|
43 | 44 | Uses keyword matching instead of semantic search.
|
44 | 45 |
|
45 |
| - It is not suitable for multi-threaded production environments. Use it for |
46 |
| - testing and development only. |
| 46 | + This class is thread-safe, however, it should be used for testing and |
| 47 | + development only. |
47 | 48 | """
|
48 | 49 |
|
49 | 50 | def __init__(self):
|
| 51 | + self._lock = threading.Lock() |
| 52 | + |
50 | 53 | self._session_events: dict[str, dict[str, list[Event]]] = {}
|
51 |
| - """Keys are app_name/user_id, session_id. Values are session event lists.""" |
| 54 | + """Keys are "{app_name}/{user_id}". Values are dicts of session_id to |
| 55 | + session event lists. |
| 56 | + """ |
52 | 57 |
|
53 | 58 | @override
|
54 | 59 | async def add_session_to_memory(self, session: Session):
|
55 | 60 | user_key = _user_key(session.app_name, session.user_id)
|
56 |
| - self._session_events[user_key] = self._session_events.get( |
57 |
| - _user_key(session.app_name, session.user_id), {} |
58 |
| - ) |
59 |
| - self._session_events[user_key][session.id] = [ |
60 |
| - event |
61 |
| - for event in session.events |
62 |
| - if event.content and event.content.parts |
63 |
| - ] |
| 61 | + |
| 62 | + with self._lock: |
| 63 | + self._session_events[user_key] = self._session_events.get(user_key, {}) |
| 64 | + self._session_events[user_key][session.id] = [ |
| 65 | + event |
| 66 | + for event in session.events |
| 67 | + if event.content and event.content.parts |
| 68 | + ] |
64 | 69 |
|
65 | 70 | @override
|
66 | 71 | async def search_memory(
|
67 | 72 | self, *, app_name: str, user_id: str, query: str
|
68 | 73 | ) -> SearchMemoryResponse:
|
69 | 74 | user_key = _user_key(app_name, user_id)
|
70 |
| - if user_key not in self._session_events: |
71 |
| - return SearchMemoryResponse() |
| 75 | + |
| 76 | + with self._lock: |
| 77 | + session_event_lists = self._session_events.get(user_key, {}) |
72 | 78 |
|
73 | 79 | words_in_query = set(query.lower().split())
|
74 | 80 | response = SearchMemoryResponse()
|
75 | 81 |
|
76 |
| - for session_events in self._session_events[user_key].values(): |
| 82 | + for session_events in session_event_lists.values(): |
77 | 83 | for event in session_events:
|
78 | 84 | if not event.content or not event.content.parts:
|
79 | 85 | continue
|
|
0 commit comments