diff --git a/python/semantic_kernel/connectors/memory/sql_server.py b/python/semantic_kernel/connectors/memory/sql_server.py index 448783ad54b1..3f303b0eb180 100644 --- a/python/semantic_kernel/connectors/memory/sql_server.py +++ b/python/semantic_kernel/connectors/memory/sql_server.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import ast import asyncio import json import logging @@ -15,6 +16,7 @@ from azure.identity.aio import DefaultAzureCredential from pydantic import SecretStr, ValidationError, field_validator +from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase from semantic_kernel.data.const import DISTANCE_FUNCTION_DIRECTION_HELPER, DistanceFunction, IndexKind from semantic_kernel.data.record_definition import ( VectorStoreRecordDataField, @@ -22,14 +24,9 @@ VectorStoreRecordKeyField, VectorStoreRecordVectorField, ) -from semantic_kernel.data.text_search import AnyTagsEqualTo, EqualTo, KernelSearchResults -from semantic_kernel.data.vector_search import ( - VectorizedSearchMixin, - VectorSearchFilter, - VectorSearchOptions, - VectorSearchResult, -) -from semantic_kernel.data.vector_storage import VectorStore, VectorStoreRecordCollection +from semantic_kernel.data.text_search import KernelSearchResults +from semantic_kernel.data.vector_search import SearchType, VectorSearch, VectorSearchOptions, VectorSearchResult +from semantic_kernel.data.vector_storage import GetFilteredRecordOptions, VectorStore, VectorStoreRecordCollection from semantic_kernel.exceptions import VectorStoreOperationException from semantic_kernel.exceptions.vector_store_exceptions import ( VectorSearchExecutionException, @@ -37,12 +34,13 @@ ) from semantic_kernel.kernel_pydantic import KernelBaseSettings from semantic_kernel.kernel_types import OneOrMany -from semantic_kernel.utils.feature_stage_decorator import experimental +from semantic_kernel.utils.feature_stage_decorator import release_candidate if sys.version_info >= (3, 12): from typing import override # pragma: no cover else: from typing_extensions import override # pragma: no cover + if sys.version_info >= (3, 11): from typing import Self # pragma: no cover else: @@ -54,7 +52,7 @@ logger = logging.getLogger(__name__) -TKey = TypeVar("TKey", str, int) +TKey = TypeVar("TKey", bound=str | int) TModel = TypeVar("TModel") # maximum number of parameters for SQL Server @@ -66,6 +64,11 @@ DistanceFunction.COSINE_DISTANCE: "cosine", DistanceFunction.EUCLIDEAN_DISTANCE: "euclidean", DistanceFunction.DOT_PROD: "dot", + DistanceFunction.DEFAULT: "cosine", +} +INDEX_KIND_MAP: Final[dict[IndexKind, str]] = { + IndexKind.FLAT: "flat", + IndexKind.DEFAULT: "flat", } __all__ = ["SqlServerCollection", "SqlServerStore"] @@ -73,7 +76,7 @@ # region: Settings -@experimental +@release_candidate class SqlSettings(KernelBaseSettings): """SQL settings. @@ -119,7 +122,7 @@ def validate_connection_string(cls, value: str) -> str: # region: SQL Command and Query Builder -@experimental +@release_candidate class QueryBuilder: """A class that helps you build strings for SQL queries.""" @@ -195,7 +198,7 @@ def __str__(self): return self._file_str.getvalue() -@experimental +@release_candidate class SqlCommand: """A class that represents a SQL command with parameters.""" @@ -261,24 +264,26 @@ async def _get_mssql_connection(settings: SqlSettings) -> "Connection": # region: SQL Server Collection -@experimental +@release_candidate class SqlServerCollection( VectorStoreRecordCollection[TKey, TModel], - VectorizedSearchMixin[TKey, TModel], + VectorSearch[TKey, TModel], Generic[TKey, TModel], ): """SQL collection implementation.""" connection: Any | None = None settings: SqlSettings | None = None - supported_key_types: ClassVar[list[str] | None] = ["str", "int"] - supported_vector_types: ClassVar[list[str] | None] = ["float"] + supported_key_types: ClassVar[set[str] | None] = {"str", "int"} + supported_vector_types: ClassVar[set[str] | None] = {"float"} + supported_search_types: ClassVar[set[SearchType]] = {SearchType.VECTOR} def __init__( self, - collection_name: str, data_model_type: type[TModel], data_model_definition: VectorStoreRecordDefinition | None = None, + collection_name: str | None = None, + embedding_generator: EmbeddingGeneratorBase | None = None, connection_string: str | None = None, connection: "Connection | None" = None, env_file_path: str | None = None, @@ -288,9 +293,10 @@ def __init__( """Initialize the collection. Args: - collection_name: The name of the collection, which corresponds to the table name. data_model_type: The type of the data model. data_model_definition: The data model definition. + collection_name: The name of the collection, which corresponds to the table name. + embedding_generator: The embedding generator to use. connection_string: The connection string to the database. connection: The connection, make sure to set the `LongAsMax=yes` option on the construction string used. env_file_path: Use the environment settings file as a fallback to environment variables. @@ -318,6 +324,7 @@ def __init__( connection=connection, settings=settings, managed_client=managed_client, + embedding_generator=embedding_generator, ) @override @@ -357,11 +364,7 @@ async def _inner_upsert( raise VectorStoreOperationException("connection is not available, use the collection as a context manager.") if not records: return [] - data_fields = [ - field - for field in self.data_model_definition.fields.values() - if isinstance(field, VectorStoreRecordDataField) - ] + data_fields = self.data_model_definition.data_fields vector_fields = self.data_model_definition.vector_fields schema, table = self._get_schema_and_table() # Check how many parameters are likely to be passed @@ -385,26 +388,20 @@ async def _inner_upsert( return keys @override - async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[dict[str, Any]] | None: - """Get records from the database. - - Args: - keys: The keys to get. - **kwargs: Additional arguments. - - Returns: - The records from the store, not deserialized. - """ + async def _inner_get( + self, + keys: Sequence[TKey] | None = None, + options: GetFilteredRecordOptions | None = None, + **kwargs: Any, + ) -> OneOrMany[dict[str, Any]] | None: if not keys: + if options is not None: + raise NotImplementedError("Get without keys is not yet implemented.") return None query = _build_select_query( *self._get_schema_and_table(), self.data_model_definition.key_field, - [ - field - for field in self.data_model_definition.fields.values() - if isinstance(field, VectorStoreRecordDataField) - ], + self.data_model_definition.data_fields, self.data_model_definition.vector_fields if kwargs.get("include_vectors", True) else None, keys, ) @@ -474,15 +471,10 @@ async def create_collection( cursor.execute(query) return - data_fields = [ - field - for field in self.data_model_definition.fields.values() - if isinstance(field, VectorStoreRecordDataField) - ] create_table_query = _build_create_table_query( *self._get_schema_and_table(), key_field=self.data_model_definition.key_field, - data_fields=data_fields, + data_fields=self.data_model_definition.data_fields, vector_fields=self.data_model_definition.vector_fields, if_not_exists=create_if_not_exists, ) @@ -523,29 +515,25 @@ async def delete_collection(self, **kwargs: Any) -> None: @override async def _inner_search( self, + search_type: SearchType, options: VectorSearchOptions, - search_text: str | None = None, - vectorizable_text: str | None = None, - vector: list[float | int] | None = None, + values: Any | None = None, + vector: Sequence[float | int] | None = None, **kwargs: Any, ) -> KernelSearchResults[VectorSearchResult[TModel]]: - if vector is not None: - query = _build_search_query( - *self._get_schema_and_table(), - self.data_model_definition.key_field, - [ - field - for field in self.data_model_definition.fields.values() - if isinstance(field, VectorStoreRecordDataField) - ], - self.data_model_definition.vector_fields, - vector, - options, - ) - elif search_text: - raise VectorSearchExecutionException("Text search not supported.") - elif vectorizable_text: - raise VectorSearchExecutionException("Vectorizable text search not supported.") + if vector is None: + vector = await self._generate_vector_from_values(values, options) + if not vector: + raise VectorSearchExecutionException("No vector provided.") + query = _build_search_query( + *self._get_schema_and_table(), + self.data_model_definition.key_field, + self.data_model_definition.data_fields, + self.data_model_definition.vector_fields, + vector, + options, + self._build_filter(options.filter), # type: ignore + ) return KernelSearchResults( results=self._get_vector_search_results_from_results(self._fetch_records(query), options), @@ -570,6 +558,97 @@ async def _fetch_records(self, query: SqlCommand) -> AsyncIterable[dict[str, Any yield record await asyncio.sleep(0) + @override + def _lambda_parser(self, node: ast.AST) -> "SqlCommand": # type: ignore + """Parse a Python lambda AST node and return a SqlCommand object.""" + command = SqlCommand() + + def parse(node: ast.AST) -> str: + match node: + case ast.Compare(): + if len(node.ops) > 1: + # Chain comparisons (e.g., 1 < x < 3) become AND of each comparison + values = [] + for idx in range(len(node.ops)): + left = node.left if idx == 0 else node.comparators[idx - 1] + right = node.comparators[idx] + op = node.ops[idx] + values.append(parse(ast.Compare(left=left, ops=[op], comparators=[right]))) + return f"({' AND '.join(values)})" + left = parse(node.left) # type: ignore + right_node = node.comparators[0] + op = node.ops[0] + match op: + case ast.In(): + right = parse(right_node) # type: ignore + return f"{left} IN {right}" + case ast.NotIn(): + right = parse(right_node) # type: ignore + return f"{left} NOT IN {right}" + case ast.Eq(): + right = parse(right_node) # type: ignore + return f"{left} = {right}" + case ast.NotEq(): + right = parse(right_node) # type: ignore + return f"{left} <> {right}" + case ast.Gt(): + right = parse(right_node) # type: ignore + return f"{left} > {right}" + case ast.GtE(): + right = parse(right_node) # type: ignore + return f"{left} >= {right}" + case ast.Lt(): + right = parse(right_node) # type: ignore + return f"{left} < {right}" + case ast.LtE(): + right = parse(right_node) # type: ignore + return f"{left} <= {right}" + raise NotImplementedError(f"Unsupported operator: {type(op)}") + case ast.BoolOp(): + op = node.op # type: ignore + values = [parse(v) for v in node.values] + if isinstance(op, ast.And): + return f"({' AND '.join(values)})" + if isinstance(op, ast.Or): + return f"({' OR '.join(values)})" + raise NotImplementedError(f"Unsupported BoolOp: {type(op)}") + case ast.UnaryOp(): + match node.op: + case ast.Not(): + operand = parse(node.operand) + return f"NOT ({operand})" + case ast.UAdd() | ast.USub() | ast.Invert(): + raise NotImplementedError("Unary +, -, ~ are not supported in SQL filters.") + case ast.Attribute(): + # Only allow attributes that are in the data model + if node.attr not in self.data_model_definition.storage_property_names: + raise VectorStoreOperationException( + f"Field '{node.attr}' not in data model (storage property names are used)." + ) + return f"[{node.attr}]" + case ast.Name(): + # Only allow names that are in the data model + if node.id not in self.data_model_definition.storage_property_names: + raise VectorStoreOperationException( + f"Field '{node.id}' not in data model (storage property names are used)." + ) + return f"[{node.id}]" + case ast.Constant(): + # Always use parameterization for constants + command.add_parameter(node.value) + return "?" + case ast.List(): + # For IN/NOT IN lists, parameterize each element + placeholders = [] + for elt in node.elts: + placeholders.append(parse(elt)) + return f"({', '.join(placeholders)})" + raise NotImplementedError(f"Unsupported AST node: {type(node)}") + + where_clause = parse(node) + command.query.append(where_clause) + return command + @override def _get_record_from_result(self, result: dict[str, Any]) -> dict[str, Any]: return result @@ -582,7 +661,7 @@ def _get_score_from_result(self, result: Any) -> float | None: # region: SQL Server Store -@experimental +@release_candidate class SqlServerStore(VectorStore): """SQL Store implementation. @@ -597,6 +676,7 @@ def __init__( self, connection_string: str | None = None, connection: "Connection | None" = None, + embedding_generator: EmbeddingGeneratorBase | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, **kwargs: Any, @@ -606,12 +686,12 @@ def __init__( Args: connection_string: The connection string to the database. connection: The connection, make sure to set the `LongAsMax=yes` option on the construction string used. + embedding_generator: The embedding generator to use. env_file_path: Use the environment settings file as a fallback to environment variables. env_file_encoding: The encoding of the environment settings file. **kwargs: Additional arguments. + """ - managed_client = not connection - settings = None if not connection: try: settings = SqlSettings( @@ -623,8 +703,14 @@ def __init__( raise VectorStoreInitializationException( "Invalid settings provided. Please check the connection string." ) from e - - super().__init__(settings=settings, connection=connection, managed_client=managed_client, **kwargs) + else: + settings = None + super().__init__( + connection=connection, + settings=settings, + embedding_generator=embedding_generator, + **kwargs, + ) @override async def __aenter__(self) -> Self: @@ -664,20 +750,21 @@ async def list_collection_names(self, **kwargs) -> Sequence[str]: @override def get_collection( self, - collection_name: str, - data_model_type: type[object], + data_model_type: type[TModel], + *, data_model_definition: VectorStoreRecordDefinition | None = None, + collection_name: str | None = None, + embedding_generator: EmbeddingGeneratorBase | None = None, **kwargs: Any, - ) -> "VectorStoreRecordCollection": - self.vector_record_collections[collection_name] = SqlServerCollection( - collection_name=collection_name, + ) -> SqlServerCollection: + return SqlServerCollection( data_model_type=data_model_type, data_model_definition=data_model_definition, + collection_name=collection_name, connection=self.connection, - settings=self.settings, + embedding_generator=embedding_generator or self.embedding_generator, **kwargs, ) - return self.vector_record_collections[collection_name] # region: Query Build Functions @@ -769,25 +856,25 @@ def _build_create_table_query( with command.query.in_parenthesis(suffix=";"): # add the key field command.query.append( - f'"{key_field.name}" {_python_type_to_sql(key_field.property_type, is_key=True)} NOT NULL,\n' + f'"{key_field.storage_property_name or key_field.name}" ' + f"{_python_type_to_sql(key_field.property_type, is_key=True)} NOT NULL,\n" ) # add the data fields [ - command.query.append(f'"{field.name}" {_python_type_to_sql(field.property_type)} NULL,\n') + command.query.append( + f'"{field.storage_property_name or field.name}" {_python_type_to_sql(field.property_type)} NULL,\n' + ) for field in data_fields ] # add the vector fields for field in vector_fields: - if field.dimensions is None: - raise VectorStoreOperationException(f"Vector dimensions are not defined for field '{field.name}'") - if field.index_kind is not None and field.index_kind != IndexKind.FLAT: - # Only FLAT index kind is supported - # None is also accepted, which means no explicit index kind - # is set, so implicit default is used + if field.index_kind not in INDEX_KIND_MAP: raise VectorStoreOperationException( f"Index kind '{field.index_kind}' is not supported for field '{field.name}'" ) - command.query.append(f'"{field.name}" VECTOR({field.dimensions}) NULL,\n') + command.query.append( + f'"{field.storage_property_name or field.name}" VECTOR({field.dimensions}) NULL,\n' + ) # set the primary key with command.query.in_parenthesis("PRIMARY KEY", "\n"): command.query.append(key_field.name) @@ -859,9 +946,9 @@ def _add_field_names( """ fields = chain([key_field], data_fields, vector_fields or []) if table_identifier: - strings = [f"{table_identifier}.{field.name}" for field in fields] + strings = [f"{table_identifier}.{field.storage_property_name or field.name}" for field in fields] else: - strings = [field.name for field in fields] + strings = [field.storage_property_name or field.name for field in fields] command.query.append_list(strings) @@ -890,7 +977,7 @@ def _build_merge_query( query_list = [] param_list = [] for field in chain([key_field], data_fields, vector_fields): - value = record.get(field.name) + value = record.get(field.storage_property_name or field.name) # add the field name to the query list query_list.append(_add_cast_check("?", value)) # add the field value to the parameter list @@ -903,12 +990,19 @@ def _build_merge_query( _add_field_names(command, key_field, data_fields, vector_fields) # add the ON clause with command.query.in_parenthesis("ON", "\n"): - command.query.append(f"t.{key_field.name} = s.{key_field.name}") + command.query.append( + f"t.{key_field.storage_property_name or key_field.name} = " + f"s.{key_field.storage_property_name or key_field.name}" + ) # Set the Matched clause command.query.append("WHEN MATCHED THEN\n") command.query.append("UPDATE SET ") command.query.append_list( - [f"t.{field.name} = s.{field.name}" for field in chain(data_fields, vector_fields)], suffix="\n" + [ + f"t.{field.storage_property_name or field.name} = s.{field.storage_property_name or field.name}" + for field in chain(data_fields, vector_fields) + ], + suffix="\n", ) # Set the Not Matched clause command.query.append("WHEN NOT MATCHED THEN\n") @@ -941,7 +1035,7 @@ def _build_select_query( command.query.append_table_name(schema, table, prefix=" FROM", newline=True) # add the WHERE clause if keys: - command.query.append(f"WHERE {key_field.name} IN\n") + command.query.append(f"WHERE {key_field.storage_property_name or key_field.name} IN\n") with command.query.in_parenthesis(): # add the keys command.query.append_list(["?"] * len(keys)) @@ -961,7 +1055,7 @@ def _build_delete_query( # start the DELETE statement command.query.append_table_name(schema, table) # add the WHERE clause - command.query.append(f"WHERE [{key_field.name}] IN") + command.query.append(f"WHERE [{key_field.storage_property_name or key_field.name}] IN") with command.query.in_parenthesis(): # add the keys command.query.append_list(["?"] * len(keys)) @@ -970,32 +1064,15 @@ def _build_delete_query( return command -def _build_filter(command: SqlCommand, filters: VectorSearchFilter): - """Build the filter query based on the data model.""" - if not filters.filters: - return - command.query.append("WHERE ") - for filter in filters.filters: - match filter: - case EqualTo(): - command.query.append(f"[{filter.field_name}] = ? AND\n") - command.add_parameter(_cast_value(filter.value)) - case AnyTagsEqualTo(): - command.query.append(f"? IN [{filter.field_name}] AND\n") - command.add_parameter(_cast_value(filter.value)) - # remove the last AND - command.query.remove_last(4) - command.query.append("\n") - - def _build_search_query( schema: str, table: str, key_field: VectorStoreRecordKeyField, data_fields: list[VectorStoreRecordDataField], vector_fields: list[VectorStoreRecordVectorField], - vector: list[float], + vector: Sequence[float | int], options: VectorSearchOptions, + filter: SqlCommand | list[SqlCommand] | None = None, ) -> SqlCommand: """Build the SELECT query based on the data model.""" # start the SELECT statement @@ -1004,33 +1081,41 @@ def _build_search_query( _add_field_names(command, key_field, data_fields, vector_fields if options.include_vectors else None) # add the vector search clause vector_field: VectorStoreRecordVectorField | None = None - if options.vector_field_name: + if options.vector_property_name: vector_field = next( - (field for field in vector_fields if field.name == options.vector_field_name), + (field for field in vector_fields if field.name == options.vector_property_name), None, ) elif len(vector_fields) == 1: vector_field = vector_fields[0] if not vector_field: raise VectorStoreOperationException("Vector field not specified.") - + if vector_field.distance_function not in DISTANCE_FUNCTION_MAP: + raise VectorStoreOperationException( + f"Distance function '{vector_field.distance_function}' is not supported for field '{vector_field.name}'" + ) + distance_function = DISTANCE_FUNCTION_MAP[vector_field.distance_function] asc: bool = True - if vector_field.distance_function: - distance_function = DISTANCE_FUNCTION_MAP.get(vector_field.distance_function) - if not distance_function: - raise VectorStoreOperationException(f"Distance function '{vector_field.distance_function}' not supported.") - asc = DISTANCE_FUNCTION_DIRECTION_HELPER[vector_field.distance_function](0, 1) - else: - distance_function = "cosine" + asc = DISTANCE_FUNCTION_DIRECTION_HELPER[vector_field.distance_function](0, 1) command.query.append( - f", VECTOR_DISTANCE('{distance_function}', {vector_field.name}, CAST(? AS VECTOR({vector_field.dimensions}))) as {SCORE_FIELD_NAME}\n", # noqa: E501 + f", VECTOR_DISTANCE('{distance_function}', {vector_field.storage_property_name or vector_field.name}, CAST(? AS VECTOR({vector_field.dimensions}))) as {SCORE_FIELD_NAME}\n", # noqa: E501 ) command.add_parameter(_cast_value(vector)) # add the FROM clause command.query.append_table_name(schema, table, prefix=" FROM", newline=True) # add the WHERE clause - _build_filter(command, options.filter) + if filter: + if not isinstance(filter, list): + filter = [filter] + for idx, f in enumerate(filter): + if idx == 0: + command.query.append(" WHERE ") + else: + command.query.append(" AND ") + command.query.append(str(f.query), suffix=" \n") + command.add_parameters(f.parameters) + # add the ORDER BY clause command.query.append(f"ORDER BY {SCORE_FIELD_NAME} {'ASC' if asc else 'DESC'}\n") command.query.append(f"OFFSET {options.skip} ROWS FETCH NEXT {options.top} ROWS ONLY;") diff --git a/python/tests/unit/connectors/memory/test_sql_server.py b/python/tests/unit/connectors/memory/test_sql_server.py new file mode 100644 index 000000000000..63ceeb5a4e79 --- /dev/null +++ b/python/tests/unit/connectors/memory/test_sql_server.py @@ -0,0 +1,555 @@ +# Copyright (c) Microsoft. All rights reserved. + +import json +import sys +from dataclasses import dataclass +from typing import NamedTuple +from unittest.mock import AsyncMock, MagicMock, NonCallableMagicMock, patch + +from pytest import fixture, mark, param, raises + +from semantic_kernel.connectors.memory.sql_server import ( + QueryBuilder, + SqlCommand, + SqlServerCollection, + SqlServerStore, + _build_create_table_query, + _build_delete_query, + _build_delete_table_query, + _build_merge_query, + _build_search_query, + _build_select_query, + _build_select_table_names_query, +) +from semantic_kernel.data.const import DistanceFunction, IndexKind +from semantic_kernel.data.record_definition import ( + VectorStoreRecordDataField, + VectorStoreRecordKeyField, + VectorStoreRecordVectorField, +) +from semantic_kernel.data.vector_search import VectorSearchOptions +from semantic_kernel.exceptions.vector_store_exceptions import ( + VectorStoreInitializationException, + VectorStoreOperationException, +) + + +class TestQueryBuilder: + def test_query_builder_append(self): + qb = QueryBuilder() + qb.append("SELECT * FROM") + qb.append(" table", suffix=";") + result = str(qb).strip() + assert result == "SELECT * FROM table;" + + def test_query_builder_append_list(self): + qb = QueryBuilder() + qb.append_list(["id", "name", "age"], sep=", ", suffix=";") + result = str(qb).strip() + assert result == "id, name, age;" + + def test_query_builder_append_table_name(self): + qb = QueryBuilder() + qb.append_table_name("dbo", "Users", prefix="SELECT * FROM", suffix=";", newline=False) + result = str(qb).strip() + assert result == "SELECT * FROM [dbo].[Users] ;" + + def test_query_builder_remove_last(self): + qb = QueryBuilder("SELECT * FROM table;") + qb.remove_last(1) # remove trailing semicolon + result = str(qb).strip() + assert result == "SELECT * FROM table" + + def test_query_builder_in_parenthesis(self): + qb = QueryBuilder("INSERT INTO table") + with qb.in_parenthesis(): + qb.append("id, name, age") + result = str(qb).strip() + assert result == "INSERT INTO table (id, name, age)" + + def test_query_builder_in_parenthesis_with_prefix_suffix(self): + qb = QueryBuilder() + with qb.in_parenthesis(prefix="VALUES", suffix=";"): + qb.append_list(["1", "'John'", "30"]) + result = str(qb).strip() + assert result == "VALUES (1, 'John', 30) ;" + + def test_query_builder_in_logical_group(self): + qb = QueryBuilder() + with qb.in_logical_group(): + qb.append("UPDATE Users SET name = 'John'") + result = str(qb).strip() + lines = result.splitlines() + assert lines[0] == "BEGIN" + assert lines[1] == "UPDATE Users SET name = 'John'" + assert lines[2] == "END" + + +class TestSqlCommand: + def test_sql_command_initial_query(self): + cmd = SqlCommand("SELECT 1") + assert str(cmd.query) == "SELECT 1" + + def test_sql_command_add_parameter(self): + cmd = SqlCommand("SELECT * FROM Test WHERE id = ?") + cmd.add_parameter("42") + assert cmd.parameters[0] == "42" + + def test_sql_command_add_parameters(self): + cmd = SqlCommand("SELECT * FROM Test WHERE id = ?") + cmd.add_parameters(["42", "43"]) + assert cmd.parameters[0] == "42" + assert cmd.parameters[1] == "43" + + def test_parameter_limit(self): + cmd = SqlCommand() + cmd.add_parameters(["42"] * 2100) + with raises(VectorStoreOperationException): + cmd.add_parameter("43") + with raises(VectorStoreOperationException): + cmd.add_parameters(["43", "44"]) + + +class TestQueryBuildFunctions: + def test_build_create_table_query(self): + schema = "dbo" + table = "Test" + key_field = VectorStoreRecordKeyField(name="id", property_type="str") + data_fields = [ + VectorStoreRecordDataField(name="name", property_type="str"), + VectorStoreRecordDataField(name="age", property_type="int"), + ] + vector_fields = [ + VectorStoreRecordVectorField(name="embedding", property_type="float", dimensions=1536), + ] + cmd = _build_create_table_query(schema, table, key_field, data_fields, vector_fields) + assert not cmd.parameters + cmd_str = str(cmd.query) + assert ( + cmd_str + == 'BEGIN\nCREATE TABLE [dbo].[Test] \n ("id" nvarchar(255) NOT NULL,\n"name" nvarchar(max) NULL,\n"age" ' + 'int NULL,\n"embedding" VECTOR(1536) NULL,\nPRIMARY KEY (id) \n) ;\nEND\n' + ) + + def test_delete_table_query(self): + schema = "dbo" + table = "Test" + cmd = _build_delete_table_query(schema, table) + assert str(cmd.query) == f"DROP TABLE IF EXISTS [{schema}].[{table}] ;" + + @mark.parametrize("schema", ["dbo", None]) + def test_build_select_table_names_query(self, schema): + cmd = _build_select_table_names_query(schema) + if schema: + assert cmd.parameters == [schema] + assert str(cmd) == ( + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " + "WHERE TABLE_TYPE = 'BASE TABLE' " + "AND (@schema is NULL or TABLE_SCHEMA = ?);" + ) + else: + assert str(cmd) == "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE';" + + def test_build_merge_query(self): + schema = "dbo" + table = "Test" + key_field = VectorStoreRecordKeyField(name="id", property_type="str") + data_fields = [ + VectorStoreRecordDataField(name="name", property_type="str"), + VectorStoreRecordDataField(name="age", property_type="int"), + ] + vector_fields = [ + VectorStoreRecordVectorField(name="embedding", property_type="float", dimensions=5), + ] + records = [ + { + "id": "test", + "name": "name", + "age": 50, + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5], + } + ] + cmd = _build_merge_query(schema, table, key_field, data_fields, vector_fields, records) + assert cmd.parameters[0] == records[0]["id"] + assert cmd.parameters[1] == records[0]["name"] + assert cmd.parameters[2] == str(records[0]["age"]) + assert cmd.parameters[3] == json.dumps(records[0]["embedding"]) + str_cmd = str(cmd) + assert str_cmd == ( + "DECLARE @UpsertedKeys TABLE (KeyColumn nvarchar(255));\nMERGE INTO [dbo].[Test] AS t\nUSING ( " + "VALUES (?, ?, ?, ?) ) AS s (id, name, age, embedding) ON (t.id = s.id) \nWHEN MATCHED THEN\nUPDATE " + "SET t.name = s.name, t.age = s.age, t.embedding = s.embedding\nWHEN NOT MATCHED THEN\nINSERT " + "(id, name, age, embedding) VALUES (s.id, s.name, s.age, s.embedding) \nOUTPUT inserted.id " + "INTO @UpsertedKeys (KeyColumn);\nSELECT KeyColumn FROM @UpsertedKeys;\n" + ) + + def test_build_select_query(self): + schema = "dbo" + table = "Test" + key_field = VectorStoreRecordKeyField(name="id", property_type="str") + data_fields = [ + VectorStoreRecordDataField(name="name", property_type="str"), + VectorStoreRecordDataField(name="age", property_type="int"), + ] + vector_fields = [ + VectorStoreRecordVectorField(name="embedding", property_type="float", dimensions=5), + ] + keys = ["test"] + cmd = _build_select_query(schema, table, key_field, data_fields, vector_fields, keys) + assert cmd.parameters == ["test"] + str_cmd = str(cmd) + assert str_cmd == "SELECT\nid, name, age, embedding FROM [dbo].[Test] \nWHERE id IN\n (?) ;" + + def test_build_delete_query(self): + schema = "dbo" + table = "Test" + key_field = VectorStoreRecordKeyField(name="id", property_type="str") + keys = ["test"] + cmd = _build_delete_query(schema, table, key_field, keys) + str_cmd = str(cmd) + assert cmd.parameters[0] == "test" + assert str_cmd == "DELETE FROM [dbo].[Test] WHERE [id] IN (?) ;" + + def test_build_search_query(self): + schema = "dbo" + table = "Test" + key_field = VectorStoreRecordKeyField(name="id", property_type="str") + data_fields = [ + VectorStoreRecordDataField(name="name", property_type="str"), + VectorStoreRecordDataField(name="age", property_type="int"), + ] + vector_fields = [ + VectorStoreRecordVectorField( + name="embedding", + property_type="float", + dimensions=5, + distance_function=DistanceFunction.COSINE_DISTANCE, + ), + ] + vector = [0.1, 0.2, 0.3, 0.4, 0.5] + options = VectorSearchOptions( + vector_property_name="embedding", + ) + + cmd = _build_search_query(schema, table, key_field, data_fields, vector_fields, vector, options) + assert cmd.parameters[0] == json.dumps(vector) + str_cmd = str(cmd) + assert ( + str_cmd == "SELECT id, name, age, VECTOR_DISTANCE('cosine', embedding, CAST(? AS VECTOR(5))) as " + "_vector_distance_value\n FROM [dbo].[Test] \nORDER BY " + "_vector_distance_value ASC\nOFFSET 0 ROWS FETCH NEXT 3 ROWS ONLY;" + ) + + +@fixture +async def mock_connection(*args, **kwargs): + return MagicMock() + + +@mark.parametrize( + "connection_string", + [ + param( + "Driver={ODBC Driver 18 for SQL Server};Server=localhost;Database=testdb;uid=testuserLongAsMax=yes;", + id="with uid", + ), + param( + "Driver={ODBC Driver 18 for SQL Server};Server=localhost;Database=testdb;LongAsMax=yes;", id="credential" + ), + ], +) +async def test_get_mssql_connection(connection_string): + mock_pyodbc = NonCallableMagicMock() + sys.modules["pyodbc"] = mock_pyodbc + + with patch("pyodbc.connect") as patched_connection: + from azure.identity.aio import DefaultAzureCredential + + from semantic_kernel.connectors.memory.sql_server import SqlSettings, _get_mssql_connection + + token = MagicMock() + token.token.return_value = "test_token" + token.token.encode.return_value = b"test_token" + credential = AsyncMock(spec=DefaultAzureCredential) + credential.__aenter__.return_value = credential + credential.get_token.return_value = token + + settings = SqlSettings(connection_string=connection_string) + with patch("semantic_kernel.connectors.memory.sql_server.DefaultAzureCredential", return_value=credential): + connection = await _get_mssql_connection(settings) + assert connection is not None + assert isinstance(connection, MagicMock) + if "uid" in connection_string: + assert patched_connection.call_args.kwargs["attrs_before"] is None + else: + assert patched_connection.call_args.kwargs["attrs_before"] == { + 1256: b"\n\x00\x00\x00test_token", + } + + +class TestSqlServerStore: + async def test_create_store(self, sql_server_unit_test_env): + store = SqlServerStore() + assert store is not None + assert store.settings is not None + assert store.settings.connection_string is not None + assert "LongAsMax=yes;" in store.settings.connection_string.get_secret_value() + + with patch("semantic_kernel.connectors.memory.sql_server._get_mssql_connection") as mock_get_connection: + mock_get_connection.return_value = AsyncMock() + await store.__aenter__() + assert store.connection is not None + + @mark.parametrize( + "override_env_param_dict", + [ + { + "SQL_SERVER_CONNECTION_STRING": "Driver={ODBC Driver 18 for SQL Server};Server=localhost;Database=testdb;User Id=testuser;Password=example;LongAsMax=yes;" # noqa: E501 + } + ], + indirect=True, + ) + def test_create_store_with_long_as_max(self, sql_server_unit_test_env): + store = SqlServerStore() + assert store is not None + assert store.settings is not None + assert store.settings.connection_string is not None + + @mark.parametrize("exclude_list", ["SQL_SERVER_CONNECTION_STRING"], indirect=True) + def test_create_without_connection_string(self, sql_server_unit_test_env): + with raises(VectorStoreInitializationException): + SqlServerStore(env_file_path="test.env") + + def test_get_collection(self, sql_server_unit_test_env, data_model_definition): + store = SqlServerStore() + collection = store.get_collection( + collection_name="test", data_model_type=dict, data_model_definition=data_model_definition + ) + assert collection is not None + + async def test_list_collection_names(self, sql_server_unit_test_env, mock_connection): + async with SqlServerStore(connection=mock_connection) as store: + mock_connection.cursor.return_value.__enter__.return_value.fetchall.return_value = [ + ["Test1"], + ["Test2"], + ] + collection_names = await store.list_collection_names() + assert collection_names == ["Test1", "Test2"] + + async def test_no_connection(self, sql_server_unit_test_env): + store = SqlServerStore() + with raises(VectorStoreOperationException): + await store.list_collection_names() + + +class TestSqlServerCollection: + @mark.parametrize("exclude_list", ["SQL_SERVER_CONNECTION_STRING"], indirect=True) + def test_create_without_connection_string(self, sql_server_unit_test_env, data_model_definition): + with raises(VectorStoreInitializationException): + SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + env_file_path="test.env", + ) + + async def test_create(self, sql_server_unit_test_env, data_model_definition): + collection = SqlServerCollection( + collection_name="test", data_model_type=dict, data_model_definition=data_model_definition + ) + assert collection is not None + assert collection.collection_name == "test" + assert collection.settings is not None + assert collection.settings.connection_string is not None + + with patch("semantic_kernel.connectors.memory.sql_server._get_mssql_connection") as mock_get_connection: + mock_get_connection.return_value = AsyncMock() + await collection.__aenter__() + assert collection.connection is not None + + async def test_upsert( + self, + sql_server_unit_test_env, + mock_connection, + data_model_definition, + ): + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + connection=mock_connection, + ) + record = {"id": "1", "content": "test", "vector": [0.1, 0.2, 0.3, 0.4, 0.5]} + mock_connection.cursor.return_value.__enter__.return_value.nextset.side_effect = [True, False] + mock_connection.cursor.return_value.__enter__.return_value.fetchall.return_value = [ + ["1"], + ] + await collection.upsert(record) + mock_connection.cursor.return_value.__enter__.return_value.execute.assert_called_with( + ( + "DECLARE @UpsertedKeys TABLE (KeyColumn nvarchar(255));\nMERGE INTO [dbo].[test] AS t\nUSING ( VALUES" + " (?, ?, ?) ) AS s (id, content, vector) ON (t.id = s.id) \nWHEN MATCHED THEN\nUPDATE SET t.content" + " = s.content, t.vector = s.vector\nWHEN NOT MATCHED THEN\nINSERT (id, content, vector) VALUES (s.id, " + "s.content, s.vector) \nOUTPUT inserted.id INTO @UpsertedKeys (KeyColumn);\nSELECT KeyColumn " + "FROM @UpsertedKeys;\n" + ), + ("1", "test", json.dumps([0.1, 0.2, 0.3, 0.4, 0.5])), + ) + + async def test_get( + self, + sql_server_unit_test_env, + mock_connection, + data_model_definition, + ): + class MockRow(NamedTuple): + id: str + content: str + vector: str + + mock_cursor = MagicMock() + mock_connection.cursor.return_value.__enter__.return_value = mock_cursor + + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + connection=mock_connection, + ) + key = "1" + + row = MockRow("1", "test", "[0.1, 0.2, 0.3, 0.4, 0.5]") + mock_cursor.description = [["id"], ["content"], ["vector"]] + + mock_cursor.__iter__.return_value = [row] + record = await collection.get(key, include_vectors=True) + mock_cursor.execute.assert_called_with( + "SELECT\nid, content, vector FROM [dbo].[test] \nWHERE id IN\n (?) ;", ("1",) + ) + assert record["id"] == "1" + assert record["content"] == "test" + assert record["vector"] == [0.1, 0.2, 0.3, 0.4, 0.5] + + async def test_delete( + self, + sql_server_unit_test_env, + mock_connection, + data_model_definition, + ): + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + connection=mock_connection, + ) + key = "1" + await collection.delete(key) + mock_connection.cursor.return_value.__enter__.return_value.execute.assert_called_with( + "DELETE FROM [dbo].[test] WHERE [id] IN (?) ;", ("1",) + ) + + async def test_search( + self, + sql_server_unit_test_env, + mock_connection, + data_model_definition, + ): + mock_cursor = MagicMock() + mock_connection.cursor.return_value.__enter__.return_value = mock_cursor + for field in data_model_definition.vector_fields: + field.distance_function = DistanceFunction.COSINE_DISTANCE + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + connection=mock_connection, + ) + vector = [0.1, 0.2, 0.3, 0.4, 0.5] + + @dataclass + class MockRow: + id: str + content: str + _vector_distance_value: float + + row = MockRow("1", "test", 0.1) + mock_cursor.description = [["id"], ["content"], ["_vector_distance_value"]] + + mock_cursor.__iter__.return_value = [row] + search_result = await collection.search( + vector=vector, + vector_property_name="vector", + filter=lambda x: x.content == "test", + ) + async for record in search_result.results: + assert record.record["id"] == "1" + assert record.record["content"] == "test" + assert record.score == 0.1 + mock_cursor.execute.assert_called_with( + ( + "SELECT id, content, VECTOR_DISTANCE('cosine', vector, CAST(? AS VECTOR(5))) as " + "_vector_distance_value\n FROM [dbo].[test] \n WHERE [content] = ? \nORDER BY _vector_distance_value " + "ASC\nOFFSET 0 ROWS FETCH NEXT 3 ROWS ONLY;" + ), + (json.dumps(vector), "test"), + ) + + async def test_create_collection( + self, + sql_server_unit_test_env, + mock_connection, + data_model_definition, + ): + for field in data_model_definition.vector_fields: + field.index_kind = IndexKind.FLAT + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + connection=mock_connection, + ) + await collection.create_collection() + mock_connection.cursor.return_value.__enter__.return_value.execute.assert_called_with( + ( + "IF OBJECT_ID(N' [dbo].[test] ', N'U') IS NULL\nBEGIN\nCREATE TABLE [dbo].[test] \n (\"id\" nvarchar" + '(255) NOT NULL,\n"content" nvarchar(max) NULL,\n"vector" VECTOR(5) NULL,\nPRIMARY KEY (id) \n) ;' + "\nEND\n" + ), + (), + ) + + async def test_delete_collection( + self, + sql_server_unit_test_env, + mock_connection, + data_model_definition, + ): + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + connection=mock_connection, + ) + await collection.delete_collection() + mock_connection.cursor.return_value.__enter__.return_value.execute.assert_called_with( + "DROP TABLE IF EXISTS [dbo].[test] ;", () + ) + + async def test_no_connection(self, sql_server_unit_test_env, data_model_definition): + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + with raises(VectorStoreOperationException): + await collection.create_collection() + with raises(VectorStoreOperationException): + await collection.delete_collection() + with raises(VectorStoreOperationException): + await collection.does_collection_exist() + with raises(VectorStoreOperationException): + await collection.upsert({"id": "1", "content": "test", "vector": [0.1, 0.2, 0.3, 0.4, 0.5]}) + with raises(VectorStoreOperationException): + await collection.get("1") + with raises(VectorStoreOperationException): + await collection.delete("1")