diff --git a/src/langchain_google_cloud_sql_pg/async_vectorstore.py b/src/langchain_google_cloud_sql_pg/async_vectorstore.py index 0aa4bc07..f900dfe0 100644 --- a/src/langchain_google_cloud_sql_pg/async_vectorstore.py +++ b/src/langchain_google_cloud_sql_pg/async_vectorstore.py @@ -15,6 +15,7 @@ # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations +import copy import json import uuid from typing import Any, Callable, Iterable, Optional, Sequence @@ -37,6 +38,36 @@ QueryOptions, ) +COMPARISONS_TO_NATIVE = { + "$eq": "=", + "$ne": "!=", + "$lt": "<", + "$lte": "<=", + "$gt": ">", + "$gte": ">=", +} + +SPECIAL_CASED_OPERATORS = { + "$in", + "$nin", + "$between", + "$exists", +} + +TEXT_OPERATORS = { + "$like", + "$ilike", +} + +LOGICAL_OPERATORS = {"$and", "$or", "$not"} + +SUPPORTED_OPERATORS = ( + set(COMPARISONS_TO_NATIVE) + .union(TEXT_OPERATORS) + .union(LOGICAL_OPERATORS) + .union(SPECIAL_CASED_OPERATORS) +) + class AsyncPostgresVectorStore(VectorStore): """Google Cloud SQL for PostgreSQL Vector Store class""" @@ -253,7 +284,7 @@ async def __aadd_embeddings( values_stmt = "VALUES (:id, :content, :embedding" # Add metadata - extra = metadata + extra = copy.deepcopy(metadata) for metadata_column in self.metadata_columns: if metadata_column in metadata: values_stmt += f", :{metadata_column}" @@ -537,7 +568,7 @@ async def __query_collection( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> Sequence[RowMapping]: """Perform similarity search query on the vector store table.""" @@ -555,6 +586,8 @@ async def __query_collection( column_names = ", ".join(f'"{col}"' for col in columns) + if filter and isinstance(filter, dict): + filter = self._create_filter_clause(filter) filter = f"WHERE {filter}" if filter else "" stmt = f"SELECT {column_names}, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.schema_name}\".\"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};" if self.index_query_options: @@ -576,7 +609,7 @@ async def asimilarity_search( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by similarity search on query.""" @@ -601,7 +634,7 @@ async def asimilarity_search_with_score( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by similarity search on query.""" @@ -615,7 +648,7 @@ async def asimilarity_search_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by vector similarity search.""" @@ -629,7 +662,7 @@ async def asimilarity_search_with_score_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by vector similarity search.""" @@ -665,7 +698,7 @@ async def amax_marginal_relevance_search( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -686,7 +719,7 @@ async def amax_marginal_relevance_search_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -709,7 +742,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected using the maximal marginal relevance.""" @@ -815,6 +848,194 @@ async def is_valid_index( return bool(len(results) == 1) + def _handle_field_filter( + self, + field: str, + value: Any, + ) -> str: + """Create a filter for a specific field. + Args: + field: name of field + value: value to filter + If provided as is then this will be an equality filter + If provided as a dictionary then this will be a filter, the key + will be the operator and the value will be the value to filter by + Returns: + sql where query as a string + """ + if not isinstance(field, str): + raise ValueError( + f"field should be a string but got: {type(field)} with value: {field}" + ) + + if field.startswith("$"): + raise ValueError( + f"Invalid filter condition. Expected a field but got an operator: " + f"{field}" + ) + + # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters + if not field.isidentifier(): + raise ValueError( + f"Invalid field name: {field}. Expected a valid identifier." + ) + + if isinstance(value, dict): + # This is a filter specification + if len(value) != 1: + raise ValueError( + "Invalid filter condition. Expected a value which " + "is a dictionary with a single key that corresponds to an operator " + f"but got a dictionary with {len(value)} keys. The first few " + f"keys are: {list(value.keys())[:3]}" + ) + operator, filter_value = list(value.items())[0] + # Verify that that operator is an operator + if operator not in SUPPORTED_OPERATORS: + raise ValueError( + f"Invalid operator: {operator}. " + f"Expected one of {SUPPORTED_OPERATORS}" + ) + else: # Then we assume an equality operator + operator = "$eq" + filter_value = value + + if operator in COMPARISONS_TO_NATIVE: + # Then we implement an equality filter + # native is trusted input + if isinstance(filter_value, str): + filter_value = f"'{filter_value}'" + native = COMPARISONS_TO_NATIVE[operator] + return f"({field} {native} {filter_value})" + elif operator == "$between": + # Use AND with two comparisons + low, high = filter_value + + return f"({field} BETWEEN {low} AND {high})" + elif operator in {"$in", "$nin", "$like", "$ilike"}: + # We'll do force coercion to text + if operator in {"$in", "$nin"}: + for val in filter_value: + if not isinstance(val, (str, int, float)): + raise NotImplementedError( + f"Unsupported type: {type(val)} for value: {val}" + ) + + if isinstance(val, bool): # b/c bool is an instance of int + raise NotImplementedError( + f"Unsupported type: {type(val)} for value: {val}" + ) + + if operator in {"$in"}: + values = str(tuple(val for val in filter_value)) + return f"({field} IN {values})" + elif operator in {"$nin"}: + values = str(tuple(val for val in filter_value)) + return f"({field} NOT IN {values})" + elif operator in {"$like"}: + return f"({field} LIKE '{filter_value}')" + elif operator in {"$ilike"}: + return f"({field} ILIKE '{filter_value}')" + else: + raise NotImplementedError() + elif operator == "$exists": + if not isinstance(filter_value, bool): + raise ValueError( + "Expected a boolean value for $exists " + f"operator, but got: {filter_value}" + ) + else: + if filter_value: + return f"({field} IS NOT NULL)" + else: + return f"({field} IS NULL)" + else: + raise NotImplementedError() + + def _create_filter_clause(self, filters: Any) -> str: + """Create LangChain filter representation to matching SQL where clauses + Args: + filters: Dictionary of filters to apply to the query. + Returns: + String containing the sql where query. + """ + + if not isinstance(filters, dict): + raise ValueError( + f"Invalid type: Expected a dictionary but got type: {type(filters)}" + ) + if len(filters) == 1: + # The only operators allowed at the top level are $AND, $OR, and $NOT + # First check if an operator or a field + key, value = list(filters.items())[0] + if key.startswith("$"): + # Then it's an operator + if key.lower() not in ["$and", "$or", "$not"]: + raise ValueError( + f"Invalid filter condition. Expected $and, $or or $not " + f"but got: {key}" + ) + else: + # Then it's a field + return self._handle_field_filter(key, filters[key]) + + if key.lower() == "$and" or key.lower() == "$or": + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) + op = key[1:].upper() # Extract the operator + filter_clause = [self._create_filter_clause(el) for el in value] + if len(filter_clause) > 1: + return f"({f' {op} '.join(filter_clause)})" + elif len(filter_clause) == 1: + return filter_clause[0] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + elif key.lower() == "$not": + if isinstance(value, list): + not_conditions = [ + self._create_filter_clause(item) for item in value + ] + not_stmts = [f"NOT {condition}" for condition in not_conditions] + return f"({' AND '.join(not_stmts)})" + elif isinstance(value, dict): + not_ = self._create_filter_clause(value) + return f"(NOT {not_})" + else: + raise ValueError( + f"Invalid filter condition. Expected a dictionary " + f"or a list but got: {type(value)}" + ) + else: + raise ValueError( + f"Invalid filter condition. Expected $and, $or or $not " + f"but got: {key}" + ) + elif len(filters) > 1: + # Then all keys have to be fields (they cannot be operators) + for key in filters.keys(): + if key.startswith("$"): + raise ValueError( + f"Invalid filter condition. Expected a field but got: {key}" + ) + # These should all be fields and combined using an $and operator + and_ = [self._handle_field_filter(k, v) for k, v in filters.items()] + if len(and_) > 1: + return f"({' AND '.join(and_)})" + elif len(and_) == 1: + return and_[0] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + else: + return "" + def get_by_ids(self, ids: Sequence[str]) -> list[Document]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." @@ -824,7 +1045,7 @@ def similarity_search( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: raise NotImplementedError( @@ -906,7 +1127,7 @@ def similarity_search_with_score( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: raise NotImplementedError( @@ -917,7 +1138,7 @@ def similarity_search_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: raise NotImplementedError( @@ -928,7 +1149,7 @@ def similarity_search_with_score_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: raise NotImplementedError( @@ -941,7 +1162,7 @@ def max_marginal_relevance_search( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: raise NotImplementedError( @@ -954,7 +1175,7 @@ def max_marginal_relevance_search_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: raise NotImplementedError( @@ -967,7 +1188,7 @@ def max_marginal_relevance_search_with_score_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: raise NotImplementedError( diff --git a/src/langchain_google_cloud_sql_pg/vectorstore.py b/src/langchain_google_cloud_sql_pg/vectorstore.py index e59deedf..f5333fd6 100644 --- a/src/langchain_google_cloud_sql_pg/vectorstore.py +++ b/src/langchain_google_cloud_sql_pg/vectorstore.py @@ -551,7 +551,7 @@ async def asimilarity_search( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by similarity search on query.""" @@ -563,7 +563,7 @@ def similarity_search( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by similarity search on query.""" @@ -586,7 +586,7 @@ async def asimilarity_search_with_score( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by similarity search on query.""" @@ -598,7 +598,7 @@ def similarity_search_with_score( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by similarity search on query.""" @@ -610,7 +610,7 @@ async def asimilarity_search_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by vector similarity search.""" @@ -622,7 +622,7 @@ def similarity_search_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by vector similarity search.""" @@ -634,7 +634,7 @@ async def asimilarity_search_with_score_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by vector similarity search.""" @@ -648,7 +648,7 @@ def similarity_search_with_score_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by similarity search on vector.""" @@ -664,7 +664,7 @@ async def amax_marginal_relevance_search( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -680,7 +680,7 @@ def max_marginal_relevance_search( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -696,7 +696,7 @@ async def amax_marginal_relevance_search_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -712,7 +712,7 @@ def max_marginal_relevance_search_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -728,7 +728,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected using the maximal marginal relevance.""" @@ -744,7 +744,7 @@ def max_marginal_relevance_search_with_score_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected using the maximal marginal relevance.""" diff --git a/tests/metadata_filtering_data.py b/tests/metadata_filtering_data.py new file mode 100644 index 00000000..0b5c2024 --- /dev/null +++ b/tests/metadata_filtering_data.py @@ -0,0 +1,260 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +METADATAS = [ + { + "name": "Wireless Headphones", + "code": "WH001", + "price": 149.99, + "is_available": True, + "release_date": "2023-10-26", + "tags": ["audio", "wireless", "electronics"], + "dimensions": [18.5, 7.2, 21.0], + "inventory_location": [101, 102], + "available_quantity": 50, + }, + { + "name": "Ergonomic Office Chair", + "code": "EC002", + "price": 299.00, + "is_available": True, + "release_date": "2023-08-15", + "tags": ["furniture", "office", "ergonomic"], + "dimensions": [65.0, 60.0, 110.0], + "inventory_location": [201], + "available_quantity": 10, + }, + { + "name": "Stainless Steel Water Bottle", + "code": "WB003", + "price": 25.50, + "is_available": False, + "release_date": "2024-01-05", + "tags": ["hydration", "eco-friendly", "kitchen"], + "dimensions": [7.5, 7.5, 25.0], + "available_quantity": 0, + }, + { + "name": "Smart Fitness Tracker", + "code": "FT004", + "price": 79.95, + "is_available": True, + "release_date": "2023-11-12", + "tags": ["fitness", "wearable", "technology"], + "dimensions": [2.0, 1.0, 25.0], + "inventory_location": [401], + "available_quantity": 100, + }, +] + +FILTERING_TEST_CASES = [ + # These tests only involve equality checks + ( + {"code": "FT004"}, + ["FT004"], + ), + # String field + ( + # check name + {"name": "Smart Fitness Tracker"}, + ["FT004"], + ), + # Boolean fields + ( + {"is_available": True}, + ["WH001", "FT004", "EC002"], + ), + # And semantics for top level filtering + ( + {"code": "WH001", "is_available": True}, + ["WH001"], + ), + # These involve equality checks and other operators + # like $ne, $gt, $gte, $lt, $lte + ( + {"available_quantity": {"$eq": 10}}, + ["EC002"], + ), + ( + {"available_quantity": {"$ne": 0}}, + ["WH001", "FT004", "EC002"], + ), + ( + {"available_quantity": {"$gt": 60}}, + ["FT004"], + ), + ( + {"available_quantity": {"$gte": 50}}, + ["WH001", "FT004"], + ), + ( + {"available_quantity": {"$lt": 5}}, + ["WB003"], + ), + ( + {"available_quantity": {"$lte": 10}}, + ["WB003", "EC002"], + ), + # Repeat all the same tests with name (string column) + ( + {"code": {"$eq": "WH001"}}, + ["WH001"], + ), + ( + {"code": {"$ne": "WB003"}}, + ["WH001", "FT004", "EC002"], + ), + # And also gt, gte, lt, lte relying on lexicographical ordering + ( + {"name": {"$gt": "Wireless Headphones"}}, + [], + ), + ( + {"name": {"$gte": "Wireless Headphones"}}, + ["WH001"], + ), + ( + {"name": {"$lt": "Smart Fitness Tracker"}}, + ["EC002"], + ), + ( + {"name": {"$lte": "Smart Fitness Tracker"}}, + ["FT004", "EC002"], + ), + ( + {"is_available": {"$eq": True}}, + ["WH001", "FT004", "EC002"], + ), + ( + {"is_available": {"$ne": True}}, + ["WB003"], + ), + # Test float column. + ( + {"price": {"$gt": 200.0}}, + ["EC002"], + ), + ( + {"price": {"$gte": 149.99}}, + ["WH001", "EC002"], + ), + ( + {"price": {"$lt": 50.0}}, + ["WB003"], + ), + ( + {"price": {"$lte": 79.95}}, + ["FT004", "WB003"], + ), + # These involve usage of AND, OR and NOT operators + ( + {"$or": [{"code": "WH001"}, {"code": "EC002"}]}, + ["WH001", "EC002"], + ), + ( + {"$or": [{"code": "WH001"}, {"available_quantity": 10}]}, + ["WH001", "EC002"], + ), + ( + {"$and": [{"code": "WH001"}, {"code": "EC002"}]}, + [], + ), + # Test for $not operator + ( + {"$not": {"code": "WB003"}}, + ["WH001", "FT004", "EC002"], + ), + ( + {"$not": [{"code": "WB003"}]}, + ["WH001", "FT004", "EC002"], + ), + ( + {"$not": {"available_quantity": 0}}, + ["WH001", "FT004", "EC002"], + ), + ( + {"$not": [{"available_quantity": 0}]}, + ["WH001", "FT004", "EC002"], + ), + ( + {"$not": {"is_available": True}}, + ["WB003"], + ), + ( + {"$not": [{"is_available": True}]}, + ["WB003"], + ), + ( + {"$not": {"price": {"$gt": 150.0}}}, + ["WH001", "FT004", "WB003"], + ), + ( + {"$not": [{"price": {"$gt": 150.0}}]}, + ["WH001", "FT004", "WB003"], + ), + # These involve special operators like $in, $nin, $between + # Test between + ( + {"available_quantity": {"$between": (40, 60)}}, + ["WH001"], + ), + # Test in + ( + {"name": {"$in": ["Smart Fitness Tracker", "Stainless Steel Water Bottle"]}}, + ["FT004", "WB003"], + ), + # With numeric fields + ( + {"available_quantity": {"$in": [0, 10]}}, + ["WB003", "EC002"], + ), + # Test nin + ( + {"name": {"$nin": ["Smart Fitness Tracker", "Stainless Steel Water Bottle"]}}, + ["WH001", "EC002"], + ), + ## with numeric fields + ( + {"available_quantity": {"$nin": [50, 0, 10]}}, + ["FT004"], + ), + # These involve special operators like $like, $ilike that + # may be specified to certain databases. + ( + {"name": {"$like": "Wireless%"}}, + ["WH001"], + ), + ( + {"name": {"$like": "%less%"}}, # adam and jane + ["WH001", "WB003"], + ), + # These involve the special operator $exists + ( + {"tags": {"$exists": False}}, + [], + ), + ( + {"inventory_location": {"$exists": False}}, + ["WB003"], + ), +] + +NEGATIVE_TEST_CASES = [ + {"$nor": [{"code": "WH001"}, {"code": "EC002"}]}, + {"$and": {"is_available": True}}, + {"is_available": {"$and": True}}, + {"is_available": {"name": "{Wireless Headphones", "code": "EC002"}}, + {"my column": {"$and": True}}, + {"is_available": {"code": "WH001", "code": "EC002"}}, +] diff --git a/tests/test_async_vectorstore_search.py b/tests/test_async_vectorstore_search.py index fae5e964..418dbbad 100644 --- a/tests/test_async_vectorstore_search.py +++ b/tests/test_async_vectorstore_search.py @@ -19,6 +19,7 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from metadata_filtering_data import FILTERING_TEST_CASES, METADATAS from sqlalchemy import text from langchain_google_cloud_sql_pg import Column, PostgresEngine @@ -27,6 +28,7 @@ DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_FILTER_TABLE = "test_table_custom_filter" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." @@ -38,7 +40,9 @@ docs = [ Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) ] - +filter_docs = [ + Document(page_content=texts[i], metadata=METADATAS[i]) for i in range(len(texts)) +] embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))] @@ -87,6 +91,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE}") await engine.close() @pytest_asyncio.fixture(scope="class") @@ -129,6 +134,42 @@ async def vs_custom(self, engine): await vs_custom.aadd_documents(docs, ids=ids) yield vs_custom + @pytest_asyncio.fixture(scope="class") + async def vs_custom_filter(self, engine): + await engine._ainit_vectorstore_table( + CUSTOM_FILTER_TABLE, + VECTOR_SIZE, + metadata_columns=[ + Column("name", "TEXT"), + Column("code", "TEXT"), + Column("price", "FLOAT"), + Column("is_available", "BOOLEAN"), + Column("tags", "TEXT[]"), + Column("inventory_location", "INTEGER[]"), + Column("available_quantity", "INTEGER", nullable=True), + ], + id_column="langchain_id", + store_metadata=False, + ) + + vs_custom_filter = await AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_FILTER_TABLE, + metadata_columns=[ + "name", + "code", + "price", + "is_available", + "tags", + "inventory_location", + "available_quantity", + ], + id_column="langchain_id", + ) + await vs_custom_filter.aadd_documents(filter_docs, ids=ids) + yield vs_custom_filter + async def test_asimilarity_search(self, vs): results = await vs.asimilarity_search("foo", k=1) assert len(results) == 1 @@ -287,3 +328,16 @@ def test_get_by_ids(self, vs): test_ids = [ids[0]] with pytest.raises(Exception, match=sync_method_exception_str): vs.get_by_ids(ids=test_ids) + + @pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES) + async def test_vectorstore_with_metadata_filters( + self, + vs_custom_filter, + test_filter, + expected_ids, + ): + """Test end to end construction and search.""" + docs = await vs_custom_filter.asimilarity_search( + "meow", k=5, filter=test_filter + ) + assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter diff --git a/tests/test_vectorstore_search.py b/tests/test_vectorstore_search.py index 2141d951..ae1341ed 100644 --- a/tests/test_vectorstore_search.py +++ b/tests/test_vectorstore_search.py @@ -19,6 +19,7 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from metadata_filtering_data import FILTERING_TEST_CASES, METADATAS, NEGATIVE_TEST_CASES from sqlalchemy import text from langchain_google_cloud_sql_pg import Column, PostgresEngine, PostgresVectorStore @@ -27,6 +28,10 @@ DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_FILTER_TABLE = "test_table_custom_filter" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_FILTER_TABLE_SYNC = "test_table_custom_filter_sync" + str(uuid.uuid4()).replace( + "-", "_" +) VECTOR_SIZE = 768 embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) @@ -37,7 +42,9 @@ docs = [ Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) ] - +filter_docs = [ + Document(page_content=texts[i], metadata=METADATAS[i]) for i in range(len(texts)) +] embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))] @@ -88,6 +95,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): ) yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE}") await engine.close() @pytest_asyncio.fixture(scope="class") @@ -142,6 +150,43 @@ async def vs_custom(self, engine_sync): vs_custom.add_documents(docs, ids=ids) yield vs_custom + @pytest_asyncio.fixture(scope="class") + async def vs_custom_filter(self, engine): + await engine.ainit_vectorstore_table( + CUSTOM_FILTER_TABLE, + VECTOR_SIZE, + metadata_columns=[ + Column("name", "TEXT"), + Column("code", "TEXT"), + Column("price", "FLOAT"), + Column("is_available", "BOOLEAN"), + Column("tags", "TEXT[]"), + Column("inventory_location", "INTEGER[]"), + Column("available_quantity", "INTEGER", nullable=True), + ], + id_column="langchain_id", + store_metadata=False, + overwrite_existing=True, + ) + + vs_custom_filter = await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_FILTER_TABLE, + metadata_columns=[ + "name", + "code", + "price", + "is_available", + "tags", + "inventory_location", + "available_quantity", + ], + id_column="langchain_id", + ) + await vs_custom_filter.aadd_documents(filter_docs, ids=ids) + yield vs_custom_filter + async def test_asimilarity_search(self, vs): results = await vs.asimilarity_search("foo", k=1) assert len(results) == 1 @@ -240,6 +285,19 @@ async def test_aget_by_ids_custom_vs(self, vs_custom): assert results[0] == Document(page_content="foo", id=ids[0]) + @pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES) + async def test_vectorstore_with_metadata_filters( + self, + vs_custom_filter, + test_filter, + expected_ids, + ): + """Test end to end construction and search.""" + docs = await vs_custom_filter.asimilarity_search( + "meow", k=5, filter=test_filter + ) + assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + class TestVectorStoreSearchSync: @pytest.fixture(scope="module") @@ -268,6 +326,7 @@ async def engine_sync(self, db_project, db_region, db_instance, db_name): ) yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE_SYNC}") + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE_SYNC}") await engine.close() @pytest.fixture(scope="class") @@ -297,6 +356,44 @@ def vs_custom(self, engine_sync): vs_custom.add_documents(docs, ids=ids) yield vs_custom + @pytest.fixture(scope="class") + def vs_custom_filter_sync(self, engine_sync): + engine_sync.init_vectorstore_table( + CUSTOM_FILTER_TABLE_SYNC, + VECTOR_SIZE, + metadata_columns=[ + Column("name", "TEXT"), + Column("code", "TEXT"), + Column("price", "FLOAT"), + Column("is_available", "BOOLEAN"), + Column("tags", "TEXT[]"), + Column("inventory_location", "INTEGER[]"), + Column("available_quantity", "INTEGER", nullable=True), + ], + id_column="langchain_id", + store_metadata=False, + overwrite_existing=True, + ) + + vs_custom_filter_sync = PostgresVectorStore.create_sync( + engine_sync, + embedding_service=embeddings_service, + table_name=CUSTOM_FILTER_TABLE_SYNC, + metadata_columns=[ + "name", + "code", + "price", + "is_available", + "tags", + "inventory_location", + "available_quantity", + ], + id_column="langchain_id", + ) + + vs_custom_filter_sync.add_documents(filter_docs, ids=ids) + yield vs_custom_filter_sync + def test_similarity_search(self, vs_custom): results = vs_custom.similarity_search("foo", k=1) assert len(results) == 1 @@ -349,3 +446,22 @@ def test_get_by_ids_custom_vs(self, vs_custom): results = vs_custom.get_by_ids(ids=test_ids) assert results[0] == Document(page_content="foo", id=ids[0]) + + @pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES) + def test_sync_vectorstore_with_metadata_filters( + self, + vs_custom_filter_sync, + test_filter, + expected_ids, + ): + """Test end to end construction and search.""" + + docs = vs_custom_filter_sync.similarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + + @pytest.mark.parametrize("test_filter", NEGATIVE_TEST_CASES) + def test_metadata_filter_negative_tests(self, vs_custom_filter_sync, test_filter): + with pytest.raises((ValueError, NotImplementedError)): + docs = vs_custom_filter_sync.similarity_search( + "meow", k=5, filter=test_filter + )