From 87591ae38c1ae6174e0007467eb12fe5f52e4cfc Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Fri, 6 Mar 2026 16:18:24 +0100 Subject: [PATCH 1/6] additional fix for filtering --- .../semantic_kernel/connectors/in_memory.py | 44 ++++++++++++++++--- .../unit/connectors/memory/test_in_memory.py | 43 ++++++++++++++++++ 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/python/semantic_kernel/connectors/in_memory.py b/python/semantic_kernel/connectors/in_memory.py index b62cc62a8420..914f9039e3cd 100644 --- a/python/semantic_kernel/connectors/in_memory.py +++ b/python/semantic_kernel/connectors/in_memory.py @@ -2,8 +2,8 @@ import ast import sys -from collections.abc import AsyncIterable, Callable, Sequence -from typing import Any, ClassVar, Final, Generic, TypeVar +from collections.abc import AsyncIterable, Callable, Mapping, Sequence +from typing import Any, ClassVar, Final, Generic, TypeVar, cast from numpy import dot from pydantic import Field @@ -81,6 +81,33 @@ def __delattr__(self, name) -> None: raise AttributeError(name) +class ReadOnlyAttributeDict(Mapping[TAKey, TAValue], Generic[TAKey, TAValue]): + """A read-only mapping that allows attribute access to keys.""" + + def __init__(self, data: Mapping[TAKey, TAValue]): + """Initialize the read-only mapping wrapper.""" + self._data = data + + def __getitem__(self, key: TAKey) -> TAValue: + """Get a value by key.""" + return self._data[key] + + def __iter__(self): + """Iterate over keys.""" + return iter(self._data) + + def __len__(self) -> int: + """Return the number of keys.""" + return len(self._data) + + def __getattr__(self, name: str) -> TAValue: + """Allow attribute-style access to mapping keys.""" + try: + return self._data[cast(TAKey, name)] + except KeyError: + raise AttributeError(name) + + class InMemoryCollection( VectorStoreCollection[TKey, TModel], VectorSearch[TKey, TModel], @@ -187,6 +214,8 @@ class InMemoryCollection( "__getattribute__", "__setattr__", "__delattr__", + "__setitem__", + "__delitem__", # Import and builtins "__builtins__", "__import__", @@ -431,13 +460,18 @@ def _parse_and_validate_filter(self, filter_str: str) -> Callable: # For Call nodes, validate that only allowed functions are called if isinstance(node, ast.Call): - func_name = None + func_name: str if isinstance(node.func, ast.Name): func_name = node.func.id elif isinstance(node.func, ast.Attribute): func_name = node.func.attr + else: + raise VectorStoreOperationException( + f"Call target node type '{type(node.func).__name__}' is not allowed in filter expressions. " + "Only direct function and method calls are supported." + ) - if func_name and func_name not in self.allowed_filter_functions: + if func_name not in self.allowed_filter_functions: raise VectorStoreOperationException( f"Function '{func_name}' is not allowed in filter expressions. " f"Allowed functions: {', '.join(sorted(self.allowed_filter_functions))}" @@ -457,7 +491,7 @@ def _parse_and_validate_filter(self, filter_str: str) -> Callable: def _run_filter(self, filter: Callable, record: AttributeDict[TAKey, TAValue]) -> bool: """Run the filter on the record, supporting attribute access.""" try: - return filter(record) + return filter(ReadOnlyAttributeDict(record)) except Exception as e: raise VectorStoreOperationException(f"Error running filter: {e}") from e diff --git a/python/tests/unit/connectors/memory/test_in_memory.py b/python/tests/unit/connectors/memory/test_in_memory.py index 1f334403e055..1e5f8f64d167 100644 --- a/python/tests/unit/connectors/memory/test_in_memory.py +++ b/python/tests/unit/connectors/memory/test_in_memory.py @@ -172,3 +172,46 @@ async def test_multiple_filters(collection): results = collection._get_filtered_records(type("opt", (), {"filter": filters})()) assert len(results) == 1 assert "1" in results + + +@mark.parametrize( + "filter_str", + [ + "lambda x: [x.clear][0]() or True", + "lambda x: [x.update][0]({'role': 'admin'}) or True", + "lambda x: [x.pop][0]('secret', '') or True", + "lambda x: [x.__setitem__][0]('leaked', ['{0.__class__.__mro__}'.format][0](x)) or True", + ], +) +def test_malicious_subscript_call_patterns_blocked(collection, filter_str): + with raises(VectorStoreOperationException, match="Call target node type 'Subscript' is not allowed"): + collection._parse_and_validate_filter(filter_str) + + +def test_direct_mutating_method_call_remains_blocked(collection): + with raises(VectorStoreOperationException, match="Function 'clear' is not allowed"): + collection._parse_and_validate_filter("lambda x: x.clear() or True") + + +async def test_valid_lambda_filter_with_get_method(collection): + record1 = {"id": "1", "vector": [1, 2, 3, 4, 5]} + record2 = {"id": "2", "vector": [5, 4, 3, 2, 1]} + await collection.upsert([record1, record2]) + results = collection._get_filtered_records(type("opt", (), {"filter": "lambda x: x.get('id') == '1'"})()) + assert len(results) == 1 + assert "1" in results + + +async def test_callable_filter_cannot_mutate_stored_record(collection): + record = {"id": "1", "content": "value", "vector": [1, 2, 3, 4, 5]} + await collection.upsert(record) + + def mutating_filter(x): + x["role"] = "admin" + return True + + with raises(VectorStoreOperationException, match="Error running filter"): + collection._get_filtered_records(type("opt", (), {"filter": mutating_filter})()) + + assert "role" not in collection.inner_storage["1"] + assert collection.inner_storage["1"]["content"] == "value" From 7e33298a3303026d181ba3cd773dc1000a88128e Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Fri, 6 Mar 2026 16:24:38 +0100 Subject: [PATCH 2/6] additional check --- .../unit/connectors/memory/test_in_memory.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/tests/unit/connectors/memory/test_in_memory.py b/python/tests/unit/connectors/memory/test_in_memory.py index 1e5f8f64d167..76b7fc239912 100644 --- a/python/tests/unit/connectors/memory/test_in_memory.py +++ b/python/tests/unit/connectors/memory/test_in_memory.py @@ -3,6 +3,7 @@ from pytest import fixture, mark, raises from semantic_kernel.connectors.in_memory import InMemoryCollection, InMemoryStore +from semantic_kernel.data._shared import default_dynamic_filter_function from semantic_kernel.data.vector import DistanceFunction from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreOperationException @@ -215,3 +216,21 @@ def mutating_filter(x): assert "role" not in collection.inner_storage["1"] assert collection.inner_storage["1"]["content"] == "value" + + +def test_default_dynamic_filter_injection_payload_is_blocked(collection): + class Param: + def __init__(self, name, default_value=None): + self.name = name + self.default_value = default_value + + injected_value = "' or [x.update][0]({'role':'admin'}) or x.name=='" + generated_filter = default_dynamic_filter_function( + filter=None, + parameters=[Param("category")], + category=injected_value, + ) + + assert isinstance(generated_filter, str) + with raises(VectorStoreOperationException, match="Call target node type 'Subscript' is not allowed"): + collection._parse_and_validate_filter(generated_filter) From 3bdde14a4e17f88d762d046e81c8ad8d889079c3 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Fri, 6 Mar 2026 19:02:40 +0100 Subject: [PATCH 3/6] fix tests --- .../unit/utils/model_diagnostics/test_trace_chat_completion.py | 2 ++ .../model_diagnostics/test_trace_streaming_chat_completion.py | 2 ++ .../model_diagnostics/test_trace_streaming_text_completion.py | 2 ++ .../unit/utils/model_diagnostics/test_trace_text_completion.py | 2 ++ 4 files changed, 8 insertions(+) diff --git a/python/tests/unit/utils/model_diagnostics/test_trace_chat_completion.py b/python/tests/unit/utils/model_diagnostics/test_trace_chat_completion.py index 8b322a72a52b..10b69aa52aa0 100644 --- a/python/tests/unit/utils/model_diagnostics/test_trace_chat_completion.py +++ b/python/tests/unit/utils/model_diagnostics/test_trace_chat_completion.py @@ -94,6 +94,7 @@ async def test_trace_chat_completion( mock_span = mock_tracer.start_span.return_value # Setup chat_completion: ChatCompletionClientBase = MockChatCompletion(ai_model_id="ai_model_id") + mock_span = mock_start_span.return_value with patch.object(MockChatCompletion, "_inner_get_chat_message_contents", return_value=mock_response): # We need to reapply the decorator to the method since the mock will not have the decorator applied @@ -174,6 +175,7 @@ async def test_trace_chat_completion_exception( mock_span = mock_tracer.start_span.return_value # Setup chat_completion: ChatCompletionClientBase = MockChatCompletion(ai_model_id="ai_model_id") + mock_span = mock_start_span.return_value with patch.object(MockChatCompletion, "_inner_get_chat_message_contents", side_effect=ServiceResponseException()): # We need to reapply the decorator to the method since the mock will not have the decorator applied diff --git a/python/tests/unit/utils/model_diagnostics/test_trace_streaming_chat_completion.py b/python/tests/unit/utils/model_diagnostics/test_trace_streaming_chat_completion.py index 8b43818a6452..b3b520e5e165 100644 --- a/python/tests/unit/utils/model_diagnostics/test_trace_streaming_chat_completion.py +++ b/python/tests/unit/utils/model_diagnostics/test_trace_streaming_chat_completion.py @@ -100,6 +100,7 @@ async def test_trace_streaming_chat_completion( mock_span = mock_tracer.start_span.return_value # Setup chat_completion: ChatCompletionClientBase = MockChatCompletion(ai_model_id="ai_model_id") + mock_span = mock_start_span.return_value iterable = MagicMock(spec=AsyncGenerator) iterable.__aiter__.return_value = [mock_response] @@ -189,6 +190,7 @@ async def test_trace_streaming_chat_completion_exception( mock_span = mock_tracer.start_span.return_value # Setup chat_completion: ChatCompletionClientBase = MockChatCompletion(ai_model_id="ai_model_id") + mock_span = mock_start_span.return_value with patch.object( MockChatCompletion, "_inner_get_streaming_chat_message_contents", side_effect=ServiceResponseException() diff --git a/python/tests/unit/utils/model_diagnostics/test_trace_streaming_text_completion.py b/python/tests/unit/utils/model_diagnostics/test_trace_streaming_text_completion.py index f270192c7930..725977d145df 100644 --- a/python/tests/unit/utils/model_diagnostics/test_trace_streaming_text_completion.py +++ b/python/tests/unit/utils/model_diagnostics/test_trace_streaming_text_completion.py @@ -81,6 +81,7 @@ async def test_trace_streaming_text_completion( mock_span = mock_tracer.start_span.return_value # Setup text_completion: TextCompletionClientBase = MockTextCompletion(ai_model_id="ai_model_id") + mock_span = mock_start_span.return_value iterable = MagicMock(spec=AsyncGenerator) iterable.__aiter__.return_value = [mock_response] @@ -155,6 +156,7 @@ async def test_trace_streaming_text_completion_exception( mock_span = mock_tracer.start_span.return_value # Setup text_completion: TextCompletionClientBase = MockTextCompletion(ai_model_id="ai_model_id") + mock_span = mock_start_span.return_value with patch.object(MockTextCompletion, "_inner_get_streaming_text_contents", side_effect=ServiceResponseException()): # We need to reapply the decorator to the method since the mock will not have the decorator applied diff --git a/python/tests/unit/utils/model_diagnostics/test_trace_text_completion.py b/python/tests/unit/utils/model_diagnostics/test_trace_text_completion.py index e94486071897..379d62586221 100644 --- a/python/tests/unit/utils/model_diagnostics/test_trace_text_completion.py +++ b/python/tests/unit/utils/model_diagnostics/test_trace_text_completion.py @@ -77,6 +77,7 @@ async def test_trace_text_completion( mock_span = mock_tracer.start_span.return_value # Setup text_completion: TextCompletionClientBase = MockTextCompletion(ai_model_id="ai_model_id") + mock_span = mock_start_span.return_value with patch.object(MockTextCompletion, "_inner_get_text_contents", return_value=mock_response): # We need to reapply the decorator to the method since the mock will not have the decorator applied @@ -148,6 +149,7 @@ async def test_trace_text_completion_exception( mock_span = mock_tracer.start_span.return_value # Setup text_completion: TextCompletionClientBase = MockTextCompletion(ai_model_id="ai_model_id") + mock_span = mock_start_span.return_value with patch.object(MockTextCompletion, "_inner_get_text_contents", side_effect=ServiceResponseException()): # We need to reapply the decorator to the method since the mock will not have the decorator applied From 70663243d8b855c0d08c309dbd7516b6e425d5a5 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 12 Mar 2026 13:16:46 +0100 Subject: [PATCH 4/6] further enveloped filters in inmemorycollection --- .../connectors/azure_cosmos_db.py | 2 +- .../semantic_kernel/connectors/in_memory.py | 394 +++++++++++++++++- python/semantic_kernel/data/_shared.py | 9 +- .../test_azure_cosmos_db_no_sql_collection.py | 62 +++ .../unit/connectors/memory/test_in_memory.py | 48 ++- 5 files changed, 494 insertions(+), 21 deletions(-) diff --git a/python/semantic_kernel/connectors/azure_cosmos_db.py b/python/semantic_kernel/connectors/azure_cosmos_db.py index 55c2bfc59eb9..0da6f52b8346 100644 --- a/python/semantic_kernel/connectors/azure_cosmos_db.py +++ b/python/semantic_kernel/connectors/azure_cosmos_db.py @@ -922,7 +922,7 @@ def _lambda_parser(self, node: ast.AST) -> Any: case ast.Constant(): # Quote strings, leave numbers as is if isinstance(node.value, str): - return f"'{node.value}'" + return "'" + node.value.replace("'", "''") + "'" if isinstance(node.value, (float, int)): return str(node.value) if node.value is None: diff --git a/python/semantic_kernel/connectors/in_memory.py b/python/semantic_kernel/connectors/in_memory.py index 914f9039e3cd..458f230cf8ca 100644 --- a/python/semantic_kernel/connectors/in_memory.py +++ b/python/semantic_kernel/connectors/in_memory.py @@ -90,7 +90,7 @@ def __init__(self, data: Mapping[TAKey, TAValue]): def __getitem__(self, key: TAKey) -> TAValue: """Get a value by key.""" - return self._data[key] + return self._wrap_value(self._data[key]) def __iter__(self): """Iterate over keys.""" @@ -103,10 +103,311 @@ def __len__(self) -> int: def __getattr__(self, name: str) -> TAValue: """Allow attribute-style access to mapping keys.""" try: - return self._data[cast(TAKey, name)] + return self._wrap_value(self._data[cast(TAKey, name)]) except KeyError: raise AttributeError(name) + @staticmethod + def _wrap_value(value: Any) -> Any: + """Wrap nested mappings to preserve read-only attribute access.""" + if isinstance(value, Mapping) and not isinstance(value, ReadOnlyAttributeDict): + return ReadOnlyAttributeDict(value) + return value + + +class _SafeFilterEvaluator: + """Evaluate a restricted filter AST without using eval().""" + + def __init__( + self, + *, + direct_call_functions: dict[str, Callable[..., Any]], + blocked_attributes: set[str], + max_literal_collection_size: int, + max_sequence_repeat_size: int, + ): + self._direct_call_functions = direct_call_functions + self._blocked_attributes = blocked_attributes + self._max_literal_collection_size = max_literal_collection_size + self._max_sequence_repeat_size = max_sequence_repeat_size + + def evaluate(self, node: ast.AST, context: Mapping[str, Any]) -> Any: + """Evaluate a supported AST node.""" + evaluator = getattr(self, f"_eval_{type(node).__name__}", None) + if evaluator is None: + raise VectorStoreOperationException( + f"AST node type '{type(node).__name__}' is not supported during filter evaluation." + ) + return evaluator(node, context) + + def _eval_Constant(self, node: ast.Constant, context: Mapping[str, Any]) -> Any: + """Evaluate a constant literal.""" + del context + if isinstance(node.value, str) and len(node.value) > self._max_literal_collection_size: + raise VectorStoreOperationException( + "String literals in filter expressions exceed the maximum allowed size." + ) + return node.value + + def _eval_Name(self, node: ast.Name, context: Mapping[str, Any]) -> Any: + """Evaluate a variable reference.""" + if node.id not in context: + raise VectorStoreOperationException(f"Use of name '{node.id}' is not allowed in filter expressions.") + return context[node.id] + + def _eval_Attribute(self, node: ast.Attribute, context: Mapping[str, Any]) -> Any: + """Evaluate an attribute access.""" + if node.attr in self._blocked_attributes: + raise VectorStoreOperationException( + f"Access to attribute '{node.attr}' is not allowed in filter expressions." + ) + value = self.evaluate(node.value, context) + try: + return ReadOnlyAttributeDict._wrap_value(getattr(value, node.attr)) + except AttributeError as e: + raise VectorStoreOperationException( + f"Attribute '{node.attr}' is not available in filter expressions." + ) from e + + def _eval_Subscript(self, node: ast.Subscript, context: Mapping[str, Any]) -> Any: + """Evaluate an index or slice operation.""" + value = self.evaluate(node.value, context) + slice_value = self.evaluate(node.slice, context) + try: + return ReadOnlyAttributeDict._wrap_value(value[slice_value]) + except Exception as e: + raise VectorStoreOperationException(f"Error evaluating subscript access: {e}") from e + + def _eval_Slice(self, node: ast.Slice, context: Mapping[str, Any]) -> slice: + """Evaluate a slice node.""" + lower = self._evaluate_optional(node.lower, context) + upper = self._evaluate_optional(node.upper, context) + step = self._evaluate_optional(node.step, context) + return slice(lower, upper, step) + + def _eval_List(self, node: ast.List, context: Mapping[str, Any]) -> list[Any]: + """Evaluate a list literal.""" + self._ensure_literal_collection_size(len(node.elts)) + return [self.evaluate(element, context) for element in node.elts] + + def _eval_Tuple(self, node: ast.Tuple, context: Mapping[str, Any]) -> tuple[Any, ...]: + """Evaluate a tuple literal.""" + self._ensure_literal_collection_size(len(node.elts)) + return tuple(self.evaluate(element, context) for element in node.elts) + + def _eval_Set(self, node: ast.Set, context: Mapping[str, Any]) -> set[Any]: + """Evaluate a set literal.""" + self._ensure_literal_collection_size(len(node.elts)) + return {self.evaluate(element, context) for element in node.elts} + + def _eval_Dict(self, node: ast.Dict, context: Mapping[str, Any]) -> dict[Any, Any]: + """Evaluate a dict literal.""" + self._ensure_literal_collection_size(len(node.keys)) + result: dict[Any, Any] = {} + for key, value in zip(node.keys, node.values, strict=True): + if key is None: + raise VectorStoreOperationException("Dictionary unpacking is not allowed in filter expressions.") + result[self.evaluate(key, context)] = self.evaluate(value, context) + return result + + def _eval_BoolOp(self, node: ast.BoolOp, context: Mapping[str, Any]) -> Any: + """Evaluate boolean operators with Python short-circuit semantics.""" + if isinstance(node.op, ast.And): + result = self.evaluate(node.values[0], context) + for value in node.values[1:]: + if not result: + return result + result = self.evaluate(value, context) + return result + if isinstance(node.op, ast.Or): + result = self.evaluate(node.values[0], context) + for value in node.values[1:]: + if result: + return result + result = self.evaluate(value, context) + return result + raise VectorStoreOperationException( + f"Boolean operator '{type(node.op).__name__}' is not allowed in filter expressions." + ) + + def _eval_UnaryOp(self, node: ast.UnaryOp, context: Mapping[str, Any]) -> Any: + """Evaluate a unary operator.""" + operand = self.evaluate(node.operand, context) + if isinstance(node.op, ast.Not): + return not operand + raise VectorStoreOperationException( + f"Unary operator '{type(node.op).__name__}' is not allowed in filter expressions." + ) + + def _eval_Compare(self, node: ast.Compare, context: Mapping[str, Any]) -> bool: + """Evaluate a comparison expression.""" + left = self.evaluate(node.left, context) + for operator_node, comparator in zip(node.ops, node.comparators, strict=True): + right = self.evaluate(comparator, context) + if not self._compare(operator_node, left, right): + return False + left = right + return True + + def _eval_BinOp(self, node: ast.BinOp, context: Mapping[str, Any]) -> Any: + """Evaluate a binary operator.""" + left = self.evaluate(node.left, context) + right = self.evaluate(node.right, context) + + if isinstance(node.op, ast.Add): + return self._safe_add(left, right) + if isinstance(node.op, ast.Sub): + return self._safe_numeric_operation(node.op, left, right, lambda a, b: a - b) + if isinstance(node.op, ast.Mult): + return self._safe_mult(left, right) + if isinstance(node.op, ast.Div): + return self._safe_numeric_operation(node.op, left, right, lambda a, b: a / b) + if isinstance(node.op, ast.Mod): + return self._safe_numeric_operation(node.op, left, right, lambda a, b: a % b) + if isinstance(node.op, ast.FloorDiv): + return self._safe_numeric_operation(node.op, left, right, lambda a, b: a // b) + + raise VectorStoreOperationException( + f"Binary operator '{type(node.op).__name__}' is not allowed in filter expressions." + ) + + def _eval_Call(self, node: ast.Call, context: Mapping[str, Any]) -> Any: + """Evaluate a function or method call.""" + args = [self.evaluate(arg, context) for arg in node.args] + + if isinstance(node.func, ast.Name): + try: + func = self._direct_call_functions[node.func.id] + except KeyError as e: + raise VectorStoreOperationException( + f"Function '{node.func.id}' is only supported as a method call in filter expressions." + ) from e + return func(*args) + + if isinstance(node.func, ast.Attribute): + target = self.evaluate(node.func.value, context) + if node.func.attr == "contains": + if len(args) != 1: + raise VectorStoreOperationException("Method 'contains' expects exactly one argument.") + return args[0] in target + + try: + func = getattr(target, node.func.attr) + except AttributeError as e: + raise VectorStoreOperationException( + f"Method '{node.func.attr}' is not available in filter expressions." + ) from e + + if not callable(func): + raise VectorStoreOperationException( + f"Attribute '{node.func.attr}' is not callable in filter expressions." + ) + return func(*args) + + raise VectorStoreOperationException( + f"Call target node type '{type(node.func).__name__}' is not allowed in filter expressions." + ) + + def _compare(self, operator_node: ast.AST, left: Any, right: Any) -> bool: + """Evaluate a comparison operator.""" + if isinstance(operator_node, ast.Eq): + return left == right + if isinstance(operator_node, ast.NotEq): + return left != right + if isinstance(operator_node, ast.Lt): + return left < right + if isinstance(operator_node, ast.LtE): + return left <= right + if isinstance(operator_node, ast.Gt): + return left > right + if isinstance(operator_node, ast.GtE): + return left >= right + if isinstance(operator_node, ast.In): + return left in right + if isinstance(operator_node, ast.NotIn): + return left not in right + if isinstance(operator_node, ast.Is): + return left is right + if isinstance(operator_node, ast.IsNot): + return left is not right + raise VectorStoreOperationException( + f"Comparison operator '{type(operator_node).__name__}' is not allowed in filter expressions." + ) + + def _safe_add(self, left: Any, right: Any) -> Any: + """Safely evaluate addition.""" + if isinstance(left, (int, float)) and isinstance(right, (int, float)): + return left + right + if isinstance(left, str) and isinstance(right, str): + return self._ensure_sequence_result_size(left, right, lambda a, b: a + b) + if isinstance(left, list) and isinstance(right, list): + return self._ensure_sequence_result_size(left, right, lambda a, b: a + b) + if isinstance(left, tuple) and isinstance(right, tuple): + return self._ensure_sequence_result_size(left, right, lambda a, b: a + b) + raise VectorStoreOperationException( + "Addition in filter expressions is only allowed for numeric values and same-type sequences." + ) + + def _safe_mult(self, left: Any, right: Any) -> Any: + """Safely evaluate multiplication.""" + if isinstance(left, (int, float)) and isinstance(right, (int, float)): + return left * right + if isinstance(left, int) and isinstance(right, (str, list, tuple)): + return self._safe_repeat(right, left) + if isinstance(right, int) and isinstance(left, (str, list, tuple)): + return self._safe_repeat(left, right) + raise VectorStoreOperationException( + "Multiplication in filter expressions is only allowed for numeric values and bounded sequence repetition." + ) + + def _safe_repeat(self, value: str | list[Any] | tuple[Any, ...], repeat_count: int) -> Any: + """Safely repeat a sequence.""" + if repeat_count <= 0 or len(value) == 0: + return value * repeat_count + if len(value) > self._max_sequence_repeat_size // repeat_count: + raise VectorStoreOperationException( + "Sequence repetition in filter expressions exceeds the maximum allowed size." + ) + return value * repeat_count + + def _safe_numeric_operation( + self, + operator_node: ast.AST, + left: Any, + right: Any, + operation: Callable[[float | int, float | int], Any], + ) -> Any: + """Safely evaluate a numeric binary operation.""" + if not isinstance(left, (int, float)) or not isinstance(right, (int, float)): + raise VectorStoreOperationException( + f"Operator '{type(operator_node).__name__}' is only allowed for numeric values in filter expressions." + ) + return operation(left, right) + + def _ensure_literal_collection_size(self, size: int) -> None: + """Reject excessively large literal collections.""" + if size > self._max_literal_collection_size: + raise VectorStoreOperationException( + "Collection literals in filter expressions exceed the maximum allowed size." + ) + + def _ensure_sequence_result_size( + self, + left: str | list[Any] | tuple[Any, ...], + right: str | list[Any] | tuple[Any, ...], + operation: Callable[[Any, Any], Any], + ) -> Any: + """Reject oversized sequence concatenation results.""" + if len(left) + len(right) > self._max_sequence_repeat_size: + raise VectorStoreOperationException( + "Sequence operations in filter expressions exceed the maximum allowed size." + ) + return operation(left, right) + + def _evaluate_optional(self, node: ast.AST | None, context: Mapping[str, Any]) -> Any: + """Evaluate an optional AST node.""" + return self.evaluate(node, context) if node is not None else None + class InMemoryCollection( VectorStoreCollection[TKey, TModel], @@ -118,6 +419,11 @@ class InMemoryCollection( inner_storage: dict[TKey, AttributeDict] = Field(default_factory=dict) supported_key_types: ClassVar[set[str] | None] = {"str", "int", "float"} supported_search_types: ClassVar[set[SearchType]] = {SearchType.VECTOR} + # Conservative defaults: callers can raise these per collection instance or subclass if needed. + max_filter_source_length: int = Field(default=2_048, exclude=True) + max_filter_ast_node_count: int = Field(default=128, exclude=True) + max_filter_literal_collection_size: int = Field(default=256, exclude=True) + max_filter_sequence_repeat_size: int = Field(default=1_024, exclude=True) # Allowlist of AST node types permitted in filter expressions. # This can be overridden in subclasses to extend or restrict allowed operations. @@ -148,7 +454,6 @@ class InMemoryCollection( ast.Load, ast.Attribute, ast.Subscript, - ast.Index, # For Python 3.8 compatibility ast.Slice, # Literals ast.Constant, @@ -192,6 +497,19 @@ class InMemoryCollection( "values", "items", } + direct_filter_functions: ClassVar[dict[str, Callable[..., Any]]] = { + "len": len, + "str": str, + "int": int, + "float": float, + "bool": bool, + "abs": abs, + "min": min, + "max": max, + "sum": sum, + "any": any, + "all": all, + } # Blocklist of dangerous attribute names that cannot be accessed in filter expressions. # These attributes can be used to escape the sandbox and execute arbitrary code. @@ -267,8 +585,17 @@ def __init__( > [Important] > Filters are powerful things, so make sure to not allow untrusted input here. - > Filters for this collection are parsed and evaluated using Python's `ast` module, so code might be executed. - > We only allow certain AST nodes and functions to be used in the filter expressions to mitigate security risks. + > Filters for this collection are parsed into Python's `ast` module and evaluated by a restricted interpreter. + > We only allow certain AST nodes and functions to be used in filter expressions, and we reject expressions + > that exceed reasonable size and complexity limits. + > + > The default filter limits are: + > - `max_filter_source_length=2048` + > - `max_filter_ast_node_count=128` + > - `max_filter_literal_collection_size=256` + > - `max_filter_sequence_repeat_size=1024` + > You can override these limits by passing them through `kwargs` or by setting them on the collection + > instance after initialization. """ super().__init__( @@ -419,6 +746,9 @@ def _parse_and_validate_filter(self, filter_str: str) -> Callable: are allowed. This can be customized by overriding `allowed_filter_ast_nodes` and `allowed_filter_functions` class attributes. """ + if len(filter_str) > self.max_filter_source_length: + raise VectorStoreOperationException("Filter string exceeds the maximum allowed length.") + try: tree = ast.parse(filter_str, mode="eval") except SyntaxError as e: @@ -433,9 +763,12 @@ def _parse_and_validate_filter(self, filter_str: str) -> Callable: # Get the lambda parameter name(s) to allow them as valid Name nodes lambda_node = tree.body lambda_param_names = {arg.arg for arg in lambda_node.args.args} - + lambda_param_order = [arg.arg for arg in lambda_node.args.args] # Walk the AST to validate all nodes against the allowlist - for node in ast.walk(tree): + for node_count, node in enumerate(ast.walk(tree), start=1): + if node_count > self.max_filter_ast_node_count: + raise VectorStoreOperationException("Filter expression exceeds the maximum allowed complexity.") + node_type = type(node) # Check if the node type is allowed @@ -477,16 +810,47 @@ def _parse_and_validate_filter(self, filter_str: str) -> Callable: f"Allowed functions: {', '.join(sorted(self.allowed_filter_functions))}" ) - try: - code = compile(tree, filename="", mode="eval") - func = eval(code, {"__builtins__": {}}, {}) # nosec - except Exception as e: - raise VectorStoreOperationException(f"Error compiling filter: {e}") from e + if ( + isinstance(node, (ast.List, ast.Tuple, ast.Set)) + and len(node.elts) > self.max_filter_literal_collection_size + ): + raise VectorStoreOperationException( + "Collection literals in filter expressions exceed the maximum allowed size." + ) - if not callable(func): - raise VectorStoreOperationException("Compiled filter is not callable.") + if isinstance(node, ast.Dict) and len(node.keys) > self.max_filter_literal_collection_size: + raise VectorStoreOperationException( + "Collection literals in filter expressions exceed the maximum allowed size." + ) + + if ( + isinstance(node, ast.Constant) + and isinstance(node.value, str) + and len(node.value) > self.max_filter_literal_collection_size + ): + raise VectorStoreOperationException( + "String literals in filter expressions exceed the maximum allowed size." + ) + + evaluator = _SafeFilterEvaluator( + direct_call_functions=self.direct_filter_functions, + blocked_attributes=self.blocked_filter_attributes, + max_literal_collection_size=self.max_filter_literal_collection_size, + max_sequence_repeat_size=self.max_filter_sequence_repeat_size, + ) + + def filter_callable(*args: Any) -> Any: + if len(args) != len(lambda_param_order): + raise VectorStoreOperationException( + f"Filter expected {len(lambda_param_order)} argument(s), but received {len(args)}." + ) + context = { + name: ReadOnlyAttributeDict._wrap_value(value) + for name, value in zip(lambda_param_order, args, strict=True) + } + return evaluator.evaluate(lambda_node.body, context) - return func + return filter_callable def _run_filter(self, filter: Callable, record: AttributeDict[TAKey, TAValue]) -> bool: """Run the filter on the record, supporting attribute access.""" diff --git a/python/semantic_kernel/data/_shared.py b/python/semantic_kernel/data/_shared.py index 50c3cc0bc73e..52ea628b0015 100644 --- a/python/semantic_kernel/data/_shared.py +++ b/python/semantic_kernel/data/_shared.py @@ -168,9 +168,9 @@ def default_dynamic_filter_function( continue new_filter = None if param.name in kwargs: - new_filter = f"lambda x: x.{param.name} == '{kwargs[param.name]}'" + new_filter = f"lambda x: x.{param.name} == {_format_filter_literal(kwargs[param.name])}" elif param.default_value: - new_filter = f"lambda x: x.{param.name} == '{param.default_value}'" + new_filter = f"lambda x: x.{param.name} == {_format_filter_literal(param.default_value)}" if not new_filter: continue if filter is None: @@ -181,3 +181,8 @@ def default_dynamic_filter_function( filter = [filter, new_filter] return filter + + +def _format_filter_literal(value: Any) -> str: + """Format a value as a safe Python literal for filter strings.""" + return repr(value) diff --git a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py index 5e512d658c6d..617a097053f0 100644 --- a/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py +++ b/python/tests/unit/connectors/memory/azure_cosmos_db/test_azure_cosmos_db_no_sql_collection.py @@ -14,11 +14,13 @@ _create_default_indexing_policy_nosql, _create_default_vector_embedding_policy, ) +from semantic_kernel.data._shared import default_dynamic_filter_function from semantic_kernel.exceptions import ( VectorStoreInitializationException, VectorStoreModelException, VectorStoreOperationException, ) +from semantic_kernel.functions import KernelParameterMetadata def test_azure_cosmos_db_no_sql_collection_init( @@ -66,6 +68,66 @@ def test_azure_cosmos_db_no_sql_collection_init_env( assert vector_collection.create_database is False +def test_azure_cosmos_db_no_sql_build_filter_escapes_apostrophes( + azure_cosmos_db_no_sql_unit_test_env, + record_type, + collection_name: str, +) -> None: + """Test Cosmos DB filter building escapes apostrophes in string literals.""" + vector_collection = CosmosNoSqlCollection( + record_type=record_type, + collection_name=collection_name, + ) + + filter_string = vector_collection._build_filter('lambda x: x.content == "O\'Reilly"') + + assert filter_string == "c.content = 'O''Reilly'" + + +def test_azure_cosmos_db_no_sql_build_filter_escapes_injection_payload( + azure_cosmos_db_no_sql_unit_test_env, + record_type, + collection_name: str, +) -> None: + """Test Cosmos DB filter building keeps injection-shaped strings inside the literal.""" + vector_collection = CosmosNoSqlCollection( + record_type=record_type, + collection_name=collection_name, + ) + + filter_string = vector_collection._build_filter("lambda x: x.content == \"test' OR '1'='1\"") + + assert filter_string == "c.content = 'test'' OR ''1''=''1'" + + +def test_azure_cosmos_db_no_sql_dynamic_filter_injection_payload_stays_literal( + azure_cosmos_db_no_sql_unit_test_env, + record_type, + collection_name: str, +) -> None: + """Test default_dynamic_filter_function does not let user values alter Cosmos filter syntax.""" + vector_collection = CosmosNoSqlCollection( + record_type=record_type, + collection_name=collection_name, + ) + generated_filter = default_dynamic_filter_function( + filter=None, + parameters=[ + KernelParameterMetadata( + name="content", + description="Content filter", + type="str", + is_required=False, + type_object=str, + ) + ], + content="test' OR '1'='1", + ) + + assert isinstance(generated_filter, str) + assert vector_collection._build_filter(generated_filter) == "c.content = 'test'' OR ''1''=''1'" + + @pytest.mark.parametrize("exclude_list", [["AZURE_COSMOS_DB_NO_SQL_URL"]], indirect=True) def test_azure_cosmos_db_no_sql_collection_init_no_url( azure_cosmos_db_no_sql_unit_test_env, diff --git a/python/tests/unit/connectors/memory/test_in_memory.py b/python/tests/unit/connectors/memory/test_in_memory.py index 76b7fc239912..3c6705c3b28d 100644 --- a/python/tests/unit/connectors/memory/test_in_memory.py +++ b/python/tests/unit/connectors/memory/test_in_memory.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +import ast + from pytest import fixture, mark, raises from semantic_kernel.connectors.in_memory import InMemoryCollection, InMemoryStore @@ -203,6 +205,32 @@ async def test_valid_lambda_filter_with_get_method(collection): assert "1" in results +async def test_valid_lambda_filter_with_bounded_sequence_repeat(collection): + record = {"id": "1", "vector": [1, 2, 3, 4, 5]} + await collection.upsert(record) + + results = collection._get_filtered_records(type("opt", (), {"filter": "lambda x: ([0] * 2)[1] == 0"})()) + + assert len(results) == 1 + assert "1" in results + + +async def test_sequence_repeat_limit_can_be_overridden(collection): + record = {"id": "1", "vector": [1, 2, 3, 4, 5]} + await collection.upsert(record) + filter_options = type("opt", (), {"filter": "lambda x: ([0] * 2)[1] == 0"})() + + collection.max_filter_sequence_repeat_size = 1 + with raises(VectorStoreOperationException, match="Sequence repetition in filter expressions exceeds the maximum"): + collection._get_filtered_records(filter_options) + + collection.max_filter_sequence_repeat_size = 2 + results = collection._get_filtered_records(filter_options) + + assert len(results) == 1 + assert "1" in results + + async def test_callable_filter_cannot_mutate_stored_record(collection): record = {"id": "1", "content": "value", "vector": [1, 2, 3, 4, 5]} await collection.upsert(record) @@ -218,7 +246,7 @@ def mutating_filter(x): assert collection.inner_storage["1"]["content"] == "value" -def test_default_dynamic_filter_injection_payload_is_blocked(collection): +def test_default_dynamic_filter_injection_payload_remains_string_literal(collection): class Param: def __init__(self, name, default_value=None): self.name = name @@ -232,5 +260,19 @@ def __init__(self, name, default_value=None): ) assert isinstance(generated_filter, str) - with raises(VectorStoreOperationException, match="Call target node type 'Subscript' is not allowed"): - collection._parse_and_validate_filter(generated_filter) + tree = ast.parse(generated_filter, mode="eval") + assert isinstance(tree.body, ast.Lambda) + assert isinstance(tree.body.body, ast.Compare) + assert isinstance(tree.body.body.comparators[0], ast.Constant) + assert tree.body.body.comparators[0].value == injected_value + + filter_func = collection._parse_and_validate_filter(generated_filter) + assert filter_func({"category": "finance", "name": "alice", "vector": [0.1] * 5}) is False + + +async def test_large_sequence_repeat_filter_is_blocked(collection): + record = {"id": "1", "content": "value", "vector": [0.1, 0.2, 0.3, 0.4, 0.5]} + await collection.upsert(record) + + with raises(VectorStoreOperationException, match="Sequence repetition in filter expressions exceeds the maximum"): + collection._get_filtered_records(type("opt", (), {"filter": "lambda x: [0] * 2000000000"})()) From b1d22cda4f1ac0b95def2e41eca54a2acb076caa Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 12 Mar 2026 13:57:39 +0100 Subject: [PATCH 5/6] added trace to exception when bedrock creation fails --- python/semantic_kernel/agents/bedrock/bedrock_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/semantic_kernel/agents/bedrock/bedrock_agent.py b/python/semantic_kernel/agents/bedrock/bedrock_agent.py index 12c52dadd5e5..13d356af4e5b 100644 --- a/python/semantic_kernel/agents/bedrock/bedrock_agent.py +++ b/python/semantic_kernel/agents/bedrock/bedrock_agent.py @@ -216,7 +216,7 @@ async def create_and_prepare_agent( env_file_encoding=env_file_encoding, ) except ValidationError as e: - raise AgentInitializationException("Failed to initialize the Amazon Bedrock Agent settings.") from e + raise AgentInitializationException(f"Failed to initialize the Amazon Bedrock Agent settings: {e}") from e import boto3 from botocore.exceptions import ClientError @@ -237,7 +237,7 @@ async def create_and_prepare_agent( ) except ClientError as e: logger.error(f"Failed to create agent {name}.") - raise AgentInitializationException("Failed to create the Amazon Bedrock Agent.") from e + raise AgentInitializationException(f"Failed to create the Amazon Bedrock Agent: {e}") from e bedrock_agent = cls( response["agent"], From b3af0e8197a1b643089d8addd2754a5fcc5ee711 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 12 Mar 2026 18:30:31 +0100 Subject: [PATCH 6/6] fix mypy --- .../unit/utils/model_diagnostics/test_trace_chat_completion.py | 2 -- .../model_diagnostics/test_trace_streaming_chat_completion.py | 2 -- .../model_diagnostics/test_trace_streaming_text_completion.py | 2 -- .../unit/utils/model_diagnostics/test_trace_text_completion.py | 2 -- 4 files changed, 8 deletions(-) diff --git a/python/tests/unit/utils/model_diagnostics/test_trace_chat_completion.py b/python/tests/unit/utils/model_diagnostics/test_trace_chat_completion.py index 10b69aa52aa0..8b322a72a52b 100644 --- a/python/tests/unit/utils/model_diagnostics/test_trace_chat_completion.py +++ b/python/tests/unit/utils/model_diagnostics/test_trace_chat_completion.py @@ -94,7 +94,6 @@ async def test_trace_chat_completion( mock_span = mock_tracer.start_span.return_value # Setup chat_completion: ChatCompletionClientBase = MockChatCompletion(ai_model_id="ai_model_id") - mock_span = mock_start_span.return_value with patch.object(MockChatCompletion, "_inner_get_chat_message_contents", return_value=mock_response): # We need to reapply the decorator to the method since the mock will not have the decorator applied @@ -175,7 +174,6 @@ async def test_trace_chat_completion_exception( mock_span = mock_tracer.start_span.return_value # Setup chat_completion: ChatCompletionClientBase = MockChatCompletion(ai_model_id="ai_model_id") - mock_span = mock_start_span.return_value with patch.object(MockChatCompletion, "_inner_get_chat_message_contents", side_effect=ServiceResponseException()): # We need to reapply the decorator to the method since the mock will not have the decorator applied diff --git a/python/tests/unit/utils/model_diagnostics/test_trace_streaming_chat_completion.py b/python/tests/unit/utils/model_diagnostics/test_trace_streaming_chat_completion.py index b3b520e5e165..8b43818a6452 100644 --- a/python/tests/unit/utils/model_diagnostics/test_trace_streaming_chat_completion.py +++ b/python/tests/unit/utils/model_diagnostics/test_trace_streaming_chat_completion.py @@ -100,7 +100,6 @@ async def test_trace_streaming_chat_completion( mock_span = mock_tracer.start_span.return_value # Setup chat_completion: ChatCompletionClientBase = MockChatCompletion(ai_model_id="ai_model_id") - mock_span = mock_start_span.return_value iterable = MagicMock(spec=AsyncGenerator) iterable.__aiter__.return_value = [mock_response] @@ -190,7 +189,6 @@ async def test_trace_streaming_chat_completion_exception( mock_span = mock_tracer.start_span.return_value # Setup chat_completion: ChatCompletionClientBase = MockChatCompletion(ai_model_id="ai_model_id") - mock_span = mock_start_span.return_value with patch.object( MockChatCompletion, "_inner_get_streaming_chat_message_contents", side_effect=ServiceResponseException() diff --git a/python/tests/unit/utils/model_diagnostics/test_trace_streaming_text_completion.py b/python/tests/unit/utils/model_diagnostics/test_trace_streaming_text_completion.py index 725977d145df..f270192c7930 100644 --- a/python/tests/unit/utils/model_diagnostics/test_trace_streaming_text_completion.py +++ b/python/tests/unit/utils/model_diagnostics/test_trace_streaming_text_completion.py @@ -81,7 +81,6 @@ async def test_trace_streaming_text_completion( mock_span = mock_tracer.start_span.return_value # Setup text_completion: TextCompletionClientBase = MockTextCompletion(ai_model_id="ai_model_id") - mock_span = mock_start_span.return_value iterable = MagicMock(spec=AsyncGenerator) iterable.__aiter__.return_value = [mock_response] @@ -156,7 +155,6 @@ async def test_trace_streaming_text_completion_exception( mock_span = mock_tracer.start_span.return_value # Setup text_completion: TextCompletionClientBase = MockTextCompletion(ai_model_id="ai_model_id") - mock_span = mock_start_span.return_value with patch.object(MockTextCompletion, "_inner_get_streaming_text_contents", side_effect=ServiceResponseException()): # We need to reapply the decorator to the method since the mock will not have the decorator applied diff --git a/python/tests/unit/utils/model_diagnostics/test_trace_text_completion.py b/python/tests/unit/utils/model_diagnostics/test_trace_text_completion.py index 379d62586221..e94486071897 100644 --- a/python/tests/unit/utils/model_diagnostics/test_trace_text_completion.py +++ b/python/tests/unit/utils/model_diagnostics/test_trace_text_completion.py @@ -77,7 +77,6 @@ async def test_trace_text_completion( mock_span = mock_tracer.start_span.return_value # Setup text_completion: TextCompletionClientBase = MockTextCompletion(ai_model_id="ai_model_id") - mock_span = mock_start_span.return_value with patch.object(MockTextCompletion, "_inner_get_text_contents", return_value=mock_response): # We need to reapply the decorator to the method since the mock will not have the decorator applied @@ -149,7 +148,6 @@ async def test_trace_text_completion_exception( mock_span = mock_tracer.start_span.return_value # Setup text_completion: TextCompletionClientBase = MockTextCompletion(ai_model_id="ai_model_id") - mock_span = mock_start_span.return_value with patch.object(MockTextCompletion, "_inner_get_text_contents", side_effect=ServiceResponseException()): # We need to reapply the decorator to the method since the mock will not have the decorator applied