Skip to content

Commit

Permalink
Enhance LoadHF class to support optional splits and improve dataset l…
Browse files Browse the repository at this point in the history
…oading logic

Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel committed Feb 4, 2025
1 parent 8ce505d commit 24d8f49
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 71 deletions.
115 changes: 46 additions & 69 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class LoadHF(Loader):
filtering_lambda: Optional[str] = None
num_proc: Optional[int] = None
requirements_list: List[str] = OptionalField(default_factory=list)
splits: Optional[List[str]] = None

def verify(self):
for requirement in self.requirements_list:
Expand All @@ -243,63 +244,42 @@ def is_streaming(self) -> bool:
return settings.stream_hf_datasets_by_default
return self.streaming

def stream_dataset(self, split=None):
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
if settings.disable_hf_datasets_cache and not self.is_streaming():
cache_dir = dir_to_be_deleted
else:
cache_dir = None
try:
dataset = hf_load_dataset(
self.path,
name=self.name,
data_dir=self.data_dir,
data_files=self.data_files,
revision=self.revision,
streaming=self.is_streaming(),
cache_dir=cache_dir,
split=split,
trust_remote_code=settings.allow_unverified_code,
num_proc=self.num_proc,
)
except ValueError as e:
if "trust_remote_code" in str(e):
raise ValueError(
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
) from e
raise e

return dataset

# returns Dict when split names are not known in advance, and just the the single split dataset - if known
def load_dataset(self, split: str) -> Union[IterableDatasetDict, IterableDataset]:
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
if settings.disable_hf_datasets_cache:
cache_dir = dir_to_be_deleted
else:
cache_dir = None
try:
dataset = hf_load_dataset(
self.path,
name=self.name,
data_dir=self.data_dir,
data_files=self.data_files,
streaming=False,
keep_in_memory=True,
cache_dir=cache_dir,
split=split,
trust_remote_code=settings.allow_unverified_code,
num_proc=self.num_proc,
)
except ValueError as e:
if "trust_remote_code" in str(e):
raise ValueError(
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
) from e

if self.split is not None:
dataset = {self.split: dataset}
def load_dataset(
self, split: str, streaming=None
) -> Union[IterableDatasetDict, IterableDataset]:
dataset = self.__class__._loader_cache.get(str(self) + "_" + str(split), None)
if dataset is None:
if streaming is None:
streaming = self.is_streaming()

with tempfile.TemporaryDirectory() as dir_to_be_deleted:
if settings.disable_hf_datasets_cache:
cache_dir = dir_to_be_deleted
else:
cache_dir = None
try:
dataset = hf_load_dataset(
self.path,
name=self.name,
data_dir=self.data_dir,
data_files=self.data_files,
revision=self.revision,
streaming=streaming,
keep_in_memory=True,
cache_dir=cache_dir,
verification_mode="no_checks",
split=split,
trust_remote_code=settings.allow_unverified_code,
num_proc=self.num_proc,
)
except ValueError as e:
if "trust_remote_code" in str(e):
raise ValueError(
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
) from e
self.__class__._loader_cache.max_size = settings.loader_cache_size
self.__class__._loader_cache[str(self) + "_" + str(split)] = dataset
return dataset

def _maybe_set_classification_policy(self):
Expand Down Expand Up @@ -365,22 +345,19 @@ def load_iterables(
return dataset

def split_generator(self, split: str) -> Generator:
dataset = self.__class__._loader_cache.get(str(self) + "_" + split, None)
if dataset is None:
try:
dataset = self.stream_dataset(split)
except NotImplementedError: # streaming is not supported for zipped files so we load without streaming
dataset = self.load_dataset(split)

if self.filtering_lambda is not None:
dataset = self.filter_load(dataset)
try:
dataset = self.load_dataset(split=split)
except (
NotImplementedError
): # streaming is not supported for zipped files so we load without streaming
dataset = self.load_dataset(split=split, streaming=False)

limit = self.get_limit()
if limit is not None:
dataset = dataset.take(limit)
if self.filtering_lambda is not None:
dataset = self.filter_load(dataset)

self.__class__._loader_cache.max_size = settings.loader_cache_size
self.__class__._loader_cache[str(self) + "_" + split] = dataset
limit = self.get_limit()
if limit is not None:
dataset = dataset.take(limit)

yield from dataset

Expand Down
4 changes: 2 additions & 2 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
"filename": "src/unitxt/loaders.py",
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
"is_verified": false,
"line_number": 603,
"line_number": 580,
"is_secret": false
}
],
Expand Down Expand Up @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2025-02-03T08:16:28Z"
"generated_at": "2025-02-04T09:34:01Z"
}

0 comments on commit 24d8f49

Please sign in to comment.