diff --git a/src/langchain_google_cloud_sql_pg/async_vectorstore.py b/src/langchain_google_cloud_sql_pg/async_vectorstore.py index 0af1f9f4..2e04bb93 100644 --- a/src/langchain_google_cloud_sql_pg/async_vectorstore.py +++ b/src/langchain_google_cloud_sql_pg/async_vectorstore.py @@ -291,6 +291,49 @@ async def __aadd_embeddings( return ids + async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]: + """Get documents by ids.""" + + quoted_ids = [f"'{id_val}'" for id_val in ids] + id_list_str = ", ".join(quoted_ids) + + columns = self.metadata_columns + [ + self.id_column, + self.content_column, + ] + if self.metadata_json_column: + columns.append(self.metadata_json_column) + + column_names = ", ".join(f'"{col}"' for col in columns) + + query = f'SELECT {column_names} FROM "{self.schema_name}"."{self.table_name}" WHERE "{self.id_column}" IN ({id_list_str});' + + async with self.pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + results = result_map.fetchall() + + documents = [] + for row in results: + metadata = ( + row[self.metadata_json_column] + if self.metadata_json_column and row[self.metadata_json_column] + else {} + ) + for col in self.metadata_columns: + metadata[col] = row[col] + documents.append( + ( + Document( + page_content=row[self.content_column], + metadata=metadata, + id=row[self.id_column], + ) + ) + ) + + return documents + async def aadd_texts( self, texts: Iterable[str], @@ -772,6 +815,11 @@ async def is_valid_index( return bool(len(results) == 1) + def get_by_ids(self, ids: Sequence[str]) -> list[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + def similarity_search( self, query: str, diff --git a/src/langchain_google_cloud_sql_pg/vectorstore.py b/src/langchain_google_cloud_sql_pg/vectorstore.py index de7275de..e59deedf 100644 --- a/src/langchain_google_cloud_sql_pg/vectorstore.py +++ b/src/langchain_google_cloud_sql_pg/vectorstore.py @@ -15,7 +15,7 @@ # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations -from typing import Any, Callable, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, Sequence import numpy as np from langchain_core.documents import Document @@ -813,3 +813,11 @@ def is_valid_index( ) -> bool: """Check if index exists in the table.""" return self._engine._run_as_sync(self.__vs.is_valid_index(index_name)) + + async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]: + """Get documents by ids.""" + return await self._engine._run_as_async(self.__vs.aget_by_ids(ids=ids)) + + def get_by_ids(self, ids: Sequence[str]) -> list[Document]: + """Get documents by ids.""" + return self._engine._run_as_sync(self.__vs.aget_by_ids(ids=ids)) diff --git a/tests/test_async_vectorstore_search.py b/tests/test_async_vectorstore_search.py index aaaecea6..fae5e964 100644 --- a/tests/test_async_vectorstore_search.py +++ b/tests/test_async_vectorstore_search.py @@ -28,6 +28,7 @@ DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 +sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) @@ -269,3 +270,20 @@ async def test_max_marginal_relevance_search_vector_score(self, vs_custom): embedding, lambda_mult=0.75, fetch_k=10 ) assert results[0][0] == Document(page_content="bar", id=ids[1]) + + async def test_aget_by_ids(self, vs): + test_ids = [ids[0]] + results = await vs.aget_by_ids(ids=test_ids) + + assert results[0] == Document(page_content="foo", id=ids[0]) + + async def test_aget_by_ids_custom_vs(self, vs_custom): + test_ids = [ids[0]] + results = await vs_custom.aget_by_ids(ids=test_ids) + + assert results[0] == Document(page_content="foo", id=ids[0]) + + def test_get_by_ids(self, vs): + test_ids = [ids[0]] + with pytest.raises(Exception, match=sync_method_exception_str): + vs.get_by_ids(ids=test_ids) diff --git a/tests/test_vectorstore_search.py b/tests/test_vectorstore_search.py index 3ea977ba..2141d951 100644 --- a/tests/test_vectorstore_search.py +++ b/tests/test_vectorstore_search.py @@ -228,6 +228,18 @@ async def test_amax_marginal_relevance_search_vector_score(self, vs): ) assert results[0][0] == Document(page_content="bar", id=ids[1]) + async def test_aget_by_ids(self, vs): + test_ids = [ids[0]] + results = await vs.aget_by_ids(ids=test_ids) + + assert results[0] == Document(page_content="foo", id=ids[0]) + + async def test_aget_by_ids_custom_vs(self, vs_custom): + test_ids = [ids[0]] + results = await vs_custom.aget_by_ids(ids=test_ids) + + assert results[0] == Document(page_content="foo", id=ids[0]) + class TestVectorStoreSearchSync: @pytest.fixture(scope="module") @@ -331,3 +343,9 @@ def test_max_marginal_relevance_search_vector_score(self, vs_custom): embedding, lambda_mult=0.75, fetch_k=10 ) assert results[0][0] == Document(page_content="bar", id=ids[1]) + + def test_get_by_ids_custom_vs(self, vs_custom): + test_ids = [ids[0]] + results = vs_custom.get_by_ids(ids=test_ids) + + assert results[0] == Document(page_content="foo", id=ids[0])