This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add HuggingfaceDatasetSplitReader for using Huggingface datasets
Added a new reader to allow for reading huggingface datasets as instance Mapped limited `datasets.features` to `allenlp.data.fields` Verified for selective dataset and/or dataset configurations Signed-off-by: Abhishek P (VMware) <[email protected]>
- Loading branch information
1 parent
f82d3f1
commit 08d3012
Showing
3 changed files
with
217 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
170 changes: 170 additions & 0 deletions
170
allennlp/data/dataset_readers/hugging_face_datasets_reader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
from typing import Iterable, Optional | ||
|
||
from allennlp.data import DatasetReader, Token | ||
from allennlp.data.fields import TextField, LabelField, ListField | ||
from allennlp.data.instance import Instance | ||
from datasets import load_dataset | ||
from datasets.features import ClassLabel, Sequence, Translation, TranslationVariableLanguages | ||
from datasets.features import Value | ||
|
||
|
||
class HuggingfaceDatasetSplitReader(DatasetReader): | ||
""" | ||
This reader implementation wraps the huggingface datasets package to utilize it's dataset management functionality | ||
and load the information in AllenNLP friendly formats | ||
Note: Reader works w.r.t to only one split of the dataset, i.e. you would need to create seperate reader for seperate splits | ||
Following dataset and configurations have been verified and work with this reader | ||
Dataset Dataset Configuration | ||
`xnli` `ar` | ||
`xnli` `en` | ||
`xnli` `de` | ||
`xnli` `all_languages` | ||
`glue` `cola` | ||
`glue` `mrpc` | ||
`glue` `sst2` | ||
`glue` `qqp` | ||
`glue` `mnli` | ||
`glue` `mnli_matched` | ||
`universal_dependencies` `en_lines` | ||
`universal_dependencies` `ko_kaist` | ||
`universal_dependencies` `af_afribooms` | ||
`afrikaans_ner_corpus` `NA` | ||
`swahili` `NA` | ||
`conll2003` `NA` | ||
`dbpedia_14` `NA` | ||
`trec` `NA` | ||
`emotion` `NA` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
max_instances: Optional[int] = None, | ||
manual_distributed_sharding: bool = False, | ||
manual_multiprocess_sharding: bool = False, | ||
serialization_dir: Optional[str] = None, | ||
dataset_name: [str] = None, | ||
split: str = 'train', | ||
config_name: Optional[str] = None, | ||
) -> None: | ||
super().__init__(max_instances, manual_distributed_sharding, manual_multiprocess_sharding, serialization_dir) | ||
|
||
# It would be cleaner to create a separate reader object for different dataset | ||
self.dataset = None | ||
self.dataset_name = dataset_name | ||
self.config_name = config_name | ||
self.index = -1 | ||
|
||
if config_name: | ||
self.dataset = load_dataset(self.dataset_name, self.config_name, split=split) | ||
else: | ||
self.dataset = load_dataset(self.dataset_name, split=split) | ||
|
||
def _read(self, file_path) -> Iterable[Instance]: | ||
""" | ||
Reads the dataset and converts the entry to AllenNLP friendly instance | ||
""" | ||
for entry in self.dataset: | ||
yield self.text_to_instance(entry) | ||
|
||
def text_to_instance(self, *inputs) -> Instance: | ||
""" | ||
Takes care of converting dataset entry into AllenNLP friendly instance | ||
Currently it is implemented in an unseemly catch-up model where it converts datasets.features that are required | ||
for the supported dataset, ideally it would require design where we cleanly map dataset.feature to an AllenNLP model | ||
and then go ahead with converting it one by one | ||
Doing that would provide the best chance of providing largest possible coverage with datasets | ||
Currently this is how datasets.features types are mapped to AllenNLP Fields | ||
dataset.feature type allennlp.data.fields | ||
`ClassLabel` `LabelField` in feature name namespace | ||
`Value.string` `TextField` with value as Token | ||
`Value.*` `LabelField` with value being label in feature name namespace | ||
`Sequence.string` `ListField` of `TextField` with individual string as token | ||
`Sequence.ClassLabel` `ListField` of `ClassLabel` in feature name namespace | ||
`Translation` `ListField` of 2 ListField (ClassLabel and TextField) | ||
`TranslationVariableLanguages` `ListField` of 2 ListField (ClassLabel and TextField) | ||
""" | ||
|
||
# features indicate the different information available in each entry from dataset | ||
# feature types decide what type of information they are | ||
# e.g. In a Sentiment an entry could have one feature indicating the text and another indica | ||
features = self.dataset.features | ||
fields = dict() | ||
|
||
# TODO we need to support all different datasets features of https://huggingface.co/docs/datasets/features.html | ||
for feature in features: | ||
value = features[feature] | ||
|
||
# datasets ClassLabel maps to LabelField | ||
if isinstance(value, ClassLabel): | ||
field = LabelField(inputs[0][feature], label_namespace=feature, skip_indexing=True) | ||
|
||
# datasets Value can be of different types | ||
elif isinstance(value, Value): | ||
|
||
# String value maps to TextField | ||
if value.dtype == 'string': | ||
# Since TextField has to be made of Tokens add whole text as a token | ||
# TODO Should we use simple heuristics to identify what is token and what is not? | ||
field = TextField([Token(inputs[0][feature])]) | ||
|
||
else: | ||
field = LabelField(inputs[0][feature], label_namespace=feature, skip_indexing=True) | ||
|
||
|
||
elif isinstance(value, Sequence): | ||
# datasets Sequence of strings to ListField of TextField | ||
if value.feature.dtype == 'string': | ||
field_list = list() | ||
for item in inputs[0][feature]: | ||
item_field = TextField([Token(item)]) | ||
field_list.append(item_field) | ||
if len(field_list) == 0: | ||
continue | ||
field = ListField(field_list) | ||
|
||
# datasets Sequence of strings to ListField of LabelField | ||
elif isinstance(value.feature, ClassLabel): | ||
field_list = list() | ||
for item in inputs[0][feature]: | ||
item_field = LabelField(label=item, label_namespace=feature, skip_indexing=True) | ||
field_list.append(item_field) | ||
if len(field_list) == 0: | ||
continue | ||
field = ListField(field_list) | ||
|
||
# datasets Translation cannot be mapped directly but it's dict structure can be mapped to a ListField of 2 ListField | ||
elif isinstance(value, Translation): | ||
if value.dtype == "dict": | ||
input_dict = inputs[0][feature] | ||
langs = list(input_dict.keys()) | ||
field_langs = [LabelField(lang, label_namespace="languages") for lang in langs] | ||
langs_field = ListField(field_langs) | ||
texts = list() | ||
for lang in langs: | ||
texts.append(TextField([Token(input_dict[lang])])) | ||
field = ListField([langs_field, ListField(texts)]) | ||
|
||
# TranslationVariableLanguages is functionally a pair of Lists and hence mapped to a ListField of 2 ListField | ||
elif isinstance(value, TranslationVariableLanguages): | ||
# Although it is indicated as dict made up of a pair of lists | ||
if value.dtype == "dict": | ||
input_dict = inputs[0][feature] | ||
langs = input_dict["language"] | ||
field_langs = [LabelField(lang, label_namespace="languages") for lang in langs] | ||
langs_field = ListField(field_langs) | ||
texts = list() | ||
for lang in langs: | ||
index = langs.index(lang) | ||
texts.append(TextField([Token(input_dict["translation"][index])])) | ||
field = ListField([langs_field, ListField(texts)]) | ||
|
||
else: | ||
raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") | ||
|
||
fields[feature] = field | ||
|
||
return Instance(fields) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import pytest | ||
|
||
from allennlp.data.dataset_readers.conll2003 import Conll2003DatasetReader | ||
from allennlp.data.dataset_readers.hugging_face_datasets_reader import HuggingfaceDatasetSplitReader | ||
from allennlp.common.checks import ConfigurationError | ||
from allennlp.common.util import ensure_list | ||
from allennlp.common.testing import AllenNlpTestCase | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# TODO these UTs are actually downloading the datasets and will be very very slow | ||
class HuggingfaceDatasetSplitReaderTest: | ||
|
||
|
||
SUPPORTED_DATASETS_WITHOUT_CONFIG = ["afrikaans_ner_corpus", "dbpedia_14", "trec", "swahili", "conll2003", "emotion"] | ||
|
||
""" | ||
Running the tests for supported datasets which do not require config name to be specified | ||
""" | ||
@pytest.mark.parametrize("dataset", SUPPORTED_DATASETS_WITHOUT_CONFIG) | ||
def test_read_for_datasets_without_config(self, dataset): | ||
huggingface_reader = HuggingfaceDatasetSplitReader(dataset_name=dataset) | ||
instances = list(huggingface_reader.read(None)) | ||
assert len(instances) == len(huggingface_reader.dataset) | ||
|
||
# Not testing for all configurations only some | ||
SUPPORTED_DATASET_CONFIGURATION = ( | ||
("glue", "cola"), | ||
("universal_dependencies", "af_afribooms"), | ||
("xnli", "all_languages") | ||
) | ||
|
||
""" | ||
Running the tests for supported datasets which require config name to be specified | ||
""" | ||
@pytest.mark.parametrize("dataset, config", SUPPORTED_DATASET_CONFIGURATION) | ||
def test_read_for_datasets_requiring_config(self, dataset, config): | ||
huggingface_reader = HuggingfaceDatasetSplitReader(dataset_name=dataset, config_name=config) | ||
instances = list(huggingface_reader.read(None)) | ||
assert len(instances) == len(huggingface_reader.dataset) | ||
|
||
|
||
|
||
|