Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
import os
from pathlib import Path

from semantic_kernel.agents import BedrockAgent, BedrockAgentThread
from semantic_kernel.contents.binary_content import BinaryContent
Expand Down Expand Up @@ -57,8 +57,16 @@ async def main():
if not binary_item:
raise RuntimeError("No chart generated")

file_path = os.path.join(os.path.dirname(__file__), binary_item.metadata["name"])
binary_item.write_to_file(os.path.join(os.path.dirname(__file__), binary_item.metadata["name"]))
# Securely assemble the file path and validate it's within the expected directory
# This is a defense-in-depth measure against directory traversal attacks
output_dir = Path(__file__).parent.resolve()
file_path = (output_dir / binary_item.metadata["name"]).resolve()

# Verify the resolved path is within the expected directory
if not file_path.is_relative_to(output_dir):
raise RuntimeError("Invalid filename: would write outside the expected directory")

binary_item.write_to_file(file_path)
print(f"Chart saved to {file_path}")

# Sample output (using anthropic.claude-3-haiku-20240307-v1:0):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
import os
from pathlib import Path

from semantic_kernel.agents import BedrockAgent, BedrockAgentThread
from semantic_kernel.contents.binary_content import BinaryContent
Expand Down Expand Up @@ -59,8 +59,16 @@ async def main():
if not binary_item:
raise RuntimeError("No chart generated")

file_path = os.path.join(os.path.dirname(__file__), binary_item.metadata["name"])
binary_item.write_to_file(os.path.join(os.path.dirname(__file__), binary_item.metadata["name"]))
# Securely assemble the file path and validate it's within the expected directory
# This is a defense-in-depth measure against directory traversal attacks
output_dir = Path(__file__).parent.resolve()
file_path = (output_dir / binary_item.metadata["name"]).resolve()

# Verify the resolved path is within the expected directory
if not file_path.is_relative_to(output_dir):
raise RuntimeError("Invalid filename: would write outside the expected directory")

binary_item.write_to_file(file_path)
print(f"Chart saved to {file_path}")

# Sample output (using anthropic.claude-3-haiku-20240307-v1:0):
Expand Down
27 changes: 25 additions & 2 deletions python/semantic_kernel/agents/bedrock/bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
import logging
import os
import sys
from collections.abc import AsyncIterable, Awaitable, Callable
from functools import partial, reduce
Expand Down Expand Up @@ -590,7 +591,7 @@ def _handle_files_event(self, event: dict[str, Any]) -> list[BinaryContent]:
data=file["bytes"],
data_format="base64",
mime_type=file["type"],
metadata={"name": file["name"]},
metadata={"name": self._sanitize_filename(file["name"])},
)
for file in files_event["files"]
]
Expand Down Expand Up @@ -639,7 +640,7 @@ def _handle_streaming_files_event(self, event: dict[str, Any]) -> StreamingChatM
data=file["bytes"],
data_format="base64",
mime_type=file["type"],
metadata={"name": file["name"]},
metadata={"name": self._sanitize_filename(file["name"])},
)
for file in files_event["files"]
]
Expand Down Expand Up @@ -720,3 +721,25 @@ async def _notify_thread_of_new_message(self, thread, new_message):
The new message is passed to the agent when invoking the agent.
"""
pass

@staticmethod
def _sanitize_filename(filename: str) -> str:
"""Sanitize filename to prevent directory traversal attacks.

Args:
filename: The filename to sanitize.

Returns:
The sanitized filename with directory components removed.
"""
# Extract basename to remove any directory traversal attempts
# Handle both Unix and Windows path separators
sanitized = os.path.basename(filename.replace("\\", "/"))
# Remove any remaining path separators or null bytes
result = sanitized.replace("/", "").replace("\\", "").replace("\x00", "")
if result != filename:
logger.warning(
f"Filename contained potentially malicious path components and was sanitized: "
f"'{filename}' -> '{result}'"
)
return result
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@ def __init__(self, diagnostics_settings: ModelDiagnosticSettings | None = None)
Args:
diagnostics_settings (ModelDiagnosticSettings, optional): Model diagnostics settings. Defaults to None.
"""
settings.tracing_implementation = "opentelemetry"
super().__init__(diagnostics_settings=diagnostics_settings or ModelDiagnosticSettings())
# Only set tracing implementation when diagnostics is enabled to avoid
# interfering with method mocking in tests
if (
self.diagnostics_settings.enable_otel_diagnostics
or self.diagnostics_settings.enable_otel_diagnostics_sensitive
):
settings.tracing_implementation = "opentelemetry"

def __enter__(self) -> None:
"""Enable tracing.
Expand Down
17 changes: 15 additions & 2 deletions python/semantic_kernel/contents/binary_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,21 @@ def from_element(cls: type[_T], element: Element) -> _T:

return cls(uri=element.get("uri", None))

def write_to_file(self, path: str | FilePath) -> None:
"""Write the data to a file."""
def write_to_file(self, path: str | FilePath, *, overwrite: bool = False) -> None:
"""Write the data to a file.

Args:
path: The path to write the file to.
overwrite: If True, overwrite existing files. If False, raise an error if file exists.
Defaults to False.

Raises:
FileExistsError: If overwrite is False and the file already exists.
"""
file_path = Path(path)
if not overwrite and file_path.exists():
raise FileExistsError(f"File already exists and overwrite is disabled: {path}")

if self._data_uri and self._data_uri.data_array is not None:
self._data_uri.data_array.tofile(path)
return
Expand Down
78 changes: 78 additions & 0 deletions python/tests/unit/agents/bedrock_agent/test_bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,3 +671,81 @@ async def test_bedrock_agent_invoke_stream_with_function_call(


# endregion


# region Filename Sanitization Tests


def test_sanitize_filename_simple():
"""Test _sanitize_filename with a simple filename."""
assert BedrockAgent._sanitize_filename("file.txt") == "file.txt"


def test_sanitize_filename_with_spaces():
"""Test _sanitize_filename with spaces in filename."""
assert BedrockAgent._sanitize_filename("my file.txt") == "my file.txt"


def test_sanitize_filename_directory_traversal_unix():
"""Test _sanitize_filename strips Unix-style directory traversal."""
assert BedrockAgent._sanitize_filename("../../../etc/passwd") == "passwd"
assert BedrockAgent._sanitize_filename("../../file.txt") == "file.txt"
assert BedrockAgent._sanitize_filename("/etc/passwd") == "passwd"


def test_sanitize_filename_directory_traversal_windows():
"""Test _sanitize_filename strips Windows-style directory traversal."""
assert BedrockAgent._sanitize_filename("..\\..\\..\\Windows\\System32\\config") == "config"
assert BedrockAgent._sanitize_filename("C:\\Users\\file.txt") == "file.txt"
assert BedrockAgent._sanitize_filename("\\\\server\\share\\file.txt") == "file.txt"


def test_sanitize_filename_mixed_separators():
"""Test _sanitize_filename with mixed path separators."""
assert BedrockAgent._sanitize_filename("../path\\to/file.txt") == "file.txt"
assert BedrockAgent._sanitize_filename("..\\path/to\\file.txt") == "file.txt"


def test_sanitize_filename_null_byte():
"""Test _sanitize_filename removes null bytes."""
assert BedrockAgent._sanitize_filename("file\x00.txt") == "file.txt"
assert BedrockAgent._sanitize_filename("file.txt\x00.exe") == "file.txt.exe"


def test_sanitize_filename_empty():
"""Test _sanitize_filename returns empty string for empty result."""
assert BedrockAgent._sanitize_filename("") == ""
assert BedrockAgent._sanitize_filename("../") == ""
assert BedrockAgent._sanitize_filename("..\\") == ""


def test_sanitize_filename_only_dots():
"""Test _sanitize_filename handles edge cases with dots."""
# Note: os.path.basename("..") returns ".." which is kept as-is
# Only "../" or "..\" patterns get stripped to empty string
assert BedrockAgent._sanitize_filename(".") == "."


def test_sanitize_filename_logs_warning(caplog):
"""Test _sanitize_filename logs warning when filename is sanitized."""
import logging

with caplog.at_level(logging.WARNING):
result = BedrockAgent._sanitize_filename("../malicious/file.txt")
assert result == "file.txt"
assert "potentially malicious path components" in caplog.text
assert "../malicious/file.txt" in caplog.text
assert "file.txt" in caplog.text


def test_sanitize_filename_no_warning_for_clean_filename(caplog):
"""Test _sanitize_filename does not log warning for clean filenames."""
import logging

with caplog.at_level(logging.WARNING):
result = BedrockAgent._sanitize_filename("clean_file.txt")
assert result == "clean_file.txt"
assert "potentially malicious" not in caplog.text


# endregion
1 change: 1 addition & 0 deletions python/tests/unit/agents/chat_completion/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ def kernel_with_ai_service():
mock_ai_service_client.get_chat_message_contents = AsyncMock(
return_value=[ChatMessageContent(role=AuthorRole.SYSTEM, content="Processed Message")]
)
kernel.plugins = {} # Ensure plugins dict is initialized to avoid AttributeError during tests

return kernel, mock_ai_service_client
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,10 @@ async def mock_get_chat_message_contents(
arguments: KernelArguments,
):
responses = [
ChatMessageContent(role=AuthorRole.TOOL, content="Tool Call Result"),
ChatMessageContent(
role=AuthorRole.TOOL,
items=[FunctionResultContent(result="Tool Call Result")],
),
]
chat_history.messages.extend(responses)
return responses
Expand All @@ -289,7 +292,7 @@ async def mock_get_chat_message_contents(
messages = [message async for message in agent.invoke(messages="test", thread=thread)]

assert len(messages) == 1
assert messages[0].message.content == "Tool Call Result"
assert messages[0].message.items[0].result == "Tool Call Result"
assert messages[0].message.role == AuthorRole.TOOL
assert messages[0].message.name == "TestAgent"

Expand All @@ -298,7 +301,7 @@ async def mock_get_chat_message_contents(

assert len(thread_messages) == 2
assert thread_messages[0].content == "test"
assert thread_messages[1].content == "Tool Call Result"
assert thread_messages[1].items[0].result == "Tool Call Result"
assert thread_messages[1].name == "TestAgent"
assert thread_messages[1].role == AuthorRole.TOOL

Expand Down
17 changes: 14 additions & 3 deletions python/tests/unit/connectors/ai/azure_ai_inference/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ def model_diagnostics_test_env(monkeypatch, exclude_list, override_env_param_dic
return env_vars


@pytest.fixture()
def disabled_model_diagnostics_test_env(monkeypatch):
"""Fixture to disable diagnostics for tests that use mocking.

This is needed because AIInferenceInstrumentor's instrument/uninstrument
cycle interferes with class-level mocking of ChatCompletionsClient.complete.
"""
monkeypatch.setenv("SEMANTICKERNEL_EXPERIMENTAL_GENAI_ENABLE_OTEL_DIAGNOSTICS", "false")
monkeypatch.setenv("SEMANTICKERNEL_EXPERIMENTAL_GENAI_ENABLE_OTEL_DIAGNOSTICS_SENSITIVE", "false")


@pytest.fixture(scope="function")
def azure_ai_inference_client(azure_ai_inference_unit_test_env, request) -> ChatCompletionsClient | EmbeddingsClient:
"""Fixture to create Azure AI Inference client for unit tests."""
Expand Down Expand Up @@ -164,8 +175,8 @@ def mock_azure_ai_inference_chat_completion_response_with_tool_call(model_id) ->
ChatCompletionsToolCall(
id="test_id",
function=FunctionCall(
name="test_function",
arguments={"test_arg": "test_value"},
name="getLightStatus",
arguments='{"arg1": "test_value"}',
),
),
],
Expand Down Expand Up @@ -254,7 +265,7 @@ def mock_azure_ai_inference_streaming_chat_completion_response_with_tool_call(mo
id="test_id",
function=FunctionCall(
name="getLightStatus",
arguments={"arg1": "test_value"},
arguments='{"arg1": "test_value"}',
),
),
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,12 @@ async def test_azure_ai_inference_chat_completion_with_function_choice_behavior_
async def test_azure_ai_inference_chat_completion_with_function_choice_behavior(
mock_complete,
azure_ai_inference_service,
kernel,
kernel: Kernel,
chat_history: ChatHistory,
mock_azure_ai_inference_chat_completion_response_with_tool_call,
mock_azure_ai_inference_chat_completion_response,
decorated_native_function,
disabled_model_diagnostics_test_env,
) -> None:
"""Test completion of AzureAIInferenceChatCompletion with function choice behavior"""
user_message_content: str = "Hello"
Expand All @@ -284,7 +287,13 @@ async def test_azure_ai_inference_chat_completion_with_function_choice_behavior(
)
settings.function_choice_behavior.maximum_auto_invoke_attempts = 1

mock_complete.return_value = mock_azure_ai_inference_chat_completion_response_with_tool_call
# First call returns tool call, second call returns final response
mock_complete.side_effect = [
mock_azure_ai_inference_chat_completion_response_with_tool_call,
mock_azure_ai_inference_chat_completion_response,
]

kernel.add_function(plugin_name="TestPlugin", function=decorated_native_function)

responses = await azure_ai_inference_service.get_chat_message_contents(
chat_history=chat_history,
Expand All @@ -293,13 +302,13 @@ async def test_azure_ai_inference_chat_completion_with_function_choice_behavior(
arguments=KernelArguments(),
)

# The function should be called twice:
# Completion should be called twice:
# One for the tool call and one for the last completion
# after the maximum_auto_invoke_attempts is reached
assert mock_complete.call_count == 2
assert len(responses) == 1
assert responses[0].role == "assistant"
assert responses[0].finish_reason == FinishReason.TOOL_CALLS
assert responses[0].content == "Hello"


@pytest.mark.parametrize(
Expand Down
Loading
Loading