Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,15 @@ async def handle_intermediate_steps(message: ChatMessageContent) -> None:
async def main():
credential = AzureCliCredential()

# Define the resources directory for file uploads
resources_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "resources")

# 1. Create the python code interpreter tool using the SessionsPythonTool
python_code_interpreter = SessionsPythonTool(credential=credential)
# allowed_upload_directories restricts which local directories can be accessed for uploads
python_code_interpreter = SessionsPythonTool(
credential=credential,
allowed_upload_directories=[resources_dir],
)

# 2. Create the agent
agent = ChatCompletionAgent(
Expand All @@ -41,7 +48,7 @@ async def main():
)

# 3. Upload a CSV file to the session
csv_file_path = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "resources", "sales.csv")
csv_file_path = os.path.join(resources_dir, "sales.csv")
file_metadata = await python_code_interpreter.upload_file(local_file_path=csv_file_path)

# 4. Invoke the agent for a response to a task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class SessionsPythonTool(KernelBaseModel):
settings: SessionsPythonSettings
auth_callback: Callable[..., Any | Awaitable[Any]]
http_client: AsyncClient
allowed_upload_directories: set[str] | None = None
"""Allowed local directories for file uploads. If None, upload_file is disabled (deny-by-default)."""
allowed_download_directories: set[str] | None = None
"""Allowed local directories for file downloads. If None, all paths are allowed (permissive-by-default)."""
Comment thread
eavanvalkenburg marked this conversation as resolved.

def __init__(
self,
Expand All @@ -48,9 +52,28 @@ def __init__(
env_file_path: str | None = None,
token_endpoint: str | None = None,
credential: TokenCredential | None = None,
allowed_upload_directories: set[str] | list[str] | None = None,
allowed_download_directories: set[str] | list[str] | None = None,
**kwargs,
):
"""Initializes a new instance of the SessionsPythonTool class."""
"""Initializes a new instance of the SessionsPythonTool class.

Args:
auth_callback: Callback to retrieve authentication token.
pool_management_endpoint: The ACA pool management endpoint URL.
settings: Python session settings.
http_client: HTTP client for making requests.
env_file_path: Path to .env file.
token_endpoint: Token endpoint for authentication.
credential: Azure credential for authentication.
allowed_upload_directories: Set or list of allowed directories for file uploads.
If None, upload_file will be disabled (deny-by-default).
Empty set/list means no directories are allowed (all uploads denied).
allowed_download_directories: Set or list of allowed directories for file downloads.
If None, all paths are allowed (permissive-by-default).
Empty set/list means no directories are allowed (all local downloads denied).
kwargs: Additional keyword arguments.
"""
try:
aca_settings = ACASessionsSettings(
env_file_path=env_file_path,
Expand All @@ -70,11 +93,19 @@ def __init__(
if auth_callback is None:
auth_callback = self._default_auth_callback(aca_settings, credential)

# Convert lists to sets and filter out empty strings (which resolve to CWD)
upload_dirs = {d for d in allowed_upload_directories if d} if allowed_upload_directories is not None else None
download_dirs = (
{d for d in allowed_download_directories if d} if allowed_download_directories is not None else None
)

super().__init__(
pool_management_endpoint=aca_settings.pool_management_endpoint,
settings=settings,
auth_callback=auth_callback,
http_client=http_client,
allowed_upload_directories=upload_dirs,
allowed_download_directories=download_dirs,
**kwargs,
)

Expand Down Expand Up @@ -145,6 +176,74 @@ def _build_url_with_version(self, base_url, endpoint, params):
endpoint = endpoint[:-1]
return f"{base_url}{endpoint}?{query_string}"

def _validate_local_path_for_upload(self, local_file_path: str) -> str:
"""Validate local path is within allowed upload directories.

Args:
local_file_path: The path to validate.

Returns:
str: The canonicalized absolute path.

Raises:
FunctionExecutionException: If the path is not within allowed directories.
"""
if self.allowed_upload_directories is None:
raise FunctionExecutionException(
"File upload is disabled. Configure 'allowed_upload_directories' to enable."
)
Comment thread
moonbox3 marked this conversation as resolved.

canonical_path = os.path.realpath(local_file_path)

for allowed_dir in self.allowed_upload_directories:
allowed_canonical = os.path.realpath(allowed_dir)
try:
common = os.path.commonpath([allowed_canonical, canonical_path])
if common == allowed_canonical:
return canonical_path
except ValueError:
continue # Different drives on Windows

logger.warning(f"Upload denied for path: {local_file_path} (resolved: {canonical_path})")
raise FunctionExecutionException(
f"Access denied: '{local_file_path}' is not within allowed upload directories."
)

def _validate_local_path_for_download(self, local_file_path: str) -> str:
"""Validate local path is within allowed download directories (optional protection).

Args:
local_file_path: The path to validate.

Returns:
str: The canonicalized absolute path.

Raises:
FunctionExecutionException: If allowed_download_directories is set and path is not within.
"""
# Permissive by default - if no restrictions configured, allow all paths
if self.allowed_download_directories is None:
return os.path.realpath(local_file_path)

parent_dir = os.path.dirname(local_file_path) or "."
canonical_parent = os.path.realpath(parent_dir)
filename = os.path.basename(local_file_path)
canonical_path = os.path.join(canonical_parent, filename)

for allowed_dir in self.allowed_download_directories:
allowed_canonical = os.path.realpath(allowed_dir)
try:
common = os.path.commonpath([allowed_canonical, canonical_parent])
if common == allowed_canonical:
return canonical_path
except ValueError:
Comment thread
moonbox3 marked this conversation as resolved.
continue

logger.warning(f"Download denied for path: {local_file_path}")
raise FunctionExecutionException(
f"Access denied: '{local_file_path}' is not within allowed download directories."
)

# endregion

# region Kernel Functions
Expand Down Expand Up @@ -230,17 +329,21 @@ async def upload_file(
Args:
remote_file_path (str): The path to the file in the session.
local_file_path (str): The path to the file on the local machine.
Must be within allowed_upload_directories.

Returns:
RemoteFileMetadata: The metadata of the uploaded file.

Raises:
FunctionExecutionException: If local_file_path is not provided.
FunctionExecutionException: If local_file_path is not provided or not in allowed directories.
"""
if not local_file_path:
raise FunctionExecutionException("Please provide a local file path to upload.")

remote_file_path = self._construct_remote_file_path(remote_file_path or os.path.basename(local_file_path))
# Validate path is in allowed directories (deny-by-default)
validated_path = self._validate_local_path_for_upload(local_file_path)

remote_file_path = self._construct_remote_file_path(remote_file_path or os.path.basename(validated_path))

auth_token = await self._ensure_auth_token()
self.http_client.headers.update({
Expand All @@ -255,7 +358,7 @@ async def upload_file(
)

try:
with open(local_file_path, "rb") as data:
with open(validated_path, "rb") as data:
files = {"file": (remote_file_path, data, "application/octet-stream")}
response = await self.http_client.post(url=url, files=files)
response.raise_for_status()
Expand Down Expand Up @@ -312,10 +415,14 @@ async def download_file(
Args:
remote_file_name: The name of the file to download, relative to `/mnt/data`.
local_file_path: The path to save the downloaded file to. Should include the extension.
If not provided, the file is returned as a BufferedReader.
If not provided, the file is returned as a BytesIO object.
If allowed_download_directories is configured, must be within those directories.

Returns:
BufferedReader: The data of the downloaded file.
BytesIO | None: The file content as BytesIO if no local_file_path provided, otherwise None.

Raises:
FunctionExecutionException: If local_file_path is not in allowed directories (when configured).
"""
auth_token = await self._ensure_auth_token()
self.http_client.headers.update({
Expand All @@ -335,7 +442,9 @@ async def download_file(
)
response.raise_for_status()
if local_file_path:
with open(local_file_path, "wb") as f:
# Validate path is in allowed directories (optional, permissive by default)
validated_path = self._validate_local_path_for_download(local_file_path)
with open(validated_path, "wb") as f:
f.write(response.content)
Comment thread
moonbox3 marked this conversation as resolved.
return None

Expand Down
Loading
Loading