diff --git a/taming/data/imagenet.py b/taming/data/imagenet.py index 9a02ec44..e07e4b64 100644 --- a/taming/data/imagenet.py +++ b/taming/data/imagenet.py @@ -154,7 +154,26 @@ def _prepare(self): print("Extracting {} to {}".format(path, datadir)) os.makedirs(datadir, exist_ok=True) with tarfile.open(path, "r:") as tar: - tar.extractall(path=datadir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=datadir) print("Extracting sub-tars.") subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) @@ -162,7 +181,26 @@ def _prepare(self): subdir = subpath[:-len(".tar")] os.makedirs(subdir, exist_ok=True) with tarfile.open(subpath, "r:") as tar: - tar.extractall(path=subdir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=subdir) filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) @@ -212,7 +250,26 @@ def _prepare(self): print("Extracting {} to {}".format(path, datadir)) os.makedirs(datadir, exist_ok=True) with tarfile.open(path, "r:") as tar: - tar.extractall(path=datadir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=datadir) vspath = os.path.join(self.root, self.FILES[1]) if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: diff --git a/taming/data/utils.py b/taming/data/utils.py index 2b3c3d53..8a6b0901 100644 --- a/taming/data/utils.py +++ b/taming/data/utils.py @@ -16,10 +16,48 @@ def unpack(path): if path.endswith("tar.gz"): with tarfile.open(path, "r:gz") as tar: - tar.extractall(path=os.path.split(path)[0]) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=os.path.split(path)["0"]) elif path.endswith("tar"): with tarfile.open(path, "r:") as tar: - tar.extractall(path=os.path.split(path)[0]) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=os.path.split(path)["0"]) elif path.endswith("zip"): with zipfile.ZipFile(path, "r") as f: f.extractall(path=os.path.split(path)[0])