Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,30 @@ def execute(

def run_subprocess(self, command, directory, input, executable, non_zero_throw=True):
# Start a child process and start reading output immediately
if isinstance(command, str):
import warnings

warnings.warn(
"String commands are deprecated; use list form for safety",
DeprecationWarning,
stacklevel=2,
)
use_shell = True
else:
use_shell = False
if executable is not None:
raise TypeError(
"executable is not supported with list-form commands; "
"include the binary as the first element of the list"
)

with subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if input else None,
cwd=directory,
shell=True,
shell=use_shell,
executable=executable,
) as process:
if input:
Expand Down
45 changes: 39 additions & 6 deletions agent/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import re
import shutil
import tarfile
import time
from datetime import datetime
from shlex import quote
Expand Down Expand Up @@ -136,6 +137,36 @@ def restore_site(
finally:
self.bench.drop_mariadb_user(self.name, mariadb_root_password, self.database)

@staticmethod
def _safe_extract_tar(path: str, dest: str, strip: int = 0):
"""Extract a tar archive safely using Python tarfile.

Strips ``strip`` leading path components from each member.
Rejects symlinks, hardlinks, absolute paths, and parent-traversal.
"""
with tarfile.open(path) as tar:
members = tar.getmembers()
valid = []
for member in members:
if member.issym() or member.islnk():
raise tarfile.ExtractError(f"Refusing to extract link: {member.name}")
if os.path.isabs(member.name):
raise tarfile.ExtractError(f"Refusing absolute path: {member.name}")
parts = member.name.split("/")
if ".." in parts:
raise tarfile.ExtractError(f"Refusing parent traversal: {member.name}")
if strip:
stripped = "/".join(parts[strip:])
if not stripped:
continue
member.name = stripped
dest_real = os.path.realpath(dest)
target = os.path.realpath(os.path.join(dest, member.name))
if not target.startswith(dest_real + os.sep):
raise tarfile.ExtractError(f"Refusing path outside destination: {member.name}")
valid.append(member)
tar.extractall(path=dest, members=valid)
Comment on lines +147 to +168

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 dest_real is recomputed via os.path.realpath on every loop iteration. Since dest never changes within the loop, this should be moved above the loop to avoid repeated syscalls proportional to archive member count.

Suggested change
with tarfile.open(path) as tar:
members = tar.getmembers()
valid = []
for member in members:
if member.issym() or member.islnk():
raise tarfile.ExtractError(f"Refusing to extract link: {member.name}")
if os.path.isabs(member.name):
raise tarfile.ExtractError(f"Refusing absolute path: {member.name}")
parts = member.name.split("/")
if ".." in parts:
raise tarfile.ExtractError(f"Refusing parent traversal: {member.name}")
if strip:
stripped = "/".join(parts[strip:])
if not stripped:
continue
member.name = stripped
dest_real = os.path.realpath(dest)
target = os.path.realpath(os.path.join(dest, member.name))
if not target.startswith(dest_real + os.sep):
raise tarfile.ExtractError(f"Refusing path outside destination: {member.name}")
valid.append(member)
tar.extractall(path=dest, members=valid)
with tarfile.open(path) as tar:
members = tar.getmembers()
valid = []
dest_real = os.path.realpath(dest)
for member in members:
if member.issym() or member.islnk():
raise tarfile.ExtractError(f"Refusing to extract link: {member.name}")
if os.path.isabs(member.name):
raise tarfile.ExtractError(f"Refusing absolute path: {member.name}")
parts = member.name.split("/")
if ".." in parts:
raise tarfile.ExtractError(f"Refusing parent traversal: {member.name}")
if strip:
stripped = "/".join(parts[strip:])
if not stripped:
continue
member.name = stripped
target = os.path.realpath(os.path.join(dest, member.name))
if not target.startswith(dest_real + os.sep):
raise tarfile.ExtractError(f"Refusing path outside destination: {member.name}")
valid.append(member)
tar.extractall(path=dest, members=valid)
Prompt To Fix With AI
This is a comment left during a code review.
Path: agent/site.py
Line: 147-168

Comment:
`dest_real` is recomputed via `os.path.realpath` on every loop iteration. Since `dest` never changes within the loop, this should be moved above the loop to avoid repeated syscalls proportional to archive member count.

```suggestion
        with tarfile.open(path) as tar:
            members = tar.getmembers()
            valid = []
            dest_real = os.path.realpath(dest)
            for member in members:
                if member.issym() or member.islnk():
                    raise tarfile.ExtractError(f"Refusing to extract link: {member.name}")
                if os.path.isabs(member.name):
                    raise tarfile.ExtractError(f"Refusing absolute path: {member.name}")
                parts = member.name.split("/")
                if ".." in parts:
                    raise tarfile.ExtractError(f"Refusing parent traversal: {member.name}")
                if strip:
                    stripped = "/".join(parts[strip:])
                    if not stripped:
                        continue
                    member.name = stripped
                target = os.path.realpath(os.path.join(dest, member.name))
                if not target.startswith(dest_real + os.sep):
                    raise tarfile.ExtractError(f"Refusing path outside destination: {member.name}")
                valid.append(member)
            tar.extractall(path=dest, members=valid)
```

How can I resolve this? If you propose a fix, please make it concise.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


@step("Restore Files")
def restore_files(
self,
Expand All @@ -152,9 +183,10 @@ def restore_files(
finally:
os.makedirs(dir_path, exist_ok=True)

self.execute(
f"tar {'z' if public_file.endswith('.tgz') else ''}xvf {public_file} --strip 2",
directory=os.path.join(sites_directory, self.name),
self._safe_extract_tar(
public_file,
dest=os.path.join(sites_directory, self.name),
strip=2,
)

if private_file:
Expand All @@ -164,9 +196,10 @@ def restore_files(
finally:
os.makedirs(dir_path, exist_ok=True)

self.execute(
f"tar {'z' if private_file.endswith('.tgz') else ''}xvf {private_file} --strip 2",
directory=os.path.join(sites_directory, self.name),
self._safe_extract_tar(
private_file,
dest=os.path.join(sites_directory, self.name),
strip=2,
)

@step("Checksum of Downloaded Backup Files")
Expand Down
11 changes: 10 additions & 1 deletion agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
import os
import re
import secrets
import shutil
import struct
import subprocess
Expand Down Expand Up @@ -58,7 +59,15 @@ def to_bytes(size_str: str) -> float:

def download_file(url, prefix):
"""Download file locally under path prefix and return local path"""
filename = urlparse(url).path.split("/")[-1]
basename = os.path.basename(urlparse(url).path)
ext = ""
for known in (".sql.gz", ".tar.gz", ".tgz", ".sql", ".gz", ".tar"):
if basename.endswith(known):
ext = known
break
if ext and not all(c.isalnum() or c in "._-" for c in ext):
ext = ""
filename = secrets.token_urlsafe(16) + ext
local_filename = os.path.join(prefix, filename)

with requests.get(url, stream=True) as r:
Expand Down
Loading