Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 12 additions & 59 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import logging
import os
import shutil
import subprocess
import tempfile
from abc import ABC
from contextlib import suppress
Expand All @@ -26,7 +25,6 @@

from litdata.constants import (
_AZURE_STORAGE_AVAILABLE,
_DISABLE_S5CMD,
_GOOGLE_STORAGE_AVAILABLE,
_HF_HUB_AVAILABLE,
_INDEX_FILENAME,
Expand Down Expand Up @@ -122,10 +120,11 @@ def __init__(
# check if kwargs contains session_options
self.session_options = kwargs.get("session_options", {})

if not self._s5cmd_available or _DISABLE_S5CMD:
self._client = S3Client(storage_options=self._storage_options, session_options=self.session_options)
self._client = S3Client(storage_options=self._storage_options, session_options=self.session_options)

def download_file(self, remote_filepath: str, local_filepath: str) -> None:
import obstore as obs

obj = parse.urlparse(remote_filepath)

if obj.scheme != "s3":
Expand All @@ -134,65 +133,19 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
if os.path.exists(local_filepath):
return

bucket = obj.netloc
key = obj.path.lstrip("/")

with (
suppress(Timeout, FileNotFoundError),
FileLock(local_filepath + ".lock", timeout=1 if obj.path.endswith(_INDEX_FILENAME) else 0),
):
if self._s5cmd_available and not _DISABLE_S5CMD:
env = None
if self._storage_options:
env = os.environ.copy()
env.update(self._storage_options)

aws_no_sign_request = self._storage_options.get("AWS_NO_SIGN_REQUEST", "no").lower() == "yes"
# prepare the s5cmd command
no_signed_option = "--no-sign-request" if aws_no_sign_request else None
cmd_parts = ["s5cmd", no_signed_option, "cp", remote_filepath, local_filepath]
cmd = " ".join(part for part in cmd_parts if part)

proc = subprocess.Popen(
cmd,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
)
return_code = proc.wait()

if return_code != 0:
stderr_output = proc.stderr.read().decode().strip() if proc.stderr else ""
error_message = (
f"Failed to execute command `{cmd}` (exit code: {return_code}). "
"This might be due to an incorrect file path, insufficient permissions, or network issues. "
"To resolve this issue, you can either:\n"
"- Pass `storage_options` with the necessary credentials and endpoint. \n"
"- Example:\n"
" storage_options = {\n"
' "AWS_ACCESS_KEY_ID": "your-key",\n'
' "AWS_SECRET_ACCESS_KEY": "your-secret",\n'
' "S3_ENDPOINT_URL": "https://s3.example.com" (Optional if using AWS)\n'
" }\n"
"- or disable `s5cmd` by setting `DISABLE_S5CMD=1` if `storage_options` do not work.\n"
)
if stderr_output:
error_message += (
f"For further debugging, please check the command output below:\n{stderr_output}"
)
raise RuntimeError(error_message)
else:
from boto3.s3.transfer import TransferConfig

extra_args: dict[str, Any] = {}

if not os.path.exists(local_filepath):
# Issue: https://github.com/boto/boto3/issues/3113
self._client.client.download_file(
obj.netloc,
obj.path.lstrip("/"),
local_filepath,
ExtraArgs=extra_args,
Config=TransferConfig(use_threads=False),
)
store = self._get_store(bucket)
resp = obs.get(store, key)
with tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(local_filepath)) as tmpfile:
for chunk in resp:
tmpfile.write(chunk)
os.replace(tmpfile.name, local_filepath)

def download_bytes(self, remote_filepath: str, offset: int, length: int, local_chunkpath: str) -> bytes:
obj = parse.urlparse(remote_filepath)
Expand Down
Loading