diff --git a/config.yaml.full b/config.yaml.full index 4ebfccc4..e3b0ad4b 100644 --- a/config.yaml.full +++ b/config.yaml.full @@ -119,6 +119,10 @@ database: endpoint: tos-cn-beijing.volces.com # default Volcengine TOS endpoint region: cn-beijing # default Volcengine TOS region bucket: + mem0: + base_url: + api_key: + # [optional] for prompt optimization in cli/app diff --git a/pyproject.toml b/pyproject.toml index ecc76a89..fc7f0163 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ database = [ "pymysql>=1.1.1", # For MySQL database "volcengine>=1.0.193", # For Viking DB "tos>=2.8.4", # For TOS storage and Viking DB + "mem0ai==0.1.118", # For mem0 ] eval = [ "prometheus-client>=0.22.1", # For exporting data to Prometheus pushgateway diff --git a/veadk/configs/database_configs.py b/veadk/configs/database_configs.py index c724af04..62081299 100644 --- a/veadk/configs/database_configs.py +++ b/veadk/configs/database_configs.py @@ -83,6 +83,15 @@ class RedisConfig(BaseSettings): """STS token for Redis auth, not supported yet.""" +class Mem0Config(BaseSettings): + model_config = SettingsConfigDict(env_prefix="DATABASE_MEM0_") + + api_key: str = "" + """Mem0 API key""" + + base_url: str = "" # "https://api.mem0.ai/v1" + + class VikingKnowledgebaseConfig(BaseSettings): model_config = SettingsConfigDict(env_prefix="DATABASE_VIKING_") diff --git a/veadk/memory/long_term_memory.py b/veadk/memory/long_term_memory.py index 5b07c301..2040ccaf 100644 --- a/veadk/memory/long_term_memory.py +++ b/veadk/memory/long_term_memory.py @@ -62,6 +62,12 @@ def _get_backend_cls(backend: str) -> type[BaseLongTermMemoryBackend]: ) return RedisLTMBackend + case "mem0": + from veadk.memory.long_term_memory_backends.mem0_backend import ( + Mem0LTMBackend, + ) + + return Mem0LTMBackend raise ValueError(f"Unsupported long term memory backend: {backend}") @@ -72,7 +78,7 @@ def build_long_term_memory_index(app_name: str, user_id: str): class LongTermMemory(BaseMemoryService, BaseModel): backend: Union[ - Literal["local", "opensearch", "redis", "viking", "viking_mem"], + Literal["local", "opensearch", "redis", "viking", "viking_mem", "mem0"], BaseLongTermMemoryBackend, ] = "opensearch" """Long term memory backend type""" @@ -153,12 +159,6 @@ async def add_session_to_memory( app_name = session.app_name user_id = session.user_id - if self._index != build_long_term_memory_index(app_name, user_id): - logger.warning( - f"The `app_name` or `user_id` is different from the initialized one, skip add session to memory. Initialized index: {self._index}, current built index: {build_long_term_memory_index(app_name, user_id)}" - ) - return - if not self._backend and isinstance(self.backend, str): self._index = build_long_term_memory_index(app_name, user_id) self._backend = _get_backend_cls(self.backend)( @@ -168,6 +168,13 @@ async def add_session_to_memory( f"Initialize long term memory backend now, index is {self._index}" ) + if not self._index and self._index != build_long_term_memory_index( + app_name, user_id + ): + logger.warning( + f"The `app_name` or `user_id` is different from the initialized one, skip add session to memory. Initialized index: {self._index}, current built index: {build_long_term_memory_index(app_name, user_id)}" + ) + return event_strings = self._filter_and_convert_events(session.events) logger.info( diff --git a/veadk/memory/long_term_memory_backends/mem0_backend.py b/veadk/memory/long_term_memory_backends/mem0_backend.py new file mode 100644 index 00000000..bcaae40e --- /dev/null +++ b/veadk/memory/long_term_memory_backends/mem0_backend.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from typing_extensions import override +from pydantic import Field + +from veadk.configs.database_configs import Mem0Config + + +from veadk.memory.long_term_memory_backends.base_backend import ( + BaseLongTermMemoryBackend, +) +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + +try: + from mem0 import MemoryClient + +except ImportError: + logger.error( + "Failed to import mem0 or dotenv. Please install them with 'pip install mem0 '" + ) + raise ImportError("Required packages not installed: mem0") + + +class Mem0LTMBackend(BaseLongTermMemoryBackend): + """Mem0 long term memory backend implementation""" + + mem0_config: Mem0Config = Field(default_factory=Mem0Config) + + def model_post_init(self, __context: Any) -> None: + """Initialize Mem0 client""" + + try: + self._mem0_client = MemoryClient( + # base_url=self.mem0_config.base_url, # mem0 endpoint + api_key=self.mem0_config.api_key, # mem0 API key + ) + logger.info(f"Initialized Mem0 client for index: {self.index}") + except Exception as e: + logger.error(f"Failed to initialize Mem0 client: {str(e)}") + raise + + def precheck_index_naming(self): + """Check if the index name is valid + For Mem0, there are no specific naming constraints + """ + pass + + @override + def save_memory(self, event_strings: list[str], **kwargs) -> bool: + """Save memory to Mem0 + + Args: + event_strings: List of event strings to save + **kwargs: Additional parameters, including 'user_id' for Mem0 + + Returns: + bool: True if saved successfully, False otherwise + """ + user_id = kwargs.get("user_id", "default_user") + + try: + logger.info( + f"Saving {len(event_strings)} events to Mem0 for user: {user_id}" + ) + + for event_string in event_strings: + # Save event string to Mem0 + result = self._mem0_client.add( + [{"role": "user", "content": event_string}], + user_id=user_id, + output_format="v1.1", + ) + logger.debug(f"Saved memory result: {result}") + + logger.info(f"Successfully saved {len(event_strings)} events to Mem0") + return True + except Exception as e: + logger.error(f"Failed to save memory to Mem0: {str(e)}") + return False + + @override + def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]: + """Search memory from Mem0 + + Args: + query: Search query + top_k: Number of results to return + **kwargs: Additional parameters, including 'user_id' for Mem0 + + Returns: + list[str]: List of memory strings + """ + user_id = kwargs.get("user_id", "default_user") + + try: + logger.info( + f"Searching Mem0 for query: {query}, user: {user_id}, top_k: {top_k}" + ) + + memories = self._mem0_client.search( + query, user_id=user_id, output_format="v1.1", top_k=top_k + ) + + memory_list = [] + if memories.get("results", []): + for mem in memories["results"]: + if "memory" in mem: + memory_list.append(mem["memory"]) + + logger.info(f"Found {len(memory_list)} memories matching query: {query}") + return memory_list + except Exception as e: + logger.error(f"Failed to search memory from Mem0: {str(e)}") + return []