Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 237 additions & 16 deletions src/langchain_google_cloud_sql_pg/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# TODO: Remove below import when minimum supported Python version is 3.10
from __future__ import annotations

import copy
import json
import uuid
from typing import Any, Callable, Iterable, Optional, Sequence
Expand All @@ -37,6 +38,36 @@
QueryOptions,
)

COMPARISONS_TO_NATIVE = {
"$eq": "=",
"$ne": "!=",
"$lt": "<",
"$lte": "<=",
"$gt": ">",
"$gte": ">=",
}

SPECIAL_CASED_OPERATORS = {
"$in",
"$nin",
"$between",
"$exists",
}

TEXT_OPERATORS = {
"$like",
"$ilike",
}

LOGICAL_OPERATORS = {"$and", "$or", "$not"}

SUPPORTED_OPERATORS = (
set(COMPARISONS_TO_NATIVE)
.union(TEXT_OPERATORS)
.union(LOGICAL_OPERATORS)
.union(SPECIAL_CASED_OPERATORS)
)


class AsyncPostgresVectorStore(VectorStore):
"""Google Cloud SQL for PostgreSQL Vector Store class"""
Expand Down Expand Up @@ -253,7 +284,7 @@ async def __aadd_embeddings(
values_stmt = "VALUES (:id, :content, :embedding"

# Add metadata
extra = metadata
extra = copy.deepcopy(metadata)
for metadata_column in self.metadata_columns:
if metadata_column in metadata:
values_stmt += f", :{metadata_column}"
Expand Down Expand Up @@ -537,7 +568,7 @@ async def __query_collection(
self,
embedding: list[float],
k: Optional[int] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> Sequence[RowMapping]:
"""Perform similarity search query on the vector store table."""
Expand All @@ -555,6 +586,8 @@ async def __query_collection(

column_names = ", ".join(f'"{col}"' for col in columns)

if filter and isinstance(filter, dict):
filter = self._create_filter_clause(filter)
filter = f"WHERE {filter}" if filter else ""
stmt = f"SELECT {column_names}, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.schema_name}\".\"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};"
if self.index_query_options:
Expand All @@ -576,7 +609,7 @@ async def asimilarity_search(
self,
query: str,
k: Optional[int] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected by similarity search on query."""
Expand All @@ -601,7 +634,7 @@ async def asimilarity_search_with_score(
self,
query: str,
k: Optional[int] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs and distance scores selected by similarity search on query."""
Expand All @@ -615,7 +648,7 @@ async def asimilarity_search_by_vector(
self,
embedding: list[float],
k: Optional[int] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected by vector similarity search."""
Expand All @@ -629,7 +662,7 @@ async def asimilarity_search_with_score_by_vector(
self,
embedding: list[float],
k: Optional[int] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs and distance scores selected by vector similarity search."""
Expand Down Expand Up @@ -665,7 +698,7 @@ async def amax_marginal_relevance_search(
k: Optional[int] = None,
fetch_k: Optional[int] = None,
lambda_mult: Optional[float] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance."""
Expand All @@ -686,7 +719,7 @@ async def amax_marginal_relevance_search_by_vector(
k: Optional[int] = None,
fetch_k: Optional[int] = None,
lambda_mult: Optional[float] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance."""
Expand All @@ -709,7 +742,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
k: Optional[int] = None,
fetch_k: Optional[int] = None,
lambda_mult: Optional[float] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""Return docs and distance scores selected using the maximal marginal relevance."""
Expand Down Expand Up @@ -815,6 +848,194 @@ async def is_valid_index(

return bool(len(results) == 1)

def _handle_field_filter(
self,
field: str,
value: Any,
) -> str:
"""Create a filter for a specific field.
Args:
field: name of field
value: value to filter
If provided as is then this will be an equality filter
If provided as a dictionary then this will be a filter, the key
will be the operator and the value will be the value to filter by
Returns:
sql where query as a string
"""
if not isinstance(field, str):
raise ValueError(
f"field should be a string but got: {type(field)} with value: {field}"
)

if field.startswith("$"):
raise ValueError(
f"Invalid filter condition. Expected a field but got an operator: "
f"{field}"
)

# Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters
if not field.isidentifier():
raise ValueError(
f"Invalid field name: {field}. Expected a valid identifier."
)

if isinstance(value, dict):
# This is a filter specification
if len(value) != 1:
raise ValueError(
"Invalid filter condition. Expected a value which "
"is a dictionary with a single key that corresponds to an operator "
f"but got a dictionary with {len(value)} keys. The first few "
f"keys are: {list(value.keys())[:3]}"
)
operator, filter_value = list(value.items())[0]
# Verify that that operator is an operator
if operator not in SUPPORTED_OPERATORS:
raise ValueError(
f"Invalid operator: {operator}. "
f"Expected one of {SUPPORTED_OPERATORS}"
)
else: # Then we assume an equality operator
operator = "$eq"
filter_value = value

if operator in COMPARISONS_TO_NATIVE:
# Then we implement an equality filter
# native is trusted input
if isinstance(filter_value, str):
filter_value = f"'{filter_value}'"
native = COMPARISONS_TO_NATIVE[operator]
return f"({field} {native} {filter_value})"
elif operator == "$between":
# Use AND with two comparisons
low, high = filter_value

return f"({field} BETWEEN {low} AND {high})"
elif operator in {"$in", "$nin", "$like", "$ilike"}:
# We'll do force coercion to text
if operator in {"$in", "$nin"}:
for val in filter_value:
if not isinstance(val, (str, int, float)):
raise NotImplementedError(
f"Unsupported type: {type(val)} for value: {val}"
)

if isinstance(val, bool): # b/c bool is an instance of int
raise NotImplementedError(
f"Unsupported type: {type(val)} for value: {val}"
)

if operator in {"$in"}:
values = str(tuple(val for val in filter_value))
return f"({field} IN {values})"
elif operator in {"$nin"}:
values = str(tuple(val for val in filter_value))
return f"({field} NOT IN {values})"
elif operator in {"$like"}:
return f"({field} LIKE '{filter_value}')"
elif operator in {"$ilike"}:
return f"({field} ILIKE '{filter_value}')"
else:
raise NotImplementedError()
elif operator == "$exists":
if not isinstance(filter_value, bool):
raise ValueError(
"Expected a boolean value for $exists "
f"operator, but got: {filter_value}"
)
else:
if filter_value:
return f"({field} IS NOT NULL)"
else:
return f"({field} IS NULL)"
else:
raise NotImplementedError()

def _create_filter_clause(self, filters: Any) -> str:
"""Create LangChain filter representation to matching SQL where clauses
Args:
filters: Dictionary of filters to apply to the query.
Returns:
String containing the sql where query.
"""

if not isinstance(filters, dict):
raise ValueError(
f"Invalid type: Expected a dictionary but got type: {type(filters)}"
)
if len(filters) == 1:
# The only operators allowed at the top level are $AND, $OR, and $NOT
# First check if an operator or a field
key, value = list(filters.items())[0]
if key.startswith("$"):
# Then it's an operator
if key.lower() not in ["$and", "$or", "$not"]:
raise ValueError(
f"Invalid filter condition. Expected $and, $or or $not "
f"but got: {key}"
)
else:
# Then it's a field
return self._handle_field_filter(key, filters[key])

if key.lower() == "$and" or key.lower() == "$or":
if not isinstance(value, list):
raise ValueError(
f"Expected a list, but got {type(value)} for value: {value}"
)
op = key[1:].upper() # Extract the operator
filter_clause = [self._create_filter_clause(el) for el in value]
if len(filter_clause) > 1:
return f"({f' {op} '.join(filter_clause)})"
elif len(filter_clause) == 1:
return filter_clause[0]
else:
raise ValueError(
"Invalid filter condition. Expected a dictionary "
"but got an empty dictionary"
)
elif key.lower() == "$not":
if isinstance(value, list):
not_conditions = [
self._create_filter_clause(item) for item in value
]
not_stmts = [f"NOT {condition}" for condition in not_conditions]
return f"({' AND '.join(not_stmts)})"
elif isinstance(value, dict):
not_ = self._create_filter_clause(value)
return f"(NOT {not_})"
else:
raise ValueError(
f"Invalid filter condition. Expected a dictionary "
f"or a list but got: {type(value)}"
)
else:
raise ValueError(
f"Invalid filter condition. Expected $and, $or or $not "
f"but got: {key}"
)
elif len(filters) > 1:
# Then all keys have to be fields (they cannot be operators)
for key in filters.keys():
if key.startswith("$"):
raise ValueError(
f"Invalid filter condition. Expected a field but got: {key}"
)
# These should all be fields and combined using an $and operator
and_ = [self._handle_field_filter(k, v) for k, v in filters.items()]
if len(and_) > 1:
return f"({' AND '.join(and_)})"
elif len(and_) == 1:
return and_[0]
else:
raise ValueError(
"Invalid filter condition. Expected a dictionary "
"but got an empty dictionary"
)
else:
return ""

def get_by_ids(self, ids: Sequence[str]) -> list[Document]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead."
Expand All @@ -824,7 +1045,7 @@ def similarity_search(
self,
query: str,
k: Optional[int] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[Document]:
raise NotImplementedError(
Expand Down Expand Up @@ -906,7 +1127,7 @@ def similarity_search_with_score(
self,
query: str,
k: Optional[int] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
raise NotImplementedError(
Expand All @@ -917,7 +1138,7 @@ def similarity_search_by_vector(
self,
embedding: list[float],
k: Optional[int] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[Document]:
raise NotImplementedError(
Expand All @@ -928,7 +1149,7 @@ def similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: Optional[int] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
raise NotImplementedError(
Expand All @@ -941,7 +1162,7 @@ def max_marginal_relevance_search(
k: Optional[int] = None,
fetch_k: Optional[int] = None,
lambda_mult: Optional[float] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[Document]:
raise NotImplementedError(
Expand All @@ -954,7 +1175,7 @@ def max_marginal_relevance_search_by_vector(
k: Optional[int] = None,
fetch_k: Optional[int] = None,
lambda_mult: Optional[float] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[Document]:
raise NotImplementedError(
Expand All @@ -967,7 +1188,7 @@ def max_marginal_relevance_search_with_score_by_vector(
k: Optional[int] = None,
fetch_k: Optional[int] = None,
lambda_mult: Optional[float] = None,
filter: Optional[str] = None,
filter: Optional[dict] | Optional[str] = None,
**kwargs: Any,
) -> list[tuple[Document, float]]:
raise NotImplementedError(
Expand Down
Loading