Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Signed-off-by: Abhishek P (VMware) <[email protected]>
Browse files Browse the repository at this point in the history
Converted HFDatasetSplitReader to HFDatasetReader
Now all splits can be used in the same reader
Support for both pre-load of all splits or on demand load of the split
Reduced tests to glue-cola dataset:config which is ~ 0.36MB download
Updated dataset dep to be the range of >=1.5.0 and <1.6.0
  • Loading branch information
Abhishek-P committed Apr 7, 2021
1 parent 6e613b9 commit f77cfa3
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 26 deletions.
63 changes: 43 additions & 20 deletions allennlp/data/dataset_readers/huggingface_datasets_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from allennlp.data import DatasetReader, Token, Field
from allennlp.data.fields import TextField, LabelField, ListField
from allennlp.data.instance import Instance
from datasets import load_dataset, Dataset, DatasetDict
from datasets import load_dataset, Dataset, DatasetDict, Split
from datasets.features import ClassLabel, Sequence, Translation, TranslationVariableLanguages
from datasets.features import Value

# TODO pab complete the documentation comments
class HuggingfaceDatasetSplitReader(DatasetReader):

# TODO pab-vmware complete the documentation comments
class HuggingfaceDatasetReader(DatasetReader):
"""
This reader implementation wraps the huggingface datasets package
to utilize it's dataset management functionality and load the information in AllenNLP friendly formats
Expand Down Expand Up @@ -44,6 +45,8 @@ class HuggingfaceDatasetSplitReader(DatasetReader):
pre_load : `bool`, optional (default='False`)
"""

SUPPORTED_SPLITS = [Split.TRAIN, Split.TEST, Split.VALIDATION]

def __init__(
self,
max_instances: Optional[int] = None,
Expand All @@ -52,7 +55,7 @@ def __init__(
serialization_dir: Optional[str] = None,
dataset_name: str = None,
config_name: Optional[str] = None,
pre_load: Optional[bool] = False
pre_load: Optional[bool] = False,
) -> None:
super().__init__(
max_instances,
Expand All @@ -61,7 +64,7 @@ def __init__(
serialization_dir,
)

# It would be cleaner to create a separate reader object for different dataset
# It would be cleaner to create a separate reader object for diferent dataset
self.dataset: Dataset = None
self.datasets: DatasetDict = DatasetDict()
self.dataset_name = dataset_name
Expand All @@ -77,22 +80,33 @@ def load_dataset(self):
else:
self.datasets = load_dataset(self.dataset_name)

def load_dataset_split(self, split):
if self.config_name is not None:
self.datasets[split] = load_dataset(self.dataset_name, self.config_name, split=split)
def load_dataset_split(self, split: str):
# TODO add support for datasets.split.NamedSplit
if split in self.SUPPORTED_SPLITS:
if self.config_name is not None:
self.datasets[split] = load_dataset(
self.dataset_name, self.config_name, split=split
)
else:
self.datasets[split] = load_dataset(self.dataset_name, split=split)
else:
self.datasets[split] = load_dataset(self.dataset_name, split=split)
raise ValueError(
f"Only default splits:{self.SUPPORTED_SPLITS} are currently supported."
)

def _read(self, file_path) -> Iterable[Instance]:
def _read(self, file_path: str) -> Iterable[Instance]:
"""
Reads the dataset and converts the entry to AllenNLP friendly instance
"""
if file_path is None:
raise ValueError("parameter split cannot be None")

# If split is not loaded, load the specific split
if file_path not in self.datasets:
self.load_dataset_split(file_path)

if self.datasets is not None and self.datasets[file_path] is not None:
for entry in self.datasets[file_path]:
yield self.text_to_instance(entry)
for entry in self.datasets[file_path]:
yield self.text_to_instance(entry)

def raise_feature_not_supported_value_error(self, value):
raise ValueError(f"Datasets feature type {type(value)} is not supported yet.")
Expand Down Expand Up @@ -136,7 +150,9 @@ def text_to_instance(self, *inputs) -> Instance:

# datasets ClassLabel maps to LabelField
if isinstance(value, ClassLabel):
field = LabelField(inputs[0][feature], label_namespace=feature, skip_indexing=True)
field = LabelField(
inputs[0][feature], label_namespace=feature, skip_indexing=True
)

# datasets Value can be of different types
elif isinstance(value, Value):
Expand Down Expand Up @@ -179,30 +195,35 @@ def text_to_instance(self, *inputs) -> Instance:
else:
self.raise_feature_not_supported_value_error(value)


# datasets.Translation cannot be mapped directly
# but it's dict structure can be mapped to a ListField of 2 ListField
elif isinstance(value, Translation):
if value.dtype == "dict":
input_dict = inputs[0][feature]
langs = list(input_dict.keys())
field_langs = [LabelField(lang, label_namespace="languages") for lang in langs]
field_langs = [
LabelField(lang, label_namespace="languages") for lang in langs
]
langs_field = ListField(field_langs)
texts = list()
for lang in langs:
texts.append(TextField([Token(input_dict[lang])]))
field = ListField([langs_field, ListField(texts)])

else:
raise ValueError(f"Datasets feature type {type(value)} is not supported yet.")
raise ValueError(
f"Datasets feature type {type(value)} is not supported yet."
)

# datasets.TranslationVariableLanguages
# is functionally a pair of Lists and hence mapped to a ListField of 2 ListField
elif isinstance(value, TranslationVariableLanguages):
if value.dtype == "dict":
input_dict = inputs[0][feature]
langs = input_dict["language"]
field_langs = [LabelField(lang, label_namespace="languages") for lang in langs]
field_langs = [
LabelField(lang, label_namespace="languages") for lang in langs
]
langs_field = ListField(field_langs)
texts = list()
for lang in langs:
Expand All @@ -211,12 +232,14 @@ def text_to_instance(self, *inputs) -> Instance:
field = ListField([langs_field, ListField(texts)])

else:
raise ValueError(f"Datasets feature type {type(value)} is not supported yet.")
raise ValueError(
f"Datasets feature type {type(value)} is not supported yet."
)

else:
raise ValueError(f"Datasets feature type {type(value)} is not supported yet.")

if field:
if field is not None:
fields[feature] = field

return Instance(fields)
14 changes: 8 additions & 6 deletions tests/data/dataset_readers/huggingface_datasets_test.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import pytest

from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetSplitReader
from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetReader
import logging

logger = logging.getLogger(__name__)


# TODO these UTs are actually downloading the datasets and will be very very slow
# TODO add UT were we compare huggingface wrapped reader with an explicitly coded builder
# TODO add UT were we compare huggingface wrapped reader with an explicitly coded dataset
class HuggingfaceDatasetSplitReaderTest:

"""
Running the tests for supported datasets which require config name to be specified
Running the tests for supported datasets which require config name to be specified
"""
@pytest.mark.parametrize("dataset, config, split", (("glue", "cola", "train"), ("glue", "cola", "test")))

@pytest.mark.parametrize(
"dataset, config, split", (("glue", "cola", "train"), ("glue", "cola", "test"))
)
def test_read_for_datasets_requiring_config(self, dataset, config, split):
huggingface_reader = HuggingfaceDatasetSplitReader(dataset_name=dataset, config_name=config)
huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset, config_name=config)
instances = list(huggingface_reader.read(split))
assert len(instances) == len(huggingface_reader.datasets[split])
print(instances[0], print(huggingface_reader.datasets[split][0]))

0 comments on commit f77cfa3

Please sign in to comment.