From af9661e6a421ebc7c799655ad0793a2b7818cbf6 Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Wed, 31 Mar 2021 20:12:55 +0530 Subject: [PATCH 01/63] Add `HuggingfaceDatasetReader` for using Huggingface `datasets` Introduced new dependency - "datasets>=1.5.0,<1.6.0"" Added a new reader to allow for reading huggingface datasets as instance Mapped limited `datasets.features` to `allenlp.data.fields` Added Tests for the same Verified for selective dataset and/or dataset configurations Added `test-with-cov-html` to provide contributor friendly html coverage report Signed-off-by: Abhishek P (VMware) <pab@vmware.com> --- .gitignore | 1 + CHANGELOG.md | 2 +- Makefile | 7 + .../huggingface_datasets_reader.py | 276 ++++++++++++++++++ setup.py | 1 + .../huggingface_datasets_reader_test.py | 184 ++++++++++++ 6 files changed, 470 insertions(+), 1 deletion(-) create mode 100644 allennlp/data/dataset_readers/huggingface_datasets_reader.py create mode 100644 tests/data/dataset_readers/huggingface_datasets_reader_test.py diff --git a/.gitignore b/.gitignore index 6917232047e..2c0ee5edca6 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,7 @@ __pycache__ .coverage .pytest_cache/ .benchmarks +htmlcov/ # documentation build artifacts diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ca16d3c4f8..d394cb03dae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Print the first batch to the console by default. ### Added - +- Added `HuggingfaceDatasetReader` for using huggingface datasets in AllenNLP - Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.sanity_checks.task_checklists` module. - Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files. - Added `allennlp.nn.util.load_state_dict` helper function. diff --git a/Makefile b/Makefile index 228796e56be..365260bbceb 100644 --- a/Makefile +++ b/Makefile @@ -66,6 +66,13 @@ test-with-cov : --cov=$(SRC) \ --cov-report=xml +.PHONY : test-with-cov-html +test-with-cov-html : + pytest --color=yes -rf --durations=40 \ + --cov-config=.coveragerc \ + --cov=$(SRC) \ + --cov-report=html + .PHONY : gpu-test gpu-test : check-for-cuda pytest --color=yes -v -rf -m gpu diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py new file mode 100644 index 00000000000..af030a377f9 --- /dev/null +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -0,0 +1,276 @@ +import typing +from typing import Iterable, Optional + +from allennlp.data import DatasetReader, Token, Field, Tokenizer +from allennlp.data.fields import TextField, LabelField, ListField +from allennlp.data.instance import Instance +from datasets import load_dataset, DatasetDict, Split, list_datasets +from datasets.features import ClassLabel, Sequence, Translation, TranslationVariableLanguages +from datasets.features import Value + + +@DatasetReader.register("huggingface-datasets") +class HuggingfaceDatasetReader(DatasetReader): + """ + Reads instances from the given huggingface supported dataset + + This reader implementation wraps the huggingface datasets package + + Following dataset and configurations have been verified and work with this reader + + Dataset Dataset Configuration + `xnli` `ar` + `xnli` `en` + `xnli` `de` + `xnli` `all_languages` + `glue` `cola` + `glue` `mrpc` + `glue` `sst2` + `glue` `qqp` + `glue` `mnli` + `glue` `mnli_matched` + `universal_dependencies` `en_lines` + `universal_dependencies` `ko_kaist` + `universal_dependencies` `af_afribooms` + `swahili` `NA` + `conll2003` `NA` + `dbpedia_14` `NA` + `trec` `NA` + `emotion` `NA` + Note: universal_dependencies will require you to install `conllu` package separately + + Registered as a `DatasetReader` with name `huggingface-datasets` + + # Parameters + + dataset_name : `str` + Name of the dataset from huggingface datasets the reader will be used for. + config_name : `str`, optional (default=`None`) + Configuration(mandatory for some datasets) of the dataset. + preload : `bool`, optional (default=`False`) + If `True` all splits for the dataset is loaded(includes download etc) as part of the initialization, + otherwise each split is loaded on when `read()` is used for the same for the first time. + tokenizer : `Tokenizer`, optional (default=`None`) + If specified is used for tokenization of string and text fields from the dataset. + This is useful since text in allennlp is dealt with as a series of tokens. + """ + + SUPPORTED_SPLITS = [Split.TRAIN, Split.TEST, Split.VALIDATION] + + def __init__( + self, + dataset_name: str = None, + config_name: Optional[str] = None, + preload: Optional[bool] = False, + tokenizer: Optional[Tokenizer] = None, + **kwargs, + ) -> None: + super().__init__( + manual_distributed_sharding=True, + manual_multiprocess_sharding=True, + **kwargs, + ) + + # It would be cleaner to create a separate reader object for diferent dataset + if dataset_name not in list_datasets(): + raise ValueError(f"Dataset {dataset_name} not available in huggingface datasets") + self.dataset: DatasetDict = DatasetDict() + self.dataset_name = dataset_name + self.config_name = config_name + self.tokenizer = tokenizer + + if preload: + self.load_dataset() + + def load_dataset(self): + if self.config_name is not None: + self.dataset = load_dataset(self.dataset_name, self.config_name) + else: + self.dataset = load_dataset(self.dataset_name) + + def load_dataset_split(self, split: str): + # TODO add support for datasets.split.NamedSplit + if split in self.SUPPORTED_SPLITS: + if self.config_name is not None: + self.dataset[split] = load_dataset(self.dataset_name, self.config_name, split=split) + else: + self.dataset[split] = load_dataset(self.dataset_name, split=split) + else: + raise ValueError( + f"Only default splits:{self.SUPPORTED_SPLITS} are currently supported." + ) + + def _read(self, file_path: str) -> Iterable[Instance]: + """ + Reads the dataset and converts the entry to AllenNLP friendly instance + """ + if file_path is None: + raise ValueError("parameter split cannot be None") + + # If split is not loaded, load the specific split + if file_path not in self.dataset: + self.load_dataset_split(file_path) + + # TODO see if use of Dataset.select() is better + for entry in self.shard_iterable(self.dataset[file_path]): + yield self.text_to_instance(file_path, entry) + + def raise_feature_not_supported_value_error(self, value): + raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + + def text_to_instance(self, *inputs) -> Instance: + """ + Takes care of converting dataset entry into AllenNLP friendly instance + Currently it is implemented in an unseemly catch-up model + where it converts datasets.features that are required for the supported dataset, + ideally it would require design where we cleanly deliberate, decide + map dataset.feature to an allenlp.data.field and then go ahead with converting it + Doing that would provide the best chance of providing largest possible coverage with datasets + + Currently this is how datasets.features types are mapped to AllenNLP Fields + + dataset.feature type allennlp.data.fields + `ClassLabel` `LabelField` in feature name namespace + `Value.string` `TextField` with value as Token + `Value.*` `LabelField` with value being label in feature name namespace + `Sequence.string` `ListField` of `TextField` with individual string as token + `Sequence.ClassLabel` `ListField` of `ClassLabel` in feature name namespace + `Translation` `ListField` of 2 ListField (ClassLabel and TextField) + `TranslationVariableLanguages` `ListField` of 2 ListField (ClassLabel and TextField) + """ + + # features indicate the different information available in each entry from dataset + # feature types decide what type of information they are + # e.g. In a Sentiment dataset an entry could have one feature (of type text/string) indicating the text + # and another indicate the sentiment (of typeint32/ClassLabel) + + split = inputs[0] + features = self.dataset[split].features + fields = dict() + + # TODO we need to support all different datasets features described + # in https://huggingface.co/docs/datasets/features.html + for feature in features: + fields_to_be_added: typing.Dict[str, Field] = dict() + item_field: Field + field_list: list + value = features[feature] + + # datasets ClassLabel maps to LabelField + if isinstance(value, ClassLabel): + fields_to_be_added[feature] = LabelField( + inputs[1][feature], label_namespace=feature, skip_indexing=True + ) + + # datasets Value can be of different types + elif isinstance(value, Value): + + # String value maps to TextField + if value.dtype == "string": + # datasets.Value[string] maps to TextField + # If tokenizer is provided we will use it to split it to tokens + # Else put whole text as a single token + if self.tokenizer is not None: + fields_to_be_added[feature] = TextField( + self.tokenizer.tokenize(inputs[1][feature]) + ) + + else: + fields_to_be_added[feature] = TextField([Token(inputs[1][feature])]) + + else: + fields_to_be_added[feature] = LabelField( + inputs[1][feature], label_namespace=feature, skip_indexing=True + ) + + elif isinstance(value, Sequence): + # We do not know if the string is token or text, we will assume text and make each a TextField + # datasets.features.Sequence of strings maps to ListField of TextField + if hasattr(value.feature, "dtype") and value.feature.dtype == "string": + field_list2: typing.List[TextField] = list() + for item in inputs[1][feature]: + # If tokenizer is provided we will use it to split it to tokens + # Else put whole text as a single token + tokens: typing.List[Token] + if self.tokenizer is not None: + tokens = self.tokenizer.tokenize(item) + + else: + tokens = [Token(item)] + + item_field = TextField(tokens) + field_list2.append(item_field) + + fields_to_be_added[feature] = ListField(field_list2) + + # datasets Sequence of strings to ListField of LabelField + elif isinstance(value.feature, ClassLabel): + field_list = list() + for item in inputs[1][feature]: + item_field = LabelField( + label=item, label_namespace=feature, skip_indexing=True + ) + field_list.append(item_field) + + fields_to_be_added[feature] = ListField(field_list) + + else: + self.raise_feature_not_supported_value_error(value) + + # datasets.Translation cannot be mapped directly + # but it's dict structure can be mapped to a ListField of 2 ListField + elif isinstance(value, Translation): + if value.dtype == "dict": + input_dict = inputs[1][feature] + langs = list(input_dict.keys()) + texts = list() + for lang in langs: + if self.tokenizer is not None: + tokens = self.tokenizer.tokenize(input_dict[lang]) + + else: + tokens = [Token(input_dict[lang])] + texts.append(TextField(tokens)) + + fields_to_be_added[feature + "-languages"] = ListField( + [LabelField(lang, label_namespace="languages") for lang in langs] + ) + fields_to_be_added[feature + "-texts"] = ListField(texts) + + else: + raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + + # datasets.TranslationVariableLanguages + # is functionally a pair of Lists and hence mapped to a ListField of 2 ListField + elif isinstance(value, TranslationVariableLanguages): + if value.dtype == "dict": + input_dict = inputs[1][feature] + fields_to_be_added[feature + "-language"] = ListField( + [ + LabelField(lang, label_namespace=feature + "-language") + for lang in input_dict["language"] + ] + ) + + if self.tokenizer is not None: + fields_to_be_added[feature + "-translation"] = ListField( + [ + TextField(self.tokenizer.tokenize(text)) + for text in input_dict["translation"] + ] + ) + else: + fields_to_be_added[feature + "-translation"] = ListField( + [TextField([Token(text)]) for text in input_dict["translation"]] + ) + + else: + raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + + else: + raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + + for field_key in fields_to_be_added: + fields[field_key] = fields_to_be_added[field_key] + + return Instance(fields) diff --git a/setup.py b/setup.py index bd0e7e0f6cd..13587bcd70e 100644 --- a/setup.py +++ b/setup.py @@ -76,6 +76,7 @@ "wandb>=0.10.0,<0.11.0", "huggingface_hub>=0.0.8", "google-cloud-storage>=1.38.0,<1.39.0", + "datasets>=1.5.0,<1.6.0", ], entry_points={"console_scripts": ["allennlp=allennlp.__main__:run"]}, include_package_data=True, diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py new file mode 100644 index 00000000000..14657bee392 --- /dev/null +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -0,0 +1,184 @@ +import pytest +from allennlp.data import Tokenizer + +from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetReader +from allennlp.data.tokenizers import WhitespaceTokenizer + + +# TODO Add test where we compare huggingface wrapped reader with an explicitly built dataset +# TODO pab-vmware/Abhishek-P Add test where we load conll2003 and test it +# the way tested for conll2003 specific reader +class HuggingfaceDatasetReaderTest: + + """ + Test read for some lightweight datasets + """ + + @pytest.mark.parametrize( + "dataset, config, split", + (("glue", "cola", "train"), ("glue", "cola", "test")), + ) + def test_read(self, dataset, config, split): + huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset, config_name=config) + instances = list(huggingface_reader.read(split)) + # Confirm instance were made for all rows + assert len(instances) == len(huggingface_reader.dataset[split]) + + entry = huggingface_reader.dataset[split][0] + instance = instances[0] + + # Confirm all features were mapped + assert len(instance.fields) == len(entry) + + def test_read_unsupported_sequence_nesting(self): + dataset = "diplomacy_detection" + split = "train" + huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset) + with pytest.raises(ValueError): + next(huggingface_reader.read(split)) + + def test_read_with_tokenizer(self): + dataset = "glue" + config = "cola" + split = "train" + tokenizer: Tokenizer = WhitespaceTokenizer() + huggingface_reader = HuggingfaceDatasetReader( + dataset_name=dataset, config_name=config, tokenizer=tokenizer + ) + instances = list(huggingface_reader.read(split)) + # Confirm instance were made for all rows + assert len(instances) == len(huggingface_reader.dataset[split]) + + entry = huggingface_reader.dataset[split][0] + instance = instances[0] + + # Confirm all features were mapped + assert len(instance.fields) == len(entry) + + # Confirm it was tokenized + assert len(instance["sentence"]) > 1 + + def test_read_without_config(self): + dataset = "urdu_fake_news" + split = "train" + huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset) + instances = list(huggingface_reader.read(split)) + # Confirm instance were made for all rows + assert len(instances) == len(huggingface_reader.dataset[split]) + + entry = huggingface_reader.dataset[split][0] + instance = instances[0] + + # Confirm all features were mapped + assert len(instance.fields) == len(entry) + + def test_read_with_preload(self): + dataset = "glue" + config = "cola" + split = "train" + tokenizer: Tokenizer = WhitespaceTokenizer() + huggingface_reader = HuggingfaceDatasetReader( + dataset_name=dataset, config_name=config, tokenizer=tokenizer, preload=True + ) + instances = list(huggingface_reader.read(split)) + # Confirm instance were made for all rows + assert len(instances) == len(huggingface_reader.dataset[split]) + + entry = huggingface_reader.dataset[split][0] + instance = instances[0] + + # Confirm all features were mapped + assert len(instance.fields) == len(entry) + + # Confirm it was tokenized + assert len(instance["sentence"]) > 1 + + """ + Test mapping of the datasets.feature.Translation and datasets.feature.TranslationVariableLanguages + """ + + def test_read_xnli_all_languages(self): + dataset = "xnli" + config = "all_languages" + split = "validation" + huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset, config_name=config) + instances = list(huggingface_reader.read(split)) + # Confirm instance were made for all rows + assert len(instances) == len(huggingface_reader.dataset[split]) + instance = instances[0] + # We are splitting datasets.features.Translation and + # datasets.features.TranslationVariableLanguages into two fields each + # For XNLI that means 3 fields become 5 + assert len(instance.fields) == 5 + + def test_non_supported_feature(self): + dataset = "pubmed_qa" + config = "pqa_labeled" + split = "train" + with pytest.raises(ValueError): + next(HuggingfaceDatasetReader(dataset_name=dataset, config_name=config).read(split)) + + def test_non_available_dataset(self): + with pytest.raises(ValueError): + HuggingfaceDatasetReader(dataset_name="surely-such-a-dataset-does-not-exist") + + @pytest.mark.parametrize("split", (None, "surely-such-a-split-does-not-exist")) + def test_read_with_invalid_split(self, split): + with pytest.raises(ValueError): + next(HuggingfaceDatasetReader(dataset_name="glue", config_name="cola").read(split)) + + """ + Test to help validate for the known supported datasets + Skipped by default, enable when required + """ + + @pytest.mark.skip() + @pytest.mark.parametrize( + "dataset, config, split", + ( + ("xnli", "ar", "train"), + ("xnli", "en", "train"), + ("xnli", "de", "train"), + ("glue", "mrpc", "train"), + ("glue", "sst2", "train"), + ("glue", "qqp", "train"), + ("glue", "mnli", "train"), + ("glue", "mnli_matched", "validation"), + ("universal_dependencies", "en_lines", "train"), + ("universal_dependencies", "ko_kaist", "train"), + ("universal_dependencies", "af_afribooms", "train"), + ), + ) + def test_read_known_supported_datasets_with_config(self, dataset, config, split): + huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset, config_name=config) + instances = list(huggingface_reader.read(split)) + # Confirm instance were made for all rows + assert len(instances) == len(huggingface_reader.dataset[split]) + + entry = huggingface_reader.dataset[split][0] + instance = instances[0] + + # Confirm all features were mapped + assert len(instance.fields) == len(entry) + + """ + Test to help validate for the known supported datasets without config + Skipped by default, enable when required + """ + + @pytest.mark.skip() + @pytest.mark.parametrize( + "dataset", (("swahili"), ("conll2003"), ("dbpedia_14"), ("trec"), ("emotion")) + ) + def test_read_known_supported_datasets_without_config(self, dataset): + split = "train" + huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset) + instances = list(huggingface_reader.read(split)) + # Confirm instance were made for all rows + assert len(instances) == len(huggingface_reader.dataset[split]) + + entry = huggingface_reader.dataset[split][0] + instance = instances[0] + + # Confirm all features were mapped + assert len(instance.fields) == len(entry) From 8370803dad04bc487d48bad335b5403090d955a4 Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Sun, 9 May 2021 19:55:19 +0530 Subject: [PATCH 02/63] Move mapping to funcs, remove preload support --- .../huggingface_datasets_reader.py | 242 ++++++++++-------- .../huggingface_datasets_reader_test.py | 23 +- 2 files changed, 130 insertions(+), 135 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index af030a377f9..5c9d3756982 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -1,6 +1,7 @@ import typing from typing import Iterable, Optional +import datasets from allennlp.data import DatasetReader, Token, Field, Tokenizer from allennlp.data.fields import TextField, LabelField, ListField from allennlp.data.instance import Instance @@ -61,7 +62,6 @@ def __init__( self, dataset_name: str = None, config_name: Optional[str] = None, - preload: Optional[bool] = False, tokenizer: Optional[Tokenizer] = None, **kwargs, ) -> None: @@ -79,15 +79,6 @@ def __init__( self.config_name = config_name self.tokenizer = tokenizer - if preload: - self.load_dataset() - - def load_dataset(self): - if self.config_name is not None: - self.dataset = load_dataset(self.dataset_name, self.config_name) - else: - self.dataset = load_dataset(self.dataset_name) - def load_dataset_split(self, split: str): # TODO add support for datasets.split.NamedSplit if split in self.SUPPORTED_SPLITS: @@ -115,7 +106,7 @@ def _read(self, file_path: str) -> Iterable[Instance]: for entry in self.shard_iterable(self.dataset[file_path]): yield self.text_to_instance(file_path, entry) - def raise_feature_not_supported_value_error(self, value): + def raise_feature_not_supported_value_error(value): raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") def text_to_instance(self, *inputs) -> Instance: @@ -145,7 +136,7 @@ def text_to_instance(self, *inputs) -> Instance: # and another indicate the sentiment (of typeint32/ClassLabel) split = inputs[0] - features = self.dataset[split].features + features: typing.List[datasets.features.Feature] = self.dataset[split].features fields = dict() # TODO we need to support all different datasets features described @@ -158,114 +149,22 @@ def text_to_instance(self, *inputs) -> Instance: # datasets ClassLabel maps to LabelField if isinstance(value, ClassLabel): - fields_to_be_added[feature] = LabelField( - inputs[1][feature], label_namespace=feature, skip_indexing=True - ) + fields_to_be_added = map_ClassLabel(feature, inputs[1]) # datasets Value can be of different types elif isinstance(value, Value): - - # String value maps to TextField - if value.dtype == "string": - # datasets.Value[string] maps to TextField - # If tokenizer is provided we will use it to split it to tokens - # Else put whole text as a single token - if self.tokenizer is not None: - fields_to_be_added[feature] = TextField( - self.tokenizer.tokenize(inputs[1][feature]) - ) - - else: - fields_to_be_added[feature] = TextField([Token(inputs[1][feature])]) - - else: - fields_to_be_added[feature] = LabelField( - inputs[1][feature], label_namespace=feature, skip_indexing=True - ) + fields_to_be_added = map_Value(feature, inputs[1], value, self.tokenizer) elif isinstance(value, Sequence): - # We do not know if the string is token or text, we will assume text and make each a TextField - # datasets.features.Sequence of strings maps to ListField of TextField - if hasattr(value.feature, "dtype") and value.feature.dtype == "string": - field_list2: typing.List[TextField] = list() - for item in inputs[1][feature]: - # If tokenizer is provided we will use it to split it to tokens - # Else put whole text as a single token - tokens: typing.List[Token] - if self.tokenizer is not None: - tokens = self.tokenizer.tokenize(item) - - else: - tokens = [Token(item)] - - item_field = TextField(tokens) - field_list2.append(item_field) - - fields_to_be_added[feature] = ListField(field_list2) - - # datasets Sequence of strings to ListField of LabelField - elif isinstance(value.feature, ClassLabel): - field_list = list() - for item in inputs[1][feature]: - item_field = LabelField( - label=item, label_namespace=feature, skip_indexing=True - ) - field_list.append(item_field) - - fields_to_be_added[feature] = ListField(field_list) - - else: - self.raise_feature_not_supported_value_error(value) - - # datasets.Translation cannot be mapped directly - # but it's dict structure can be mapped to a ListField of 2 ListField + fields_to_be_added = map_Sequence(feature, inputs[1], value, self.tokenizer) + elif isinstance(value, Translation): - if value.dtype == "dict": - input_dict = inputs[1][feature] - langs = list(input_dict.keys()) - texts = list() - for lang in langs: - if self.tokenizer is not None: - tokens = self.tokenizer.tokenize(input_dict[lang]) - - else: - tokens = [Token(input_dict[lang])] - texts.append(TextField(tokens)) - - fields_to_be_added[feature + "-languages"] = ListField( - [LabelField(lang, label_namespace="languages") for lang in langs] - ) - fields_to_be_added[feature + "-texts"] = ListField(texts) - - else: - raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") - - # datasets.TranslationVariableLanguages - # is functionally a pair of Lists and hence mapped to a ListField of 2 ListField + fields_to_be_added = map_Translation(feature, inputs[1], value, self.tokenizer) + elif isinstance(value, TranslationVariableLanguages): - if value.dtype == "dict": - input_dict = inputs[1][feature] - fields_to_be_added[feature + "-language"] = ListField( - [ - LabelField(lang, label_namespace=feature + "-language") - for lang in input_dict["language"] - ] - ) - - if self.tokenizer is not None: - fields_to_be_added[feature + "-translation"] = ListField( - [ - TextField(self.tokenizer.tokenize(text)) - for text in input_dict["translation"] - ] - ) - else: - fields_to_be_added[feature + "-translation"] = ListField( - [TextField([Token(text)]) for text in input_dict["translation"]] - ) - - else: - raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + fields_to_be_added = map_TranslationVariableLanguages( + feature, inputs[1], value, self.tokenizer + ) else: raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") @@ -274,3 +173,120 @@ def text_to_instance(self, *inputs) -> Instance: fields[field_key] = fields_to_be_added[field_key] return Instance(fields) + + +def map_ClassLabel(feature: str, entry: typing.Dict) -> typing.Dict[str, Field]: + fields: typing.Dict[str, Field] = dict() + fields[feature] = LabelField(entry[feature], label_namespace=feature, skip_indexing=True) + return fields + + +def map_Value( + feature: str, entry: typing.Dict, value, tokenizer: Optional[Tokenizer] +) -> typing.Dict[str, Field]: + fields: typing.Dict[str, Field] = dict() + if value.dtype == "string": + # datasets.Value[string] maps to TextField + # If tokenizer is provided we will use it to split it to tokens + # Else put whole text as a single token + if tokenizer is not None: + fields[feature] = TextField(tokenizer.tokenize(entry[feature])) + + else: + fields[feature] = TextField([Token(entry[feature])]) + + else: + fields[feature] = LabelField(entry[feature], label_namespace=feature, skip_indexing=True) + return fields + + +def map_Sequence( + feature: str, entry: typing.Dict, value, tokenizer: Optional[Tokenizer] +) -> typing.Dict[str, Field]: + item_field: typing.Union[LabelField, TextField] + fields: typing.Dict[str, Field] = dict() + if hasattr(value.feature, "dtype") and value.feature.dtype == "string": + field_list2: typing.List[TextField] = list() + for item in entry[feature]: + # If tokenizer is provided we will use it to split it to tokens + # Else put whole text as a single token + tokens: typing.List[Token] + if tokenizer is not None: + tokens = tokenizer.tokenize(item) + + else: + tokens = [Token(item)] + + item_field = TextField(tokens) + field_list2.append(item_field) + + fields[feature] = ListField(field_list2) + + # datasets Sequence of strings to ListField of LabelField + elif isinstance(value.feature, ClassLabel): + field_list = list() + for item in entry[feature]: + item_field = LabelField(label=item, label_namespace=feature, skip_indexing=True) + field_list.append(item_field) + + fields[feature] = ListField(field_list) + + else: + HuggingfaceDatasetReader.raise_feature_not_supported_value_error(value) + + return fields + + +def map_Translation( + feature: str, entry: typing.Dict, value, tokenizer: Optional[Tokenizer] +) -> typing.Dict[str, Field]: + fields: typing.Dict[str, Field] = dict() + if value.dtype == "dict": + input_dict = entry[feature] + langs = list(input_dict.keys()) + texts = list() + for lang in langs: + if tokenizer is not None: + tokens = tokenizer.tokenize(input_dict[lang]) + + else: + tokens = [Token(input_dict[lang])] + texts.append(TextField(tokens)) + + fields[feature + "-languages"] = ListField( + [LabelField(lang, label_namespace="languages") for lang in langs] + ) + fields[feature + "-texts"] = ListField(texts) + + else: + raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + + return fields + + +def map_TranslationVariableLanguages( + feature: str, entry: typing.Dict, value, tokenizer: Optional[Tokenizer] +) -> typing.Dict[str, Field]: + fields: typing.Dict[str, Field] = dict() + if value.dtype == "dict": + input_dict = entry[feature] + fields[feature + "-language"] = ListField( + [ + LabelField(lang, label_namespace=feature + "-language") + for lang in input_dict["language"] + ] + ) + + if tokenizer is not None: + fields[feature + "-translation"] = ListField( + [TextField(tokenizer.tokenize(text)) for text in input_dict["translation"]] + ) + else: + fields[feature + "-translation"] = ListField( + [TextField([Token(text)]) for text in input_dict["translation"]] + ) + + else: + raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + + return fields diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index 14657bee392..84b213ba8d7 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -5,7 +5,7 @@ from allennlp.data.tokenizers import WhitespaceTokenizer -# TODO Add test where we compare huggingface wrapped reader with an explicitly built dataset +# TODO Add test where we compare huggingface wrapped reader with an explicitly coded dataset # TODO pab-vmware/Abhishek-P Add test where we load conll2003 and test it # the way tested for conll2003 specific reader class HuggingfaceDatasetReaderTest: @@ -72,27 +72,6 @@ def test_read_without_config(self): # Confirm all features were mapped assert len(instance.fields) == len(entry) - def test_read_with_preload(self): - dataset = "glue" - config = "cola" - split = "train" - tokenizer: Tokenizer = WhitespaceTokenizer() - huggingface_reader = HuggingfaceDatasetReader( - dataset_name=dataset, config_name=config, tokenizer=tokenizer, preload=True - ) - instances = list(huggingface_reader.read(split)) - # Confirm instance were made for all rows - assert len(instances) == len(huggingface_reader.dataset[split]) - - entry = huggingface_reader.dataset[split][0] - instance = instances[0] - - # Confirm all features were mapped - assert len(instance.fields) == len(entry) - - # Confirm it was tokenized - assert len(instance["sentence"]) > 1 - """ Test mapping of the datasets.feature.Translation and datasets.feature.TranslationVariableLanguages """ From d5b8f3f5aae6197a75cb69f01143cd237e22cbfd Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Sun, 9 May 2021 23:28:42 +0530 Subject: [PATCH 03/63] Support for Sequence Nesting --- .../huggingface_datasets_reader.py | 150 +++++++++--------- .../huggingface_datasets_reader_test.py | 6 +- 2 files changed, 82 insertions(+), 74 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 5c9d3756982..7f656025b51 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -1,13 +1,16 @@ -import typing -from typing import Iterable, Optional - -import datasets from allennlp.data import DatasetReader, Token, Field, Tokenizer from allennlp.data.fields import TextField, LabelField, ListField from allennlp.data.instance import Instance from datasets import load_dataset, DatasetDict, Split, list_datasets -from datasets.features import ClassLabel, Sequence, Translation, TranslationVariableLanguages -from datasets.features import Value +from datasets.features import ( + ClassLabel, + Sequence, + Translation, + TranslationVariableLanguages, + Value, + FeatureType, +) +from typing import Iterable, Optional, Dict, List, Union @DatasetReader.register("huggingface-datasets") @@ -43,7 +46,6 @@ class HuggingfaceDatasetReader(DatasetReader): Registered as a `DatasetReader` with name `huggingface-datasets` # Parameters - dataset_name : `str` Name of the dataset from huggingface datasets the reader will be used for. config_name : `str`, optional (default=`None`) @@ -136,100 +138,92 @@ def text_to_instance(self, *inputs) -> Instance: # and another indicate the sentiment (of typeint32/ClassLabel) split = inputs[0] - features: typing.List[datasets.features.Feature] = self.dataset[split].features - fields = dict() + features: Dict[str, FeatureType] = self.dataset[split].features + fields: Dict[str, Field] = dict() # TODO we need to support all different datasets features described # in https://huggingface.co/docs/datasets/features.html for feature in features: - fields_to_be_added: typing.Dict[str, Field] = dict() + fields_to_be_added: Dict[str, Field] = dict() item_field: Field field_list: list value = features[feature] - # datasets ClassLabel maps to LabelField - if isinstance(value, ClassLabel): - fields_to_be_added = map_ClassLabel(feature, inputs[1]) + fields_to_be_added = map_Feature(feature, inputs[1], value, self.tokenizer) + for field_key in fields_to_be_added: + fields[field_key] = fields_to_be_added[field_key] - # datasets Value can be of different types - elif isinstance(value, Value): - fields_to_be_added = map_Value(feature, inputs[1], value, self.tokenizer) + return Instance(fields) - elif isinstance(value, Sequence): - fields_to_be_added = map_Sequence(feature, inputs[1], value, self.tokenizer) - elif isinstance(value, Translation): - fields_to_be_added = map_Translation(feature, inputs[1], value, self.tokenizer) +# Feature Mappers - These functions map a FeatureType into Fields +def map_Feature( + feature: str, entry: Dict, value, tokenizer: Optional[Tokenizer] +) -> Dict[str, Field]: + fields_to_be_added: Dict[str, Field] = dict() + if isinstance(value, ClassLabel): + fields_to_be_added[feature] = map_ClassLabel(feature, entry[feature]) + # datasets Value can be of different types + elif isinstance(value, Value): + fields_to_be_added[feature] = map_Value(feature, entry[feature], value, tokenizer) - elif isinstance(value, TranslationVariableLanguages): - fields_to_be_added = map_TranslationVariableLanguages( - feature, inputs[1], value, self.tokenizer - ) + elif isinstance(value, Sequence): + fields_to_be_added = map_Sequence(feature, entry, value, tokenizer) - else: - raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + elif isinstance(value, Translation): + fields_to_be_added = map_Translation(feature, entry, value, tokenizer) - for field_key in fields_to_be_added: - fields[field_key] = fields_to_be_added[field_key] + elif isinstance(value, TranslationVariableLanguages): + fields_to_be_added = map_TranslationVariableLanguages(feature, entry, value, tokenizer) - return Instance(fields) + else: + raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + return fields_to_be_added -def map_ClassLabel(feature: str, entry: typing.Dict) -> typing.Dict[str, Field]: - fields: typing.Dict[str, Field] = dict() - fields[feature] = LabelField(entry[feature], label_namespace=feature, skip_indexing=True) - return fields +def map_ClassLabel(feature: str, entry: Dict) -> Field: + field: Field = map_to_Label(feature, entry, skip_indexing=True) + return field def map_Value( - feature: str, entry: typing.Dict, value, tokenizer: Optional[Tokenizer] -) -> typing.Dict[str, Field]: - fields: typing.Dict[str, Field] = dict() + feature: str, item: Value, value, tokenizer: Optional[Tokenizer] +) -> Union[TextField, LabelField]: + field: Union[TextField, LabelField] if value.dtype == "string": # datasets.Value[string] maps to TextField # If tokenizer is provided we will use it to split it to tokens # Else put whole text as a single token - if tokenizer is not None: - fields[feature] = TextField(tokenizer.tokenize(entry[feature])) - - else: - fields[feature] = TextField([Token(entry[feature])]) + field = map_String(feature, item, None, tokenizer) else: - fields[feature] = LabelField(entry[feature], label_namespace=feature, skip_indexing=True) - return fields + field = LabelField(item, label_namespace=feature, skip_indexing=True) + return field def map_Sequence( - feature: str, entry: typing.Dict, value, tokenizer: Optional[Tokenizer] -) -> typing.Dict[str, Field]: - item_field: typing.Union[LabelField, TextField] - fields: typing.Dict[str, Field] = dict() - if hasattr(value.feature, "dtype") and value.feature.dtype == "string": - field_list2: typing.List[TextField] = list() + feature: str, entry: Dict, value, tokenizer: Optional[Tokenizer] +) -> Dict[str, Field]: + item_field: Union[LabelField, TextField] + field_list: List[Union[TextField, LabelField]] = list() + fields: Dict[str, Field] = dict() + if isinstance(value.feature, Value): for item in entry[feature]: # If tokenizer is provided we will use it to split it to tokens # Else put whole text as a single token - tokens: typing.List[Token] - if tokenizer is not None: - tokens = tokenizer.tokenize(item) - - else: - tokens = [Token(item)] - - item_field = TextField(tokens) - field_list2.append(item_field) - - fields[feature] = ListField(field_list2) + item_field = map_Value(feature, item, value.feature, tokenizer) + field_list.append(item_field) + if len(field_list) > 0: + fields[feature] = ListField(field_list) # datasets Sequence of strings to ListField of LabelField elif isinstance(value.feature, ClassLabel): - field_list = list() for item in entry[feature]: - item_field = LabelField(label=item, label_namespace=feature, skip_indexing=True) + item_field = map_to_Label(feature, item, skip_indexing=True) field_list.append(item_field) - fields[feature] = ListField(field_list) + if len(field_list) > 0: + fields[feature] = ListField(field_list) else: HuggingfaceDatasetReader.raise_feature_not_supported_value_error(value) @@ -238,9 +232,9 @@ def map_Sequence( def map_Translation( - feature: str, entry: typing.Dict, value, tokenizer: Optional[Tokenizer] -) -> typing.Dict[str, Field]: - fields: typing.Dict[str, Field] = dict() + feature: str, entry: Dict, value, tokenizer: Optional[Tokenizer] +) -> Dict[str, Field]: + fields: Dict[str, Field] = dict() if value.dtype == "dict": input_dict = entry[feature] langs = list(input_dict.keys()) @@ -254,7 +248,7 @@ def map_Translation( texts.append(TextField(tokens)) fields[feature + "-languages"] = ListField( - [LabelField(lang, label_namespace="languages") for lang in langs] + [map_to_Label(feature + "-languages", lang, skip_indexing=False) for lang in langs] ) fields[feature + "-texts"] = ListField(texts) @@ -265,14 +259,14 @@ def map_Translation( def map_TranslationVariableLanguages( - feature: str, entry: typing.Dict, value, tokenizer: Optional[Tokenizer] -) -> typing.Dict[str, Field]: - fields: typing.Dict[str, Field] = dict() + feature: str, entry: Dict, value, tokenizer: Optional[Tokenizer] +) -> Dict[str, Field]: + fields: Dict[str, Field] = dict() if value.dtype == "dict": input_dict = entry[feature] fields[feature + "-language"] = ListField( [ - LabelField(lang, label_namespace=feature + "-language") + map_to_Label(feature + "-languages", lang, skip_indexing=False) for lang in input_dict["language"] ] ) @@ -290,3 +284,17 @@ def map_TranslationVariableLanguages( raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") return fields + + +# Value mapper - Maps a single Value +def map_String(feature: str, text: str, value, tokenizer: Optional[Tokenizer]) -> TextField: + field: TextField + if tokenizer is not None: + field = TextField(tokenizer.tokenize(text)) + else: + field = TextField([Token(text)]) + return field + + +def map_to_Label(namespace, item, skip_indexing=True) -> LabelField: + return LabelField(label=item, label_namespace=namespace, skip_indexing=skip_indexing) diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index 84b213ba8d7..235471aab60 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -30,12 +30,12 @@ def test_read(self, dataset, config, split): # Confirm all features were mapped assert len(instance.fields) == len(entry) - def test_read_unsupported_sequence_nesting(self): + def test_read_sequence_nesting(self): dataset = "diplomacy_detection" split = "train" huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset) - with pytest.raises(ValueError): - next(huggingface_reader.read(split)) + instances = list(huggingface_reader.read(split)) + assert len(instances) == len(huggingface_reader.dataset[split]) def test_read_with_tokenizer(self): dataset = "glue" From 49fa0bc3bc1e116be631ed0eadc9271c54f086da Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Sun, 9 May 2021 23:48:29 +0530 Subject: [PATCH 04/63] Misc Fixes --- .../huggingface_datasets_reader.py | 52 ++++--------------- 1 file changed, 11 insertions(+), 41 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 7f656025b51..7e6a93fee74 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -20,29 +20,6 @@ class HuggingfaceDatasetReader(DatasetReader): This reader implementation wraps the huggingface datasets package - Following dataset and configurations have been verified and work with this reader - - Dataset Dataset Configuration - `xnli` `ar` - `xnli` `en` - `xnli` `de` - `xnli` `all_languages` - `glue` `cola` - `glue` `mrpc` - `glue` `sst2` - `glue` `qqp` - `glue` `mnli` - `glue` `mnli_matched` - `universal_dependencies` `en_lines` - `universal_dependencies` `ko_kaist` - `universal_dependencies` `af_afribooms` - `swahili` `NA` - `conll2003` `NA` - `dbpedia_14` `NA` - `trec` `NA` - `emotion` `NA` - Note: universal_dependencies will require you to install `conllu` package separately - Registered as a `DatasetReader` with name `huggingface-datasets` # Parameters @@ -50,9 +27,6 @@ class HuggingfaceDatasetReader(DatasetReader): Name of the dataset from huggingface datasets the reader will be used for. config_name : `str`, optional (default=`None`) Configuration(mandatory for some datasets) of the dataset. - preload : `bool`, optional (default=`False`) - If `True` all splits for the dataset is loaded(includes download etc) as part of the initialization, - otherwise each split is loaded on when `read()` is used for the same for the first time. tokenizer : `Tokenizer`, optional (default=`None`) If specified is used for tokenization of string and text fields from the dataset. This is useful since text in allennlp is dealt with as a series of tokens. @@ -105,8 +79,10 @@ def _read(self, file_path: str) -> Iterable[Instance]: self.load_dataset_split(file_path) # TODO see if use of Dataset.select() is better - for entry in self.shard_iterable(self.dataset[file_path]): - yield self.text_to_instance(file_path, entry) + dataset_split = self.dataset[file_path] + for index in self.shard_iterable(range(len(dataset_split))): + yield self.text_to_instance(file_path, dataset_split[index]) + def raise_feature_not_supported_value_error(value): raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") @@ -114,22 +90,16 @@ def raise_feature_not_supported_value_error(value): def text_to_instance(self, *inputs) -> Instance: """ Takes care of converting dataset entry into AllenNLP friendly instance - Currently it is implemented in an unseemly catch-up model - where it converts datasets.features that are required for the supported dataset, - ideally it would require design where we cleanly deliberate, decide - map dataset.feature to an allenlp.data.field and then go ahead with converting it - Doing that would provide the best chance of providing largest possible coverage with datasets Currently this is how datasets.features types are mapped to AllenNLP Fields - dataset.feature type allennlp.data.fields - `ClassLabel` `LabelField` in feature name namespace - `Value.string` `TextField` with value as Token - `Value.*` `LabelField` with value being label in feature name namespace - `Sequence.string` `ListField` of `TextField` with individual string as token - `Sequence.ClassLabel` `ListField` of `ClassLabel` in feature name namespace - `Translation` `ListField` of 2 ListField (ClassLabel and TextField) - `TranslationVariableLanguages` `ListField` of 2 ListField (ClassLabel and TextField) + dataset.feature type allennlp.data.fields + `ClassLabel` `LabelField` in feature name namespace + `Value.string` `TextField` with value as Token + `Value.*` `LabelField` with value being label in feature name namespace + `Translation` `ListField` of 2 ListField (ClassLabel and TextField) + `TranslationVariableLanguages` `ListField` of 2 ListField (ClassLabel and TextField) + `Sequence` `ListField` of sub-types """ # features indicate the different information available in each entry from dataset From 17cd4acadd6323fdf41ca1e04e81985e7495055e Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Mon, 10 May 2021 00:07:53 +0530 Subject: [PATCH 05/63] Misc check --- allennlp/data/dataset_readers/huggingface_datasets_reader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 7e6a93fee74..e5709184f99 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -83,7 +83,6 @@ def _read(self, file_path: str) -> Iterable[Instance]: for index in self.shard_iterable(range(len(dataset_split))): yield self.text_to_instance(file_path, dataset_split[index]) - def raise_feature_not_supported_value_error(value): raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") From 5159f690b92c619d3164f685229ea00474e622a7 Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Tue, 11 May 2021 00:00:06 +0530 Subject: [PATCH 06/63] Comments --- allennlp/data/dataset_readers/huggingface_datasets_reader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index e5709184f99..f3a70ae5ed4 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -255,7 +255,7 @@ def map_TranslationVariableLanguages( return fields -# Value mapper - Maps a single Value +# value mapper - Maps a single text value to TextField def map_String(feature: str, text: str, value, tokenizer: Optional[Tokenizer]) -> TextField: field: TextField if tokenizer is not None: @@ -265,5 +265,6 @@ def map_String(feature: str, text: str, value, tokenizer: Optional[Tokenizer]) - return field +# value mapper - Maps a single value to a LabelField def map_to_Label(namespace, item, skip_indexing=True) -> LabelField: return LabelField(label=item, label_namespace=namespace, skip_indexing=skip_indexing) From 7155a32a1f876cd3fcbe6d6e9fa480a3a1cd47f2 Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Tue, 11 May 2021 21:13:18 +0530 Subject: [PATCH 07/63] map funcs _ prefix --- .../huggingface_datasets_reader.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index f3a70ae5ed4..4a07a90d5d9 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -118,7 +118,7 @@ def text_to_instance(self, *inputs) -> Instance: field_list: list value = features[feature] - fields_to_be_added = map_Feature(feature, inputs[1], value, self.tokenizer) + fields_to_be_added = _map_Feature(feature, inputs[1], value, self.tokenizer) for field_key in fields_to_be_added: fields[field_key] = fields_to_be_added[field_key] @@ -126,36 +126,36 @@ def text_to_instance(self, *inputs) -> Instance: # Feature Mappers - These functions map a FeatureType into Fields -def map_Feature( +def _map_Feature( feature: str, entry: Dict, value, tokenizer: Optional[Tokenizer] ) -> Dict[str, Field]: fields_to_be_added: Dict[str, Field] = dict() if isinstance(value, ClassLabel): - fields_to_be_added[feature] = map_ClassLabel(feature, entry[feature]) + fields_to_be_added[feature] = _map_ClassLabel(feature, entry[feature]) # datasets Value can be of different types elif isinstance(value, Value): - fields_to_be_added[feature] = map_Value(feature, entry[feature], value, tokenizer) + fields_to_be_added[feature] = _map_Value(feature, entry[feature], value, tokenizer) elif isinstance(value, Sequence): - fields_to_be_added = map_Sequence(feature, entry, value, tokenizer) + fields_to_be_added = _map_Sequence(feature, entry, value, tokenizer) elif isinstance(value, Translation): - fields_to_be_added = map_Translation(feature, entry, value, tokenizer) + fields_to_be_added = _map_Translation(feature, entry, value, tokenizer) elif isinstance(value, TranslationVariableLanguages): - fields_to_be_added = map_TranslationVariableLanguages(feature, entry, value, tokenizer) + fields_to_be_added = _map_TranslationVariableLanguages(feature, entry, value, tokenizer) else: raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") return fields_to_be_added -def map_ClassLabel(feature: str, entry: Dict) -> Field: - field: Field = map_to_Label(feature, entry, skip_indexing=True) +def _map_ClassLabel(feature: str, entry: Dict) -> Field: + field: Field = _map_to_Label(feature, entry, skip_indexing=True) return field -def map_Value( +def _map_Value( feature: str, item: Value, value, tokenizer: Optional[Tokenizer] ) -> Union[TextField, LabelField]: field: Union[TextField, LabelField] @@ -163,14 +163,14 @@ def map_Value( # datasets.Value[string] maps to TextField # If tokenizer is provided we will use it to split it to tokens # Else put whole text as a single token - field = map_String(feature, item, None, tokenizer) + field = _map_String(feature, item, None, tokenizer) else: field = LabelField(item, label_namespace=feature, skip_indexing=True) return field -def map_Sequence( +def _map_Sequence( feature: str, entry: Dict, value, tokenizer: Optional[Tokenizer] ) -> Dict[str, Field]: item_field: Union[LabelField, TextField] @@ -180,7 +180,7 @@ def map_Sequence( for item in entry[feature]: # If tokenizer is provided we will use it to split it to tokens # Else put whole text as a single token - item_field = map_Value(feature, item, value.feature, tokenizer) + item_field = _map_Value(feature, item, value.feature, tokenizer) field_list.append(item_field) if len(field_list) > 0: fields[feature] = ListField(field_list) @@ -188,7 +188,7 @@ def map_Sequence( # datasets Sequence of strings to ListField of LabelField elif isinstance(value.feature, ClassLabel): for item in entry[feature]: - item_field = map_to_Label(feature, item, skip_indexing=True) + item_field = _map_to_Label(feature, item, skip_indexing=True) field_list.append(item_field) if len(field_list) > 0: @@ -200,7 +200,7 @@ def map_Sequence( return fields -def map_Translation( +def _map_Translation( feature: str, entry: Dict, value, tokenizer: Optional[Tokenizer] ) -> Dict[str, Field]: fields: Dict[str, Field] = dict() @@ -217,7 +217,7 @@ def map_Translation( texts.append(TextField(tokens)) fields[feature + "-languages"] = ListField( - [map_to_Label(feature + "-languages", lang, skip_indexing=False) for lang in langs] + [_map_to_Label(feature + "-languages", lang, skip_indexing=False) for lang in langs] ) fields[feature + "-texts"] = ListField(texts) @@ -227,7 +227,7 @@ def map_Translation( return fields -def map_TranslationVariableLanguages( +def _map_TranslationVariableLanguages( feature: str, entry: Dict, value, tokenizer: Optional[Tokenizer] ) -> Dict[str, Field]: fields: Dict[str, Field] = dict() @@ -235,7 +235,7 @@ def map_TranslationVariableLanguages( input_dict = entry[feature] fields[feature + "-language"] = ListField( [ - map_to_Label(feature + "-languages", lang, skip_indexing=False) + _map_to_Label(feature + "-languages", lang, skip_indexing=False) for lang in input_dict["language"] ] ) @@ -256,7 +256,7 @@ def map_TranslationVariableLanguages( # value mapper - Maps a single text value to TextField -def map_String(feature: str, text: str, value, tokenizer: Optional[Tokenizer]) -> TextField: +def _map_String(feature: str, text: str, value, tokenizer: Optional[Tokenizer]) -> TextField: field: TextField if tokenizer is not None: field = TextField(tokenizer.tokenize(text)) @@ -266,5 +266,5 @@ def map_String(feature: str, text: str, value, tokenizer: Optional[Tokenizer]) - # value mapper - Maps a single value to a LabelField -def map_to_Label(namespace, item, skip_indexing=True) -> LabelField: +def _map_to_Label(namespace, item, skip_indexing=True) -> LabelField: return LabelField(label=item, label_namespace=namespace, skip_indexing=skip_indexing) From eb4b573ebe37523d59a9f7699f16dfa9ffb4d0d5 Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Wed, 12 May 2021 22:58:17 +0530 Subject: [PATCH 08/63] Parameters rename and cleanup --- .../huggingface_datasets_reader.py | 114 +++++++++--------- .../huggingface_datasets_reader_test.py | 21 ++-- 2 files changed, 72 insertions(+), 63 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 4a07a90d5d9..2f285bf316f 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -104,7 +104,7 @@ def text_to_instance(self, *inputs) -> Instance: # features indicate the different information available in each entry from dataset # feature types decide what type of information they are # e.g. In a Sentiment dataset an entry could have one feature (of type text/string) indicating the text - # and another indicate the sentiment (of typeint32/ClassLabel) + # and another indicate the sentiment (of type int32/ClassLabel) split = inputs[0] features: Dict[str, FeatureType] = self.dataset[split].features @@ -112,13 +112,12 @@ def text_to_instance(self, *inputs) -> Instance: # TODO we need to support all different datasets features described # in https://huggingface.co/docs/datasets/features.html - for feature in features: - fields_to_be_added: Dict[str, Field] = dict() + for feature_name in features: item_field: Field field_list: list - value = features[feature] + feature_type = features[feature_name] - fields_to_be_added = _map_Feature(feature, inputs[1], value, self.tokenizer) + fields_to_be_added = _map_Feature(feature_name, inputs[1], feature_type, self.tokenizer) for field_key in fields_to_be_added: fields[field_key] = fields_to_be_added[field_key] @@ -127,85 +126,88 @@ def text_to_instance(self, *inputs) -> Instance: # Feature Mappers - These functions map a FeatureType into Fields def _map_Feature( - feature: str, entry: Dict, value, tokenizer: Optional[Tokenizer] + feature_name: str, entry: Dict, feature_type, tokenizer: Optional[Tokenizer] ) -> Dict[str, Field]: fields_to_be_added: Dict[str, Field] = dict() - if isinstance(value, ClassLabel): - fields_to_be_added[feature] = _map_ClassLabel(feature, entry[feature]) + if isinstance(feature_type, ClassLabel): + fields_to_be_added[feature_name] = _map_ClassLabel(feature_name, entry[feature_name]) # datasets Value can be of different types - elif isinstance(value, Value): - fields_to_be_added[feature] = _map_Value(feature, entry[feature], value, tokenizer) + elif isinstance(feature_type, Value): + fields_to_be_added[feature_name] = _map_Value(feature_name, entry[feature_name], feature_type, tokenizer) - elif isinstance(value, Sequence): - fields_to_be_added = _map_Sequence(feature, entry, value, tokenizer) + elif isinstance(feature_type, Sequence): + fields_to_be_added[feature_name] = _map_Sequence(feature_name, entry, feature_type.feature, tokenizer) - elif isinstance(value, Translation): - fields_to_be_added = _map_Translation(feature, entry, value, tokenizer) + elif isinstance(feature_type, Translation): + fields_to_be_added = _map_Translation(feature_name, entry[feature_name], feature_type, tokenizer) - elif isinstance(value, TranslationVariableLanguages): - fields_to_be_added = _map_TranslationVariableLanguages(feature, entry, value, tokenizer) + elif isinstance(feature_type, TranslationVariableLanguages): + fields_to_be_added = _map_TranslationVariableLanguages(feature_name, entry[feature_name], feature_type, tokenizer) else: - raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + raise ValueError(f"Datasets feature type {type(feature_type)} is not supported yet.") return fields_to_be_added -def _map_ClassLabel(feature: str, entry: Dict) -> Field: - field: Field = _map_to_Label(feature, entry, skip_indexing=True) +def _map_ClassLabel(feature_name: str, value: ClassLabel) -> Field: + field: Field = _map_to_Label(feature_name, value, skip_indexing=True) return field def _map_Value( - feature: str, item: Value, value, tokenizer: Optional[Tokenizer] + feature_name: str, value: Value, feature_type, tokenizer: Optional[Tokenizer] ) -> Union[TextField, LabelField]: field: Union[TextField, LabelField] - if value.dtype == "string": + if feature_type.dtype == "string": # datasets.Value[string] maps to TextField # If tokenizer is provided we will use it to split it to tokens # Else put whole text as a single token - field = _map_String(feature, item, None, tokenizer) + field = _map_String(value, tokenizer) else: - field = LabelField(item, label_namespace=feature, skip_indexing=True) + field = LabelField(value, label_namespace=feature_name, skip_indexing=True) return field - -def _map_Sequence( - feature: str, entry: Dict, value, tokenizer: Optional[Tokenizer] -) -> Dict[str, Field]: - item_field: Union[LabelField, TextField] - field_list: List[Union[TextField, LabelField]] = list() - fields: Dict[str, Field] = dict() - if isinstance(value.feature, Value): - for item in entry[feature]: +def _map_Sequence(feature_name, value:Sequence, item_feature_type, tokenizer:Optional[Tokenizer]) -> Field: + field_list: List[Field] = list() + field: ListField = None + if isinstance(item_feature_type, Value): + for item in value: # If tokenizer is provided we will use it to split it to tokens # Else put whole text as a single token - item_field = _map_Value(feature, item, value.feature, tokenizer) + item_field = _map_Value(value.feature, item, item.value, tokenizer) field_list.append(item_field) if len(field_list) > 0: - fields[feature] = ListField(field_list) + field = ListField(field_list) # datasets Sequence of strings to ListField of LabelField - elif isinstance(value.feature, ClassLabel): - for item in entry[feature]: - item_field = _map_to_Label(feature, item, skip_indexing=True) + elif isinstance(item_feature_type, ClassLabel): + for item in value: + item_field = _map_to_Label(value.feature, item, skip_indexing=True) field_list.append(item_field) if len(field_list) > 0: - fields[feature] = ListField(field_list) + field = ListField(field_list) - else: - HuggingfaceDatasetReader.raise_feature_not_supported_value_error(value) + elif isinstance(item_feature_type, Sequence): + for item in value: + item_field = _map_Sequence(value.feature, item, tokenizer) + field_list.append(item_field) - return fields + if len(field_list) > 0: + field = ListField(field_list) + + else: + HuggingfaceDatasetReader.raise_feature_not_supported_value_error(feature_name) + return field def _map_Translation( - feature: str, entry: Dict, value, tokenizer: Optional[Tokenizer] + feature_name: str, value: Translation, feature_type, tokenizer: Optional[Tokenizer] ) -> Dict[str, Field]: fields: Dict[str, Field] = dict() - if value.dtype == "dict": - input_dict = entry[feature] + if feature_type.dtype == "dict": + input_dict = value langs = list(input_dict.keys()) texts = list() for lang in langs: @@ -216,36 +218,36 @@ def _map_Translation( tokens = [Token(input_dict[lang])] texts.append(TextField(tokens)) - fields[feature + "-languages"] = ListField( - [_map_to_Label(feature + "-languages", lang, skip_indexing=False) for lang in langs] + fields[feature_name + "-languages"] = ListField( + [_map_to_Label(feature_name + "-languages", lang, skip_indexing=False) for lang in langs] ) - fields[feature + "-texts"] = ListField(texts) + fields[feature_name + "-texts"] = ListField(texts) else: - raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + raise ValueError(f"Datasets feature type {type(feature_type)} is not supported yet.") return fields def _map_TranslationVariableLanguages( - feature: str, entry: Dict, value, tokenizer: Optional[Tokenizer] + feature_name: str, value: TranslationVariableLanguages, feature_type, tokenizer: Optional[Tokenizer] ) -> Dict[str, Field]: fields: Dict[str, Field] = dict() - if value.dtype == "dict": - input_dict = entry[feature] - fields[feature + "-language"] = ListField( + if feature_type.dtype == "dict": + input_dict = value + fields[feature_name + "-language"] = ListField( [ - _map_to_Label(feature + "-languages", lang, skip_indexing=False) + _map_to_Label(feature_name + "-languages", lang, skip_indexing=False) for lang in input_dict["language"] ] ) if tokenizer is not None: - fields[feature + "-translation"] = ListField( + fields[feature_name + "-translation"] = ListField( [TextField(tokenizer.tokenize(text)) for text in input_dict["translation"]] ) else: - fields[feature + "-translation"] = ListField( + fields[feature_name + "-translation"] = ListField( [TextField([Token(text)]) for text in input_dict["translation"]] ) @@ -256,7 +258,7 @@ def _map_TranslationVariableLanguages( # value mapper - Maps a single text value to TextField -def _map_String(feature: str, text: str, value, tokenizer: Optional[Tokenizer]) -> TextField: +def _map_String(text: str, tokenizer: Optional[Tokenizer]) -> TextField: field: TextField if tokenizer is not None: field = TextField(tokenizer.tokenize(text)) diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index 235471aab60..b261188d4df 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -8,6 +8,9 @@ # TODO Add test where we compare huggingface wrapped reader with an explicitly coded dataset # TODO pab-vmware/Abhishek-P Add test where we load conll2003 and test it # the way tested for conll2003 specific reader +from datasets import list_datasets, load_dataset + + class HuggingfaceDatasetReaderTest: """ @@ -30,13 +33,6 @@ def test_read(self, dataset, config, split): # Confirm all features were mapped assert len(instance.fields) == len(entry) - def test_read_sequence_nesting(self): - dataset = "diplomacy_detection" - split = "train" - huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset) - instances = list(huggingface_reader.read(split)) - assert len(instances) == len(huggingface_reader.dataset[split]) - def test_read_with_tokenizer(self): dataset = "glue" config = "cola" @@ -161,3 +157,14 @@ def test_read_known_supported_datasets_without_config(self, dataset): # Confirm all features were mapped assert len(instance.fields) == len(entry) + + def test_load_all(self): + for dataset_name in list_datasets(): + try: + print("Dataset:", dataset_name) + reader = HuggingfaceDatasetReader(dataset_name) + reader.read() + except Exception as e: + print(e) + + From a9ef47540c1839ae5e6d6ce4547894faa244bcfc Mon Sep 17 00:00:00 2001 From: pab-vmware <80775579+pab-vmware@users.noreply.github.com> Date: Wed, 12 May 2021 23:14:34 +0530 Subject: [PATCH 09/63] Apply suggestions from code review by Dirk - comment text Co-authored-by: Dirk Groeneveld <groeneveld@gmail.com> --- allennlp/data/dataset_readers/huggingface_datasets_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 2f285bf316f..99cabe40548 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -47,7 +47,7 @@ def __init__( **kwargs, ) - # It would be cleaner to create a separate reader object for diferent dataset + # It would be cleaner to create a separate reader object for each different dataset if dataset_name not in list_datasets(): raise ValueError(f"Dataset {dataset_name} not available in huggingface datasets") self.dataset: DatasetDict = DatasetDict() From 2610df84ce238e8a5fc9e4e11b549009c917edf5 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld <dirkg@allenai.org> Date: Wed, 19 May 2021 17:19:29 -0700 Subject: [PATCH 10/63] Formatting --- .../huggingface_datasets_reader.py | 35 ++++++++++++++----- .../huggingface_datasets_reader_test.py | 4 +-- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 99cabe40548..20e7b41b946 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -133,16 +133,24 @@ def _map_Feature( fields_to_be_added[feature_name] = _map_ClassLabel(feature_name, entry[feature_name]) # datasets Value can be of different types elif isinstance(feature_type, Value): - fields_to_be_added[feature_name] = _map_Value(feature_name, entry[feature_name], feature_type, tokenizer) + fields_to_be_added[feature_name] = _map_Value( + feature_name, entry[feature_name], feature_type, tokenizer + ) elif isinstance(feature_type, Sequence): - fields_to_be_added[feature_name] = _map_Sequence(feature_name, entry, feature_type.feature, tokenizer) + fields_to_be_added[feature_name] = _map_Sequence( + feature_name, entry, feature_type.feature, tokenizer + ) elif isinstance(feature_type, Translation): - fields_to_be_added = _map_Translation(feature_name, entry[feature_name], feature_type, tokenizer) + fields_to_be_added = _map_Translation( + feature_name, entry[feature_name], feature_type, tokenizer + ) elif isinstance(feature_type, TranslationVariableLanguages): - fields_to_be_added = _map_TranslationVariableLanguages(feature_name, entry[feature_name], feature_type, tokenizer) + fields_to_be_added = _map_TranslationVariableLanguages( + feature_name, entry[feature_name], feature_type, tokenizer + ) else: raise ValueError(f"Datasets feature type {type(feature_type)} is not supported yet.") @@ -163,12 +171,14 @@ def _map_Value( # If tokenizer is provided we will use it to split it to tokens # Else put whole text as a single token field = _map_String(value, tokenizer) - else: field = LabelField(value, label_namespace=feature_name, skip_indexing=True) return field -def _map_Sequence(feature_name, value:Sequence, item_feature_type, tokenizer:Optional[Tokenizer]) -> Field: + +def _map_Sequence( + feature_name, value: Sequence, item_feature_type, tokenizer: Optional[Tokenizer] +) -> Field: field_list: List[Field] = list() field: ListField = None if isinstance(item_feature_type, Value): @@ -178,7 +188,7 @@ def _map_Sequence(feature_name, value:Sequence, item_feature_type, tokenizer:Opt item_field = _map_Value(value.feature, item, item.value, tokenizer) field_list.append(item_field) if len(field_list) > 0: - field = ListField(field_list) + field = ListField(field_list) # datasets Sequence of strings to ListField of LabelField elif isinstance(item_feature_type, ClassLabel): @@ -202,6 +212,7 @@ def _map_Sequence(feature_name, value:Sequence, item_feature_type, tokenizer:Opt return field + def _map_Translation( feature_name: str, value: Translation, feature_type, tokenizer: Optional[Tokenizer] ) -> Dict[str, Field]: @@ -219,7 +230,10 @@ def _map_Translation( texts.append(TextField(tokens)) fields[feature_name + "-languages"] = ListField( - [_map_to_Label(feature_name + "-languages", lang, skip_indexing=False) for lang in langs] + [ + _map_to_Label(feature_name + "-languages", lang, skip_indexing=False) + for lang in langs + ] ) fields[feature_name + "-texts"] = ListField(texts) @@ -230,7 +244,10 @@ def _map_Translation( def _map_TranslationVariableLanguages( - feature_name: str, value: TranslationVariableLanguages, feature_type, tokenizer: Optional[Tokenizer] + feature_name: str, + value: TranslationVariableLanguages, + feature_type, + tokenizer: Optional[Tokenizer], ) -> Dict[str, Field]: fields: Dict[str, Field] = dict() if feature_type.dtype == "dict": diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index b261188d4df..3ff5eef1389 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -8,7 +8,7 @@ # TODO Add test where we compare huggingface wrapped reader with an explicitly coded dataset # TODO pab-vmware/Abhishek-P Add test where we load conll2003 and test it # the way tested for conll2003 specific reader -from datasets import list_datasets, load_dataset +from datasets import list_datasets class HuggingfaceDatasetReaderTest: @@ -166,5 +166,3 @@ def test_load_all(self): reader.read() except Exception as e: print(e) - - From a0d14085e423d43fe277948bde4d8068ef5835db Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Thu, 20 May 2021 20:19:00 +0530 Subject: [PATCH 11/63] Comments addressed --- .../huggingface_datasets_reader.py | 30 ++++++++--------- .../huggingface_datasets_reader_test.py | 33 +++++++++++-------- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 20e7b41b946..930678d4b57 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -32,8 +32,6 @@ class HuggingfaceDatasetReader(DatasetReader): This is useful since text in allennlp is dealt with as a series of tokens. """ - SUPPORTED_SPLITS = [Split.TRAIN, Split.TEST, Split.VALIDATION] - def __init__( self, dataset_name: str = None, @@ -55,17 +53,13 @@ def __init__( self.config_name = config_name self.tokenizer = tokenizer + self.features = None + def load_dataset_split(self, split: str): - # TODO add support for datasets.split.NamedSplit - if split in self.SUPPORTED_SPLITS: - if self.config_name is not None: - self.dataset[split] = load_dataset(self.dataset_name, self.config_name, split=split) - else: - self.dataset[split] = load_dataset(self.dataset_name, split=split) + if self.config_name is not None: + self.dataset[split] = load_dataset(self.dataset_name, self.config_name, split=split) else: - raise ValueError( - f"Only default splits:{self.SUPPORTED_SPLITS} are currently supported." - ) + self.dataset[split] = load_dataset(self.dataset_name, split=split) def _read(self, file_path: str) -> Iterable[Instance]: """ @@ -77,6 +71,8 @@ def _read(self, file_path: str) -> Iterable[Instance]: # If split is not loaded, load the specific split if file_path not in self.dataset: self.load_dataset_split(file_path) + if self.features is None: + self.features = self.dataset[file_path].features # TODO see if use of Dataset.select() is better dataset_split = self.dataset[file_path] @@ -86,7 +82,7 @@ def _read(self, file_path: str) -> Iterable[Instance]: def raise_feature_not_supported_value_error(value): raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") - def text_to_instance(self, *inputs) -> Instance: + def text_to_instance(self, split: str, entry) -> Instance: # type: ignore """ Takes care of converting dataset entry into AllenNLP friendly instance @@ -106,7 +102,6 @@ def text_to_instance(self, *inputs) -> Instance: # e.g. In a Sentiment dataset an entry could have one feature (of type text/string) indicating the text # and another indicate the sentiment (of type int32/ClassLabel) - split = inputs[0] features: Dict[str, FeatureType] = self.dataset[split].features fields: Dict[str, Field] = dict() @@ -117,7 +112,7 @@ def text_to_instance(self, *inputs) -> Instance: field_list: list feature_type = features[feature_name] - fields_to_be_added = _map_Feature(feature_name, inputs[1], feature_type, self.tokenizer) + fields_to_be_added = _map_Feature(feature_name, entry, feature_type, self.tokenizer) for field_key in fields_to_be_added: fields[field_key] = fields_to_be_added[field_key] @@ -178,9 +173,10 @@ def _map_Value( def _map_Sequence( feature_name, value: Sequence, item_feature_type, tokenizer: Optional[Tokenizer] -) -> Field: +) -> Union[ListField]: field_list: List[Field] = list() - field: ListField = None + field: ListField + item_field: Field if isinstance(item_feature_type, Value): for item in value: # If tokenizer is provided we will use it to split it to tokens @@ -201,7 +197,7 @@ def _map_Sequence( elif isinstance(item_feature_type, Sequence): for item in value: - item_field = _map_Sequence(value.feature, item, tokenizer) + item_field = _map_Sequence(value.feature, item, item_feature_type.feature, tokenizer) field_list.append(item_field) if len(field_list) > 0: diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index 3ff5eef1389..138b65be50d 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -1,16 +1,14 @@ import pytest +from allennlp.common.testing import AllenNlpTestCase +from allennlp.common.util import ensure_list from allennlp.data import Tokenizer +from allennlp.data.dataset_readers import Conll2003DatasetReader from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetReader from allennlp.data.tokenizers import WhitespaceTokenizer # TODO Add test where we compare huggingface wrapped reader with an explicitly coded dataset -# TODO pab-vmware/Abhishek-P Add test where we load conll2003 and test it -# the way tested for conll2003 specific reader -from datasets import list_datasets - - class HuggingfaceDatasetReaderTest: """ @@ -158,11 +156,20 @@ def test_read_known_supported_datasets_without_config(self, dataset): # Confirm all features were mapped assert len(instance.fields) == len(entry) - def test_load_all(self): - for dataset_name in list_datasets(): - try: - print("Dataset:", dataset_name) - reader = HuggingfaceDatasetReader(dataset_name) - reader.read() - except Exception as e: - print(e) + def test_read_from_file_with_deprecated_parameter(self): + conll_reader = HuggingfaceDatasetReader("conll2003") + instances = ensure_list( + conll_reader.read(AllenNlpTestCase.FIXTURES_ROOT / "data" / "conll2003.txt") + ) + + expected_labels = ["I-ORG", "O", "I-PER", "O", "O", "I-LOC", "O"] + + fields = instances[0].fields + tokens = [t.text for t in fields["tokens"].tokens] + assert tokens == ["U.N.", "official", "Ekeus", "heads", "for", "Baghdad", "."] + assert fields["tags"].labels == expected_labels + + fields = instances[1].fields + tokens = [t.text for t in fields["tokens"].tokens] + assert tokens == ["AI2", "engineer", "Joel", "lives", "in", "Seattle", "."] + assert fields["tags"].labels == expected_labels From 57b6f9eb1328e61843acd7f1f594d7c48870d292 Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Sun, 23 May 2021 21:00:24 +0530 Subject: [PATCH 12/63] Formatting --- allennlp/data/dataset_readers/huggingface_datasets_reader.py | 2 +- tests/data/dataset_readers/huggingface_datasets_reader_test.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 930678d4b57..ed8feba6a85 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -1,7 +1,7 @@ from allennlp.data import DatasetReader, Token, Field, Tokenizer from allennlp.data.fields import TextField, LabelField, ListField from allennlp.data.instance import Instance -from datasets import load_dataset, DatasetDict, Split, list_datasets +from datasets import load_dataset, DatasetDict, list_datasets from datasets.features import ( ClassLabel, Sequence, diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index 138b65be50d..d679e2d4ff8 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -2,7 +2,6 @@ from allennlp.common.testing import AllenNlpTestCase from allennlp.common.util import ensure_list from allennlp.data import Tokenizer -from allennlp.data.dataset_readers import Conll2003DatasetReader from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetReader from allennlp.data.tokenizers import WhitespaceTokenizer From e841b6e94dbcd7805863740bf2e0e7c567d0bb29 Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Sun, 23 May 2021 21:17:52 +0530 Subject: [PATCH 13/63] removed invalid conll test --- .../huggingface_datasets_reader_test.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index d679e2d4ff8..27408444e13 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -155,20 +155,3 @@ def test_read_known_supported_datasets_without_config(self, dataset): # Confirm all features were mapped assert len(instance.fields) == len(entry) - def test_read_from_file_with_deprecated_parameter(self): - conll_reader = HuggingfaceDatasetReader("conll2003") - instances = ensure_list( - conll_reader.read(AllenNlpTestCase.FIXTURES_ROOT / "data" / "conll2003.txt") - ) - - expected_labels = ["I-ORG", "O", "I-PER", "O", "O", "I-LOC", "O"] - - fields = instances[0].fields - tokens = [t.text for t in fields["tokens"].tokens] - assert tokens == ["U.N.", "official", "Ekeus", "heads", "for", "Baghdad", "."] - assert fields["tags"].labels == expected_labels - - fields = instances[1].fields - tokens = [t.text for t in fields["tokens"].tokens] - assert tokens == ["AI2", "engineer", "Joel", "lives", "in", "Seattle", "."] - assert fields["tags"].labels == expected_labels From 2497b2403d62575484bb2c4a0bbccd681786b642 Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Tue, 25 May 2021 22:18:52 +0530 Subject: [PATCH 14/63] Regression Fix --- .../dataset_readers/huggingface_datasets_reader.py | 7 ++++--- .../huggingface_datasets_reader_test.py | 11 ++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index ed8feba6a85..98cadd2815b 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -134,7 +134,7 @@ def _map_Feature( elif isinstance(feature_type, Sequence): fields_to_be_added[feature_name] = _map_Sequence( - feature_name, entry, feature_type.feature, tokenizer + feature_name, entry[feature_name], feature_type.feature, tokenizer ) elif isinstance(feature_type, Translation): @@ -177,11 +177,12 @@ def _map_Sequence( field_list: List[Field] = list() field: ListField item_field: Field + # In HF Sequence and list are considered interchangeable, but there are some distinctions such as if isinstance(item_feature_type, Value): for item in value: # If tokenizer is provided we will use it to split it to tokens # Else put whole text as a single token - item_field = _map_Value(value.feature, item, item.value, tokenizer) + item_field = _map_Value(feature_name, item, item_feature_type, tokenizer) field_list.append(item_field) if len(field_list) > 0: field = ListField(field_list) @@ -189,7 +190,7 @@ def _map_Sequence( # datasets Sequence of strings to ListField of LabelField elif isinstance(item_feature_type, ClassLabel): for item in value: - item_field = _map_to_Label(value.feature, item, skip_indexing=True) + item_field = _map_to_Label(feature_name, item, skip_indexing=True) field_list.append(item_field) if len(field_list) > 0: diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index 27408444e13..33cfe4ef8c5 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -1,6 +1,4 @@ import pytest -from allennlp.common.testing import AllenNlpTestCase -from allennlp.common.util import ensure_list from allennlp.data import Tokenizer from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetReader @@ -103,8 +101,7 @@ def test_read_with_invalid_split(self, split): Test to help validate for the known supported datasets Skipped by default, enable when required """ - - @pytest.mark.skip() + # TODO pab-vmware skip these once MR is ready to check-in @pytest.mark.parametrize( "dataset, config, split", ( @@ -138,7 +135,7 @@ def test_read_known_supported_datasets_with_config(self, dataset, config, split) Skipped by default, enable when required """ - @pytest.mark.skip() + # TODO pab-vmware skip these once MR is ready to check-in @pytest.mark.parametrize( "dataset", (("swahili"), ("conll2003"), ("dbpedia_14"), ("trec"), ("emotion")) ) @@ -155,3 +152,7 @@ def test_read_known_supported_datasets_without_config(self, dataset): # Confirm all features were mapped assert len(instance.fields) == len(entry) + # def test_air_dialogue(self): + # reader = HuggingfaceDatasetReader(dataset_name="amazon_us_reviews", config_name="Apparel_v1_00") + # instances = list(reader.read("train")) + # print(instances[0]) From 74931dc260ae4fa33ac558ca7d08912931b1376e Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Thu, 24 Jun 2021 22:35:45 +0530 Subject: [PATCH 15/63] Add float mapping to TensorField --- .../huggingface_datasets_reader.py | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 98cadd2815b..a152e078fb3 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -1,5 +1,5 @@ from allennlp.data import DatasetReader, Token, Field, Tokenizer -from allennlp.data.fields import TextField, LabelField, ListField +from allennlp.data.fields import TextField, LabelField, ListField, TensorField from allennlp.data.instance import Instance from datasets import load_dataset, DatasetDict, list_datasets from datasets.features import ( @@ -10,6 +10,8 @@ Value, FeatureType, ) + +import torch from typing import Iterable, Optional, Dict, List, Union @@ -79,8 +81,8 @@ def _read(self, file_path: str) -> Iterable[Instance]: for index in self.shard_iterable(range(len(dataset_split))): yield self.text_to_instance(file_path, dataset_split[index]) - def raise_feature_not_supported_value_error(value): - raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + def raise_feature_not_supported_value_error(feature_name, feature_type): + raise ValueError(f"Datasets feature {feature_name} type {feature_type} is not supported yet.") def text_to_instance(self, split: str, entry) -> Instance: # type: ignore """ @@ -166,8 +168,11 @@ def _map_Value( # If tokenizer is provided we will use it to split it to tokens # Else put whole text as a single token field = _map_String(value, tokenizer) + + elif feature_type.dtype == "float32" or feature_type.dtype == "float64": + field = _map_Float(value) else: - field = LabelField(value, label_namespace=feature_name, skip_indexing=True) + field = LabelField(value, label_namespace=feature_name, skip_indexing=False) return field @@ -188,6 +193,15 @@ def _map_Sequence( field = ListField(field_list) # datasets Sequence of strings to ListField of LabelField + elif isinstance(item_feature_type, str): + for item in value: + # If tokenizer is provided we will use it to split it to tokens + # Else put whole text as a single token + item_field = _map_Value(feature_name, item, item_feature_type, tokenizer) + field_list.append(item_field) + if len(field_list) > 0: + field = ListField(field_list) + elif isinstance(item_feature_type, ClassLabel): for item in value: item_field = _map_to_Label(feature_name, item, skip_indexing=True) @@ -203,9 +217,9 @@ def _map_Sequence( if len(field_list) > 0: field = ListField(field_list) - + # Add support for Dict else: - HuggingfaceDatasetReader.raise_feature_not_supported_value_error(feature_name) + HuggingfaceDatasetReader.raise_feature_not_supported_value_error(feature_name, item_feature_type) return field @@ -280,7 +294,12 @@ def _map_String(text: str, tokenizer: Optional[Tokenizer]) -> TextField: field = TextField([Token(text)]) return field +def _map_Float(value: float) -> TensorField: + return TensorField(torch.tensor(value)) + # value mapper - Maps a single value to a LabelField def _map_to_Label(namespace, item, skip_indexing=True) -> LabelField: return LabelField(label=item, label_namespace=namespace, skip_indexing=skip_indexing) + + From 10dd3e628039adf2e8dace18fd7edb5df9a25358 Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Tue, 29 Jun 2021 12:03:04 +0530 Subject: [PATCH 16/63] Verification tests --- .../huggingface_datasets_reader_test.py | 72 +++++++++++++++++-- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index 33cfe4ef8c5..f3554c8233e 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -137,7 +137,7 @@ def test_read_known_supported_datasets_with_config(self, dataset, config, split) # TODO pab-vmware skip these once MR is ready to check-in @pytest.mark.parametrize( - "dataset", (("swahili"), ("conll2003"), ("dbpedia_14"), ("trec"), ("emotion")) + "dataset", (("swahili"), ("dbpedia_14"), ("trec"), ("emotion")) ) def test_read_known_supported_datasets_without_config(self, dataset): split = "train" @@ -152,7 +152,71 @@ def test_read_known_supported_datasets_without_config(self, dataset): # Confirm all features were mapped assert len(instance.fields) == len(entry) - # def test_air_dialogue(self): - # reader = HuggingfaceDatasetReader(dataset_name="amazon_us_reviews", config_name="Apparel_v1_00") - # instances = list(reader.read("train")) + + # def test_conll2003(self): + # instances = list(HuggingfaceDatasetReader("conll2003").read("test")) # print(instances[0]) + + + @pytest.mark.skip("Requires implementation of Dict") + def test_squad(self): + instances = list(HuggingfaceDatasetReader("squad").read("train")) + print(instances[0]) + + @pytest.mark.parametrize("config", (("default"), ("ptb"))) + def test_sst(self, config): + instances = list(HuggingfaceDatasetReader("sst", config).read("test")) + print(instances[0]) + + def test_open_web_text(self): + instances = list(HuggingfaceDatasetReader("openwebtext").read("plain_text")) + print(instances[0]) + + @pytest.mark.skip("Requires mapping of dict type") + def test_mocha(self): + instances = list(HuggingfaceDatasetReader("mocha").read("test")) + print(instances[0]) + + @pytest.mark.skip("Requires implementation of Dict") + def test_commonsense_qa(self): + instances = list(HuggingfaceDatasetReader("commonsense_qa").read("test")) + print(instances[0]) + + def test_piqa(self): + instances = list(HuggingfaceDatasetReader("piqa").read("test")) + print(instances[0]) + + def test_swag(self): + instances = list(HuggingfaceDatasetReader("swag").read("test")) + print(instances[0]) + + def test_snli(self): + instances = list(HuggingfaceDatasetReader("snli").read("test")) + print(instances[0]) + + def test_multi_nli(self): + instances = list(HuggingfaceDatasetReader("multi_nli").read("test")) + print(instances[0]) + + def test_super_glue(self): + instances = list(HuggingfaceDatasetReader("super_glue").read("test")) + print(instances[0]) + + @pytest.mark.parametrize("config", (("cola"), ("mnli"), ("ax"), ("mnli_matched"), ("mnli_mismatched"), ("mrpc"), ("qnli"),\ + ("qqp"), ("rte"), ("sst2"), ("stsb"), ("wnli"))) + def test_glue(self, config): + instances = list(HuggingfaceDatasetReader("glue", config).read("test")) + print(instances[0]) + + def test_drop(self): + instances = list(HuggingfaceDatasetReader("drop").read("test")) + print(instances[0]) + + + + + + + + + From f3e54dd1af3432c76b64d95e80658ba30a42cc3d Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Thu, 1 Jul 2021 00:55:35 +0530 Subject: [PATCH 17/63] Attempt to Support Dict --- .../huggingface_datasets_reader.py | 37 ++++++++++++++----- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index a152e078fb3..dfec615bdb6 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -114,7 +114,7 @@ def text_to_instance(self, split: str, entry) -> Instance: # type: ignore field_list: list feature_type = features[feature_name] - fields_to_be_added = _map_Feature(feature_name, entry, feature_type, self.tokenizer) + fields_to_be_added = _map_Feature(feature_name, entry[feature_name], feature_type, self.tokenizer) for field_key in fields_to_be_added: fields[field_key] = fields_to_be_added[field_key] @@ -123,32 +123,34 @@ def text_to_instance(self, split: str, entry) -> Instance: # type: ignore # Feature Mappers - These functions map a FeatureType into Fields def _map_Feature( - feature_name: str, entry: Dict, feature_type, tokenizer: Optional[Tokenizer] + feature_name: str, value, feature_type, tokenizer: Optional[Tokenizer] ) -> Dict[str, Field]: fields_to_be_added: Dict[str, Field] = dict() if isinstance(feature_type, ClassLabel): - fields_to_be_added[feature_name] = _map_ClassLabel(feature_name, entry[feature_name]) + fields_to_be_added[feature_name] = _map_ClassLabel(feature_name, value) # datasets Value can be of different types elif isinstance(feature_type, Value): fields_to_be_added[feature_name] = _map_Value( - feature_name, entry[feature_name], feature_type, tokenizer + feature_name, value, feature_type, tokenizer ) elif isinstance(feature_type, Sequence): fields_to_be_added[feature_name] = _map_Sequence( - feature_name, entry[feature_name], feature_type.feature, tokenizer + feature_name, value, feature_type.feature, tokenizer ) elif isinstance(feature_type, Translation): fields_to_be_added = _map_Translation( - feature_name, entry[feature_name], feature_type, tokenizer + feature_name, value, feature_type, tokenizer ) elif isinstance(feature_type, TranslationVariableLanguages): fields_to_be_added = _map_TranslationVariableLanguages( - feature_name, entry[feature_name], feature_type, tokenizer + feature_name, value, feature_type, tokenizer ) + elif isinstance(feature_type, dict): + fields_to_be_added = _map_Dict(feature_type, value, tokenizer) else: raise ValueError(f"Datasets feature type {type(feature_type)} is not supported yet.") return fields_to_be_added @@ -172,7 +174,7 @@ def _map_Value( elif feature_type.dtype == "float32" or feature_type.dtype == "float64": field = _map_Float(value) else: - field = LabelField(value, label_namespace=feature_name, skip_indexing=False) + field = LabelField(value, label_namespace=feature_name, skip_indexing=True) return field @@ -180,7 +182,7 @@ def _map_Sequence( feature_name, value: Sequence, item_feature_type, tokenizer: Optional[Tokenizer] ) -> Union[ListField]: field_list: List[Field] = list() - field: ListField + field: ListField = None item_field: Field # In HF Sequence and list are considered interchangeable, but there are some distinctions such as if isinstance(item_feature_type, Value): @@ -217,7 +219,15 @@ def _map_Sequence( if len(field_list) > 0: field = ListField(field_list) - # Add support for Dict + + # WIP for drop + elif isinstance(item_feature_type, dict): + for item in value: + item_field = _map_Dict(item_feature_type, value[item], tokenizer) + field_list.append(item_field) + if len(field_list) > 0: + field = ListField(field_list) + else: HuggingfaceDatasetReader.raise_feature_not_supported_value_error(feature_name, item_feature_type) @@ -302,4 +312,11 @@ def _map_Float(value: float) -> TensorField: def _map_to_Label(namespace, item, skip_indexing=True) -> LabelField: return LabelField(label=item, label_namespace=namespace, skip_indexing=skip_indexing) +def _map_Dict(feature_definition: dict, values: dict, tokenizer: Tokenizer) -> Dict[str, Field]: + fields: Dict[str, Field] = dict() + for key in values: + fields[key] = _map_Feature(key, values[key], feature_definition[key], tokenizer) + return fields + + From d0f31c1fb23d5320293a563081b4304f44f33472 Mon Sep 17 00:00:00 2001 From: Abhishek Purushothama <abhijnvb@gmail.com> Date: Wed, 4 Aug 2021 22:33:52 +0530 Subject: [PATCH 18/63] Quick changes --- .../data/dataset_readers/huggingface_datasets_reader.py | 9 ++++++--- .../dataset_readers/huggingface_datasets_reader_test.py | 6 +++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index dfec615bdb6..67c6217c674 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -135,9 +135,12 @@ def _map_Feature( ) elif isinstance(feature_type, Sequence): - fields_to_be_added[feature_name] = _map_Sequence( - feature_name, value, feature_type.feature, tokenizer - ) + if type(value) == dict: + fields_to_be_added = _map_Dict(feature_type, value, tokenizer) + else: + fields_to_be_added[feature_name] = _map_Sequence( + feature_name, value, feature_type.feature, tokenizer + ) elif isinstance(feature_type, Translation): fields_to_be_added = _map_Translation( diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index f3554c8233e..e86add641c0 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -158,7 +158,7 @@ def test_read_known_supported_datasets_without_config(self, dataset): # print(instances[0]) - @pytest.mark.skip("Requires implementation of Dict") + # @pytest.mark.skip("Requires implementation of Dict") def test_squad(self): instances = list(HuggingfaceDatasetReader("squad").read("train")) print(instances[0]) @@ -172,9 +172,9 @@ def test_open_web_text(self): instances = list(HuggingfaceDatasetReader("openwebtext").read("plain_text")) print(instances[0]) - @pytest.mark.skip("Requires mapping of dict type") + # @pytest.mark.skip("Requires mapping of dict type") def test_mocha(self): - instances = list(HuggingfaceDatasetReader("mocha").read("test")) + reader = HuggingfaceDatasetReader("mocha").read("test") print(instances[0]) @pytest.mark.skip("Requires implementation of Dict") From b2775341e127b1fbbbb1f43d84808a5645b02d1e Mon Sep 17 00:00:00 2001 From: Abhishek Purushothama <abhishek.purushothama@colorado.edu> Date: Wed, 11 Aug 2021 09:02:50 -0700 Subject: [PATCH 19/63] Dictionary works with SQUAD --- .../huggingface_datasets_reader.py | 47 ++++++++++--------- .../huggingface_datasets_reader_test.py | 41 ++++++++-------- 2 files changed, 48 insertions(+), 40 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 67c6217c674..6b00ae107cc 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -82,7 +82,9 @@ def _read(self, file_path: str) -> Iterable[Instance]: yield self.text_to_instance(file_path, dataset_split[index]) def raise_feature_not_supported_value_error(feature_name, feature_type): - raise ValueError(f"Datasets feature {feature_name} type {feature_type} is not supported yet.") + raise ValueError( + f"Datasets feature {feature_name} type {feature_type} is not supported yet." + ) def text_to_instance(self, split: str, entry) -> Instance: # type: ignore """ @@ -114,7 +116,9 @@ def text_to_instance(self, split: str, entry) -> Instance: # type: ignore field_list: list feature_type = features[feature_name] - fields_to_be_added = _map_Feature(feature_name, entry[feature_name], feature_type, self.tokenizer) + fields_to_be_added = _map_Feature( + feature_name, entry[feature_name], feature_type, self.tokenizer + ) for field_key in fields_to_be_added: fields[field_key] = fields_to_be_added[field_key] @@ -130,22 +134,18 @@ def _map_Feature( fields_to_be_added[feature_name] = _map_ClassLabel(feature_name, value) # datasets Value can be of different types elif isinstance(feature_type, Value): - fields_to_be_added[feature_name] = _map_Value( - feature_name, value, feature_type, tokenizer - ) + fields_to_be_added[feature_name] = _map_Value(feature_name, value, feature_type, tokenizer) elif isinstance(feature_type, Sequence): - if type(value) == dict: - fields_to_be_added = _map_Dict(feature_type, value, tokenizer) + if type(feature_type.feature) == dict: + fields_to_be_added[feature_name] = _map_Dict(feature_type.feature, value, tokenizer) else: fields_to_be_added[feature_name] = _map_Sequence( feature_name, value, feature_type.feature, tokenizer ) elif isinstance(feature_type, Translation): - fields_to_be_added = _map_Translation( - feature_name, value, feature_type, tokenizer - ) + fields_to_be_added = _map_Translation(feature_name, value, feature_type, tokenizer) elif isinstance(feature_type, TranslationVariableLanguages): fields_to_be_added = _map_TranslationVariableLanguages( @@ -166,8 +166,8 @@ def _map_ClassLabel(feature_name: str, value: ClassLabel) -> Field: def _map_Value( feature_name: str, value: Value, feature_type, tokenizer: Optional[Tokenizer] -) -> Union[TextField, LabelField]: - field: Union[TextField, LabelField] +) -> Union[TextField, LabelField, TensorField]: + field: Union[TextField, LabelField, TensorField] if feature_type.dtype == "string": # datasets.Value[string] maps to TextField # If tokenizer is provided we will use it to split it to tokens @@ -176,6 +176,7 @@ def _map_Value( elif feature_type.dtype == "float32" or feature_type.dtype == "float64": field = _map_Float(value) + else: field = LabelField(value, label_namespace=feature_name, skip_indexing=True) return field @@ -183,9 +184,9 @@ def _map_Value( def _map_Sequence( feature_name, value: Sequence, item_feature_type, tokenizer: Optional[Tokenizer] -) -> Union[ListField]: +) -> ListField: field_list: List[Field] = list() - field: ListField = None + field: ListField item_field: Field # In HF Sequence and list are considered interchangeable, but there are some distinctions such as if isinstance(item_feature_type, Value): @@ -223,7 +224,7 @@ def _map_Sequence( if len(field_list) > 0: field = ListField(field_list) - # WIP for drop + # WIP for dropx` elif isinstance(item_feature_type, dict): for item in value: item_field = _map_Dict(item_feature_type, value[item], tokenizer) @@ -232,7 +233,9 @@ def _map_Sequence( field = ListField(field_list) else: - HuggingfaceDatasetReader.raise_feature_not_supported_value_error(feature_name, item_feature_type) + HuggingfaceDatasetReader.raise_feature_not_supported_value_error( + feature_name, item_feature_type + ) return field @@ -307,6 +310,7 @@ def _map_String(text: str, tokenizer: Optional[Tokenizer]) -> TextField: field = TextField([Token(text)]) return field + def _map_Float(value: float) -> TensorField: return TensorField(torch.tensor(value)) @@ -315,11 +319,12 @@ def _map_Float(value: float) -> TensorField: def _map_to_Label(namespace, item, skip_indexing=True) -> LabelField: return LabelField(label=item, label_namespace=namespace, skip_indexing=skip_indexing) -def _map_Dict(feature_definition: dict, values: dict, tokenizer: Tokenizer) -> Dict[str, Field]: + +def _map_Dict( + feature_definition: dict, values: dict, tokenizer: Optional[Tokenizer] +) -> Dict[str, Field]: + # Map it as a Dictionary of List fields: Dict[str, Field] = dict() for key in values: - fields[key] = _map_Feature(key, values[key], feature_definition[key], tokenizer) + fields[key] = _map_Sequence(key, values[key], feature_definition[key], tokenizer) return fields - - - diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index e86add641c0..1c329057cc8 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -7,7 +7,6 @@ # TODO Add test where we compare huggingface wrapped reader with an explicitly coded dataset class HuggingfaceDatasetReaderTest: - """ Test read for some lightweight datasets """ @@ -101,6 +100,7 @@ def test_read_with_invalid_split(self, split): Test to help validate for the known supported datasets Skipped by default, enable when required """ + # TODO pab-vmware skip these once MR is ready to check-in @pytest.mark.parametrize( "dataset, config, split", @@ -136,9 +136,7 @@ def test_read_known_supported_datasets_with_config(self, dataset, config, split) """ # TODO pab-vmware skip these once MR is ready to check-in - @pytest.mark.parametrize( - "dataset", (("swahili"), ("dbpedia_14"), ("trec"), ("emotion")) - ) + @pytest.mark.parametrize("dataset", (("swahili"), ("dbpedia_14"), ("trec"), ("emotion"))) def test_read_known_supported_datasets_without_config(self, dataset): split = "train" huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset) @@ -152,15 +150,14 @@ def test_read_known_supported_datasets_without_config(self, dataset): # Confirm all features were mapped assert len(instance.fields) == len(entry) - # def test_conll2003(self): # instances = list(HuggingfaceDatasetReader("conll2003").read("test")) # print(instances[0]) - # @pytest.mark.skip("Requires implementation of Dict") def test_squad(self): - instances = list(HuggingfaceDatasetReader("squad").read("train")) + tokenizer: Tokenizer = WhitespaceTokenizer() + instances = list(HuggingfaceDatasetReader("squad", tokenizer=tokenizer).read("train")) print(instances[0]) @pytest.mark.parametrize("config", (("default"), ("ptb"))) @@ -174,7 +171,7 @@ def test_open_web_text(self): # @pytest.mark.skip("Requires mapping of dict type") def test_mocha(self): - reader = HuggingfaceDatasetReader("mocha").read("test") + instances = list(HuggingfaceDatasetReader("mocha").read("test")) print(instances[0]) @pytest.mark.skip("Requires implementation of Dict") @@ -202,8 +199,23 @@ def test_super_glue(self): instances = list(HuggingfaceDatasetReader("super_glue").read("test")) print(instances[0]) - @pytest.mark.parametrize("config", (("cola"), ("mnli"), ("ax"), ("mnli_matched"), ("mnli_mismatched"), ("mrpc"), ("qnli"),\ - ("qqp"), ("rte"), ("sst2"), ("stsb"), ("wnli"))) + @pytest.mark.parametrize( + "config", + ( + ("cola"), + ("mnli"), + ("ax"), + ("mnli_matched"), + ("mnli_mismatched"), + ("mrpc"), + ("qnli"), + ("qqp"), + ("rte"), + ("sst2"), + ("stsb"), + ("wnli"), + ), + ) def test_glue(self, config): instances = list(HuggingfaceDatasetReader("glue", config).read("test")) print(instances[0]) @@ -211,12 +223,3 @@ def test_glue(self, config): def test_drop(self): instances = list(HuggingfaceDatasetReader("drop").read("test")) print(instances[0]) - - - - - - - - - From a1d9bca27564e0896fc8b0ba78e993421baf887e Mon Sep 17 00:00:00 2001 From: ArjunSubramonian <arjun.subramonian@gmail.com> Date: Tue, 11 May 2021 10:44:41 -0700 Subject: [PATCH 20/63] Bias Mitigation and Direction Methods (#5130) * added linear and hard debiasers * worked on documentation * committing changes before branch switch * committing changes before switching branch * finished bias direction, linear and hard debiasers, need to write tests * finished bias direction test * Commiting changes before switching branch * finished hard and linear debiasers * finished OSCaR * bias mitigators tests and bias metrics remaining * added bias mitigator tests * added bias mitigator tests * finished tests for bias mitigation methods * fixed gpu issues * fixed gpu issues * fixed gpu issues * resolve issue with count_nonzero not being differentiable * added more references * responded to Akshita's comments Co-authored-by: Arjun Subramonian <arjuns@Arjuns-MacBook-Pro.local> Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-0-106.us-west-2.compute.internal> Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-0-108.us-west-2.compute.internal> Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-1-108.us-west-2.compute.internal> Co-authored-by: Michael Schmitz <MichaelS@allenai.org> Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com> --- CHANGELOG.md | 1 + allennlp/fairness/__init__.py | 12 + allennlp/fairness/bias_direction.py | 301 +++ allennlp/fairness/bias_mitigators.py | 563 ++++ test_fixtures/fairness/bias_embeddings.json | 2602 +++++++++++++++++++ tests/fairness/bias_direction_test.py | 149 ++ tests/fairness/bias_mitigators_test.py | 320 +++ 7 files changed, 3948 insertions(+) create mode 100644 allennlp/fairness/bias_direction.py create mode 100644 allennlp/fairness/bias_mitigators.py create mode 100644 test_fixtures/fairness/bias_embeddings.json create mode 100644 tests/fairness/bias_direction_test.py create mode 100644 tests/fairness/bias_mitigators_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d394cb03dae..a1d9b198274 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added a `quiet` parameter to the `MultiProcessDataLoading` that disables `Tqdm` progress bars. - The test for distributed metrics now takes a parameter specifying how often you want to run it. - Created the fairness module and added four fairness metrics: `Independence`, `Separation`, `Sufficiency`, and `DemographicParityWithoutGroundTruth`. +- Added four bias direction methods (`PCABiasDirection`, `PairedPCABiasDirection`, `TwoMeansBiasDirection`, `ClassificationNormalBiasDirection`) and four bias mitigation methods (`LinearBiasMitigator`, `HardBiasMitigator`, `INLPBiasMitigator`, `OSCaRBiasMitigator`). ### Changed diff --git a/allennlp/fairness/__init__.py b/allennlp/fairness/__init__.py index c3e75844e6b..5c254de681b 100644 --- a/allennlp/fairness/__init__.py +++ b/allennlp/fairness/__init__.py @@ -12,3 +12,15 @@ Sufficiency, DemographicParityWithoutGroundTruth, ) +from allennlp.fairness.bias_direction import ( + PCABiasDirection, + PairedPCABiasDirection, + ClassificationNormalBiasDirection, + TwoMeansBiasDirection, +) +from allennlp.fairness.bias_mitigators import ( + LinearBiasMitigator, + HardBiasMitigator, + INLPBiasMitigator, + OSCaRBiasMitigator, +) diff --git a/allennlp/fairness/bias_direction.py b/allennlp/fairness/bias_direction.py new file mode 100644 index 00000000000..03d61a2293b --- /dev/null +++ b/allennlp/fairness/bias_direction.py @@ -0,0 +1,301 @@ +""" +A suite of differentiable methods to compute the bias direction +or concept subspace representing binary protected variables. +""" + +import torch +import sklearn +import numpy as np + +from allennlp.common.checks import ConfigurationError + + +class BiasDirection: + """ + Parent class for bias direction classes. + + # Parameters + + requires_grad : `bool`, optional (default=`False`) + Option to enable gradient calculation. + """ + + def __init__(self, requires_grad: bool = False): + self.requires_grad = requires_grad + + def _normalize_bias_direction(self, bias_direction: torch.Tensor): + return bias_direction / torch.linalg.norm(bias_direction) + + +class PCABiasDirection(BiasDirection): + """ + PCA-based bias direction. Computes one-dimensional subspace that is the span + of a specific concept (e.g. gender) using PCA. This subspace minimizes the sum of + squared distances from all seed word embeddings. + + !!! Note + It is uncommon to utilize more than one direction to represent a concept. + + Implementation and terminology based on Rathore, A., Dev, S., Phillips, J.M., Srikumar, + V., Zheng, Y., Yeh, C.M., Wang, J., Zhang, W., & Wang, B. (2021). + [VERB: Visualizing and Interpreting Bias Mitigation Techniques for + Word Representations](https://api.semanticscholar.org/CorpusID:233168618). + ArXiv, abs/2104.02797. + """ + + def __call__(self, seed_embeddings: torch.Tensor): + """ + + # Parameters + + !!! Note + In the examples below, we treat gender identity as binary, which does not accurately + characterize gender in real life. + + seed_embeddings : `torch.Tensor` + A tensor of size (batch_size, ..., dim) containing seed word embeddings related to + a concept. For example, if the concept is gender, seed_embeddings could contain embeddings + for words like "man", "king", "brother", "woman", "queen", "sister", etc. + + # Returns + + bias_direction : `torch.Tensor` + A unit tensor of size (dim, ) representing the concept subspace. + """ + + # Some sanity checks + if seed_embeddings.ndim < 2: + raise ConfigurationError("seed_embeddings1 must have at least two dimensions.") + + with torch.set_grad_enabled(self.requires_grad): + # pca_lowrank centers the embeddings by default + # There will be two dimensions when applying PCA to + # definitionally-gendered words: 1) the gender direction, + # 2) all other directions, with the gender direction being principal. + _, _, V = torch.pca_lowrank(seed_embeddings, q=2) + # get top principal component + bias_direction = V[:, 0] + return self._normalize_bias_direction(bias_direction) + + +class PairedPCABiasDirection(BiasDirection): + """ + Paired-PCA-based bias direction. Computes one-dimensional subspace that is the span + of a specific concept (e.g. gender) as the first principle component of the + difference vectors between seed word embedding pairs. + + !!! Note + It is uncommon to utilize more than one direction to represent a concept. + + Based on: T. Bolukbasi, K. W. Chang, J. Zou, V. Saligrama, and A. Kalai. [Man is to + computer programmer as woman is to homemaker? debiasing word embeddings] + (https://api.semanticscholar.org/CorpusID:1704893). + In ACM Transactions of Information Systems, 2016. + + Implementation and terminology based on Rathore, A., Dev, S., Phillips, J.M., Srikumar, + V., Zheng, Y., Yeh, C.M., Wang, J., Zhang, W., & Wang, B. (2021). + [VERB: Visualizing and Interpreting Bias Mitigation Techniques for + Word Representations](https://api.semanticscholar.org/CorpusID:233168618). + ArXiv, abs/2104.02797. + """ + + def __call__(self, seed_embeddings1: torch.Tensor, seed_embeddings2: torch.Tensor): + """ + + # Parameters + + !!! Note + In the examples below, we treat gender identity as binary, which does not accurately + characterize gender in real life. + + seed_embeddings1 : `torch.Tensor` + A tensor of size (batch_size, ..., dim) containing seed word + embeddings related to a concept group. For example, if the concept is gender, + seed_embeddings1 could contain embeddings for linguistically masculine words, e.g. + "man", "king", "brother", etc. + + seed_embeddings2: `torch.Tensor` + A tensor of the same size as seed_embeddings1 containing seed word + embeddings related to a different group for the same concept. For example, + seed_embeddings2 could contain embeddings for linguistically feminine words, e.g. + "woman", "queen", "sister", etc. + + !!! Note + For Paired-PCA, the embeddings at the same positions in each of seed_embeddings1 and + seed_embeddings2 are expected to form seed word pairs. For example, if the concept + is gender, the embeddings for ("man", "woman"), ("king", "queen"), ("brother", "sister"), etc. + should be at the same positions in seed_embeddings1 and seed_embeddings2. + + !!! Note + All tensors are expected to be on the same device. + + # Returns + + bias_direction : `torch.Tensor` + A unit tensor of size (dim, ) representing the concept subspace. + """ + + # Some sanity checks + if seed_embeddings1.size() != seed_embeddings2.size(): + raise ConfigurationError("seed_embeddings1 and seed_embeddings2 must be the same size.") + if seed_embeddings1.ndim < 2: + raise ConfigurationError( + "seed_embeddings1 and seed_embeddings2 must have at least two dimensions." + ) + + with torch.set_grad_enabled(self.requires_grad): + paired_embeddings = seed_embeddings1 - seed_embeddings2 + _, _, V = torch.pca_lowrank( + paired_embeddings, + q=min(paired_embeddings.size(0), paired_embeddings.size(1)) - 1, + ) + bias_direction = V[:, 0] + return self._normalize_bias_direction(bias_direction) + + +class TwoMeansBiasDirection(BiasDirection): + """ + Two-means bias direction. Computes one-dimensional subspace that is the span + of a specific concept (e.g. gender) as the normalized difference vector of the + averages of seed word embedding sets. + + !!! Note + It is uncommon to utilize more than one direction to represent a concept. + + Based on: Dev, S., & Phillips, J.M. (2019). [Attenuating Bias in Word Vectors] + (https://api.semanticscholar.org/CorpusID:59158788). AISTATS. + + Implementation and terminology based on Rathore, A., Dev, S., Phillips, J.M., Srikumar, + V., Zheng, Y., Yeh, C.M., Wang, J., Zhang, W., & Wang, B. (2021). + [VERB: Visualizing and Interpreting Bias Mitigation Techniques for + Word Representations](https://api.semanticscholar.org/CorpusID:233168618). + ArXiv, abs/2104.02797. + """ + + def __call__(self, seed_embeddings1: torch.Tensor, seed_embeddings2: torch.Tensor): + """ + + # Parameters + + !!! Note + In the examples below, we treat gender identity as binary, which does not accurately + characterize gender in real life. + + seed_embeddings1 : `torch.Tensor` + A tensor of size (embeddings1_batch_size, ..., dim) containing seed word + embeddings related to a specific concept group. For example, if the concept is gender, + seed_embeddings1 could contain embeddings for linguistically masculine words, e.g. + "man", "king", "brother", etc. + seed_embeddings2: `torch.Tensor` + A tensor of size (embeddings2_batch_size, ..., dim) containing seed word + embeddings related to a different group for the same concept. For example, + seed_embeddings2 could contain embeddings for linguistically feminine words, , e.g. + "woman", "queen", "sister", etc. + + !!! Note + seed_embeddings1 and seed_embeddings2 need NOT be the same size. Furthermore, + the embeddings at the same positions in each of seed_embeddings1 and seed_embeddings2 + are NOT expected to form seed word pairs. + + !!! Note + All tensors are expected to be on the same device. + + # Returns + + bias_direction : `torch.Tensor` + A unit tensor of size (dim, ) representing the concept subspace. + """ + # Some sanity checks + if seed_embeddings1.ndim < 2 or seed_embeddings2.ndim < 2: + raise ConfigurationError( + "seed_embeddings1 and seed_embeddings2 must have at least two dimensions." + ) + if seed_embeddings1.size(-1) != seed_embeddings2.size(-1): + raise ConfigurationError("All seed embeddings must have same dimensionality.") + + with torch.set_grad_enabled(self.requires_grad): + seed_embeddings1_mean = torch.mean(seed_embeddings1, dim=0) + seed_embeddings2_mean = torch.mean(seed_embeddings2, dim=0) + bias_direction = seed_embeddings1_mean - seed_embeddings2_mean + return self._normalize_bias_direction(bias_direction) + + +class ClassificationNormalBiasDirection(BiasDirection): + """ + Classification normal bias direction. Computes one-dimensional subspace that is the span + of a specific concept (e.g. gender) as the direction perpendicular to the classification + boundary of a linear support vector machine fit to classify seed word embedding sets. + + !!! Note + It is uncommon to utilize more than one direction to represent a concept. + + Based on: Ravfogel, S., Elazar, Y., Gonen, H., Twiton, M., & Goldberg, Y. (2020). + [Null It Out: Guarding Protected Attributes by Iterative Nullspace Projection] + (https://api.semanticscholar.org/CorpusID:215786522). ArXiv, abs/2004.07667. + + Implementation and terminology based on Rathore, A., Dev, S., Phillips, J.M., Srikumar, + V., Zheng, Y., Yeh, C.M., Wang, J., Zhang, W., & Wang, B. (2021). + [VERB: Visualizing and Interpreting Bias Mitigation Techniques for + Word Representations](https://api.semanticscholar.org/CorpusID:233168618). + ArXiv, abs/2104.02797. + """ + + def __init__(self): + super().__init__() + + def __call__(self, seed_embeddings1: torch.Tensor, seed_embeddings2: torch.Tensor): + """ + + # Parameters + + !!! Note + In the examples below, we treat gender identity as binary, which does not accurately + characterize gender in real life. + + seed_embeddings1 : `torch.Tensor` + A tensor of size (embeddings1_batch_size, ..., dim) containing seed word + embeddings related to a specific concept group. For example, if the concept is gender, + seed_embeddings1 could contain embeddings for linguistically masculine words, e.g. + "man", "king", "brother", etc. + seed_embeddings2: `torch.Tensor` + A tensor of size (embeddings2_batch_size, ..., dim) containing seed word + embeddings related to a different group for the same concept. For example, + seed_embeddings2 could contain embeddings for linguistically feminine words, , e.g. + "woman", "queen", "sister", etc. + + !!! Note + seed_embeddings1 and seed_embeddings2 need NOT be the same size. Furthermore, + the embeddings at the same positions in each of seed_embeddings1 and seed_embeddings2 + are NOT expected to form seed word pairs. + + !!! Note + All tensors are expected to be on the same device. + + !!! Note + This bias direction method is NOT differentiable. + + # Returns + + bias_direction : `torch.Tensor` + A unit tensor of size (dim, ) representing the concept subspace. + """ + + # Some sanity checks + if seed_embeddings1.ndim < 2 or seed_embeddings2.ndim < 2: + raise ConfigurationError( + "seed_embeddings1 and seed_embeddings2 must have at least two dimensions." + ) + if seed_embeddings1.size(-1) != seed_embeddings2.size(-1): + raise ConfigurationError("All seed embeddings must have same dimensionality.") + + device = seed_embeddings1.device + seed_embeddings1 = seed_embeddings1.flatten(end_dim=-2).detach().cpu().numpy() + seed_embeddings2 = seed_embeddings2.flatten(end_dim=-2).detach().cpu().numpy() + + X = np.vstack([seed_embeddings1, seed_embeddings2]) + Y = np.concatenate([[0] * seed_embeddings1.shape[0], [1] * seed_embeddings2.shape[0]]) + + classifier = sklearn.svm.SVC(kernel="linear").fit(X, Y) + bias_direction = torch.Tensor(classifier.coef_[0]).to(device) + + return self._normalize_bias_direction(bias_direction) diff --git a/allennlp/fairness/bias_mitigators.py b/allennlp/fairness/bias_mitigators.py new file mode 100644 index 00000000000..113a6472b9b --- /dev/null +++ b/allennlp/fairness/bias_mitigators.py @@ -0,0 +1,563 @@ +""" +A suite of differentiable methods to mitigate +biases for binary concepts in embeddings. +""" + +import torch +import numpy as np +import scipy +import sklearn +from allennlp.common.checks import ConfigurationError + + +class BiasMitigator: + """ + Parent class for bias mitigator classes. + + # Parameters + + requires_grad : `bool`, optional (default=`False`) + Option to enable gradient calculation. + """ + + def __init__(self, requires_grad: bool = False): + self.requires_grad = requires_grad + + def _proj(self, u: torch.Tensor, v: torch.Tensor, normalize: bool = False): + proj = torch.matmul(u, v.reshape(-1, 1)) * v + if normalize: + return proj / torch.dot(v, v) + return proj + + def _remove_component( + self, embeddings: torch.Tensor, bias_direction: torch.Tensor, normalize: bool = False + ): + return embeddings - self._proj(embeddings, bias_direction, normalize) + + +class HardBiasMitigator(BiasMitigator): + """ + Hard bias mitigator. Mitigates bias in embeddings by: + + 1. Neutralizing: ensuring protected variable-neutral words remain equidistant + from the bias direction by removing component of embeddings + in the bias direction. + + 2. Equalizing: ensuring that protected variable-related words are averaged + out to have the same norm. + + !!! Note + For a detailed walkthrough and visual descriptions of the steps, please + refer to Figure 4 in [VERB: Visualizing and Interpreting Bias Mitigation Techniques + for Word Representations](https://api.semanticscholar.org/CorpusID:233168618). + + Based on: T. Bolukbasi, K. W. Chang, J. Zou, V. Saligrama, and A. Kalai. [Man is to + computer programmer as woman is to homemaker? debiasing word embeddings] + (https://api.semanticscholar.org/CorpusID:1704893). + In ACM Transactions of Information Systems, 2016. + + Description taken from: Goenka, D. (2020). [Tackling Gender Bias in Word Embeddings] + (https://towardsdatascience.com/tackling-gender-bias-in-word-embeddings-c965f4076a10). + + Implementation and terminology based on Rathore, A., Dev, S., Phillips, J.M., Srikumar, + V., Zheng, Y., Yeh, C.M., Wang, J., Zhang, W., & Wang, B. (2021). + [VERB: Visualizing and Interpreting Bias Mitigation Techniques for + Word Representations](https://api.semanticscholar.org/CorpusID:233168618). + ArXiv, abs/2104.02797. + """ + + def __call__( + self, + evaluation_embeddings: torch.Tensor, + bias_direction: torch.Tensor, + equalize_embeddings1: torch.Tensor, + equalize_embeddings2: torch.Tensor, + ): + """ + + !!! Note + In the examples below, we treat gender identity as binary, which does not accurately + characterize gender in real life. + + # Parameters + + evaluation_embeddings : `torch.Tensor` + A tensor of size (evaluation_batch_size, ..., dim) of embeddings for which to mitigate bias. + bias_direction : `torch.Tensor` + A unit tensor of size (dim, ) representing the concept subspace. The words + that are used to define the bias direction are considered definitionally + gendered and not modified. + equalize_embeddings1: `torch.Tensor` + A tensor of size (equalize_batch_size, ..., dim) containing equalize word + embeddings related to a group from the concept represented by bias_direction. + For example, if the concept is gender, equalize_embeddings1 could contain embeddings + for "boy", "man", "dad", "brother", etc. + equalize_embeddings2: `torch.Tensor` + A tensor of size (equalize_batch_size, ..., dim) containing equalize word + embeddings related to a different group for the same concept. For example, + equalize_embeddings2 could contain embeddings for "girl", "woman", "mom", + "sister", etc. + + !!! Note + The embeddings at the same positions in each of equalize_embeddings1 and + equalize_embeddings2 are expected to form equalize word pairs. For example, if the concept + is gender, the embeddings for ("boy", "girl"), ("man", "woman"), ("dad", "mom"), + ("brother", "sister"), etc. should be at the same positions in equalize_embeddings1 + and equalize_embeddings2. + + !!! Note + evaluation_embeddings, equalize_embeddings1, and equalize_embeddings2 must have same size + except for 0th dim (i.e. batch dimension). + + !!! Note + Please ensure that the words in evaluation_embeddings, equalize_embeddings1, and + equalize_embeddings2 and those used to compute bias_direction are disjoint. + + !!! Note + All tensors are expected to be on the same device. + + # Returns + + bias_mitigated_embeddings : `torch.Tensor` + A tensor of the same size as evaluation_embeddings, equalize_embeddings1, and equalize_embeddings2 + (in this order) stacked. + """ + + # Some sanity checks + if equalize_embeddings1.size() != equalize_embeddings2.size(): + raise ConfigurationError( + "equalize_embeddings1 and equalize_embeddings2 must be the same size." + ) + if equalize_embeddings1.ndim < 2: + raise ConfigurationError( + "equalize_embeddings1 and equalize_embeddings2 must have at least two dimensions." + ) + if evaluation_embeddings.ndim < 2: + raise ConfigurationError("evaluation_embeddings must have at least two dimensions.") + if evaluation_embeddings.size()[1:] != equalize_embeddings1.size()[1:]: + raise ConfigurationError( + "evaluation_embeddings, equalize_embeddings1, and equalize_embeddings2 must have same size \ + except for 0th dim (i.e. batch dimension)." + ) + if bias_direction.ndim != 1: + raise ConfigurationError("bias_direction must be one-dimensional.") + if evaluation_embeddings.size(-1) != bias_direction.size(-1): + raise ConfigurationError( + "All embeddings and bias_direction must have the same dimensionality." + ) + + with torch.set_grad_enabled(self.requires_grad): + bias_direction = bias_direction / torch.linalg.norm(bias_direction) + + bias_mitigated_embeddings = self._remove_component( + evaluation_embeddings, bias_direction, normalize=True + ) + + mean_equalize_embeddings = (equalize_embeddings1 + equalize_embeddings2) / 2 + y = self._remove_component(mean_equalize_embeddings, bias_direction, normalize=True) + z = torch.sqrt(1 - torch.square(torch.linalg.norm(y, dim=-1, keepdim=True))) + z = torch.where( + torch.matmul( + equalize_embeddings1 - equalize_embeddings2, bias_direction.reshape(-1, 1) + ) + < 0, + -z, + z, + ) + return torch.cat( + [bias_mitigated_embeddings, z * bias_direction + y, -z * bias_direction + y] + ) + + +class LinearBiasMitigator(BiasMitigator): + """ + Linear bias mitigator. Mitigates bias in embeddings by removing component + in the bias direction. + + !!! Note + For a detailed walkthrough and visual descriptions of the steps, please + refer to Figure 3 in [VERB: Visualizing and Interpreting Bias Mitigation Techniques + for Word Representations](https://api.semanticscholar.org/CorpusID:233168618). + + Based on: S. Dev and J. M. Phillips. [Attenuating bias in word vectors] + (https://api.semanticscholar.org/CorpusID:59158788). + In International Conference on Artificial Intelligence and Statistics, + Proceedings of Machine Learning Research, pages 879–887. PMLR, 2019. + + Implementation and terminology based on Rathore, A., Dev, S., Phillips, J.M., Srikumar, + V., Zheng, Y., Yeh, C.M., Wang, J., Zhang, W., & Wang, B. (2021). + [VERB: Visualizing and Interpreting Bias Mitigation Techniques for + Word Representations](https://api.semanticscholar.org/CorpusID:233168618). + ArXiv, abs/2104.02797. + """ + + def __call__(self, evaluation_embeddings: torch.Tensor, bias_direction: torch.Tensor): + """ + + !!! Note + In the examples below, we treat gender identity as binary, which does not accurately + characterize gender in real life. + + # Parameters + + evaluation_embeddings : `torch.Tensor` + A tensor of size (batch_size, ..., dim) of embeddings for which to mitigate bias. + bias_direction : `torch.Tensor` + A unit tensor of size (dim, ) representing the concept subspace. + + !!! Note + All tensors are expected to be on the same device. + + # Returns + + bias_mitigated_embeddings : `torch.Tensor` + A tensor of the same size as evaluation_embeddings. + """ + # Some sanity checks + if evaluation_embeddings.ndim < 2: + raise ConfigurationError("evaluation_embeddings must have at least two dimensions.") + if bias_direction.ndim != 1: + raise ConfigurationError("bias_direction must be one-dimensional.") + if evaluation_embeddings.size(-1) != bias_direction.size(-1): + raise ConfigurationError( + "All embeddings and bias_direction must have the same dimensionality." + ) + + with torch.set_grad_enabled(self.requires_grad): + bias_direction = bias_direction / torch.linalg.norm(bias_direction) + return self._remove_component(evaluation_embeddings, bias_direction) + + +class INLPBiasMitigator(BiasMitigator): + """ + Iterative Nullspace Projection. It mitigates bias by repeatedly building + a linear classifier that separates concept groups and linearly + projecting all words along the classifier normal. + + !!! Note + For a detailed walkthrough and visual descriptions of the steps, please + refer to Figure 5 in [VERB: Visualizing and Interpreting Bias Mitigation Techniques + for Word Representations](https://api.semanticscholar.org/CorpusID:233168618). + + Based on: Ravfogel, S., Elazar, Y., Gonen, H., Twiton, M., & Goldberg, Y. (2020). + [Null It Out: Guarding Protected Attributes by Iterative Nullspace Projection] + (https://api.semanticscholar.org/CorpusID:215786522). ArXiv, abs/2004.07667. + + Implementation and terminology based on Rathore, A., Dev, S., Phillips, J.M., Srikumar, + V., Zheng, Y., Yeh, C.M., Wang, J., Zhang, W., & Wang, B. (2021). + [VERB: Visualizing and Interpreting Bias Mitigation Techniques for + Word Representations](https://api.semanticscholar.org/CorpusID:233168618). + ArXiv, abs/2104.02797. + """ + + def __init__(self): + super().__init__() + + def __call__( + self, + evaluation_embeddings: torch.Tensor, + seed_embeddings1: torch.Tensor, + seed_embeddings2: torch.Tensor, + num_iters: int = 35, + ): + """ + + # Parameters + + !!! Note + In the examples below, we treat gender identity as binary, which does not accurately + characterize gender in real life. + + evaluation_embeddings : `torch.Tensor` + A tensor of size (evaluation_batch_size, ..., dim) of embeddings for which to mitigate bias. + seed_embeddings1 : `torch.Tensor` + A tensor of size (embeddings1_batch_size, ..., dim) containing seed word + embeddings related to a specific concept group. For example, if the concept is gender, + seed_embeddings1 could contain embeddings for linguistically masculine words, e.g. + "man", "king", "brother", etc. + seed_embeddings2: `torch.Tensor` + A tensor of size (embeddings2_batch_size, ..., dim) containing seed word + embeddings related to a different group for the same concept. For example, + seed_embeddings2 could contain embeddings for linguistically feminine words, , e.g. + "woman", "queen", "sister", etc. + num_iters: `torch.Tensor` + Number of times to build classifier and project embeddings along normal. + + !!! Note + seed_embeddings1 and seed_embeddings2 need NOT be the same size. Furthermore, + the embeddings at the same positions in each of seed_embeddings1 and seed_embeddings2 + are NOT expected to form seed word pairs. + + !!! Note + All tensors are expected to be on the same device. + + !!! Note + This bias mitigator is not differentiable. + + # Returns + + bias_mitigated_embeddings : `torch.Tensor` + A tensor of the same size as evaluation_embeddings. + """ + # Some sanity checks + if seed_embeddings1.ndim < 2 or seed_embeddings2.ndim < 2: + raise ConfigurationError( + "seed_embeddings1 and seed_embeddings2 must have at least two dimensions." + ) + if seed_embeddings1.size(-1) != seed_embeddings2.size(-1): + raise ConfigurationError("All seed embeddings must have same dimensionality.") + if evaluation_embeddings.ndim < 2: + raise ConfigurationError("evaluation_embeddings must have at least two dimensions.") + if evaluation_embeddings.size(-1) != seed_embeddings1.size( + -1 + ) or evaluation_embeddings.size(-1) != seed_embeddings2.size(-1): + raise ConfigurationError( + "evaluation_embeddings, seed_embeddings1, and seed_embeddings2 must have the same dimensionality." + ) + + device = seed_embeddings1.device + seed_embeddings1 = seed_embeddings1.flatten(end_dim=-2).detach().cpu().numpy() + seed_embeddings2 = seed_embeddings2.flatten(end_dim=-2).detach().cpu().numpy() + X = np.vstack([seed_embeddings1, seed_embeddings2]) + Y = np.concatenate([[0] * seed_embeddings1.shape[0], [1] * seed_embeddings2.shape[0]]) + + rowspace_projs = [] + for iter_idx in range(num_iters): + classifier = sklearn.svm.SVC(kernel="linear").fit(X, Y) + weights = np.expand_dims(classifier.coef_[0], 0) + + if (np.linalg.norm(weights) < 1e-10 or classifier.score(X, Y) < 0.55) and iter_idx > 1: + break + + rowspace_projs.append(self._get_rowspace_proj(weights)) + # Project embeddings to intersection of nullspaces + nullspace_proj = np.eye(seed_embeddings1.shape[1]) - self._get_rowspace_proj( + np.sum(rowspace_projs, axis=0) + ) + evaluation_embeddings = torch.matmul( + evaluation_embeddings, torch.from_numpy(nullspace_proj).float().t().to(device) + ) + X = nullspace_proj.dot(X.T).T + + return evaluation_embeddings + + def _get_rowspace_proj(self, weights: np.ndarray): + # Compute orthogonal basis + if np.allclose(weights, 0): + weights_basis = np.zeros_like(weights.T) + else: + weights_basis = scipy.linalg.orth(weights.T) + # Get rowspace projection + return weights_basis.dot(weights_basis.T) + + +class OSCaRBiasMitigator(BiasMitigator): + """ + OSCaR bias mitigator. Mitigates bias in embeddings by dissociating concept subspaces + through subspace orthogonalization. Formally, OSCaR applies a graded rotation + on the embedding space to rectify two ideally-independent concept subspaces + so that they become orthogonal. + + !!! Note + For a detailed walkthrough and visual descriptions of the steps, please + refer to Figure 6 in [VERB: Visualizing and Interpreting Bias Mitigation Techniques + for Word Representations](https://api.semanticscholar.org/CorpusID:233168618). + + Based on: Dev, S., Li, T., Phillips, J.M., & Srikumar, V. (2020). [OSCaR: Orthogonal Subspace + Correction and Rectification of Biases in Word Embeddings](https://api.semanticscholar.org/CorpusID:220281039). + ArXiv, abs/2007.00049. + + Implementation and terminology based on Rathore, A., Dev, S., Phillips, J.M., Srikumar, + V., Zheng, Y., Yeh, C.M., Wang, J., Zhang, W., & Wang, B. (2021). + [VERB: Visualizing and Interpreting Bias Mitigation Techniques for + Word Representations](https://api.semanticscholar.org/CorpusID:233168618). + ArXiv, abs/2104.02797. + """ + + def __call__( + self, + evaluation_embeddings: torch.Tensor, + bias_direction1: torch.Tensor, + bias_direction2: torch.Tensor, + ): + """ + + # Parameters + + evaluation_embeddings : `torch.Tensor` + A tensor of size (batch_size, ..., dim) of embeddings for which to mitigate bias. + bias_direction1 : `torch.Tensor` + A unit tensor of size (dim, ) representing a concept subspace (e.g. gender). + bias_direction2 : `torch.Tensor` + A unit tensor of size (dim, ) representing another concept subspace from + which bias_direction1 should be dissociated (e.g. occupation). + + !!! Note + All tensors are expected to be on the same device. + + # Returns + + bias_mitigated_embeddings : `torch.Tensor` + A tensor of the same size as evaluation_embeddings. + """ + # Some sanity checks + if evaluation_embeddings.ndim < 2: + raise ConfigurationError("evaluation_embeddings must have at least two dimensions.") + if bias_direction1.ndim != 1 or bias_direction2.ndim != 1: + raise ConfigurationError("bias_direction1 and bias_direction2 must be one-dimensional.") + if evaluation_embeddings.size(-1) != bias_direction1.size(-1) or evaluation_embeddings.size( + -1 + ) != bias_direction2.size(-1): + raise ConfigurationError( + "All embeddings, bias_direction1, and bias_direction2 must have the same dimensionality." + ) + if bias_direction1.size(-1) < 2: + raise ConfigurationError( + "Dimensionality of all embeddings, bias_direction1, and bias_direction2 must \ + be >= 2." + ) + + with torch.set_grad_enabled(self.requires_grad): + bias_direction1 = bias_direction1 / torch.linalg.norm(bias_direction1) + bias_direction2 = bias_direction2 / torch.linalg.norm(bias_direction2) + + bias_direction2_orth = self._remove_component( + bias_direction2.reshape(1, -1), bias_direction1 + )[0] + bias_direction2_orth = bias_direction2_orth / torch.linalg.norm(bias_direction2_orth) + + # Create rotation matrix as orthonormal basis + # with v1 and v2' + init_orth_matrix = torch.eye( + bias_direction1.size(0), + device=evaluation_embeddings.device, + requires_grad=self.requires_grad, + ) + rotation_matrix = torch.zeros( + (bias_direction1.size(0), bias_direction1.size(0)), + device=evaluation_embeddings.device, + requires_grad=self.requires_grad, + ) + rotation_matrix = torch.cat( + [ + bias_direction1.reshape(1, -1), + bias_direction2_orth.reshape(1, -1), + rotation_matrix[2:], + ] + ) + # Apply Gram-Schmidt + for i in range(len(rotation_matrix) - 2): + subspace_proj = torch.sum( + self._proj( + rotation_matrix[: i + 2].clone(), init_orth_matrix[i], normalize=True + ), + dim=0, + ) + rotation_matrix[i + 2] = (init_orth_matrix[i] - subspace_proj) / torch.linalg.norm( + init_orth_matrix[i] - subspace_proj + ) + + mask = ~(evaluation_embeddings == 0).all(dim=-1) + # Transform all evaluation embeddings + # using orthonormal basis computed above + rotated_evaluation_embeddings = torch.matmul( + evaluation_embeddings[mask], rotation_matrix.t() + ) + # Want to adjust first 2 coordinates and leave d - 2 + # other orthogonal components fixed + fixed_rotated_evaluation_embeddings = rotated_evaluation_embeddings[..., 2:] + # Restrict attention to subspace S spanned by bias directions + # which we hope to make orthogonal + restricted_rotated_evaluation_embeddings = torch.cat( + [ + torch.matmul(rotated_evaluation_embeddings, bias_direction1.reshape(-1, 1)), + torch.matmul( + rotated_evaluation_embeddings, bias_direction2_orth.reshape(-1, 1) + ), + ], + dim=-1, + ) + + # Transform and restrict bias directions + restricted_bias_direction1 = torch.tensor( + [1.0, 0.0], device=evaluation_embeddings.device, requires_grad=self.requires_grad + ) + bias_direction_inner_prod = torch.dot(bias_direction1, bias_direction2) + restricted_bias_direction2 = torch.tensor( + [ + bias_direction_inner_prod, + torch.sqrt(1 - torch.square(bias_direction_inner_prod)), + ], + device=evaluation_embeddings.device, + requires_grad=self.requires_grad, + ) + restricted_bias_direction2_orth = torch.tensor( + [0.0, 1.0], device=evaluation_embeddings.device, requires_grad=self.requires_grad + ) + + restricted_bias_direction_inner_prod = torch.dot( + restricted_bias_direction1, restricted_bias_direction2 + ) + theta = torch.abs(torch.arccos(restricted_bias_direction_inner_prod)) + theta_proj = np.pi / 2 - theta + phi = torch.arccos( + torch.matmul( + restricted_rotated_evaluation_embeddings + / torch.linalg.norm( + restricted_rotated_evaluation_embeddings, dim=-1, keepdim=True + ), + restricted_bias_direction1, + ) + ) + d = torch.matmul( + restricted_rotated_evaluation_embeddings + / torch.linalg.norm(restricted_rotated_evaluation_embeddings, dim=-1, keepdim=True), + restricted_bias_direction2_orth, + ) + + # Add noise to avoid DivideByZero + theta_x = torch.zeros_like(phi, requires_grad=self.requires_grad) + theta_x = torch.where( + (d > 0) & (phi < theta_proj), + theta * (phi / (theta_proj + 1e-10)), + theta_x, + ) + theta_x = torch.where( + (d > 0) & (phi > theta_proj), + theta * ((np.pi - phi) / (np.pi - theta_proj + 1e-10)), + theta_x, + ) + theta_x = torch.where( + (d < 0) & (phi >= np.pi - theta_proj), + theta * ((np.pi - phi) / (theta_proj + 1e-10)), + theta_x, + ) + theta_x = torch.where( + (d < 0) & (phi < np.pi - theta_proj), + theta * (phi / (np.pi - theta_proj + 1e-10)), + theta_x, + ) + + f_matrix = torch.cat( + [ + torch.cos(theta_x).unsqueeze(-1), + -torch.sin(theta_x).unsqueeze(-1), + torch.sin(theta_x).unsqueeze(-1), + torch.cos(theta_x).unsqueeze(-1), + ], + dim=-1, + ) + f_matrix = f_matrix.reshape(f_matrix.size()[:-1] + (2, 2)) + + evaluation_embeddings_clone = evaluation_embeddings.clone() + evaluation_embeddings_clone[mask] = torch.cat( + [ + torch.bmm( + f_matrix, + restricted_rotated_evaluation_embeddings.unsqueeze(-1), + ).squeeze(-1), + fixed_rotated_evaluation_embeddings, + ], + dim=-1, + ) + return torch.matmul(evaluation_embeddings_clone, rotation_matrix) diff --git a/test_fixtures/fairness/bias_embeddings.json b/test_fixtures/fairness/bias_embeddings.json new file mode 100644 index 00000000000..4481e4ccf5b --- /dev/null +++ b/test_fixtures/fairness/bias_embeddings.json @@ -0,0 +1,2602 @@ +{ + "he": [ + -0.0367426791967243, + -0.011021889398097602, + -0.11295283313084178, + -0.15441727211683257, + 0.10571840953427392, + 0.026829178105471942, + -0.15744929292651647, + 0.12261579933009975, + -0.15828684752895025, + -0.03334491401916538, + 0.02899621348513142, + 0.08378106234913035, + -0.18585300053569034, + -0.06560356726574197, + 0.13508585355279848, + -0.04397710279329215, + -0.06198086930193991, + 0.047074957589193556, + -0.142991418064417, + 0.01527494777993225, + 0.03245981264890784, + 0.1678272893649701, + 0.11800924901671392, + -0.03638424971620678, + 0.06842345634205835, + -0.5033556862879934, + -0.016748531838100697, + 0.007378709749694549, + -0.01184865454167908, + -0.05754256139777657, + 0.6207413411574805, + 0.00821845881833559, + -0.10064919259552618, + -0.11947771265373211, + 0.019084541115718435, + 0.0029980065834715216, + 0.044598868218679685, + 0.18442659750097778, + 0.057443810418450314, + -0.061821770501914286, + -0.030951117131423393, + 0.018704166973128416, + -0.1136422612642862, + 0.03626172535296865, + -0.06610280832789137, + -0.045295611239481594, + -0.07130003579539537, + -0.06092752552245986, + -0.0076182723106526865, + -0.0024086095346409304 + ], + "him": [ + 0.021675511432224605, + -0.008226150088433285, + 0.009257929072105294, + -0.15014331819815693, + 0.17694239585658778, + 0.020160906989116968, + -0.09889859729708098, + 0.2094538512771219, + -0.12334424053972613, + 0.010907869582079127, + -0.05211435024640094, + 0.15954256203887754, + -0.16630755365438463, + -0.03320534716899135, + 0.15251849334757575, + -0.021479844829526487, + -0.027098012745886278, + -0.029268462653593154, + -0.06776949019005296, + -0.08214736203275895, + -0.033547763723713056, + 0.18827475326285362, + 0.10085345159625937, + -0.025924013129697582, + 0.1293464947502711, + -0.5327747949577307, + -0.009395439212334803, + -0.025130476352088555, + 0.05824886058654742, + -0.17749678456423243, + 0.5263793958139867, + 0.09679336959027346, + -0.14750725424514066, + -0.08850652662044771, + 0.00202460581958467, + 0.07647484228231635, + 0.07066463121886396, + 0.08196800098028567, + -0.0012803661554887538, + -0.11756120539331513, + 0.016508644911532415, + 0.08495191667143194, + -0.10542806429822921, + 0.08271262110722016, + 0.004953263812746746, + -0.0806925446071424, + -0.07676834218636353, + -0.11376201219092673, + 0.004455401012548207, + -0.025981988419385916 + ], + "his": [ + -0.006007637080540547, + 0.08515521480682708, + -0.12314787212298071, + -0.13016099171337828, + 0.1505232216965325, + 0.11519071319052966, + -0.1360975123577148, + 0.11328830373842183, + -0.09704796090150124, + 0.020836936673181157, + -0.03628192188605667, + 0.11342982007242987, + -0.2270889324328727, + -0.03166024920579469, + 0.06320585226758287, + -0.09869599922159465, + -0.11648227365660288, + -0.06099533130345757, + -0.056710431671214666, + -0.013956018417591108, + -0.019857070411125614, + 0.1742961744638801, + -0.0030173432025710407, + -0.13211535019948906, + 0.08571948879685908, + -0.4848994487674869, + -0.07598352328630718, + -0.009581730619943147, + 0.003308078658387521, + -0.021302686633207557, + 0.5926130837535927, + 0.03200060494581398, + -0.09082482489714848, + -0.041011075326324746, + 0.04331116359045512, + 0.10207089508578596, + 0.01739629841078612, + 0.1709338180216895, + 0.0013630889118572678, + -0.097008551289499, + 0.017605348488997974, + 0.07439997342221741, + -0.1999142136113323, + 0.0009127624408917406, + -0.026825406351520618, + -0.08084881902258298, + -0.014543579905624413, + -0.11137335486431328, + -0.004053457729029773, + -0.07851469518445067 + ], + "man": [ + -0.017997175489727155, + 0.08200416653811962, + -0.03284208999587444, + -0.0868130234220952, + 0.3136053496064486, + 0.07690929516857847, + -0.07105171850419587, + 0.047804461117427324, + -0.02018880915445417, + 0.02055109416950388, + -0.020684567596101135, + 0.02894657270247155, + -0.12469468865363477, + 0.10497494325550812, + 0.1136259280622477, + -0.08824118908668588, + 0.022589424069967754, + 0.12288707710486045, + -0.13528103814603462, + 0.045661259238922745, + -0.15808020617208376, + 0.242540283759593, + 0.0062963228852401885, + 0.05596350100899413, + 0.07457351020312643, + -0.5356860638319186, + -0.13489396520890257, + 0.07829169851547868, + 0.0742493604528188, + -0.05554401309683132, + 0.4981228280609754, + -0.06592824568609817, + -0.03209463880692979, + 0.04796272246610693, + 0.05952152120942967, + 0.060328082058724544, + 0.023908904230043524, + -0.002411292789641362, + 0.04251509989770161, + -0.10789419777151389, + -0.016448502388551514, + 0.11926613371760048, + -0.010982956245717417, + 0.05601117008992173, + 0.12585590746503095, + -0.101277729338764, + -0.09196891121522364, + -0.18671978999338162, + 0.10131586460350607, + -0.022356798955041095 + ], + "boy": [ + -0.06170777316335876, + 0.04451277673357509, + -0.03831242852578669, + -0.1008233852569851, + 0.20844616774860342, + 0.11913253656472213, + -0.19051094513278224, + 0.05358054751253456, + 0.01685082940926519, + 0.07043404784102773, + 0.06142923444387041, + 0.06675390270477424, + 0.012869824299975324, + 0.04618973244885079, + 0.17659545595505652, + -0.06215801383321662, + -0.1891277905325833, + 0.15408723806106034, + -0.04358367840213111, + 0.07645697070010096, + -0.16363195870215738, + 0.2639631317014165, + 0.010767429307672917, + 0.14606303358045786, + 0.06883340410369404, + -0.39476186189402357, + -0.08905417045887844, + 0.02357849338463289, + 0.06701527122922564, + -0.14707607508763806, + 0.42093687032813343, + -0.0812817955054846, + -0.046319462811352206, + 0.07310687486844666, + 0.11578434853251639, + 0.11987657834965673, + 0.0607157174501126, + -0.16885932919118513, + 0.07312404506348362, + -0.21863381680386187, + -0.08003027906723562, + 0.049717253629220255, + -0.1253042677807823, + -0.021960679452262256, + 0.20205503959595997, + -0.11665812067995243, + 0.061339567869788544, + -0.2514861233078977, + 0.060561185694780036, + 0.004821009206486554 + ], + "brother": [ + 0.07614158568847995, + 0.11257270917746082, + 0.03419512022734642, + -0.06537173560639559, + 0.23516860661767125, + 0.12771102867719966, + -0.15030601087936146, + 0.15144491904112006, + -0.16834536043448892, + 0.013538671217996083, + 0.08061558330993374, + 0.26218780898666455, + -0.20687510141034662, + -0.07192244863469241, + 0.1476459037043449, + -0.04145665530764993, + -0.09166219603971817, + 0.006020284471841425, + -0.04583508021622887, + 0.11874511355062804, + -0.0968569712041031, + 0.13916581653488708, + 0.033147802582092885, + 0.04287431721529349, + 0.05013585229559701, + -0.41112195321663497, + -0.04054871453533888, + -0.23277928879580007, + -0.025147570075194206, + 0.03275356514148414, + 0.30348318534133817, + -0.06526620740259628, + -0.08548381837199627, + -0.023532789447246266, + 0.1547202755552685, + 0.052620742830343036, + -8.286950645523221e-05, + 0.08214474671593129, + 0.0695410952055608, + 0.04817063838710795, + -0.018738225018024756, + 0.22369988107268954, + -0.17073666935454498, + -0.09515259115783499, + 0.09095535618408128, + -0.04639856100255349, + -0.1867510720556366, + -0.343106039220702, + 0.018184102393169133, + 0.03886822766728948 + ], + "grandpa": [ + -0.09096832379090469, + 0.08813875886827918, + 0.018170235239713336, + -0.15038661249368115, + 0.07632101274011627, + -0.024546411336967257, + -0.06903468996392866, + 0.06712943241821176, + -0.23694194786112804, + -0.014431553507097784, + -0.10747969762966478, + 0.3514427770139955, + -0.03237135557067374, + -0.09908626573907413, + 0.2226293442170196, + 0.1591353491698244, + -0.011315170083671118, + 0.20191867976060512, + 0.216313672920204, + 0.07068762961845602, + -0.08213462292691188, + 0.005535288104779407, + 0.10886229668648907, + 0.10055382898374794, + 0.2930234611865408, + -0.1662337208637994, + -0.19499796045940643, + 0.003894706877575612, + 0.144488038119036, + -0.1912595361940268, + -0.07436426174721782, + 0.16360498038518193, + -0.03579309513588693, + 0.24197543232312338, + 0.05524731948107191, + 0.2323899271302801, + -0.029582985405793467, + -0.021913551382575225, + 0.06877722272802096, + -0.04248466859712776, + -0.02660923883105966, + 0.12761878482236422, + -0.312951425245796, + 0.015160958186424267, + 0.1991586309916747, + -0.03915304256448227, + 0.04770352946897662, + -0.14478670011268893, + 0.031560333777564516, + 0.007135189508709786 + ], + "uncle": [ + 0.06753174114657315, + 0.11158246719809903, + -0.005007705990043669, + -0.14550167959960217, + 0.20934007328370216, + 0.09552757778382517, + -0.1503056600178724, + 0.16204288166907713, + -0.19282296778723265, + 0.005775291370144851, + -0.028858932070356648, + 0.3474068648123471, + -0.027502952189413062, + -0.08773124111778605, + 0.16330466585554482, + -0.02104572780321371, + -0.07472566348747578, + 0.06545943586810364, + 0.05265977440711275, + 0.09549690941818186, + -0.05006172571761504, + 0.04724899846861413, + -0.029919181282596866, + 0.059560156676857975, + 0.1406276000598615, + -0.3657421662719558, + -0.11381249549987699, + -0.23610260350258386, + 0.09494049764151033, + 0.030079094903451276, + 0.21411557593387504, + -0.05925347302042485, + -0.07108489036610545, + 0.10259006427197075, + 0.12657491680258673, + 0.11202058670728919, + -0.038499751870086384, + 0.046261038975390294, + 0.0855537871571109, + 0.02354454242387986, + -0.011906116721997526, + 0.20939702881989689, + -0.21823389932026266, + -0.02972421810100724, + 0.1019942217394721, + -0.018760934562786963, + -0.14064512484022912, + -0.40902837377994483, + 0.004215147797918647, + 0.0850455685264503 + ], + "jack": [ + -0.19683922292481285, + 0.18515257115010456, + 0.08419660779737294, + 0.10014290118797743, + 0.03904932078085833, + 0.26248059060511736, + -0.26531737343269374, + 0.035068890073188176, + -0.13744547852697378, + -0.14302746300422836, + -0.1208246115820952, + 0.3204894488981554, + -0.12423545203698427, + 0.07168572542158141, + 0.11947099709901547, + 0.018682336922492906, + 0.12668133959304423, + 0.15585999634629627, + -0.13398326326497112, + -0.11454795036517429, + 0.09345747520141359, + 0.12720178872597754, + 0.13478068961886464, + 0.05292721976332445, + 0.07779932746809835, + -0.3099687818761989, + 0.03459088099830523, + -0.01411757362608734, + -0.00010938813278115742, + -0.052942855574185106, + 0.27514559740224964, + -0.1763942833808707, + -0.04929077689458876, + -0.023587737526934213, + 0.07792888132951524, + 0.02828518184692878, + -0.07024946450966676, + -0.08088181589491363, + 0.13686248615059785, + -0.18582714470437858, + 0.04653664049584729, + 0.21118396254583288, + -0.17998605250429026, + -0.1292232757015339, + 0.13301384299161037, + 0.039866850320144115, + -0.04781654329915537, + -0.14113552989008885, + -0.04368198817014446, + 0.1779198917834177 + ], + "she": [ + 0.010696994952413164, + 0.06700192873624893, + -0.13311808067209266, + -0.12783353628087532, + 0.10389807558032646, + 0.14017595021772117, + -0.1292295205183139, + 0.12090499014810599, + -0.023028425256933972, + -0.04072447417542874, + 0.020541164001396217, + 0.03973063260537412, + -0.07915124333060643, + -0.020399439713331385, + 0.18307234910774345, + -0.015593037638972778, + -0.139121875825239, + 0.060773146275799675, + -0.020296689604484386, + -0.02109034561764743, + 0.08128419386598211, + 0.28934784962035737, + 0.12143291312114747, + 0.03951981772687769, + 0.17890919814583905, + -0.4664854941653861, + -0.09234754610304285, + 0.04546692316479809, + 0.004157481990381772, + -0.14635867228454935, + 0.5506697212758948, + 0.0482199174604574, + -0.051777197090884616, + -0.08372185162069717, + -0.021784794629165095, + -0.02399569352297643, + 0.01982722790026964, + 0.1531295501468466, + 0.0586756268124402, + -0.17116042269589452, + -0.007933902801229232, + -0.011823348731808394, + 0.0005379676819580844, + -0.06006452483547553, + 0.0031505309236811584, + -0.10363411409380571, + -0.0008882569754463185, + -0.22268428762186324, + -0.010757404930200797, + 0.07484282497343561 + ], + "her": [ + 0.02232489067350508, + 0.14854055811996092, + -0.1278580118621893, + -0.10690895941119526, + 0.1435868742534382, + 0.21856839171658113, + -0.10663245923572694, + 0.1366960234949296, + 0.05460545332757719, + 0.0035740146174841345, + -0.015856119098512592, + 0.06800072086442177, + -0.10593954313334852, + -0.00304400042571294, + 0.11611008573220116, + -0.04918871495033042, + -0.19841386085413157, + -0.03980436562148406, + 0.05720055738408102, + -0.05529337243883269, + 0.039479561198494176, + 0.305882483271094, + 0.02047933528543945, + -0.031021320891095917, + 0.14408490770202267, + -0.4390689533340252, + -0.1297718594622682, + 0.03381297326510133, + 0.03162262548955413, + -0.1330815332493498, + 0.49773362911712216, + 0.07384886614194741, + -0.04724988239463691, + -0.03262702070526132, + 0.010306294191025344, + 0.06422466123920084, + -0.0046008962932444775, + 0.11967294040284411, + 0.028576126568279726, + -0.20267795994569116, + 0.013597812243693656, + 0.02880432249622647, + -0.0528315214789401, + -0.061694518067294986, + 0.031609300179893, + -0.148535561128838, + 0.03080145328168738, + -0.27068700912865107, + 0.0065184083534801755, + -0.017121357250836286 + ], + "hers": [ + 0.02232489067350508, + 0.14854055811996092, + -0.1278580118621893, + -0.10690895941119526, + 0.1435868742534382, + 0.21856839171658113, + -0.10663245923572694, + 0.1366960234949296, + 0.05460545332757719, + 0.0035740146174841345, + -0.015856119098512592, + 0.06800072086442177, + -0.10593954313334852, + -0.00304400042571294, + 0.11611008573220116, + -0.04918871495033042, + -0.19841386085413157, + -0.03980436562148406, + 0.05720055738408102, + -0.05529337243883269, + 0.039479561198494176, + 0.305882483271094, + 0.02047933528543945, + -0.031021320891095917, + 0.14408490770202267, + -0.4390689533340252, + -0.1297718594622682, + 0.03381297326510133, + 0.03162262548955413, + -0.1330815332493498, + 0.49773362911712216, + 0.07384886614194741, + -0.04724988239463691, + -0.03262702070526132, + 0.010306294191025344, + 0.06422466123920084, + -0.0046008962932444775, + 0.11967294040284411, + 0.028576126568279726, + -0.20267795994569116, + 0.013597812243693656, + 0.02880432249622647, + -0.0528315214789401, + -0.061694518067294986, + 0.031609300179893, + -0.148535561128838, + 0.03080145328168738, + -0.27068700912865107, + 0.0065184083534801755, + -0.017121357250836286 + ], + "woman": [ + -0.03256106847930736, + 0.11628030553121017, + -0.10441138082853971, + -0.08870034690520731, + 0.2764991299556674, + 0.2412528899061775, + -0.0776762557426544, + 0.10414053185920268, + 0.06377686062084793, + -0.04517258572042509, + 0.03632963592683806, + -0.12850617688883476, + 0.05490521159872188, + 0.10067510001311543, + 0.15054180330145475, + -0.06831313243179755, + -0.1630026495927426, + 0.07771392348011187, + -0.002589387894933515, + 0.04255557481802276, + -0.09649936226068731, + 0.3187946180150552, + -0.01191609905958148, + 0.12519141599257738, + 0.12428739029359812, + -0.47961791994061564, + -0.13776526549623763, + 0.060858507818785844, + 0.03532695662975588, + -0.06321901936612064, + 0.4111164488215307, + -0.049167159592700606, + -0.05411418911211501, + 0.0001529776503237045, + 0.030354815285369825, + 0.01640035351880411, + -0.004234929911289852, + 0.006499657783375649, + 0.06186118711586803, + -0.15057588363534485, + -0.0451546487025882, + 0.07555610023433394, + 0.08720260591582694, + 0.004004439232085809, + 0.10001681145850154, + -0.1528646471113321, + -0.041386081255057494, + -0.23565654034106767, + 0.08746807377981293, + -0.01877467656987331 + ], + "girl": [ + -0.060866403703171, + 0.1228293243829214, + -0.13787862259771433, + -0.103263294403088, + 0.21653120263757533, + 0.22149289781340167, + -0.1350181435283767, + 0.07164440631998385, + 0.03330515815175688, + 0.017451394341908278, + 0.057486800654583226, + -0.056178396339534344, + 0.042146157349394815, + 0.05924723129170026, + 0.18702589075570397, + -0.04461287911474336, + -0.18201122374526024, + 0.12364862209166995, + 0.005377877223136779, + 0.09511269478329637, + -0.053464472679304764, + 0.30926722777437265, + 0.05523196622770413, + 0.21370603812464928, + 0.07298635946362371, + -0.3429220000346041, + -0.16554757754618374, + 0.057303164961243026, + 0.09225751289742048, + -0.15310272786674453, + 0.37535842009863624, + -0.03598023579993372, + -0.03365653798805206, + 0.06663150503736069, + 0.11690001036141788, + 0.08865895959908092, + -0.02214928978134017, + -0.14667724432765838, + 0.03735397204434401, + -0.23115142899196758, + -0.0794100772748893, + -0.020567197654101584, + 0.005932492331580573, + -0.12074399982681787, + 0.17554689419412137, + -0.15584666889992393, + 0.09994196037257931, + -0.2343120817908036, + 0.06589519653617935, + 0.003953288017493822 + ], + "sister": [ + 0.14810416710846683, + 0.2594289856383753, + -0.002297455763809501, + 0.01863195508916412, + 0.11551860290540009, + 0.2589676894134288, + -0.21071209301250643, + 0.11599594421643167, + -0.028808952065787572, + -0.09258616361827846, + 0.09688223541756256, + 0.11172795131779646, + -0.0659613488958244, + -0.038464082617493324, + 0.0622208382022367, + 0.01599755251756763, + -0.26035157808826825, + 0.03709423339297836, + 0.10505921241367897, + 0.13400053643966858, + -0.015193292577552257, + 0.28770844986335475, + 0.037986741306461756, + 0.19778978180672327, + 0.13523400243246023, + -0.28343644569319826, + -0.06637250422675496, + -0.07281861756144176, + -0.029647307813733776, + -0.12276697054434164, + 0.3431041095721443, + 0.03489405096355975, + 0.046947921884987164, + 0.005301296442515297, + 0.03172915773327482, + -0.10063678556147475, + -0.007620814199691803, + -0.0001266097686617218, + 0.05782448461492075, + -0.11501318269371962, + 0.026953738987198498, + 0.004588894620337091, + 0.013984495904616432, + -0.193953000596625, + 0.02871268154927701, + -0.05332584360381231, + -0.12713123395948742, + -0.4042960566291743, + 0.004844011489090097, + 0.11876773283763177 + ], + "grandma": [ + -0.03769079204578591, + 0.13669996615202537, + -0.04840885126885742, + -0.06146196123218447, + 0.06240152318052893, + 0.06063575126152231, + -0.1070218994619599, + 0.008615757876990317, + -0.022045448299506363, + 0.021342414144592318, + -0.1501812946977078, + 0.15117123530279994, + 0.13456139217308583, + -0.001217198725680991, + 0.29055890279231467, + 0.170871809023985, + -0.05317719113002649, + 0.12751845595975847, + 0.2594500822506148, + 0.038277703396896526, + 0.04484959916963297, + 0.09819052093408946, + 0.003846410442686299, + 0.06903387334136261, + 0.41665668063178823, + -0.2038118937384635, + -0.1924741770373954, + 0.050338353822294034, + 0.18608112549611325, + -0.297057749512767, + -0.035887236134218535, + 0.14968254599590566, + -0.1516044917508301, + 0.32912376453267306, + -0.026982808554063938, + 0.07786525186923304, + 0.010028627305479345, + 0.13999221137048706, + 0.138498484197918, + -0.051167082725793574, + -0.02773848840527932, + 0.07852521227262782, + -0.10789848808936975, + 0.061313344194778775, + 0.1117852014574542, + 0.031227210385057002, + 0.0960746173506864, + -0.21289012768439766, + 0.016581631081935406, + 0.027871991845660703 + ], + "aunt": [ + 0.14825517176741843, + 0.22900728172897045, + -0.09560278355884262, + -0.10669466865247117, + 0.07193268819082903, + 0.24418498879339506, + -0.19006514271947433, + 0.14393036380140903, + -0.01921321619412316, + 0.0019387843680099205, + -0.06389311545711791, + 0.23311616375193625, + 0.12928094156850858, + -0.09181045316111552, + 0.15029493820024786, + 0.028439333430384593, + -0.22791717017185706, + -0.01635922028286519, + 0.18947396683657822, + 0.11657066008624249, + -0.004269883114477911, + 0.25135456864979455, + 0.055717278778767636, + 0.08334322127807534, + 0.2523608254717453, + -0.23517060476341908, + -0.17232148075907514, + -0.028946654578118125, + 0.11964812886670871, + -0.11158130334406975, + 0.11420805292303716, + 0.03386263842702361, + -0.006545700626789978, + 0.15514593462940238, + 0.08821727775939954, + -0.0004155421400981007, + -0.08156340452424984, + 0.028221311118961916, + 0.18048893196390917, + -0.09445397368711546, + -0.02862171747936318, + 0.1395258938369952, + -0.03273688860746615, + -0.13657001442251473, + 0.05315132388279308, + -0.030946590011745347, + -0.058945686082526455, + -0.4098190544232555, + 0.0860244764348984, + 0.08473730625015301 + ], + "jill": [ + -0.1406050149171967, + 0.07361177145293331, + 0.0408626369826799, + 0.12827510367514605, + 0.259395013240159, + 0.11220533707107652, + -0.21569482752549454, + 0.021533196963137883, + 0.07942717396929387, + -0.23746146306718016, + 0.02861497352996215, + -0.02162923335244717, + -0.20508728175417962, + 0.015659253373193335, + 0.06092142136939443, + 0.06325096904961705, + -0.03554555930757154, + -0.17138745960134313, + 0.09103620752836802, + 0.056898536849965306, + -0.022455097919454443, + 0.3044619636893259, + 0.14390944080627677, + 0.13761022776442866, + 0.037016343305053746, + -0.05818547284776224, + -0.045282246133008824, + -0.1314609959854817, + 0.10105834256384298, + -0.0664562137809722, + -0.0464506485520613, + -0.03200406626100255, + 0.015374288973681158, + -0.11162476444049145, + -0.07887562997023805, + -0.10654959202812687, + -0.19113902430437305, + 0.06178744221001717, + 0.07629450081676187, + -0.1508497027943959, + 0.2626123532346513, + 0.0509839531758798, + 0.07590987144899926, + -0.37480801409521813, + 0.15361951805281227, + 0.08965009037284614, + 0.006525152557250815, + -0.17691015674978372, + -0.15883499551756822, + 0.3061069194759836 + ], + "engineer": [ + -0.06740834931842901, + -0.07036640709915663, + 0.1626630346881778, + -0.046876775714017216, + 0.05358666357396935, + -0.013682222965938937, + -0.2653409982114645, + 0.0756796576237788, + 0.05988459365875221, + -0.09347221441084566, + 0.23324848275071713, + 0.09884575143915114, + -0.12213442781201288, + 0.1483228850120417, + -0.1906460201478464, + -0.11652376386921971, + -0.055109902567034255, + 0.21992516543670615, + -0.0661282658902201, + -0.027868441102670328, + 0.09641620534071385, + 0.10394799920088174, + -0.1922838034979096, + -0.059659524045001196, + 0.09002583595028324, + -0.29711198565256763, + -0.0854681762718252, + -0.23477574084402755, + -0.043896612881939044, + 0.1777607847606008, + 0.32844087206665984, + -0.2966899801267845, + -0.052610022214299765, + -0.10963904514858598, + 0.1937206318356951, + 0.06924306858052433, + 0.02095156958007217, + 0.22824470294500257, + 0.1786490059148682, + 0.08212629441764717, + 0.07082659407727254, + 0.019149806939990383, + 0.013367929326736628, + -0.175461859420144, + 0.04429852290649443, + -0.026751131234406356, + -0.06576252776787471, + -0.05964947629438731, + 0.014609027482563654, + 0.13591391300389422 + ], + "banker": [ + 0.05289026060616259, + 0.0640197166693786, + 0.00444935818733204, + -0.0743227279754251, + 0.25835765035461644, + -0.03627982291339841, + -0.22007011764446757, + 0.044315069805771795, + -0.27717855229122335, + 0.012434577623738095, + -0.05855416200863093, + 0.1886872553988842, + -0.049308999991761394, + 0.08132216410547938, + 0.06882191167170487, + -0.08797339151117956, + 0.024403260789701165, + 0.025831357330090073, + 0.025793891834431724, + 0.11253532969425728, + 0.20589052946412464, + 0.00010509071532167437, + -0.21124148378521146, + 0.08644832545261609, + -0.16641291630365168, + -0.3590957566218647, + 0.021531861125391436, + -0.10855296789104314, + -0.14364271035411746, + 0.181317572017618, + 0.11660804945758858, + -0.1679622247417588, + 0.12169674530906698, + 0.019346080031629527, + 0.06111944653488508, + -0.11568683903493032, + -0.28658900325952674, + 0.07404945024238772, + 0.22203815809287392, + -0.16100466181273446, + 0.013672702062612304, + 0.12596781181411895, + 0.1114179763825641, + 0.013305540205160464, + 0.1007667563521635, + -0.15465976993035843, + -0.174818410447237, + -0.016567021794854205, + 0.2254320912289833, + 0.13696724056887363 + ], + "nurse": [ + 0.07460675986831795, + 0.03203169897886132, + -0.15147310133886815, + -0.1959345274797395, + 0.13288530438836885, + 0.13447147387934336, + -0.13019571264280336, + 0.15070638513990606, + 0.1684382185032043, + -0.18190240402889335, + 0.13198471710704826, + 0.05729073820022404, + 0.1282829788448635, + 0.0927422348284243, + 0.039163376639048536, + -0.10890006879914423, + -0.212287083024799, + 0.12152046083602842, + 0.03518578281321602, + 0.07761074582245257, + -0.07371834268539361, + 0.3908021430000767, + 0.03042118929109436, + 0.09338927839315686, + -0.013659109935001903, + -0.25131253188202374, + 0.002601764202139401, + -0.11556928272027493, + -0.076819689426698, + -0.0021102950619052256, + 0.28906012040764484, + 0.05899049527622998, + -0.00728968610436485, + -0.03226698754785499, + 0.12331352199973875, + 0.10839703806543365, + 0.07332281448751633, + 0.0018814566450615622, + 0.3275379148370394, + -0.10240326460529336, + -0.02557951847912995, + 0.04535795672272651, + 0.12023651545522679, + 0.013731522020459434, + 0.17433057447670935, + -0.17035703735034224, + 0.006260907119937389, + -0.12913488573259918, + 0.0033159461430244744, + 0.2946380821725808 + ], + "receptionist": [ + 0.1097255390394386, + 0.009435615870951559, + -0.06199439975688469, + 0.01009308524306059, + 0.17155453769642165, + -0.09134017302849656, + -0.1453436321463831, + 0.16346022136998498, + -0.01203939099830946, + -0.2343578522075128, + -0.04993562581718051, + -0.0425597705184416, + 0.050491270137162585, + 0.18406041148429747, + 0.04669221362379672, + -0.0028110433806721475, + -0.32795936839872714, + 0.2278477682910716, + 0.2225833614082646, + 0.09479550538112934, + 0.06439013131327255, + 0.41743102587677233, + -0.15797355675918517, + 0.13302383459589665, + 0.12241232028033185, + -0.11747354681295621, + -0.01232574165251418, + 0.14458898962808187, + -0.054406624298896804, + -0.13634994733923125, + -0.04703593778452982, + 0.03333349041455306, + 0.08178309072480473, + 0.11451441776002837, + 0.0977391280809414, + 0.1328946149865985, + 0.12285166695194558, + 0.08053741369117048, + 0.2517146301284415, + -0.019806264834784596, + 0.08016009243201987, + 0.07925555516693275, + -0.03461276454660483, + 0.10731171673774899, + -0.043355763311718254, + -0.22319069357196591, + -0.02171949037005314, + -0.01164423743307569, + 0.13554620136939668, + 0.24106434993008727 + ], + "homemaker": [ + -0.14551601922112395, + 0.10414864695815045, + -0.11908713975416799, + -0.13219142821516494, + 0.17656877682278754, + 0.12188176417284301, + -0.19334642457281356, + 0.05839502627188894, + -0.04768683740104024, + 0.022972109758648013, + -0.005221417846548367, + -0.18037085220550222, + 0.11383205769850183, + 0.003580535183005695, + 0.02132751413086441, + -0.0740117538692501, + -0.17122953424432694, + 0.27716040403812486, + 0.2686205862839806, + 0.19469051762803102, + -0.08156144782581241, + 0.384836367025161, + -0.15050129254194902, + 0.2641650291948619, + 0.12489421582698607, + -0.15442960870885536, + 0.038345019370854626, + -0.014787003854987501, + 0.10774032103387893, + 0.074182550224333, + -0.009445533497982323, + 0.023925598975719425, + 0.03875592074685113, + 0.1579148444763438, + 0.1336469101976104, + -0.01734251337654649, + -0.13913962196469623, + 0.18047729051374228, + 0.1921112451353301, + -0.007125425909298439, + -0.04174856992504255, + 0.05956832297202354, + -0.020184416206556057, + 0.04022377927676636, + -0.04003318044573184, + -0.1696799905011112, + -0.09718807666114938, + -0.20805718889538724, + 0.0793683236141684, + 0.2066512131028209 + ], + "scientist": [ + -0.030347079970493712, + 0.11754191425674392, + 0.06363825086315666, + 0.019219110372927093, + 0.09362864660273924, + 0.009220576572483426, + -0.14244040159721122, + -0.10334009779787118, + 0.0625026453928504, + 0.02460349984138947, + 0.11502006034305867, + 0.07715754722612375, + 0.008153608716341671, + 0.13094167320178401, + 0.03284386958687169, + -0.0017853684576541857, + 0.09863765001845352, + 0.16931511237921326, + -0.13477265464405663, + 0.1724924940686271, + -0.024121494123941982, + 0.1598967206602894, + 0.02326352394688546, + -0.05242679787532814, + 0.2045439462560151, + -0.4301997429362301, + -0.10102454233125346, + -0.16988966319441068, + -0.1940940623017536, + 0.1271299719882093, + 0.17867759143491324, + -0.3508230413869781, + -0.13724052391738775, + -0.28608003341943183, + -0.04675455459240613, + -0.013664476485062268, + -0.127573417248261, + 0.17982283701956844, + 0.22987816676505496, + 0.02075516619328874, + 0.07631692925489533, + 0.11600335200665154, + 0.042023186469941604, + 0.0041238481161937094, + 0.24703757030618548, + -0.0010774370203247142, + 0.04983167909259088, + 0.08721025846920852, + -0.022199255322761412, + 0.10419613995205791 + ], + "maid": [ + 0.08042447591378793, + 0.005957938124312269, + -0.3177595476538498, + 0.022466142504172917, + 0.3126550623810127, + 0.05535910387947161, + -0.018472683691139115, + 0.1302732986200714, + 0.1466161945647364, + 0.02843988531091991, + -0.044522217572201594, + 0.08090502369051945, + 0.14703694084036356, + 0.01866575710899036, + 0.22528079773173865, + -0.2746383938218077, + -0.19292818563427586, + 0.19214863035202248, + 0.08859805965401701, + -0.010650006616318874, + 0.10219435808367429, + 0.2563562206270439, + -0.04088286907642152, + 0.10297818490172081, + -0.09812572024068071, + -0.15285050105486672, + 0.06962603342865643, + 0.01802032805064251, + -0.003321973386347614, + -0.03779454869796025, + 0.15453775769316852, + -0.030943005285717008, + -0.07174044364627513, + 0.18112806800564618, + 0.03555199240654647, + 0.14087738622661372, + -0.0073457601035129124, + -0.08821148566473522, + 0.149409778973469, + -0.2562707899111805, + -0.006672352485719802, + 0.02121671828467095, + 0.05019481710553017, + -0.030204029593498747, + 0.1311041123318428, + -0.09527019856294716, + -0.14841878266945377, + -0.3323468423875223, + 0.20218033216228015, + -0.03842246445955611 + ], + "lawyer": [ + -0.09235780028333866, + -0.004703762251305323, + -0.10856786598744905, + 0.10913842648924105, + 0.19771746412266442, + 0.06569514302956428, + -0.08856943223777032, + 0.1520476499506075, + -0.057935904421019375, + -0.01148248207158548, + -0.07204238847710881, + 0.0468973837365185, + -0.13326333821147598, + -0.04338373000662778, + 0.13190129310787155, + -0.09796543026559729, + 0.0022378650792507946, + -0.08157862528147052, + 0.07918688176385763, + -0.008393771045386298, + 0.04809037387662902, + 0.2834936473045238, + -0.10600322534808104, + 0.032754399176947425, + -0.08391849966255684, + -0.49191152218470124, + 0.04905283452106601, + -0.13388576784979453, + -0.1697657627722489, + 0.010594175081590143, + 0.2626883603199877, + -0.28727817319183097, + -0.09580037408539671, + -0.12578457715600055, + 0.022682181295144762, + -0.1310079913121366, + 0.015228202160454224, + 0.017384037160827853, + 0.20054145044466518, + 0.015598778324749424, + 0.10283920801996184, + 0.23346874674336776, + 0.032904243349135213, + 0.11861318876142316, + 0.04176041814126, + -0.151707618944489, + -0.22866604891683584, + -0.01803144082784435, + 0.024051910715271657, + 0.1902559927793645 + ], + "programmer": [ + -0.09132397057940844, + -0.08964690204949342, + 0.17872299755628585, + 0.14525826603657202, + 0.045900061398282066, + 0.06034402026315556, + -0.21798360616074378, + -0.10772861562769313, + -0.06665517584887179, + 0.1984370066495619, + 0.23657237861789698, + 0.12868546509297493, + -0.010268847295942782, + 0.23513068624343014, + 0.05512442640866316, + 0.010050174616212423, + -0.07557006599340658, + 0.17665171605337976, + 0.1351820886214804, + 0.056646351286198444, + -0.005906976367732349, + 0.10493074957607187, + -0.24592331471210932, + 0.260141435235358, + 0.12886633210506668, + -0.1277805586516622, + 0.012864594809092723, + -0.2502532700738873, + -0.023849129100656702, + -0.049055177474945484, + 0.22469034046652525, + -0.250933402437581, + 0.001463505599656834, + -0.1488070973145797, + 0.061081932011791136, + 0.09666325212772872, + -0.004518700067857631, + 0.05834284715218058, + 0.12470956979238657, + -0.05583400269986542, + 0.23127236644816512, + 0.08622603991048908, + -0.1414088621939273, + 0.08307155457175744, + -0.05707246171782524, + -0.07847558103826663, + 0.18566867545223464, + 0.07465225838683566, + -0.03612341495316358, + 0.2807540067144568 + ], + "linear_two_means_he": [ + -0.02507127125955947, + 0.04596272184970461, + -0.15112561429658566, + -0.1374721379066834, + 0.09718108698786107, + 0.10249319831779265, + -0.15247727952227522, + 0.10454195966689624, + -0.08226996977894457, + -0.03691591943343716, + 0.03574415823087041, + 0.05087226927797465, + -0.1453044915586461, + -0.04731206155150875, + 0.13789590236740085, + -0.04381323022526035, + -0.12228235375094398, + 0.047842457135033024, + -0.08288217429759909, + 0.012966515842096142, + 0.06203371502217571, + 0.22586469347697694, + 0.09933306274707533, + -0.019199162493873927, + 0.09881586564720174, + -0.47166435428573183, + -0.06431256316382453, + 0.030928687181711743, + -0.014426195456820558, + -0.06831882976423932, + 0.5967822317964825, + 0.012358250895681832, + -0.0644548914167879, + -0.09723743812718905, + 0.011175387173342946, + -0.006526642242854383, + 0.020319442604923476, + 0.18598181884710724, + 0.06498906889638778, + -0.10901688493467754, + -0.02607121690513947, + -0.0023322330823423824, + -0.0731648635930143, + -0.022165242808381856, + -0.042825214594514105, + -0.07592037681372436, + -0.028103187191426626, + -0.13827287856275716, + -0.00787945227225964, + 0.01849113741326002 + ], + "linear_two_means_him": [ + 0.03274469080260336, + 0.04581813374168451, + -0.026945189186177193, + -0.13407252956347865, + 0.16884558738945282, + 0.09192076802072385, + -0.09418313293324665, + 0.19231259683545218, + -0.051249728933154076, + 0.007521123126469416, + -0.04571459015112196, + 0.12833181734308483, + -0.1278512918661992, + -0.015857657944589244, + 0.15518354751548874, + -0.02132442786019474, + -0.08428802369055341, + -0.028540565026732807, + -0.010761800573271772, + -0.08433668206353717, + -0.005499833961154117, + 0.243317507296093, + 0.08314093080303513, + -0.00962565272761784, + 0.1581706976148933, + -0.5027186920172938, + -0.05450523206314018, + -0.0027956453045591585, + 0.055804317230870495, + -0.18771701233228788, + 0.5036565434812432, + 0.10071955410091013, + -0.11318052913280709, + -0.0674138196131925, + -0.005476446677888, + 0.06744165221081247, + 0.04763799051062047, + 0.08344297504799277, + 0.00587556738207631, + -0.16232111695448279, + 0.02113674901046344, + 0.06500096587143649, + -0.06703924457042704, + 0.027300403992778298, + 0.027029765785151295, + -0.10973711447755685, + -0.03580039144668812, + -0.1871164515934015, + 0.004207697577063773, + -0.006160639614912668 + ], + "linear_two_means_she": [ + 0.0009913532327193771, + 0.019614996218358445, + -0.10137458070759878, + -0.1419246727077103, + 0.11099749274807863, + 0.07725570313133208, + -0.1333641185879623, + 0.13593472890924266, + -0.08624209956798193, + -0.03775491820440955, + 0.01492974724696289, + 0.06709673588063728, + -0.11287033724949981, + -0.0356101839467477, + 0.18073558501242507, + -0.015729309838656443, + -0.08897671712813923, + 0.06013491345487104, + -0.0702819859100635, + -0.01917071317539572, + 0.05669129965532317, + 0.24108544171851742, + 0.13696354777972908, + 0.025229142565868203, + 0.15363565468800802, + -0.4928391882260989, + -0.052794524798579136, + 0.025883369518779613, + 0.00630089851518296, + -0.1373974055373634, + 0.5705934984785015, + 0.04477737317504965, + -0.0818754441427187, + -0.10221629012069311, + -0.015207746265859648, + -0.016075241293742035, + 0.04001737183455121, + 0.15183626819991988, + 0.052401184528904333, + -0.1319141831616075, + -0.011991901928596681, + 0.005669978817853338, + -0.03312199190971774, + -0.011478165520904714, + -0.016206515466013413, + -0.07816734915735406, + -0.03680964180126183, + -0.15836588753390368, + -0.010540214412538516, + 0.05746313462239602 + ], + "linear_two_means_her": [ + 0.006682066310324496, + 0.0721658593730306, + -0.07669622277516402, + -0.12961999476245178, + 0.15502918162923526, + 0.11715826320718434, + -0.11329629386755959, + 0.1609198275931058, + -0.04727759914411676, + 0.00836012189744181, + -0.024900179167214452, + 0.11210735074042223, + -0.16028544617534557, + -0.02755953554935032, + 0.11234386487046452, + -0.04940834824679864, + -0.11759366031335812, + -0.040833021346570825, + -0.02336198896080744, + -0.05219945304604531, + -0.00015741859430162125, + 0.22809675905455243, + 0.04551044577038141, + -0.05405395778735998, + 0.10335090857408702, + -0.48154385807692685, + -0.0660232704283855, + 0.0022496723583729367, + 0.03507722325886698, + -0.11863843655916378, + 0.5298452767992242, + 0.06830043182154232, + -0.09575997640687633, + -0.06243496761968843, + 0.020906686761314602, + 0.07699025126170013, + 0.027940061280992766, + 0.11758852569518013, + 0.018463451749559753, + -0.13942381872755277, + 0.0070574340339206445, + 0.0569987539712408, + -0.10708211625372363, + 0.016613326705301197, + 0.0004110666566505619, + -0.1074901421339271, + -0.027093936836852947, + -0.16702344262225488, + 0.006868459717342648, + -0.04513263682404869 + ], + "linear_two_means_engineer": [ + -0.06426400243213064, + -0.055014413071649795, + 0.15237905984589217, + -0.04231165552642368, + 0.05128665797818696, + 0.006702115208194068, + -0.2640015081559008, + 0.07081045745749148, + 0.08036399375736768, + -0.09443426456073081, + 0.2350664192982535, + 0.08997992637483913, + -0.11121041773264433, + 0.15325072571025908, + -0.18988897624293705, + -0.11647961562234029, + -0.07135548201102163, + 0.22013193439039092, + -0.04993447723220356, + -0.02849034643020615, + 0.10438359096751304, + 0.11958362190303527, + -0.19731527943132057, + -0.05502975908526194, + 0.09821373222468487, + -0.2885741518590411, + -0.09828220990033086, + -0.2284312365126377, + -0.04459101777245995, + 0.17485759381170884, + 0.3219861451542496, + -0.2955746954586103, + -0.04285906167673526, + -0.10364738258872655, + 0.191589858607148, + 0.06667707139823251, + 0.014410547559358507, + 0.22866368884829152, + 0.18068174353565453, + 0.06941164919968533, + 0.07214126829349139, + 0.013482475340456348, + 0.024272781596345674, + -0.19120243299987477, + 0.05056964544544692, + -0.03500162567307298, + -0.054125039202910216, + -0.08048677555692896, + 0.014538664043682616, + 0.1415444294414869 + ], + "linear_two_means_banker": [ + 0.056764988050817644, + 0.08293772881612002, + -0.008223416478587913, + -0.0686972047368485, + 0.25552339092742915, + -0.011160535054728003, + -0.21841948587010718, + 0.03831483376309934, + -0.25194212117163, + 0.01124905889641891, + -0.0563139484163981, + 0.17776204360183548, + -0.03584752032974893, + 0.08739466192035919, + 0.06975480454301729, + -0.08791898834575193, + 0.004384097891282766, + 0.02608615534332333, + 0.04574923379100179, + 0.11176896589035235, + 0.21570860897845326, + 0.019372613857345945, + -0.21744168961213908, + 0.09215350939080506, + -0.15632310511053565, + -0.34857472312462695, + 0.005741336233414801, + -0.10073473855932036, + -0.14449841418842677, + 0.17774001710605142, + 0.10865399463486754, + -0.1665878776138359, + 0.13371269513040748, + 0.02672950819630324, + 0.05849372939694325, + -0.11884887556100787, + -0.2946493981384101, + 0.07456575974547455, + 0.22454306754843206, + -0.17667271198549583, + 0.015292753682413184, + 0.11898405140830488, + 0.12485584818770233, + -0.006091312193588385, + 0.10849455849649992, + -0.1648267198215359, + -0.16047772287042478, + -0.04224448622375791, + 0.22534538350951358, + 0.14390563425036884 + ], + "linear_two_means_nurse": [ + 0.05844632238367545, + -0.046870199799296255, + -0.09861839460896223, + -0.21939705946784796, + 0.14470623190973592, + 0.029705735670130795, + -0.1370800501815668, + 0.17573174262523478, + 0.06318390751773008, + -0.176957926887374, + 0.12264139367936702, + 0.10285683408039564, + 0.07213879744474296, + 0.06741549308601777, + 0.03527253343986835, + -0.10912696964995701, + -0.1287925830539588, + 0.12045776741490864, + -0.04804253731512363, + 0.08080704143818226, + -0.11466689002887823, + 0.31044252863543187, + 0.05628056652192304, + 0.06959450322808786, + -0.055740976426830346, + -0.29519290967233236, + 0.06845976170296676, + -0.14817699723970285, + -0.07325078076996139, + 0.012810716369750359, + 0.3222343262190283, + 0.05325846607927745, + -0.05740495451485347, + -0.06306126430199212, + 0.1342646761818469, + 0.12158503498488703, + 0.10694053583754631, + -0.0002719302877073986, + 0.3170906167394461, + -0.03705607594756484, + -0.032336314485466, + 0.07448532763021165, + 0.06419079602452518, + 0.09463053364544989, + 0.14210000723526234, + -0.12795344626614577, + -0.053550212083182405, + -0.022041144406395652, + 0.0036775805289112396, + 0.26569992357040223 + ], + "linear_two_means_receptionist": [ + 0.09451416054101652, + -0.06483258681225645, + -0.012243707657382932, + -0.01199155540739322, + 0.18268125396884974, + -0.1899533014837988, + -0.1518236710755608, + 0.18701590698365797, + -0.11111239966606287, + -0.22970375089695594, + -0.058730241014120374, + 0.00033035122407720774, + -0.002355714641120804, + 0.16022104109957586, + 0.04302986917074675, + -0.003024618955380896, + -0.24936827562664693, + 0.22684748399358287, + 0.14424281645314627, + 0.09780409116043857, + 0.025846381913280377, + 0.34179071531899846, + -0.13363283113431113, + 0.11062646251579758, + 0.08280180851444766, + -0.15877694815072663, + 0.04966459295849106, + 0.11389623839564306, + -0.051047308039327305, + -0.1223052067384964, + -0.015809963768599847, + 0.027938087860524978, + 0.034610957107357565, + 0.08552860597507576, + 0.10804715046071488, + 0.1453081163517979, + 0.15449511003297323, + 0.07851048937559837, + 0.24188087372182449, + 0.041703259361273545, + 0.07380010486443837, + 0.10667235399795065, + -0.0873670697618532, + 0.18345974720693775, + -0.0736935161104532, + -0.1832773512328741, + -0.07801806356632346, + 0.08916017678362433, + 0.13588659794139515, + 0.21382565147508412 + ], + "hard_two_means_engineer": [ + -0.0615298187707372, + -0.05028601036959615, + 0.14625994806668458, + -0.0424608011177029, + 0.04662619217494263, + 0.035971693891699685, + -0.2614795462504052, + 0.08544722077904941, + 0.09908413436732402, + -0.10654383171055688, + 0.24193136298416668, + 0.06281489366487647, + -0.07094072290656792, + 0.15563699516750512, + -0.17546459592973965, + -0.10788505555987593, + -0.10208930148281103, + 0.2142971153539293, + -0.02046253337329823, + -0.03492626119604156, + 0.1161578043431099, + 0.13931220099143204, + -0.19492818673872775, + -0.03370842076926304, + 0.11867118939339891, + -0.28049364753132755, + -0.09949948214236101, + -0.23108240861796886, + -0.047994279083554574, + 0.1605071725337192, + 0.30035369714230525, + -0.2865402485790226, + -0.047808526589933883, + -0.11179440455103819, + 0.18119746140843931, + 0.056561581781467044, + 0.011489734086414911, + 0.22424183672768963, + 0.18232854844511018, + 0.054943481732343535, + 0.06980935405220816, + 0.005875363404471246, + 0.05134113471622906, + -0.20198533443899558, + 0.05206143974434169, + -0.04640693621701136, + -0.04412743717130785, + -0.09732366092013081, + 0.011571589440019493, + 0.1503677911085482 + ], + "hard_two_means_banker": [ + 0.05831517220947051, + 0.08255060229350653, + -0.010687978298559553, + -0.07024751366240858, + 0.2519342862586521, + 0.009542531419089874, + -0.21650663602619274, + 0.05332891540326999, + -0.24100385837596072, + 0.0003716364196564882, + -0.050541299381410484, + 0.1554367318486496, + -0.002065675915319115, + 0.08807187829482309, + 0.08283185577718966, + -0.0800012922655011, + -0.018950955749106952, + 0.020637597775664977, + 0.06793581151908434, + 0.10602212885129779, + 0.22410876081073244, + 0.032740401032898825, + -0.2136818122329968, + 0.11039690242045577, + -0.13997799196096344, + -0.3437597785963685, + 0.008583286067544691, + -0.1051446329873683, + -0.1474241786552995, + 0.16539534104352926, + 0.09068823158556982, + -0.15859570091455277, + 0.1261277317800112, + 0.017357039720143373, + 0.04956263110118833, + -0.12738975430010907, + -0.2953207127814389, + 0.07035546665029485, + 0.2254337673876172, + -0.18608990288627977, + 0.012733957733472462, + 0.11371769553436785, + 0.14646096558896335, + -0.011171241232389892, + 0.10793064489326969, + -0.17279882758386675, + -0.15485279947290634, + -0.05133406428332264, + 0.22262903819597718, + 0.15030577992094413 + ], + "hard_two_means_nurse": [ + 0.06146327453686229, + -0.01286496315865509, + -0.11479833599553839, + -0.2058079638529719, + 0.14844784216731188, + 0.023452993568146493, + -0.13882932214717714, + 0.1288676242431418, + 0.08079410674381028, + -0.15267628895756552, + 0.11257113950338321, + 0.13785016487903767, + 0.013821770549912307, + 0.07638901547180618, + 0.005220059529020181, + -0.1282148848106331, + -0.10724841128181291, + 0.13410391057870605, + -0.06691573398880354, + 0.09339094022931868, + -0.11785750549188001, + 0.3117332560185772, + 0.03633362132737723, + 0.03536662418177696, + -0.07770569091973209, + -0.28846856629532674, + 0.03397359452510464, + -0.1238270024928557, + -0.06765794140528432, + 0.036466114021770274, + 0.35185870053329116, + 0.036297264946993826, + -0.01802508788477762, + -0.027447937198398376, + 0.15131339466835741, + 0.1367508816615861, + 0.09447801560633352, + 0.01083124656010927, + 0.31931102669206046, + -0.04162669860832338, + -0.023305127074758826, + 0.07503755989512789, + 0.035334299846069575, + 0.07303391094593463, + 0.15697389277542634, + -0.12640969670612956, + -0.042111810108238674, + -0.04490123431482444, + 0.010107187939507111, + 0.262321445674747 + ], + "hard_two_means_receptionist": [ + 0.09674335685164037, + -0.034910052937179456, + -0.025769724689733035, + 0.0003408204073268948, + 0.18692608456209275, + -0.20099618140575293, + -0.15387128586208695, + 0.14188947620096687, + -0.09860789226373355, + -0.20549041401062226, + -0.06911095070338602, + 0.037010992313086875, + -0.06256521541802815, + 0.1679078866637648, + 0.013165465149707359, + -0.02188881872423698, + -0.22420978151645113, + 0.24027678791492135, + 0.12173488329552958, + 0.11038203768839533, + 0.020792665441772953, + 0.33933251012588445, + -0.15213368491908033, + 0.07571326217192854, + 0.0591517495959763, + -0.15417358458567518, + 0.018661078551310818, + 0.1364326125454474, + -0.04535731363484939, + -0.09824696641401537, + 0.014991948104691483, + 0.010918762270965297, + 0.07117943893398276, + 0.11927432641675981, + 0.12539537290888017, + 0.16090048664140563, + 0.14374724131582556, + 0.0893773675001957, + 0.24358870629291604, + 0.04022442205106769, + 0.08240657141634557, + 0.10857091604183973, + -0.11847301952386918, + 0.16588631835880335, + -0.0604994355066736, + -0.17978269572934102, + -0.0694985544568145, + 0.07155565829259858, + 0.14225409779564568, + 0.20914431866504204 + ], + "hard_two_means_boy": [ + -0.08233869790364136, + 0.01176112910229632, + -0.029354421664617908, + -0.11785738925591122, + 0.2374148337319568, + -0.0075029559423780146, + -0.17659279239586476, + 0.027633850309864774, + -0.1152995062971225, + 0.09075349808875449, + 0.02836374964167929, + 0.1343178807958097, + -0.1558218182607306, + 0.02652591735658287, + 0.12744446546535132, + -0.08432153027611451, + -0.017331551904295, + 0.1590225437950705, + -0.18263648425053264, + 0.11105959675158651, + -0.17924486839119158, + 0.15997241429057224, + 0.04246950026749356, + 0.08695102346300654, + -0.031672011588002466, + -0.42835387214362464, + -0.07705335565597732, + 0.027214634797224094, + 0.09431054709518588, + -0.0883024797955068, + 0.498730645137716, + -0.09497822573168696, + -0.05718263942349022, + 0.07758774887517442, + 0.16118891300167687, + 0.14968144938848388, + 0.05316699891799519, + -0.1434336200114328, + 0.04206219646662961, + -0.12754823516106828, + -0.07607733927253879, + 0.062112146146456576, + -0.19567175761378108, + 0.023630893562061474, + 0.16100118048659245, + -0.06586297910813095, + 0.0031633274726710625, + -0.10798405610676251, + 0.07410556241487891, + -0.04737364343000658 + ], + "hard_two_means_brother": [ + 0.09355176730086777, + 0.12256403520417863, + 0.06776850191177644, + -0.03732057833194877, + 0.1973327179716123, + 0.03647561569899152, + -0.1927079246169542, + 0.10286331869848622, + -0.2224140473182231, + 0.0017713408944394562, + 0.061318463134264894, + 0.30078445384743735, + -0.29814637663525245, + -0.07829957366464688, + 0.056973105522299054, + -0.040020452271316814, + -0.027592324015172268, + 0.03933706482881305, + -0.11465243907201403, + 0.14866947649366527, + -0.11839163417576236, + 0.10171662000727881, + 0.04392125246078521, + 0.03834884533659155, + 0.0021902060019948627, + -0.39977887905454507, + -0.009133730411963414, + -0.16446671183826925, + -0.01445233209192701, + 0.009499897611979814, + 0.4120250039636035, + -0.047250515132056914, + -0.03443655174881199, + -0.002306661428680433, + 0.13278718253020522, + 0.016054588321409494, + 0.02603943417088129, + 0.0536546887730865, + 0.052058594849243044, + 0.05245307618307668, + 0.007321362039567724, + 0.1560802316952669, + -0.19833881114270957, + -0.060761388445296785, + 0.035309867066888365, + 0.012233264543807625, + -0.22528946293735996, + -0.25468297237496323, + 0.021109753145762584, + 0.03315613585319681 + ], + "hard_two_means_girl": [ + -0.059051184553146355, + 0.09130864580381053, + -0.09433445291364079, + -0.1003637202968386, + 0.20984126422924856, + 0.1891986248941208, + -0.1612958379331935, + 0.06632757751512357, + 0.039987571005311515, + 0.038970920808602316, + 0.0627605581156269, + -0.008416612908176607, + 0.04697955816742365, + 0.055500409623059965, + 0.18758494041746404, + -0.05009970673487836, + -0.203438159767199, + 0.1367272966450184, + -0.0017339020456125648, + 0.08310038512270256, + -0.10103948243654516, + 0.3000659826824898, + 0.031993904448984534, + 0.18975505911288815, + 0.08180516528881598, + -0.3625211325292369, + -0.13263769261468908, + 0.04184559109223322, + 0.07807784138721574, + -0.1566518269355896, + 0.3874646653623079, + -0.05477055699459896, + -0.03816174762959326, + 0.06904939724928574, + 0.11157898137858574, + 0.09944435514507138, + 0.01568439700054018, + -0.1592907801843134, + 0.056638525551177606, + -0.23523162783503626, + -0.08010708624283328, + 0.009526082580762835, + -0.04524274813689534, + -0.08144056495839333, + 0.19175359863488856, + -0.14372849621884304, + 0.08886968844489498, + -0.2572285093461013, + 0.06207289908291371, + 0.009884692917453634 + ], + "hard_two_means_sister": [ + 0.12248107633941492, + 0.22138328907948263, + -0.012954045706107253, + -0.015588770907164453, + 0.1630789828718641, + 0.28083149581140743, + -0.17370502309033872, + 0.15093125859943654, + -0.029506037718332184, + -0.06255644534891791, + 0.10404848275060993, + 0.12347010445017062, + -0.04621292183347117, + -0.042305518464498454, + 0.13168363212206985, + 0.0024921893399441654, + -0.2587864189771069, + 0.01164041533313236, + 0.11007686661686816, + 0.1139366709118734, + -0.021239665019178636, + 0.27575023507492097, + 0.030907765609037744, + 0.1660589049279059, + 0.14315915658432826, + -0.3179970399850531, + -0.07818431719550811, + -0.14629115784856864, + -0.03461768649837656, + -0.07540823990526033, + 0.27380295054152354, + 0.0026981443537368206, + -0.010807525814932504, + -0.012913573774685147, + 0.0711584022604951, + -0.046353295366842115, + -0.020523964768226453, + 0.033955862186895676, + 0.0701662872588102, + -0.08131844755576975, + 0.0023153404134622197, + 0.09075430096508427, + -0.011465818821832785, + -0.19128819261369961, + 0.0735125808948881, + -0.08449649676813556, + -0.11881928014809101, + -0.4400844295880006, + 0.006161972725925848, + 0.10428627681320381 + ], + "inlp_engineer": [ + -0.052782267332077026, + -0.0534612238407135, + 0.14106474816799164, + -0.026172973215579987, + 0.06754761189222336, + 0.017533697187900543, + -0.260309636592865, + 0.07288291305303574, + 0.1343180537223816, + -0.11478697508573532, + 0.2469760626554489, + 0.02129148691892624, + -0.08431467413902283, + 0.1557355523109436, + -0.18594929575920105, + -0.09864699095487595, + -0.10250400006771088, + 0.17043571174144745, + -0.0017871428281068802, + -0.014588537625968456, + 0.11120335757732391, + 0.16553060710430145, + -0.18866267800331116, + -0.028261572122573853, + 0.11305459588766098, + -0.24370981752872467, + -0.11005059629678726, + -0.22849202156066895, + -0.025211429223418236, + 0.1481746882200241, + 0.2578311562538147, + -0.27018457651138306, + -0.04134102910757065, + -0.11011195182800293, + 0.1624029278755188, + 0.028553470969200134, + -0.0067904251627624035, + 0.2444734275341034, + 0.18399406969547272, + 0.05907474458217621, + 0.09790921956300735, + -0.01432067435234785, + 0.09672029316425323, + -0.23263487219810486, + 0.048395272344350815, + -0.02799120359122753, + -0.0380919948220253, + -0.09178897738456726, + 0.0034351442009210587, + 0.16631698608398438 + ], + "inlp_homemaker": [ + -0.18342512845993042, + 0.05450461059808731, + -0.05481652170419693, + -0.1319858729839325, + 0.21286684274673462, + 0.031099338084459305, + -0.20349711179733276, + 0.046906277537345886, + -0.13989706337451935, + 0.02225230261683464, + -0.0062011610716581345, + -0.08698759227991104, + 0.00638633593916893, + -0.013031614944338799, + -0.01284074503928423, + -0.09018639475107193, + -0.09081724286079407, + 0.29148492217063904, + 0.2009771168231964, + 0.20561586320400238, + -0.14814838767051697, + 0.2990812063217163, + -0.15312905609607697, + 0.23493033647537231, + 0.05826185271143913, + -0.15942531824111938, + 0.07459422200918198, + -0.07452902942895889, + 0.10226386785507202, + 0.13873976469039917, + -0.012295234948396683, + -0.013750902377068996, + 0.043406546115875244, + 0.14195837080478668, + 0.15917563438415527, + 0.03275652229785919, + -0.13516400754451752, + 0.15392933785915375, + 0.1699870228767395, + 0.04027242586016655, + -0.01579558104276657, + 0.10894948989152908, + -0.13225291669368744, + 0.060379303991794586, + -0.012283140793442726, + -0.13622595369815826, + -0.13849352300167084, + -0.1582598090171814, + 0.05875536426901817, + 0.20893695950508118 + ], + "oscar_programmer": [ + -0.07830896973609924, + -0.07528749108314514, + 0.18948820233345032, + 0.148856520652771, + 0.06805361062288284, + 0.09882345795631409, + -0.20125211775302887, + -0.11760184168815613, + -0.047656722366809845, + 0.2133728563785553, + 0.23543985188007355, + 0.09724324941635132, + 0.03016793727874756, + 0.2407066524028778, + 0.06938372552394867, + 0.026680808514356613, + -0.10048334300518036, + 0.18456003069877625, + 0.188383087515831, + 0.043253201991319656, + -0.021929919719696045, + 0.12237273901700974, + -0.2533322274684906, + 0.28144603967666626, + 0.145513653755188, + -0.09027981758117676, + -0.02162027545273304, + -0.22467951476573944, + -0.008976730518043041, + -0.06909440457820892, + 0.1673312932252884, + -0.24996359646320343, + 0.02877015806734562, + -0.1258096545934677, + 0.06085729971528053, + 0.08735853433609009, + -0.012516401708126068, + 0.06429371237754822, + 0.1277480572462082, + -0.07779265195131302, + 0.23643597960472107, + 0.07029429078102112, + -0.11198242008686066, + 0.1061592698097229, + -0.05602099001407623, + -0.08935096859931946, + 0.18985341489315033, + 0.04543842002749443, + -0.0005864029517397285, + 0.0003419801068957895 + ], + "oscar_grandpa": [ + -0.07416621595621109, + 0.08458206057548523, + -0.053309064358472824, + -0.16716738045215607, + 0.09583862870931625, + -0.026479464024305344, + -0.054730307310819626, + 0.10439373552799225, + -0.22542959451675415, + -0.03889299929141998, + -0.140243798494339, + 0.3228200376033783, + 0.006546759977936745, + -0.10860257595777512, + 0.2503794729709625, + 0.1383737474679947, + -0.08268776535987854, + 0.22803069651126862, + 0.2692744731903076, + 0.07927031069993973, + -0.08930720388889313, + 0.06765256077051163, + 0.11146003007888794, + 0.14737895131111145, + 0.2791425883769989, + -0.10617826879024506, + -0.15906700491905212, + 0.059100620448589325, + 0.1821293979883194, + -0.23608064651489258, + -0.10596588999032974, + 0.25953924655914307, + -0.01979101449251175, + 0.3231821060180664, + 0.06740891933441162, + 0.2706506848335266, + 0.0015214867889881134, + -0.05631883069872856, + 0.0701802670955658, + -0.06867749989032745, + -0.047012969851493835, + 0.1160019040107727, + -0.3308452069759369, + 0.02871404029428959, + 0.16991694271564484, + -0.06927385926246643, + 0.04270310699939728, + -0.19192083179950714, + 0.027655234560370445, + 0.007677375338971615 + ], + "oscar_grandma": [ + -0.018536921590566635, + 0.12609684467315674, + -0.15519794821739197, + -0.08694198727607727, + 0.08195910602807999, + 0.04313560575246811, + -0.09287772327661514, + 0.06591854244470596, + -0.012779037468135357, + -0.019507993012666702, + -0.19680896401405334, + 0.12208959460258484, + 0.174989253282547, + -0.017019614577293396, + 0.32496196031570435, + 0.13468721508979797, + -0.14616119861602783, + 0.1619987040758133, + 0.3151642680168152, + 0.05573020502924919, + 0.040678128600120544, + 0.18073932826519012, + 0.010412609204649925, + 0.1281396895647049, + 0.3903489112854004, + -0.13189978897571564, + -0.12766925990581512, + 0.11984847486019135, + 0.23445692658424377, + -0.3537692725658417, + -0.05933094397187233, + 0.28710752725601196, + -0.13906823098659515, + 0.43696609139442444, + -0.00942843034863472, + 0.13638177514076233, + 0.05776602774858475, + 0.08829684555530548, + 0.13935112953186035, + -0.08038724958896637, + -0.05902135744690895, + 0.06793523579835892, + -0.14485999941825867, + 0.07194627076387405, + 0.06938131153583527, + -0.007875805720686913, + 0.08729098737239838, + -0.26941344141960144, + 0.039947230368852615, + 0.010896616615355015 + ], + "oscar_bias1": [ + 0.08474383216258977, + 0.0945811531047026, + 0.07428173109827416, + 0.024479834032266147, + 0.14469250544588316, + 0.2530798343782911, + 0.10930049304612847, + -0.06673947901906906, + 0.12434172735746349, + 0.0993960907920231, + -0.005839121876168384, + -0.20531481852387545, + 0.2639451804182061, + 0.037126206819276295, + 0.09238763082193571, + 0.11035712444527741, + -0.16029347056639812, + 0.05071308343986695, + 0.3471755325221579, + -0.08847484367886994, + -0.10499192878117405, + 0.11162700081145335, + -0.048837720672568766, + 0.13777196765170105, + 0.11012939179972218, + 0.24360542041955216, + -0.2284846261778304, + 0.1654286047628525, + 0.09593342221771839, + -0.12955089953583826, + -0.3755603276721494, + 0.0016714486406543773, + 0.1787442511314885, + 0.14721545106478545, + -0.0020732543727022246, + -0.0630506707229516, + -0.05410667297409925, + 0.04081149967427376, + 0.019907888418869278, + -0.14308377250011778, + 0.0349490091871661, + -0.10417437038931851, + 0.1943431938272854, + 0.15112671438161962, + 0.008346949102905113, + -0.07002366108042024, + 0.027757956870002695, + -0.1897564411718682, + -0.005211557668612942, + 0.0018718653159160395 + ], + "oscar_bias2": [ + -0.07375836241905759, + -0.003531244283090557, + 0.23980375746558857, + 0.055099282155637724, + -0.09354853471244139, + -0.03623296399517668, + -0.0691153173191639, + -0.12025046794231448, + -0.061813955528240296, + 0.06947850278037167, + 0.11670592968750768, + 0.13601944641445932, + -0.18235506325127593, + 0.027291841412269383, + -0.11372389081903568, + 0.05454688013100257, + 0.27933836816694524, + -0.10084788999856882, + -0.2461105105991461, + -0.015257886139356254, + 0.043195013966402775, + -0.23837120799364592, + -0.0008649711515565736, + -0.18881254808973197, + 0.030285064291280767, + -0.2535447782593246, + -0.0880216817353963, + -0.2231164800064211, + -0.14925981567619542, + 0.18033630869008552, + 0.17550613393187223, + -0.33909388125586026, + -0.08692648004795574, + -0.3118441049458602, + -0.04259817267933873, + -0.1243972891709841, + -0.10064526922062178, + 0.11456477598533303, + -0.008342291120067315, + 0.11684959186158074, + 0.06611318716932911, + 0.058751772007228635, + 0.030128999216968336, + -0.07357844288626657, + 0.10185220889088065, + 0.11829129739777486, + 0.012937075774670055, + 0.1987487011878333, + -0.09678284262303441, + -0.02743255064684724 + ] +} \ No newline at end of file diff --git a/tests/fairness/bias_direction_test.py b/tests/fairness/bias_direction_test.py new file mode 100644 index 00000000000..c4f32ba10f7 --- /dev/null +++ b/tests/fairness/bias_direction_test.py @@ -0,0 +1,149 @@ +import torch +from torch import allclose +import pytest +import math + +from allennlp.common.checks import ConfigurationError +from allennlp.common.testing import AllenNlpTestCase, multi_device +from allennlp.fairness.bias_direction import ( + PCABiasDirection, + PairedPCABiasDirection, + TwoMeansBiasDirection, + ClassificationNormalBiasDirection, +) + + +class PCABiasDirectionTest(AllenNlpTestCase): + def test_pca_invalid_dims(self): + pca = PCABiasDirection() + with pytest.raises(ConfigurationError): + pca(torch.zeros(2)) + + @multi_device + def test_pca_without_grad(self, device: str): + seed_embeddings = torch.eye(2, device=device) + pca = PCABiasDirection() + + const = 1 / math.sqrt(2) + expected_bias_direction = torch.tensor([const, -const], device=device) + test_bias_direction = pca(seed_embeddings) + k = expected_bias_direction / test_bias_direction + assert k[0].item() == pytest.approx(k[1].item()) + assert seed_embeddings.grad is None + + @multi_device + def test_pca_with_grad(self, device: str): + # add noise to avoid "RuntimeError: triangular_solve_cpu: U(2,2) is zero, singular U." + torch.manual_seed(0) + seed_embeddings = torch.eye(2, device=device) + (1 - torch.eye(2, device=device)) * 1e-1 + seed_embeddings = seed_embeddings.requires_grad_() + assert seed_embeddings.grad is None + + pca = PCABiasDirection(requires_grad=True) + test_bias_direction = pca(seed_embeddings) + test_bias_direction.sum().backward() + assert seed_embeddings.grad is not None + + +class PairedPCABiasDirectionTest(AllenNlpTestCase): + def test_paired_pca_invalid_dims(self): + paired_pca = PairedPCABiasDirection() + with pytest.raises(ConfigurationError): + paired_pca(torch.zeros(2), torch.zeros(3)) + + with pytest.raises(ConfigurationError): + paired_pca(torch.zeros(2), torch.zeros(2)) + + @multi_device + def test_paired_pca_without_grad(self, device: str): + seed_embeddings1 = torch.eye(2, device=device) + seed_embeddings2 = torch.tensor([[1.0, 1.0], [1.0, 1.0]], device=device) + paired_pca = PairedPCABiasDirection() + + const = 1 / math.sqrt(2) + expected_bias_direction = torch.tensor([const, -const], device=device) + test_bias_direction = paired_pca(seed_embeddings1, seed_embeddings2) + k = expected_bias_direction / test_bias_direction + assert k[0].item() == pytest.approx(k[1].item()) + assert seed_embeddings1.grad is None + assert seed_embeddings2.grad is None + + @multi_device + def test_paired_pca_with_grad(self, device: str): + # add noise to avoid "RuntimeError: triangular_solve_cpu: U(2,2) is zero, singular U." + torch.manual_seed(0) + seed_embeddings1 = torch.tensor([[1.0, 1.0], [1.0, 1.0]], device=device) + seed_embeddings2 = (1 - torch.eye(2, device=device)) * 9e-1 + seed_embeddings1 = seed_embeddings1.requires_grad_() + seed_embeddings2 = seed_embeddings2.requires_grad_() + assert seed_embeddings1.grad is None + assert seed_embeddings2.grad is None + + paired_pca = PairedPCABiasDirection(requires_grad=True) + test_bias_direction = paired_pca(seed_embeddings1, seed_embeddings2) + test_bias_direction.sum().backward() + assert seed_embeddings1.grad is not None + assert seed_embeddings2.grad is not None + + +class TwoMeansBiasDirectionTest(AllenNlpTestCase): + def test_two_means_invalid_dims(self): + two_means = TwoMeansBiasDirection() + with pytest.raises(ConfigurationError): + two_means(torch.zeros(2), torch.zeros(2)) + + with pytest.raises(ConfigurationError): + two_means(torch.zeros(2, 2), torch.zeros(2, 3)) + + @multi_device + def test_two_means_without_grad(self, device: str): + seed_embeddings1 = torch.eye(2, device=device) + seed_embeddings2 = 1 - torch.eye(2, device=device) + two_means = TwoMeansBiasDirection() + + expected_bias_direction = torch.tensor([float("nan"), float("nan")], device=device) + test_bias_direction = two_means(seed_embeddings1, seed_embeddings2) + assert allclose(expected_bias_direction, test_bias_direction, equal_nan=True) + assert seed_embeddings1.grad is None + assert seed_embeddings2.grad is None + + @multi_device + def test_two_means_with_grad(self, device: str): + seed_embeddings1 = torch.eye(2, device=device) + seed_embeddings2 = 1 - torch.eye(2, device=device) + seed_embeddings1 = seed_embeddings1.requires_grad_() + seed_embeddings2 = seed_embeddings2.requires_grad_() + assert seed_embeddings1.grad is None + assert seed_embeddings2.grad is None + + two_means = TwoMeansBiasDirection(requires_grad=True) + test_bias_direction = two_means(seed_embeddings1, seed_embeddings2) + test_bias_direction.sum().backward() + assert seed_embeddings1.grad is not None + assert seed_embeddings2.grad is not None + + +class ClassificationNormalBiasDirectionTest(AllenNlpTestCase): + def test_classification_normal_invalid_dims(self): + classification_normal = ClassificationNormalBiasDirection() + with pytest.raises(ConfigurationError): + classification_normal(torch.zeros(2), torch.zeros(2)) + + with pytest.raises(ConfigurationError): + classification_normal(torch.zeros(2, 2), torch.zeros(2, 3)) + + @multi_device + def test_classification_normal_without_grad(self, device: str): + seed_embeddings1 = torch.eye(2, device=device) + seed_embeddings2 = torch.tensor([[1.0, 1.0], [1.0, 1.0]], device=device) + classification_normal = ClassificationNormalBiasDirection() + test_bias_direction = classification_normal(seed_embeddings1, seed_embeddings2) + const = 1 / math.sqrt(2) + assert ( + allclose(test_bias_direction, torch.Tensor([const, const]).to(device)) + or allclose(test_bias_direction, torch.Tensor([-const, -const]).to(device)) + or allclose(test_bias_direction, torch.Tensor([const, -const]).to(device)) + or allclose(test_bias_direction, torch.Tensor([-const, const]).to(device)) + ) + assert seed_embeddings1.grad is None + assert seed_embeddings2.grad is None diff --git a/tests/fairness/bias_mitigators_test.py b/tests/fairness/bias_mitigators_test.py new file mode 100644 index 00000000000..c7837be77ff --- /dev/null +++ b/tests/fairness/bias_mitigators_test.py @@ -0,0 +1,320 @@ +import torch +from torch import allclose +import pytest +import json + +from allennlp.common.checks import ConfigurationError +from allennlp.common.testing import AllenNlpTestCase, multi_device +from allennlp.fairness.bias_mitigators import ( + LinearBiasMitigator, + HardBiasMitigator, + INLPBiasMitigator, + OSCaRBiasMitigator, +) +from allennlp.fairness.bias_direction import TwoMeansBiasDirection + + +class LinearBiasMitigatorTest(AllenNlpTestCase): + def setup_method(self): + super().setup_method() + + # embedding data from VERB demo + emb_filename = str(self.FIXTURES_ROOT / "fairness" / "bias_embeddings.json") + with open(emb_filename) as emb_file: + emb_data = json.load(emb_file) + + seed_embeddings1 = torch.cat( + [ + torch.Tensor(emb_data["he"]).reshape(1, -1), + torch.Tensor(emb_data["him"]).reshape(1, -1), + ] + ) + seed_embeddings2 = torch.cat( + [ + torch.Tensor(emb_data["she"]).reshape(1, -1), + torch.Tensor(emb_data["her"]).reshape(1, -1), + ] + ) + tm = TwoMeansBiasDirection() + self.bias_direction = tm(seed_embeddings1, seed_embeddings2) + + evaluation_embeddings = [] + expected_bias_mitigated_embeddings = [] + for word in ["engineer", "banker", "nurse", "receptionist"]: + evaluation_embeddings.append(torch.Tensor(emb_data[word]).reshape(1, -1)) + expected_bias_mitigated_embeddings.append( + torch.Tensor(emb_data["linear_two_means_" + word]).reshape(1, -1) + ) + self.evaluation_embeddings = torch.cat(evaluation_embeddings).reshape(2, 2, -1) + self.expected_bias_mitigated_embeddings = torch.cat( + expected_bias_mitigated_embeddings + ).reshape(2, 2, -1) + + def test_invalid_dims(self): + lbm = LinearBiasMitigator() + with pytest.raises(ConfigurationError): + lbm(torch.zeros(2), torch.zeros(2)) + with pytest.raises(ConfigurationError): + lbm(torch.zeros(2), torch.zeros((2, 2))) + with pytest.raises(ConfigurationError): + lbm(torch.zeros((2, 3)), torch.zeros(2)) + + @multi_device + def test_lbm_without_grad(self, device: str): + self.bias_direction = self.bias_direction.to(device) + self.evaluation_embeddings = self.evaluation_embeddings.to(device) + self.expected_bias_mitigated_embeddings = self.expected_bias_mitigated_embeddings.to(device) + + lbm = LinearBiasMitigator() + test_bias_mitigated_embeddings = lbm(self.evaluation_embeddings, self.bias_direction) + assert allclose( + self.expected_bias_mitigated_embeddings, test_bias_mitigated_embeddings, atol=1e-6 + ) + + @multi_device + def test_lbm_with_grad(self, device: str): + self.bias_direction = self.bias_direction.to(device).requires_grad_() + self.evaluation_embeddings = self.evaluation_embeddings.to(device).requires_grad_() + assert self.bias_direction.grad is None + assert self.evaluation_embeddings.grad is None + + lbm = LinearBiasMitigator(requires_grad=True) + test_bias_mitigated_embeddings = lbm(self.evaluation_embeddings, self.bias_direction) + test_bias_mitigated_embeddings.sum().backward() + assert self.bias_direction.grad is not None + assert self.evaluation_embeddings.grad is not None + + +class HardBiasMitigatorTest(AllenNlpTestCase): + def setup_method(self): + super().setup_method() + + # embedding data from VERB demo + emb_filename = str(self.FIXTURES_ROOT / "fairness" / "bias_embeddings.json") + with open(emb_filename) as emb_file: + emb_data = json.load(emb_file) + + seed_embeddings1 = torch.cat( + [ + torch.Tensor(emb_data["he"]).reshape(1, -1), + torch.Tensor(emb_data["man"]).reshape(1, -1), + ] + ) + seed_embeddings2 = torch.cat( + [ + torch.Tensor(emb_data["she"]).reshape(1, -1), + torch.Tensor(emb_data["woman"]).reshape(1, -1), + ] + ) + tm = TwoMeansBiasDirection() + self.bias_direction = tm(seed_embeddings1, seed_embeddings2) + + self.equalize_embeddings1 = torch.cat( + [ + torch.Tensor(emb_data["boy"]).reshape(1, -1), + torch.Tensor(emb_data["brother"]).reshape(1, -1), + ] + ).unsqueeze(0) + self.equalize_embeddings2 = torch.cat( + [ + torch.Tensor(emb_data["girl"]).reshape(1, -1), + torch.Tensor(emb_data["sister"]).reshape(1, -1), + ] + ).unsqueeze(0) + + evaluation_embeddings = [] + expected_bias_mitigated_embeddings = [] + for word in ["engineer", "banker", "nurse", "receptionist"]: + evaluation_embeddings.append(torch.Tensor(emb_data[word]).reshape(1, -1)) + expected_bias_mitigated_embeddings.append( + torch.Tensor(emb_data["hard_two_means_" + word]).reshape(1, -1) + ) + for word in ["boy", "brother", "girl", "sister"]: + expected_bias_mitigated_embeddings.append( + torch.Tensor(emb_data["hard_two_means_" + word]).reshape(1, -1) + ) + self.evaluation_embeddings = torch.cat(evaluation_embeddings).reshape(2, 2, -1) + self.expected_bias_mitigated_embeddings = torch.cat( + expected_bias_mitigated_embeddings + ).reshape(4, 2, -1) + + def test_invalid_dims(self): + hbm = HardBiasMitigator() + with pytest.raises(ConfigurationError): + hbm(torch.zeros(2), torch.zeros(2), torch.zeros(2), torch.zeros(2)) + with pytest.raises(ConfigurationError): + hbm(torch.zeros(2), torch.zeros(2), torch.zeros((2, 2)), torch.zeros((3, 2))) + with pytest.raises(ConfigurationError): + hbm(torch.zeros(2), torch.zeros(2), torch.zeros((2, 2)), torch.zeros((2, 2))) + with pytest.raises(ConfigurationError): + hbm(torch.zeros((3, 3)), torch.zeros(2), torch.zeros((2, 2)), torch.zeros((2, 2))) + with pytest.raises(ConfigurationError): + hbm(torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((2, 2)), torch.zeros((2, 2))) + with pytest.raises(ConfigurationError): + hbm(torch.zeros((3, 2)), torch.zeros(3), torch.zeros((2, 2)), torch.zeros((2, 2))) + + @multi_device + def test_hbm_without_grad(self, device: str): + self.bias_direction = self.bias_direction.to(device) + self.evaluation_embeddings = self.evaluation_embeddings.to(device) + self.equalize_embeddings1 = self.equalize_embeddings1.to(device) + self.equalize_embeddings2 = self.equalize_embeddings2.to(device) + self.expected_bias_mitigated_embeddings = self.expected_bias_mitigated_embeddings.to(device) + + hbm = HardBiasMitigator() + test_bias_mitigated_embeddings = hbm( + self.evaluation_embeddings, + self.bias_direction, + self.equalize_embeddings1, + self.equalize_embeddings2, + ) + assert allclose( + self.expected_bias_mitigated_embeddings, test_bias_mitigated_embeddings, atol=1e-6 + ) + + @multi_device + def test_hbm_with_grad(self, device: str): + self.bias_direction = self.bias_direction.to(device).requires_grad_() + self.evaluation_embeddings = self.evaluation_embeddings.to(device).requires_grad_() + self.equalize_embeddings1 = self.equalize_embeddings1.to(device).requires_grad_() + self.equalize_embeddings2 = self.equalize_embeddings2.to(device).requires_grad_() + assert self.bias_direction.grad is None + assert self.evaluation_embeddings.grad is None + assert self.equalize_embeddings1.grad is None + assert self.equalize_embeddings2.grad is None + + hbm = HardBiasMitigator(requires_grad=True) + test_bias_mitigated_embeddings = hbm( + self.evaluation_embeddings, + self.bias_direction, + self.equalize_embeddings1, + self.equalize_embeddings2, + ) + test_bias_mitigated_embeddings.sum().backward() + assert self.bias_direction.grad is not None + assert self.evaluation_embeddings.grad is not None + assert self.equalize_embeddings1.grad is not None + assert self.equalize_embeddings2.grad is not None + + +class INLPBiasMitigatorTest(AllenNlpTestCase): + def setup_method(self): + super().setup_method() + + # embedding data from VERB demo + emb_filename = str(self.FIXTURES_ROOT / "fairness" / "bias_embeddings.json") + with open(emb_filename) as emb_file: + emb_data = json.load(emb_file) + + seed_embeddings1 = [] + for word in ["man", "he", "his", "boy", "grandpa", "uncle", "jack"]: + seed_embeddings1.append(torch.Tensor(emb_data[word]).reshape(1, -1)) + self.seed_embeddings1 = torch.cat(seed_embeddings1) + + seed_embeddings2 = [] + for word in ["woman", "she", "her", "girl", "grandma", "aunt", "jill"]: + seed_embeddings2.append(torch.Tensor(emb_data[word]).reshape(1, -1)) + self.seed_embeddings2 = torch.cat(seed_embeddings2) + + evaluation_embeddings = [] + expected_bias_mitigated_embeddings = [] + for word in ["engineer", "homemaker"]: + evaluation_embeddings.append(torch.Tensor(emb_data[word]).reshape(1, -1)) + expected_bias_mitigated_embeddings.append( + torch.Tensor(emb_data["inlp_" + word]).reshape(1, -1) + ) + self.evaluation_embeddings = torch.cat(evaluation_embeddings) + self.expected_bias_mitigated_embeddings = torch.cat(expected_bias_mitigated_embeddings) + + def test_invalid_dims(self): + ibm = INLPBiasMitigator() + with pytest.raises(ConfigurationError): + ibm(torch.zeros(2), torch.zeros(2), torch.zeros(2)) + with pytest.raises(ConfigurationError): + ibm(torch.zeros(2), torch.zeros((2, 2)), torch.zeros((2, 3))) + with pytest.raises(ConfigurationError): + ibm(torch.zeros(2), torch.zeros((2, 2)), torch.zeros((2, 2))) + with pytest.raises(ConfigurationError): + ibm(torch.zeros((2, 3)), torch.zeros((2, 2)), torch.zeros((2, 2))) + + @multi_device + def test_inlp(self, device: str): + self.seed_embeddings1 = self.seed_embeddings1.to(device) + self.seed_embeddings2 = self.seed_embeddings2.to(device) + self.evaluation_embeddings = self.evaluation_embeddings.to(device) + self.expected_bias_mitigated_embeddings = self.expected_bias_mitigated_embeddings.to(device) + + ibm = INLPBiasMitigator() + test_bias_mitigated_embeddings = ibm( + self.evaluation_embeddings, self.seed_embeddings1, self.seed_embeddings2 + ) + assert allclose( + self.expected_bias_mitigated_embeddings, test_bias_mitigated_embeddings, atol=1e-6 + ) + + +class OSCaRBiasMitigatorTest(AllenNlpTestCase): + def setup_method(self): + super().setup_method() + + # embedding data from VERB demo + emb_filename = str(self.FIXTURES_ROOT / "fairness" / "bias_embeddings.json") + with open(emb_filename) as emb_file: + emb_data = json.load(emb_file) + + self.bias_direction1 = torch.Tensor(emb_data["oscar_bias1"]) + self.bias_direction2 = torch.Tensor(emb_data["oscar_bias2"]) + + evaluation_embeddings = [] + expected_bias_mitigated_embeddings = [] + for word in ["programmer", "grandpa", "grandma"]: + evaluation_embeddings.append(torch.Tensor(emb_data[word]).reshape(1, -1)) + expected_bias_mitigated_embeddings.append( + torch.Tensor(emb_data["oscar_" + word]).reshape(1, -1) + ) + self.evaluation_embeddings = torch.cat(evaluation_embeddings) + self.expected_bias_mitigated_embeddings = torch.cat(expected_bias_mitigated_embeddings) + + def test_invalid_dims(self): + ibm = INLPBiasMitigator() + with pytest.raises(ConfigurationError): + ibm(torch.zeros(2), torch.zeros(2), torch.zeros(2)) + with pytest.raises(ConfigurationError): + ibm(torch.zeros(2), torch.zeros((2, 2)), torch.zeros((2, 3))) + with pytest.raises(ConfigurationError): + ibm(torch.zeros((2, 3)), torch.zeros(2), torch.zeros(2)) + with pytest.raises(ConfigurationError): + ibm(torch.zeros((2, 1)), torch.zeros(1), torch.zeros(1)) + + @multi_device + def test_oscar_without_grad(self, device: str): + self.bias_direction1 = self.bias_direction1.to(device) + self.bias_direction2 = self.bias_direction2.to(device) + self.evaluation_embeddings = self.evaluation_embeddings.to(device) + self.expected_bias_mitigated_embeddings = self.expected_bias_mitigated_embeddings.to(device) + + obm = OSCaRBiasMitigator() + test_bias_mitigated_embeddings = obm( + self.evaluation_embeddings, self.bias_direction1, self.bias_direction2 + ) + assert allclose( + self.expected_bias_mitigated_embeddings, test_bias_mitigated_embeddings, atol=1e-6 + ) + + @multi_device + def test_oscar_with_grad(self, device: str): + self.bias_direction1 = self.bias_direction1.to(device).requires_grad_() + self.bias_direction2 = self.bias_direction2.to(device).requires_grad_() + self.evaluation_embeddings = self.evaluation_embeddings.to(device).requires_grad_() + assert self.bias_direction1.grad is None + assert self.bias_direction2.grad is None + assert self.evaluation_embeddings.grad is None + + obm = OSCaRBiasMitigator(requires_grad=True) + test_bias_mitigated_embeddings = obm( + self.evaluation_embeddings, self.bias_direction1, self.bias_direction2 + ) + test_bias_mitigated_embeddings.sum().backward() + assert self.bias_direction1.grad is not None + assert self.bias_direction2.grad is not None + assert self.evaluation_embeddings.grad is not None From 5dce9f5b8d657ac507cafb90fc63e928c6d4dc18 Mon Sep 17 00:00:00 2001 From: ArjunSubramonian <arjun.subramonian@gmail.com> Date: Wed, 12 May 2021 18:06:36 -0700 Subject: [PATCH 21/63] Bias Metrics (#5139) * finished WEAT * finished bias metrics * updated CHANGELOG * fixed gpu issu * fixed gpu issue * expanded NLI to include more NLI scores and work in batched and distributed settings * removed evaluate bias mitigation command from this PR Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-0-108.us-west-2.compute.internal> Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-1-108.us-west-2.compute.internal> Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com> Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-0-106.us-west-2.compute.internal> --- CHANGELOG.md | 3 +- allennlp/fairness/__init__.py | 11 +- allennlp/fairness/bias_metrics.py | 639 ++++++++++++++++++++++++ allennlp/fairness/fairness_metrics.py | 270 +--------- tests/fairness/bias_metrics_test.py | 368 ++++++++++++++ tests/fairness/fairness_metrics_test.py | 201 +------- 6 files changed, 1017 insertions(+), 475 deletions(-) create mode 100644 allennlp/fairness/bias_metrics.py create mode 100644 tests/fairness/bias_metrics_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a1d9b198274..da0b9df6211 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,7 +55,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add new dimension to the `interpret` module: influence functions via the `InfluenceInterpreter` base class, along with a concrete implementation: `SimpleInfluence`. - Added a `quiet` parameter to the `MultiProcessDataLoading` that disables `Tqdm` progress bars. - The test for distributed metrics now takes a parameter specifying how often you want to run it. -- Created the fairness module and added four fairness metrics: `Independence`, `Separation`, `Sufficiency`, and `DemographicParityWithoutGroundTruth`. +- Created the fairness module and added four fairness metrics: `Independence`, `Separation`, and `Sufficiency`. +- Added three bias metrics to the fairness module: `WordEmbeddingAssociationTest`, `EmbeddingCoherenceTest`, `NaturalLanguageInference`, and `AssociationWithoutGroundTruth`. - Added four bias direction methods (`PCABiasDirection`, `PairedPCABiasDirection`, `TwoMeansBiasDirection`, `ClassificationNormalBiasDirection`) and four bias mitigation methods (`LinearBiasMitigator`, `HardBiasMitigator`, `INLPBiasMitigator`, `OSCaRBiasMitigator`). ### Changed diff --git a/allennlp/fairness/__init__.py b/allennlp/fairness/__init__.py index 5c254de681b..976ada2d076 100644 --- a/allennlp/fairness/__init__.py +++ b/allennlp/fairness/__init__.py @@ -6,11 +6,12 @@ 3. debias embeddings during training time and post-processing """ -from allennlp.fairness.fairness_metrics import ( - Independence, - Separation, - Sufficiency, - DemographicParityWithoutGroundTruth, +from allennlp.fairness.fairness_metrics import Independence, Separation, Sufficiency +from allennlp.fairness.bias_metrics import ( + WordEmbeddingAssociationTest, + EmbeddingCoherenceTest, + NaturalLanguageInference, + AssociationWithoutGroundTruth, ) from allennlp.fairness.bias_direction import ( PCABiasDirection, diff --git a/allennlp/fairness/bias_metrics.py b/allennlp/fairness/bias_metrics.py new file mode 100644 index 00000000000..e7be2763c1c --- /dev/null +++ b/allennlp/fairness/bias_metrics.py @@ -0,0 +1,639 @@ +""" + +A suite of metrics to quantify how much bias is encoded by word embeddings +and determine the effectiveness of bias mitigation. + +Bias metrics are based on: + +1. Caliskan, A., Bryson, J., & Narayanan, A. (2017). [Semantics derived automatically +from language corpora contain human-like biases](https://api.semanticscholar.org/CorpusID:23163324). +Science, 356, 183 - 186. + +2. Dev, S., & Phillips, J.M. (2019). [Attenuating Bias in Word Vectors] +(https://api.semanticscholar.org/CorpusID:59158788). AISTATS. + +3. Dev, S., Li, T., Phillips, J.M., & Srikumar, V. (2020). [On Measuring and Mitigating +Biased Inferences of Word Embeddings](https://api.semanticscholar.org/CorpusID:201670701). +ArXiv, abs/1908.09369. + +4. Rathore, A., Dev, S., Phillips, J.M., Srikumar, V., Zheng, Y., Yeh, C.M., Wang, J., Zhang, +W., & Wang, B. (2021). [VERB: Visualizing and Interpreting Bias Mitigation Techniques for +Word Representations](https://api.semanticscholar.org/CorpusID:233168618). +ArXiv, abs/2104.02797. + +5. Aka, O.; Burke, K.; Bäuerle, A.; Greer, C.; and Mitchell, M. 2021. +[Measuring model biases in the absence of ground truth](https://api.semanticscholar.org/CorpusID:232135043). +arXiv preprint arXiv:2103.03417. + +""" + +from typing import Optional, Dict, Union, List + +from overrides import overrides +import torch +import torch.distributed as dist + +from allennlp.common.util import is_distributed +from allennlp.common.checks import ConfigurationError +from allennlp.nn.util import dist_reduce_sum +from allennlp.training.metrics.metric import Metric + + +class WordEmbeddingAssociationTest: + """ + Word Embedding Association Test (WEAT) score measures the unlikelihood there is no + difference between two sets of target words in terms of their relative similarity + to two sets of attribute words by computing the probability that a random + permutation of attribute words would produce the observed (or greater) difference + in sample means. Analog of Implicit Association Test from psychology for word embeddings. + + Based on: Caliskan, A., Bryson, J., & Narayanan, A. (2017). [Semantics derived automatically + from language corpora contain human-like biases](https://api.semanticscholar.org/CorpusID:23163324). + Science, 356, 183 - 186. + """ + + def __call__( + self, + target_embeddings1: torch.Tensor, + target_embeddings2: torch.Tensor, + attribute_embeddings1: torch.Tensor, + attribute_embeddings2: torch.Tensor, + ) -> torch.FloatTensor: + """ + + # Parameters + + !!! Note + In the examples below, we treat gender identity as binary, which does not accurately + characterize gender in real life. + + target_embeddings1 : `torch.Tensor`, required. + A tensor of size (target_embeddings_batch_size, ..., dim) containing target word + embeddings related to a concept group. For example, if the concept is gender, + target_embeddings1 could contain embeddings for linguistically masculine words, e.g. + "man", "king", "brother", etc. Represented as X. + + target_embeddings2 : `torch.Tensor`, required. + A tensor of the same size as target_embeddings1 containing target word + embeddings related to a different group for the same concept. For example, + target_embeddings2 could contain embeddings for linguistically feminine words, e.g. + "woman", "queen", "sister", etc. Represented as Y. + + attribute_embeddings1 : `torch.Tensor`, required. + A tensor of size (attribute_embeddings1_batch_size, ..., dim) containing attribute word + embeddings related to a concept group associated with the concept group for target_embeddings1. + For example, if the concept is professions, attribute_embeddings1 could contain embeddings for + stereotypically male professions, e.g. "doctor", "banker", "engineer", etc. Represented as A. + + attribute_embeddings2 : `torch.Tensor`, required. + A tensor of size (attribute_embeddings2_batch_size, ..., dim) containing attribute word + embeddings related to a concept group associated with the concept group for target_embeddings2. + For example, if the concept is professions, attribute_embeddings2 could contain embeddings for + stereotypically female professions, e.g. "nurse", "receptionist", "homemaker", etc. Represented as B. + + !!! Note + While target_embeddings1 and target_embeddings2 must be the same size, attribute_embeddings1 and + attribute_embeddings2 need not be the same size. + + # Returns + + weat_score : `torch.FloatTensor` + The unlikelihood there is no difference between target_embeddings1 and target_embeddings2 in + terms of their relative similarity to attribute_embeddings1 and attribute_embeddings2. + Typical values are around [-1, 1], with values closer to 0 indicating less biased associations. + + """ + + # Some sanity checks + if target_embeddings1.ndim < 2 or target_embeddings2.ndim < 2: + raise ConfigurationError( + "target_embeddings1 and target_embeddings2 must have at least two dimensions." + ) + if attribute_embeddings1.ndim < 2 or attribute_embeddings2.ndim < 2: + raise ConfigurationError( + "attribute_embeddings1 and attribute_embeddings2 must have at least two dimensions." + ) + if target_embeddings1.size() != target_embeddings2.size(): + raise ConfigurationError( + "target_embeddings1 and target_embeddings2 must be of the same size." + ) + if attribute_embeddings1.size(dim=-1) != attribute_embeddings2.size( + dim=-1 + ) or attribute_embeddings1.size(dim=-1) != target_embeddings1.size(dim=-1): + raise ConfigurationError("All embeddings must have the same dimensionality.") + + target_embeddings1 = target_embeddings1.flatten(end_dim=-2) + target_embeddings2 = target_embeddings2.flatten(end_dim=-2) + attribute_embeddings1 = attribute_embeddings1.flatten(end_dim=-2) + attribute_embeddings2 = attribute_embeddings2.flatten(end_dim=-2) + + # Normalize + target_embeddings1 = torch.nn.functional.normalize(target_embeddings1, p=2, dim=-1) + target_embeddings2 = torch.nn.functional.normalize(target_embeddings2, p=2, dim=-1) + attribute_embeddings1 = torch.nn.functional.normalize(attribute_embeddings1, p=2, dim=-1) + attribute_embeddings2 = torch.nn.functional.normalize(attribute_embeddings2, p=2, dim=-1) + + # Compute cosine similarities + X_sim_A = torch.mm(target_embeddings1, attribute_embeddings1.t()) + X_sim_B = torch.mm(target_embeddings1, attribute_embeddings2.t()) + Y_sim_A = torch.mm(target_embeddings2, attribute_embeddings1.t()) + Y_sim_B = torch.mm(target_embeddings2, attribute_embeddings2.t()) + X_union_Y_sim_A = torch.cat([X_sim_A, Y_sim_A]) + X_union_Y_sim_B = torch.cat([X_sim_B, Y_sim_B]) + + s_X_A_B = torch.mean(X_sim_A, dim=-1) - torch.mean(X_sim_B, dim=-1) + s_Y_A_B = torch.mean(Y_sim_A, dim=-1) - torch.mean(Y_sim_B, dim=-1) + s_X_Y_A_B = torch.mean(s_X_A_B) - torch.mean(s_Y_A_B) + S_X_union_Y_A_B = torch.mean(X_union_Y_sim_A, dim=-1) - torch.mean(X_union_Y_sim_B, dim=-1) + return s_X_Y_A_B / torch.std(S_X_union_Y_A_B, unbiased=False) + + +class EmbeddingCoherenceTest: + """ + Embedding Coherence Test (ECT) score measures if groups of words + have stereotypical associations by computing the Spearman Coefficient + of lists of attribute embeddings sorted based on their similarity to + target embeddings. + + Based on: Dev, S., & Phillips, J.M. (2019). [Attenuating Bias in Word Vectors] + (https://api.semanticscholar.org/CorpusID:59158788). AISTATS. + """ + + def __call__( + self, + target_embeddings1: torch.Tensor, + target_embeddings2: torch.Tensor, + attribute_embeddings: torch.Tensor, + ) -> torch.FloatTensor: + """ + + # Parameters + + !!! Note + In the examples below, we treat gender identity as binary, which does not accurately + characterize gender in real life. + + target_embeddings1 : `torch.Tensor`, required. + A tensor of size (target_embeddings_batch_size, ..., dim) containing target word + embeddings related to a concept group. For example, if the concept is gender, + target_embeddings1 could contain embeddings for linguistically masculine words, e.g. + "man", "king", "brother", etc. Represented as X. + + target_embeddings2 : `torch.Tensor`, required. + A tensor of the same size as target_embeddings1 containing target word + embeddings related to a different group for the same concept. For example, + target_embeddings2 could contain embeddings for linguistically feminine words, e.g. + "woman", "queen", "sister", etc. Represented as Y. + + attribute_embeddings : `torch.Tensor`, required. + A tensor of size (attribute_embeddings_batch_size, ..., dim) containing attribute word + embeddings related to a concept associated with target_embeddings1 and target_embeddings2. + For example, if the concept is professions, attribute_embeddings could contain embeddings for + "doctor", "banker", "engineer", etc. Represented as AB. + + # Returns + + ect_score : `torch.FloatTensor` + The Spearman Coefficient measuring the similarity of lists of attribute embeddings sorted + based on their similarity to the target embeddings. Ranges from [-1, 1], with values closer + to 1 indicating less biased associations. + + """ + # Some sanity checks + if target_embeddings1.ndim < 2 or target_embeddings2.ndim < 2: + raise ConfigurationError( + "target_embeddings1 and target_embeddings2 must have at least two dimensions." + ) + if attribute_embeddings.ndim < 2: + raise ConfigurationError("attribute_embeddings must have at least two dimensions.") + if target_embeddings1.size() != target_embeddings2.size(): + raise ConfigurationError( + "target_embeddings1 and target_embeddings2 must be of the same size." + ) + if attribute_embeddings.size(dim=-1) != target_embeddings1.size(dim=-1): + raise ConfigurationError("All embeddings must have the same dimensionality.") + + mean_target_embedding1 = target_embeddings1.flatten(end_dim=-2).mean(dim=0) + mean_target_embedding2 = target_embeddings2.flatten(end_dim=-2).mean(dim=0) + attribute_embeddings = attribute_embeddings.flatten(end_dim=-2) + + # Normalize + mean_target_embedding1 = torch.nn.functional.normalize(mean_target_embedding1, p=2, dim=-1) + mean_target_embedding2 = torch.nn.functional.normalize(mean_target_embedding2, p=2, dim=-1) + attribute_embeddings = torch.nn.functional.normalize(attribute_embeddings, p=2, dim=-1) + + # Compute cosine similarities + AB_sim_m = torch.matmul(attribute_embeddings, mean_target_embedding1) + AB_sim_f = torch.matmul(attribute_embeddings, mean_target_embedding2) + + return self.spearman_correlation(AB_sim_m, AB_sim_f) + + def _get_ranks(self, x: torch.Tensor) -> torch.Tensor: + tmp = x.argsort() + ranks = torch.zeros_like(tmp) + ranks[tmp] = torch.arange(x.size(0), device=ranks.device) + return ranks + + def spearman_correlation(self, x: torch.Tensor, y: torch.Tensor): + x_rank = self._get_ranks(x) + y_rank = self._get_ranks(y) + + n = x.size(0) + upper = 6 * torch.sum((x_rank - y_rank).pow(2)) + down = n * (n ** 2 - 1.0) + return 1.0 - (upper / down) + + +@Metric.register("nli") +class NaturalLanguageInference(Metric): + """ + Natural language inference scores measure the effect biased associations have on decisions + made downstream, given neutrally-constructed pairs of sentences differing only in the subject. + + 1. Net Neutral (NN): The average probability of the neutral label + across all sentence pairs. + + 2. Fraction Neutral (FN): The fraction of sentence pairs predicted neutral. + + 3. Threshold:tau (T:tau): A parameterized measure that reports the fraction + of examples whose probability of neutral is above tau. + + neutral_label : `int`, optional (default=`2`) + The discrete integer label corresponding to a neutral entailment prediction. + taus : `List[float]`, optional (default=`[0.5, 0.7]`) + All the taus for which to compute Threshold:tau. + + Based on: Dev, S., Li, T., Phillips, J.M., & Srikumar, V. (2020). [On Measuring and Mitigating + Biased Inferences of Word Embeddings](https://api.semanticscholar.org/CorpusID:201670701). + ArXiv, abs/1908.09369. + """ + + def __init__(self, neutral_label: int = 2, taus: List[float] = [0.5, 0.7]): + self.neutral_label = neutral_label + self.taus = taus + + self._nli_probs_sum = 0.0 + self._num_neutral_predictions = 0.0 + self._num_neutral_above_taus = {tau: 0.0 for tau in taus} + self._total_predictions = 0 + + @overrides + def __call__(self, nli_probabilities: torch.Tensor) -> None: + """ + + # Parameters + + !!! Note + In the examples below, we treat gender identity as binary, which does not accurately + characterize gender in real life. + + nli_probabilities : `torch.Tensor`, required. + A tensor of size (batch_size, ..., 3) containing natural language inference + (i.e. entailment, contradiction, and neutral) probabilities for neutrally-constructed + pairs of sentences differing only in the subject. For example, if the concept is gender, + nli_probabilities could contain the natural language inference probabilities of: + + - "The driver owns a cabinet." -> "The man owns a cabinet." + + - "The driver owns a cabinet." -> "The woman owns a cabinet." + + - "The doctor eats an apple." -> "The man eats an apple." + + - "The doctor eats an apple." -> "The woman eats an apple." + """ + nli_probabilities = nli_probabilities.detach() + + # Some sanity checks + if nli_probabilities.dim() < 2: + raise ConfigurationError( + "nli_probabilities must have at least two dimensions but " + "found tensor of shape: {}".format(nli_probabilities.size()) + ) + if nli_probabilities.size(-1) != 3: + raise ConfigurationError( + "Last dimension of nli_probabilities must have dimensionality of 3 but " + "found tensor of shape: {}".format(nli_probabilities.size()) + ) + + _nli_neutral_probs = nli_probabilities[..., self.neutral_label] + + self._nli_probs_sum += dist_reduce_sum(_nli_neutral_probs.sum().item()) + self._num_neutral_predictions += dist_reduce_sum( + (nli_probabilities.argmax(dim=-1) == self.neutral_label).float().sum().item() + ) + for tau in self.taus: + self._num_neutral_above_taus[tau] += dist_reduce_sum( + (_nli_neutral_probs > tau).float().sum().item() + ) + self._total_predictions += dist_reduce_sum(_nli_neutral_probs.numel()) + + def get_metric(self, reset: bool = False): + """ + # Returns + + nli_scores : `Dict[str, float]` + Contains the following keys: + + 1. "`net_neutral`" : The average probability of the neutral label across + all sentence pairs. A value closer to 1 suggests lower bias, as bias will result in a higher + probability of entailment or contradiction. + + 2. "`fraction_neutral`" : The fraction of sentence pairs predicted neutral. + A value closer to 1 suggests lower bias, as bias will result in a higher + probability of entailment or contradiction. + + 3. "`threshold_{taus}`" : For each tau, the fraction of examples whose probability of + neutral is above tau. For each tau, a value closer to 1 suggests lower bias, as bias + will result in a higher probability of entailment or contradiction. + + """ + if self._total_predictions == 0: + nli_scores = { + "net_neutral": 0.0, + "fraction_neutral": 0.0, + **{"threshold_{}".format(tau): 0.0 for tau in self.taus}, + } + else: + nli_scores = { + "net_neutral": self._nli_probs_sum / self._total_predictions, + "fraction_neutral": self._num_neutral_predictions / self._total_predictions, + **{ + "threshold_{}".format(tau): self._num_neutral_above_taus[tau] + / self._total_predictions + for tau in self.taus + }, + } + if reset: + self.reset() + return nli_scores + + @overrides + def reset(self): + self._nli_probs_sum = 0.0 + self._num_neutral_predictions = 0.0 + self._num_neutral_above_taus = {tau: 0.0 for tau in self.taus} + self._total_predictions = 0 + + +@Metric.register("association_without_ground_truth") +class AssociationWithoutGroundTruth(Metric): + """ + Association without ground truth, from: Aka, O.; Burke, K.; Bäuerle, A.; + Greer, C.; and Mitchell, M. 2021. Measuring model biases in the absence of ground + truth. arXiv preprint arXiv:2103.03417. + + # Parameters + + num_classes : `int` + Number of classes. + num_protected_variable_labels : `int` + Number of protected variable labels. + association_metric : `str`, optional (default = `"npmixy"`). + A generic association metric A(x, y), where x is an identity label and y is any other label. + Examples include: nPMIxy (`"npmixy"`), nPMIy (`"npmiy"`), PMI^2 (`"pmisq"`), PMI (`"pmi"`) + Empirically, nPMIxy and nPMIy are more capable of capturing labels across a range of + marginal frequencies. + gap_type : `str`, optional (default = `"ova"`). + Either one-vs-all (`"ova"`) or pairwise (`"pairwise"`). One-vs-all gap is equivalent to + A(x, y) - E[A(x', y)], where x' is in the set of all protected variable labels setminus {x}. + Pairwise gaps are A(x, y) - A(x', y), for all x' in the set of all protected variable labels + setminus {x}. + + !!! Note + Assumes integer predictions, with each item to be classified + having a single correct class. + """ + + def __init__( + self, + num_classes: int, + num_protected_variable_labels: int, + association_metric: str = "npmixy", + gap_type: str = "ova", + ) -> None: + self._num_classes = num_classes + self._num_protected_variable_labels = num_protected_variable_labels + self._joint_counts_by_protected_variable_label = torch.zeros( + (num_protected_variable_labels, num_classes) + ) + self._protected_variable_label_counts = torch.zeros(num_protected_variable_labels) + self._y_counts = torch.zeros(num_classes) + self._total_predictions = torch.tensor(0) + + self.IMPLEMENTED_ASSOCIATION_METRICS = set(["npmixy", "npmiy", "pmisq", "pmi"]) + if association_metric in self.IMPLEMENTED_ASSOCIATION_METRICS: + self.association_metric = association_metric + else: + raise NotImplementedError( + f"Association metric {association_metric} has not been implemented!" + ) + + if gap_type == "ova": + self.gap_func = self._ova_gap + elif gap_type == "pairwise": + self.gap_func = self._pairwise_gaps + else: + raise NotImplementedError(f"Gap type {gap_type} has not been implemented!") + + def __call__( + self, + predicted_labels: torch.Tensor, + protected_variable_labels: torch.Tensor, + mask: Optional[torch.BoolTensor] = None, + ) -> None: + """ + # Parameters + + predicted_labels : `torch.Tensor`, required. + A tensor of predicted integer class labels of shape (batch_size, ...). Represented as Y. + protected_variable_labels : `torch.Tensor`, required. + A tensor of integer protected variable labels of shape (batch_size, ...). It must be the same + shape as the `predicted_labels` tensor. Represented as X. + mask : `torch.BoolTensor`, optional (default = `None`). + A tensor of the same shape as `predicted_labels`. + + !!! Note + All tensors are expected to be on the same device. + """ + predicted_labels, protected_variable_labels, mask = self.detach_tensors( + predicted_labels, protected_variable_labels, mask + ) + + # Some sanity checks. + if predicted_labels.size() != protected_variable_labels.size(): + raise ConfigurationError( + "protected_variable_labels must be of same size as predicted_labels but " + "found tensor of shape: {}".format(protected_variable_labels.size()) + ) + if mask is not None and predicted_labels.size() != mask.size(): + raise ConfigurationError( + "mask must be of same size as predicted_labels but " + "found tensor of shape: {}".format(mask.size()) + ) + if (predicted_labels >= self._num_classes).any(): + raise ConfigurationError( + "predicted_labels contains an id >= {}, " + "the number of classes.".format(self._num_classes) + ) + if (protected_variable_labels >= self._num_protected_variable_labels).any(): + raise ConfigurationError( + "protected_variable_labels contains an id >= {}, " + "the number of protected variable labels.".format( + self._num_protected_variable_labels + ) + ) + + device = predicted_labels.device + self._joint_counts_by_protected_variable_label = ( + self._joint_counts_by_protected_variable_label.to(device) + ) + self._protected_variable_label_counts = self._protected_variable_label_counts.to(device) + self._y_counts = self._y_counts.to(device) + self._total_predictions = self._total_predictions.to(device) + + if mask is not None: + predicted_labels = predicted_labels[mask] + protected_variable_labels = protected_variable_labels[mask] + else: + predicted_labels = predicted_labels.flatten() + protected_variable_labels = protected_variable_labels.flatten() + + _total_predictions = torch.tensor(predicted_labels.nelement()).to(device) + _y_counts = torch.zeros(self._num_classes).to(device) + _y_counts = torch.zeros_like(_y_counts, dtype=predicted_labels.dtype).scatter_add_( + 0, predicted_labels, torch.ones_like(predicted_labels) + ) + + _joint_counts_by_protected_variable_label = torch.zeros( + (self._num_protected_variable_labels, self._num_classes) + ).to(device) + _protected_variable_label_counts = torch.zeros(self._num_protected_variable_labels).to( + device + ) + for x in range(self._num_protected_variable_labels): + x_mask = (protected_variable_labels == x).long() + + _joint_counts_by_protected_variable_label[x] = torch.zeros(self._num_classes).to(device) + _joint_counts_by_protected_variable_label[x] = torch.zeros_like( + _joint_counts_by_protected_variable_label[x], dtype=x_mask.dtype + ).scatter_add_(0, predicted_labels, x_mask) + + _protected_variable_label_counts[x] = torch.tensor(x_mask.sum()).to(device) + + if is_distributed(): + _total_predictions = _total_predictions.to(device) + dist.all_reduce(_total_predictions, op=dist.ReduceOp.SUM) + + _y_counts = _y_counts.to(device) + dist.all_reduce(_y_counts, op=dist.ReduceOp.SUM) + + _joint_counts_by_protected_variable_label = ( + _joint_counts_by_protected_variable_label.to(device) + ) + dist.all_reduce(_joint_counts_by_protected_variable_label, op=dist.ReduceOp.SUM) + + _protected_variable_label_counts = _protected_variable_label_counts.to(device) + dist.all_reduce(_protected_variable_label_counts, op=dist.ReduceOp.SUM) + + self._total_predictions += _total_predictions + self._y_counts += _y_counts + self._joint_counts_by_protected_variable_label += _joint_counts_by_protected_variable_label + self._protected_variable_label_counts += _protected_variable_label_counts + + @overrides + def get_metric( + self, reset: bool = False + ) -> Dict[int, Union[torch.FloatTensor, Dict[int, torch.FloatTensor]]]: + """ + # Returns + + gaps : `Dict[int, Union[torch.FloatTensor, Dict[int, torch.FloatTensor]]]` + A dictionary mapping each protected variable label x to either: + + 1. a tensor of the one-vs-all gaps (where the gap corresponding to prediction + label i is at index i), + + 2. another dictionary mapping protected variable labels x' to a tensor + of the pairwise gaps (where the gap corresponding to prediction label i is at index i). + A gap of nearly 0 implies fairness on the basis of Association in the Absence of Ground Truth. + + !!! Note + If a possible class label is not present in Y, the expected behavior is that + the gaps corresponding to this class label are NaN. If a possible (class label, + protected variable label) pair is not present in the joint of Y and X, the expected + behavior is that the gap corresponding to this (class label, protected variable label) + pair is NaN. + """ + gaps = {} + for x in range(self._num_protected_variable_labels): + gaps[x] = self.gap_func(x) + if reset: + self.reset() + return gaps + + @overrides + def reset(self) -> None: + self._joint_counts_by_protected_variable_label = torch.zeros( + (self._num_protected_variable_labels, self._num_classes) + ) + self._protected_variable_label_counts = torch.zeros(self._num_protected_variable_labels) + self._y_counts = torch.zeros(self._num_classes) + self._total_predictions = torch.tensor(0) + + def _ova_gap(self, x: int): + device = self._y_counts.device + pmi_terms = self._all_pmi_terms() + pmi_not_x = torch.sum( + pmi_terms[torch.arange(self._num_protected_variable_labels, device=device) != x], dim=0 + ) + pmi_not_x /= self._num_protected_variable_labels - 1 + + # Will contain NaN if not all possible class labels are predicted + # Will contain NaN if not all possible (class label, + # protected variable label) pairs are predicted + gap = pmi_terms[x] - pmi_not_x + return torch.where(~gap.isinf(), gap, torch.tensor(float("nan")).to(device)) + + def _pairwise_gaps(self, x: int): + device = self._y_counts.device + pmi_terms = self._all_pmi_terms() + pairwise_gaps = {} + for not_x in range(self._num_protected_variable_labels): + gap = pmi_terms[x] - pmi_terms[not_x] + pairwise_gaps[not_x] = torch.where( + ~gap.isinf(), gap, torch.tensor(float("nan")).to(device) + ) + return pairwise_gaps + + def _all_pmi_terms(self) -> Dict[int, torch.Tensor]: + if self._total_predictions == 0: + return torch.full( + (self._num_protected_variable_labels, self._num_classes), float("nan") + ) + + device = self._y_counts.device + prob_y = torch.zeros(self._num_classes).to(device) + torch.div(self._y_counts, self._total_predictions, out=prob_y) + + joint = torch.zeros((self._num_protected_variable_labels, self._num_classes)).to(device) + torch.div( + self._joint_counts_by_protected_variable_label, + self._total_predictions, + out=joint, + ) + if self.association_metric == "pmisq": + torch.square_(joint) + + pmi_terms = torch.log( + torch.div( + joint, + (self._protected_variable_label_counts / self._total_predictions).unsqueeze(-1) + * prob_y, + ) + ) + if self.association_metric == "npmixy": + pmi_terms.div_(torch.log(joint)) + elif self.association_metric == "npmiy": + pmi_terms.div_(torch.log(prob_y)) + + return pmi_terms diff --git a/allennlp/fairness/fairness_metrics.py b/allennlp/fairness/fairness_metrics.py index ab28eeacf7b..752e486c527 100644 --- a/allennlp/fairness/fairness_metrics.py +++ b/allennlp/fairness/fairness_metrics.py @@ -15,15 +15,11 @@ adversarially learning fair representations](https://api.semanticscholar.org/CorpusID:24990444). arXiv preprint arXiv:1707.00075. -5. Aka, O.; Burke, K.; Bäuerle, A.; Greer, C.; and Mitchell, M. 2021. -[Measuring model biases in the absence of ground truth](https://api.semanticscholar.org/CorpusID:232135043). -arXiv preprint arXiv:2103.03417. - It is provably [impossible](https://fairmlbook.org/pdf/classification.pdf) (pg. 18) to satisfy any two of Independence, Separation, and Sufficiency simultaneously, except in degenerate cases. """ -from typing import Optional, Dict, Union +from typing import Optional, Dict from overrides import overrides import torch @@ -614,267 +610,3 @@ def reset(self) -> None: self._gold_label_counts_by_predicted_label_and_protected_variable_label = torch.zeros( (self._num_classes, self._num_protected_variable_labels, self._num_classes) ) - - -@Metric.register("demographic_parity_without_ground_truth") -class DemographicParityWithoutGroundTruth(Metric): - """ - Demographic parity without ground truth, from: Aka, O.; Burke, K.; Bäuerle, A.; - Greer, C.; and Mitchell, M. 2021. Measuring model biases in the absence of ground - truth. arXiv preprint arXiv:2103.03417. - - # Parameters - - num_classes : `int` - Number of classes. - num_protected_variable_labels : `int` - Number of protected variable labels. - association_metric : `str`, optional (default = `"npmixy"`). - A generic association metric A(x, y), where x is an identity label and y is any other label. - Examples include: nPMIxy (`"npmixy"`), nPMIy (`"npmiy"`), PMI^2 (`"pmisq"`), PMI (`"pmi"`) - Empirically, nPMIxy and nPMIy are more capable of capturing labels across a range of - marginal frequencies. - gap_type : `str`, optional (default = `"ova"`). - Either one-vs-all (`"ova"`) or pairwise (`"pairwise"`). One-vs-all gap is equivalent to - A(x, y) - E[A(x', y)], where x' is in the set of all protected variable labels setminus {x}. - Pairwise gaps are A(x, y) - A(x', y), for all x' in the set of all protected variable labels - setminus {x}. - - !!! Note - Assumes integer predictions, with each item to be classified - having a single correct class. - """ - - def __init__( - self, - num_classes: int, - num_protected_variable_labels: int, - association_metric: str = "npmixy", - gap_type: str = "ova", - ) -> None: - self._num_classes = num_classes - self._num_protected_variable_labels = num_protected_variable_labels - self._joint_counts_by_protected_variable_label = torch.zeros( - (num_protected_variable_labels, num_classes) - ) - self._protected_variable_label_counts = torch.zeros(num_protected_variable_labels) - self._y_counts = torch.zeros(num_classes) - self._total_predictions = torch.tensor(0) - - self.IMPLEMENTED_ASSOCIATION_METRICS = set(["npmixy", "npmiy", "pmisq", "pmi"]) - if association_metric in self.IMPLEMENTED_ASSOCIATION_METRICS: - self.association_metric = association_metric - else: - raise NotImplementedError( - f"Association metric {association_metric} has not been implemented!" - ) - - if gap_type == "ova": - self.gap_func = self._ova_gap - elif gap_type == "pairwise": - self.gap_func = self._pairwise_gaps - else: - raise NotImplementedError(f"Gap type {gap_type} has not been implemented!") - - def __call__( - self, - predicted_labels: torch.Tensor, - protected_variable_labels: torch.Tensor, - mask: Optional[torch.BoolTensor] = None, - ) -> None: - """ - # Parameters - - predicted_labels : `torch.Tensor`, required. - A tensor of predicted integer class labels of shape (batch_size, ...). Represented as Y. - protected_variable_labels : `torch.Tensor`, required. - A tensor of integer protected variable labels of shape (batch_size, ...). It must be the same - shape as the `predicted_labels` tensor. Represented as X. - mask : `torch.BoolTensor`, optional (default = `None`). - A tensor of the same shape as `predicted_labels`. - - !!! Note - All tensors are expected to be on the same device. - """ - predicted_labels, protected_variable_labels, mask = self.detach_tensors( - predicted_labels, protected_variable_labels, mask - ) - - # Some sanity checks. - if predicted_labels.size() != protected_variable_labels.size(): - raise ConfigurationError( - "protected_variable_labels must be of same size as predicted_labels but " - "found tensor of shape: {}".format(protected_variable_labels.size()) - ) - if mask is not None and predicted_labels.size() != mask.size(): - raise ConfigurationError( - "mask must be of same size as predicted_labels but " - "found tensor of shape: {}".format(mask.size()) - ) - if (predicted_labels >= self._num_classes).any(): - raise ConfigurationError( - "predicted_labels contains an id >= {}, " - "the number of classes.".format(self._num_classes) - ) - if (protected_variable_labels >= self._num_protected_variable_labels).any(): - raise ConfigurationError( - "protected_variable_labels contains an id >= {}, " - "the number of protected variable labels.".format( - self._num_protected_variable_labels - ) - ) - - device = predicted_labels.device - self._joint_counts_by_protected_variable_label = ( - self._joint_counts_by_protected_variable_label.to(device) - ) - self._protected_variable_label_counts = self._protected_variable_label_counts.to(device) - self._y_counts = self._y_counts.to(device) - self._total_predictions = self._total_predictions.to(device) - - if mask is not None: - predicted_labels = predicted_labels[mask] - protected_variable_labels = protected_variable_labels[mask] - else: - predicted_labels = predicted_labels.flatten() - protected_variable_labels = protected_variable_labels.flatten() - - _total_predictions = torch.tensor(predicted_labels.nelement()).to(device) - _y_counts = torch.zeros(self._num_classes).to(device) - _y_counts = torch.zeros_like(_y_counts, dtype=predicted_labels.dtype).scatter_add_( - 0, predicted_labels, torch.ones_like(predicted_labels) - ) - - _joint_counts_by_protected_variable_label = torch.zeros( - (self._num_protected_variable_labels, self._num_classes) - ).to(device) - _protected_variable_label_counts = torch.zeros(self._num_protected_variable_labels).to( - device - ) - for x in range(self._num_protected_variable_labels): - x_mask = (protected_variable_labels == x).long() - - _joint_counts_by_protected_variable_label[x] = torch.zeros(self._num_classes).to(device) - _joint_counts_by_protected_variable_label[x] = torch.zeros_like( - _joint_counts_by_protected_variable_label[x], dtype=x_mask.dtype - ).scatter_add_(0, predicted_labels, x_mask) - - _protected_variable_label_counts[x] = torch.tensor(x_mask.sum()).to(device) - - if is_distributed(): - _total_predictions = _total_predictions.to(device) - dist.all_reduce(_total_predictions, op=dist.ReduceOp.SUM) - - _y_counts = _y_counts.to(device) - dist.all_reduce(_y_counts, op=dist.ReduceOp.SUM) - - _joint_counts_by_protected_variable_label = ( - _joint_counts_by_protected_variable_label.to(device) - ) - dist.all_reduce(_joint_counts_by_protected_variable_label, op=dist.ReduceOp.SUM) - - _protected_variable_label_counts = _protected_variable_label_counts.to(device) - dist.all_reduce(_protected_variable_label_counts, op=dist.ReduceOp.SUM) - - self._total_predictions += _total_predictions - self._y_counts += _y_counts - self._joint_counts_by_protected_variable_label += _joint_counts_by_protected_variable_label - self._protected_variable_label_counts += _protected_variable_label_counts - - @overrides - def get_metric( - self, reset: bool = False - ) -> Dict[int, Union[torch.FloatTensor, Dict[int, torch.FloatTensor]]]: - """ - # Returns - - gaps : `Dict[int, Union[torch.FloatTensor, Dict[int, torch.FloatTensor]]]` - A dictionary mapping each protected variable label x to either: - - 1. a tensor of the one-vs-all gaps (where the gap corresponding to prediction - label i is at index i), - - 2. another dictionary mapping protected variable labels x' to a tensor - of the pairwise gaps (where the gap corresponding to prediction label i is at index i). - A gap of nearly 0 implies fairness on the basis of Demographic Parity in the Absence of Ground Truth. - - !!! Note - If a possible class label is not present in Y, the expected behavior is that - the gaps corresponding to this class label are NaN. If a possible (class label, - protected variable label) pair is not present in the joint of Y and X, the expected - behavior is that the gap corresponding to this (class label, protected variable label) - pair is NaN. - """ - gaps = {} - for x in range(self._num_protected_variable_labels): - gaps[x] = self.gap_func(x) - if reset: - self.reset() - return gaps - - @overrides - def reset(self) -> None: - self._joint_counts_by_protected_variable_label = torch.zeros( - (self._num_protected_variable_labels, self._num_classes) - ) - self._protected_variable_label_counts = torch.zeros(self._num_protected_variable_labels) - self._y_counts = torch.zeros(self._num_classes) - self._total_predictions = torch.tensor(0) - - def _ova_gap(self, x: int): - device = self._y_counts.device - pmi_terms = self._all_pmi_terms() - pmi_not_x = torch.sum( - pmi_terms[torch.arange(self._num_protected_variable_labels, device=device) != x], dim=0 - ) - pmi_not_x /= self._num_protected_variable_labels - 1 - - # Will contain NaN if not all possible class labels are predicted - # Will contain NaN if not all possible (class label, - # protected variable label) pairs are predicted - gap = pmi_terms[x] - pmi_not_x - return torch.where(~gap.isinf(), gap, torch.tensor(float("nan")).to(device)) - - def _pairwise_gaps(self, x: int): - device = self._y_counts.device - pmi_terms = self._all_pmi_terms() - pairwise_gaps = {} - for not_x in range(self._num_protected_variable_labels): - gap = pmi_terms[x] - pmi_terms[not_x] - pairwise_gaps[not_x] = torch.where( - ~gap.isinf(), gap, torch.tensor(float("nan")).to(device) - ) - return pairwise_gaps - - def _all_pmi_terms(self) -> Dict[int, torch.Tensor]: - if self._total_predictions == 0: - return torch.full( - (self._num_protected_variable_labels, self._num_classes), float("nan") - ) - - device = self._y_counts.device - prob_y = torch.zeros(self._num_classes).to(device) - torch.div(self._y_counts, self._total_predictions, out=prob_y) - - joint = torch.zeros((self._num_protected_variable_labels, self._num_classes)).to(device) - torch.div( - self._joint_counts_by_protected_variable_label, - self._total_predictions, - out=joint, - ) - if self.association_metric == "pmisq": - torch.square_(joint) - - pmi_terms = torch.log( - torch.div( - joint, - (self._protected_variable_label_counts / self._total_predictions).unsqueeze(-1) - * prob_y, - ) - ) - if self.association_metric == "npmixy": - pmi_terms.div_(torch.log(joint)) - elif self.association_metric == "npmiy": - pmi_terms.div_(torch.log(prob_y)) - - return pmi_terms diff --git a/tests/fairness/bias_metrics_test.py b/tests/fairness/bias_metrics_test.py new file mode 100644 index 00000000000..ba9675946e8 --- /dev/null +++ b/tests/fairness/bias_metrics_test.py @@ -0,0 +1,368 @@ +import pytest +import torch +import json +import math +import numpy as np + +from allennlp.common.checks import ConfigurationError +from allennlp.common.testing import ( + AllenNlpTestCase, + multi_device, + global_distributed_metric, + run_distributed_test, +) +from allennlp.fairness.bias_metrics import ( + WordEmbeddingAssociationTest, + EmbeddingCoherenceTest, + NaturalLanguageInference, + AssociationWithoutGroundTruth, +) + + +class WordEmbeddingAssociationTestTest(AllenNlpTestCase): + def setup_method(self): + # embedding data from VERB demo + emb_filename = str(self.FIXTURES_ROOT / "fairness" / "bias_embeddings.json") + with open(emb_filename) as emb_file: + emb_data = json.load(emb_file) + + self.X = torch.cat( + [ + torch.Tensor(emb_data["he"]).reshape(1, -1), + torch.Tensor(emb_data["him"]).reshape(1, -1), + ] + ) + self.Y = torch.cat( + [ + torch.Tensor(emb_data["she"]).reshape(1, -1), + torch.Tensor(emb_data["her"]).reshape(1, -1), + ] + ) + self.A = torch.cat( + [ + torch.Tensor(emb_data["engineer"]).reshape(1, -1), + torch.Tensor(emb_data["banker"]).reshape(1, -1), + ] + ) + self.B = torch.cat( + [ + torch.Tensor(emb_data["nurse"]).reshape(1, -1), + torch.Tensor(emb_data["receptionist"]).reshape(1, -1), + ] + ) + + def teardown_method(self): + pass + + def test_invalid_dims(self): + weat = WordEmbeddingAssociationTest() + with pytest.raises(ConfigurationError): + weat(torch.zeros(2), torch.zeros(2), torch.zeros(2), torch.zeros(2)) + with pytest.raises(ConfigurationError): + weat(torch.zeros((2, 2)), torch.zeros((2, 2)), torch.zeros(2), torch.zeros(2)) + with pytest.raises(ConfigurationError): + weat(torch.zeros((2, 2)), torch.zeros((2, 3)), torch.zeros((2, 2)), torch.zeros((2, 2))) + with pytest.raises(ConfigurationError): + weat(torch.zeros((2, 2)), torch.zeros((2, 2)), torch.zeros((2, 3)), torch.zeros((2, 2))) + + @multi_device + def test_weat(self, device: str): + self.X = self.X.to(device) + self.Y = self.Y.to(device) + self.A = self.A.to(device) + self.B = self.B.to(device) + + weat = WordEmbeddingAssociationTest() + test_weat_score = weat(self.X, self.Y, self.A, self.B) + assert test_weat_score.item() == pytest.approx(1.872, rel=1e-4) + + +class EmbeddingCoherenceTestTest(AllenNlpTestCase): + def setup_method(self): + # embedding data from VERB demo + emb_filename = str(self.FIXTURES_ROOT / "fairness" / "bias_embeddings.json") + with open(emb_filename) as emb_file: + emb_data = json.load(emb_file) + + self.X = torch.cat( + [ + torch.Tensor(emb_data["he"]).reshape(1, -1), + torch.Tensor(emb_data["him"]).reshape(1, -1), + ] + ) + self.Y = torch.cat( + [ + torch.Tensor(emb_data["she"]).reshape(1, -1), + torch.Tensor(emb_data["her"]).reshape(1, -1), + ] + ) + self.AB = torch.cat( + [ + torch.Tensor(emb_data["engineer"]).reshape(1, -1), + torch.Tensor(emb_data["banker"]).reshape(1, -1), + torch.Tensor(emb_data["nurse"]).reshape(1, -1), + torch.Tensor(emb_data["receptionist"]).reshape(1, -1), + ] + ) + + def teardown_method(self): + pass + + def test_invalid_dims(self): + ect = EmbeddingCoherenceTest() + with pytest.raises(ConfigurationError): + ect(torch.zeros(2), torch.zeros(2), torch.zeros(2)) + with pytest.raises(ConfigurationError): + ect(torch.zeros((2, 2)), torch.zeros((2, 2)), torch.zeros(2)) + with pytest.raises(ConfigurationError): + ect(torch.zeros((2, 2)), torch.zeros((2, 3)), torch.zeros((2, 2))) + with pytest.raises(ConfigurationError): + ect(torch.zeros((2, 2)), torch.zeros((2, 2)), torch.zeros((2, 3))) + + @multi_device + def test_ect(self, device: str): + self.X = self.X.to(device) + self.Y = self.Y.to(device) + self.AB = self.AB.to(device) + + ect = EmbeddingCoherenceTest() + test_ect_score = ect(self.X, self.Y, self.AB) + assert test_ect_score.item() == pytest.approx(0.800, rel=1e-4) + + +class NaturalLanguageInferenceTest(AllenNlpTestCase): + def test_invalid_dimensions(self): + nli_probabilities = torch.ones(3) + with pytest.raises(ConfigurationError): + NaturalLanguageInference(0)(nli_probabilities) + + nli_probabilities = torch.eye(4) + with pytest.raises(ConfigurationError): + NaturalLanguageInference(0)(nli_probabilities) + + @multi_device + def test_nli(self, device: str): + nli_probabilities = 0.6 * torch.eye(3, device=device) + nli = NaturalLanguageInference(0) + nli(nli_probabilities) + + expected_scores = { + "net_neutral": 0.6 / 3, + "fraction_neutral": 1 / 3, + "threshold_0.5": 1 / 3, + "threshold_0.7": 0.0, + } + assert nli.get_metric(reset=True) == pytest.approx(expected_scores) + assert all([v == 0.0 for k, v in nli.get_metric().items()]) + + def test_distributed_nli(self): + nli_probabilities = 0.6 * torch.eye(3) + expected_scores = { + "net_neutral": 0.6 / 3, + "fraction_neutral": 1 / 3, + "threshold_0.5": 1 / 3, + "threshold_0.7": 0.0, + } + metric_kwargs = {"nli_probabilities": [nli_probabilities, nli_probabilities]} + run_distributed_test( + [-1, -1], + global_distributed_metric, + NaturalLanguageInference(0), + metric_kwargs, + expected_scores, + exact=False, + ) + + +class AssociationWithoutGroundTruthTest(AllenNlpTestCase): + def test_invalid_dimensions(self): + ova_npmixy = AssociationWithoutGroundTruth(2, 2) + Y = torch.eye(3).long() + X = torch.eye(4).long() + with pytest.raises(ConfigurationError): + ova_npmixy(Y, X) + + def test_invalid_num_classes(self): + ova_npmixy = AssociationWithoutGroundTruth(1, 1) + Y = torch.eye(3).long() + X = torch.eye(3).long() + with pytest.raises(ConfigurationError): + ova_npmixy(Y, X) + + @multi_device + def test_pmi_unmasked_computation(self, device: str): + ova_pmi = AssociationWithoutGroundTruth(2, 2, "pmi", "ova") + pairwise_pmi = AssociationWithoutGroundTruth(2, 2, "pmi", "pairwise") + Y = torch.ones(3, 3, device=device).long() + X = torch.eye(3, device=device).long() + + # P(X = 0, Y = 0) = 0 + # P(X = 0, Y = 1) = 2/3 + # P(X = 1, Y = 0) = 0 + # P(X = 1, Y = 1) = 1/3 + # P(X = 0) = 2/3 + # P(X = 1) = 1/3 + # P(Y = 0) = 0 + # P(Y = 1) = 1 + # G(Y = 0 | X = 0, X = rest, PMI) = NaN + # G(Y = 1 | X = 0, X = rest, PMI) = ln(1) - ln(1) = 0.0 + # G(Y = 0 | X = 1, X = rest, PMI) = NaN + # G(Y = 1 | X = 1, X = rest, PMI) = ln(1) - ln(1) = 0.0 + expected_ova_pmi_gaps = { + 0: [np.nan, 0.0], + 1: [np.nan, 0.0], + } + + ova_pmi(Y, X) + test_ova_pmi_gaps = { + k: [(e if not math.isnan(e) else np.nan) for e in v.tolist()] + for k, v in ova_pmi.get_metric().items() + } + assert expected_ova_pmi_gaps == test_ova_pmi_gaps + + ova_pmi(Y, X) + test_ova_pmi_gaps = { + k: [(e if not math.isnan(e) else np.nan) for e in v.tolist()] + for k, v in ova_pmi.get_metric(reset=True).items() + } + assert expected_ova_pmi_gaps == test_ova_pmi_gaps + + test_ova_pmi_gaps = { + k: [(e if not math.isnan(e) else np.nan) for e in v.tolist()] + for k, v in ova_pmi.get_metric(reset=True).items() + } + assert test_ova_pmi_gaps == {0: [np.nan, np.nan], 1: [np.nan, np.nan]} + + expected_pairwise_pmi_gaps = { + 0: {0: [np.nan, 0.0], 1: [np.nan, 0.0]}, + 1: {0: [np.nan, 0.0], 1: [np.nan, 0.0]}, + } + + pairwise_pmi(Y, X) + test_pairwise_pmi_gaps = { + k1: { + k2: [(e if not math.isnan(e) else np.nan) for e in v2.tolist()] + for k2, v2 in v1.items() + } + for k1, v1 in pairwise_pmi.get_metric().items() + } + assert expected_pairwise_pmi_gaps == test_pairwise_pmi_gaps + + pairwise_pmi(Y, X) + test_pairwise_pmi_gaps = { + k1: { + k2: [(e if not math.isnan(e) else np.nan) for e in v2.tolist()] + for k2, v2 in v1.items() + } + for k1, v1 in pairwise_pmi.get_metric(reset=True).items() + } + assert expected_pairwise_pmi_gaps == test_pairwise_pmi_gaps + + test_pairwise_pmi_gaps = { + k1: { + k2: [(e if not math.isnan(e) else np.nan) for e in v2.tolist()] + for k2, v2 in v1.items() + } + for k1, v1 in pairwise_pmi.get_metric(reset=True).items() + } + assert test_pairwise_pmi_gaps == { + 0: {0: [np.nan, np.nan], 1: [np.nan, np.nan]}, + 1: {0: [np.nan, np.nan], 1: [np.nan, np.nan]}, + } + + @multi_device + def test_pmisq_masked_computation(self, device: str): + ova_pmisq = AssociationWithoutGroundTruth(2, 2, "pmisq", "ova") + pairwise_pmisq = AssociationWithoutGroundTruth(2, 2, "pmisq", "pairwise") + Y = torch.ones(3, 3, device=device).long() + X = torch.eye(3, device=device).long() + mask = torch.ones_like(Y).bool() + + expected_ova_pmisq_gaps = { + 0: [np.nan, round(math.log(2), 3)], + 1: [np.nan, round(math.log(0.5), 3)], + } + ova_pmisq(Y, X, mask) + test_ova_pmisq_gaps = { + k: [(round(e, 3) if not math.isnan(e) else np.nan) for e in v.tolist()] + for k, v in ova_pmisq.get_metric().items() + } + assert expected_ova_pmisq_gaps == test_ova_pmisq_gaps + + expected_pairwise_pmisq_gaps = { + 0: {0: [np.nan, 0.0], 1: [np.nan, round(math.log(2), 3)]}, + 1: {0: [np.nan, round(math.log(0.5), 3)], 1: [np.nan, 0.0]}, + } + pairwise_pmisq(Y, X, mask) + test_pairwise_pmisq_gaps = { + k1: { + k2: [(round(e, 3) if not math.isnan(e) else np.nan) for e in v2.tolist()] + for k2, v2 in v1.items() + } + for k1, v1 in pairwise_pmisq.get_metric().items() + } + assert expected_pairwise_pmisq_gaps == test_pairwise_pmisq_gaps + + def test_distributed_npmiy_unmasked_computation(self): + Y = torch.ones(3, 3).long() + X = torch.eye(3).long() + + expected_ova_npmiy_gaps = { + 0: [np.nan, np.nan], + 1: [np.nan, np.nan], + } + metric_kwargs = {"predicted_labels": Y, "protected_variable_labels": X} + run_distributed_test( + [-1, -1], + global_distributed_metric, + AssociationWithoutGroundTruth(2, 2, "npmiy", "ova"), + metric_kwargs, + expected_ova_npmiy_gaps, + exact=True, + ) + + expected_pairwise_npmiy_gaps = { + 0: {0: [np.nan, np.nan], 1: [np.nan, np.nan]}, + 1: {0: [np.nan, np.nan], 1: [np.nan, np.nan]}, + } + run_distributed_test( + [-1, -1], + global_distributed_metric, + AssociationWithoutGroundTruth(2, 2, "npmiy", "pairwise"), + metric_kwargs, + expected_pairwise_npmiy_gaps, + exact=True, + ) + + def test_distributed_npmixy_masked_computation(self): + Y = torch.ones(3, 3).long() + X = torch.eye(3).long() + mask = torch.ones_like(Y).bool() + + expected_ova_npmixy_gaps = { + 0: [np.nan, 0.0], + 1: [np.nan, 0.0], + } + metric_kwargs = {"predicted_labels": Y, "protected_variable_labels": X, "mask": mask} + run_distributed_test( + [-1, -1], + global_distributed_metric, + AssociationWithoutGroundTruth(2, 2, "npmixy", "ova"), + metric_kwargs, + expected_ova_npmixy_gaps, + exact=True, + ) + + expected_pairwise_npmixy_gaps = { + 0: {0: [np.nan, 0.0], 1: [np.nan, 0.0]}, + 1: {0: [np.nan, 0.0], 1: [np.nan, 0.0]}, + } + metric_kwargs = {"predicted_labels": Y, "protected_variable_labels": X, "mask": mask} + run_distributed_test( + [-1, -1], + global_distributed_metric, + AssociationWithoutGroundTruth(2, 2, "npmixy", "pairwise"), + metric_kwargs, + expected_pairwise_npmixy_gaps, + exact=True, + ) diff --git a/tests/fairness/fairness_metrics_test.py b/tests/fairness/fairness_metrics_test.py index 2b1af18a877..adce0d41bd7 100644 --- a/tests/fairness/fairness_metrics_test.py +++ b/tests/fairness/fairness_metrics_test.py @@ -10,12 +10,7 @@ global_distributed_metric, run_distributed_test, ) -from allennlp.fairness.fairness_metrics import ( - Independence, - Separation, - Sufficiency, - DemographicParityWithoutGroundTruth, -) +from allennlp.fairness.fairness_metrics import Independence, Separation, Sufficiency class IndependenceTest(AllenNlpTestCase): @@ -249,197 +244,3 @@ def test_distributed_sufficiency_masked_computation(self): expected_kl_divs, exact=False, ) - - -class DemographicParityWithoutGroundTruthTest(AllenNlpTestCase): - def test_invalid_dimensions(self): - ova_npmixy = DemographicParityWithoutGroundTruth(2, 2) - Y = torch.eye(3).long() - X = torch.eye(4).long() - with pytest.raises(ConfigurationError): - ova_npmixy(Y, X) - - def test_invalid_num_classes(self): - ova_npmixy = DemographicParityWithoutGroundTruth(1, 1) - Y = torch.eye(3).long() - X = torch.eye(3).long() - with pytest.raises(ConfigurationError): - ova_npmixy(Y, X) - - @multi_device - def test_pmi_unmasked_computation(self, device: str): - ova_pmi = DemographicParityWithoutGroundTruth(2, 2, "pmi", "ova") - pairwise_pmi = DemographicParityWithoutGroundTruth(2, 2, "pmi", "pairwise") - Y = torch.ones(3, 3, device=device).long() - X = torch.eye(3, device=device).long() - - # P(X = 0, Y = 0) = 0 - # P(X = 0, Y = 1) = 2/3 - # P(X = 1, Y = 0) = 0 - # P(X = 1, Y = 1) = 1/3 - # P(X = 0) = 2/3 - # P(X = 1) = 1/3 - # P(Y = 0) = 0 - # P(Y = 1) = 1 - # G(Y = 0 | X = 0, X = rest, PMI) = NaN - # G(Y = 1 | X = 0, X = rest, PMI) = ln(1) - ln(1) = 0.0 - # G(Y = 0 | X = 1, X = rest, PMI) = NaN - # G(Y = 1 | X = 1, X = rest, PMI) = ln(1) - ln(1) = 0.0 - expected_ova_pmi_gaps = { - 0: [np.nan, 0.0], - 1: [np.nan, 0.0], - } - - ova_pmi(Y, X) - test_ova_pmi_gaps = { - k: [(e if not math.isnan(e) else np.nan) for e in v.tolist()] - for k, v in ova_pmi.get_metric().items() - } - assert expected_ova_pmi_gaps == test_ova_pmi_gaps - - ova_pmi(Y, X) - test_ova_pmi_gaps = { - k: [(e if not math.isnan(e) else np.nan) for e in v.tolist()] - for k, v in ova_pmi.get_metric(reset=True).items() - } - assert expected_ova_pmi_gaps == test_ova_pmi_gaps - - test_ova_pmi_gaps = { - k: [(e if not math.isnan(e) else np.nan) for e in v.tolist()] - for k, v in ova_pmi.get_metric(reset=True).items() - } - assert test_ova_pmi_gaps == {0: [np.nan, np.nan], 1: [np.nan, np.nan]} - - expected_pairwise_pmi_gaps = { - 0: {0: [np.nan, 0.0], 1: [np.nan, 0.0]}, - 1: {0: [np.nan, 0.0], 1: [np.nan, 0.0]}, - } - - pairwise_pmi(Y, X) - test_pairwise_pmi_gaps = { - k1: { - k2: [(e if not math.isnan(e) else np.nan) for e in v2.tolist()] - for k2, v2 in v1.items() - } - for k1, v1 in pairwise_pmi.get_metric().items() - } - assert expected_pairwise_pmi_gaps == test_pairwise_pmi_gaps - - pairwise_pmi(Y, X) - test_pairwise_pmi_gaps = { - k1: { - k2: [(e if not math.isnan(e) else np.nan) for e in v2.tolist()] - for k2, v2 in v1.items() - } - for k1, v1 in pairwise_pmi.get_metric(reset=True).items() - } - assert expected_pairwise_pmi_gaps == test_pairwise_pmi_gaps - - test_pairwise_pmi_gaps = { - k1: { - k2: [(e if not math.isnan(e) else np.nan) for e in v2.tolist()] - for k2, v2 in v1.items() - } - for k1, v1 in pairwise_pmi.get_metric(reset=True).items() - } - assert test_pairwise_pmi_gaps == { - 0: {0: [np.nan, np.nan], 1: [np.nan, np.nan]}, - 1: {0: [np.nan, np.nan], 1: [np.nan, np.nan]}, - } - - @multi_device - def test_pmisq_masked_computation(self, device: str): - ova_pmisq = DemographicParityWithoutGroundTruth(2, 2, "pmisq", "ova") - pairwise_pmisq = DemographicParityWithoutGroundTruth(2, 2, "pmisq", "pairwise") - Y = torch.ones(3, 3, device=device).long() - X = torch.eye(3, device=device).long() - mask = torch.ones_like(Y).bool() - - expected_ova_pmisq_gaps = { - 0: [np.nan, round(math.log(2), 3)], - 1: [np.nan, round(math.log(0.5), 3)], - } - ova_pmisq(Y, X, mask) - test_ova_pmisq_gaps = { - k: [(round(e, 3) if not math.isnan(e) else np.nan) for e in v.tolist()] - for k, v in ova_pmisq.get_metric().items() - } - assert expected_ova_pmisq_gaps == test_ova_pmisq_gaps - - expected_pairwise_pmisq_gaps = { - 0: {0: [np.nan, 0.0], 1: [np.nan, round(math.log(2), 3)]}, - 1: {0: [np.nan, round(math.log(0.5), 3)], 1: [np.nan, 0.0]}, - } - pairwise_pmisq(Y, X, mask) - test_pairwise_pmisq_gaps = { - k1: { - k2: [(round(e, 3) if not math.isnan(e) else np.nan) for e in v2.tolist()] - for k2, v2 in v1.items() - } - for k1, v1 in pairwise_pmisq.get_metric().items() - } - assert expected_pairwise_pmisq_gaps == test_pairwise_pmisq_gaps - - def test_distributed_npmiy_unmasked_computation(self): - Y = torch.ones(3, 3).long() - X = torch.eye(3).long() - - expected_ova_npmiy_gaps = { - 0: [np.nan, np.nan], - 1: [np.nan, np.nan], - } - metric_kwargs = {"predicted_labels": Y, "protected_variable_labels": X} - run_distributed_test( - [-1, -1], - global_distributed_metric, - DemographicParityWithoutGroundTruth(2, 2, "npmiy", "ova"), - metric_kwargs, - expected_ova_npmiy_gaps, - exact=True, - ) - - expected_pairwise_npmiy_gaps = { - 0: {0: [np.nan, np.nan], 1: [np.nan, np.nan]}, - 1: {0: [np.nan, np.nan], 1: [np.nan, np.nan]}, - } - run_distributed_test( - [-1, -1], - global_distributed_metric, - DemographicParityWithoutGroundTruth(2, 2, "npmiy", "pairwise"), - metric_kwargs, - expected_pairwise_npmiy_gaps, - exact=True, - ) - - def test_distributed_npmixy_masked_computation(self): - Y = torch.ones(3, 3).long() - X = torch.eye(3).long() - mask = torch.ones_like(Y).bool() - - expected_ova_npmixy_gaps = { - 0: [np.nan, 0.0], - 1: [np.nan, 0.0], - } - metric_kwargs = {"predicted_labels": Y, "protected_variable_labels": X, "mask": mask} - run_distributed_test( - [-1, -1], - global_distributed_metric, - DemographicParityWithoutGroundTruth(2, 2, "npmixy", "ova"), - metric_kwargs, - expected_ova_npmixy_gaps, - exact=True, - ) - - expected_pairwise_npmixy_gaps = { - 0: {0: [np.nan, 0.0], 1: [np.nan, 0.0]}, - 1: {0: [np.nan, 0.0], 1: [np.nan, 0.0]}, - } - metric_kwargs = {"predicted_labels": Y, "protected_variable_labels": X, "mask": mask} - run_distributed_test( - [-1, -1], - global_distributed_metric, - DemographicParityWithoutGroundTruth(2, 2, "npmixy", "pairwise"), - metric_kwargs, - expected_pairwise_npmixy_gaps, - exact=True, - ) From dfed580bbb07f3e634470a152e205233d650120b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 13 May 2021 17:20:21 -0700 Subject: [PATCH 22/63] Update transformers requirement from <4.6,>=4.1 to >=4.1,<4.7 (#5199) Updates the requirements on [transformers](https://github.com/huggingface/transformers) to permit the latest version. - [Release notes](https://github.com/huggingface/transformers/releases) - [Commits](https://github.com/huggingface/transformers/compare/v4.1.0...v4.6.0) Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 13587bcd70e..4ac272c483c 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ "scikit-learn", "scipy", "pytest", - "transformers>=4.1,<4.6", + "transformers>=4.1,<4.7", "sentencepiece", "dataclasses;python_version<'3.7'", "filelock>=3.0,<3.1", From f1a1adc7e15309a06f383347380bfd748c50adbd Mon Sep 17 00:00:00 2001 From: Akshita Bhagia <akshita23bhagia@gmail.com> Date: Fri, 14 May 2021 13:50:34 -0700 Subject: [PATCH 23/63] Rename sanity_checks to confidence_checks (#5201) * renaming sanity_checks to confidence_checks * update changelog * docs fix * clean up --- CHANGELOG.md | 5 ++- allennlp/commands/checklist.py | 6 +-- allennlp/common/testing/checklist_test.py | 2 +- ...check_test.py => confidence_check_test.py} | 0 allennlp/common/testing/model_test_case.py | 2 +- allennlp/confidence_checks/__init__.py | 2 + .../normalization_bias_verification.py | 2 +- .../task_checklists/__init__.py | 10 +++++ .../question_answering_suite.py | 4 +- .../sentiment_analysis_suite.py | 4 +- .../task_checklists/task_suite.py | 2 +- .../textual_entailment_suite.py | 4 +- .../task_checklists/utils.py | 0 .../verification_base.py | 0 allennlp/sanity_checks/__init__.py | 11 +++++- .../sanity_checks/task_checklists/__init__.py | 8 ++-- allennlp/training/callbacks/__init__.py | 2 +- ...{sanity_checks.py => confidence_checks.py} | 20 +++++----- allennlp/training/trainer.py | 37 ++++++++++++++----- .../normalization_bias_verification_test.py | 4 +- .../task_checklists/__init__.py | 0 .../sentiment_analysis_suite_test.py | 4 +- .../task_checklists/task_suite_test.py | 2 +- .../task_checklists/utils_test.py | 2 +- tests/training/trainer_test.py | 18 ++++----- 25 files changed, 97 insertions(+), 54 deletions(-) rename allennlp/common/testing/{sanity_check_test.py => confidence_check_test.py} (100%) create mode 100644 allennlp/confidence_checks/__init__.py rename allennlp/{sanity_checks => confidence_checks}/normalization_bias_verification.py (97%) create mode 100644 allennlp/confidence_checks/task_checklists/__init__.py rename allennlp/{sanity_checks => confidence_checks}/task_checklists/question_answering_suite.py (97%) rename allennlp/{sanity_checks => confidence_checks}/task_checklists/sentiment_analysis_suite.py (99%) rename allennlp/{sanity_checks => confidence_checks}/task_checklists/task_suite.py (99%) rename allennlp/{sanity_checks => confidence_checks}/task_checklists/textual_entailment_suite.py (99%) rename allennlp/{sanity_checks => confidence_checks}/task_checklists/utils.py (100%) rename allennlp/{sanity_checks => confidence_checks}/verification_base.py (100%) rename allennlp/training/callbacks/{sanity_checks.py => confidence_checks.py} (74%) rename tests/{sanity_checks => confidence_checks}/normalization_bias_verification_test.py (91%) rename tests/{sanity_checks => confidence_checks}/task_checklists/__init__.py (100%) rename tests/{sanity_checks => confidence_checks}/task_checklists/sentiment_analysis_suite_test.py (89%) rename tests/{sanity_checks => confidence_checks}/task_checklists/task_suite_test.py (96%) rename tests/{sanity_checks => confidence_checks}/task_checklists/utils_test.py (86%) diff --git a/CHANGELOG.md b/CHANGELOG.md index da0b9df6211..a55610dea5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,10 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Use `dist_reduce_sum` in distributed metrics. - Allow Google Cloud Storage paths in `cached_path` ("gs://..."). - Print the first batch to the console by default. +- Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0). ### Added -- Added `HuggingfaceDatasetReader` for using huggingface datasets in AllenNLP -- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.sanity_checks.task_checklists` module. + +- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.confidence_checks.task_checklists` module. - Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files. - Added `allennlp.nn.util.load_state_dict` helper function. - Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers diff --git a/allennlp/commands/checklist.py b/allennlp/commands/checklist.py index 27a061915a4..f6f51fd09b4 100644 --- a/allennlp/commands/checklist.py +++ b/allennlp/commands/checklist.py @@ -1,6 +1,6 @@ """ -The `checklist` subcommand allows you to sanity check your -model's predictions using a trained model and its +The `checklist` subcommand allows you to conduct behavioural +testing for your model's predictions using a trained model and its [`Predictor`](../predictors/predictor.md#predictor) wrapper. """ @@ -15,7 +15,7 @@ from allennlp.common.checks import check_for_gpu, ConfigurationError from allennlp.models.archival import load_archive from allennlp.predictors.predictor import Predictor -from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite +from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite @Subcommand.register("checklist") diff --git a/allennlp/common/testing/checklist_test.py b/allennlp/common/testing/checklist_test.py index c84b82b7afb..615f06219f6 100644 --- a/allennlp/common/testing/checklist_test.py +++ b/allennlp/common/testing/checklist_test.py @@ -1,7 +1,7 @@ from typing import Optional from checklist.test_suite import TestSuite from checklist.test_types import MFT as MinimumFunctionalityTest -from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite +from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite @TaskSuite.register("fake-task-suite") diff --git a/allennlp/common/testing/sanity_check_test.py b/allennlp/common/testing/confidence_check_test.py similarity index 100% rename from allennlp/common/testing/sanity_check_test.py rename to allennlp/common/testing/confidence_check_test.py diff --git a/allennlp/common/testing/model_test_case.py b/allennlp/common/testing/model_test_case.py index 04794e474d5..d01b2773388 100644 --- a/allennlp/common/testing/model_test_case.py +++ b/allennlp/common/testing/model_test_case.py @@ -16,7 +16,7 @@ from allennlp.data.batch import Batch from allennlp.models import load_archive, Model from allennlp.training import GradientDescentTrainer -from allennlp.sanity_checks.normalization_bias_verification import NormalizationBiasVerification +from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification class ModelTestCase(AllenNlpTestCase): diff --git a/allennlp/confidence_checks/__init__.py b/allennlp/confidence_checks/__init__.py new file mode 100644 index 00000000000..9ad124601bb --- /dev/null +++ b/allennlp/confidence_checks/__init__.py @@ -0,0 +1,2 @@ +from allennlp.confidence_checks.verification_base import VerificationBase +from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification diff --git a/allennlp/sanity_checks/normalization_bias_verification.py b/allennlp/confidence_checks/normalization_bias_verification.py similarity index 97% rename from allennlp/sanity_checks/normalization_bias_verification.py rename to allennlp/confidence_checks/normalization_bias_verification.py index 979cab656a0..106da18924d 100644 --- a/allennlp/sanity_checks/normalization_bias_verification.py +++ b/allennlp/confidence_checks/normalization_bias_verification.py @@ -7,7 +7,7 @@ import torch from torch import nn as nn from typing import Tuple, List, Callable -from allennlp.sanity_checks.verification_base import VerificationBase +from allennlp.confidence_checks.verification_base import VerificationBase import logging logger = logging.getLogger(__name__) diff --git a/allennlp/confidence_checks/task_checklists/__init__.py b/allennlp/confidence_checks/task_checklists/__init__.py new file mode 100644 index 00000000000..33ed124a611 --- /dev/null +++ b/allennlp/confidence_checks/task_checklists/__init__.py @@ -0,0 +1,10 @@ +from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite +from allennlp.confidence_checks.task_checklists.sentiment_analysis_suite import ( + SentimentAnalysisSuite, +) +from allennlp.confidence_checks.task_checklists.question_answering_suite import ( + QuestionAnsweringSuite, +) +from allennlp.confidence_checks.task_checklists.textual_entailment_suite import ( + TextualEntailmentSuite, +) diff --git a/allennlp/sanity_checks/task_checklists/question_answering_suite.py b/allennlp/confidence_checks/task_checklists/question_answering_suite.py similarity index 97% rename from allennlp/sanity_checks/task_checklists/question_answering_suite.py rename to allennlp/confidence_checks/task_checklists/question_answering_suite.py index 890ccb6b4ee..a5a1114e888 100644 --- a/allennlp/sanity_checks/task_checklists/question_answering_suite.py +++ b/allennlp/confidence_checks/task_checklists/question_answering_suite.py @@ -6,8 +6,8 @@ from checklist.test_suite import TestSuite from checklist.test_types import MFT from checklist.perturb import Perturb -from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite -from allennlp.sanity_checks.task_checklists import utils +from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite +from allennlp.confidence_checks.task_checklists import utils def _crossproduct(template: CheckListTemplate): diff --git a/allennlp/sanity_checks/task_checklists/sentiment_analysis_suite.py b/allennlp/confidence_checks/task_checklists/sentiment_analysis_suite.py similarity index 99% rename from allennlp/sanity_checks/task_checklists/sentiment_analysis_suite.py rename to allennlp/confidence_checks/task_checklists/sentiment_analysis_suite.py index 79dcfe8a75b..2c68cd9efaf 100644 --- a/allennlp/sanity_checks/task_checklists/sentiment_analysis_suite.py +++ b/allennlp/confidence_checks/task_checklists/sentiment_analysis_suite.py @@ -5,8 +5,8 @@ from checklist.test_types import MFT, INV, DIR, Expect from checklist.editor import Editor from checklist.perturb import Perturb -from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite -from allennlp.sanity_checks.task_checklists import utils +from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite +from allennlp.confidence_checks.task_checklists import utils from allennlp.data.instance import Instance diff --git a/allennlp/sanity_checks/task_checklists/task_suite.py b/allennlp/confidence_checks/task_checklists/task_suite.py similarity index 99% rename from allennlp/sanity_checks/task_checklists/task_suite.py rename to allennlp/confidence_checks/task_checklists/task_suite.py index 85b05902fdb..6ddf00d59b1 100644 --- a/allennlp/sanity_checks/task_checklists/task_suite.py +++ b/allennlp/confidence_checks/task_checklists/task_suite.py @@ -10,7 +10,7 @@ from allennlp.common.registrable import Registrable from allennlp.common.file_utils import cached_path from allennlp.predictors.predictor import Predictor -from allennlp.sanity_checks.task_checklists import utils +from allennlp.confidence_checks.task_checklists import utils logger = logging.getLogger(__name__) diff --git a/allennlp/sanity_checks/task_checklists/textual_entailment_suite.py b/allennlp/confidence_checks/task_checklists/textual_entailment_suite.py similarity index 99% rename from allennlp/sanity_checks/task_checklists/textual_entailment_suite.py rename to allennlp/confidence_checks/task_checklists/textual_entailment_suite.py index 566324b440f..b8e1a810f23 100644 --- a/allennlp/sanity_checks/task_checklists/textual_entailment_suite.py +++ b/allennlp/confidence_checks/task_checklists/textual_entailment_suite.py @@ -5,8 +5,8 @@ from checklist.test_suite import TestSuite from checklist.test_types import MFT from checklist.perturb import Perturb -from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite -from allennlp.sanity_checks.task_checklists import utils +from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite +from allennlp.confidence_checks.task_checklists import utils def _wrap_apply_to_each(perturb_fn: Callable, both: bool = False, *args, **kwargs): diff --git a/allennlp/sanity_checks/task_checklists/utils.py b/allennlp/confidence_checks/task_checklists/utils.py similarity index 100% rename from allennlp/sanity_checks/task_checklists/utils.py rename to allennlp/confidence_checks/task_checklists/utils.py diff --git a/allennlp/sanity_checks/verification_base.py b/allennlp/confidence_checks/verification_base.py similarity index 100% rename from allennlp/sanity_checks/verification_base.py rename to allennlp/confidence_checks/verification_base.py diff --git a/allennlp/sanity_checks/__init__.py b/allennlp/sanity_checks/__init__.py index d35f382f94e..8f569054524 100644 --- a/allennlp/sanity_checks/__init__.py +++ b/allennlp/sanity_checks/__init__.py @@ -1,2 +1,9 @@ -from allennlp.sanity_checks.verification_base import VerificationBase -from allennlp.sanity_checks.normalization_bias_verification import NormalizationBiasVerification +from allennlp.confidence_checks.verification_base import VerificationBase +from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification + +import warnings + +warnings.warn( + "Module 'sanity_checks' is deprecated, please use 'confidence_checks' instead.", + DeprecationWarning, +) diff --git a/allennlp/sanity_checks/task_checklists/__init__.py b/allennlp/sanity_checks/task_checklists/__init__.py index ef0e0d28263..33ed124a611 100644 --- a/allennlp/sanity_checks/task_checklists/__init__.py +++ b/allennlp/sanity_checks/task_checklists/__init__.py @@ -1,10 +1,10 @@ -from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite -from allennlp.sanity_checks.task_checklists.sentiment_analysis_suite import ( +from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite +from allennlp.confidence_checks.task_checklists.sentiment_analysis_suite import ( SentimentAnalysisSuite, ) -from allennlp.sanity_checks.task_checklists.question_answering_suite import ( +from allennlp.confidence_checks.task_checklists.question_answering_suite import ( QuestionAnsweringSuite, ) -from allennlp.sanity_checks.task_checklists.textual_entailment_suite import ( +from allennlp.confidence_checks.task_checklists.textual_entailment_suite import ( TextualEntailmentSuite, ) diff --git a/allennlp/training/callbacks/__init__.py b/allennlp/training/callbacks/__init__.py index 759c5b263a4..3e55e115b43 100644 --- a/allennlp/training/callbacks/__init__.py +++ b/allennlp/training/callbacks/__init__.py @@ -1,6 +1,6 @@ from allennlp.training.callbacks.callback import TrainerCallback from allennlp.training.callbacks.console_logger import ConsoleLoggerCallback -from allennlp.training.callbacks.sanity_checks import SanityChecksCallback +from allennlp.training.callbacks.confidence_checks import ConfidenceChecksCallback from allennlp.training.callbacks.tensorboard import TensorBoardCallback from allennlp.training.callbacks.track_epoch import TrackEpochCallback from allennlp.training.callbacks.wandb import WandBCallback diff --git a/allennlp/training/callbacks/sanity_checks.py b/allennlp/training/callbacks/confidence_checks.py similarity index 74% rename from allennlp/training/callbacks/sanity_checks.py rename to allennlp/training/callbacks/confidence_checks.py index 4e970883d08..e57a0a0a626 100644 --- a/allennlp/training/callbacks/sanity_checks.py +++ b/allennlp/training/callbacks/confidence_checks.py @@ -2,25 +2,27 @@ from allennlp.training.callbacks.callback import TrainerCallback from allennlp.data import TensorDict -from allennlp.sanity_checks.normalization_bias_verification import NormalizationBiasVerification +from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification if TYPE_CHECKING: from allennlp.training.trainer import GradientDescentTrainer +# `sanity_checks` is deprecated and will be removed. @TrainerCallback.register("sanity_checks") -class SanityChecksCallback(TrainerCallback): +@TrainerCallback.register("confidence_checks") +class ConfidenceChecksCallback(TrainerCallback): """ - Performs model sanity checks. + Performs model confidence checks. Checks performed: * `NormalizationBiasVerification` for detecting invalid combinations of bias and normalization layers. - See `allennlp.sanity_checks.normalization_bias_verification` for more details. + See `allennlp.confidence_checks.normalization_bias_verification` for more details. - Note: Any new sanity checks should also be added to this callback. + Note: Any new confidence checks should also be added to this callback. """ def on_start( @@ -54,18 +56,18 @@ def on_batch( self._verification.destroy_hooks() detected_pairs = self._verification.collect_detections() if len(detected_pairs) > 0: - raise SanityCheckError( + raise ConfidenceCheckError( "The NormalizationBiasVerification check failed. See logs for more details." ) -class SanityCheckError(Exception): +class ConfidenceCheckError(Exception): """ - The error type raised when a sanity check fails. + The error type raised when a confidence check fails. """ def __init__(self, message) -> None: super().__init__( message - + "\nYou can disable these checks by setting the trainer parameter `run_sanity_checks` to `False`." + + "\nYou can disable these checks by setting the trainer parameter `run_confidence_checks` to `False`." ) diff --git a/allennlp/training/trainer.py b/allennlp/training/trainer.py index 2d8f666eecf..54d9b59ffb1 100644 --- a/allennlp/training/trainer.py +++ b/allennlp/training/trainer.py @@ -5,6 +5,7 @@ import re import time import traceback +import warnings from contextlib import contextmanager from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Type @@ -23,7 +24,11 @@ from allennlp.data import DataLoader, TensorDict from allennlp.models.model import Model from allennlp.training import util as training_util -from allennlp.training.callbacks import TrainerCallback, SanityChecksCallback, ConsoleLoggerCallback +from allennlp.training.callbacks import ( + TrainerCallback, + ConfidenceChecksCallback, + ConsoleLoggerCallback, +) from allennlp.training.checkpointer import Checkpointer from allennlp.training.learning_rate_schedulers import LearningRateScheduler from allennlp.training.metric_tracker import MetricTracker @@ -263,10 +268,13 @@ class GradientDescentTrainer(Trainer): addition to any other callbacks listed in the `callbacks` parameter. When set to `False`, `DEFAULT_CALLBACKS` are not used. + run_confidence_checks : `bool`, optional (default = `True`) + Determines whether model confidence checks, such as + [`NormalizationBiasVerification`](../../confidence_checks/normalization_bias_verification/), + are run. + run_sanity_checks : `bool`, optional (default = `True`) - Determines whether model sanity checks, such as - [`NormalizationBiasVerification`](../../sanity_checks/normalization_bias_verification/), - are ran. + This parameter is deprecated. Please use `run_confidence_checks` instead. """ @@ -294,7 +302,8 @@ def __init__( num_gradient_accumulation_steps: int = 1, use_amp: bool = False, enable_default_callbacks: bool = True, - run_sanity_checks: bool = True, + run_confidence_checks: bool = True, + **kwargs, ) -> None: super().__init__( serialization_dir=serialization_dir, @@ -304,6 +313,13 @@ def __init__( world_size=world_size, ) + if "run_sanity_checks" in kwargs: + warnings.warn( + "'run_sanity_checks' is deprecated, please use 'run_confidence_checks' instead.", + DeprecationWarning, + ) + run_confidence_checks = kwargs["run_sanity_checks"] + # I am not calling move_to_gpu here, because if the model is # not already on the GPU then the optimizer is going to be wrong. self.model = model @@ -345,8 +361,9 @@ def __init__( self._callbacks = callbacks or [] default_callbacks = list(DEFAULT_CALLBACKS) if enable_default_callbacks else [] - if run_sanity_checks: - default_callbacks.append(SanityChecksCallback) + + if run_confidence_checks: + default_callbacks.append(ConfidenceChecksCallback) for callback_cls in default_callbacks: for callback in self._callbacks: if callback.__class__ == callback_cls: @@ -1014,7 +1031,8 @@ def from_partial_objects( checkpointer: Lazy[Checkpointer] = Lazy(Checkpointer), callbacks: List[Lazy[TrainerCallback]] = None, enable_default_callbacks: bool = True, - run_sanity_checks: bool = True, + run_confidence_checks: bool = True, + **kwargs, ) -> "Trainer": """ This method exists so that we can have a documented method to construct this class using @@ -1106,7 +1124,8 @@ def from_partial_objects( num_gradient_accumulation_steps=num_gradient_accumulation_steps, use_amp=use_amp, enable_default_callbacks=enable_default_callbacks, - run_sanity_checks=run_sanity_checks, + run_confidence_checks=run_confidence_checks, + **kwargs, ) diff --git a/tests/sanity_checks/normalization_bias_verification_test.py b/tests/confidence_checks/normalization_bias_verification_test.py similarity index 91% rename from tests/sanity_checks/normalization_bias_verification_test.py rename to tests/confidence_checks/normalization_bias_verification_test.py index 547242ddccf..8b4a2bae247 100644 --- a/tests/sanity_checks/normalization_bias_verification_test.py +++ b/tests/confidence_checks/normalization_bias_verification_test.py @@ -1,10 +1,10 @@ import torch from allennlp.common.testing import AllenNlpTestCase -from allennlp.common.testing.sanity_check_test import ( +from allennlp.common.testing.confidence_check_test import ( FakeModelForTestingNormalizationBiasVerification, ) -from allennlp.sanity_checks.normalization_bias_verification import NormalizationBiasVerification +from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification class TestNormalizationBiasVerification(AllenNlpTestCase): diff --git a/tests/sanity_checks/task_checklists/__init__.py b/tests/confidence_checks/task_checklists/__init__.py similarity index 100% rename from tests/sanity_checks/task_checklists/__init__.py rename to tests/confidence_checks/task_checklists/__init__.py diff --git a/tests/sanity_checks/task_checklists/sentiment_analysis_suite_test.py b/tests/confidence_checks/task_checklists/sentiment_analysis_suite_test.py similarity index 89% rename from tests/sanity_checks/task_checklists/sentiment_analysis_suite_test.py rename to tests/confidence_checks/task_checklists/sentiment_analysis_suite_test.py index 92a075fa9b0..00284a9f77d 100644 --- a/tests/sanity_checks/task_checklists/sentiment_analysis_suite_test.py +++ b/tests/confidence_checks/task_checklists/sentiment_analysis_suite_test.py @@ -1,4 +1,6 @@ -from allennlp.sanity_checks.task_checklists.sentiment_analysis_suite import SentimentAnalysisSuite +from allennlp.confidence_checks.task_checklists.sentiment_analysis_suite import ( + SentimentAnalysisSuite, +) from allennlp.common.testing import AllenNlpTestCase, requires_gpu from allennlp.models.archival import load_archive from allennlp.predictors import Predictor diff --git a/tests/sanity_checks/task_checklists/task_suite_test.py b/tests/confidence_checks/task_checklists/task_suite_test.py similarity index 96% rename from tests/sanity_checks/task_checklists/task_suite_test.py rename to tests/confidence_checks/task_checklists/task_suite_test.py index 84623511f77..71f9a843650 100644 --- a/tests/sanity_checks/task_checklists/task_suite_test.py +++ b/tests/confidence_checks/task_checklists/task_suite_test.py @@ -1,5 +1,5 @@ import pytest -from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite +from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite from allennlp.common.testing import AllenNlpTestCase from allennlp.common.checks import ConfigurationError from allennlp.models.archival import load_archive diff --git a/tests/sanity_checks/task_checklists/utils_test.py b/tests/confidence_checks/task_checklists/utils_test.py similarity index 86% rename from tests/sanity_checks/task_checklists/utils_test.py rename to tests/confidence_checks/task_checklists/utils_test.py index ce6e17eb902..bf5eb94697f 100644 --- a/tests/sanity_checks/task_checklists/utils_test.py +++ b/tests/confidence_checks/task_checklists/utils_test.py @@ -1,4 +1,4 @@ -from allennlp.sanity_checks.task_checklists import utils +from allennlp.confidence_checks.task_checklists import utils from allennlp.common.testing import AllenNlpTestCase diff --git a/tests/training/trainer_test.py b/tests/training/trainer_test.py index 9028dd27c71..3926adf0ec2 100644 --- a/tests/training/trainer_test.py +++ b/tests/training/trainer_test.py @@ -29,10 +29,10 @@ TrainerCallback, TrackEpochCallback, TensorBoardCallback, - SanityChecksCallback, + ConfidenceChecksCallback, ConsoleLoggerCallback, ) -from allennlp.training.callbacks.sanity_checks import SanityCheckError +from allennlp.training.callbacks.confidence_checks import ConfidenceCheckError from allennlp.training.learning_rate_schedulers import CosineWithRestarts from allennlp.training.learning_rate_schedulers import ExponentialLearningRateScheduler from allennlp.training.momentum_schedulers import MomentumScheduler @@ -49,7 +49,7 @@ TensorField, ) from allennlp.training.optimizers import Optimizer -from allennlp.common.testing.sanity_check_test import ( +from allennlp.common.testing.confidence_check_test import ( FakeModelForTestingNormalizationBiasVerification, ) @@ -814,7 +814,7 @@ def test_trainer_can_log_learning_rates_tensorboard(self): trainer.train() - def test_sanity_check_callback(self): + def test_confidence_check_callback(self): model_with_bias = FakeModelForTestingNormalizationBiasVerification(use_bias=True) inst = Instance({"x": TensorField(torch.rand(3, 1, 4))}) data_loader = SimpleDataLoader([inst, inst], 2) @@ -824,12 +824,12 @@ def test_sanity_check_callback(self): data_loader, num_epochs=1, serialization_dir=self.TEST_DIR, - callbacks=[SanityChecksCallback(serialization_dir=self.TEST_DIR)], + callbacks=[ConfidenceChecksCallback(serialization_dir=self.TEST_DIR)], ) - with pytest.raises(SanityCheckError): + with pytest.raises(ConfidenceCheckError): trainer.train() - def test_sanity_check_default(self): + def test_confidence_check_default(self): model_with_bias = FakeModelForTestingNormalizationBiasVerification(use_bias=True) inst = Instance({"x": TensorField(torch.rand(3, 1, 4))}) data_loader = SimpleDataLoader([inst, inst], 2) @@ -839,7 +839,7 @@ def test_sanity_check_default(self): data_loader=data_loader, num_epochs=1, ) - with pytest.raises(SanityCheckError): + with pytest.raises(ConfidenceCheckError): trainer.train() trainer = GradientDescentTrainer.from_partial_objects( @@ -847,7 +847,7 @@ def test_sanity_check_default(self): serialization_dir=self.TEST_DIR, data_loader=data_loader, num_epochs=1, - run_sanity_checks=False, + run_confidence_checks=False, ) # Check is not run, so no failure. From 047ae341b0738ab7e48cb8cef11d0a0c926bb89c Mon Sep 17 00:00:00 2001 From: Pete <petew@allenai.org> Date: Mon, 17 May 2021 12:25:42 -0700 Subject: [PATCH 24/63] Changes and improvements to how we initialize transformer modules from pretrained models (#5200) * updates * rename 'load_state_dict' -> 'read_state_dict' * fix TransformerStack * more fixes * fix embeddings * fix toolkit tests * fix self attention * fix bimodal encoder tests * fix more tests * fix T5! * fixes * fix backbone * fix * fixes * fix * doc fixes * name changes * patch models branch temporarily * update CHANGELOG * change default dist loading strategy to 'MEM_EFFICIENT' for T5 * fix distilbert test * always use memory efficient distributed loading strategy * Update .github/workflows/ci.yml Co-authored-by: Pete <petew@allenai.org> Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com> --- CHANGELOG.md | 4 + allennlp/commands/diff.py | 6 +- allennlp/common/testing/distributed_test.py | 9 +- allennlp/common/util.py | 12 + allennlp/models/model.py | 2 +- .../modules/backbones/vilbert_backbone.py | 52 +- allennlp/modules/transformer/__init__.py | 2 +- .../modules/transformer/bimodal_attention.py | 5 +- .../transformer/bimodal_connection_layer.py | 2 +- .../modules/transformer/bimodal_encoder.py | 110 +--- allennlp/modules/transformer/layer_norm.py | 7 + allennlp/modules/transformer/output_layer.py | 5 +- .../transformer/positional_encoding.py | 3 + .../modules/transformer/self_attention.py | 70 +-- allennlp/modules/transformer/t5.py | 66 ++- .../transformer/transformer_embeddings.py | 66 ++- .../modules/transformer/transformer_layer.py | 76 +-- .../modules/transformer/transformer_module.py | 495 ++++++++++------ .../modules/transformer/transformer_stack.py | 85 +-- allennlp/nn/util.py | 217 ++++++- scripts/py2md.py | 7 + .../transformer/activation_layer_test.py | 38 +- .../transformer/bimodal_attention_test.py | 103 ++-- .../transformer/bimodal_encoder_test.py | 181 +++--- .../transformer/self_attention_test.py | 193 ++----- tests/modules/transformer/toolkit_test.py | 68 ++- .../transformer_embeddings_test.py | 539 +++++++++--------- .../transformer/transformer_layer_test.py | 529 ++++++++--------- .../transformer/transformer_module_test.py | 81 +-- .../transformer/transformer_stack_test.py | 253 +++----- tests/nn/util_test.py | 80 ++- 31 files changed, 1708 insertions(+), 1658 deletions(-) create mode 100644 allennlp/modules/transformer/layer_norm.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a55610dea5e..92308cd9c2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Use `dist_reduce_sum` in distributed metrics. - Allow Google Cloud Storage paths in `cached_path` ("gs://..."). +- Renamed `nn.util.load_state_dict()` to `read_state_dict` to avoid confusion with `torch.nn.Module.load_state_dict()`. +- `TransformerModule.from_pretrained_module` now only accepts a pretrained model ID (e.g. "bert-base-case") instead of + an actual `torch.nn.Module`. Other parameters to this method have changed as well. - Print the first batch to the console by default. - Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0). @@ -18,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.confidence_checks.task_checklists` module. - Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files. +- Added `nn.util.distributed_device()` helper function. - Added `allennlp.nn.util.load_state_dict` helper function. - Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers such as the `PretrainedTransformerEmbedder` and `PretrainedTransformerMismatchedEmbedder`. diff --git a/allennlp/commands/diff.py b/allennlp/commands/diff.py index 6d86f7db76f..35738ca2237 100644 --- a/allennlp/commands/diff.py +++ b/allennlp/commands/diff.py @@ -19,7 +19,7 @@ from allennlp.commands.subcommand import Subcommand from allennlp.common.file_utils import cached_path -from allennlp.nn.util import load_state_dict +from allennlp.nn.util import read_state_dict logger = logging.getLogger(__name__) @@ -249,10 +249,10 @@ def _get_checkpoint_path(checkpoint: str) -> str: def _diff(args: argparse.Namespace): checkpoint_1_path = _get_checkpoint_path(args.checkpoint1) checkpoint_2_path = _get_checkpoint_path(args.checkpoint2) - checkpoint_1 = load_state_dict( + checkpoint_1 = read_state_dict( checkpoint_1_path, strip_prefix=args.strip_prefix_1, strict=False ) - checkpoint_2 = load_state_dict( + checkpoint_2 = read_state_dict( checkpoint_2_path, strip_prefix=args.strip_prefix_2, strict=False ) for step in checkpoint_diff(checkpoint_1, checkpoint_2, args.scale, args.threshold): diff --git a/allennlp/common/testing/distributed_test.py b/allennlp/common/testing/distributed_test.py index 7ef00e2e0e8..2fae00ff635 100644 --- a/allennlp/common/testing/distributed_test.py +++ b/allennlp/common/testing/distributed_test.py @@ -61,12 +61,19 @@ def run_distributed_test( func: `Callable` `func` needs to be global for spawning the processes, so that it can be pickled. + + start_method: `Optional[str]`, optional (default = `None`) + The start method to use for starting the workers. Defaults to "spawn" for GPU + processes and fork otherwise. """ device_ids = device_ids or [-1, -1] check_for_gpu(device_ids) # "fork" start method is the default and should be preferred, except when we're # running the tests on GPU, in which case we need to use "spawn". - start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork" + if "start_method" in kwargs: + start_method = kwargs.pop("start_method") + else: + start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork" nprocs = world_size = len(device_ids) mp.start_processes( init_process, diff --git a/allennlp/common/util.py b/allennlp/common/util.py index db77d795e8d..4db2ef6b5fe 100644 --- a/allennlp/common/util.py +++ b/allennlp/common/util.py @@ -509,6 +509,18 @@ def is_distributed() -> bool: return dist.is_available() and dist.is_initialized() +def is_global_primary() -> bool: + """ + Checks if the distributed process group is the global primary (rank = 0). + If the distributed process group is not available or has not been initialized, + this trivially returns `True`. + """ + if not is_distributed(): + return True + else: + return dist.get_rank() == 0 + + def sanitize_wordpiece(wordpiece: str) -> str: """ Sanitizes wordpieces from BERT, RoBERTa or ALBERT tokenizers. diff --git a/allennlp/models/model.py b/allennlp/models/model.py index 5ff7c967e8e..2800243a6a1 100644 --- a/allennlp/models/model.py +++ b/allennlp/models/model.py @@ -335,7 +335,7 @@ def _load( # Load state dict. We pass `strict=False` so PyTorch doesn't raise a RuntimeError # if the state dict is missing keys because we handle this case below. - model_state = util.load_state_dict(weights_file, cuda_device=cuda_device) + model_state = util.read_state_dict(weights_file, cuda_device=cuda_device) missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False) # Modules might define a class variable called `authorized_missing_keys`, diff --git a/allennlp/modules/backbones/vilbert_backbone.py b/allennlp/modules/backbones/vilbert_backbone.py index c1b9d1090b7..0f554a7a1d2 100644 --- a/allennlp/modules/backbones/vilbert_backbone.py +++ b/allennlp/modules/backbones/vilbert_backbone.py @@ -7,7 +7,12 @@ from allennlp.data.fields.text_field import TextFieldTensors from allennlp.data.vocabulary import Vocabulary from allennlp.modules.backbones.backbone import Backbone -from allennlp.modules.transformer import BiModalEncoder, ImageFeatureEmbeddings, Embeddings +from allennlp.modules.transformer import ( + BiModalEncoder, + ImageFeatureEmbeddings, + TransformerEmbeddings, + TransformerPooler, +) logger = logging.getLogger(__name__) @@ -23,7 +28,7 @@ class VilbertBackbone(Backbone): def __init__( self, vocab: Vocabulary, - text_embeddings: Embeddings, + text_embeddings: TransformerEmbeddings, image_embeddings: ImageFeatureEmbeddings, encoder: BiModalEncoder, pooled_output_dim: int, @@ -36,7 +41,6 @@ def __init__( self.text_embeddings = text_embeddings self.image_embeddings = image_embeddings self.encoder = encoder - from allennlp.modules.transformer import TransformerPooler self.t_pooler = TransformerPooler(encoder.hidden_size1, pooled_output_dim) self.v_pooler = TransformerPooler(encoder.hidden_size2, pooled_output_dim) @@ -66,44 +70,7 @@ def from_huggingface_model_name( image_fixed_layer: int, fusion_method: str = "sum", ): - from transformers import AutoModel - - transformer = AutoModel.from_pretrained(model_name) - - from copy import deepcopy - - text_embeddings = deepcopy(transformer.embeddings) - - # Albert (and maybe others?) has this "embedding_size", that's different from "hidden_size". - # To get them to the same dimensionality, it uses a linear transform after the embedding - # layer, which we need to pull out and copy here. - if hasattr(transformer.config, "embedding_size"): - config = transformer.config - - from transformers.models.albert.modeling_albert import AlbertModel - - if isinstance(transformer, AlbertModel): - linear_transform = deepcopy(transformer.encoder.embedding_hidden_mapping_in) - else: - logger.warning( - "Unknown model that uses separate embedding size; weights of the linear " - f"transform will not be initialized. Model type is: {transformer.__class__}" - ) - linear_transform = torch.nn.Linear(config.embedding_dim, config.hidden_dim) - - # We can't just use torch.nn.Sequential here, even though that's basically all this is, - # because Sequential doesn't accept *inputs, only a single argument. - - class EmbeddingsShim(torch.nn.Module): - def __init__(self, embeddings: torch.nn.Module, linear_transform: torch.nn.Module): - super().__init__() - self.linear_transform = linear_transform - self.embeddings = embeddings - - def forward(self, *inputs, **kwargs): - return self.linear_transform(self.embeddings(*inputs, **kwargs)) - - text_embeddings = EmbeddingsShim(text_embeddings, linear_transform) + text_embeddings = TransformerEmbeddings.from_pretrained_module(model_name) image_embeddings = ImageFeatureEmbeddings( feature_size=image_feature_dim, @@ -112,7 +79,7 @@ def forward(self, *inputs, **kwargs): ) encoder = BiModalEncoder.from_pretrained_module( - pretrained_module=transformer, + model_name, num_hidden_layers2=image_num_hidden_layers, hidden_size2=image_hidden_size, num_attention_heads2=image_num_attention_heads, @@ -126,6 +93,7 @@ def forward(self, *inputs, **kwargs): fixed_layer1=text_fixed_layer, fixed_layer2=image_fixed_layer, ) + return cls( vocab=vocab, text_embeddings=text_embeddings, diff --git a/allennlp/modules/transformer/__init__.py b/allennlp/modules/transformer/__init__.py index b0b56b90d17..9b944130c7c 100644 --- a/allennlp/modules/transformer/__init__.py +++ b/allennlp/modules/transformer/__init__.py @@ -123,8 +123,8 @@ def forward(self, token_ids: torch.LongTensor, mask: torch.BoolTensor): ``` """ +from allennlp.modules.transformer.layer_norm import LayerNorm from allennlp.modules.transformer.positional_encoding import SinusoidalPositionalEncoding - from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.transformer_embeddings import ( Embeddings, diff --git a/allennlp/modules/transformer/bimodal_attention.py b/allennlp/modules/transformer/bimodal_attention.py index fc6bb4047f9..cc4bf11aa22 100644 --- a/allennlp/modules/transformer/bimodal_attention.py +++ b/allennlp/modules/transformer/bimodal_attention.py @@ -118,10 +118,12 @@ def forward( input_tensor2, attention_mask1=None, attention_mask2=None, - co_attention_mask=None, + co_attention_mask=None, # TODO: is this flag necessary? use_co_attention_mask=False, ): """ + # Parameters + input_tensor1 : `torch.Tensor` Shape `batch_size x seq_len1 x hidden_dim1` where `seq_len1` can be the sequence length @@ -143,7 +145,6 @@ def forward( if you know which words correspond to which regions in the image, this mask can be applied to limit the attention given the bias. use_co_attention_mask : `bool` - # TODO: is this flag necessary? Whether to use co_attention_mask or not, default = `False`. """ diff --git a/allennlp/modules/transformer/bimodal_connection_layer.py b/allennlp/modules/transformer/bimodal_connection_layer.py index 5d7e4f7fc88..f9656c2b7a5 100644 --- a/allennlp/modules/transformer/bimodal_connection_layer.py +++ b/allennlp/modules/transformer/bimodal_connection_layer.py @@ -31,7 +31,7 @@ def forward(self, hidden_states1, input_tensor1, hidden_states2, input_tensor2): class BiModalConnectionLayer(TransformerModule, FromParams): - _huggingface_mapping = {"biAttention": "bimodal_attention", "biOutput": "bimodal_output"} + _pretrained_mapping = {"biAttention": "bimodal_attention", "biOutput": "bimodal_output"} def __init__( self, diff --git a/allennlp/modules/transformer/bimodal_encoder.py b/allennlp/modules/transformer/bimodal_encoder.py index bf5e732e96d..acc993194df 100644 --- a/allennlp/modules/transformer/bimodal_encoder.py +++ b/allennlp/modules/transformer/bimodal_encoder.py @@ -1,14 +1,16 @@ -from typing import Optional, Dict, List, Union +from typing import Optional, List, TYPE_CHECKING + import torch from allennlp.common import FromParams - from allennlp.modules.util import replicate_layers - from allennlp.modules.transformer.transformer_layer import TransformerLayer from allennlp.modules.transformer.bimodal_connection_layer import BiModalConnectionLayer from allennlp.modules.transformer.transformer_module import TransformerModule +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + class BiModalEncoder(TransformerModule, FromParams): """ @@ -46,8 +48,9 @@ class BiModalEncoder(TransformerModule, FromParams): in_batch_pairs: `bool` (default = `False`) """ - _huggingface_mapping = {"layer": "layers1"} - _relevant_module = "encoder" + _pretrained_mapping = {"layer": "layers1"} + _pretrained_relevant_module = ["encoder", "bert.encoder"] + _pretrained_allow_missing = [r"^layers2\..*", r"^c_layer\..*"] def __init__( self, @@ -243,93 +246,14 @@ def forward( ) @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - """ - The `pretrained_module` only supplies one of the modalities. - """ - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["num_hidden_layers1"] = len(submodules["layers1"]) - - final_kwargs["hidden_size1"] = submodules["layers1.0.attention.self.query"].in_features - final_kwargs["num_attention_heads1"] = submodules[ - "layers1.0.attention.self" - ].num_attention_heads - final_kwargs["attention_dropout1"] = submodules["layers1.0.attention.self.dropout"].p - final_kwargs["hidden_dropout1"] = submodules["layers1.0.attention.output.dropout"].p - final_kwargs["intermediate_size1"] = submodules["layers1.0.intermediate.dense"].out_features - final_kwargs["activation"] = submodules["layers1.0.intermediate"].intermediate_act_fn - + final_kwargs["num_hidden_layers1"] = config.num_hidden_layers + final_kwargs["hidden_size1"] = config.hidden_size + final_kwargs["num_attention_heads1"] = config.num_attention_heads + final_kwargs["attention_dropout1"] = config.attention_probs_dropout_prob + final_kwargs["hidden_dropout1"] = config.hidden_dropout_prob + final_kwargs["intermediate_size1"] = config.intermediate_size + final_kwargs["activation"] = config.hidden_act final_kwargs.update(**kwargs) - - return final_kwargs - - def _load_from_pretrained_module( - self, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - ignore_absent_parameters: Optional[List] = None, - ): - if source == "huggingface": - ignore_absent_parameters = ["layers2", "c_layer"] - super()._load_from_pretrained_module( - pretrained_module, source, mapping, ignore_absent_parameters - ) - - @classmethod - def from_pretrained_module( # type: ignore - cls, - pretrained_module: Union[str, torch.nn.Module], - num_hidden_layers2: int, - hidden_size2: int, - combined_hidden_size: int, - intermediate_size2: int, - num_attention_heads2: int, - combined_num_attention_heads: int, - attention_dropout2: float, - hidden_dropout2: float, - biattention_id1: List[int], - biattention_id2: List[int], - fixed_layer1: int, - fixed_layer2: int, - fast_mode: bool = False, - with_coattention: bool = True, - in_batch_pairs: bool = False, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - # **kwargs, - ): - """ - The `pretrained_module` only supplies one of the modalities. - """ - pretrained_module = cls.get_relevant_module( - pretrained_module, source=source, mapping=mapping - ) - final_kwargs = {} - final_kwargs.update(cls._get_input_arguments(pretrained_module, source, mapping)) - final_kwargs["num_hidden_layers2"] = num_hidden_layers2 - final_kwargs["hidden_size2"] = hidden_size2 - final_kwargs["combined_hidden_size"] = combined_hidden_size - final_kwargs["intermediate_size2"] = intermediate_size2 - final_kwargs["num_attention_heads2"] = num_attention_heads2 - final_kwargs["combined_num_attention_heads"] = combined_num_attention_heads - final_kwargs["attention_dropout2"] = attention_dropout2 - final_kwargs["hidden_dropout2"] = hidden_dropout2 - final_kwargs["biattention_id1"] = biattention_id1 - final_kwargs["biattention_id2"] = biattention_id2 - final_kwargs["fixed_layer1"] = fixed_layer1 - final_kwargs["fixed_layer2"] = fixed_layer2 - final_kwargs["fast_mode"] = fast_mode - final_kwargs["with_coattention"] = with_coattention - final_kwargs["in_batch_pairs"] = in_batch_pairs - - return super().from_pretrained_module(pretrained_module, source, mapping, **final_kwargs) + return cls(**final_kwargs) diff --git a/allennlp/modules/transformer/layer_norm.py b/allennlp/modules/transformer/layer_norm.py new file mode 100644 index 00000000000..0302b705c1d --- /dev/null +++ b/allennlp/modules/transformer/layer_norm.py @@ -0,0 +1,7 @@ +import torch + +from allennlp.modules.transformer.transformer_module import TransformerModule + + +class LayerNorm(torch.nn.LayerNorm, TransformerModule): + _pretrained_mapping = {"gamma": "weight", "beta": "bias"} diff --git a/allennlp/modules/transformer/output_layer.py b/allennlp/modules/transformer/output_layer.py index 03dd1f9d5df..ac38a1794b1 100644 --- a/allennlp/modules/transformer/output_layer.py +++ b/allennlp/modules/transformer/output_layer.py @@ -3,16 +3,17 @@ from allennlp.common import FromParams from allennlp.modules.transformer.transformer_module import TransformerModule +from allennlp.modules.transformer.layer_norm import LayerNorm class OutputLayer(TransformerModule, FromParams): - _huggingface_mapping = {"LayerNorm": "layer_norm"} + _pretrained_mapping = {"LayerNorm": "layer_norm"} def __init__(self, input_size: int, hidden_size: int, dropout: float): super().__init__() self.dense = torch.nn.Linear(input_size, hidden_size) - self.layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12) + self.layer_norm = LayerNorm(hidden_size, eps=1e-12) self.dropout = torch.nn.Dropout(dropout) def forward(self, hidden_states, input_tensor): diff --git a/allennlp/modules/transformer/positional_encoding.py b/allennlp/modules/transformer/positional_encoding.py index 1cf63b15c91..b0abc2b91b2 100644 --- a/allennlp/modules/transformer/positional_encoding.py +++ b/allennlp/modules/transformer/positional_encoding.py @@ -42,6 +42,9 @@ def __init__(self, min_timescale: float = 1.0, max_timescale: float = 1.0e4): self.max_timescale = max_timescale def forward(self, input_tensor: torch.Tensor): + """ + Adds a positional encoding to `input_tensor`. + """ # TODO: Another option is to specify the expected size in init, so that we can construct # the positional encoding beforehand, and simply add it to the input tensor in forward. _, timesteps, hidden_dim = input_tensor.size() diff --git a/allennlp/modules/transformer/self_attention.py b/allennlp/modules/transformer/self_attention.py index 6db6aba1fad..d464012de81 100644 --- a/allennlp/modules/transformer/self_attention.py +++ b/allennlp/modules/transformer/self_attention.py @@ -1,4 +1,5 @@ -from typing import Optional, Dict +from typing import Optional, TYPE_CHECKING + import torch from allennlp.common import FromParams @@ -6,6 +7,9 @@ from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.util import apply_mask +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + class SelfAttention(TransformerModule, FromParams): """ @@ -25,8 +29,15 @@ class SelfAttention(TransformerModule, FromParams): Eg. `additive`, `linear`, etc. For a complete list, please check :mod:`allennlp.modules.attention`. """ - _relevant_module = ["encoder.layers.0.attention.self", "encoder.layers.0.attention"] - _huggingface_mapping = {"layer": "layers"} + _pretrained_relevant_module = ["encoder.layers.0.attention.self", "encoder.layers.0.attention"] + _pretrained_mapping = { + "layer": "layers", + "q_lin": "query", + "k_lin": "key", + "v_lin": "value", + "out_lin": "output", + "transformer": "encoder", + } def __init__( self, @@ -83,6 +94,8 @@ def forward( output_attentions: bool = False, ): """ + # Parameters + query_states : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` key_states : `torch.Tensor`, optional @@ -133,47 +146,16 @@ def forward( return outputs @classmethod - def _get_mapping( - cls, pretrained_module=None, source="huggingface", mapping: Optional[Dict[str, str]] = None - ): - combined_mapping = {} - if "huggingface" in source: - combined_mapping.update(cls._huggingface_mapping) - if mapping is not None: - combined_mapping.update(mapping) - if pretrained_module is not None: - for name, _ in pretrained_module.named_modules(): - if "q_lin" in name: - combined_mapping["q_lin"] = "query" - combined_mapping["k_lin"] = "key" - combined_mapping["v_lin"] = "value" - combined_mapping["out_lin"] = "output" - combined_mapping["transformer"] = "encoder" - break - return combined_mapping - - @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["hidden_size"] = submodules["query"].in_features - if hasattr(submodules[""], "num_attention_heads"): - final_kwargs["num_attention_heads"] = submodules[""].num_attention_heads - elif hasattr(submodules[""], "n_heads"): - final_kwargs["num_attention_heads"] = submodules[""].n_heads - final_kwargs["output_linear"] = True # Since this is the distilbert case. + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["output_linear"] = hasattr( + config, "n_heads" + ) # Since this is the distilbert case. + if hasattr(config, "attention_dropout"): + final_kwargs["dropout"] = config.attention_dropout else: - raise AttributeError("Cannot find a relevant attribute for number of heads.") - - final_kwargs["dropout"] = submodules["dropout"].p - + final_kwargs["dropout"] = config.attention_probs_dropout_prob final_kwargs.update(**kwargs) - - return final_kwargs + return cls(**final_kwargs) diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index 83305487b76..15d34f5b2b1 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -1,11 +1,11 @@ """ -Adapted from [HuggingFace] +An implementation of [T5](https://api.semanticscholar.org/CorpusID:204838007), adapted from [HuggingFace] (https://github.com/huggingface/transformers/blob/4c32f9f26e6a84f0d9843fec8757e6ce640bb44e/src/transformers/models/t5/modeling_t5.py). """ # noqa: E401 import math from dataclasses import dataclass -from typing import Optional, Tuple, List, Union, Dict, Any +from typing import Optional, Tuple, List, Union, Dict, TYPE_CHECKING import torch from torch import nn @@ -14,13 +14,18 @@ from allennlp.common import FromParams, Params, Lazy, Registrable from allennlp.common.checks import ConfigurationError -from allennlp.modules.transformer import TransformerModule +from allennlp.modules.transformer.transformer_module import ( + TransformerModule, +) from allennlp.modules.transformer.util import ( apply_mask, get_extended_attention_mask, ) from allennlp.nn.beam_search import BeamSearch +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + # Unfortunately mypy is insane, so I have to wrap these in unions. FloatT = Union[torch.FloatTensor] IntT = Union[torch.IntTensor] @@ -94,7 +99,7 @@ def forward(self, hidden_states) -> FloatT: class T5LayerFF(TransformerModule, FromParams): - _huggingface_mapping = {"DenseReluDense": "ff_proj"} + _pretrained_mapping = {"DenseReluDense": "ff_proj"} def __init__( self, @@ -376,16 +381,19 @@ class T5LayerSelfAttentionOutput: class T5LayerSelfAttention(TransformerModule, FromParams): - _huggingface_mapping = {"SelfAttention": "self_attention"} + _pretrained_mapping = {"SelfAttention": "self_attention"} def __init__( self, self_attention: Optional[T5Attention] = None, layer_norm: Optional[T5LayerNorm] = None, dropout: float = 0.1, + has_relative_attention_bias: bool = False, ): super().__init__() - self.self_attention = self_attention or T5Attention() + self.self_attention = self_attention or T5Attention( + has_relative_attention_bias=has_relative_attention_bias + ) self.layer_norm = layer_norm or T5LayerNorm(hidden_size=self.self_attention.hidden_size) self.dropout = nn.Dropout(dropout) @@ -427,7 +435,7 @@ class T5LayerCrossAttentionOutput: class T5LayerCrossAttention(TransformerModule, FromParams): - _huggingface_mapping = {"EncDecAttention": "enc_dec_attention"} + _pretrained_mapping = {"EncDecAttention": "enc_dec_attention"} def __init__( self, @@ -618,7 +626,7 @@ class T5StackOutput: class T5Stack(TransformerModule, FromParams): - _huggingface_mapping = {"embed_tokens": "token_embeddings", "block": "blocks"} + _pretrained_mapping = {"embed_tokens": "token_embeddings", "block": "blocks"} def __init__( self, @@ -959,7 +967,18 @@ class T5Output: class T5(TransformerModule, Registrable): - _huggingface_mapping = {"shared": "token_embeddings"} + _pretrained_mapping = {"shared": "token_embeddings"} + _tied_weights = { + "token_embeddings.weight": [ + "encoder.token_embeddings.weight", + "decoder.token_embeddings.weight", + "lm_head.weight", + ] + } + # Don't know why HF has this param in their state_dict. It's not used in their model. + _pretrained_ignore = [ + r"^decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight$" + ] default_implementation = "default" @@ -1003,16 +1022,7 @@ def __init__( ) @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ) -> Dict[str, Any]: - from transformers.models.t5 import T5Config - - config: T5Config = pretrained_module.config + def _from_config(cls, config: "PretrainedConfig", **kwargs): attention_kwargs = { "hidden_size": config.d_model, "key_value_proj_dim": config.d_kv, @@ -1039,8 +1049,8 @@ def _get_input_arguments( } ), ) - return { - "encoder": Lazy( + return cls( + encoder=Lazy( T5EncoderStack.basic_encoder, contructor_extras={ "num_blocks": config.num_layers, @@ -1050,7 +1060,7 @@ def _get_input_arguments( "dropout": config.dropout_rate, }, ), - "decoder": Lazy( + decoder=Lazy( T5DecoderStack.basic_decoder, contructor_extras={ "num_blocks": config.num_decoder_layers, @@ -1061,12 +1071,12 @@ def _get_input_arguments( "dropout": config.dropout_rate, }, ), - "decoder_start_token_id": config.decoder_start_token_id, - "pad_token_id": config.pad_token_id, - "eos_token_id": config.eos_token_id, - "vocab_size": config.vocab_size, - "model_dim": config.d_model, - } + decoder_start_token_id=config.decoder_start_token_id, + pad_token_id=config.pad_token_id, + eos_token_id=config.eos_token_id, + vocab_size=config.vocab_size, + model_dim=config.d_model, + ) def _shift_right(self, input_ids, start_value: int): # shift inputs to the right diff --git a/allennlp/modules/transformer/transformer_embeddings.py b/allennlp/modules/transformer/transformer_embeddings.py index 754344d1c0e..3712d9b0a3a 100644 --- a/allennlp/modules/transformer/transformer_embeddings.py +++ b/allennlp/modules/transformer/transformer_embeddings.py @@ -1,11 +1,14 @@ -from typing import Optional, Dict +from typing import Optional, TYPE_CHECKING import torch from allennlp.common import FromParams - +from allennlp.modules.transformer.layer_norm import LayerNorm from allennlp.modules.transformer.transformer_module import TransformerModule +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + class Embeddings(TransformerModule, FromParams): """ @@ -38,7 +41,7 @@ def __init__(self, embeddings: torch.nn.ModuleDict, embedding_size: int, dropout ) ) self.embeddings = embeddings - self.layer_norm = torch.nn.LayerNorm(embedding_size, eps=1e-12) + self.layer_norm = LayerNorm(embedding_size, eps=1e-12) self.dropout = torch.nn.Dropout(dropout) def forward(self, *inputs) -> torch.Tensor: @@ -101,13 +104,27 @@ class TransformerEmbeddings(Embeddings): Optionally apply a linear transform after the dropout, projecting to `output_size`. """ - _relevant_module = "embeddings" - _huggingface_mapping = { + _pretrained_relevant_module = ["embeddings", "bert.embeddings"] + _pretrained_mapping = { "LayerNorm": "layer_norm", "word_embeddings": "embeddings.word_embeddings", "position_embeddings": "embeddings.position_embeddings", "token_type_embeddings": "embeddings.token_type_embeddings", + # Albert is a special case. A linear projection is applied to the embeddings, + # but that linear transformation lives in the encoder. + "albert.embeddings.LayerNorm": "layer_norm", + "albert.embeddings.LayerNorm": "layer_norm", + "albert.embeddings.word_embeddings": "embeddings.word_embeddings", + "albert.embeddings.position_embeddings": "embeddings.position_embeddings", + "albert.embeddings.token_type_embeddings": "embeddings.token_type_embeddings", + "albert.encoder.embedding_hidden_mapping_in": "linear_transform", } + _pretrained_ignore = [ + # Ignore these for Albert case. + r"^albert\.pooler\..*", + r"^albert\.encoder\.albert_layer_groups\..*", + r"^predictions\.*", + ] def __init__( self, @@ -149,6 +166,7 @@ def forward( # type: ignore ) -> torch.Tensor: """ + # Parameters input_ids : `torch.Tensor` Shape `batch_size x seq_len` token_type_ids : `torch.Tensor`, optional @@ -182,32 +200,18 @@ def forward( # type: ignore return embeddings @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["vocab_size"] = submodules["embeddings.word_embeddings"].num_embeddings - final_kwargs["embedding_size"] = submodules["embeddings.word_embeddings"].embedding_dim - final_kwargs["pad_token_id"] = submodules["embeddings.word_embeddings"].padding_idx - final_kwargs["max_position_embeddings"] = submodules[ - "embeddings.position_embeddings" - ].num_embeddings - - if "embeddings.token_type_embeddings" in submodules: - final_kwargs["type_vocab_size"] = submodules[ - "embeddings.token_type_embeddings" - ].num_embeddings - + final_kwargs["vocab_size"] = config.vocab_size + # For Albert, the embedding size is different than the hidden size used + # in the model, so a linear transform is applied. + if hasattr(config, "embedding_size"): + final_kwargs["embedding_size"] = config.embedding_size + final_kwargs["output_size"] = config.hidden_size else: - final_kwargs["type_vocab_size"] = 0 - + final_kwargs["embedding_size"] = config.hidden_size + final_kwargs["pad_token_id"] = config.pad_token_id + final_kwargs["max_position_embeddings"] = config.max_position_embeddings + final_kwargs["type_vocab_size"] = config.type_vocab_size final_kwargs.update(**kwargs) - - return final_kwargs + return cls(**final_kwargs) diff --git a/allennlp/modules/transformer/transformer_layer.py b/allennlp/modules/transformer/transformer_layer.py index 3282b2dbf14..43a76d33144 100644 --- a/allennlp/modules/transformer/transformer_layer.py +++ b/allennlp/modules/transformer/transformer_layer.py @@ -1,15 +1,16 @@ -from typing import Union, Optional, Dict +from typing import Union, Optional, TYPE_CHECKING import torch from allennlp.common import FromParams - from allennlp.modules.transformer.transformer_module import TransformerModule - from allennlp.modules.transformer.activation_layer import ActivationLayer from allennlp.modules.transformer.self_attention import SelfAttention from allennlp.modules.transformer.output_layer import OutputLayer +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + class AttentionLayer(TransformerModule, FromParams): """ @@ -28,8 +29,8 @@ class AttentionLayer(TransformerModule, FromParams): Dropout probability for the `OutputLayer`. """ - _relevant_module = "encoder.layers.0.attention" - _huggingface_mapping = {"layer": "layers"} + _pretrained_relevant_module = "encoder.layer.0.attention" + _pretrained_mapping = {"layer": "layers"} def __init__( self, @@ -52,6 +53,8 @@ def forward( output_attentions: bool = False, ): """ + # Parameters + input_tensor : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` attention_mask : `torch.BoolTensor`, optional @@ -77,25 +80,16 @@ def forward( return outputs @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - final_kwargs["hidden_size"] = submodules["self.query"].in_features - final_kwargs["num_attention_heads"] = submodules["self"].num_attention_heads - final_kwargs["attention_dropout"] = submodules["self.dropout"].p - final_kwargs["hidden_dropout"] = submodules["output.dropout"].p + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["attention_dropout"] = config.attention_probs_dropout_prob + final_kwargs["hidden_dropout"] = config.hidden_dropout_prob final_kwargs.update(**kwargs) - - return final_kwargs + return cls(**final_kwargs) class TransformerLayer(TransformerModule, FromParams): @@ -120,8 +114,8 @@ class TransformerLayer(TransformerModule, FromParams): This is helpful when using the layer in a decoder. """ - _relevant_module = "encoder.layers.0" - _huggingface_mapping = { + _pretrained_relevant_module = "encoder.layer.0" + _pretrained_mapping = { "layer": "layers", "intermediate_act_fn": "act_fn", "crossattention": "cross_attention", @@ -174,6 +168,8 @@ def forward( output_attentions: bool = False, ): """ + # Parameters + hidden_states : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` attention_mask : `torch.BoolTensor`, optional @@ -218,32 +214,14 @@ def forward( return outputs @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["hidden_size"] = submodules["attention.self.query"].in_features - final_kwargs["num_attention_heads"] = submodules["attention.self"].num_attention_heads - final_kwargs["attention_dropout"] = submodules["attention.self.dropout"].p - final_kwargs["hidden_dropout"] = submodules["attention.output.dropout"].p - final_kwargs["intermediate_size"] = submodules["intermediate.dense"].out_features - - # We require the if block as `act_fn` is a function rather than a module, - # so `_get_mapped_submodules` does not automatically fix this. - if source == "huggingface": - final_kwargs["activation"] = getattr(submodules["intermediate"], "intermediate_act_fn") - else: - final_kwargs["activation"] = getattr(submodules["intermediate"], "act_fn") - - final_kwargs["add_cross_attention"] = "cross_attention" in submodules - + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["attention_dropout"] = config.attention_probs_dropout_prob + final_kwargs["hidden_dropout"] = config.hidden_dropout_prob + final_kwargs["intermediate_size"] = config.intermediate_size + final_kwargs["activation"] = config.hidden_act + final_kwargs["add_cross_attention"] = config.add_cross_attention final_kwargs.update(**kwargs) - - return final_kwargs + return cls(**final_kwargs) diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index 861120deca2..2a0ffa092ce 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -1,229 +1,382 @@ -from typing import Optional, Dict, Union, List, Any import logging -import inspect +import os +from os import PathLike +from typing import TYPE_CHECKING, Optional, Dict, Union, List, Any, TypeVar, Type +import re +import warnings import torch +import torch.distributed as dist + +from allennlp.common.util import is_distributed, is_global_primary +from allennlp.nn.util import StateDictType, read_state_dict, load_state_dict_distributed + +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig -from allennlp.common import cached_transformers logger = logging.getLogger(__name__) +_T = TypeVar("_T", bound="TransformerModule") + + class TransformerModule(torch.nn.Module): """ Base class to help with generalized loading of pretrained weights. - `_huggingface_mapping` is an optional mapping for each class, that determines - any differences in the module names between the class modules and the huggingface model's - modules. + Subclasses should override `_from_config()` if you want to instantiate them with + `from_pretrained_module()`. + """ + + _pretrained_mapping: Dict[str, str] = {} + """ + An optional mapping for each class that determines any differences in the module + names between the class modules and the HuggingFace model's modules. + Keys correspond to HuggingFace submodule names, values correspond to submodules names of this module. + """ - `_relevant_module` is an optional str or list of str which contains the expected name of the module - in the huggingface pretrained model. It can be a list to account for different names in different + _pretrained_relevant_module: Optional[Union[str, List[str]]] = None + """ + An optional string or list of strings which contains the expected name of the module + in the HuggingFace pretrained model. It can be a list to account for different names in different models. The search is carried out in the order of the list. """ - _huggingface_mapping: Dict[str, str] = {} - _relevant_module: Optional[Union[str, List[str]]] = None + _pretrained_ignore: Optional[List[str]] = None + """ + An optional list of regular expressions that define which weights to ignore from a pretrained state_dict. + """ + + _pretrained_allow_missing: Optional[List[str]] = None + """ + An optional list of regular expressions that specifies which weights are allowed to be missing + from a pretrained state dictionary. + """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + _tied_weights: Optional[Dict[str, List[str]]] = None + """ + A mapping that defines any weights that need to be tied. Keys and values are parameter names. + The values will be tied to the corresponding key. + """ @classmethod def _get_mapping( cls, - pretrained_module: Optional[torch.nn.Module] = None, - source: str = "huggingface", mapping: Optional[Dict[str, str]] = None, ): """ - Returns the mapping to be used, based on the optional `pretrained_module`. - If `pretrained_module` is not given, the default module-level mapping is returned. + Returns the mapping to be used, based on the optional `mapping` overrides + and the default module-level mapping. """ combined_mapping = {} - if "huggingface" == source: - combined_mapping.update(cls._huggingface_mapping) + combined_mapping.update(cls._pretrained_mapping) if mapping is not None: combined_mapping.update(mapping) return combined_mapping - @classmethod - def _get_mapped_submodules( - cls, - pretrained_module: torch.nn.Module, - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, - ): - """ - Subclasses overload this method, and provide appropriate name mapping based on the source. - """ - submodules = dict(pretrained_module.named_modules()) - combined_mapping = cls._get_mapping(pretrained_module, source, mapping) - for name, module in pretrained_module.named_modules(): - newname = name - for key, val in combined_mapping.items(): - newname = newname.replace(key, val) - submodules[newname] = submodules.pop(name) - return submodules - - def _construct_default_mapping( + def _get_mapped_state_dict( self, - pretrained_module: torch.nn.Module, - source: str = "huggingface", + state_dict: StateDictType, mapping: Optional[Dict[str, str]] = None, - ): + ) -> StateDictType: """ - Recursively constructs the default mapping of parameter names for loading pretrained module weights. - Keys are parameter names from this module, and values are corresponding parameter names in the - expected pretrained module, as per `source`. + Recursively map keys in a HuggingFace `state_dict` to the corresponding keys + for this module and all submodules. """ - combined_mapping = self._get_mapping(pretrained_module, source, mapping) - for name, module in self.named_modules(): - if name != "": - if hasattr(module, "_construct_default_mapping"): - # We handle collisions by giving priority to the outer module's mapping. - combined_mapping = dict( - list( - module._construct_default_mapping( - pretrained_module, source, combined_mapping - ).items() - ) - + list(combined_mapping.items()) - ) - return combined_mapping + return _get_mapped_state_dict(self, state_dict, mapping=mapping) - def _load_from_pretrained_module( - self, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - ignore_absent_parameters: Optional[List] = None, - ): + @classmethod + def _get_relevant_submodule_state( + cls, + state_dict: StateDictType, + relevant_module: Optional[Union[str, List[str]]] = None, + ) -> StateDictType: """ - Loads the weights of the `pretrained_module` into the instance. - Optionally, a `mapping` is specified for any differences in parameter names - between `pretrained_module` and the instance. + Returns the relevant part of the `state_dict`. """ - ignore_absent_parameters = ignore_absent_parameters or [] - combined_mapping = self._construct_default_mapping(pretrained_module, source, mapping) - if mapping is not None: - combined_mapping.update(mapping) + relevant_modules: Optional[List[str]] = None + if relevant_module: + relevant_modules = ( + [relevant_module] if isinstance(relevant_module, str) else relevant_module + ) + elif isinstance(cls._pretrained_relevant_module, str): + relevant_modules = [cls._pretrained_relevant_module] + elif isinstance(cls._pretrained_relevant_module, list): + relevant_modules = cls._pretrained_relevant_module - inverse_mapping = {val: key for key, val in combined_mapping.items()} - pretrained_parameters = dict(pretrained_module.named_parameters()) - for name, parameter in self.named_parameters(): - pretrained_name = name - for key, val in inverse_mapping.items(): - # so that we replace the names of submodules too. - # eg. module.key.anothermodule --> module.val.anothermodule - pretrained_name = pretrained_name.replace(key, val) - if not any( - [pretrained_name.startswith(paraname) for paraname in ignore_absent_parameters] - ): - if pretrained_name not in pretrained_parameters: - raise ValueError( - f"Couldn't find a matching parameter for {name}. Is this module " - "compatible with the pretrained module you're using?" - ) - parameter.data.copy_(pretrained_parameters[pretrained_name].data) + if relevant_modules: + found = False + for module_name in relevant_modules: + relevant_keys = set( + [key for key in state_dict.keys() if key.startswith(module_name + ".")] + ) + if relevant_keys: + # Only keep elements of state dict that correspond to the relevant module. + state_dict = { + key.replace(module_name + ".", "", 1): value + for key, value in state_dict.items() + if key in relevant_keys + } + found = True + break + + if not found: + warnings.warn( + f"{relevant_modules} was not found at top level of state_dict!", UserWarning + ) + + return state_dict @classmethod - def _get_input_arguments( + def _get_pretrained_state_dict( cls, - pretrained_module: torch.nn.Module, - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ) -> Dict[str, Any]: + model_name: str, + weights_path: Optional[Union[str, PathLike]] = None, + relevant_module: Optional[Union[str, List[str]]] = None, + ignore: Optional[List[str]] = None, + ) -> StateDictType: """ - Constructs the arguments required for instantiating an object of this class, using - the values from `pretrained_module`. + Get a HuggingFace pretrained `state_dict` corresponding to this module. """ - return kwargs + if weights_path is None: + from transformers.file_utils import WEIGHTS_NAME + + # First see if we can find the weights locally. + if os.path.isdir(model_name): + local_weights_path = os.path.join(model_name, WEIGHTS_NAME) + if os.path.isfile(local_weights_path): + logger.info("Found weights at local path %s", local_weights_path) + weights_path = local_weights_path + + # If we haven't found locally, we assume model ID corresponds to a model + # on the HuggingFace Hub. + if weights_path is None: + from allennlp.common.file_utils import cached_path + + weights_path = cached_path(f"hf://{model_name}/{WEIGHTS_NAME}") + + # Now load the state dict. + logger.info("Reading state dict from %s", weights_path) + state_dict = read_state_dict( + weights_path, + ignore=ignore if ignore is not None else cls._pretrained_ignore, + strict=False, + ) + + # Keep just the relevant_module, remove everything else. + state_dict = cls._get_relevant_submodule_state(state_dict, relevant_module=relevant_module) + + return state_dict @classmethod - def get_relevant_module( - cls, - pretrained_module: Union[str, torch.nn.Module], - relevant_module: Optional[Union[str, List[str]]] = None, - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, + def _from_config(cls: Type[_T], config: "PretrainedConfig", **kwargs) -> _T: + """ + Instantiate this module from a HuggingFace config. Subclasses should override + this method if you want to be able to instantiate them with `from_pretrained_module()`. + """ + raise NotImplementedError + + def tie_weights(self) -> None: + """ + Tie weights according to the `_tied_weights` class attribute. + + This should always be called after loading a state dictionary. It will be called + automatically within `from_pretrained_module()`. + """ + if self._tied_weights: + param_dict = dict(self.named_parameters()) + param_dict.update(dict(self.named_buffers())) + for anchor_name, free_names in self._tied_weights.items(): + for free_name in free_names: + param_dict[free_name] = param_dict[anchor_name] + + @classmethod + def from_pretrained_module( + cls: Type[_T], + model_name: str, + *, load_weights: bool = True, - ): + weights_path: Optional[Union[str, PathLike]] = None, + auto_config_kwargs: Optional[Dict[str, Any]] = None, + mapping: Optional[Dict[str, str]] = None, + relevant_module: Optional[Union[str, List[str]]] = None, + ignore: Optional[List[str]] = None, + allow_missing: Optional[List[str]] = None, + strict: bool = True, + **kwargs, + ) -> _T: """ - Returns the relevant underlying module given a model name/object. + Initialize this module from a corresponding model on HuggingFace. + + !!! Note + This method is only available for subclasses that implement `_from_config()`. + Otherwise a `NotImplementedError` will be raised. # Parameters - pretrained_module : `Union[str, torch.nn.Module]` - Name of the transformer model containing the layer, - or the actual layer (not the model object). - relevant_module : `Optional[Union[str, List[str]]]`, optional - Name of the desired module. Defaults to cls._relevant_module. - source : `str`, optional - Where the model came from. Default - huggingface. - mapping : `Dict[str, str]`, optional - Optional mapping that determines any differences in the module names - between the class modules and the input model's modules. - Default - cls._huggingface_mapping - load_weights : `bool`, optional - Whether or not to load the pretrained weights. - Default is `True`. - """ - if isinstance(pretrained_module, str): - pretrained_module = cached_transformers.get( - pretrained_module, False, load_weights=load_weights - ) + model_name : `str` + The model identifier or path. - relevant_module = relevant_module or cls._relevant_module + load_weights : `bool`, optional (default = `True`) + Whether to download and load the pretrained weights. If `False`, the + weights are left uninitialized. - if relevant_module is not None: - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - # If the relevant_module is not found, we assume that the pretrained_module - # is already the relevant module. - if isinstance(relevant_module, str): - relevant_module = [relevant_module] - found = False - for module in relevant_module: - if module in submodules: - pretrained_module = submodules[module] - found = True - break + weights_path : `Optional[Union[str, PathLike]]`, optional (default = `None`) + When `load_weights` is `True`, this can be set to override the weights file. + Otherwise the default weights from the pretrained model are used. - if not found: - logger.warning( - "{} was not found! The submodules are: {}".format( - relevant_module, submodules.keys() + auto_config_kwargs : `Optional[Dict[str, Any]]`, optional (default = `None`) + Optional key-word arguments to pass to `transformers.AutoConfig.from_pretrained()` + to load the pretrained model's configuration file. + + mapping : `Optional[Dict[str, str]]`, optional (default = `None`) + Optional mapping that determines any differences in the submodule names + between this module and the pretrained model from HuggingFace. + If not given, the class's default is used: `cls._pretrained_mapping`. + + relevant_module : `Optional[str]`, optional (default = `None`) + An optional submodule of the HuggingFace module to initialize weights from. + This is only relevant when `load_weights` is `True`. + If not given, the class's default is used: `cls._pretrained_relevant_module`. + + ignore : `Optional[List[str]]`, optional (default = `None`) + An optional list of regular expressions that define which weights to ignore + from a pretrained state_dict. + This is only relevant when `load_weights` is `True`. + If not specified, the class's default is used: `cls._pretrained_ignore`. + + allow_missing: `Optional[List[str]]`, optional (default = `None`) + An optional list of regular expressions that specifies which weights are allowed to be missing + from the pretrained state dictionary. + This is only relevant when `load_weights` is `True`. + If not specified, the class's default is used: `cls._pretrained_allow_missing`. + + strict : `bool`, optional (default = `True`) + Whether to load the `state_dict` in "strict" model. This only applies + when `load_weights` is `True`. + + **kwargs : `Any` + Key word arguments to pass to `cls.from_config()` when instantiating the module. + """ # noqa: E501 + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(model_name, **(auto_config_kwargs or {})) + model = cls._from_config(config, **kwargs) + + if load_weights: + state_dict: Optional[StateDictType] = None + if is_global_primary(): + # Load the pretrained HuggingFace state_dict. + pretrained_state_dict = cls._get_pretrained_state_dict( + model_name, + weights_path=weights_path, + relevant_module=relevant_module, + ignore=ignore, + ) + # Now map keys from the HuggingFace state_dict to the corresponding keys from + # this class. This is called recursively on each submodule of the current module. + state_dict = model._get_mapped_state_dict(pretrained_state_dict, mapping=mapping) + + missing_keys: List[str] + unexpected_keys: List[str] + error_msgs: List[str] = [] + if not is_distributed(): + assert state_dict is not None + logger.info("Loading state_dict into module") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + else: + # We're in distributed training. `state_dict` is `None` for all process groups + # except the global primary. + # Syncronize here since non-primary process groups will have to wait for the primary + # to load the state_dict into memory. + dist.barrier() + # Now load the state dict into the model. + logger.info("Loading state_dict into module (MEMORY_EFFICIENT strategy)") + missing_keys, unexpected_keys = load_state_dict_distributed( + model, state_dict, strict=False + ) + + # Exclude any keys in `missing_keys` that match with the `allow_missing` + # regular expressions. + if allow_missing is None: + allow_missing = cls._pretrained_allow_missing + if allow_missing: + missing_keys = [ + k for k in missing_keys if not any(re.match(p, k) for p in allow_missing) + ] + + # Allow missing keys in state_dict for params that are going to be tied. + for param_names in (model._tied_weights or {}).values(): + for param_name in param_names: + if param_name in missing_keys: + missing_keys.remove(param_name) + + if missing_keys: + error_msgs.append( + "Missing key(s) in state_dict: {}".format( + ", ".join(f'"{k}"' for k in missing_keys) + ) + ) + if unexpected_keys: + error_msgs.append( + "Unexpected key(s) in state_dict: {}".format( + ", ".join(f'"{k}"' for k in unexpected_keys) ) ) - return pretrained_module - @classmethod - def from_pretrained_module( - cls, - pretrained_module: Union[str, torch.nn.Module], - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, - load_weights: bool = True, - **kwargs, - ): - """ - Creates and returns an instance of the class, by using the weights - (and the architecture, by default) of the `pretrained_module`. - Optionally, the architecture can be changed by providing arguments. - """ - accepted_args = inspect.getfullargspec(cls).args - accepted_args.remove("self") - for key in kwargs: - assert key in accepted_args, ( - "{} is not a valid argument for creating an instance of `{}`. " - "Accepted arguments are {}.".format(key, cls.__name__, accepted_args) - ) + if error_msgs and strict: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + cls.__name__, "\n\t".join(error_msgs) + ) + ) + + # If there were error messages but we're not loading in 'strict' mode, + # we just issue warnings from the logger. + for msg in error_msgs: + logger.warning(msg) - pretrained_module = cls.get_relevant_module( - pretrained_module, source=source, mapping=mapping, load_weights=load_weights + model.tie_weights() + + return model + + +def _get_mapped_state_dict( + module: torch.nn.Module, + state_dict: StateDictType, + mapping: Optional[Dict[str, str]] = None, +) -> StateDictType: + # First fix all top-level keys according to `combined_mapping`. + combined_mapping = module._get_mapping(mapping) if isinstance(module, TransformerModule) else {} + for hf_key, cls_key in sorted( + # Sort by most specific key first. + combined_mapping.items(), + key=lambda x: x[0].count("."), + reverse=True, + ): + relevant_keys = set( + [key for key in state_dict.keys() if (key == hf_key or key.startswith(hf_key + "."))] ) - final_kwargs = cls._get_input_arguments(pretrained_module, source, mapping) - final_kwargs.update(kwargs) - module = cls(**final_kwargs) - module._load_from_pretrained_module(pretrained_module, source, mapping) - return module + for key in relevant_keys: + new_key = key.replace(hf_key, cls_key, 1) + # We have to be careful not to overwrite an entry that we might have updated + # on a previous iteration of this loop due to having a more specific key. + if new_key not in state_dict: + state_dict[new_key] = state_dict.pop(key) + + # Now loop through the submodules, calling this function on each submodule. + for name, submodule in module.named_children(): + # Pull-out the part of the state_dict corresponding to just this submodule. + relevant_keys = set([key for key in state_dict.keys() if key.startswith(name + ".")]) + module_state_dict = { + key.replace(name + ".", "", 1): state_dict.pop(key) for key in relevant_keys + } + # Recursively call this function from the submodule to map this part + # of the state_dict. + module_state_dict = _get_mapped_state_dict(submodule, module_state_dict) + # And then update the full state_dict. + for key, value in module_state_dict.items(): + state_dict[name + "." + key] = value + + return state_dict diff --git a/allennlp/modules/transformer/transformer_stack.py b/allennlp/modules/transformer/transformer_stack.py index 09fb1d2bc40..7bc4a7247d3 100644 --- a/allennlp/modules/transformer/transformer_stack.py +++ b/allennlp/modules/transformer/transformer_stack.py @@ -1,14 +1,17 @@ -from typing import Union, Optional, Dict +from typing import Union, Optional, TYPE_CHECKING import logging import torch from allennlp.common import FromParams - from allennlp.modules.util import replicate_layers from allennlp.modules.transformer.transformer_layer import TransformerLayer from allennlp.modules.transformer.transformer_module import TransformerModule +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + + logger = logging.getLogger(__name__) @@ -38,8 +41,8 @@ class TransformerStack(TransformerModule, FromParams): This is helpful when using the `TransformerStack` as a decoder. """ - _huggingface_mapping = {"layer": "layers"} - _relevant_module = "encoder" + _pretrained_mapping = {"layer": "layers"} + _pretrained_relevant_module = ["encoder", "bert.encoder"] def __init__( self, @@ -86,6 +89,8 @@ def forward( output_hidden_states: bool = False, ): """ + # Parameters + hidden_states : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` attention_mask : `torch.BoolTensor`, optional @@ -129,67 +134,15 @@ def forward( ) @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["num_hidden_layers"] = len(submodules["layers"]) - - final_kwargs["hidden_size"] = submodules["layers.0.attention.self.query"].in_features - final_kwargs["num_attention_heads"] = submodules[ - "layers.0.attention.self" - ].num_attention_heads - final_kwargs["attention_dropout"] = submodules["layers.0.attention.self.dropout"].p - final_kwargs["hidden_dropout"] = submodules["layers.0.attention.output.dropout"].p - final_kwargs["intermediate_size"] = submodules["layers.0.intermediate.dense"].out_features - - # We require the if block as `act_fn` is a function rather than a module, - # so `_get_mapped_submodules` does not automatically fix this. - if source == "huggingface": - final_kwargs["activation"] = getattr( - submodules["layers.0.intermediate"], "intermediate_act_fn" - ) - else: - final_kwargs["activation"] = getattr(submodules["layers.0.intermediate"], "act_fn") - - final_kwargs["add_cross_attention"] = "layers.0.cross_attention" in submodules - + final_kwargs["num_hidden_layers"] = config.num_hidden_layers + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["add_cross_attention"] = config.add_cross_attention + final_kwargs["attention_dropout"] = config.attention_probs_dropout_prob + final_kwargs["hidden_dropout"] = config.hidden_dropout_prob + final_kwargs["intermediate_size"] = config.intermediate_size + final_kwargs["activation"] = config.hidden_act final_kwargs.update(**kwargs) - - return final_kwargs - - @classmethod - def from_pretrained_module( # type: ignore - cls, - pretrained_module: Union[str, torch.nn.Module], - num_hidden_layers: Optional[Union[int, range]] = None, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - load_weights: bool = True, - **kwargs, - ): - final_kwargs = {} - if num_hidden_layers is not None: - if isinstance(num_hidden_layers, range): - if mapping is None: - mapping = {} - for num_layer, mapped in enumerate(num_hidden_layers): - mapping[str(mapped)] = str(num_layer) - final_kwargs["num_hidden_layers"] = len(num_hidden_layers) - else: - final_kwargs["num_hidden_layers"] = num_hidden_layers - - return super().from_pretrained_module( - pretrained_module, - source=source, - mapping=mapping, - load_weights=load_weights, - **final_kwargs, - ) + return cls(**final_kwargs) diff --git a/allennlp/nn/util.py b/allennlp/nn/util.py index 67a623f98e3..d25239b27f3 100644 --- a/allennlp/nn/util.py +++ b/allennlp/nn/util.py @@ -4,11 +4,12 @@ import copy from collections import defaultdict, OrderedDict +from itertools import chain import json import logging from os import PathLike import re -from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, NamedTuple import math import numpy @@ -16,11 +17,18 @@ import torch.distributed as dist from allennlp.common.checks import ConfigurationError -from allennlp.common.util import int_to_device, is_distributed +from allennlp.common.util import int_to_device, is_distributed, is_global_primary logger = logging.getLogger(__name__) T = TypeVar("T") +StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"] + +_MODULE_SHARDED_FLAG = "_is_sharded_allennlp" +""" +This flag is used to indicate when a module's parameters have been sharded across +distributed workers. +""" def move_to_device(obj, device: Union[torch.device, int]): @@ -926,7 +934,7 @@ def inner_device_mapping(storage: torch.Storage, location) -> torch.Storage: return inner_device_mapping -def load_state_dict( +def read_state_dict( path: Union[PathLike, str], strip_prefix: Optional[str] = None, ignore: Optional[List[str]] = None, @@ -934,7 +942,7 @@ def load_state_dict( cuda_device: int = -1, ) -> Dict[str, torch.Tensor]: """ - Load a PyTorch model state dictionary from a checkpoint at the given `path`. + Read a PyTorch model state dictionary from a checkpoint at the given `path`. # Parameters @@ -2110,6 +2118,17 @@ def tiny_value_of_dtype(dtype: torch.dtype): _V = TypeVar("_V", int, float, torch.Tensor) +def distributed_device() -> torch.device: + """ + Get the correct `torch.device` of the current process to use for distributed point-to-point communication. + """ + if not is_distributed(): + raise RuntimeError( + "'distributed_device()' can only be called within a distributed process group" + ) + return int_to_device(-1 if dist.get_backend() != "nccl" else torch.cuda.current_device()) + + def dist_reduce(value: _V, reduce_op, **kwargs) -> _V: """ Reduces the given `value` across all distributed worker nodes according the given @@ -2134,7 +2153,7 @@ def dist_reduce(value: _V, reduce_op, **kwargs) -> _V: """ if not is_distributed(): return value - device = int_to_device(-1 if dist.get_backend() != "nccl" else torch.cuda.current_device()) + device = distributed_device() value_tensor = torch.tensor(value, device=device, **kwargs) dist.all_reduce(value_tensor, op=reduce_op) @@ -2157,3 +2176,191 @@ def dist_reduce_sum(value: _V, **kwargs) -> _V: if not is_distributed(): return value return dist_reduce(value, dist.ReduceOp.SUM, **kwargs) + + +def _collect_state_dict( + module: torch.nn.Module, state_dict: Optional[StateDictType], recurse: bool = True +) -> Tuple[StateDictType, List[str], List[str]]: + """ + Collect a module's state dict across distributed processes. + + Returns the syncronized state dictionary, which will always be a valid state dict, + and then the missing and unexpected keys corresponding to the original `state_dict`. + Parameters that missing from the original `state_dict` will be populated from the + corresponding parameter in the primary processes' module's state dict. + + !!! Note + + `missing_keys` and `unexpected_keys` are only populated in the primary process. + """ + # This is the device we'll use for the broadcast operation. + dist_device = distributed_device() + # This is the device we'll put all tensors on in the returned state dict. + state_dict_device = ( + int_to_device(-1) if not state_dict else state_dict[list(state_dict.keys())[0]].device + ) + + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + + # Gather current state dict and prepare to iterator over it. + # We iterate over this state dict instead of `state_dict` so we can be sure + # that the order is consistent across processes. + # We'll also update this state dict as we go and return it at the end. + if recurse: + current_state_dict = module.state_dict() + else: + # Only collect state of direct members, including both parameters and buffers. + current_state_dict = OrderedDict( + chain( + # Paramaters + ((n, p.data) for (n, p) in module.named_parameters(recurse=False)), + # Buffers + module.named_buffers(recurse=False), + ) + ) + + keys = list(current_state_dict.keys()) + + # Gather unexpected_keys. + if is_global_primary(): + assert state_dict is not None + module_keys = set(module.state_dict().keys()) + for key in state_dict: + if key not in module_keys: + unexpected_keys.append(key) + + for key in keys: + tensor = current_state_dict[key] + if is_global_primary(): + assert state_dict is not None + if key in state_dict: + # Update `tensor` to the value in `state_dict`. + tensor = state_dict[key] + else: + missing_keys.append(key) + tensor = tensor.to(dist_device) + dist.broadcast(tensor, 0) + current_state_dict[key] = tensor.to(state_dict_device) + + return current_state_dict, missing_keys, unexpected_keys + + +class _LoadStateDictResult(NamedTuple): + missing_keys: List[str] + unexpected_keys: List[str] + + +def load_state_dict_distributed( + module: torch.nn.Module, state_dict: Optional[StateDictType], strict: bool = True +) -> _LoadStateDictResult: + """ + Load a `state_dict` to the `module` within a distributed process. Only the global + primary process requires the `state_dict` to not be `None`. All other processes + will have the state tensors broadcasted to them one-by-one. + + If `strict` is `True`, then the keys of `state_dict` must exactly match the keys + returned by `module.state_dict()`. + + !!! Note + The returned `missing_keys` and `unexpected_keys` will only be accurate + in the primary process. + + # Returns + + `_LoadStateDictResult` + A `NamedTuple` with `missing_keys` and `unexpected_keys` fields, both of which + are lists of strings. + + # Raises + + `RuntimeError` + If `strict` is `True` and there are missing or unexpected keys. + + """ + if not is_distributed(): + return module.load_state_dict(state_dict, strict=strict) + + if is_global_primary(): + assert state_dict is not None + else: + assert state_dict is None + + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + + submodules = dict(module.named_children()) + + def update_key_list(original, updates): + for key in updates: + if key not in original: + original.append(key) + + # If we've found a sharded module or there aren't any more submodules of the current module, + # we collect the state_dict and load it now instead of recursing further. + if getattr(module, _MODULE_SHARDED_FLAG, False) or not submodules: + # Collect. + state_dict, _missing_keys, _unexpected_keys = _collect_state_dict(module, state_dict) + assert state_dict is not None + update_key_list(missing_keys, _missing_keys) + update_key_list(unexpected_keys, _unexpected_keys) + # And load. + _missing_keys, _unexpected_keys = module.load_state_dict(state_dict, strict=False) + update_key_list(missing_keys, _missing_keys) + update_key_list(unexpected_keys, _unexpected_keys) + else: + # We'll recursively call this function on each submodule, but first we need + # to collect any parameters that are direct members of this module. + direct_member_state_dict, _missing_keys, _unexpected_keys = _collect_state_dict( + module, state_dict, recurse=False + ) + update_key_list(missing_keys, _missing_keys) + update_key_list(unexpected_keys, _unexpected_keys) + + # `_missing_keys` here will contain any keys corresponding to submodules, but + # we'll remove those below. + _missing_keys, _unexpected_keys = module.load_state_dict( + direct_member_state_dict, strict=False + ) + update_key_list(missing_keys, _missing_keys) + update_key_list(unexpected_keys, _unexpected_keys) + + # Okay, now for the recursive part. + for name, submodule in submodules.items(): + # Update `missing_keys` to remove keys corresponding to this submodule. + # If they are actually missing after this step, we add them back in below. + missing_keys = [k for k in missing_keys if not k.startswith(name + ".")] + submodule_state_dict: Optional[StateDictType] = None + if is_global_primary(): + assert state_dict is not None + submodule_state_dict = { + key.replace(name + ".", "", 1): value + for key, value in state_dict.items() + if key.startswith(name + ".") + } + _missing_keys, _unexpected_keys = load_state_dict_distributed( + submodule, submodule_state_dict, strict=False + ) + update_key_list(missing_keys, [f"{name}.{key}" for key in _missing_keys]) + update_key_list(unexpected_keys, [f"{name}.{key}" for key in _unexpected_keys]) + + if strict: + error_msgs: List[str] = [] + if missing_keys: + error_msgs.append( + "Missing key(s) in state_dict: {}".format(", ".join(f'"{k}"' for k in missing_keys)) + ) + if unexpected_keys: + error_msgs.append( + "Unexpected key(s) in state_dict: {}".format( + ", ".join(f'"{k}"' for k in unexpected_keys) + ) + ) + if error_msgs: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + module.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + + return _LoadStateDictResult(missing_keys, unexpected_keys) diff --git a/scripts/py2md.py b/scripts/py2md.py index 82a31565485..c8bc1ca1d43 100755 --- a/scripts/py2md.py +++ b/scripts/py2md.py @@ -279,6 +279,13 @@ class AllenNlpFilterProcessor(Struct): "__call__", "__iter__", "InfluenceInterpreter._calculate_influence_scores", + "TransformerModule._from_config", + "TransformerModule._pretrained_mapping", + "TransformerModule._pretrained_relevant_module", + "TransformerModule._pretrained_ignore", + "TransformerModule._pretrained_allow_missing", + "TransformerModule._distributed_loading_strategy", + "TransformerModule._tied_weights", } def process(self, graph, _resolver): diff --git a/tests/modules/transformer/activation_layer_test.py b/tests/modules/transformer/activation_layer_test.py index 8c1b7ebef26..2af0338a92e 100644 --- a/tests/modules/transformer/activation_layer_test.py +++ b/tests/modules/transformer/activation_layer_test.py @@ -1,32 +1,34 @@ -import copy import torch +import pytest from allennlp.common import Params from allennlp.modules.transformer import ActivationLayer -from allennlp.common.testing import AllenNlpTestCase -class TestActivationLayer(AllenNlpTestCase): - def setup_method(self): - super().setup_method() +@pytest.fixture +def params_dict(): + return { + "hidden_size": 5, + "intermediate_size": 3, + "activation": "relu", + } - self.params_dict = { - "hidden_size": 5, - "intermediate_size": 3, - "activation": "relu", - } - params = Params(copy.deepcopy(self.params_dict)) +@pytest.fixture +def params(params_dict): + return Params(params_dict) - self.activation_layer = ActivationLayer.from_params(params) - def test_can_construct_from_params(self): +@pytest.fixture +def activation_layer(params): + return ActivationLayer.from_params(params.duplicate()) - activation_layer = self.activation_layer - assert activation_layer.dense.in_features == self.params_dict["hidden_size"] - assert activation_layer.dense.out_features == self.params_dict["intermediate_size"] +def test_can_construct_from_params(activation_layer, params_dict): + activation_layer = activation_layer + assert activation_layer.dense.in_features == params_dict["hidden_size"] + assert activation_layer.dense.out_features == params_dict["intermediate_size"] - def test_forward_runs(self): - self.activation_layer.forward(torch.randn(7, 5)) +def test_forward_runs(activation_layer): + activation_layer.forward(torch.randn(7, 5)) diff --git a/tests/modules/transformer/bimodal_attention_test.py b/tests/modules/transformer/bimodal_attention_test.py index 40dc81f12de..270aefd23e7 100644 --- a/tests/modules/transformer/bimodal_attention_test.py +++ b/tests/modules/transformer/bimodal_attention_test.py @@ -1,55 +1,56 @@ -import copy import torch +import pytest from allennlp.common import Params from allennlp.modules.transformer import BiModalAttention -from allennlp.common.testing import AllenNlpTestCase - - -class TestBiModalAttention(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = { - "hidden_size1": 6, - "hidden_size2": 4, - "combined_hidden_size": 16, - "num_attention_heads": 2, - "dropout1": 0.1, - "dropout2": 0.2, - } - - params = Params(copy.deepcopy(self.params_dict)) - - self.biattention = BiModalAttention.from_params(params) - - def test_can_construct_from_params(self): - - biattention = self.biattention - - assert biattention.num_attention_heads == self.params_dict["num_attention_heads"] - assert biattention.attention_head_size == int( - self.params_dict["combined_hidden_size"] / self.params_dict["num_attention_heads"] - ) - assert ( - biattention.all_head_size - == self.params_dict["num_attention_heads"] * biattention.attention_head_size - ) - assert biattention.query1.in_features == self.params_dict["hidden_size1"] - assert biattention.key1.in_features == self.params_dict["hidden_size1"] - assert biattention.value1.in_features == self.params_dict["hidden_size1"] - assert biattention.dropout1.p == self.params_dict["dropout1"] - - assert biattention.query2.in_features == self.params_dict["hidden_size2"] - assert biattention.key2.in_features == self.params_dict["hidden_size2"] - assert biattention.value2.in_features == self.params_dict["hidden_size2"] - assert biattention.dropout2.p == self.params_dict["dropout2"] - - def test_forward_runs(self): - - self.biattention.forward( - torch.randn(2, 3, 6), - torch.randn(2, 3, 4), - torch.randint(0, 2, (2, 2, 3, 3)) == 1, # creating boolean tensors - torch.randint(0, 2, (2, 2, 3, 3)) == 1, - ) + + +@pytest.fixture +def params_dict(): + return { + "hidden_size1": 6, + "hidden_size2": 4, + "combined_hidden_size": 16, + "num_attention_heads": 2, + "dropout1": 0.1, + "dropout2": 0.2, + } + + +@pytest.fixture +def params(params_dict): + return Params(params_dict) + + +@pytest.fixture +def biattention(params): + return BiModalAttention.from_params(params.duplicate()) + + +def test_can_construct_from_params(biattention, params_dict): + assert biattention.num_attention_heads == params_dict["num_attention_heads"] + assert biattention.attention_head_size == int( + params_dict["combined_hidden_size"] / params_dict["num_attention_heads"] + ) + assert ( + biattention.all_head_size + == params_dict["num_attention_heads"] * biattention.attention_head_size + ) + assert biattention.query1.in_features == params_dict["hidden_size1"] + assert biattention.key1.in_features == params_dict["hidden_size1"] + assert biattention.value1.in_features == params_dict["hidden_size1"] + assert biattention.dropout1.p == params_dict["dropout1"] + + assert biattention.query2.in_features == params_dict["hidden_size2"] + assert biattention.key2.in_features == params_dict["hidden_size2"] + assert biattention.value2.in_features == params_dict["hidden_size2"] + assert biattention.dropout2.p == params_dict["dropout2"] + + +def test_forward_runs(biattention): + biattention( + torch.randn(2, 3, 6), + torch.randn(2, 3, 4), + torch.randint(0, 2, (2, 2, 3, 3)) == 1, # creating boolean tensors + torch.randint(0, 2, (2, 2, 3, 3)) == 1, + ) diff --git a/tests/modules/transformer/bimodal_encoder_test.py b/tests/modules/transformer/bimodal_encoder_test.py index b95af3bfa1f..39bd3b54e8c 100644 --- a/tests/modules/transformer/bimodal_encoder_test.py +++ b/tests/modules/transformer/bimodal_encoder_test.py @@ -1,95 +1,92 @@ -import copy import torch +from torch.testing import assert_allclose +from transformers import AutoModel +import pytest + from allennlp.common import Params -from allennlp.common import cached_transformers -from allennlp.common.testing import assert_equal_parameters from allennlp.modules.transformer import BiModalEncoder -from allennlp.common.testing import AllenNlpTestCase - - -class TestBiModalEncoder(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = { - "num_hidden_layers1": 3, - "num_hidden_layers2": 3, - "hidden_size1": 12, - "hidden_size2": 12, - "combined_hidden_size": 12, - "intermediate_size1": 3, - "intermediate_size2": 3, - "num_attention_heads1": 4, - "num_attention_heads2": 6, - "combined_num_attention_heads": 2, - "attention_dropout1": 0.1, - "hidden_dropout1": 0.2, - "attention_dropout2": 0.1, - "hidden_dropout2": 0.2, - "activation": "relu", - "biattention_id1": [1, 2], - "biattention_id2": [1, 2], - "fixed_layer1": 1, - "fixed_layer2": 1, - } - - params = Params(copy.deepcopy(self.params_dict)) - - self.bimodal_encoder = BiModalEncoder.from_params(params) - - self.pretrained = cached_transformers.get("bert-base-uncased", False) - - def test_can_construct_from_params(self): - - modules = dict(self.bimodal_encoder.named_modules()) - assert len(modules["layers1"]) == self.params_dict["num_hidden_layers1"] - assert len(modules["layers2"]) == self.params_dict["num_hidden_layers2"] - - def test_forward_runs(self): - - embedding1 = torch.randn(16, 34, self.params_dict["hidden_size1"]) - embedding2 = torch.randn(16, 2, self.params_dict["hidden_size2"]) - attn_mask1 = torch.randint(0, 2, (16, 1, 1, 34)) == 1 - attn_mask2 = torch.randint(0, 2, (16, 1, 1, 2)) == 1 - - self.bimodal_encoder.forward(embedding1, embedding2, attn_mask1, attn_mask2) - - def test_loading_from_pretrained_weights(self): - pretrained_module = self.pretrained.encoder - required_kwargs = [ - "num_hidden_layers2", - "hidden_size2", - "combined_hidden_size", - "intermediate_size2", - "num_attention_heads2", - "combined_num_attention_heads", - "attention_dropout2", - "hidden_dropout2", - "biattention_id1", - "biattention_id2", - "fixed_layer1", - "fixed_layer2", - ] - kwargs = {key: self.params_dict[key] for key in required_kwargs} - module = BiModalEncoder.from_pretrained_module(pretrained_module, **kwargs) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters( - pretrained_module, - module, - ignore_missing=True, - mapping=mapping, - ) - - def test_default_parameters(self): - encoder = BiModalEncoder() - embedding1 = torch.randn(16, 34, 1024) - embedding2 = torch.randn(16, 2, 1024) - attn_mask1 = torch.randint(0, 2, (16, 1, 1, 34)) == 1 - attn_mask2 = torch.randint(0, 2, (16, 1, 1, 2)) == 1 - - encoder.forward(embedding1, embedding2, attn_mask1, attn_mask2) + + +@pytest.fixture +def params_dict(): + return { + "num_hidden_layers1": 3, + "num_hidden_layers2": 3, + "hidden_size1": 12, + "hidden_size2": 12, + "combined_hidden_size": 12, + "intermediate_size1": 3, + "intermediate_size2": 3, + "num_attention_heads1": 4, + "num_attention_heads2": 6, + "combined_num_attention_heads": 2, + "attention_dropout1": 0.1, + "hidden_dropout1": 0.2, + "attention_dropout2": 0.1, + "hidden_dropout2": 0.2, + "activation": "relu", + "biattention_id1": [1, 2], + "biattention_id2": [1, 2], + "fixed_layer1": 1, + "fixed_layer2": 1, + } + + +@pytest.fixture +def params(params_dict): + return Params(params_dict) + + +@pytest.fixture +def bimodal_encoder(params): + return BiModalEncoder.from_params(params.duplicate()) + + +def test_can_construct_from_params(bimodal_encoder, params_dict): + modules = dict(bimodal_encoder.named_modules()) + assert len(modules["layers1"]) == params_dict["num_hidden_layers1"] + assert len(modules["layers2"]) == params_dict["num_hidden_layers2"] + + +def test_forward_runs(bimodal_encoder, params_dict): + embedding1 = torch.randn(16, 34, params_dict["hidden_size1"]) + embedding2 = torch.randn(16, 2, params_dict["hidden_size2"]) + attn_mask1 = torch.randint(0, 2, (16, 1, 1, 34)) == 1 + attn_mask2 = torch.randint(0, 2, (16, 1, 1, 2)) == 1 + bimodal_encoder(embedding1, embedding2, attn_mask1, attn_mask2) + + +def test_loading_from_pretrained_weights(params_dict): + pretrained_module = AutoModel.from_pretrained("bert-base-cased").encoder + + required_kwargs = [ + "num_hidden_layers2", + "hidden_size2", + "combined_hidden_size", + "intermediate_size2", + "num_attention_heads2", + "combined_num_attention_heads", + "attention_dropout2", + "hidden_dropout2", + "biattention_id1", + "biattention_id2", + "fixed_layer1", + "fixed_layer2", + ] + kwargs = {key: params_dict[key] for key in required_kwargs} + + module = BiModalEncoder.from_pretrained_module("bert-base-cased", **kwargs) + assert_allclose( + module.layers1[0].intermediate.dense.weight.data, + pretrained_module.layer[0].intermediate.dense.weight.data, + ) + + +def test_default_parameters(): + encoder = BiModalEncoder() + embedding1 = torch.randn(16, 34, 1024) + embedding2 = torch.randn(16, 2, 1024) + attn_mask1 = torch.randint(0, 2, (16, 1, 1, 34)) == 1 + attn_mask2 = torch.randint(0, 2, (16, 1, 1, 2)) == 1 + + encoder(embedding1, embedding2, attn_mask1, attn_mask2) diff --git a/tests/modules/transformer/self_attention_test.py b/tests/modules/transformer/self_attention_test.py index e29ae44cf9e..7a3dcb81ec8 100644 --- a/tests/modules/transformer/self_attention_test.py +++ b/tests/modules/transformer/self_attention_test.py @@ -1,21 +1,13 @@ import copy + import torch import pytest +from transformers import AutoModel from allennlp.common import Params -from allennlp.common import cached_transformers -from allennlp.common.testing import assert_equal_parameters, AllenNlpTestCase from allennlp.modules.transformer import SelfAttention from allennlp.nn.util import min_value_of_dtype -from transformers.models.bert.configuration_bert import BertConfig -from transformers.models.bert.modeling_bert import BertSelfAttention -from transformers.models.roberta.configuration_roberta import RobertaConfig -from transformers.models.roberta.modeling_roberta import RobertaSelfAttention -from transformers.models.electra.configuration_electra import ElectraConfig -from transformers.models.electra.modeling_electra import ElectraSelfAttention -from transformers.models.distilbert.configuration_distilbert import DistilBertConfig -from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention PARAMS_DICT = { "hidden_size": 6, @@ -24,145 +16,78 @@ } -def get_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) - params["attention_probs_dropout_prob"] = params.pop("dropout") +@pytest.fixture +def params_dict(): + return copy.deepcopy(PARAMS_DICT) - # bert, roberta, electra self attentions have the same code. - torch.manual_seed(1234) - hf_module = BertSelfAttention(BertConfig(**params)) - modules["bert"] = hf_module +@pytest.fixture +def params(params_dict): + return Params(params_dict) - torch.manual_seed(1234) - hf_module = RobertaSelfAttention(RobertaConfig(**params)) - modules["roberta"] = hf_module - torch.manual_seed(1234) - hf_module = ElectraSelfAttention(ElectraConfig(**params)) - modules["electra"] = hf_module +@pytest.fixture +def self_attention(params): + return SelfAttention.from_params(params.duplicate()) - torch.manual_seed(1234) - distilparams = copy.deepcopy(params_dict) - distilparams["n_heads"] = distilparams.pop("num_attention_heads") - distilparams["dim"] = distilparams.pop("hidden_size") - distilparams["attention_dropout"] = distilparams.pop("dropout") - hf_module = MultiHeadSelfAttention(DistilBertConfig(**distilparams)) - modules["distilbert"] = hf_module - return modules - - -class TestSelfAttention(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = {key: val for key, val in PARAMS_DICT.items()} +def test_can_construct_from_params(self_attention, params_dict): + assert self_attention.num_attention_heads == params_dict["num_attention_heads"] + assert self_attention.attention_head_size == int( + params_dict["hidden_size"] / params_dict["num_attention_heads"] + ) - params = Params(copy.deepcopy(self.params_dict)) + assert ( + self_attention.all_head_size + == params_dict["num_attention_heads"] * self_attention.attention_head_size + ) - self.self_attention = SelfAttention.from_params(params) + assert self_attention.query.in_features == params_dict["hidden_size"] + assert self_attention.key.in_features == params_dict["hidden_size"] + assert self_attention.value.in_features == params_dict["hidden_size"] - def test_can_construct_from_params(self): - assert self.self_attention.num_attention_heads == self.params_dict["num_attention_heads"] - assert self.self_attention.attention_head_size == int( - self.params_dict["hidden_size"] / self.params_dict["num_attention_heads"] - ) + assert self_attention.dropout.p == params_dict["dropout"] - assert ( - self.self_attention.all_head_size - == self.params_dict["num_attention_heads"] * self.self_attention.attention_head_size - ) - assert self.self_attention.query.in_features == self.params_dict["hidden_size"] - assert self.self_attention.key.in_features == self.params_dict["hidden_size"] - assert self.self_attention.value.in_features == self.params_dict["hidden_size"] +@pytest.mark.parametrize( + "pretrained_name, relevant_module", + [ + ("bert-base-cased", "bert.encoder.layer.0.attention.self"), + ("google/electra-base-generator", "electra.encoder.layer.0.attention.self"), + ("distilbert-base-uncased", "distilbert.transformer.layer.0.attention"), + ], +) +def test_loading_from_pretrained_weights_using_model_name(pretrained_name, relevant_module): + torch.manual_seed(1234) + module = SelfAttention.from_pretrained_module(pretrained_name, relevant_module=relevant_module) - assert self.self_attention.dropout.p == self.params_dict["dropout"] + torch.manual_seed(1234) + pretrained_module = dict(AutoModel.from_pretrained(pretrained_name).named_modules())[ + # Module name will exclude the top-level part (e.g. 'bert.', 'electra.') for some reason. + relevant_module[relevant_module.index(".") + 1 :] + ] - @pytest.mark.skip("Takes up too much memory") - @pytest.mark.parametrize("module_name, hf_module", get_modules(PARAMS_DICT).items()) - def test_forward_against_huggingface_output(self, module_name, hf_module): - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + batch_size = 2 + seq_len = 3 + dim = module.query.in_features + hidden_states = torch.randn(batch_size, seq_len, dim) + attention_mask = torch.tensor([[1, 1, 0], [1, 0, 1]])[:, None, None, :] - torch.manual_seed(1234) - self_attention = SelfAttention.from_pretrained_module(hf_module) - - output = self_attention.forward(hidden_states, attention_mask=attention_mask) - if module_name == "distilbert": - hf_output = hf_module.forward( - hidden_states, hidden_states, hidden_states, mask=attention_mask - ) - else: - # We do this because bert, roberta, electra process the attention_mask at the model level. - attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - hf_output = hf_module.forward(hidden_states, attention_mask=attention_mask_hf) - - assert torch.allclose(output[0], hf_output[0]) - - @pytest.mark.skip("Takes up too much memory") - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - "roberta-base", - "google/electra-base-generator", - "distilbert-base-uncased", - ], - ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): + # setting to eval mode to avoid non-deterministic dropout. + module = module.eval() + pretrained_module = pretrained_module.eval() + torch.manual_seed(1234) + output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + if "distilbert" in pretrained_name: torch.manual_seed(1234) - pretrained = cached_transformers.get(pretrained_name, False) - - if "distilbert" in pretrained_name: - encoder = pretrained.transformer - else: - encoder = pretrained.encoder - # Hacky way to get a bert layer. - for i, pretrained_module in enumerate(encoder.layer.modules()): - if i == 1: - break - - # Get the self attention layer. - if "distilbert" in pretrained_name: - pretrained_module = pretrained_module.attention - else: - pretrained_module = pretrained_module.attention.self - + hf_output = pretrained_module( + hidden_states, hidden_states, hidden_states, mask=attention_mask + )[0] + else: + # The attn_mask is processed outside the self attention module in HF bert models. + attention_mask = (~(attention_mask == 1)) * min_value_of_dtype(hidden_states.dtype) torch.manual_seed(1234) - module = SelfAttention.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - batch_size = 2 - seq_len = 3 - dim = module.query.in_features - hidden_states = torch.randn(batch_size, seq_len, dim) - attention_mask = torch.randint(0, 2, (batch_size, 1, 1, seq_len)) - - # setting to eval mode to avoid non-deterministic dropout. - module = module.eval() - pretrained_module = pretrained_module.eval() + hf_output = pretrained_module(hidden_states, attention_mask=attention_mask)[0] - torch.manual_seed(1234) - output = module.forward(hidden_states, attention_mask=attention_mask.squeeze())[0] - if "distilbert" in pretrained_name: - torch.manual_seed(1234) - hf_output = pretrained_module.forward( - hidden_states, hidden_states, hidden_states, mask=attention_mask - )[0] - else: - # The attn_mask is processed outside the self attention module in HF bert models. - attention_mask = (~(attention_mask == 1)) * min_value_of_dtype(hidden_states.dtype) - torch.manual_seed(1234) - hf_output = pretrained_module.forward(hidden_states, attention_mask=attention_mask)[0] - - assert torch.allclose(output, hf_output) + assert torch.allclose(output, hf_output) diff --git a/tests/modules/transformer/toolkit_test.py b/tests/modules/transformer/toolkit_test.py index cd1bf60e9fd..ff59b9cf6b5 100644 --- a/tests/modules/transformer/toolkit_test.py +++ b/tests/modules/transformer/toolkit_test.py @@ -1,9 +1,10 @@ import torch +from torch.testing import assert_allclose from overrides import overrides +from transformers import AutoModel from transformers.models.albert.modeling_albert import AlbertEmbeddings from allennlp.common import cached_transformers -from allennlp.common.testing import assert_equal_parameters from allennlp.data.vocabulary import Vocabulary from allennlp.modules.token_embedders import Embedding, TokenEmbedder from allennlp.modules.transformer import TransformerStack, TransformerEmbeddings @@ -49,15 +50,19 @@ def forward(self, token_ids: torch.LongTensor): tiny.forward(torch.LongTensor([[0, 1, 2]])) def test_use_first_four_layers_of_pretrained(self): - pretrained = cached_transformers.get("bert-base-uncased", False) + pretrained = "bert-base-cased" class SmallTransformer(TokenEmbedder): def __init__(self): super().__init__() - self.embeddings = TransformerEmbeddings.from_pretrained_module(pretrained) - + self.embeddings = TransformerEmbeddings.from_pretrained_module( + pretrained, relevant_module="bert.embeddings" + ) self.transformer = TransformerStack.from_pretrained_module( - pretrained, num_hidden_layers=4 + pretrained, + num_hidden_layers=4, + relevant_module="bert.encoder", + strict=False, ) @overrides @@ -68,19 +73,27 @@ def forward(self, token_ids: torch.LongTensor): small = SmallTransformer() assert len(small.transformer.layers) == 4 - small.forward(torch.LongTensor([[0, 1, 2]])) + small(torch.LongTensor([[0, 1, 2]])) def test_use_selected_layers_of_bert_for_different_purposes(self): class MediumTransformer(torch.nn.Module): def __init__(self): super().__init__() - self.embeddings = TransformerEmbeddings.from_pretrained_module("bert-base-uncased") + self.embeddings = TransformerEmbeddings.from_pretrained_module( + "bert-base-cased", relevant_module="bert.embeddings" + ) self.separate_transformer = TransformerStack.from_pretrained_module( - "bert-base-uncased", num_hidden_layers=range(0, 8) + "bert-base-cased", + relevant_module="bert.encoder", + num_hidden_layers=8, + strict=False, ) self.combined_transformer = TransformerStack.from_pretrained_module( - "bert-base-uncased", - num_hidden_layers=range(8, 12), + "bert-base-cased", + relevant_module="bert.encoder", + num_hidden_layers=4, + mapping={f"layer.{l}": f"layers.{i}" for (i, l) in enumerate(range(8, 12))}, + strict=False, ) @overrides @@ -106,22 +119,31 @@ def forward( assert (len(medium.separate_transformer.layers)) == 8 assert (len(medium.combined_transformer.layers)) == 4 - pretrained = cached_transformers.get("bert-base-uncased", False) + pretrained = cached_transformers.get("bert-base-cased", False) pretrained_layers = dict(pretrained.encoder.layer.named_modules()) - medium_layers = dict(medium.combined_transformer.layers.named_modules()) + separate_layers = dict(medium.separate_transformer.layers.named_modules()) + assert_allclose( + separate_layers["0"].intermediate.dense.weight.data, + pretrained_layers["0"].intermediate.dense.weight.data, + ) - assert_equal_parameters( - medium_layers["0"], pretrained_layers["8"], TransformerStack._huggingface_mapping + combined_layers = dict(medium.combined_transformer.layers.named_modules()) + assert_allclose( + combined_layers["0"].intermediate.dense.weight.data, + pretrained_layers["8"].intermediate.dense.weight.data, ) - assert_equal_parameters( - medium_layers["1"], pretrained_layers["9"], TransformerStack._huggingface_mapping + assert_allclose( + combined_layers["1"].intermediate.dense.weight.data, + pretrained_layers["9"].intermediate.dense.weight.data, ) - assert_equal_parameters( - medium_layers["2"], pretrained_layers["10"], TransformerStack._huggingface_mapping + assert_allclose( + combined_layers["2"].intermediate.dense.weight.data, + pretrained_layers["10"].intermediate.dense.weight.data, ) - assert_equal_parameters( - medium_layers["3"], pretrained_layers["11"], TransformerStack._huggingface_mapping + assert_allclose( + combined_layers["3"].intermediate.dense.weight.data, + pretrained_layers["11"].intermediate.dense.weight.data, ) def test_combination_of_two_different_berts(self): @@ -130,8 +152,10 @@ def test_combination_of_two_different_berts(self): class AlmostRegularTransformer(TokenEmbedder): def __init__(self): super().__init__() - self.embeddings = TransformerEmbeddings.get_relevant_module("albert-base-v2") - self.transformer = TransformerStack.from_pretrained_module("bert-base-uncased") + self.embeddings = AutoModel.from_pretrained("albert-base-v2").embeddings + self.transformer = TransformerStack.from_pretrained_module( + "bert-base-cased", relevant_module="bert.encoder" + ) # We want to tune only the embeddings, because that's our experiment. self.transformer.requires_grad = False diff --git a/tests/modules/transformer/transformer_embeddings_test.py b/tests/modules/transformer/transformer_embeddings_test.py index d366f4732b4..d37eae8629b 100644 --- a/tests/modules/transformer/transformer_embeddings_test.py +++ b/tests/modules/transformer/transformer_embeddings_test.py @@ -1,23 +1,21 @@ -import pytest import copy + +import pytest import torch from torch.testing import assert_allclose - -from allennlp.common import Params, FromParams -from allennlp.common import cached_transformers - +from transformers import AutoModel from transformers.models.bert.configuration_bert import BertConfig from transformers.models.bert.modeling_bert import BertEmbeddings from transformers.models.albert.configuration_albert import AlbertConfig from transformers.models.albert.modeling_albert import AlbertEmbeddings -from allennlp.common.testing import assert_equal_parameters +from allennlp.common import Params, FromParams from allennlp.modules.transformer import ( TransformerEmbeddings, ImageFeatureEmbeddings, TransformerModule, ) -from allennlp.common.testing import AllenNlpTestCase + PARAMS_DICT = { "vocab_size": 20, @@ -29,9 +27,159 @@ } -def get_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) +@pytest.fixture +def params_dict(): + return copy.deepcopy(PARAMS_DICT) + + +@pytest.fixture +def params(params_dict): + return Params(params_dict) + + +@pytest.fixture +def transformer_embeddings(params): + return TransformerEmbeddings.from_params(params.duplicate()) + + +def test_can_construct_from_params(params_dict, transformer_embeddings): + embeddings = transformer_embeddings.embeddings + assert embeddings.word_embeddings.num_embeddings == params_dict["vocab_size"] + assert embeddings.word_embeddings.embedding_dim == params_dict["embedding_size"] + assert embeddings.word_embeddings.padding_idx == params_dict["pad_token_id"] + + assert embeddings.position_embeddings.num_embeddings == params_dict["max_position_embeddings"] + assert embeddings.position_embeddings.embedding_dim == params_dict["embedding_size"] + + assert embeddings.token_type_embeddings.num_embeddings == params_dict["type_vocab_size"] + assert embeddings.token_type_embeddings.embedding_dim == params_dict["embedding_size"] + + assert transformer_embeddings.layer_norm.normalized_shape[0] == params_dict["embedding_size"] + + assert transformer_embeddings.dropout.p == params_dict["dropout"] + + +def test_sanity(): + class TextEmbeddings(TransformerModule, FromParams): + def __init__( + self, + vocab_size: int, + hidden_size: int, + pad_token_id: int, + max_position_embeddings: int, + type_vocab_size: int, + dropout: float, + ): + super().__init__() + self.word_embeddings = torch.nn.Embedding( + vocab_size, hidden_size, padding_idx=pad_token_id + ) + self.position_embeddings = torch.nn.Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = torch.nn.Embedding(type_vocab_size, hidden_size) + + self.layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = torch.nn.Dropout(dropout) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.layer_norm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + torch.manual_seed(23) + text = TextEmbeddings(10, 5, 2, 3, 7, 0.0) + torch.manual_seed(23) + transformer = TransformerEmbeddings(10, 5, 2, 3, 7, 0.0) + + input_ids = torch.tensor([[1, 2]]) + token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) + position_ids = torch.tensor([[0, 1]]) + + text_output = text(input_ids, token_type_ids, position_ids) + transformer_output = transformer(input_ids, token_type_ids, position_ids) + + assert_allclose(text_output, transformer_output) + + +def test_forward_runs_with_inputs(transformer_embeddings): + input_ids = torch.tensor([[1, 2]]) + token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) + position_ids = torch.tensor([[0, 1]]) + transformer_embeddings( + input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids + ) + + +def test_output_size(params): + input_ids = torch.tensor([[1, 2]]) + token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) + position_ids = torch.tensor([[0, 1]]) + params["output_size"] = 7 + module = TransformerEmbeddings.from_params(params) + output = module(input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids) + + assert output.shape[-1] == 7 + + +def test_no_token_type_layer(params): + params["type_vocab_size"] = 0 + module = TransformerEmbeddings.from_params(params) + assert len(module.embeddings) == 2 + + +@pytest.mark.parametrize( + "pretrained_name", + [ + "bert-base-cased", + "epwalsh/bert-xsmall-dummy", + ], +) +def test_loading_from_pretrained_module(pretrained_name): + TransformerEmbeddings.from_pretrained_module(pretrained_name) + + +def test_loading_albert(): + """ + Albert is a special case because it includes a Linear layer in the encoder + that maps the embeddings to the encoder hidden size, but we include this linear + layer within our embedding layer. + """ + transformer_embedding = TransformerEmbeddings.from_pretrained_module( + "albert-base-v2", + ) + albert = AutoModel.from_pretrained("albert-base-v2") + assert_allclose( + transformer_embedding.embeddings.word_embeddings.weight.data, + albert.embeddings.word_embeddings.weight.data, + ) + assert_allclose( + transformer_embedding.linear_transform.weight.data, + albert.encoder.embedding_hidden_mapping_in.weight.data, + ) + + +def get_modules(): + params = copy.deepcopy(PARAMS_DICT) params["hidden_dropout_prob"] = params.pop("dropout") params["hidden_size"] = params.pop("embedding_size") @@ -39,270 +187,117 @@ def get_modules(params_dict): # bert, roberta, electra self attentions have the same code. torch.manual_seed(1234) - hf_module = BertEmbeddings(BertConfig(**params)) - modules["bert"] = hf_module + yield "bert", BertEmbeddings(BertConfig(**params)) - albertparams = copy.deepcopy(params_dict) + albertparams = copy.deepcopy(PARAMS_DICT) albertparams["hidden_dropout_prob"] = albertparams.pop("dropout") torch.manual_seed(1234) - hf_module = AlbertEmbeddings(AlbertConfig(**albertparams)) - modules["albert"] = hf_module - - return modules - - -class TestTransformerEmbeddings(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = {key: val for key, val in PARAMS_DICT.items()} - - params = Params(copy.deepcopy(self.params_dict)) - - self.transformer_embeddings = TransformerEmbeddings.from_params(params) - - def test_can_construct_from_params(self): - - transformer_embeddings = self.transformer_embeddings.embeddings - - assert ( - transformer_embeddings.word_embeddings.num_embeddings == self.params_dict["vocab_size"] - ) - assert ( - transformer_embeddings.word_embeddings.embedding_dim - == self.params_dict["embedding_size"] - ) - assert ( - transformer_embeddings.word_embeddings.padding_idx == self.params_dict["pad_token_id"] - ) - - assert ( - transformer_embeddings.position_embeddings.num_embeddings - == self.params_dict["max_position_embeddings"] - ) - assert ( - transformer_embeddings.position_embeddings.embedding_dim - == self.params_dict["embedding_size"] - ) - - assert ( - transformer_embeddings.token_type_embeddings.num_embeddings - == self.params_dict["type_vocab_size"] - ) - assert ( - transformer_embeddings.token_type_embeddings.embedding_dim - == self.params_dict["embedding_size"] - ) - - assert ( - self.transformer_embeddings.layer_norm.normalized_shape[0] - == self.params_dict["embedding_size"] - ) - - assert self.transformer_embeddings.dropout.p == self.params_dict["dropout"] - - def test_sanity(self): - class TextEmbeddings(TransformerModule, FromParams): - def __init__( - self, - vocab_size: int, - hidden_size: int, - pad_token_id: int, - max_position_embeddings: int, - type_vocab_size: int, - dropout: float, - ): - super().__init__() - self.word_embeddings = torch.nn.Embedding( - vocab_size, hidden_size, padding_idx=pad_token_id - ) - self.position_embeddings = torch.nn.Embedding(max_position_embeddings, hidden_size) - self.token_type_embeddings = torch.nn.Embedding(type_vocab_size, hidden_size) - - self.layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12) - self.dropout = torch.nn.Dropout(dropout) - - def forward( - self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None - ): - if input_ids is not None: - input_shape = input_ids.size() - else: - input_shape = inputs_embeds.size()[:-1] - - seq_length = input_shape[1] - device = input_ids.device if input_ids is not None else inputs_embeds.device - if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand(input_shape) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = inputs_embeds + position_embeddings + token_type_embeddings - embeddings = self.layer_norm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - torch.manual_seed(23) - text = TextEmbeddings(10, 5, 2, 3, 7, 0.0) - torch.manual_seed(23) - transformer = TransformerEmbeddings(10, 5, 2, 3, 7, 0.0) - - input_ids = torch.tensor([[1, 2]]) - token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) - position_ids = torch.tensor([[0, 1]]) - - text_output = text.forward(input_ids, token_type_ids, position_ids) - transformer_output = transformer.forward(input_ids, token_type_ids, position_ids) - - assert_allclose(text_output, transformer_output) - - def test_forward_runs_with_inputs(self): - input_ids = torch.tensor([[1, 2]]) - token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) - position_ids = torch.tensor([[0, 1]]) - self.transformer_embeddings.forward( - input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids - ) - - def test_output_size(self): - input_ids = torch.tensor([[1, 2]]) - token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) - position_ids = torch.tensor([[0, 1]]) - params = copy.deepcopy(self.params_dict) - params["output_size"] = 7 - params = Params(params) - module = TransformerEmbeddings.from_params(params) - output = module.forward( - input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids - ) - - assert output.shape[-1] == 7 - - def test_no_token_type_layer(self): - params = copy.deepcopy(self.params_dict) - params["type_vocab_size"] = 0 - params = Params(params) - module = TransformerEmbeddings.from_params(params) - - assert len(module.embeddings) == 2 - - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - "albert-base-v2", - ], + yield "albert", AlbertEmbeddings(AlbertConfig(**albertparams)) + + +@pytest.mark.parametrize("module_name, hf_module", get_modules()) +def test_forward_against_huggingface_output(transformer_embeddings, module_name, hf_module): + input_ids = torch.tensor([[1, 2]]) + token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) + position_ids = torch.tensor([[0, 1]]) + + state_dict = transformer_embeddings._get_mapped_state_dict(hf_module.state_dict()) + if "position_ids" in state_dict: + del state_dict["position_ids"] + transformer_embeddings.load_state_dict(state_dict) + + torch.manual_seed(1234) + transformer_embeddings = ( + transformer_embeddings.eval() + ) # setting to eval mode to avoid non-deterministic dropout. + output = transformer_embeddings( + input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): - pretrained_module = cached_transformers.get(pretrained_name, False).embeddings - module = TransformerEmbeddings.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - missing = assert_equal_parameters(pretrained_module, module, mapping=mapping) - assert len(missing) == 0 - - @pytest.mark.parametrize("module_name, hf_module", get_modules(PARAMS_DICT).items()) - def test_forward_against_huggingface_output(self, module_name, hf_module): - input_ids = torch.tensor([[1, 2]]) - token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) - position_ids = torch.tensor([[0, 1]]) - - torch.manual_seed(1234) - embeddings = TransformerEmbeddings.from_pretrained_module(hf_module) - - torch.manual_seed(1234) - embeddings = embeddings.eval() # setting to eval mode to avoid non-deterministic dropout. - output = embeddings.forward( - input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids - ) - - torch.manual_seed(1234) - hf_module = hf_module.eval() # setting to eval mode to avoid non-deterministic dropout. - hf_output = hf_module.forward( - input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids - ) - - assert torch.allclose(output, hf_output) - - -class TestImageFeatureEmbeddings(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = {"feature_size": 3, "embedding_size": 5, "dropout": 0.1} - - params = Params(copy.deepcopy(self.params_dict)) - - self.img_embeddings = ImageFeatureEmbeddings.from_params(params) - - def test_can_construct_from_params(self): - assert ( - self.img_embeddings.embeddings.image_embeddings.in_features - == self.params_dict["feature_size"] - ) - assert ( - self.img_embeddings.embeddings.image_embeddings.out_features - == self.params_dict["embedding_size"] - ) - assert ( - self.img_embeddings.embeddings.location_embeddings.out_features - == self.params_dict["embedding_size"] - ) - assert self.img_embeddings.dropout.p == self.params_dict["dropout"] - - def test_forward_runs_with_inputs(self): - batch_size = 2 - feature_dim = self.params_dict["feature_size"] - image_feature = torch.randn(batch_size, feature_dim) - image_location = torch.randn(batch_size, 4) - self.img_embeddings.forward(image_feature, image_location) - - def test_sanity(self): - class OldImageFeatureEmbeddings(TransformerModule, FromParams): - """Construct the embeddings from image, spatial location (omit now) and - token_type embeddings. - """ - - def __init__(self, feature_size: int, embedding_size: int, dropout: float = 0.0): - super().__init__() - - self.image_embeddings = torch.nn.Linear(feature_size, embedding_size) - self.image_location_embeddings = torch.nn.Linear(4, embedding_size, bias=False) - self.layer_norm = torch.nn.LayerNorm(embedding_size, eps=1e-12) - self.dropout = torch.nn.Dropout(dropout) - - def forward(self, image_feature: torch.Tensor, image_location: torch.Tensor): - img_embeddings = self.image_embeddings(image_feature) - loc_embeddings = self.image_location_embeddings(image_location) - embeddings = self.layer_norm(img_embeddings + loc_embeddings) - embeddings = self.dropout(embeddings) - - return embeddings - - torch.manual_seed(23) - old = OldImageFeatureEmbeddings(**self.params_dict) - torch.manual_seed(23) - now = ImageFeatureEmbeddings(**self.params_dict) - - batch_size = 2 - - image_feature = torch.randn(batch_size, self.params_dict["feature_size"]) - image_location = torch.randn(batch_size, 4) - - torch.manual_seed(23) - old_output = old.forward(image_feature, image_location) - torch.manual_seed(23) - now_output = now.forward(image_feature, image_location) - - assert_allclose(old_output, now_output) + + torch.manual_seed(1234) + hf_module = hf_module.eval() # setting to eval mode to avoid non-deterministic dropout. + hf_output = hf_module( + input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids + ) + + assert torch.allclose(output, hf_output) + + +@pytest.fixture +def image_params_dict(): + return {"feature_size": 3, "embedding_size": 5, "dropout": 0.1} + + +@pytest.fixture +def image_params(image_params_dict): + return Params(image_params_dict) + + +@pytest.fixture +def image_embeddings(image_params): + return ImageFeatureEmbeddings.from_params(image_params.duplicate()) + + +def test_can_construct_image_embeddings_from_params(image_embeddings, image_params_dict): + assert ( + image_embeddings.embeddings.image_embeddings.in_features + == image_params_dict["feature_size"] + ) + assert ( + image_embeddings.embeddings.image_embeddings.out_features + == image_params_dict["embedding_size"] + ) + assert ( + image_embeddings.embeddings.location_embeddings.out_features + == image_params_dict["embedding_size"] + ) + assert image_embeddings.dropout.p == image_params_dict["dropout"] + + +def test_image_embedding_forward_runs_with_inputs(image_embeddings, image_params_dict): + batch_size = 2 + feature_dim = image_params_dict["feature_size"] + image_feature = torch.randn(batch_size, feature_dim) + image_location = torch.randn(batch_size, 4) + image_embeddings(image_feature, image_location) + + +def test_image_embeddings_sanity(image_params_dict): + class OldImageFeatureEmbeddings(TransformerModule, FromParams): + """Construct the embeddings from image, spatial location (omit now) and + token_type embeddings. + """ + + def __init__(self, feature_size: int, embedding_size: int, dropout: float = 0.0): + super().__init__() + + self.image_embeddings = torch.nn.Linear(feature_size, embedding_size) + self.image_location_embeddings = torch.nn.Linear(4, embedding_size, bias=False) + self.layer_norm = torch.nn.LayerNorm(embedding_size, eps=1e-12) + self.dropout = torch.nn.Dropout(dropout) + + def forward(self, image_feature: torch.Tensor, image_location: torch.Tensor): + img_embeddings = self.image_embeddings(image_feature) + loc_embeddings = self.image_location_embeddings(image_location) + embeddings = self.layer_norm(img_embeddings + loc_embeddings) + embeddings = self.dropout(embeddings) + + return embeddings + + torch.manual_seed(23) + old = OldImageFeatureEmbeddings(**image_params_dict) + torch.manual_seed(23) + now = ImageFeatureEmbeddings(**image_params_dict) + + batch_size = 2 + + image_feature = torch.randn(batch_size, image_params_dict["feature_size"]) + image_location = torch.randn(batch_size, 4) + + torch.manual_seed(23) + old_output = old(image_feature, image_location) + torch.manual_seed(23) + now_output = now(image_feature, image_location) + + assert_allclose(old_output, now_output) diff --git a/tests/modules/transformer/transformer_layer_test.py b/tests/modules/transformer/transformer_layer_test.py index 1ecf183eace..4c1e141a5a8 100644 --- a/tests/modules/transformer/transformer_layer_test.py +++ b/tests/modules/transformer/transformer_layer_test.py @@ -1,13 +1,7 @@ import copy + import torch import pytest - -from allennlp.common import Params -from allennlp.common import cached_transformers -from allennlp.common.testing import assert_equal_parameters -from allennlp.modules.transformer import AttentionLayer, TransformerLayer -from allennlp.common.testing import AllenNlpTestCase - from transformers.models.bert.configuration_bert import BertConfig from transformers.models.bert.modeling_bert import BertAttention, BertLayer from transformers.models.roberta.configuration_roberta import RobertaConfig @@ -15,6 +9,14 @@ from transformers.models.electra.configuration_electra import ElectraConfig from transformers.models.electra.modeling_electra import ElectraAttention, ElectraLayer +from allennlp.common import Params, cached_transformers +from allennlp.common.testing import run_distributed_test +from allennlp.modules.transformer import ( + AttentionLayer, + TransformerLayer, +) + + ATTENTION_PARAMS_DICT = { "hidden_size": 6, "num_attention_heads": 2, @@ -23,141 +25,113 @@ } -def get_attention_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) +@pytest.fixture +def attention_params(): + return Params(copy.deepcopy(ATTENTION_PARAMS_DICT)) + + +def test_attention(attention_params): + attention_layer = AttentionLayer.from_params(attention_params.duplicate()).eval() + + assert attention_layer.self.num_attention_heads == attention_params["num_attention_heads"] + assert attention_layer.self.attention_head_size == int( + attention_params["hidden_size"] / attention_params["num_attention_heads"] + ) + assert ( + attention_layer.self.all_head_size + == attention_params["num_attention_heads"] * attention_layer.self.attention_head_size + ) + assert attention_layer.self.query.in_features == attention_params["hidden_size"] + assert attention_layer.self.key.in_features == attention_params["hidden_size"] + assert attention_layer.self.value.in_features == attention_params["hidden_size"] + assert attention_layer.self.dropout.p == attention_params["attention_dropout"] + + assert attention_layer.output.dense.in_features == attention_params["hidden_size"] + assert attention_layer.output.dense.out_features == attention_params["hidden_size"] + assert attention_layer.output.layer_norm.normalized_shape[0] == attention_params["hidden_size"] + assert attention_layer.output.dropout.p == attention_params["hidden_dropout"] + + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + attention_layer(torch.randn(2, 3, 6), attention_mask=attention_mask) + + +def get_attention_modules(): + params = copy.deepcopy(ATTENTION_PARAMS_DICT) params["attention_probs_dropout_prob"] = params.pop("attention_dropout") params["hidden_dropout_prob"] = params.pop("hidden_dropout") torch.manual_seed(1234) - hf_module = BertAttention(BertConfig(**params)) - modules["bert"] = hf_module + yield "bert", BertAttention(BertConfig(**params)).eval() torch.manual_seed(1234) - hf_module = RobertaAttention(RobertaConfig(**params)) - modules["roberta"] = hf_module + yield "roberta", RobertaAttention(RobertaConfig(**params)).eval() torch.manual_seed(1234) - hf_module = ElectraAttention(ElectraConfig(**params)) - modules["electra"] = hf_module + yield "electra", ElectraAttention(ElectraConfig(**params)).eval() - return modules +@pytest.mark.parametrize("module_name, hf_module", get_attention_modules()) +def test_attention_matches_huggingface(attention_params, module_name, hf_module): + hidden_states = torch.randn(2, 3, 6) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) -class TestAttentionLayer(AllenNlpTestCase): - def setup_method(self): - super().setup_method() + attention = AttentionLayer.from_params(attention_params).eval() + state_dict = attention._get_mapped_state_dict(hf_module.state_dict()) + attention.load_state_dict(state_dict) - self.params_dict = { - "hidden_size": 6, - "num_attention_heads": 2, - "attention_dropout": 0.1, - "hidden_dropout": 0.2, - } + torch.manual_seed(1234) + output = attention(hidden_states, attention_mask=attention_mask) + # We do this because bert, roberta, electra process the attention_mask at the model level. + attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - params = Params(copy.deepcopy(self.params_dict)) + torch.manual_seed(1234) + hf_output = hf_module(hidden_states, attention_mask=attention_mask_hf) - self.attention_layer = AttentionLayer.from_params(params) + assert torch.allclose(output[0], hf_output[0]) - def test_can_construct_from_params(self): - attention_layer = self.attention_layer +@pytest.mark.parametrize( + "pretrained_name, relevant_top_level_module", + [ + ("bert-base-cased", "bert"), + ("epwalsh/bert-xsmall-dummy", None), + ], +) +def test_attention_from_pretrained(pretrained_name, relevant_top_level_module): + torch.manual_seed(1234) + pretrained = cached_transformers.get(pretrained_name, False).eval() - assert attention_layer.self.num_attention_heads == self.params_dict["num_attention_heads"] - assert attention_layer.self.attention_head_size == int( - self.params_dict["hidden_size"] / self.params_dict["num_attention_heads"] - ) - assert ( - attention_layer.self.all_head_size - == self.params_dict["num_attention_heads"] * attention_layer.self.attention_head_size - ) - assert attention_layer.self.query.in_features == self.params_dict["hidden_size"] - assert attention_layer.self.key.in_features == self.params_dict["hidden_size"] - assert attention_layer.self.value.in_features == self.params_dict["hidden_size"] - assert attention_layer.self.dropout.p == self.params_dict["attention_dropout"] - - assert attention_layer.output.dense.in_features == self.params_dict["hidden_size"] - assert attention_layer.output.dense.out_features == self.params_dict["hidden_size"] - assert ( - attention_layer.output.layer_norm.normalized_shape[0] == self.params_dict["hidden_size"] - ) - assert attention_layer.output.dropout.p == self.params_dict["hidden_dropout"] + if "distilbert" in pretrained_name: + encoder = pretrained.transformer + else: + encoder = pretrained.encoder + # Hacky way to get a bert layer. + pretrained_module = list(encoder.layer.modules())[1].attention - def test_forward_runs(self): - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - self.attention_layer.forward(torch.randn(2, 3, 6), attention_mask=attention_mask) + torch.manual_seed(1234) + module = AttentionLayer.from_pretrained_module( + pretrained_name, + relevant_module=None + if relevant_top_level_module is None + else f"{relevant_top_level_module}.encoder.layer.0.attention", + ).eval() + + batch_size = 2 + seq_length = 15 + hidden_size = module.self.query.in_features + + hidden_states = torch.randn(batch_size, seq_length, hidden_size) + attention_mask = torch.randint(0, 2, (batch_size, seq_length)) + attention_mask_hf = attention_mask[:, None, None, :] + attention_mask_hf = (1.0 - attention_mask_hf) * -10e5 - @pytest.mark.parametrize( - "module_name, hf_module", get_attention_modules(ATTENTION_PARAMS_DICT).items() - ) - def test_forward_against_huggingface_outputs(self, module_name, hf_module): - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - - attention = AttentionLayer.from_pretrained_module(hf_module) - - torch.manual_seed(1234) - output = attention.forward(hidden_states, attention_mask=attention_mask) - # We do this because bert, roberta, electra process the attention_mask at the model level. - attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - torch.manual_seed(1234) - hf_output = hf_module.forward(hidden_states, attention_mask=attention_mask_hf) - - assert torch.allclose(output[0], hf_output[0]) - - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - "roberta-base", - ], - ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): - - torch.manual_seed(1234) - pretrained = cached_transformers.get(pretrained_name, False) - - if "distilbert" in pretrained_name: - encoder = pretrained.transformer - else: - encoder = pretrained.encoder - # Hacky way to get a bert layer. - for i, pretrained_module in enumerate(encoder.layer.modules()): - if i == 1: - break - - pretrained_module = pretrained_module.attention - - torch.manual_seed(1234) - module = AttentionLayer.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - batch_size = 2 - seq_len = 768 - dim = module.self.query.in_features - hidden_states = torch.randn(batch_size, seq_len, dim) - attention_mask = torch.randint(0, 2, (batch_size, seq_len)) - mask_reshp = (batch_size, 1, 1, dim) - attention_mask_hf = (attention_mask == 0).view(mask_reshp).expand( - batch_size, 12, seq_len, seq_len - ) * -10e5 - - # setting to eval mode to avoid non-deterministic dropout. - module = module.eval() - pretrained_module = pretrained_module.eval() - - torch.manual_seed(1234) - output = module.forward(hidden_states, attention_mask=attention_mask.squeeze())[0] - torch.manual_seed(1234) - hf_output = pretrained_module.forward(hidden_states, attention_mask=attention_mask_hf)[0] - - assert torch.allclose(output, hf_output, atol=1e-04) + torch.manual_seed(1234) + output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + + torch.manual_seed(1234) + hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf)[0] + + assert torch.allclose(output, hf_output, atol=1e-04) LAYER_PARAMS_DICT = { @@ -170,213 +144,158 @@ def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name) } -def get_layer_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) - params["attention_probs_dropout_prob"] = params.pop("attention_dropout") - params["hidden_dropout_prob"] = params.pop("hidden_dropout") +@pytest.fixture +def layer_params(): + return Params(copy.deepcopy(LAYER_PARAMS_DICT)) - # bert, roberta, electra, layoutlm self attentions have the same code. - torch.manual_seed(1234) - hf_module = BertLayer(BertConfig(**params)) - modules["bert"] = hf_module +def test_layer(layer_params): + transformer_layer = TransformerLayer.from_params(layer_params.duplicate()).eval() - torch.manual_seed(1234) - hf_module = RobertaLayer(RobertaConfig(**params)) - modules["roberta"] = hf_module + assert ( + transformer_layer.attention.self.num_attention_heads == layer_params["num_attention_heads"] + ) + assert transformer_layer.attention.self.attention_head_size == int( + layer_params["hidden_size"] / layer_params["num_attention_heads"] + ) + assert ( + transformer_layer.attention.self.all_head_size + == layer_params["num_attention_heads"] + * transformer_layer.attention.self.attention_head_size + ) + assert transformer_layer.attention.self.query.in_features == layer_params["hidden_size"] + assert transformer_layer.attention.self.key.in_features == layer_params["hidden_size"] + assert transformer_layer.attention.self.value.in_features == layer_params["hidden_size"] + assert transformer_layer.attention.self.dropout.p == layer_params["attention_dropout"] + + assert transformer_layer.attention.output.dense.in_features == layer_params["hidden_size"] + assert transformer_layer.attention.output.dense.out_features == layer_params["hidden_size"] + assert ( + transformer_layer.attention.output.layer_norm.normalized_shape[0] + == layer_params["hidden_size"] + ) + assert transformer_layer.attention.output.dropout.p == layer_params["hidden_dropout"] - torch.manual_seed(1234) - hf_module = ElectraLayer(ElectraConfig(**params)) - modules["electra"] = hf_module + assert transformer_layer.intermediate.dense.in_features == layer_params["hidden_size"] + assert transformer_layer.intermediate.dense.out_features == layer_params["intermediate_size"] - return modules + assert transformer_layer.output.dense.in_features == layer_params["intermediate_size"] + assert transformer_layer.output.dense.out_features == layer_params["hidden_size"] + assert transformer_layer.output.layer_norm.normalized_shape[0] == layer_params["hidden_size"] -class TestTransformerLayer(AllenNlpTestCase): - def setup_method(self): - super().setup_method() + assert transformer_layer.output.dropout.p == layer_params["hidden_dropout"] - self.params_dict = { - "hidden_size": 6, - "intermediate_size": 3, - "num_attention_heads": 2, - "attention_dropout": 0.1, - "hidden_dropout": 0.2, - "activation": "relu", - } + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + transformer_layer(torch.randn(2, 3, 6), attention_mask=attention_mask) - params = Params(copy.deepcopy(self.params_dict)) + with pytest.raises(AssertionError): + transformer_layer( + torch.randn(2, 3, 6), + attention_mask=attention_mask, + encoder_hidden_states=torch.randn(2, 3, 6), + ) - self.transformer_layer = TransformerLayer.from_params(params) - self.pretrained_name = "bert-base-uncased" - self.pretrained = cached_transformers.get(self.pretrained_name, False) +def test_layer_with_cross_attention(layer_params): + layer_params["add_cross_attention"] = True - def test_can_construct_from_params(self): + transformer_layer = TransformerLayer.from_params(layer_params).eval() + assert hasattr(transformer_layer, "cross_attention") - transformer_layer = self.transformer_layer + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + transformer_layer( + torch.randn(2, 3, 6), + attention_mask=attention_mask, + encoder_hidden_states=torch.randn(2, 3, 6), + ) - assert ( - transformer_layer.attention.self.num_attention_heads - == self.params_dict["num_attention_heads"] - ) - assert transformer_layer.attention.self.attention_head_size == int( - self.params_dict["hidden_size"] / self.params_dict["num_attention_heads"] - ) - assert ( - transformer_layer.attention.self.all_head_size - == self.params_dict["num_attention_heads"] - * transformer_layer.attention.self.attention_head_size - ) - assert transformer_layer.attention.self.query.in_features == self.params_dict["hidden_size"] - assert transformer_layer.attention.self.key.in_features == self.params_dict["hidden_size"] - assert transformer_layer.attention.self.value.in_features == self.params_dict["hidden_size"] - assert transformer_layer.attention.self.dropout.p == self.params_dict["attention_dropout"] - assert ( - transformer_layer.attention.output.dense.in_features == self.params_dict["hidden_size"] - ) - assert ( - transformer_layer.attention.output.dense.out_features == self.params_dict["hidden_size"] - ) - assert ( - transformer_layer.attention.output.layer_norm.normalized_shape[0] - == self.params_dict["hidden_size"] - ) - assert transformer_layer.attention.output.dropout.p == self.params_dict["hidden_dropout"] +def get_layer_modules(): + params = copy.deepcopy(LAYER_PARAMS_DICT) + params["attention_probs_dropout_prob"] = params.pop("attention_dropout") + params["hidden_dropout_prob"] = params.pop("hidden_dropout") + params["hidden_act"] = params.pop("activation") - assert transformer_layer.intermediate.dense.in_features == self.params_dict["hidden_size"] - assert ( - transformer_layer.intermediate.dense.out_features - == self.params_dict["intermediate_size"] - ) + torch.manual_seed(1234) + yield "bert", BertLayer(BertConfig(**params)).eval() - assert transformer_layer.output.dense.in_features == self.params_dict["intermediate_size"] - assert transformer_layer.output.dense.out_features == self.params_dict["hidden_size"] + torch.manual_seed(1234) + yield "roberta", RobertaLayer(RobertaConfig(**params)).eval() - assert ( - transformer_layer.output.layer_norm.normalized_shape[0] - == self.params_dict["hidden_size"] - ) + torch.manual_seed(1234) + yield "electra", ElectraLayer(ElectraConfig(**params)).eval() - assert transformer_layer.output.dropout.p == self.params_dict["hidden_dropout"] - def test_forward_runs(self): - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - self.transformer_layer.forward(torch.randn(2, 3, 6), attention_mask=attention_mask) +@pytest.mark.parametrize("module_name, hf_module", get_layer_modules()) +def test_layer_matches_huggingface(layer_params, module_name, hf_module): + layer = TransformerLayer.from_params(layer_params).eval() + state_dict = layer._get_mapped_state_dict(hf_module.state_dict()) + layer.load_state_dict(state_dict) - with pytest.raises(AssertionError): - self.transformer_layer.forward( - torch.randn(2, 3, 6), - attention_mask=attention_mask, - encoder_hidden_states=torch.randn(2, 3, 6), - ) + hidden_states = torch.randn(2, 3, 6) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - def test_cross_attention(self): - params = copy.deepcopy(self.params_dict) - params["add_cross_attention"] = True + torch.manual_seed(1234) + output = layer(hidden_states, attention_mask=attention_mask) + # We do this because bert, roberta, electra process the attention_mask at the model level. + attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 + torch.manual_seed(1234) + hf_output = hf_module(hidden_states, attention_mask=attention_mask_hf) - params = Params(params) + assert torch.allclose(output[0], hf_output[0]) - transformer_layer = TransformerLayer.from_params(params) - assert hasattr(transformer_layer, "cross_attention") - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - transformer_layer.forward( - torch.randn(2, 3, 6), - attention_mask=attention_mask, - encoder_hidden_states=torch.randn(2, 3, 6), - ) +@pytest.mark.parametrize( + "pretrained_name, relevant_top_level_module", + [ + ("bert-base-cased", "bert"), + ("epwalsh/bert-xsmall-dummy", None), + ], +) +def test_layer_from_pretrained(pretrained_name, relevant_top_level_module): + torch.manual_seed(1234) + pretrained = cached_transformers.get(pretrained_name, False).eval() - transformer_layer_new = TransformerLayer.from_pretrained_module( - transformer_layer, source="allennlp" - ) + if "distilbert" in pretrained_name: + encoder = pretrained.transformer + else: + encoder = pretrained.encoder + # Hacky way to get a bert layer. + pretrained_module = list(encoder.layer.modules())[1] + + torch.manual_seed(1234) + module = TransformerLayer.from_pretrained_module( + pretrained_name, + relevant_module=None + if relevant_top_level_module is None + else f"{relevant_top_level_module}.encoder.layer.0", + ).eval() + + batch_size = 2 + seq_length = 15 + hidden_size = module.attention.self.query.in_features + + hidden_states = torch.randn(batch_size, seq_length, hidden_size) + attention_mask = torch.randint(0, 2, (batch_size, seq_length)) + attention_mask_hf = attention_mask[:, None, None, :] + attention_mask_hf = (1.0 - attention_mask_hf) * -10e5 + + torch.manual_seed(1234) + output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + + torch.manual_seed(1234) + hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf)[0] - assert hasattr(transformer_layer_new, "cross_attention") - - def test_loading_from_pretrained_weights(self): - - # Hacky way to get a bert layer. - for i, pretrained_module in enumerate(self.pretrained.encoder.layer.modules()): - if i == 1: - break - - module = TransformerLayer.from_pretrained_module(pretrained_module) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - @pytest.mark.parametrize("module_name, hf_module", get_layer_modules(LAYER_PARAMS_DICT).items()) - def test_forward_against_huggingface_outputs(self, module_name, hf_module): - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - - layer = TransformerLayer.from_pretrained_module(hf_module) - - torch.manual_seed(1234) - output = layer.forward(hidden_states, attention_mask=attention_mask) - # We do this because bert, roberta, electra process the attention_mask at the model level. - attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - torch.manual_seed(1234) - hf_output = hf_module.forward(hidden_states, attention_mask=attention_mask_hf) - - assert torch.allclose(output[0], hf_output[0]) - - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - "roberta-base", - ], + assert torch.allclose(output, hf_output, atol=1e-04) + + +def _load_pretrained(global_rank, world_size, gpu_id): + TransformerLayer.from_pretrained_module( + "epwalsh/bert-xsmall-dummy", ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): - - torch.manual_seed(1234) - pretrained = cached_transformers.get(pretrained_name, False) - - if "distilbert" in pretrained_name: - encoder = pretrained.transformer - else: - encoder = pretrained.encoder - # Hacky way to get a bert layer. - for i, pretrained_module in enumerate(encoder.layer.modules()): - if i == 1: - break - - pretrained_module = pretrained_module - - torch.manual_seed(1234) - module = TransformerLayer.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - batch_size = 2 - seq_len = 768 - dim = module.attention.self.query.in_features - hidden_states = torch.randn(batch_size, seq_len, dim) - attention_mask = torch.randint(0, 2, (batch_size, seq_len)) - mask_reshp = (batch_size, 1, 1, dim) - attention_mask_hf = (attention_mask == 0).view(mask_reshp).expand( - batch_size, 12, seq_len, seq_len - ) * -10e5 - - # setting to eval mode to avoid non-deterministic dropout. - module = module.eval() - pretrained_module = pretrained_module.eval() - - torch.manual_seed(1234) - output = module.forward(hidden_states, attention_mask=attention_mask.squeeze())[0] - torch.manual_seed(1234) - hf_output = pretrained_module.forward(hidden_states, attention_mask=attention_mask_hf)[0] - - assert torch.allclose(output, hf_output, atol=1e-04) + + +@pytest.mark.parametrize("test_func", [_load_pretrained]) +def test_distributed(test_func): + run_distributed_test([-1, -1], func=test_func, start_method="spawn") diff --git a/tests/modules/transformer/transformer_module_test.py b/tests/modules/transformer/transformer_module_test.py index d5002f215ea..4018229c41d 100644 --- a/tests/modules/transformer/transformer_module_test.py +++ b/tests/modules/transformer/transformer_module_test.py @@ -1,74 +1,89 @@ import torch +from torch.nn import Parameter -from allennlp.common.testing import assert_equal_parameters +from allennlp.common.testing import assert_equal_parameters, assert_allclose from allennlp.modules.transformer import TransformerModule from allennlp.common.testing import AllenNlpTestCase class TestTransformerModule(AllenNlpTestCase): - def test_can_load_pretrained_weights(self): + def test_get_mapped_state_dict(self): class InternalOld(torch.nn.Module): def __init__(self, inp, out): super().__init__() self.ff = torch.nn.Linear(inp, out) + self.p = Parameter(torch.randn(out, out)) + self.register_buffer("b", torch.randn(inp, inp)) def forward(self, x): - x = self.ff(x) + x = self.ff(x).matmul(self.p) return x class InternalNew(TransformerModule): + _pretrained_mapping = {"ff": "linear", "p": "param", "b": "buffer"} + def __init__(self, inp, out): super().__init__() self.linear = torch.nn.Linear(inp, out) - - def _construct_default_mapping(self, pretrained_module, source, mapping): - # return {"linear": "ff"} - return {"ff": "linear"} + self.param = Parameter(torch.randn(out, out)) + self.register_buffer("buffer", torch.randn(inp, inp)) def forward(self, x): - x = self.linear(x) + x = self.linear(x).matmul(self.param) return x class ExternalOld(torch.nn.Module): def __init__(self, inp, out): super().__init__() self.internal = InternalOld(inp, out) + self.p = Parameter(torch.randn(out, out)) def forward(self, x): - x = self.internal(x) + x = self.internal(x).matmul(self.p) return x - class External(TransformerModule): - # _huggingface_mapping = {"internal_layer": "internal"} - _huggingface_mapping = {"internal": "internal_layer"} + class ExternalNew(TransformerModule): + _pretrained_mapping = {"internal": "internal_layer", "p": "param"} def __init__(self, inp, out): super().__init__() self.internal_layer = InternalNew(inp, out) + self.param = Parameter(torch.randn(out, out)) def forward(self, x): - x = self.internal_layer(x) + x = self.internal_layer(x).matmul(self.param) return x - iold = InternalOld(3, 5) - x = torch.randn(4, 3) - iold.forward(x) - inew = InternalNew(3, 5) - inew._load_from_pretrained_module(iold) - mapping = { - val: key - for key, val in inew._construct_default_mapping(iold, "huggingface", {}).items() - } - assert_equal_parameters(iold, inew, mapping=mapping) - eold = ExternalOld(3, 5) + state_dict_old = eold.state_dict() + + enew = ExternalNew(3, 5) + state_dict_new = enew._get_mapped_state_dict(state_dict_old) + assert set(state_dict_new.keys()) == set( + [ + "internal_layer.linear.weight", + "internal_layer.linear.bias", + "internal_layer.param", + "internal_layer.buffer", + "param", + ] + ) + + enew.load_state_dict(state_dict_new) + x = torch.randn(4, 3) - eold.forward(x) - - enew = External(3, 5) - enew._load_from_pretrained_module(eold) - mapping = { - val: key - for key, val in enew._construct_default_mapping(eold, "huggingface", {}).items() - } - assert_equal_parameters(eold, enew, mapping=mapping) + out_old = eold(x) + out_new = enew(x) + assert_allclose(out_old, out_new) + + assert_equal_parameters( + eold, + enew, + mapping={ + "internal_layer.linear.weight": "internal.ff.weight", + "internal_layer.linear.bias": "internal.ff.bias", + "internal_layer.param": "internal.p", + "internal_layer.buffer": "internal.b", + "param": "p", + }, + ) diff --git a/tests/modules/transformer/transformer_stack_test.py b/tests/modules/transformer/transformer_stack_test.py index 0481a407937..cf42f6c0f6d 100644 --- a/tests/modules/transformer/transformer_stack_test.py +++ b/tests/modules/transformer/transformer_stack_test.py @@ -1,20 +1,12 @@ import copy + import torch import pytest from allennlp.common import Params from allennlp.common import cached_transformers - -from allennlp.common.testing import assert_equal_parameters from allennlp.modules.transformer import TransformerStack, TransformerLayer -from allennlp.common.testing import AllenNlpTestCase -from transformers.models.bert.configuration_bert import BertConfig -from transformers.models.bert.modeling_bert import BertEncoder -from transformers.models.roberta.configuration_roberta import RobertaConfig -from transformers.models.roberta.modeling_roberta import RobertaEncoder -from transformers.models.electra.configuration_electra import ElectraConfig -from transformers.models.electra.modeling_electra import ElectraEncoder PARAMS_DICT = { "num_hidden_layers": 3, @@ -26,208 +18,93 @@ "activation": "relu", } - -def get_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) - params["attention_probs_dropout_prob"] = params.pop("attention_dropout") - params["hidden_dropout_prob"] = params.pop("hidden_dropout") - - torch.manual_seed(1234) - hf_module = BertEncoder(BertConfig(**params)) - modules["bert"] = hf_module - - torch.manual_seed(1234) - hf_module = RobertaEncoder(RobertaConfig(**params)) - modules["roberta"] = hf_module - - torch.manual_seed(1234) - hf_module = ElectraEncoder(ElectraConfig(**params)) - modules["electra"] = hf_module - - return modules +SEED = 1234 -class TestTransformerStack(AllenNlpTestCase): - def setup_method(self): - super().setup_method() +@pytest.fixture +def params(): + return Params(copy.deepcopy(PARAMS_DICT)) - self.params_dict = { - "num_hidden_layers": 3, - "hidden_size": 6, - "intermediate_size": 3, - "num_attention_heads": 2, - "attention_dropout": 0.1, - "hidden_dropout": 0.2, - "activation": "relu", - } - params = Params(copy.deepcopy(self.params_dict)) +def test_transformer_stack_from_params(params): + torch.manual_seed(SEED) + transformer_stack = TransformerStack.from_params(params) - self.transformer_stack = TransformerStack.from_params(params) + # Make sure we have the right number of modules. + modules = dict(transformer_stack.named_modules()) + assert len(modules["layers"]) == PARAMS_DICT["num_hidden_layers"] - self.pretrained_name = "bert-base-uncased" + hidden_states = torch.randn(2, 3, 6) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - self.pretrained = cached_transformers.get(self.pretrained_name, False) + # Make sure forward pass can run. + torch.manual_seed(SEED) + output = transformer_stack.forward(hidden_states, attention_mask=attention_mask) - def test_can_construct_from_params(self): - - modules = dict(self.transformer_stack.named_modules()) - assert len(modules["layers"]) == self.params_dict["num_hidden_layers"] - - def test_forward_runs(self): - self.transformer_stack.forward(torch.randn(2, 3, 6), attention_mask=torch.randn(2, 3)) - - with pytest.raises(AssertionError): - self.transformer_stack.forward( - torch.randn(2, 3, 6), - attention_mask=torch.randn(2, 3), - encoder_hidden_states=torch.randn(2, 3, 6), - ) - - def test_layer_same_as_params(self): - params = copy.deepcopy(self.params_dict) - num_hidden_layers = params.pop("num_hidden_layers") - # params = Params(params) - - torch.manual_seed(1234) - transformer_layer = TransformerLayer(**params) - transformer_stack_from_layer = TransformerStack(num_hidden_layers, transformer_layer) - torch.manual_seed(1234) - transformer_stack_from_params = TransformerStack(num_hidden_layers, **params) + # Make sure we get the same results when instantiating from a single layer. + torch.manual_seed(SEED) + layer_params = copy.deepcopy(PARAMS_DICT) + num_hidden_layers = layer_params.pop("num_hidden_layers") + transformer_layer = TransformerLayer(**layer_params) # type: ignore[arg-type] + transformer_stack_from_layer = TransformerStack( + num_hidden_layers, transformer_layer # type: ignore[arg-type] + ) - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + torch.manual_seed(SEED) + from_layer_output = transformer_stack_from_layer.forward( + hidden_states, attention_mask=attention_mask + ) - transformer_stack_from_layer.eval() - transformer_stack_from_params.eval() + assert torch.allclose(from_layer_output[0], output[0]) - torch.manual_seed(1234) - layer_output = transformer_stack_from_layer.forward( - hidden_states, attention_mask=attention_mask + # Make sure forward pass raises with bad input. + with pytest.raises(AssertionError): + transformer_stack.forward( + torch.randn(2, 3, 6), + attention_mask=torch.randn(2, 3), + encoder_hidden_states=torch.randn(2, 3, 6), ) - torch.manual_seed(1234) - params_output = transformer_stack_from_params.forward( - hidden_states, attention_mask=attention_mask - ) - assert torch.allclose(layer_output[0], params_output[0]) +def test_transformer_stack_with_cross_attention(params): + params["add_cross_attention"] = True - def test_cross_attention(self): - params = copy.deepcopy(self.params_dict) - params["add_cross_attention"] = True + transformer_stack = TransformerStack.from_params(params).eval() + modules = dict(transformer_stack.named_modules()) - params = Params(params) + assert hasattr(modules["layers.0"], "cross_attention") - transformer_stack = TransformerStack.from_params(params) - modules = dict(transformer_stack.named_modules()) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + transformer_stack.forward( + torch.randn(2, 3, 6), + attention_mask=attention_mask, + encoder_hidden_states=torch.randn(2, 3, 6), + ) - assert hasattr(modules["layers.0"], "cross_attention") - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - transformer_stack.forward( - torch.randn(2, 3, 6), - attention_mask=attention_mask, - encoder_hidden_states=torch.randn(2, 3, 6), - ) +@pytest.mark.parametrize("pretrained_model_name", ["epwalsh/bert-xsmall-dummy", "bert-base-cased"]) +def test_loading_from_pretrained(pretrained_model_name): + transformer_stack = TransformerStack.from_pretrained_module(pretrained_model_name).eval() + pretrained_module = cached_transformers.get(pretrained_model_name, True).encoder.eval() - transformer_stack_new = TransformerStack.from_pretrained_module( - transformer_stack, source="allennlp" - ) + batch_size = 2 + seq_length = 15 + hidden_size = transformer_stack.layers[0]._hidden_size - new_modules = dict(transformer_stack_new.named_modules()) - assert hasattr(new_modules["layers.0"], "cross_attention") - - def test_loading_from_pretrained_weights(self): - pretrained_module = self.pretrained.encoder - module = TransformerStack.from_pretrained_module(pretrained_module) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping) - - def test_loading_partial_pretrained_weights(self): - - kwargs = TransformerStack._get_input_arguments(self.pretrained.encoder) - # The pretrained module has 12 bert layers, while the instance will have only 3. - kwargs["num_hidden_layers"] = 3 - transformer_stack = TransformerStack(**kwargs) - transformer_stack._load_from_pretrained_module(self.pretrained.encoder) - mapping = { - val: key - for key, val in transformer_stack._construct_default_mapping( - self.pretrained.encoder, "huggingface", {} - ).items() - } - assert_equal_parameters( - self.pretrained.encoder, - transformer_stack, - mapping, - ) + hidden_states = torch.randn(batch_size, seq_length, hidden_size) + attention_mask = torch.randint(0, 2, (batch_size, seq_length)) + attention_mask_hf = attention_mask[:, None, None, :] + attention_mask_hf = (1.0 - attention_mask_hf) * -10e5 - @pytest.mark.skip("Takes up too much memory") - @pytest.mark.parametrize("module_name, hf_module", get_modules(PARAMS_DICT).items()) - def test_forward_against_huggingface_outputs(self, module_name, hf_module): - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + torch.manual_seed(SEED) + output = transformer_stack(hidden_states, attention_mask=attention_mask) - stack = TransformerStack.from_pretrained_module(hf_module) + torch.manual_seed(SEED) + hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf) - torch.manual_seed(1234) - output = stack.forward(hidden_states, attention_mask=attention_mask) - # We do this because bert, roberta, electra process the attention_mask at the model level. - attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - torch.manual_seed(1234) - hf_output = hf_module.forward(hidden_states, attention_mask=attention_mask_hf) + assert torch.allclose(output[0], hf_output[0]) - assert torch.allclose(output[0], hf_output[0]) - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - ], - ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): - - torch.manual_seed(1234) - pretrained = cached_transformers.get(pretrained_name, False) - - if "distilbert" in pretrained_name: - pretrained_module = pretrained.transformer - else: - pretrained_module = pretrained.encoder - - torch.manual_seed(1234) - module = TransformerStack.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - batch_size = 1 - seq_len = 768 - dim = dict(module.named_modules())["layers.0.attention.self.query"].in_features - hidden_states = torch.randn(batch_size, seq_len, dim) - attention_mask = torch.randint(0, 2, (batch_size, seq_len)) - mask_reshp = (batch_size, 1, 1, dim) - attention_mask_hf = (attention_mask == 0).view(mask_reshp) - attention_mask_hf = attention_mask_hf.expand(batch_size, 12, seq_len, seq_len) * -10e5 - - # setting to eval mode to avoid non-deterministic dropout. - module = module.eval() - pretrained_module = pretrained_module.eval() - - torch.manual_seed(1234) - output = module.forward(hidden_states, attention_mask=attention_mask.squeeze())[0] - torch.manual_seed(1234) - hf_output = pretrained_module.forward(hidden_states, attention_mask=attention_mask_hf)[0] - - assert torch.allclose(output, hf_output) +def test_loading_partial_pretrained_weights(): + # The pretrained module has 12 bert layers, while the instance will have only 3. + TransformerStack.from_pretrained_module("bert-base-cased", num_hidden_layers=3, strict=False) diff --git a/tests/nn/util_test.py b/tests/nn/util_test.py index 7ca660ed04d..73a9952a11f 100644 --- a/tests/nn/util_test.py +++ b/tests/nn/util_test.py @@ -9,7 +9,7 @@ from flaky import flaky from allennlp.common.checks import ConfigurationError -from allennlp.common.testing import AllenNlpTestCase +from allennlp.common.testing import AllenNlpTestCase, run_distributed_test from allennlp.common.util import sanitize from allennlp.data import Token, Vocabulary from allennlp.data.fields import TextField @@ -1730,8 +1730,6 @@ def test_dist_reduce_sum(self): ret_value = util.dist_reduce_sum(value) assert (ret_value == value).all().item() - from allennlp.common.testing.distributed_test import run_distributed_test - func_kwargs = {"value": [torch.Tensor([1, 2, 3]), torch.Tensor([4, 5, 6])]} desired_values = torch.Tensor([5, 7, 9]) @@ -1761,3 +1759,79 @@ def global_distributed_func( output = function(**kwargs) assert (output == desired_values).all().item() + + +class DistributedFixtureModel(torch.nn.Module): + """ + Fake model for testing `load_state_dict_distributed()`. + """ + + def __init__(self): + super().__init__() + self.direct_param = torch.nn.Parameter(torch.randn(3, 5)) + self.register_buffer("direct_buffer", torch.randn(2, 2)) + self.custom_submodule = DistributedFixtureSubmodule() + self.custom_sharded_submodule = DistributedFixtureSubmodule(sharded=True) + self.linear_submodule = torch.nn.Linear(3, 5) + + def forward(self, x): + # This doesn't matter, we're not going to actually use it. + pass + + +class DistributedFixtureSubmodule(torch.nn.Module): + def __init__(self, sharded: bool = False): + super().__init__() + self.direct_param = torch.nn.Parameter(torch.randn(3, 5)) + self.register_buffer("direct_buffer", torch.randn(2, 2)) + self.linear_submodule = torch.nn.Linear(3, 5) + if sharded: + setattr(self, util._MODULE_SHARDED_FLAG, True) + + def forward(self, x): + # This doesn't matter, we're not going to actually use it. + pass + + +def _dist_load_ok(global_rank, world_size, gpu_id): + model = DistributedFixtureModel() + state_dict = None if global_rank != 0 else model.state_dict() + missing_keys, unexpected_keys = util.load_state_dict_distributed(model, state_dict) + assert not missing_keys + assert not unexpected_keys + + +def _dist_load_with_errors(global_rank, world_size, gpu_id): + model = DistributedFixtureModel() + state_dict = None if global_rank != 0 else model.state_dict() + _missing_keys = [ + "direct_buffer", + "custom_submodule.linear_submodule.bias", + "custom_submodule.direct_param", + "custom_sharded_submodule.linear_submodule.bias", + "custom_sharded_submodule.direct_buffer", + ] + _unexpected_keys = [ + "not_a_parameter", + "custom_submodule.not_a_parameter", + "custom_submodule.linear.not_a_parameter", + "custom_sharded_submodule.not_a_parameter", + "custom_sharded_submodule.linear.not_a_parameter", + "not_even_submodule.not_a_parameter", + ] + if state_dict is not None: + for key in _missing_keys: + del state_dict[key] + for key in _unexpected_keys: + state_dict[key] = torch.randn(2, 2) + missing_keys, unexpected_keys = util.load_state_dict_distributed( + model, state_dict, strict=False + ) + if global_rank == 0: + assert set(missing_keys) == set(_missing_keys) + assert set(unexpected_keys) == set(_unexpected_keys) + + +@pytest.mark.parametrize("test_func", [_dist_load_ok, _dist_load_with_errors]) +def test_load_state_dict_distributed(test_func): + run_distributed_test([-1, -1], func=test_func) From 0ea92252563db955b0ccd6e4f871ad313274478d Mon Sep 17 00:00:00 2001 From: Daniel Deutsch <danieldeutsch@users.noreply.github.com> Date: Mon, 17 May 2021 17:26:39 -0400 Subject: [PATCH 25/63] Add a `min_steps` parameter to `BeamSearch` (#5207) * Implementing minimum number of decoding steps * Adding unit tests * Reformatting * Adding entry to changelog * Adding end token comment * Adding start token comment * Changing param to optional Co-authored-by: Pete <petew@allenai.org> --- CHANGELOG.md | 1 + allennlp/nn/beam_search.py | 25 +++++++++- tests/nn/beam_search_test.py | 95 +++++++++++++++++++++++++++++++++++- 3 files changed, 119 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92308cd9c2f..a316a6ce2c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 You can do this by setting the parameter `load_weights` to `False`. See [PR #5172](https://github.com/allenai/allennlp/pull/5172) for more details. - Added `SpanExtractorWithSpanWidthEmbedding`, putting specific span embedding computations into the `_embed_spans` method and leaving the common code in `SpanExtractorWithSpanWidthEmbedding` to unify the arguments, and modified `BidirectionalEndpointSpanExtractor`, `EndpointSpanExtractor` and `SelfAttentiveSpanExtractor` accordingly. Now, `SelfAttentiveSpanExtractor` can also embed span widths. +- Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences. ### Fixed diff --git a/allennlp/nn/beam_search.py b/allennlp/nn/beam_search.py index fff07b7dac2..f1d43226a83 100644 --- a/allennlp/nn/beam_search.py +++ b/allennlp/nn/beam_search.py @@ -1,5 +1,5 @@ from inspect import signature -from typing import List, Callable, Tuple, Dict, cast, TypeVar +from typing import List, Callable, Tuple, Dict, cast, TypeVar, Optional import warnings from overrides import overrides @@ -462,6 +462,11 @@ class BeamSearch(FromParams): Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039). + + min_steps : `int`, optional (default = `None`) + The minimum number of decoding steps to take, i.e. the minimum length of + the predicted sequences. This does not include the start or end tokens. If `None`, + no minimum is enforced. """ def __init__( @@ -471,6 +476,7 @@ def __init__( beam_size: int = 10, per_node_beam_size: int = None, sampler: Sampler = None, + min_steps: Optional[int] = None, ) -> None: if not max_steps > 0: raise ValueError("max_steps must be positive") @@ -478,12 +484,18 @@ def __init__( raise ValueError("beam_size must be positive") if per_node_beam_size is not None and not per_node_beam_size > 0: raise ValueError("per_node_beam_size must be positive") + if min_steps is not None: + if not min_steps >= 0: + raise ValueError("min_steps must be non-negative") + if not min_steps <= max_steps: + raise ValueError("min_steps must be less than or equal to max_steps") self._end_index = end_index self.max_steps = max_steps self.beam_size = beam_size self.per_node_beam_size = per_node_beam_size or beam_size self.sampler = sampler or DeterministicSampler() + self.min_steps = min_steps or 0 @staticmethod def _reconstruct_sequences(predictions, backpointers): @@ -629,6 +641,10 @@ def _search( start_class_log_probabilities, batch_size, num_classes ) + # Prevent selecting the end symbol if there is any min_steps constraint + if self.min_steps >= 1: + start_class_log_probabilities[:, self._end_index] = float("-inf") + # Get the initial predicted classed and their log probabilities. # shape: (batch_size, beam_size), (batch_size, beam_size) ( @@ -675,6 +691,13 @@ def _search( # shape: (batch_size * beam_size, num_classes) class_log_probabilities, state = step(last_predictions, state, timestep + 1) + # The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token + # of the sequence (because `timestep` is 0-indexed and we generated the first token + # before the for loop). Here we block the end index if the search is not allowed to + # terminate on this iteration. + if timestep + 2 <= self.min_steps: + class_log_probabilities[:, self._end_index] = float("-inf") + # shape: (batch_size * beam_size, num_classes) last_predictions_expanded = last_predictions.unsqueeze(-1).expand( batch_size * self.beam_size, num_classes diff --git a/tests/nn/beam_search_test.py b/tests/nn/beam_search_test.py index 88614f88e9d..4fcd892ab91 100644 --- a/tests/nn/beam_search_test.py +++ b/tests/nn/beam_search_test.py @@ -27,6 +27,18 @@ ] # end token -> jth token ) +# A transition matrix that favors shorter sequences over longer ones +short_sequence_transition_probabilities = torch.tensor( + [ + [0.0, 0.1, 0.0, 0.0, 0.0, 0.9], # start token -> jth token + [0.0, 0.0, 0.1, 0.0, 0.0, 0.9], # 1st token -> jth token + [0.0, 0.0, 0.0, 0.1, 0.0, 0.9], # 2nd token -> jth token + [0.0, 0.0, 0.0, 0.0, 0.1, 0.9], # ... + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # ... + [0.2, 0.1, 0.2, 0.2, 0.2, 0.3], + ] # end token -> jth token +) + log_probabilities = torch.log( torch.tensor([[0.1, 0.3, 0.3, 0.3, 0.0, 0.0], [0.0, 0.0, 0.4, 0.3, 0.2, 0.1]]) ) @@ -62,6 +74,25 @@ def take_step_with_timestep( return take_step_no_timestep(last_predictions, state) +def take_short_sequence_step( + last_predictions: torch.Tensor, + state: Dict[str, torch.Tensor], + timestep: int, +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Take decoding step. + + This method is the same as `take_step_no_timestep` except it uses the + `short_sequence_transition_probabilities` transitions instead of `transition_probabilities` + """ + log_probs_list = [] + for last_token in last_predictions: + log_probs = torch.log(short_sequence_transition_probabilities[last_token.item()]) + log_probs_list.append(log_probs) + + return torch.stack(log_probs_list), state + + class BeamSearchTest(AllenNlpTestCase): def setup_method(self): super().setup_method() @@ -101,7 +132,7 @@ def _check_results( # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`. assert list(log_probs.size()) == [batch_size, beam_size] - np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs) + np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs, rtol=1e-6) @pytest.mark.parametrize("step_function", [take_step_with_timestep, take_step_no_timestep]) def test_search(self, step_function): @@ -211,6 +242,68 @@ def test_early_stopping(self): beam_search=beam_search, ) + def test_take_short_sequence_step(self): + """ + Tests to ensure the top-k from the short_sequence_transition_probabilities + transition matrix is expected + """ + self.beam_search.beam_size = 5 + expected_top_k = np.array( + [[5, 5, 5, 5, 5], [1, 5, 5, 5, 5], [1, 2, 5, 5, 5], [1, 2, 3, 5, 5], [1, 2, 3, 4, 5]] + ) + expected_log_probs = np.log(np.array([0.9, 0.09, 0.009, 0.0009, 0.0001])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_short_sequence_step, + ) + + def test_min_steps(self): + """ + Tests to ensure all output sequences are greater than a specified minimum length. + It uses the `take_short_sequence_step` step function, which favors shorter sequences. + See `test_take_short_sequence_step`. + """ + self.beam_search.beam_size = 1 + + # An empty sequence is allowed under this step function + self.beam_search.min_steps = 0 + expected_top_k = np.array([[5]]) + expected_log_probs = np.log(np.array([0.9])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_short_sequence_step, + ) + + self.beam_search.min_steps = 1 + expected_top_k = np.array([[1, 5]]) + expected_log_probs = np.log(np.array([0.09])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_short_sequence_step, + ) + + self.beam_search.min_steps = 2 + expected_top_k = np.array([[1, 2, 5]]) + expected_log_probs = np.log(np.array([0.009])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_short_sequence_step, + ) + + self.beam_search.beam_size = 3 + self.beam_search.min_steps = 2 + expected_top_k = np.array([[1, 2, 5, 5, 5], [1, 2, 3, 5, 5], [1, 2, 3, 4, 5]]) + expected_log_probs = np.log(np.array([0.009, 0.0009, 0.0001])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_short_sequence_step, + ) + def test_different_per_node_beam_size(self): # per_node_beam_size = 1 beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=1) From 9de5b4e0ee1ae1a648185ef11a1feb4c8484c0ac Mon Sep 17 00:00:00 2001 From: Daniel Deutsch <danieldeutsch@users.noreply.github.com> Date: Tue, 18 May 2021 12:28:51 -0400 Subject: [PATCH 26/63] Implementing abstraction to score final sequences in `BeamSearch` (#5208) * Implementing FinalSequenceScorer in BeamSearch * Including the end token in the normalization * Reformating * Apply suggestions from code review Co-authored-by: Pete <epwalsh10@gmail.com> * Sorting the sequences by the final scores Co-authored-by: Pete <petew@allenai.org> Co-authored-by: Pete <epwalsh10@gmail.com> --- CHANGELOG.md | 1 + allennlp/nn/beam_search.py | 120 ++++++++++++++++++++++++++++++++++- tests/nn/beam_search_test.py | 58 +++++++++++++++++ 3 files changed, 176 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a316a6ce2c4..0f8a9b13503 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 See [PR #5172](https://github.com/allenai/allennlp/pull/5172) for more details. - Added `SpanExtractorWithSpanWidthEmbedding`, putting specific span embedding computations into the `_embed_spans` method and leaving the common code in `SpanExtractorWithSpanWidthEmbedding` to unify the arguments, and modified `BidirectionalEndpointSpanExtractor`, `EndpointSpanExtractor` and `SelfAttentiveSpanExtractor` accordingly. Now, `SelfAttentiveSpanExtractor` can also embed span widths. - Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences. +- Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`. ### Fixed diff --git a/allennlp/nn/beam_search.py b/allennlp/nn/beam_search.py index f1d43226a83..4337e3efc4a 100644 --- a/allennlp/nn/beam_search.py +++ b/allennlp/nn/beam_search.py @@ -431,6 +431,99 @@ def gumbel_with_max(self, phi, T) -> torch.Tensor: return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs())) +class FinalSequenceScorer(Registrable): + """ + An abstract class that can be used to score the final generated sequences found + by beam search. Given the predicted sequences and the corresponding log probabilities of + those sequences, the class calculates and returns the final score of the sequences. + + The default implementation scores the sequences using the sum of the log probabilities of + the sequence, which is passed as input. + """ + + default_implementation = "sequence-log-prob" + + def score( + self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int + ) -> torch.Tensor: + """ + Score the final predictions found by beam search. + + # Parameters + + predictions : `torch.Tensor` + A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`. + + log_probabilities : `torch.Tensor` + A tensor containing the log probabilities of the sequence, defined as the sum + of the log probabilities per token, with shape `(batch_size, beam_size)`. + + end_index : `int` + The index of the end symbol. + + # Returns + + `torch.Tensor` + A tensor of the final sequence scores of shape `(batch_size, beam_size)`. + """ + raise NotImplementedError + + +@FinalSequenceScorer.register("sequence-log-prob") +class SequenceLogProbabilityScorer(FinalSequenceScorer): + """ + A `FinalSequenceScorer` which scores the sequences by the sum of the log probabilities + across the sequence's tokens. + """ + + @overrides + def score( + self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int + ) -> torch.Tensor: + # The sum of the sequence log probabilities is the input parameter, so just + # return it. + return log_probabilities + + +@FinalSequenceScorer.register("length-normalized-sequence-log-prob") +class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer): + """ + A `FinalSequenceScorer` which scores the sequences by the average log probability of the + tokens in the sequence. It optionally includes a length penalty which promotes + or demotes sequences based on their lengths. The final score for a sequence will + be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length + here includes the end token. + + # Parameters + + length_penalty : `float`, optional (default = `1.0`) + The length penalty to use. A value of 1.0 means no length penalty is used. + A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences. + """ + + def __init__(self, length_penalty: float = 1.0): + super().__init__() + self.length_penalty = length_penalty + + @overrides + def score( + self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int + ) -> torch.Tensor: + # shape: (batch_size, beam_size) + lengths = (predictions != end_index).long().sum(dim=2) + + # If the sequence ended during beam search, the `log_probabilities` will include + # the transition to the end token. Therefore, in such situations, `lengths` is + # actually off by 1. This corrects for that. + # shape: (batch_size, beam_size) + is_end_token = predictions[:, :, -1] == end_index + lengths += is_end_token.long() + + # shape: (batch_size, beam_size) + average_log_probs = log_probabilities / (lengths ** self.length_penalty) + return average_log_probs + + class BeamSearch(FromParams): """ Implements the beam search algorithm for decoding the most likely sequences. @@ -467,6 +560,12 @@ class BeamSearch(FromParams): The minimum number of decoding steps to take, i.e. the minimum length of the predicted sequences. This does not include the start or end tokens. If `None`, no minimum is enforced. + + final_sequence_scorer : `FinalSequenceScorer`, optional (default = `None`) + An optional `FinalSequenceScorer` which is used to score the final generated sequences. + The output from this module is what is returned by the `search` method. If not + specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences + by the sum of the token log probabilities. """ def __init__( @@ -477,6 +576,7 @@ def __init__( per_node_beam_size: int = None, sampler: Sampler = None, min_steps: Optional[int] = None, + final_sequence_scorer: FinalSequenceScorer = None, ) -> None: if not max_steps > 0: raise ValueError("max_steps must be positive") @@ -496,6 +596,7 @@ def __init__( self.per_node_beam_size = per_node_beam_size or beam_size self.sampler = sampler or DeterministicSampler() self.min_steps = min_steps or 0 + self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer() @staticmethod def _reconstruct_sequences(predictions, backpointers): @@ -580,8 +681,8 @@ def search( # Returns `Tuple[torch.Tensor, torch.Tensor]` - Tuple of `(predictions, log_probabilities)`, where `predictions` - has shape `(batch_size, beam_size, max_steps)` and `log_probabilities` + Tuple of `(predictions, final_scores)`, where `predictions` + has shape `(batch_size, beam_size, max_steps)` and `final_scores` has shape `(batch_size, beam_size)`. """ step_signature = signature(step) @@ -786,7 +887,20 @@ def _search( # shape: (batch_size, beam_size, max_steps) all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2) - return all_predictions, last_log_probabilities + # Calculate the final sequence scores + # shape: (batch_size, beam_size) + final_scores = self.final_sequence_scorer.score( + all_predictions, last_log_probabilities, self._end_index + ) + + # Sort the sequences based on the final scores so the best scoring + # sequence is at index 0 + sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True) + sorted_all_predictions = torch.gather( + all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions) + ) + + return sorted_all_predictions, sorted_final_scores @staticmethod def _is_multilayer_rnn_decoder(key: str, state_tensor: torch.Tensor) -> bool: diff --git a/tests/nn/beam_search_test.py b/tests/nn/beam_search_test.py index 4fcd892ab91..275390cc135 100644 --- a/tests/nn/beam_search_test.py +++ b/tests/nn/beam_search_test.py @@ -12,6 +12,8 @@ TopKSampler, TopPSampler, GumbelSampler, + SequenceLogProbabilityScorer, + LengthNormalizedSequenceLogProbabilityScorer, ) from allennlp.common.params import Params @@ -538,3 +540,59 @@ def test_gumbel_sampler(self): assert all([x >= 0 and x < 4 for x in indices[0]]) assert all([x > 1 and x <= 5 for x in indices[1]]) + + def test_sequence_log_prob_scorer(self): + # SequenceLogProbabilityScorer is the default, so manually setting the + # sequence scorer shouldn't actually change anything + self.beam_search.sequence_scorer = SequenceLogProbabilityScorer() + + def test_length_normalized_sequence_log_prob_scorer(self): + """ + Tests to ensure the sequences are normalized by the correct values. The end token is + included in the length. The start token is not. + """ + self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer() + expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) + length_normalization = np.array([5, 4, 3]) + expected_scores = expected_log_probs / length_normalization + self._check_results(expected_log_probs=expected_scores) + + # Introduce a length penalty + length_penalty = 2.0 + self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer( + length_penalty=length_penalty + ) + expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) + length_normalization = np.array( + [5 ** length_penalty, 4 ** length_penalty, 3 ** length_penalty] + ) + expected_scores = expected_log_probs / length_normalization + self._check_results(expected_log_probs=expected_scores) + + # Pick a length penalty so extreme that the order of the sequences is reversed + length_penalty = -2.0 + self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer( + length_penalty=length_penalty + ) + expected_top_k = np.array([[3, 4, 5, 5, 5], [2, 3, 4, 5, 5], [1, 2, 3, 4, 5]]) + expected_log_probs = np.log(np.array([0.2, 0.3, 0.4])) + length_normalization = np.array( + [3 ** length_penalty, 4 ** length_penalty, 5 ** length_penalty] + ) + expected_scores = expected_log_probs / length_normalization + self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores) + + # Here, we set the max_steps = 4. This prevents the first sequence from finishing, + # so its length does not include the end token, whereas the other sequences do. + length_penalty = 2.0 + self.beam_search.max_steps = 4 + self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer( + length_penalty=length_penalty + ) + expected_top_k = np.array([[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 5]]) + expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) + length_normalization = np.array( + [4 ** length_penalty, 4 ** length_penalty, 3 ** length_penalty] + ) + expected_scores = expected_log_probs / length_normalization + self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores) From 56606701f95d1eb888f9291c989b3ab8a3eb9c1d Mon Sep 17 00:00:00 2001 From: ArjunSubramonian <arjun.subramonian@gmail.com> Date: Wed, 19 May 2021 12:18:59 -0700 Subject: [PATCH 27/63] added shuffle disable option in BucketBatchSampler (#5212) * added shuffle disable option in BucketBatchSampler * Update allennlp/data/samplers/bucket_batch_sampler.py Co-authored-by: Pete <petew@allenai.org> Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-0-106.us-west-2.compute.internal> Co-authored-by: Pete <petew@allenai.org> --- CHANGELOG.md | 1 + allennlp/data/samplers/bucket_batch_sampler.py | 11 ++++++++++- tests/data/samplers/bucket_batch_sampler_test.py | 14 ++++++++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f8a9b13503..79457e13f8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `SpanExtractorWithSpanWidthEmbedding`, putting specific span embedding computations into the `_embed_spans` method and leaving the common code in `SpanExtractorWithSpanWidthEmbedding` to unify the arguments, and modified `BidirectionalEndpointSpanExtractor`, `EndpointSpanExtractor` and `SelfAttentiveSpanExtractor` accordingly. Now, `SelfAttentiveSpanExtractor` can also embed span widths. - Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences. - Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`. +- Added `shuffle` argument to `BucketBatchSampler` which allows for disabling shuffling. ### Fixed diff --git a/allennlp/data/samplers/bucket_batch_sampler.py b/allennlp/data/samplers/bucket_batch_sampler.py index d65a676f14c..e4aa125741f 100644 --- a/allennlp/data/samplers/bucket_batch_sampler.py +++ b/allennlp/data/samplers/bucket_batch_sampler.py @@ -57,6 +57,10 @@ class BucketBatchSampler(BatchSampler): If `True`, the sampler will drop the last batch if its size would be less than batch_size`. + shuffle : `bool`, (default = `True`) + If `False`, the sampler won't shuffle the batches. `padding_noise` will be ignored and set + to `0.0`. + """ def __init__( @@ -65,11 +69,15 @@ def __init__( sorting_keys: List[str] = None, padding_noise: float = 0.1, drop_last: bool = False, + shuffle: bool = True, ): self.sorting_keys = sorting_keys self.padding_noise = padding_noise self.batch_size = batch_size self.drop_last = drop_last + self.shuffle = shuffle + if not shuffle: + self.padding_noise = 0.0 def _argsort_by_padding( self, instances: Iterable[Instance] @@ -113,7 +121,8 @@ def get_batch_indices(self, instances: Sequence[Instance]) -> Iterable[List[int] if self.drop_last and len(batch_indices) < self.batch_size: continue batches.append(batch_indices) - random.shuffle(batches) + if self.shuffle: + random.shuffle(batches) for batch in batches: yield batch diff --git a/tests/data/samplers/bucket_batch_sampler_test.py b/tests/data/samplers/bucket_batch_sampler_test.py index 3a972facdc2..450c825cc3c 100644 --- a/tests/data/samplers/bucket_batch_sampler_test.py +++ b/tests/data/samplers/bucket_batch_sampler_test.py @@ -24,6 +24,20 @@ def test_create_batches_groups_correctly(self): expected_groups.remove(group) assert expected_groups == [] + def test_disable_shuffle(self): + sampler = BucketBatchSampler(batch_size=2, sorting_keys=["text"], shuffle=False) + + grouped_instances = [] + for indices in sampler.get_batch_indices(self.instances): + grouped_instances.append([self.instances[idx] for idx in indices]) + expected_groups = [ + [self.instances[4], self.instances[2]], + [self.instances[0], self.instances[1]], + [self.instances[3]], + ] + for idx, group in enumerate(grouped_instances): + assert group == expected_groups[idx] + def test_guess_sorting_key_picks_the_longest_key(self): sampler = BucketBatchSampler(batch_size=2, padding_noise=0) instances = [] From 73e570b352278995f5cdd2be7dbfa633edc54693 Mon Sep 17 00:00:00 2001 From: Pete <petew@allenai.org> Date: Wed, 19 May 2021 13:34:00 -0700 Subject: [PATCH 28/63] save meta data with model archives (#5209) Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com> --- CHANGELOG.md | 2 ++ allennlp/commands/train.py | 4 +++ allennlp/common/__init__.py | 1 + allennlp/common/meta.py | 37 ++++++++++++++++++++++ allennlp/models/archival.py | 59 +++++++++++++++++++++++++++++++++-- tests/commands/train_test.py | 7 ++++- tests/models/archival_test.py | 19 ++++++++++- 7 files changed, 125 insertions(+), 4 deletions(-) create mode 100644 allennlp/common/meta.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 79457e13f8f..739dff1071e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.confidence_checks.task_checklists` module. - Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files. +- Meta data defined by the class `allennlp.common.meta.Meta` is now saved in the serialization directory and archive file + when training models from the command line. This is also now part of the `Archive` named tuple that's returned from `load_archive()`. - Added `nn.util.distributed_device()` helper function. - Added `allennlp.nn.util.load_state_dict` helper function. - Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers diff --git a/allennlp/commands/train.py b/allennlp/commands/train.py index c844c222ae4..5304e6e0735 100644 --- a/allennlp/commands/train.py +++ b/allennlp/commands/train.py @@ -19,6 +19,7 @@ from allennlp.commands.subcommand import Subcommand from allennlp.common import Params, Registrable, Lazy from allennlp.common.checks import check_for_gpu, ConfigurationError +from allennlp.common.meta import Meta, META_NAME from allennlp.common import logging as common_logging from allennlp.common import util as common_util from allennlp.common.plugins import import_plugins @@ -226,6 +227,9 @@ def train_model( training_util.create_serialization_dir(params, serialization_dir, recover, force) params.to_file(os.path.join(serialization_dir, CONFIG_NAME)) + meta = Meta.new() + meta.to_file(os.path.join(serialization_dir, META_NAME)) + include_in_archive = params.pop("include_in_archive", None) verify_include_in_archive(include_in_archive) diff --git a/allennlp/common/__init__.py b/allennlp/common/__init__.py index 8c5857e5097..c865c04b0fc 100644 --- a/allennlp/common/__init__.py +++ b/allennlp/common/__init__.py @@ -4,3 +4,4 @@ from allennlp.common.registrable import Registrable from allennlp.common.tqdm import Tqdm from allennlp.common.util import JsonDict +from allennlp.common.meta import Meta diff --git a/allennlp/common/meta.py b/allennlp/common/meta.py new file mode 100644 index 00000000000..787442c2786 --- /dev/null +++ b/allennlp/common/meta.py @@ -0,0 +1,37 @@ +from os import PathLike +from dataclasses import dataclass, asdict +import json +import logging +from typing import Union + +from allennlp.version import VERSION + + +logger = logging.getLogger(__name__) + + +META_NAME = "meta.json" + + +@dataclass +class Meta: + """ + Defines the meta data that's saved in a serialization directory and archive + when training an AllenNLP model. + """ + + version: str + + @classmethod + def new(cls) -> "Meta": + return cls(version=VERSION) + + def to_file(self, path: Union[PathLike, str]) -> None: + with open(path, "w") as meta_file: + json.dump(asdict(self), meta_file) + + @classmethod + def from_path(cls, path: Union[PathLike, str]) -> "Meta": + with open(path) as meta_file: + data = json.load(meta_file) + return cls(**data) diff --git a/allennlp/models/archival.py b/allennlp/models/archival.py index 027f1275c5f..9341ee4338d 100644 --- a/allennlp/models/archival.py +++ b/allennlp/models/archival.py @@ -2,7 +2,7 @@ Helper functions for archiving models and restoring archived models. """ from os import PathLike -from typing import NamedTuple, Union, Dict, Any, List, Optional +from typing import Tuple, NamedTuple, Union, Dict, Any, List, Optional import logging import os import tempfile @@ -10,11 +10,14 @@ import shutil from contextlib import contextmanager import glob +import warnings from torch.nn import Module +from allennlp.version import VERSION, _MAJOR, _MINOR, _PATCH from allennlp.common.checks import ConfigurationError from allennlp.common.file_utils import cached_path +from allennlp.common.meta import Meta, META_NAME from allennlp.common.params import Params from allennlp.data.dataset_readers import DatasetReader from allennlp.models.model import Model, _DEFAULT_WEIGHTS @@ -29,6 +32,7 @@ class Archive(NamedTuple): config: Params dataset_reader: DatasetReader validation_dataset_reader: DatasetReader + meta: Optional[Meta] def extract_module(self, path: str, freeze: bool = True) -> Module: """ @@ -90,12 +94,13 @@ def extract_module(self, path: str, freeze: bool = True) -> Module: # These constants are the *known names* under which we archive them. CONFIG_NAME = "config.json" _WEIGHTS_NAME = "weights.th" +_VERSION_TUPLE = (_MAJOR, _MINOR, _PATCH) def verify_include_in_archive(include_in_archive: Optional[List[str]] = None): if include_in_archive is None: return - saved_names = [CONFIG_NAME, _WEIGHTS_NAME, _DEFAULT_WEIGHTS, "vocabulary"] + saved_names = [CONFIG_NAME, _WEIGHTS_NAME, _DEFAULT_WEIGHTS, META_NAME, "vocabulary"] for archival_target in include_in_archive: if archival_target in saved_names: raise ConfigurationError( @@ -133,6 +138,9 @@ def archive_model( config_file = os.path.join(serialization_dir, CONFIG_NAME) if not os.path.exists(config_file): logger.error("config file %s does not exist, unable to archive model", config_file) + return + + meta_file = os.path.join(serialization_dir, META_NAME) if archive_path is not None: archive_file = archive_path @@ -140,11 +148,16 @@ def archive_model( archive_file = os.path.join(archive_file, "model.tar.gz") else: archive_file = os.path.join(serialization_dir, "model.tar.gz") + logger.info("archiving weights and vocabulary to %s", archive_file) with tarfile.open(archive_file, "w:gz") as archive: archive.add(config_file, arcname=CONFIG_NAME) archive.add(weights_file, arcname=_WEIGHTS_NAME) archive.add(os.path.join(serialization_dir, "vocabulary"), arcname="vocabulary") + if os.path.exists(meta_file): + archive.add(meta_file, arcname=META_NAME) + else: + logger.warning("meta file %s does not exist", meta_file) if include_in_archive is not None: for archival_target in include_in_archive: @@ -184,6 +197,8 @@ def load_archive( else: logger.info(f"loading archive file {archive_file} from cache at {resolved_archive_file}") + meta: Optional[Meta] = None + tempdir = None try: if os.path.isdir(resolved_archive_file): @@ -205,16 +220,26 @@ def load_archive( config.duplicate(), serialization_dir ) model = _load_model(config.duplicate(), weights_path, serialization_dir, cuda_device) + + # Load meta. + meta_path = os.path.join(serialization_dir, META_NAME) + if os.path.exists(meta_path): + meta = Meta.from_path(meta_path) finally: if tempdir is not None: logger.info(f"removing temporary unarchived model dir at {tempdir}") shutil.rmtree(tempdir, ignore_errors=True) + # Check version compatibility. + if meta is not None: + _check_version_compatibility(archive_file, meta) + return Archive( model=model, config=config, dataset_reader=dataset_reader, validation_dataset_reader=validation_dataset_reader, + meta=meta, ) @@ -267,3 +292,33 @@ def extracted_archive(resolved_archive_file, cleanup=True): if tempdir is not None and cleanup: logger.info(f"removing temporary unarchived model dir at {tempdir}") shutil.rmtree(tempdir, ignore_errors=True) + + +def _parse_version(version: str) -> Tuple[str, str, str]: + """ + Parse a version string into a (major, minor, patch). + """ + try: + major, minor, patch = version.split(".")[:3] + except ValueError: + raise ValueError(f"Invalid version '{version}', unable to parse") + return (major, minor, patch) + + +def _check_version_compatibility(archive_file: Union[PathLike, str], meta: Meta): + meta_version_tuple = _parse_version(meta.version) + # Warn if current version is behind the version the model was trained on. + if _VERSION_TUPLE < meta_version_tuple: + warnings.warn( + f"The model {archive_file} was trained on a newer version of AllenNLP (v{meta.version}), " + f"but you're using version {VERSION}.", + UserWarning, + ) + # Warn if major versions differ since there is no guarantee of backwards + # compatibility across major releases. + elif _VERSION_TUPLE[0] != meta_version_tuple[0]: + warnings.warn( + f"The model {archive_file} was trained on version {meta.version} of AllenNLP, " + f"but you're using {VERSION} which may not be compatible.", + UserWarning, + ) diff --git a/tests/commands/train_test.py b/tests/commands/train_test.py index 1c32955b469..3a0d913a49b 100644 --- a/tests/commands/train_test.py +++ b/tests/commands/train_test.py @@ -13,6 +13,7 @@ import pytest import torch +from allennlp.version import VERSION from allennlp.commands.train import Train, train_model, train_model_from_args, TrainModel from allennlp.common import Params from allennlp.common.checks import ConfigurationError @@ -109,7 +110,11 @@ class TestTrain(AllenNlpTestCase): def test_train_model(self): params = lambda: copy.deepcopy(self.DEFAULT_PARAMS) - train_model(params(), serialization_dir=os.path.join(self.TEST_DIR, "test_train_model")) + serialization_dir = os.path.join(self.TEST_DIR, "test_train_model") + train_model(params(), serialization_dir=serialization_dir) + archive = load_archive(os.path.join(serialization_dir, "model.tar.gz")) + assert archive.meta is not None + assert archive.meta.version == VERSION # It's OK if serialization dir exists but is empty: serialization_dir2 = os.path.join(self.TEST_DIR, "empty_directory") diff --git a/tests/models/archival_test.py b/tests/models/archival_test.py index 9e7ed2fee31..4a40588bc13 100644 --- a/tests/models/archival_test.py +++ b/tests/models/archival_test.py @@ -6,12 +6,19 @@ import pytest import torch +from allennlp.version import _MAJOR, _MINOR from allennlp.commands.train import train_model from allennlp.common import Params +from allennlp.common.meta import Meta from allennlp.common.checks import ConfigurationError from allennlp.common.testing import AllenNlpTestCase from allennlp.data.dataset_readers import DatasetReader -from allennlp.models.archival import archive_model, load_archive, CONFIG_NAME +from allennlp.models.archival import ( + archive_model, + load_archive, + CONFIG_NAME, + _check_version_compatibility, +) def assert_models_equal(model, model2): @@ -32,6 +39,16 @@ def assert_models_equal(model, model2): assert vocab._index_to_token == vocab2._index_to_token +def _test_check_version_compatibility(): + meta = Meta(version=f"{_MAJOR}.{int(_MINOR) + 1}.0") + with pytest.warns(UserWarning, match="trained on a newer version"): + _check_version_compatibility("model.tar.gz", meta) + + meta = Meta(version="1.2.0") + with pytest.warns(UserWarning, match="trained on version"): + _check_version_compatibility("model.tar.gz", meta) + + class ArchivalTest(AllenNlpTestCase): def setup_method(self): super().setup_method() From f3aeeeb38284506d44589b51bb6ddbff197bd664 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld <dirkg@allenai.org> Date: Wed, 19 May 2021 17:19:29 -0700 Subject: [PATCH 29/63] Formatting --- .../huggingface_datasets_reader.py | 35 ++++++++++++++----- .../huggingface_datasets_reader_test.py | 4 +-- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 99cabe40548..20e7b41b946 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -133,16 +133,24 @@ def _map_Feature( fields_to_be_added[feature_name] = _map_ClassLabel(feature_name, entry[feature_name]) # datasets Value can be of different types elif isinstance(feature_type, Value): - fields_to_be_added[feature_name] = _map_Value(feature_name, entry[feature_name], feature_type, tokenizer) + fields_to_be_added[feature_name] = _map_Value( + feature_name, entry[feature_name], feature_type, tokenizer + ) elif isinstance(feature_type, Sequence): - fields_to_be_added[feature_name] = _map_Sequence(feature_name, entry, feature_type.feature, tokenizer) + fields_to_be_added[feature_name] = _map_Sequence( + feature_name, entry, feature_type.feature, tokenizer + ) elif isinstance(feature_type, Translation): - fields_to_be_added = _map_Translation(feature_name, entry[feature_name], feature_type, tokenizer) + fields_to_be_added = _map_Translation( + feature_name, entry[feature_name], feature_type, tokenizer + ) elif isinstance(feature_type, TranslationVariableLanguages): - fields_to_be_added = _map_TranslationVariableLanguages(feature_name, entry[feature_name], feature_type, tokenizer) + fields_to_be_added = _map_TranslationVariableLanguages( + feature_name, entry[feature_name], feature_type, tokenizer + ) else: raise ValueError(f"Datasets feature type {type(feature_type)} is not supported yet.") @@ -163,12 +171,14 @@ def _map_Value( # If tokenizer is provided we will use it to split it to tokens # Else put whole text as a single token field = _map_String(value, tokenizer) - else: field = LabelField(value, label_namespace=feature_name, skip_indexing=True) return field -def _map_Sequence(feature_name, value:Sequence, item_feature_type, tokenizer:Optional[Tokenizer]) -> Field: + +def _map_Sequence( + feature_name, value: Sequence, item_feature_type, tokenizer: Optional[Tokenizer] +) -> Field: field_list: List[Field] = list() field: ListField = None if isinstance(item_feature_type, Value): @@ -178,7 +188,7 @@ def _map_Sequence(feature_name, value:Sequence, item_feature_type, tokenizer:Opt item_field = _map_Value(value.feature, item, item.value, tokenizer) field_list.append(item_field) if len(field_list) > 0: - field = ListField(field_list) + field = ListField(field_list) # datasets Sequence of strings to ListField of LabelField elif isinstance(item_feature_type, ClassLabel): @@ -202,6 +212,7 @@ def _map_Sequence(feature_name, value:Sequence, item_feature_type, tokenizer:Opt return field + def _map_Translation( feature_name: str, value: Translation, feature_type, tokenizer: Optional[Tokenizer] ) -> Dict[str, Field]: @@ -219,7 +230,10 @@ def _map_Translation( texts.append(TextField(tokens)) fields[feature_name + "-languages"] = ListField( - [_map_to_Label(feature_name + "-languages", lang, skip_indexing=False) for lang in langs] + [ + _map_to_Label(feature_name + "-languages", lang, skip_indexing=False) + for lang in langs + ] ) fields[feature_name + "-texts"] = ListField(texts) @@ -230,7 +244,10 @@ def _map_Translation( def _map_TranslationVariableLanguages( - feature_name: str, value: TranslationVariableLanguages, feature_type, tokenizer: Optional[Tokenizer] + feature_name: str, + value: TranslationVariableLanguages, + feature_type, + tokenizer: Optional[Tokenizer], ) -> Dict[str, Field]: fields: Dict[str, Field] = dict() if feature_type.dtype == "dict": diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index b261188d4df..3ff5eef1389 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -8,7 +8,7 @@ # TODO Add test where we compare huggingface wrapped reader with an explicitly coded dataset # TODO pab-vmware/Abhishek-P Add test where we load conll2003 and test it # the way tested for conll2003 specific reader -from datasets import list_datasets, load_dataset +from datasets import list_datasets class HuggingfaceDatasetReaderTest: @@ -166,5 +166,3 @@ def test_load_all(self): reader.read() except Exception as e: print(e) - - From d6c7769e782d1f2042d8f097d367658403c82f1e Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Thu, 20 May 2021 20:19:00 +0530 Subject: [PATCH 30/63] Comments addressed --- .../huggingface_datasets_reader.py | 30 ++++++++--------- .../huggingface_datasets_reader_test.py | 33 +++++++++++-------- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 20e7b41b946..930678d4b57 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -32,8 +32,6 @@ class HuggingfaceDatasetReader(DatasetReader): This is useful since text in allennlp is dealt with as a series of tokens. """ - SUPPORTED_SPLITS = [Split.TRAIN, Split.TEST, Split.VALIDATION] - def __init__( self, dataset_name: str = None, @@ -55,17 +53,13 @@ def __init__( self.config_name = config_name self.tokenizer = tokenizer + self.features = None + def load_dataset_split(self, split: str): - # TODO add support for datasets.split.NamedSplit - if split in self.SUPPORTED_SPLITS: - if self.config_name is not None: - self.dataset[split] = load_dataset(self.dataset_name, self.config_name, split=split) - else: - self.dataset[split] = load_dataset(self.dataset_name, split=split) + if self.config_name is not None: + self.dataset[split] = load_dataset(self.dataset_name, self.config_name, split=split) else: - raise ValueError( - f"Only default splits:{self.SUPPORTED_SPLITS} are currently supported." - ) + self.dataset[split] = load_dataset(self.dataset_name, split=split) def _read(self, file_path: str) -> Iterable[Instance]: """ @@ -77,6 +71,8 @@ def _read(self, file_path: str) -> Iterable[Instance]: # If split is not loaded, load the specific split if file_path not in self.dataset: self.load_dataset_split(file_path) + if self.features is None: + self.features = self.dataset[file_path].features # TODO see if use of Dataset.select() is better dataset_split = self.dataset[file_path] @@ -86,7 +82,7 @@ def _read(self, file_path: str) -> Iterable[Instance]: def raise_feature_not_supported_value_error(value): raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") - def text_to_instance(self, *inputs) -> Instance: + def text_to_instance(self, split: str, entry) -> Instance: # type: ignore """ Takes care of converting dataset entry into AllenNLP friendly instance @@ -106,7 +102,6 @@ def text_to_instance(self, *inputs) -> Instance: # e.g. In a Sentiment dataset an entry could have one feature (of type text/string) indicating the text # and another indicate the sentiment (of type int32/ClassLabel) - split = inputs[0] features: Dict[str, FeatureType] = self.dataset[split].features fields: Dict[str, Field] = dict() @@ -117,7 +112,7 @@ def text_to_instance(self, *inputs) -> Instance: field_list: list feature_type = features[feature_name] - fields_to_be_added = _map_Feature(feature_name, inputs[1], feature_type, self.tokenizer) + fields_to_be_added = _map_Feature(feature_name, entry, feature_type, self.tokenizer) for field_key in fields_to_be_added: fields[field_key] = fields_to_be_added[field_key] @@ -178,9 +173,10 @@ def _map_Value( def _map_Sequence( feature_name, value: Sequence, item_feature_type, tokenizer: Optional[Tokenizer] -) -> Field: +) -> Union[ListField]: field_list: List[Field] = list() - field: ListField = None + field: ListField + item_field: Field if isinstance(item_feature_type, Value): for item in value: # If tokenizer is provided we will use it to split it to tokens @@ -201,7 +197,7 @@ def _map_Sequence( elif isinstance(item_feature_type, Sequence): for item in value: - item_field = _map_Sequence(value.feature, item, tokenizer) + item_field = _map_Sequence(value.feature, item, item_feature_type.feature, tokenizer) field_list.append(item_field) if len(field_list) > 0: diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index 3ff5eef1389..138b65be50d 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -1,16 +1,14 @@ import pytest +from allennlp.common.testing import AllenNlpTestCase +from allennlp.common.util import ensure_list from allennlp.data import Tokenizer +from allennlp.data.dataset_readers import Conll2003DatasetReader from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetReader from allennlp.data.tokenizers import WhitespaceTokenizer # TODO Add test where we compare huggingface wrapped reader with an explicitly coded dataset -# TODO pab-vmware/Abhishek-P Add test where we load conll2003 and test it -# the way tested for conll2003 specific reader -from datasets import list_datasets - - class HuggingfaceDatasetReaderTest: """ @@ -158,11 +156,20 @@ def test_read_known_supported_datasets_without_config(self, dataset): # Confirm all features were mapped assert len(instance.fields) == len(entry) - def test_load_all(self): - for dataset_name in list_datasets(): - try: - print("Dataset:", dataset_name) - reader = HuggingfaceDatasetReader(dataset_name) - reader.read() - except Exception as e: - print(e) + def test_read_from_file_with_deprecated_parameter(self): + conll_reader = HuggingfaceDatasetReader("conll2003") + instances = ensure_list( + conll_reader.read(AllenNlpTestCase.FIXTURES_ROOT / "data" / "conll2003.txt") + ) + + expected_labels = ["I-ORG", "O", "I-PER", "O", "O", "I-LOC", "O"] + + fields = instances[0].fields + tokens = [t.text for t in fields["tokens"].tokens] + assert tokens == ["U.N.", "official", "Ekeus", "heads", "for", "Baghdad", "."] + assert fields["tags"].labels == expected_labels + + fields = instances[1].fields + tokens = [t.text for t in fields["tokens"].tokens] + assert tokens == ["AI2", "engineer", "Joel", "lives", "in", "Seattle", "."] + assert fields["tags"].labels == expected_labels From 79f58a823e7726119d71bcc34926e00ae5d3794f Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Sun, 23 May 2021 21:00:24 +0530 Subject: [PATCH 31/63] Formatting --- allennlp/data/dataset_readers/huggingface_datasets_reader.py | 2 +- tests/data/dataset_readers/huggingface_datasets_reader_test.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 930678d4b57..ed8feba6a85 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -1,7 +1,7 @@ from allennlp.data import DatasetReader, Token, Field, Tokenizer from allennlp.data.fields import TextField, LabelField, ListField from allennlp.data.instance import Instance -from datasets import load_dataset, DatasetDict, Split, list_datasets +from datasets import load_dataset, DatasetDict, list_datasets from datasets.features import ( ClassLabel, Sequence, diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index 138b65be50d..d679e2d4ff8 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -2,7 +2,6 @@ from allennlp.common.testing import AllenNlpTestCase from allennlp.common.util import ensure_list from allennlp.data import Tokenizer -from allennlp.data.dataset_readers import Conll2003DatasetReader from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetReader from allennlp.data.tokenizers import WhitespaceTokenizer From a55a7ba91b58d16428ac337796b5c841ba5bd53c Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Sun, 23 May 2021 21:17:52 +0530 Subject: [PATCH 32/63] removed invalid conll test --- .../huggingface_datasets_reader_test.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index d679e2d4ff8..27408444e13 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -155,20 +155,3 @@ def test_read_known_supported_datasets_without_config(self, dataset): # Confirm all features were mapped assert len(instance.fields) == len(entry) - def test_read_from_file_with_deprecated_parameter(self): - conll_reader = HuggingfaceDatasetReader("conll2003") - instances = ensure_list( - conll_reader.read(AllenNlpTestCase.FIXTURES_ROOT / "data" / "conll2003.txt") - ) - - expected_labels = ["I-ORG", "O", "I-PER", "O", "O", "I-LOC", "O"] - - fields = instances[0].fields - tokens = [t.text for t in fields["tokens"].tokens] - assert tokens == ["U.N.", "official", "Ekeus", "heads", "for", "Baghdad", "."] - assert fields["tags"].labels == expected_labels - - fields = instances[1].fields - tokens = [t.text for t in fields["tokens"].tokens] - assert tokens == ["AI2", "engineer", "Joel", "lives", "in", "Seattle", "."] - assert fields["tags"].labels == expected_labels From 81d0409cd99240b9e915d05b9117e2856079084c Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Tue, 25 May 2021 22:18:52 +0530 Subject: [PATCH 33/63] Regression Fix --- .../dataset_readers/huggingface_datasets_reader.py | 7 ++++--- .../huggingface_datasets_reader_test.py | 11 ++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index ed8feba6a85..98cadd2815b 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -134,7 +134,7 @@ def _map_Feature( elif isinstance(feature_type, Sequence): fields_to_be_added[feature_name] = _map_Sequence( - feature_name, entry, feature_type.feature, tokenizer + feature_name, entry[feature_name], feature_type.feature, tokenizer ) elif isinstance(feature_type, Translation): @@ -177,11 +177,12 @@ def _map_Sequence( field_list: List[Field] = list() field: ListField item_field: Field + # In HF Sequence and list are considered interchangeable, but there are some distinctions such as if isinstance(item_feature_type, Value): for item in value: # If tokenizer is provided we will use it to split it to tokens # Else put whole text as a single token - item_field = _map_Value(value.feature, item, item.value, tokenizer) + item_field = _map_Value(feature_name, item, item_feature_type, tokenizer) field_list.append(item_field) if len(field_list) > 0: field = ListField(field_list) @@ -189,7 +190,7 @@ def _map_Sequence( # datasets Sequence of strings to ListField of LabelField elif isinstance(item_feature_type, ClassLabel): for item in value: - item_field = _map_to_Label(value.feature, item, skip_indexing=True) + item_field = _map_to_Label(feature_name, item, skip_indexing=True) field_list.append(item_field) if len(field_list) > 0: diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index 27408444e13..33cfe4ef8c5 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -1,6 +1,4 @@ import pytest -from allennlp.common.testing import AllenNlpTestCase -from allennlp.common.util import ensure_list from allennlp.data import Tokenizer from allennlp.data.dataset_readers.huggingface_datasets_reader import HuggingfaceDatasetReader @@ -103,8 +101,7 @@ def test_read_with_invalid_split(self, split): Test to help validate for the known supported datasets Skipped by default, enable when required """ - - @pytest.mark.skip() + # TODO pab-vmware skip these once MR is ready to check-in @pytest.mark.parametrize( "dataset, config, split", ( @@ -138,7 +135,7 @@ def test_read_known_supported_datasets_with_config(self, dataset, config, split) Skipped by default, enable when required """ - @pytest.mark.skip() + # TODO pab-vmware skip these once MR is ready to check-in @pytest.mark.parametrize( "dataset", (("swahili"), ("conll2003"), ("dbpedia_14"), ("trec"), ("emotion")) ) @@ -155,3 +152,7 @@ def test_read_known_supported_datasets_without_config(self, dataset): # Confirm all features were mapped assert len(instance.fields) == len(entry) + # def test_air_dialogue(self): + # reader = HuggingfaceDatasetReader(dataset_name="amazon_us_reviews", config_name="Apparel_v1_00") + # instances = list(reader.read("train")) + # print(instances[0]) From 5b9e0c28130d6c2eb9fbcdd55c631429b81b522f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 26 May 2021 03:19:58 +0530 Subject: [PATCH 34/63] Bump black from 20.8b1 to 21.5b1 (#5195) * Bump black from 20.8b1 to 21.5b1 Bumps [black](https://github.com/psf/black) from 20.8b1 to 21.5b1. - [Release notes](https://github.com/psf/black/releases) - [Changelog](https://github.com/psf/black/blob/main/CHANGES.md) - [Commits](https://github.com/psf/black/commits) Signed-off-by: dependabot[bot] <support@github.com> * formatting changes Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com> --- allennlp/commands/find_learning_rate.py | 2 +- allennlp/models/archival.py | 2 +- allennlp/modules/transformer/t5.py | 4 ++-- dev-requirements.txt | 2 +- tests/data/vocabulary_test.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/allennlp/commands/find_learning_rate.py b/allennlp/commands/find_learning_rate.py index 8a1f6380ed4..9e3babba520 100644 --- a/allennlp/commands/find_learning_rate.py +++ b/allennlp/commands/find_learning_rate.py @@ -317,7 +317,7 @@ def search_learning_rate( def _smooth(values: List[float], beta: float) -> List[float]: - """ Exponential smoothing of values """ + """Exponential smoothing of values""" avg_value = 0.0 smoothed = [] for i, value in enumerate(values): diff --git a/allennlp/models/archival.py b/allennlp/models/archival.py index 9341ee4338d..e1d48fcb76f 100644 --- a/allennlp/models/archival.py +++ b/allennlp/models/archival.py @@ -26,7 +26,7 @@ class Archive(NamedTuple): - """ An archive comprises a Model and its experimental config""" + """An archive comprises a Model and its experimental config""" model: Model config: Params diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index 15d34f5b2b1..206f944aae5 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -221,7 +221,7 @@ def _relative_position_bucket( return relative_buckets def compute_bias(self, query_length: int, key_length: int) -> FloatT: - """ Compute binned relative position bias """ + """Compute binned relative position bias""" context_position = torch.arange(query_length, dtype=torch.long)[:, None] memory_position = torch.arange(key_length, dtype=torch.long)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) @@ -283,7 +283,7 @@ def unshape(states): return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) def project(hidden_states, proj_layer, key_value_states, past_key_value) -> FloatT: - """ projects hidden states correctly to key/query states """ + """projects hidden states correctly to key/query states""" if key_value_states is None: # self-attn # (batch_size, num_heads, seq_length, dim_per_head) diff --git a/dev-requirements.txt b/dev-requirements.txt index 3ab927ec5a6..0e02f8e103a 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -7,7 +7,7 @@ flake8 mypy==0.812 # Automatic code formatting -black==20.8b1 +black==21.5b1 # Allows generation of coverage reports with pytest. pytest-cov diff --git a/tests/data/vocabulary_test.py b/tests/data/vocabulary_test.py index d0a81c336f2..7d0280d7c28 100644 --- a/tests/data/vocabulary_test.py +++ b/tests/data/vocabulary_test.py @@ -700,7 +700,7 @@ def test_read_pretrained_words(self): ) def test_from_instances_exclusive_embeddings_file_inside_archive(self): - """ Just for ensuring there are no problems when reading pretrained tokens from an archive """ + """Just for ensuring there are no problems when reading pretrained tokens from an archive""" # Read embeddings file from archive archive_path = str(self.TEST_DIR / "embeddings-archive.zip") From 66f226bbef9226f29fca392c1ad7775b0c88c52d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 26 May 2021 04:38:59 +0530 Subject: [PATCH 35/63] Update nr-interface requirement from <0.0.4 to <0.0.6 (#5213) Updates the requirements on [nr-interface](https://git.niklasrosenstein.com/NiklasRosenstein/nr) to permit the latest version. Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com> --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 0e02f8e103a..58115480562 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -36,7 +36,7 @@ ruamel.yaml # Generating markdown files from Python modules. git+https://github.com/NiklasRosenstein/pydoc-markdown.git@f0bf8af1db4f11581c19d206d4ed1ab34b4854c1 nr.databind.core<0.0.17 -nr.interface<0.0.4 +nr.interface<0.0.6 mkdocs==1.1.2 mkdocs-material>=5.5.0,<7.2.0 From 3295bd53531e480e1433917c393074ab703ff48f Mon Sep 17 00:00:00 2001 From: Pete <petew@allenai.org> Date: Tue, 25 May 2021 21:58:34 -0700 Subject: [PATCH 36/63] Fix W&B callback for distributed training (#5223) * fix wandb callback for distributed training * fix * close out Co-authored-by: Dirk Groeneveld <dirkg@allenai.org> --- CHANGELOG.md | 1 + allennlp/training/callbacks/log_writer.py | 2 +- allennlp/training/callbacks/wandb.py | 34 +++++++++++++++-------- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 739dff1071e..648a0ae9bb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids. - Fixed documentation for `GradientDescentTrainer.cuda_device`. +- Fixed `wandb` callback to work in distributed training. ## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22 diff --git a/allennlp/training/callbacks/log_writer.py b/allennlp/training/callbacks/log_writer.py index 34ae82d3f4d..253b35de3df 100644 --- a/allennlp/training/callbacks/log_writer.py +++ b/allennlp/training/callbacks/log_writer.py @@ -148,7 +148,7 @@ def on_batch( batch_grad_norm: Optional[float] = None, **kwargs, ) -> None: - if not is_training and not is_primary: + if not is_training or not is_primary: return None assert self.trainer is not None diff --git a/allennlp/training/callbacks/wandb.py b/allennlp/training/callbacks/wandb.py index 8ed3024aaab..5adc9f1520d 100644 --- a/allennlp/training/callbacks/wandb.py +++ b/allennlp/training/callbacks/wandb.py @@ -88,11 +88,7 @@ def __init__( self._watch_model = watch_model self._files_to_save = files_to_save - - import wandb - - self.wandb = wandb - self.wandb.init( + self._wandb_kwargs: Dict[str, Any] = dict( dir=os.path.abspath(serialization_dir), project=project, entity=entity, @@ -105,9 +101,6 @@ def __init__( **(wandb_kwargs or {}), ) - for fpath in self._files_to_save: - self.wandb.save(os.path.join(serialization_dir, fpath), base_path=serialization_dir) - @overrides def log_scalars( self, @@ -122,7 +115,7 @@ def log_tensors( self, tensors: Dict[str, torch.Tensor], log_prefix: str = "", epoch: Optional[int] = None ) -> None: self._log( - {k: self.wandb.Histogram(v.cpu().data.numpy().flatten()) for k, v in tensors.items()}, + {k: self.wandb.Histogram(v.cpu().data.numpy().flatten()) for k, v in tensors.items()}, # type: ignore log_prefix=log_prefix, epoch=epoch, ) @@ -134,12 +127,31 @@ def _log( dict_to_log = {f"{log_prefix}/{k}": v for k, v in dict_to_log.items()} if epoch is not None: dict_to_log["epoch"] = epoch - self.wandb.log(dict_to_log, step=self.trainer._batch_num_total) # type: ignore[union-attr] + self.wandb.log(dict_to_log, step=self.trainer._batch_num_total) # type: ignore @overrides def on_start( self, trainer: "GradientDescentTrainer", is_primary: bool = True, **kwargs ) -> None: super().on_start(trainer, is_primary=is_primary, **kwargs) + + if not is_primary: + return None + + import wandb + + self.wandb = wandb + self.wandb.init(**self._wandb_kwargs) + + for fpath in self._files_to_save: + self.wandb.save( # type: ignore + os.path.join(self.serialization_dir, fpath), base_path=self.serialization_dir + ) + if self._watch_model: - self.wandb.watch(self.trainer.model) # type: ignore[union-attr] + self.wandb.watch(self.trainer.model) # type: ignore + + @overrides + def close(self) -> None: + super().close() + self.wandb.finish() # type: ignore From 19d2a8705d2d8ae5c084e199beb5b76c0a98b886 Mon Sep 17 00:00:00 2001 From: Pete <petew@allenai.org> Date: Wed, 26 May 2021 12:01:39 -0700 Subject: [PATCH 37/63] cancel redundant GH Actions workflows (#5226) * cancel redundant GH Actions workflows * trigger CI * fix job conditions * run docker jobs on any self-hosted --- .github/workflows/ci.yml | 17 +++++++++++------ Makefile | 6 ++++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cdcebbfff40..a53b1b6e12b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,5 +1,9 @@ name: CI +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + on: pull_request: branches: @@ -81,22 +85,22 @@ jobs: pip freeze - name: Format - if: always() + if: '! cancelled()' run: | make format - name: Lint - if: always() + if: '! cancelled()' run: | make lint - name: Type check - if: always() + if: '! cancelled()' run: | make typecheck - name: Run tests - if: always() + if: '! cancelled()' run: | make test-with-cov @@ -270,7 +274,8 @@ jobs: docker: name: Docker (CUDA ${{ matrix.cuda }}) if: github.repository == 'allenai/allennlp' - runs-on: [self-hosted, GPU] + # Run on self-hosted to utilize layer caching. + runs-on: [self-hosted] strategy: matrix: cuda: ['10.1', '10.2', '11.1'] @@ -310,7 +315,7 @@ jobs: - name: Test image run: | - make docker-run DOCKER_IMAGE_NAME=$DOCKER_IMAGE_NAME ARGS='test-install' + make docker-run DOCKER_GPUS='' DOCKER_IMAGE_NAME=$DOCKER_IMAGE_NAME ARGS='test-install' - name: Authenticate to Docker Hub if: github.event_name == 'release' || github.event_name == 'push' diff --git a/Makefile b/Makefile index 365260bbceb..fe28b8d4463 100644 --- a/Makefile +++ b/Makefile @@ -154,9 +154,11 @@ docker-image : --build-arg TORCH=$(DOCKER_TORCH_VERSION) \ -t $(DOCKER_IMAGE_NAME) . +DOCKER_GPUS = --gpus all + .PHONY : docker-run docker-run : - $(DOCKER_RUN_CMD) --gpus all $(DOCKER_IMAGE_NAME) $(ARGS) + $(DOCKER_RUN_CMD) $(DOCKER_GPUS) $(DOCKER_IMAGE_NAME) $(ARGS) .PHONY : docker-test-image docker-test-image : @@ -168,4 +170,4 @@ docker-test-image : .PHONY : docker-test-run docker-test-run : - $(DOCKER_RUN_CMD) --gpus all $(DOCKER_TEST_IMAGE_NAME) $(ARGS) + $(DOCKER_RUN_CMD) $(DOCKER_GPUS) $(DOCKER_TEST_IMAGE_NAME) $(ARGS) From 51a01feb126c8bf111249f981fe06974d095c06f Mon Sep 17 00:00:00 2001 From: Pete <petew@allenai.org> Date: Thu, 27 May 2021 10:06:55 -0700 Subject: [PATCH 38/63] fix race condition when extracting files with cached_path (#5227) * fix race condition when extracting files with cached_path * add warning when directory already exists --- CHANGELOG.md | 2 ++ allennlp/common/file_utils.py | 18 +++++++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 648a0ae9bb0..8ff15d31c67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids. - Fixed documentation for `GradientDescentTrainer.cuda_device`. +- Fixed the potential for a race condition with `cached_path()` when extracting archives. Although the race condition + is still possible if used with `force_extract=True`. - Fixed `wandb` callback to work in distributed training. diff --git a/allennlp/common/file_utils.py b/allennlp/common/file_utils.py index 431d3606a74..0acd91b2257 100644 --- a/allennlp/common/file_utils.py +++ b/allennlp/common/file_utils.py @@ -247,6 +247,10 @@ def cached_path( force_extract : `bool`, optional (default = `False`) If `True` and the file is an archive file, it will be extracted regardless of whether or not the extracted directory already exists. + + !!! Warning + Use this flag with caution! This can lead to race conditions if used + from multiple processes on the same file. """ if cache_dir is None: cache_dir = CACHE_DIRECTORY @@ -325,12 +329,24 @@ def cached_path( if extraction_path is not None: # If the extracted directory already exists (and is non-empty), then no - # need to extract again unless `force_extract=True`. + # need to create a lock file and extract again unless `force_extract=True`. if os.path.isdir(extraction_path) and os.listdir(extraction_path) and not force_extract: return extraction_path # Extract it. with FileLock(extraction_path + ".lock"): + # Check again if the directory exists now that we've acquired the lock. + if os.path.isdir(extraction_path) and os.listdir(extraction_path): + if force_extract: + logger.warning( + "Extraction directory for %s (%s) already exists, " + "overwriting it since 'force_extract' is 'True'", + url_or_filename, + extraction_path, + ) + else: + return extraction_path + logger.info("Extracting %s to %s", url_or_filename, extraction_path) shutil.rmtree(extraction_path, ignore_errors=True) From 7727af555fe194fa300b34c47b245e397b55f41b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 27 May 2021 10:40:39 -0700 Subject: [PATCH 39/63] Bump checklist from 0.0.10 to 0.0.11 (#5222) Bumps [checklist](https://github.com/marcotcr/checklist) from 0.0.10 to 0.0.11. - [Release notes](https://github.com/marcotcr/checklist/releases) - [Commits](https://github.com/marcotcr/checklist/commits) Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com> Co-authored-by: Dirk Groeneveld <dirkg@allenai.org> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4ac272c483c..589b4e210dd 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ "lmdb", "more-itertools", "termcolor==1.1.0", - "checklist==0.0.10", + "checklist==0.0.11", "wandb>=0.10.0,<0.11.0", "huggingface_hub>=0.0.8", "google-cloud-storage>=1.38.0,<1.39.0", From 0d5b88f6a4a43161ef77a4a77f89c6b7d96f1bf0 Mon Sep 17 00:00:00 2001 From: wlhgtc <hgtcwl@foxmail.com> Date: Fri, 28 May 2021 02:37:10 +0800 Subject: [PATCH 40/63] Added `DataCollator` for dynamic operations for each batch. (#5221) * ADD: add from_pretrained method for vocab * MOD: test format * MOD: format file * MOD: update changelog * MOD: fix bug * MOD: fix bug * MOD: fix typo * MOD: make the mothod in class * MOD: fix bug * MOD: change to instance method * MOD: fix typo * MOD: fix bug * MOD: change oov to avoid bug * Update allennlp/data/vocabulary.py * Update allennlp/data/vocabulary.py Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com> * Update allennlp/data/vocabulary.py Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com> * Update allennlp/data/vocabulary.py Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com> * MOD: fix formate * MOD: add test case * Update CHANGELOG.md * MOD: fix worker info bug * ADD: update changelog * MOD: fix format * Update allennlp/data/data_loaders/multitask_data_loader.py Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com> * Update CHANGELOG.md Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com> * MOD: add demo code * MOD: align code * MOD: fix bug * MOD: fix bug * MOD: fix bug * MOD: formate code * Update allennlp/data/data_loaders/data_collator.py Co-authored-by: Pete <epwalsh10@gmail.com> * fix error * MOD: add test code * mod: change tokenizer * mod: fix tokenizer * MOD: fix bug * MOD: fix bug * MOD: fix bug * Update allennlp/data/data_loaders/data_collator.py Co-authored-by: Dirk Groeneveld <groeneveld@gmail.com> * MOD: update changelog * MOD: update change log * Update allennlp/data/data_loaders/data_collator.py We should be using underscores for everything. * Formatting Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com> Co-authored-by: Dirk Groeneveld <dirkg@allenai.org> Co-authored-by: Dirk Groeneveld <groeneveld@gmail.com> --- CHANGELOG.md | 2 + allennlp/data/data_loaders/__init__.py | 3 +- allennlp/data/data_loaders/data_collator.py | 71 +++++++++++++++++++ allennlp/data/data_loaders/data_loader.py | 12 +--- .../data_loaders/multiprocess_data_loader.py | 8 ++- .../data/data_loaders/simple_data_loader.py | 6 +- .../multiprocess_data_loader_test.py | 27 +++++++ 7 files changed, 113 insertions(+), 16 deletions(-) create mode 100644 allennlp/data/data_loaders/data_collator.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ff15d31c67..e2058b2550b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased + ### Changed - Use `dist_reduce_sum` in distributed metrics. @@ -33,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences. - Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`. - Added `shuffle` argument to `BucketBatchSampler` which allows for disabling shuffling. +- Added `DataCollator` for dynamic operations for each batch. ### Fixed diff --git a/allennlp/data/data_loaders/__init__.py b/allennlp/data/data_loaders/__init__.py index 8c2dfe8776c..0578bc2363f 100644 --- a/allennlp/data/data_loaders/__init__.py +++ b/allennlp/data/data_loaders/__init__.py @@ -1,4 +1,5 @@ -from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict, allennlp_collate +from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict from allennlp.data.data_loaders.multiprocess_data_loader import MultiProcessDataLoader, WorkerError from allennlp.data.data_loaders.multitask_data_loader import MultiTaskDataLoader from allennlp.data.data_loaders.simple_data_loader import SimpleDataLoader +from allennlp.data.data_loaders.data_collator import allennlp_collate diff --git a/allennlp/data/data_loaders/data_collator.py b/allennlp/data/data_loaders/data_collator.py new file mode 100644 index 00000000000..00e8c7d3b1f --- /dev/null +++ b/allennlp/data/data_loaders/data_collator.py @@ -0,0 +1,71 @@ +from typing import List + +from transformers.data.data_collator import DataCollatorForLanguageModeling +from allennlp.common import Registrable +from allennlp.data.batch import Batch +from allennlp.data.data_loaders.data_loader import TensorDict +from allennlp.data.instance import Instance + + +def allennlp_collate(instances: List[Instance]) -> TensorDict: + """ + This is the default function used to turn a list of `Instance`s into a `TensorDict` + batch. + """ + batch = Batch(instances) + return batch.as_tensor_dict() + + +class DataCollator(Registrable): + """ + This class is similar with `DataCollator` in [Transformers] + (https://github.com/huggingface/transformers/blob/master/src/transformers/data/data_collator.py) + Allow to do some dynamic operations for tensor in different batches + Cause this method run before each epoch to convert `List[Instance]` to `TensorDict` + """ + + default_implementation = "allennlp" + + def __call__(self, instances: List[Instance]) -> TensorDict: + raise NotImplementedError + + +@DataCollator.register("allennlp") +class DefaultDataCollator(DataCollator): + def __call__(self, instances: List[Instance]) -> TensorDict: + return allennlp_collate(instances) + + +@DataCollator.register("language_model") +class LanguageModelingDataCollator(DataCollator): + """ + Register as an `DataCollator` with name `LanguageModelingDataCollator` + Used for language modeling. + """ + + def __init__( + self, + model_name: str, + mlm: bool = True, + mlm_probability: float = 0.15, + filed_name: str = "source", + namespace: str = "tokens", + ): + self._field_name = filed_name + self._namespace = namespace + from allennlp.common import cached_transformers + + tokenizer = cached_transformers.get_tokenizer(model_name) + self._collator = DataCollatorForLanguageModeling(tokenizer, mlm, mlm_probability) + + def __call__(self, instances: List[Instance]) -> TensorDict: + tensor_dicts = allennlp_collate(instances) + tensor_dicts = self.process_tokens(tensor_dicts) + return tensor_dicts + + def process_tokens(self, tensor_dicts: TensorDict) -> TensorDict: + inputs = tensor_dicts[self._field_name][self._namespace]["token_ids"] + inputs, labels = self._collator.mask_tokens(inputs) + tensor_dicts[self._field_name][self._namespace]["token_ids"] = inputs + tensor_dicts[self._field_name][self._namespace]["labels"] = labels + return tensor_dicts diff --git a/allennlp/data/data_loaders/data_loader.py b/allennlp/data/data_loaders/data_loader.py index ce4ce8ca160..6927841ec51 100644 --- a/allennlp/data/data_loaders/data_loader.py +++ b/allennlp/data/data_loaders/data_loader.py @@ -1,10 +1,9 @@ -from typing import List, Dict, Union, Iterator +from typing import Dict, Union, Iterator import torch from allennlp.common.registrable import Registrable from allennlp.data.instance import Instance -from allennlp.data.batch import Batch from allennlp.data.vocabulary import Vocabulary @@ -14,15 +13,6 @@ """ -def allennlp_collate(instances: List[Instance]) -> TensorDict: - """ - This is the default function used to turn a list of `Instance`s into a `TensorDict` - batch. - """ - batch = Batch(instances) - return batch.as_tensor_dict() - - class DataLoader(Registrable): """ A `DataLoader` is responsible for generating batches of instances from a diff --git a/allennlp/data/data_loaders/multiprocess_data_loader.py b/allennlp/data/data_loaders/multiprocess_data_loader.py index 692d7d3518d..bb6e38381dd 100644 --- a/allennlp/data/data_loaders/multiprocess_data_loader.py +++ b/allennlp/data/data_loaders/multiprocess_data_loader.py @@ -12,7 +12,8 @@ from allennlp.common.util import lazy_groups_of, shuffle_iterable from allennlp.common.tqdm import Tqdm from allennlp.data.instance import Instance -from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict, allennlp_collate +from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict +from allennlp.data.data_loaders.data_collator import DataCollator, DefaultDataCollator from allennlp.data.dataset_readers import DatasetReader, WorkerInfo, DatasetReaderInput from allennlp.data.fields import TextField from allennlp.data.samplers import BatchSampler @@ -124,6 +125,8 @@ class MultiProcessDataLoader(DataLoader): quiet : `bool`, optional (default = `False`) If `True`, tqdm progress bars will be disabled. + collate_fn : `DataCollator`, optional ( default = `DefaultDataCollator`) + # Best practices - **Large datasets** @@ -207,6 +210,7 @@ def __init__( start_method: str = "fork", cuda_device: Optional[Union[int, str, torch.device]] = None, quiet: bool = False, + collate_fn: DataCollator = DefaultDataCollator(), ) -> None: # Do some parameter validation. if num_workers is not None and num_workers < 0: @@ -244,7 +248,7 @@ def __init__( self.batch_sampler = batch_sampler self.batches_per_epoch = batches_per_epoch self.num_workers = num_workers - self.collate_fn = allennlp_collate + self.collate_fn = collate_fn self.max_instances_in_memory = max_instances_in_memory self.start_method = start_method self.quiet = quiet diff --git a/allennlp/data/data_loaders/simple_data_loader.py b/allennlp/data/data_loaders/simple_data_loader.py index eab9693e284..d63f634cb16 100644 --- a/allennlp/data/data_loaders/simple_data_loader.py +++ b/allennlp/data/data_loaders/simple_data_loader.py @@ -7,7 +7,8 @@ from allennlp.common.util import lazy_groups_of from allennlp.common.tqdm import Tqdm -from allennlp.data.data_loaders.data_loader import DataLoader, allennlp_collate, TensorDict +from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict +from allennlp.data.data_loaders.data_collator import DefaultDataCollator from allennlp.data.dataset_readers import DatasetReader from allennlp.data.instance import Instance from allennlp.data.vocabulary import Vocabulary @@ -36,6 +37,7 @@ def __init__( self.vocab = vocab self.cuda_device: Optional[torch.device] = None self._batch_generator: Optional[Iterator[TensorDict]] = None + self.collate_fn = DefaultDataCollator() def __len__(self) -> int: if self.batches_per_epoch is not None: @@ -60,7 +62,7 @@ def _iter_batches(self) -> Iterator[TensorDict]: if self.shuffle: random.shuffle(self.instances) for batch in lazy_groups_of(self.iter_instances(), self.batch_size): - tensor_dict = allennlp_collate(batch) + tensor_dict = self.collate_fn(batch) if self.cuda_device is not None: tensor_dict = nn_util.move_to_device(tensor_dict, self.cuda_device) yield tensor_dict diff --git a/tests/data/data_loaders/multiprocess_data_loader_test.py b/tests/data/data_loaders/multiprocess_data_loader_test.py index e0197edee71..3fbd0214d32 100644 --- a/tests/data/data_loaders/multiprocess_data_loader_test.py +++ b/tests/data/data_loaders/multiprocess_data_loader_test.py @@ -10,6 +10,7 @@ from allennlp.data.tokenizers import PretrainedTransformerTokenizer from allennlp.data.token_indexers import PretrainedTransformerIndexer from allennlp.data.vocabulary import Vocabulary +from allennlp.data.data_loaders.data_collator import LanguageModelingDataCollator class MockDatasetReader(DatasetReader): @@ -166,6 +167,32 @@ def test_drop_last(): assert len(batches) == 6 +def test_language_model_data_collator(): + """ + Ensure `LanguageModelingDataCollator` works + """ + norm_loader = MultiProcessDataLoader(MockDatasetReader(), "some path", batch_size=16) + vocab = Vocabulary.from_instances(norm_loader.iter_instances()) + norm_loader.index_with(vocab) + batch0 = list(norm_loader)[0] + + model_name = "epwalsh/bert-xsmall-dummy" + data_collate = LanguageModelingDataCollator(model_name) + mlm_loader = MultiProcessDataLoader( + MockDatasetReader(), "some path", batch_size=16, collate_fn=data_collate + ) + vocab = Vocabulary.from_instances(mlm_loader.iter_instances()) + mlm_loader.index_with(vocab) + batch1 = list(mlm_loader)[0] + + norm_inputs = batch0["source"]["tokens"]["token_ids"] + mlm_inputs = batch1["source"]["tokens"]["token_ids"] + mlm_labels = batch1["source"]["tokens"]["labels"] + + # if we replace the mlm inputs with their labels, should be same as origin inputs + assert torch.where(mlm_labels != -100, mlm_labels, mlm_inputs).tolist() == norm_inputs.tolist() + + def test_batches_per_epoch(): loader = MultiProcessDataLoader( MockDatasetReader(), "some path", batch_size=4, batches_per_epoch=10 From b75c60c5d3d38173308d41226660d22a99a1d17d Mon Sep 17 00:00:00 2001 From: Jacob Morrison <jacob1morrison@gmail.com> Date: Fri, 28 May 2021 15:36:56 -0700 Subject: [PATCH 41/63] Roll backbone (#5229) Adding support for inputs to the backbone with more than 3 dimensions --- CHANGELOG.md | 1 + .../modules/backbones/vilbert_backbone.py | 60 +++++++++++++++---- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e2058b2550b..deab8ed058a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 an actual `torch.nn.Module`. Other parameters to this method have changed as well. - Print the first batch to the console by default. - Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0). +- VilBERT backbone now rolls and unrolls extra dimensions to handle input with > 3 dimensions. ### Added diff --git a/allennlp/modules/backbones/vilbert_backbone.py b/allennlp/modules/backbones/vilbert_backbone.py index 0f554a7a1d2..3eeb9aad4ac 100644 --- a/allennlp/modules/backbones/vilbert_backbone.py +++ b/allennlp/modules/backbones/vilbert_backbone.py @@ -111,19 +111,50 @@ def forward( box_mask: torch.Tensor, text: TextFieldTensors, ) -> Dict[str, torch.Tensor]: - batch_size, _, feature_size = box_features.size() - if "token_ids" in text["tokens"]: token_ids = text["tokens"]["token_ids"] else: token_ids = text["tokens"]["tokens"] + if token_ids.shape[:-1] != box_features.shape[:-2]: + raise ValueError( + "Tokens and boxes must have the same batch size and extra " + "dimensions (if applicable). Token size {0} did not match " + "box feature size {1}.".format(token_ids.shape[:-1], box_features.shape[:-2]) + ) + # Shape: (batch_size, num_tokens) token_type_ids = text["tokens"].get("type_ids") # Shape: (batch_size, num_tokens) attention_mask = text["tokens"].get("mask") - # Shape: (batch_size, num_tokens, embedding_dim) + box_feature_dimensions = box_features.shape + feature_size = box_feature_dimensions[-1] + rolled_dimensions = box_feature_dimensions[:-2] + rolled_dimensions_product = 1 + for dim in rolled_dimensions: + rolled_dimensions_product *= dim + + token_ids = token_ids.view(rolled_dimensions_product, token_ids.shape[-1]) + if token_type_ids is not None: + token_type_ids = token_type_ids.view( + rolled_dimensions_product, token_type_ids.shape[-1] + ) + if attention_mask is not None: + attention_mask = attention_mask.view( + rolled_dimensions_product, attention_mask.shape[-1] + ) + box_features = box_features.view( + rolled_dimensions_product, box_feature_dimensions[-2], feature_size + ) + box_coordinates = box_coordinates.view( + rolled_dimensions_product, + box_coordinates.shape[-2], + box_coordinates.shape[-1], + ) + box_mask = box_mask.view(rolled_dimensions_product, box_mask.shape[-1]) + + # Shape: (rolled_dimensions_product, num_tokens, embedding_dim) embedding_output = self.text_embeddings(token_ids, token_type_ids) num_tokens = embedding_output.size(1) @@ -137,16 +168,16 @@ def forward( extended_image_attention_mask = box_mask - # Shape: (batch_size, feature_size, num_tokens) + # Shape: (rolled_dimensions_product, feature_size, num_tokens) # TODO (epwalsh): Why all zeros?? This doesn't seem right. extended_co_attention_mask = torch.zeros( - batch_size, + extended_image_attention_mask.shape[0], feature_size, num_tokens, dtype=extended_image_attention_mask.dtype, ) - # Shape: (batch_size, num_boxes, image_embedding_dim) + # Shape: (rolled_dimensions_product, num_boxes, image_embedding_dim) v_embedding_output = self.image_embeddings(box_features, box_coordinates) encoded_layers_t, encoded_layers_v = self.encoder( @@ -157,16 +188,25 @@ def forward( extended_co_attention_mask, ) - # Shape: (batch_size, num_tokens, embedding_dim) + # Shape: (rolled_dimensions_product, num_tokens, embedding_dim) sequence_output_t = encoded_layers_t[:, :, :, -1] - # Shape: (batch_size, num_boxes, image_embedding_dim) + # Shape: (rolled_dimensions_product, num_boxes, image_embedding_dim) sequence_output_v = encoded_layers_v[:, :, :, -1] - # Shape: (batch_size, pooled_output_dim) + # Shape: (rolled_dimensions_product, pooled_output_dim) pooled_output_t = self.t_pooler(sequence_output_t) - # Shape: (batch_size, pooled_output_dim) + # Shape: (rolled_dimensions_product, pooled_output_dim) pooled_output_v = self.v_pooler(sequence_output_v) + sequence_output_t = sequence_output_t.view( + rolled_dimensions + (sequence_output_t.shape[-2], sequence_output_t.shape[-1]) + ) + sequence_output_v = sequence_output_v.view( + rolled_dimensions + (sequence_output_v.shape[-2], sequence_output_v.shape[-1]) + ) + pooled_output_t = pooled_output_t.view(rolled_dimensions + (pooled_output_t.shape[-1],)) + pooled_output_v = pooled_output_v.view(rolled_dimensions + (pooled_output_v.shape[-1],)) + if self.fusion_method == "sum": pooled_output = self.dropout(pooled_output_t + pooled_output_v) elif self.fusion_method == "mul": From fd0981ca8d4b623c824c14afaafdc8fe0acdb510 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld <dirkg@allenai.org> Date: Fri, 28 May 2021 19:18:27 -0700 Subject: [PATCH 42/63] Fixes Checkpointing (#5220) * Removes unused variable * Formatting * Make sure we always restore the model's weights properly * Give TrainerCallbacks the ability to save and load state dicts * Give MovingAverage the ability to save and load state dicts * Do not set gradients to None * Typo * Remove unused variable * Typo * Entirely new checkpointing code * Formatting * Make mypy happy lol * Makes the no-op trainer work with the new checkpointer * Mark epochs as completed when they're skipped * Changelog * Fixes how we get the best weights after a training run * Mypy is annoying * Callback fixes * Fix the no op trainer * Simplify * Assorted checkpointer fixes * Mypy is now happy * Fixed all the tests except for one * Removed unused variable * Fix trainer restore logic * Fix test for trainer restore logic * Check the Checkpointing branch of the models repo * Help mypy along * Fixed finalizing logic * More mypy stuff * Update allennlp/training/checkpointer.py Co-authored-by: Pete <petew@allenai.org> * Make weaker claims Co-authored-by: Pete <petew@allenai.org> --- .github/workflows/ci.yml | 1 + CHANGELOG.md | 4 + allennlp/commands/train.py | 21 +- allennlp/models/archival.py | 7 +- allennlp/training/__init__.py | 6 +- allennlp/training/callbacks/callback.py | 8 +- .../training/callbacks/confidence_checks.py | 2 +- allennlp/training/callbacks/console_logger.py | 2 +- allennlp/training/callbacks/log_writer.py | 10 +- allennlp/training/callbacks/tensorboard.py | 6 +- allennlp/training/callbacks/track_epoch.py | 2 +- allennlp/training/callbacks/wandb.py | 4 +- allennlp/training/checkpointer.py | 348 +++--- allennlp/training/gradient_descent_trainer.py | 1072 ++++++++++++++++ allennlp/training/moving_average.py | 10 +- allennlp/training/no_op_trainer.py | 25 +- allennlp/training/scheduler.py | 1 - allennlp/training/trainer.py | 1073 +---------------- tests/commands/no_op_train_test.py | 2 +- tests/training/checkpointer_test.py | 91 +- tests/training/trainer_test.py | 153 +-- 21 files changed, 1455 insertions(+), 1393 deletions(-) create mode 100644 allennlp/training/gradient_descent_trainer.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a53b1b6e12b..e0e3494a954 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -156,6 +156,7 @@ jobs: run: | git clone https://github.com/allenai/allennlp-models.git cd allennlp-models + git checkout Checkpointing pip install --upgrade --upgrade-strategy eager -e . -r dev-requirements.txt - name: Run models tests diff --git a/CHANGELOG.md b/CHANGELOG.md index deab8ed058a..281391a0e14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 an actual `torch.nn.Module`. Other parameters to this method have changed as well. - Print the first batch to the console by default. - Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0). +- Trainer callbacks can now store and restore state in case a training run gets interrupted. - VilBERT backbone now rolls and unrolls extra dimensions to handle input with > 3 dimensions. ### Added @@ -41,6 +42,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids. - Fixed documentation for `GradientDescentTrainer.cuda_device`. +- Re-starting a training run from a checkpoint in the middle of an epoch now works correctly. +- When using the "moving average" weights smoothing feature of the trainer, training checkpoints would also get smoothed, with strange results for resuming a training job. This has been fixed. +- When re-starting an interrupted training job, the trainer will now read out the data loader even for epochs and batches that can be skipped. We do this to try to get any random number generators used by the reader or data loader into the same state as they were the first time the training job ran. - Fixed the potential for a race condition with `cached_path()` when extracting archives. Although the race condition is still possible if used with `force_extract=True`. - Fixed `wandb` callback to work in distributed training. diff --git a/allennlp/commands/train.py b/allennlp/commands/train.py index 5304e6e0735..c5be97c990f 100644 --- a/allennlp/commands/train.py +++ b/allennlp/commands/train.py @@ -471,11 +471,22 @@ def _train_worker( except KeyboardInterrupt: # if we have completed an epoch, try to create a model archive. if primary and os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)): - logging.info( - "Training interrupted by the user. Attempting to create " - "a model archive using the current best epoch weights." - ) - archive_model(serialization_dir, include_in_archive=include_in_archive) + best_weights_path = train_loop.trainer.get_best_weights_path() + if best_weights_path is None: + logging.info( + "Training interrupted by the user, and no best model has been saved. " + "No model archive created." + ) + else: + logging.info( + "Training interrupted by the user. Attempting to create " + "a model archive using the current best epoch weights." + ) + archive_model( + serialization_dir, + weights=best_weights_path, + include_in_archive=include_in_archive, + ) raise if primary: diff --git a/allennlp/models/archival.py b/allennlp/models/archival.py index e1d48fcb76f..e49bd9dec6a 100644 --- a/allennlp/models/archival.py +++ b/allennlp/models/archival.py @@ -2,6 +2,7 @@ Helper functions for archiving models and restoring archived models. """ from os import PathLike +from pathlib import Path from typing import Tuple, NamedTuple, Union, Dict, Any, List, Optional import logging import os @@ -130,7 +131,11 @@ def archive_model( include_in_archive : `List[str]`, optional, (default = `None`) Paths relative to `serialization_dir` that should be archived in addition to the default ones. """ - weights_file = os.path.join(serialization_dir, weights) + extra_copy_of_weights_just_for_mypy = Path(weights) + if extra_copy_of_weights_just_for_mypy.is_absolute(): + weights_file = extra_copy_of_weights_just_for_mypy + else: + weights_file = Path(serialization_dir) / extra_copy_of_weights_just_for_mypy if not os.path.exists(weights_file): logger.error("weights file %s does not exist, unable to archive model", weights_file) return diff --git a/allennlp/training/__init__.py b/allennlp/training/__init__.py index c309005246c..cf95606636b 100644 --- a/allennlp/training/__init__.py +++ b/allennlp/training/__init__.py @@ -1,7 +1,5 @@ from allennlp.training.checkpointer import Checkpointer from allennlp.training.no_op_trainer import NoOpTrainer from allennlp.training.callbacks import TrainerCallback -from allennlp.training.trainer import ( - Trainer, - GradientDescentTrainer, -) +from allennlp.training.trainer import Trainer +from allennlp.training.gradient_descent_trainer import GradientDescentTrainer diff --git a/allennlp/training/callbacks/callback.py b/allennlp/training/callbacks/callback.py index 19c14cc0dc6..301e9cb4387 100644 --- a/allennlp/training/callbacks/callback.py +++ b/allennlp/training/callbacks/callback.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: - from allennlp.training.trainer import GradientDescentTrainer + from allennlp.training.gradient_descent_trainer import GradientDescentTrainer class TrainerCallback(Registrable): @@ -77,5 +77,11 @@ def on_end( """ pass + def state_dict(self) -> Dict[str, Any]: + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + pass + TrainerCallback.register("null")(TrainerCallback) diff --git a/allennlp/training/callbacks/confidence_checks.py b/allennlp/training/callbacks/confidence_checks.py index e57a0a0a626..584dcd137b4 100644 --- a/allennlp/training/callbacks/confidence_checks.py +++ b/allennlp/training/callbacks/confidence_checks.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: - from allennlp.training.trainer import GradientDescentTrainer + from allennlp.training.gradient_descent_trainer import GradientDescentTrainer # `sanity_checks` is deprecated and will be removed. diff --git a/allennlp/training/callbacks/console_logger.py b/allennlp/training/callbacks/console_logger.py index 68565ed247a..768ba0d5936 100644 --- a/allennlp/training/callbacks/console_logger.py +++ b/allennlp/training/callbacks/console_logger.py @@ -8,7 +8,7 @@ from allennlp.data import TensorDict if TYPE_CHECKING: - from allennlp.training.trainer import GradientDescentTrainer + from allennlp.training.gradient_descent_trainer import GradientDescentTrainer logger = logging.getLogger(__name__) diff --git a/allennlp/training/callbacks/log_writer.py b/allennlp/training/callbacks/log_writer.py index 253b35de3df..8b3183f28c3 100644 --- a/allennlp/training/callbacks/log_writer.py +++ b/allennlp/training/callbacks/log_writer.py @@ -10,7 +10,7 @@ from allennlp.training.util import get_train_and_validation_metrics, get_batch_size if TYPE_CHECKING: - from allennlp.training.trainer import GradientDescentTrainer + from allennlp.training.gradient_descent_trainer import GradientDescentTrainer logger = logging.getLogger(__name__) @@ -289,15 +289,17 @@ def log_epoch( ) def _should_log_distributions_next_batch(self) -> bool: + assert self.trainer is not None return ( self._distribution_interval is not None - and (self.trainer._batch_num_total + 1) % self._distribution_interval == 0 # type: ignore[union-attr] + and (self.trainer._total_batches_completed + 1) % self._distribution_interval == 0 ) def _should_log_distributions_this_batch(self) -> bool: + assert self.trainer is not None return ( self._distribution_interval is not None - and self.trainer._batch_num_total % self._distribution_interval == 0 # type: ignore[union-attr] + and self.trainer._total_batches_completed % self._distribution_interval == 0 ) def _enable_activation_logging(self) -> None: @@ -318,7 +320,7 @@ def hook(module_, inputs, outputs): self._module_hook_handles.append(module.register_forward_hook(hook)) def _should_log_this_batch(self) -> bool: - return self.trainer._batch_num_total % self._summary_interval == 0 # type: ignore[union-attr] + return self.trainer._total_batches_completed % self._summary_interval == 0 # type: ignore[union-attr] def _log_activation_distribution(self, outputs: Any, module_name: str) -> None: activations_to_log: Dict[str, torch.Tensor] = {} diff --git a/allennlp/training/callbacks/tensorboard.py b/allennlp/training/callbacks/tensorboard.py index 0f6302dfcb4..73bc04a686a 100644 --- a/allennlp/training/callbacks/tensorboard.py +++ b/allennlp/training/callbacks/tensorboard.py @@ -49,7 +49,8 @@ def log_scalars( log_prefix: str = "", epoch: Optional[int] = None, ) -> None: - timestep = epoch if epoch is not None else self.trainer._batch_num_total # type: ignore[union-attr] + assert self.trainer is not None + timestep = epoch if epoch is not None else self.trainer._total_batches_completed log = self._train_log if not log_prefix.startswith("validation") else self._validation_log for key, value in scalars.items(): name = f"{log_prefix}/{key}" if log_prefix else key @@ -59,7 +60,8 @@ def log_scalars( def log_tensors( self, tensors: Dict[str, torch.Tensor], log_prefix: str = "", epoch: Optional[int] = None ) -> None: - timestep = epoch if epoch is not None else self.trainer._batch_num_total # type: ignore[union-attr] + assert self.trainer is not None + timestep = epoch if epoch is not None else self.trainer._total_batches_completed log = self._train_log if not log_prefix.startswith("validation") else self._validation_log for key, values in tensors.items(): name = f"{log_prefix}/{key}" if log_prefix else key diff --git a/allennlp/training/callbacks/track_epoch.py b/allennlp/training/callbacks/track_epoch.py index ea08459b390..b15da434248 100644 --- a/allennlp/training/callbacks/track_epoch.py +++ b/allennlp/training/callbacks/track_epoch.py @@ -3,7 +3,7 @@ from allennlp.training.callbacks.callback import TrainerCallback if TYPE_CHECKING: - from allennlp.training.trainer import GradientDescentTrainer + from allennlp.training.gradient_descent_trainer import GradientDescentTrainer @TrainerCallback.register("track_epoch_callback") diff --git a/allennlp/training/callbacks/wandb.py b/allennlp/training/callbacks/wandb.py index 5adc9f1520d..b09301af62b 100644 --- a/allennlp/training/callbacks/wandb.py +++ b/allennlp/training/callbacks/wandb.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: - from allennlp.training.trainer import GradientDescentTrainer + from allennlp.training.gradient_descent_trainer import GradientDescentTrainer logger = logging.getLogger(__name__) @@ -127,7 +127,7 @@ def _log( dict_to_log = {f"{log_prefix}/{k}": v for k, v in dict_to_log.items()} if epoch is not None: dict_to_log["epoch"] = epoch - self.wandb.log(dict_to_log, step=self.trainer._batch_num_total) # type: ignore + self.wandb.log(dict_to_log, step=self.trainer._total_batches_completed) # type: ignore @overrides def on_start( diff --git a/allennlp/training/checkpointer.py b/allennlp/training/checkpointer.py index 95105a0820f..38d2692273d 100644 --- a/allennlp/training/checkpointer.py +++ b/allennlp/training/checkpointer.py @@ -1,5 +1,5 @@ import glob -from typing import Union, Dict, Any, List, Tuple, Optional +from typing import Dict, Any, Tuple, Optional, Set, Union import logging import os @@ -8,10 +8,9 @@ import torch -import allennlp from allennlp.common import Registrable from allennlp.nn import util as nn_util -from allennlp.training import util as training_util +from allennlp.training.trainer import Trainer logger = logging.getLogger(__name__) @@ -20,28 +19,26 @@ class Checkpointer(Registrable): """ This class implements the functionality for checkpointing your model and trainer state during training. It is agnostic as to what those states look like (they are typed as - Dict[str, Any]), but they will be fed to `torch.save` so they should be serializable - in that sense. They will also be restored as Dict[str, Any], which means the calling + `Dict[str, Any]`), but they will be fed to `torch.save` so they should be serializable + in that sense. They will also be restored as `Dict[str, Any]`, which means the calling code is responsible for knowing what to do with them. # Parameters - num_serialized_models_to_keep : `int`, optional (default=`2`) - Number of previous model checkpoints to retain. Default is to keep 2 checkpoints. - A value of None or -1 means all checkpoints will be kept. - - In a typical AllenNLP configuration file, this argument does not get an entry under the - "checkpointer", it gets passed in separately. - keep_serialized_model_every_num_seconds : `int`, optional (default=`None`) - If num_serialized_models_to_keep is not None, then occasionally it's useful to - save models at a given interval in addition to the last num_serialized_models_to_keep. - To do so, specify keep_serialized_model_every_num_seconds as the number of seconds - between permanently saved checkpoints. Note that this option is only used if - num_serialized_models_to_keep is not None, otherwise all checkpoints are kept. - model_save_interval : `float`, optional (default=`None`) - If provided, then serialize models every `model_save_interval` - seconds within single epochs. In all cases, models are also saved - at the end of every epoch if `serialization_dir` is provided. + save_completed_epochs : `bool`, (default=`True`) + Saves model and trainer state at the end of each completed epoch. + save_every_num_seconds : `int`, optional (default=`None`) + If set, makes sure we never go longer than this number of seconds between saving a model. + save_every_num_batches : `int`, optional (default=`None`) + If set, makes sure we never go longer than this number of batches between saving a model. + keep_most_recent_by_count : `int`, optional (default=`2`) + Sets the number of model checkpoints to keep on disk. If both `keep_most_recent_by_count` and + `keep_most_recent_by_age` are set, we'll keep checkpoints that satisfy either criterion. + If both are `None`, we keep all checkpoints. + keep_most_recent_by_age : `int`, optional (default=`None`) + Sets the number of seconds we'll keep a checkpoint before deleting it. If both + `keep_most_recent_by_count` and `keep_most_recent_by_age` are set, we'll keep checkpoints + that satisfy either criterion. If both are `None`, we keep all checkpoints. """ default_implementation = "default" @@ -49,183 +46,179 @@ class Checkpointer(Registrable): def __init__( self, serialization_dir: str, - keep_serialized_model_every_num_seconds: int = None, - num_serialized_models_to_keep: int = 2, - model_save_interval: float = None, + save_completed_epochs: bool = True, + save_every_num_seconds: Optional[float] = None, + save_every_num_batches: Optional[int] = None, + keep_most_recent_by_count: Optional[int] = 2, + keep_most_recent_by_age: Optional[int] = None, ) -> None: self._serialization_dir = serialization_dir - self._keep_serialized_model_every_num_seconds = keep_serialized_model_every_num_seconds - self._num_serialized_models_to_keep = num_serialized_models_to_keep - self._model_save_interval = model_save_interval - - self._last_permanent_saved_checkpoint_time = time.time() - self._serialized_paths: List[Tuple[float, str, str]] = [] + self._save_completed_epochs = save_completed_epochs + self._save_every_num_seconds = save_every_num_seconds + self._save_every_num_batches = save_every_num_batches + self._keep_most_recent_by_count = keep_most_recent_by_count + self._keep_most_recent_by_age = keep_most_recent_by_age self._last_save_time = time.time() + self._last_save_num_epochs_completed = 0 + self._last_save_num_batches_in_epoch_completed = 0 + + def _model_state_path(self, epochs_completed: int, batches_in_epoch_completed: int) -> str: + return os.path.join( + self._serialization_dir, + f"model_state_e{epochs_completed}_b{batches_in_epoch_completed}.th", + ) + + def _training_state_path(self, epochs_completed: int, batches_in_epoch_completed: int) -> str: + return os.path.join( + self._serialization_dir, + f"training_state_e{epochs_completed}_b{batches_in_epoch_completed}.th", + ) + + _model_state_file_re = re.compile(r"(.*/)?model_state_e(\d+)_b(\d+)\.th$") + _training_state_file_re = re.compile(r"(.*/)?training_state_e(\d+)_b(\d+)\.th$") + + @classmethod + def _parse_model_state_path(cls, path: Union[str, os.PathLike]) -> Optional[Tuple[int, int]]: + match = cls._model_state_file_re.match(str(path)) + if match is None: + return None + else: + try: + return int(match.group(2)), int(match.group(3)) + except ValueError: + return None + + @classmethod + def _parse_training_state_path(cls, path: Union[str, os.PathLike]) -> Optional[Tuple[int, int]]: + match = cls._training_state_file_re.match(str(path)) + if match is None: + return None + else: + try: + return int(match.group(2)), int(match.group(3)) + except ValueError: + return None + + def _find_all_checkpoints(self) -> Set[Tuple[int, int]]: + """Returns a set of integers, each of which is a number of batches that were completed at the + time a checkpoint wsa saved.""" + checkpoints = set() + for model_state_file in glob.iglob( + os.path.join(self._serialization_dir, "model_state_e*_b*.th") + ): + point_in_time = self._parse_model_state_path(model_state_file) + if point_in_time is None: + continue + else: + checkpoints.add(point_in_time) + return checkpoints def maybe_save_checkpoint( - self, trainer: "allennlp.training.trainer.Trainer", epoch: int, batches_this_epoch: int + self, + trainer: Trainer, + num_epochs_completed: int, + num_batches_in_epoch_completed: int, ) -> None: """ - Given amount of time lapsed between the last save and now (tracked internally), the - current epoch, and the number of batches seen so far this epoch, this method decides whether - to save a checkpoint or not. If we decide to save a checkpoint, we grab whatever state we - need out of the `Trainer` and save it. - - This function is intended to be called at the end of each batch in an epoch (perhaps because - your data is large enough that you don't really have "epochs"). The default implementation - only looks at time, not batch or epoch number, though those parameters are available to you - if you want to customize the behavior of this function. + Figures out whether we need to save a checkpoint, and does so if necessary. """ - if self._model_save_interval is None: - return - if time.time() - self._last_save_time < self._model_save_interval: - return - - self._last_save_time = time.time() - epoch_str = f"{epoch}.{training_util.time_to_str(int(self._last_save_time))}" - self.save_checkpoint(epoch_str, trainer) - - def shelve_model(self, epoch: Union[int, str], trainer: "allennlp.training.trainer.Trainer"): - if self._serialization_dir is None: - return - - # back up the model - with trainer.get_checkpoint_state() as state: - model_state, _ = state - model_backup_path = os.path.join( - self._serialization_dir, "model_state_backup_epoch_{}.th".format(epoch) + end_of_epoch = num_batches_in_epoch_completed == 0 + if num_epochs_completed == self._last_save_num_epochs_completed: + last_save_num_batches_in_epoch_completed = ( + self._last_save_num_batches_in_epoch_completed ) - torch.save(model_state, model_backup_path) + else: + last_save_num_batches_in_epoch_completed = 0 - def remove_shelved_models(self): - if self._serialization_dir is None: - return + should_save = ( + (end_of_epoch and self._save_completed_epochs) + or ( + self._save_every_num_seconds is not None + and (time.time() - self._last_save_time >= self._save_every_num_seconds) + ) + or ( + self._save_every_num_batches is not None + and ( + num_batches_in_epoch_completed - last_save_num_batches_in_epoch_completed + >= self._save_every_num_batches + ) + ) + ) - for old_model_backup_path in glob.glob( - os.path.join(self._serialization_dir, "model_state_backup_epoch_*.th") - ): - os.remove(old_model_backup_path) + if should_save: + self.save_checkpoint(trainer) def save_checkpoint( self, - epoch: Union[int, str], - trainer: "allennlp.training.trainer.Trainer", - is_best_so_far: bool = False, + trainer: Trainer, ) -> None: if self._serialization_dir is None: return - with trainer.get_checkpoint_state() as state: - model_state, training_states = state - model_path = os.path.join( - self._serialization_dir, "model_state_epoch_{}.th".format(epoch) - ) - if not os.path.isfile(model_path): - model_backup_path = os.path.join( - self._serialization_dir, "model_state_backup_epoch_{}.th".format(epoch) - ) - if os.path.isfile(model_backup_path): - os.rename(model_backup_path, model_path) - else: - torch.save(model_state, model_path) + tcps = trainer.get_checkpoint_state() + epochs_completed = tcps.trainer_state["epochs_completed"] + batches_in_epoch_completed = tcps.trainer_state["batches_in_epoch_completed"] - training_path = os.path.join( - self._serialization_dir, "training_state_epoch_{}.th".format(epoch) - ) - if not os.path.isfile(training_path): - torch.save({**training_states, "epoch": epoch}, training_path) + model_state_path = self._model_state_path(epochs_completed, batches_in_epoch_completed) + if not os.path.isfile(model_state_path): + torch.save(tcps.model_state, model_state_path) - # The main checkpointing logic is now done, this is just shuffling files around, to keep - # track of best weights, and to remove old checkpoints, if desired. - self.remove_shelved_models() + trainer_state_path = self._training_state_path(epochs_completed, batches_in_epoch_completed) + if not os.path.isfile(trainer_state_path): + torch.save(tcps.trainer_state, trainer_state_path) - if is_best_so_far: - logger.info( - "Best validation performance so far. Copying weights to '%s/best.th'.", - self._serialization_dir, - ) - dest_path = os.path.join(self._serialization_dir, "best.th") - if os.path.exists(dest_path): - os.remove(dest_path) - os.link(model_path, dest_path) - - if ( - self._num_serialized_models_to_keep is not None - and self._num_serialized_models_to_keep >= 0 - ): - self._serialized_paths.append((time.time(), model_path, training_path)) - if len(self._serialized_paths) > self._num_serialized_models_to_keep: - paths_to_remove = self._serialized_paths.pop(0) - # Check to see if we should keep this checkpoint, if it has been longer - # then self._keep_serialized_model_every_num_seconds since the last - # kept checkpoint. - remove_path = True - if self._keep_serialized_model_every_num_seconds is not None: - save_time = paths_to_remove[0] - time_since_checkpoint_kept = ( - save_time - self._last_permanent_saved_checkpoint_time + self._last_save_time = time.time() + self._last_save_num_epochs_completed = epochs_completed + self._last_save_num_batches_in_epoch_completed = batches_in_epoch_completed + + if self._keep_most_recent_by_age is not None or self._keep_most_recent_by_count is not None: + checkpoints = list(self._find_all_checkpoints()) + checkpoints.sort(reverse=True) + + # Keep the most recent n checkpoints + if self._keep_most_recent_by_count is not None: + checkpoints_to_keep = set(checkpoints[: self._keep_most_recent_by_count]) + else: + checkpoints_to_keep = set() + + # Keep the youngest checkpoints by age + now = time.time() + if self._keep_most_recent_by_age is not None: + for checkpoint in checkpoints: + checkpoint_mtime = max( + os.path.getmtime(n) + for n in [ + self._model_state_path(*checkpoint), + self._training_state_path(*checkpoint), + ] ) - if time_since_checkpoint_kept > self._keep_serialized_model_every_num_seconds: - # We want to keep this checkpoint. - remove_path = False - self._last_permanent_saved_checkpoint_time = save_time - if remove_path: - for fname in paths_to_remove[1:]: - if os.path.isfile(fname): - os.remove(fname) - - def find_latest_checkpoint(self) -> Optional[Tuple[str, str]]: + if now - checkpoint_mtime <= self._keep_most_recent_by_age: + checkpoints_to_keep.add(checkpoint) + + # Remove everything we're not keeping + for checkpoint in checkpoints: + if checkpoint not in checkpoints_to_keep: + os.remove(self._model_state_path(*checkpoint)) + os.remove(self._training_state_path(*checkpoint)) + + def _find_latest_checkpoint(self) -> Optional[Tuple[str, str]]: """ Return the location of the latest model and training state files. If there isn't a valid checkpoint then return None. """ - have_checkpoint = self._serialization_dir is not None and any( - "model_state_epoch_" in x for x in os.listdir(self._serialization_dir) - ) - - if not have_checkpoint: + checkpoints = self._find_all_checkpoints() + if len(checkpoints) <= 0: return None + last_checkpoint = max(checkpoints) + return self._model_state_path(*last_checkpoint), self._training_state_path(*last_checkpoint) - serialization_files = os.listdir(self._serialization_dir) - model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x] - # Get the last checkpoint file. Epochs are specified as either an - # int (for end of epoch files) or with epoch and timestamp for - # within epoch checkpoints, e.g. 5.2018-02-02-15-33-42 - found_epochs = [ - re.search(r"model_state_epoch_([0-9\.\-]+)\.th", x).group(1) for x in model_checkpoints # type: ignore - ] - int_epochs: Any = [] - for epoch in found_epochs: - pieces = epoch.split(".") - if len(pieces) == 1: - # Just a single epoch without timestamp - int_epochs.append([int(pieces[0]), "0"]) - else: - # has a timestamp - int_epochs.append([int(pieces[0]), pieces[1]]) - last_epoch = sorted(int_epochs, reverse=True)[0] - if last_epoch[1] == "0": - epoch_to_load = str(last_epoch[0]) - else: - epoch_to_load = "{0}.{1}".format(last_epoch[0], last_epoch[1]) - - model_path = os.path.join( - self._serialization_dir, "model_state_epoch_{}.th".format(epoch_to_load) - ) - training_state_path = os.path.join( - self._serialization_dir, "training_state_epoch_{}.th".format(epoch_to_load) - ) - - return (model_path, training_state_path) - - def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def load_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ - Restores a model from a serialization_dir to the last saved checkpoint. - This includes a training state (typically consisting of an epoch count and optimizer state), - which is serialized separately from model parameters. This function should only be used to - continue training - if you wish to load a model for inference/load parts of a model into a new - computation graph, you should use the native Pytorch functions: - ` model.load_state_dict(torch.load("/path/to/model/weights.th"))` + Loads model state from a `serialization_dir` corresponding to the last saved checkpoint. + This includes a training state, which is serialized separately from model parameters. This function + should only be used to continue training - if you wish to load a model for inference/load parts + of a model into a new computation graph, you should use the native Pytorch functions: + `model.load_state_dict(torch.load("/path/to/model/weights.th"))` If `self._serialization_dir` does not exist or does not contain any checkpointed weights, this function will do nothing and return empty dicts. @@ -235,12 +228,9 @@ def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: states : `Tuple[Dict[str, Any], Dict[str, Any]]` The model state and the training state. """ - latest_checkpoint = self.find_latest_checkpoint() - + latest_checkpoint = self._find_latest_checkpoint() if latest_checkpoint is None: - # No checkpoint to restore, start at 0 return {}, {} - model_path, training_state_path = latest_checkpoint # Load the parameters onto CPU, then transfer to GPU. @@ -251,17 +241,5 @@ def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: training_state = torch.load(training_state_path, map_location=nn_util.device_mapping(-1)) return model_state, training_state - def best_model_state(self) -> Dict[str, Any]: - if self._serialization_dir: - logger.info("loading best weights") - best_model_state_path = os.path.join(self._serialization_dir, "best.th") - return torch.load(best_model_state_path, map_location=nn_util.device_mapping(-1)) - else: - logger.info( - "cannot load best weights without `serialization_dir`, " - "so you're just getting the last weights" - ) - return {} - Checkpointer.register("default")(Checkpointer) diff --git a/allennlp/training/gradient_descent_trainer.py b/allennlp/training/gradient_descent_trainer.py new file mode 100644 index 00000000000..0e3f3cb0816 --- /dev/null +++ b/allennlp/training/gradient_descent_trainer.py @@ -0,0 +1,1072 @@ +import datetime +import logging +import math +import os +import re +import time +import warnings +from typing import Optional, Union, List, Dict, Tuple, Any, Type + +import torch +from torch.cuda import amp +from torch.nn.parallel import DistributedDataParallel +from torch.nn.utils import clip_grad_norm_ +import torch.distributed as dist + +from allennlp.common.checks import ConfigurationError, check_for_gpu +from allennlp.common import util as common_util, Tqdm, Lazy +from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict +from allennlp.models.model import Model +from allennlp.training.callbacks import ConsoleLoggerCallback +from allennlp.training.callbacks.confidence_checks import ConfidenceChecksCallback +from allennlp.training.checkpointer import Checkpointer +from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler +from allennlp.training.metric_tracker import MetricTracker +from allennlp.training.momentum_schedulers.momentum_scheduler import MomentumScheduler +from allennlp.training.moving_average import MovingAverage +from allennlp.training.optimizers import Optimizer +from allennlp.training.trainer import Trainer, TrainerCheckpoint +from allennlp.training.callbacks import TrainerCallback +from allennlp.training import util as training_util + +logger = logging.getLogger(__name__) + + +@Trainer.register("gradient_descent", constructor="from_partial_objects") +class GradientDescentTrainer(Trainer): + """ + A trainer for doing supervised learning with gradient descent. It just takes a labeled dataset + and a `DataLoader`, and uses the supplied `Optimizer` to learn the weights for your model over + some fixed number of epochs. You can also pass in a validation data_loader and enable early + stopping. There are many other bells and whistles as well. + + Registered as a `Trainer` with the name "gradient_descent" (and is also the default `Trainer`). + The constructor that is registered is [`from_partial_objects`](#from_partial_objects) - + see the arguments to that function for the exact keys that should be used, if you are using + a configuration file. They largely match the arguments to `__init__`, and we don't repeat their + docstrings in `from_partial_objects`. + + [0]: https://tinyurl.com/y5mv44fw + + # Parameters + + model : `Model`, required. + An AllenNLP model to be optimized. Pytorch Modules can also be optimized if + their `forward` method returns a dictionary with a "loss" key, containing a + scalar tensor representing the loss function to be optimized. + + If you are training your model using GPUs, your model should already be + on the correct device. (If you are using our `train` command this will be + handled for you.) + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately. + + optimizer : `torch.nn.Optimizer`, required. + An instance of a Pytorch Optimizer, instantiated with the parameters of the + model to be optimized. + + data_loader : `DataLoader`, required. + A `DataLoader` containing your `Dataset`, yielding padded indexed batches. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately. + + patience : `Optional[int] > 0`, optional (default=`None`) + Number of epochs to be patient before early stopping: the training is stopped + after `patience` epochs with no improvement. If given, it must be `> 0`. + If None, early stopping is disabled. + + validation_metric : `Union[str, List[str]]`, optional (default=`"-loss"`) + Validation metric to measure for whether to stop training using patience + and whether to serialize an `is_best` model each epoch. The metric name + must be prepended with either "+" or "-", which specifies whether the metric + is an increasing or decreasing function. If you specify more than one metric, + the metrics will be summed to make the `is_best` decision. + + validation_data_loader : `DataLoader`, optional (default=`None`) + A `DataLoader` to use for the validation set. If `None`, then + use the training `DataLoader` with the validation data. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately. + + num_epochs : `int`, optional (default = `20`) + Number of training epochs. + + serialization_dir : `str`, optional (default=`None`) + Path to directory for saving and loading model files. Models will not be saved if + this parameter is not passed. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately. + + checkpointer : `Checkpointer`, optional (default=`None`) + A `Checkpointer` is responsible for periodically saving model weights. If none is given + here, we will construct one with default parameters. + + cuda_device : `Optional[Union[int, torch.device]]`, optional (default = `None`) + An integer or `torch.device` specifying the CUDA device to use for this process. + If -1, the CPU is used. If `None` and you have a GPU available, that GPU will be used. + + !!! Note + If you *don't* intend to use a GPU, but you have one available, you'll need + to explicitly set `cuda_device=-1`. + + !!! Note + If you intend to use a GPU, your model already needs to be on the correct device, + which you can do with `model = model.cuda()`. + + !!! Note + Data parallelism is controlled at the allennlp train level, so each trainer will have a single GPU. + + grad_norm : `float`, optional, (default = `None`). + If provided, gradient norms will be rescaled to have a maximum of this value. + + grad_clipping : `float`, optional (default = `None`). + If provided, gradients will be clipped `during the backward pass` to have an (absolute) + maximum of this value. If you are getting `NaNs` in your gradients during training + that are not solved by using `grad_norm`, you may need this. + + learning_rate_scheduler : `LearningRateScheduler`, optional (default = `None`) + If specified, the learning rate will be decayed with respect to + this schedule at the end of each epoch (or batch, if the scheduler implements + the `step_batch` method). If you use `torch.optim.lr_scheduler.ReduceLROnPlateau`, + this will use the `validation_metric` provided to determine if learning has plateaued. + To support updating the learning rate on every batch, this can optionally implement + `step_batch(batch_num_total)` which updates the learning rate given the batch number. + + momentum_scheduler : `MomentumScheduler`, optional (default = `None`) + If specified, the momentum will be updated at the end of each batch or epoch + according to the schedule. + + moving_average : `MovingAverage`, optional, (default = `None`) + If provided, we will maintain moving averages for all parameters. During training, we + employ a shadow variable for each parameter, which maintains the moving average. During + evaluation, we backup the original parameters and assign the moving averages to corresponding + parameters. Be careful that when saving the checkpoint, we will save the moving averages of + parameters. This is necessary because we want the saved model to perform as well as the validated + model if we load it later. But this may cause problems if you restart the training from checkpoint. + + callbacks : `List[Lazy[TrainerCallback]]`, optional (default = `None`) + A list of callbacks that can be called at certain events: e.g. each batch, epoch, and at the start + and end of training, etc. + + distributed : `bool`, optional, (default = `False`) + If set, PyTorch's `DistributedDataParallel` is used to train the model in multiple GPUs. This also + requires `world_size` to be greater than 1. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately (you need a top-level "distributed" key, next to + the "trainer" entry, that specifies a list of "cuda_devices"). + + local_rank : `int`, optional, (default = `0`) + This is the unique identifier of the `Trainer` in a distributed process group. The GPU device id is + used as the rank. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately. + + world_size : `int`, (default = `1`) + The number of `Trainer` workers participating in the distributed training. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "trainer", it gets constructed separately. + + num_gradient_accumulation_steps : `int`, optional, (default = `1`) + Gradients are accumulated for the given number of steps before doing an optimizer step. This can + be useful to accommodate batches that are larger than the RAM size. Refer [Thomas Wolf's + post][0] for details on Gradient Accumulation. + + use_amp : `bool`, optional, (default = `False`) + If `True`, we'll train using [Automatic Mixed Precision](https://pytorch.org/docs/stable/amp.html). + + enable_default_callbacks : `bool`, optional (default = `True`) + When `True`, the [`DEFAULT_CALLBACKS`](#default_callbacks) will be used in + addition to any other callbacks listed in the `callbacks` parameter. + When set to `False`, `DEFAULT_CALLBACKS` are not used. + + run_confidence_checks : `bool`, optional (default = `True`) + Determines whether model confidence checks, such as + [`NormalizationBiasVerification`](../../confidence_checks/normalization_bias_verification/), + are run. + + run_sanity_checks : `bool`, optional (default = `True`) + This parameter is deprecated. Please use `run_confidence_checks` instead. + + """ + + def __init__( + self, + model: Model, + optimizer: torch.optim.Optimizer, + data_loader: DataLoader, + patience: Optional[int] = None, + validation_metric: Union[str, List[str]] = "-loss", + validation_data_loader: DataLoader = None, + num_epochs: int = 20, + serialization_dir: Optional[str] = None, + checkpointer: Checkpointer = None, + cuda_device: Optional[Union[int, torch.device]] = None, + grad_norm: Optional[float] = None, + grad_clipping: Optional[float] = None, + learning_rate_scheduler: Optional[LearningRateScheduler] = None, + momentum_scheduler: Optional[MomentumScheduler] = None, + moving_average: Optional[MovingAverage] = None, + callbacks: List[TrainerCallback] = None, + distributed: bool = False, + local_rank: int = 0, + world_size: int = 1, + num_gradient_accumulation_steps: int = 1, + use_amp: bool = False, + enable_default_callbacks: bool = True, + run_confidence_checks: bool = True, + **kwargs, + ) -> None: + super().__init__( + serialization_dir=serialization_dir, + cuda_device=cuda_device, + distributed=distributed, + local_rank=local_rank, + world_size=world_size, + ) + + if "run_sanity_checks" in kwargs: + warnings.warn( + "'run_sanity_checks' is deprecated, please use 'run_confidence_checks' instead.", + DeprecationWarning, + ) + run_confidence_checks = kwargs["run_sanity_checks"] + + # I am not calling move_to_gpu here, because if the model is + # not already on the GPU then the optimizer is going to be wrong. + self.model = model + + self.data_loader = data_loader + self.data_loader.set_target_device(self.cuda_device) + self._validation_data_loader = validation_data_loader + if self._validation_data_loader is not None: + self._validation_data_loader.set_target_device(self.cuda_device) + self.optimizer = optimizer + + if patience is None: # no early stopping + if validation_data_loader is not None: + logger.warning( + "You provided a validation dataset but patience was set to None, " + "meaning that early stopping is disabled" + ) + elif (not isinstance(patience, int)) or patience <= 0: + raise ConfigurationError( + '{} is an invalid value for "patience": it must be a positive integer ' + "or None (if you want to disable early stopping)".format(patience) + ) + + # For tracking is_best_so_far and should_stop_early + self._metric_tracker = MetricTracker(validation_metric, patience) + + self._num_epochs = num_epochs + + self._checkpointer: Optional[Checkpointer] = checkpointer + if checkpointer is None and serialization_dir is not None: + self._checkpointer = Checkpointer(serialization_dir) + + self._grad_norm = grad_norm + self._grad_clipping = grad_clipping + + self._learning_rate_scheduler = learning_rate_scheduler + self._momentum_scheduler = momentum_scheduler + self._moving_average = moving_average + + self._callbacks = callbacks or [] + default_callbacks = list(DEFAULT_CALLBACKS) if enable_default_callbacks else [] + + if run_confidence_checks: + default_callbacks.append(ConfidenceChecksCallback) + for callback_cls in default_callbacks: + for callback in self._callbacks: + if callback.__class__ == callback_cls: + break + else: + self._callbacks.append(callback_cls(self._serialization_dir)) + + self._num_gradient_accumulation_steps = num_gradient_accumulation_steps + + # Enable automatic mixed precision training. + self._scaler: Optional[amp.GradScaler] = None + self._use_amp = use_amp + if self._use_amp: + if self.cuda_device == torch.device("cpu"): + raise ValueError("Using AMP requires a cuda device") + self._scaler = amp.GradScaler() + + # Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its + # usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model` + # will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc. + # + # Hence a reference to Pytorch's object is maintained in the case of distributed training and in the + # normal case, reference to `Model` is retained. This reference is only used in + # these places: `model.__call__`, `model.train` and `model.eval`. + if self._distributed: + self._pytorch_model = DistributedDataParallel( + self.model, + device_ids=None if self.cuda_device == torch.device("cpu") else [self.cuda_device], + find_unused_parameters=True, + ) + else: + self._pytorch_model = self.model + + # training state management + self._epochs_completed: int = 0 + self._start_after_epochs_completed: int = 0 + self._batches_in_epoch_completed: int = 0 + self._start_after_batches_in_epoch_completed: int = 0 + self._best_model_filename: Optional[str] = None + + # This is a kind of training state, but it is not serialized with the trainer state, because we can + # re-create it with `epochs_completed` and `batches_in_epoch_completed`. + self._total_batches_completed: int = 0 + + def rescale_gradients(self) -> float: + """ + Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled. + + Returns the norm of the gradients. + """ + parameters_to_clip = [p for p in self.model.parameters() if p.grad is not None] + if self._grad_norm: + if self._scaler is not None: + # Need to first unscale gradients in order to clip as usual. + self._scaler.unscale_(self.optimizer) + return clip_grad_norm_(parameters_to_clip, self._grad_norm) + else: + return torch.norm( + torch.stack([torch.norm(p.grad.detach()) for p in parameters_to_clip]) + ) + + def batch_outputs(self, batch: TensorDict, for_training: bool) -> Dict[str, torch.Tensor]: + """ + Does a forward pass on the given batch and returns the output dictionary that the model + returns, after adding any specified regularization penalty to the loss (if training). + """ + output_dict = self._pytorch_model(**batch) + + if for_training: + try: + assert "loss" in output_dict + regularization_penalty = self.model.get_regularization_penalty() + + if regularization_penalty is not None: + output_dict["reg_loss"] = regularization_penalty + output_dict["loss"] += regularization_penalty + + except AssertionError: + if for_training: + raise RuntimeError( + "The model you are trying to optimize does not contain a" + " 'loss' key in the output of model.forward(inputs)." + ) + + return output_dict + + def _train_epoch(self, epoch: int) -> Dict[str, float]: + """ + Trains one epoch and returns metrics. + """ + logger.info("Epoch %d/%d", epoch, self._num_epochs - 1) + cpu_memory_usage = [] + for worker, memory in common_util.peak_cpu_memory().items(): + cpu_memory_usage.append((worker, memory)) + logger.info(f"Worker {worker} memory usage: {common_util.format_size(memory)}") + gpu_memory_usage = [] + for gpu, memory in common_util.peak_gpu_memory().items(): + gpu_memory_usage.append((gpu, memory)) + logger.info(f"GPU {gpu} memory usage: {common_util.format_size(memory)}") + + regularization_penalty = self.model.get_regularization_penalty() + + train_loss = 0.0 + train_reg_loss = None if regularization_penalty is None else 0.0 + batch_reg_loss = None if regularization_penalty is None else 0.0 + + # Set the model to "train" mode. + self._pytorch_model.train() + + # Get tqdm for the training batches + batch_generator = iter(self.data_loader) + batch_group_generator = common_util.lazy_groups_of( + batch_generator, self._num_gradient_accumulation_steps + ) + + logger.info("Training") + + num_training_batches: Union[int, float] + try: + len_data_loader = len(self.data_loader) + num_training_batches = math.ceil( + len_data_loader / self._num_gradient_accumulation_steps + ) + except TypeError: + num_training_batches = float("inf") + + # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the primary's + # progress is shown + if self._primary: + batch_group_generator_tqdm = Tqdm.tqdm( + batch_group_generator, total=num_training_batches + ) + else: + batch_group_generator_tqdm = batch_group_generator + + done_early = False + for batch_group in batch_group_generator_tqdm: + if done_early: + break + + if self._epochs_completed < self._start_after_epochs_completed or ( + self._epochs_completed == self._start_after_epochs_completed + and self._batches_in_epoch_completed < self._start_after_batches_in_epoch_completed + ): + self._batches_in_epoch_completed += 1 + self._total_batches_completed += 1 + continue + + self.optimizer.zero_grad() + + batch_loss = 0.0 + batch_group_outputs = [] + for batch in batch_group: + if self._distributed: + # Check whether the other workers have stopped already (due to differing amounts of + # data in each). If so, we can't proceed because we would hang when we hit the + # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor + # here because NCCL process groups apparently don't support BoolTensor. + done = torch.tensor(0, device=self.cuda_device) + torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) + if done.item() > 0: + done_early = True + logger.warning( + f"Worker {torch.distributed.get_rank()} finishing training early! " + "This implies that there is an imbalance in your training " + "data across the workers and that some amount of it will be " + "ignored. A small amount of this is fine, but a major imbalance " + "should be avoided. Note: This warning will appear unless your " + "data is perfectly balanced." + ) + break + + with amp.autocast(self._use_amp): + batch_outputs = self.batch_outputs(batch, for_training=True) + batch_group_outputs.append(batch_outputs) + loss = batch_outputs["loss"] + reg_loss = batch_outputs.get("reg_loss") + if torch.isnan(loss): + raise ValueError("nan loss encountered") + loss = loss / len(batch_group) + + batch_loss += loss.item() + if reg_loss is not None: + reg_loss = reg_loss / len(batch_group) + batch_reg_loss = reg_loss.item() + train_reg_loss += batch_reg_loss # type: ignore + + if self._scaler is not None: + self._scaler.scale(loss).backward() + else: + loss.backward() + if len(batch_group_outputs) <= 0: + continue + + train_loss += batch_loss + + batch_grad_norm = self.rescale_gradients() + + if self._learning_rate_scheduler: + self._learning_rate_scheduler.step_batch(self._total_batches_completed + 1) + if self._momentum_scheduler: + self._momentum_scheduler.step_batch(self._total_batches_completed + 1) + + if self._scaler is not None: + self._scaler.step(self.optimizer) + self._scaler.update() + else: + self.optimizer.step() + + # Update moving averages + if self._moving_average is not None: + self._moving_average.apply(self._total_batches_completed + 1) + + self._batches_in_epoch_completed += 1 + self._total_batches_completed += 1 + + # Update the description with the latest metrics + metrics = training_util.get_metrics( + self.model, + train_loss, + train_reg_loss, + batch_loss, + batch_reg_loss, + self._batches_in_epoch_completed, + world_size=self._world_size, + cuda_device=self.cuda_device, + ) + + for callback in self._callbacks: + callback.on_batch( + self, + batch_group, + batch_group_outputs, + metrics, + epoch, + self._batches_in_epoch_completed, + is_training=True, + is_primary=self._primary, + batch_grad_norm=batch_grad_norm, + ) + + if self._primary: + # Updating tqdm only for the primary as the trainers wouldn't have one + description = training_util.description_from_metrics(metrics) + batch_group_generator_tqdm.set_description(description, refresh=False) + + if self._checkpointer is not None: + self._checkpointer.maybe_save_checkpoint( + self, self._epochs_completed, self._batches_in_epoch_completed + ) + + if self._distributed and not done_early: + logger.warning( + f"Worker {torch.distributed.get_rank()} completed its entire epoch (training)." + ) + # Indicate that we're done so that any workers that have remaining data stop the epoch early. + done = torch.tensor(1, device=self.cuda_device) + torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) + assert done.item() + + # Let all workers finish their epoch before computing + # the final statistics for the epoch. + if self._distributed: + dist.barrier() + + metrics = training_util.get_metrics( + self.model, + train_loss, + train_reg_loss, + batch_loss=None, + batch_reg_loss=None, + num_batches=self._batches_in_epoch_completed, + reset=True, + world_size=self._world_size, + cuda_device=self.cuda_device, + ) + + for (worker, memory) in cpu_memory_usage: + metrics["worker_" + str(worker) + "_memory_MB"] = memory / (1024 * 1024) + for (gpu_num, memory) in gpu_memory_usage: + metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory / (1024 * 1024) + return metrics + + def _validation_loss(self, epoch: int) -> Tuple[float, Optional[float], int]: + """ + Computes the validation loss. Returns it and the number of batches. + """ + logger.info("Validating") + + self._pytorch_model.eval() + + # Replace parameter values with the shadow values from the moving averages. + if self._moving_average is not None: + self._moving_average.assign_average_value() + try: + if self._validation_data_loader is not None: + validation_data_loader = self._validation_data_loader + else: + raise ConfigurationError( + "Validation results cannot be calculated without a validation_data_loader" + ) + + regularization_penalty = self.model.get_regularization_penalty() + + # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the primary's + # progress is shown + if self._primary: + val_generator_tqdm = Tqdm.tqdm(validation_data_loader) + else: + val_generator_tqdm = validation_data_loader + + batches_this_epoch = 0 + val_loss = 0.0 + val_batch_loss = 0.0 + val_reg_loss = None if regularization_penalty is None else 0.0 + val_batch_reg_loss = None if regularization_penalty is None else 0.0 + done_early = False + for batch in val_generator_tqdm: + if self._distributed: + # Check whether the other workers have stopped already (due to differing amounts of + # data in each). If so, we can't proceed because we would hang when we hit the + # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor + # here because NCCL process groups apparently don't support BoolTensor. + done = torch.tensor(0, device=self.cuda_device) + torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) + if done.item() > 0: + done_early = True + logger.warning( + f"Worker {torch.distributed.get_rank()} finishing validation early! " + "This implies that there is an imbalance in your validation " + "data across the workers and that some amount of it will be " + "ignored. A small amount of this is fine, but a major imbalance " + "should be avoided. Note: This warning will appear unless your " + "data is perfectly balanced." + ) + break + + with amp.autocast(self._use_amp): + batch_outputs = self.batch_outputs(batch, for_training=False) + loss = batch_outputs.get("loss") + reg_loss = batch_outputs.get("reg_loss") + if loss is not None: + # You shouldn't necessarily have to compute a loss for validation, so we allow for + # `loss` to be None. We need to be careful, though - `batches_this_epoch` is + # currently only used as the divisor for the loss function, so we can safely only + # count those batches for which we actually have a loss. If this variable ever + # gets used for something else, we might need to change things around a bit. + batches_this_epoch += 1 + val_batch_loss = loss.item() + val_loss += val_batch_loss + if reg_loss is not None: + val_batch_reg_loss = reg_loss.item() + val_reg_loss += val_batch_reg_loss # type: ignore + + # Update the description with the latest metrics + val_metrics = training_util.get_metrics( + self.model, + val_loss, + val_reg_loss, + val_batch_loss, + val_batch_reg_loss, + batches_this_epoch, + world_size=self._world_size, + cuda_device=self.cuda_device, + ) + + description = training_util.description_from_metrics(val_metrics) + if self._primary: + val_generator_tqdm.set_description(description, refresh=False) + + for callback in self._callbacks: + callback.on_batch( + self, + [batch], + [batch_outputs], + val_metrics, + epoch, + batches_this_epoch, + is_training=False, + is_primary=self._primary, + ) + + if self._distributed and not done_early: + logger.warning( + f"Worker {torch.distributed.get_rank()} completed its entire epoch (validation)." + ) + # Indicate that we're done so that any workers that have remaining data stop validation early. + done = torch.tensor(1, device=self.cuda_device) + torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) + assert done.item() + + return val_loss, val_reg_loss, batches_this_epoch + finally: + # Now restore the original parameter values. + if self._moving_average is not None: + self._moving_average.restore() + + def train(self) -> Dict[str, Any]: + """ + Trains the supplied model with the supplied parameters. + """ + try: + self._restore_checkpoint() + except RuntimeError as e: + configuration_error = ConfigurationError( + "Could not recover training from the checkpoint. Did you mean to output to " + "a different serialization directory or delete the existing serialization " + "directory?" + ) + configuration_error.__cause__ = e + raise configuration_error + + # Callbacks get their `on_start` call even when we're starting from a checkpoint. + for callback in self._callbacks: + callback.on_start(self, is_primary=self._primary) + + # Set default values in case of failure + epoch = None + metrics = None + + try: + metrics, epoch = self._try_train() + return metrics + finally: + for callback in self._callbacks: + callback.on_end(self, metrics=metrics, epoch=epoch, is_primary=self._primary) + + def _try_train(self) -> Tuple[Dict[str, Any], int]: + training_util.enable_gradient_clipping(self.model, self._grad_clipping) + + logger.info("Beginning training.") + + val_metrics: Dict[str, float] = {} + metrics: Dict[str, Any] = {} + training_start_time = None + + metrics["best_epoch"] = self._metric_tracker.best_epoch + for key, value in self._metric_tracker.best_epoch_metrics.items(): + metrics["best_validation_" + key] = value + + for epoch in range(self._num_epochs): + epoch_start_time = time.time() + train_metrics = self._train_epoch(epoch) + + if self._epochs_completed < self._start_after_epochs_completed: + # We're still catching up with the checkpoint, so we do nothing. + # Note that we have to call _train_epoch() even when we know the epoch is skipped. We have to + # read from the data loader, because the data loader and dataset readers might use randomness, + # and we have to make sure we consume exactly the same instances in exactly the same way every + # time we train, even when starting from a checkpoint, so that we update the randomness + # generators in the same way each time. + self._epochs_completed += 1 + self._batches_in_epoch_completed = 0 + continue + if training_start_time is None: + training_start_time = epoch_start_time + + # get peak of memory usage + for key, value in train_metrics.items(): + if key.startswith("gpu_") and key.endswith("_memory_MB"): + metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) + elif key.startswith("worker_") and key.endswith("_memory_MB"): + metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) + + this_epoch_val_metric: float = 0.0 + if self._validation_data_loader is not None: + with torch.no_grad(): + # We have a validation set, so compute all the metrics on it. + val_loss, val_reg_loss, num_batches = self._validation_loss(epoch) + + # It is safe again to wait till the validation is done. This is + # important to get the metrics right. + if self._distributed: + dist.barrier() + + val_metrics = training_util.get_metrics( + self.model, + val_loss, + val_reg_loss, + batch_loss=None, + batch_reg_loss=None, + num_batches=num_batches, + reset=True, + world_size=self._world_size, + cuda_device=self.cuda_device, + ) + + # Check validation metric for early stopping + this_epoch_val_metric = self._metric_tracker.combined_score(val_metrics) + self._metric_tracker.add_metrics(val_metrics) + + # Create overall metrics dict + training_elapsed_time = time.time() - training_start_time + metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time)) + metrics["epoch"] = epoch + + for key, value in train_metrics.items(): + metrics["training_" + key] = value + for key, value in val_metrics.items(): + metrics["validation_" + key] = value + + if self._metric_tracker.is_best_so_far(): + # Update all the best_ metrics. + # (Otherwise they just stay the same as they were.) + metrics["best_epoch"] = epoch + for key, value in val_metrics.items(): + metrics["best_validation_" + key] = value + + self._metric_tracker.best_epoch_metrics = val_metrics + + if self._serialization_dir and self._primary: + common_util.dump_metrics( + os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), + metrics, + ) + + # The Scheduler API is agnostic to whether your schedule requires a validation metric - + # if it doesn't, the validation metric passed here is ignored. + if self._learning_rate_scheduler: + self._learning_rate_scheduler.step(this_epoch_val_metric) + if self._momentum_scheduler: + self._momentum_scheduler.step(this_epoch_val_metric) + for callback in self._callbacks: + callback.on_epoch(self, metrics=metrics, epoch=epoch, is_primary=self._primary) + + self._epochs_completed += 1 + self._batches_in_epoch_completed = 0 + + # The checkpointer saves state from the learning rate scheduler, momentum scheduler, moving + # average, and callbacks, so we have to make sure those are updated before we save the + # checkpoint here. + if self._primary and self._checkpointer is not None: + self._checkpointer.maybe_save_checkpoint( + self, self._epochs_completed, self._batches_in_epoch_completed + ) + # Wait for the primary process to finish saving the checkpoint + if self._distributed: + dist.barrier() + + if self._primary and self._serialization_dir and self._metric_tracker.is_best_so_far(): + self._best_model_filename = os.path.join(self._serialization_dir, "best.th") + if self._moving_average is None: + torch.save(self.model.state_dict(), self._best_model_filename) + else: + self._moving_average.assign_average_value() + try: + torch.save(self.model.state_dict(), self._best_model_filename) + finally: + self._moving_average.restore() + # Wait for the primary process to finish saving the best + if self._distributed: + dist.barrier() + + epoch_elapsed_time = time.time() - epoch_start_time + logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) + + if self._metric_tracker.should_stop_early(): + logger.info("Ran out of patience. Stopping training.") + break + + if epoch < self._num_epochs - 1: + time_per_epoch = training_elapsed_time / ( + (epoch + 1) - self._start_after_epochs_completed + ) + # Note: If the first non-skipped epoch is half skipped (because it was checkpointed half-way + # through), then this estimate is going to be optimistic. + estimated_time_remaining = ( + time_per_epoch * self._num_epochs + ) - training_elapsed_time + formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining))) + logger.info("Estimated training time remaining: %s", formatted_time) + else: + epoch = self._num_epochs - 1 + + # Load the best model state before returning + if self._best_model_filename is None or self._metric_tracker.is_best_so_far(): + self._finalize_model() + else: + # The model we're loading here has already been finalized. + self.model.load_state_dict(torch.load(self._best_model_filename)) + + return metrics, epoch + + def _finalize_model(self) -> None: + """If we have a moving average, we have to finalize the model at the end of training.""" + if self._moving_average is not None: + self._moving_average.assign_average_value() + + def get_checkpoint_state(self) -> TrainerCheckpoint: + model_state = self.model.state_dict() + + # These are the training states we need to persist. + training_states = { + "version": 1, + "metric_tracker": self._metric_tracker.state_dict(), + "optimizer": self.optimizer.state_dict(), + "callbacks": [cb.state_dict() for cb in self._callbacks], + "epochs_completed": self._epochs_completed, + "batches_in_epoch_completed": self._batches_in_epoch_completed, + "best_model_filename": self._best_model_filename, + } + + # If we have any of these optional objects, we should persist them too. + if self._learning_rate_scheduler is not None: + training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict() + if self._momentum_scheduler is not None: + training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict() + if self._moving_average is not None: + training_states["moving_average"] = self._moving_average.state_dict() + + return TrainerCheckpoint(model_state, training_states) + + def _restore_checkpoint(self) -> None: + """ + Restores the model and training state from the last saved checkpoint. + This includes an epoch count and optimizer state, which is serialized separately + from model parameters. This function should only be used to continue training - + if you wish to load a model for inference/load parts of a model into a new + computation graph, you should use the native Pytorch functions: + `model.load_state_dict(torch.load("/path/to/model/weights.th"))` + + If `self._serialization_dir` does not exist or does not contain any checkpointed weights, + this function will do nothing. + """ + if self._checkpointer is None: + return + + model_state, training_state = self._checkpointer.load_checkpoint() + if len(model_state) <= 0 and len(training_state) <= 0: + self._start_after_epochs_completed = 0 + self._start_after_batches_in_epoch_completed = 0 + self._best_model_filename = None + return + if training_state["version"] != 1: + raise ValueError( + f"This version of {self.__class__.__name__} only supports checkpoints of version 1. " + f"Found version {training_state['version']}" + ) + + self.model.load_state_dict(model_state) + self._metric_tracker.load_state_dict(training_state["metric_tracker"]) + self.optimizer.load_state_dict(training_state["optimizer"]) + + for cb, state_dict in zip(self._callbacks, training_state["callbacks"]): + cb.load_state_dict(state_dict) + + if self._learning_rate_scheduler is not None: + self._learning_rate_scheduler.load_state_dict(training_state["learning_rate_scheduler"]) + if self._momentum_scheduler is not None: + self._momentum_scheduler.load_state_dict(training_state["momentum_scheduler"]) + if self._moving_average is not None: + self._moving_average.load_state_dict(training_state["moving_average"]) + + self._start_after_epochs_completed = training_state["epochs_completed"] + self._start_after_batches_in_epoch_completed = training_state["batches_in_epoch_completed"] + self._best_model_filename = training_state["best_model_filename"] + + @classmethod + def from_partial_objects( + cls, + model: Model, + serialization_dir: str, + data_loader: DataLoader, + validation_data_loader: DataLoader = None, + local_rank: int = 0, + patience: int = None, + validation_metric: Union[str, List[str]] = "-loss", + num_epochs: int = 20, + cuda_device: Optional[Union[int, torch.device]] = None, + grad_norm: float = None, + grad_clipping: float = None, + distributed: bool = False, + world_size: int = 1, + num_gradient_accumulation_steps: int = 1, + use_amp: bool = False, + no_grad: List[str] = None, + optimizer: Lazy[Optimizer] = Lazy(Optimizer.default), + learning_rate_scheduler: Lazy[LearningRateScheduler] = None, + momentum_scheduler: Lazy[MomentumScheduler] = None, + moving_average: Lazy[MovingAverage] = None, + checkpointer: Lazy[Checkpointer] = Lazy(Checkpointer), + callbacks: List[Lazy[TrainerCallback]] = None, + enable_default_callbacks: bool = True, + run_confidence_checks: bool = True, + **kwargs, + ) -> Trainer: + """ + This method exists so that we can have a documented method to construct this class using + `FromParams`. If you are not using `FromParams` or config files, you can safely ignore this + method. + + The reason we can't just use `__init__` with `FromParams` here is because there are + sequential dependencies to this class's arguments. Anything that has a `Lazy[]` type + annotation needs something from one of the non-`Lazy` arguments. The `Optimizer` needs to + have the parameters from the `Model` before it's constructed, and the `Schedulers` need to + have the `Optimizer`. Because of this, the typical way we construct things `FromParams` + doesn't work, so we use `Lazy` to allow for constructing the objects sequentially. + + If you're not using `FromParams`, you can just construct these arguments in the right order + yourself in your code and call the constructor directly. + """ + if cuda_device is None: + from torch import cuda + + if cuda.device_count() > 0: + cuda_device = 0 + else: + cuda_device = -1 + + check_for_gpu(cuda_device) + if cuda_device >= 0: + # Moving model to GPU here so that the optimizer state gets constructed on + # the right device. + model = model.cuda(cuda_device) + + if no_grad: + for name, parameter in model.named_parameters(): + if any(re.search(regex, name) for regex in no_grad): + parameter.requires_grad_(False) + + parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad] + optimizer_ = optimizer.construct(model_parameters=parameters) + + common_util.log_frozen_and_tunable_parameter_names(model) + + batches_per_epoch: Optional[int] + try: + batches_per_epoch = len(data_loader) + batches_per_epoch = math.ceil(batches_per_epoch / num_gradient_accumulation_steps) + except TypeError: + batches_per_epoch = None + + moving_average_ = ( + None if moving_average is None else moving_average.construct(parameters=parameters) + ) + learning_rate_scheduler_ = ( + None + if learning_rate_scheduler is None + else learning_rate_scheduler.construct( + optimizer=optimizer_, num_epochs=num_epochs, num_steps_per_epoch=batches_per_epoch + ) + ) + momentum_scheduler_ = ( + None + if momentum_scheduler is None + else momentum_scheduler.construct(optimizer=optimizer_) + ) + checkpointer_ = checkpointer.construct(serialization_dir=serialization_dir) + + callbacks_: List[TrainerCallback] = [] + for callback_ in callbacks or []: + callbacks_.append(callback_.construct(serialization_dir=serialization_dir)) + + return cls( + model, + optimizer_, + data_loader, + patience=patience, + validation_metric=validation_metric, + validation_data_loader=validation_data_loader, + num_epochs=num_epochs, + serialization_dir=serialization_dir, + cuda_device=cuda_device, + grad_norm=grad_norm, + grad_clipping=grad_clipping, + learning_rate_scheduler=learning_rate_scheduler_, + momentum_scheduler=momentum_scheduler_, + checkpointer=checkpointer_, + moving_average=moving_average_, + callbacks=callbacks_, + distributed=distributed, + local_rank=local_rank, + world_size=world_size, + num_gradient_accumulation_steps=num_gradient_accumulation_steps, + use_amp=use_amp, + enable_default_callbacks=enable_default_callbacks, + run_confidence_checks=run_confidence_checks, + **kwargs, + ) + + def get_best_weights_path(self) -> Optional[str]: + return self._best_model_filename + + +DEFAULT_CALLBACKS: Tuple[Type[TrainerCallback]] = (ConsoleLoggerCallback,) +""" +The default callbacks used by `GradientDescentTrainer`. +""" diff --git a/allennlp/training/moving_average.py b/allennlp/training/moving_average.py index 205eec973fa..4657e2d45dd 100644 --- a/allennlp/training/moving_average.py +++ b/allennlp/training/moving_average.py @@ -1,4 +1,4 @@ -from typing import Iterable, Tuple, Optional +from typing import Iterable, Tuple, Optional, Any, Dict import torch @@ -41,6 +41,14 @@ def restore(self) -> None: for name, parameter in self._parameters: parameter.data.copy_(self._backups[name]) + def state_dict(self) -> Dict[str, Any]: + return {"parameters": self._parameters, "shadows": self._shadows, "backups": self._backups} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._parameters = state_dict["parameters"] + self._shadows = state_dict["shadows"] + self._backups = state_dict["backups"] + @MovingAverage.register("exponential") class ExponentialMovingAverage(MovingAverage): diff --git a/allennlp/training/no_op_trainer.py b/allennlp/training/no_op_trainer.py index 93ee3542aec..47cec4323d0 100644 --- a/allennlp/training/no_op_trainer.py +++ b/allennlp/training/no_op_trainer.py @@ -1,10 +1,11 @@ import os -from contextlib import contextmanager -from typing import Any, Dict, Iterator, Tuple +from typing import Any, Dict, Optional + +import torch from allennlp.models import Model from allennlp.training.checkpointer import Checkpointer -from allennlp.training.trainer import Trainer +from allennlp.training.trainer import Trainer, TrainerCheckpoint @Trainer.register("no_op") @@ -24,14 +25,24 @@ def __init__(self, serialization_dir: str, model: Model) -> None: super().__init__(serialization_dir, cuda_device=-1) self.model = model + self._best_model_filename: Optional[str] = None def train(self) -> Dict[str, Any]: assert self._serialization_dir is not None self.model.vocab.save_to_files(os.path.join(self._serialization_dir, "vocabulary")) checkpointer = Checkpointer(self._serialization_dir) - checkpointer.save_checkpoint(epoch=0, trainer=self, is_best_so_far=True) + checkpointer.save_checkpoint(self) + + best_model_filename = os.path.join(self._serialization_dir, "best.th") + torch.save(self.model.state_dict(), best_model_filename) + self._best_model_filename = best_model_filename + return {} - @contextmanager - def get_checkpoint_state(self) -> Iterator[Tuple[Dict[str, Any], Dict[str, Any]]]: - yield self.model.state_dict(), {} + def get_checkpoint_state(self) -> TrainerCheckpoint: + return TrainerCheckpoint( + self.model.state_dict(), {"epochs_completed": 0, "batches_in_epoch_completed": 0} + ) + + def get_best_weights_path(self) -> Optional[str]: + return self._best_model_filename diff --git a/allennlp/training/scheduler.py b/allennlp/training/scheduler.py index 26b115b68ed..e9cad0bc9ca 100644 --- a/allennlp/training/scheduler.py +++ b/allennlp/training/scheduler.py @@ -79,5 +79,4 @@ def step_batch(self, batch_num_total: int = None) -> None: By default, a scheduler is assumed to only update every epoch, not every batch. So this does nothing unless it's overriden. """ - return diff --git a/allennlp/training/trainer.py b/allennlp/training/trainer.py index 54d9b59ffb1..797a8f382d2 100644 --- a/allennlp/training/trainer.py +++ b/allennlp/training/trainer.py @@ -1,44 +1,23 @@ -import datetime import logging -import math import os -import re -import time -import traceback -import warnings -from contextlib import contextmanager -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Type +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union -from allennlp.common.util import int_to_device - -import torch -import torch.distributed as dist -from torch.cuda import amp import torch.optim.lr_scheduler -from torch.nn.parallel import DistributedDataParallel -from torch.nn.utils import clip_grad_norm_ -from allennlp.common import Lazy, Registrable, Tqdm -from allennlp.common import util as common_util +from allennlp.common import Registrable from allennlp.common.checks import ConfigurationError, check_for_gpu -from allennlp.data import DataLoader, TensorDict -from allennlp.models.model import Model -from allennlp.training import util as training_util -from allennlp.training.callbacks import ( - TrainerCallback, - ConfidenceChecksCallback, - ConsoleLoggerCallback, -) -from allennlp.training.checkpointer import Checkpointer -from allennlp.training.learning_rate_schedulers import LearningRateScheduler -from allennlp.training.metric_tracker import MetricTracker -from allennlp.training.momentum_schedulers import MomentumScheduler -from allennlp.training.moving_average import MovingAverage -from allennlp.training.optimizers import Optimizer +from allennlp.common.util import int_to_device logger = logging.getLogger(__name__) +@dataclass +class TrainerCheckpoint: + model_state: Dict[str, Any] + trainer_state: Dict[str, Any] + + class Trainer(Registrable): """ The base class for an AllenNLP trainer. It can do pretty much @@ -77,7 +56,7 @@ def __init__( if isinstance(cuda_device, list): raise ConfigurationError( - "In allennlp 1.0, the Trainer can only be assigned a single `cuda_device`. " + "In AllenNLP 1.0, the Trainer can only be assigned a single `cuda_device`. " "Instead, we use torch's DistributedDataParallel at the command level, meaning " "our Trainer always uses a single GPU per process." ) @@ -101,1035 +80,13 @@ def train(self) -> Dict[str, Any]: """ raise NotImplementedError - @contextmanager - def get_checkpoint_state(self) -> Iterator[Tuple[Dict[str, Any], Dict[str, Any]]]: + def get_checkpoint_state(self) -> TrainerCheckpoint: """ Returns a tuple of (model state, training state), where training state could have several internal components (e.g., for an, optimizer, learning rate scheduler, etc.). - - This is a context manager, and should be called as `with trainer.get_checkpoint_state() as - state:`, so that the trainer has the opportunity to change and restore its internal state - for checkpointing. This is used, e.g., for moving averages of model weights. """ raise NotImplementedError - -@Trainer.register("gradient_descent", constructor="from_partial_objects") -class GradientDescentTrainer(Trainer): - """ - A trainer for doing supervised learning with gradient descent. It just takes a labeled dataset - and a `DataLoader`, and uses the supplied `Optimizer` to learn the weights for your model over - some fixed number of epochs. You can also pass in a validation data_loader and enable early - stopping. There are many other bells and whistles as well. - - Registered as a `Trainer` with the name "gradient_descent" (and is also the default `Trainer`). - The constructor that is registered is [`from_partial_objects`](#from_partial_objects) - - see the arguments to that function for the exact keys that should be used, if you are using - a configuration file. They largely match the arguments to `__init__`, and we don't repeat their - docstrings in `from_partial_objects`. - - [0]: https://tinyurl.com/y5mv44fw - - # Parameters - - model : `Model`, required. - An AllenNLP model to be optimized. Pytorch Modules can also be optimized if - their `forward` method returns a dictionary with a "loss" key, containing a - scalar tensor representing the loss function to be optimized. - - If you are training your model using GPUs, your model should already be - on the correct device. (If you are using our `train` command this will be - handled for you.) - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately. - - optimizer : `torch.nn.Optimizer`, required. - An instance of a Pytorch Optimizer, instantiated with the parameters of the - model to be optimized. - - data_loader : `DataLoader`, required. - A `DataLoader` containing your `Dataset`, yielding padded indexed batches. - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately. - - patience : `Optional[int] > 0`, optional (default=`None`) - Number of epochs to be patient before early stopping: the training is stopped - after `patience` epochs with no improvement. If given, it must be `> 0`. - If None, early stopping is disabled. - - validation_metric : `Union[str, List[str]]`, optional (default=`"-loss"`) - Validation metric to measure for whether to stop training using patience - and whether to serialize an `is_best` model each epoch. The metric name - must be prepended with either "+" or "-", which specifies whether the metric - is an increasing or decreasing function. If you specify more than one metric, - the metrics will be summed to make the `is_best` decision. - - validation_data_loader : `DataLoader`, optional (default=`None`) - A `DataLoader` to use for the validation set. If `None`, then - use the training `DataLoader` with the validation data. - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately. - - num_epochs : `int`, optional (default = `20`) - Number of training epochs. - - serialization_dir : `str`, optional (default=`None`) - Path to directory for saving and loading model files. Models will not be saved if - this parameter is not passed. - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately. - - checkpointer : `Checkpointer`, optional (default=`None`) - A `Checkpointer` is responsible for periodically saving model weights. If none is given - here, we will construct one with default parameters. - - cuda_device : `Optional[Union[int, torch.device]]`, optional (default = `None`) - An integer or `torch.device` specifying the CUDA device to use for this process. - If -1, the CPU is used. If `None` and you have a GPU available, that GPU will be used. - - !!! Note - If you *don't* intend to use a GPU, but you have one available, you'll need - to explicitly set `cuda_device=-1`. - - !!! Note - If you intend to use a GPU, your model already needs to be on the correct device, - which you can do with `model = model.cuda()`. - - !!! Note - Data parallelism is controlled at the allennlp train level, so each trainer will have a single GPU. - - grad_norm : `float`, optional, (default = `None`). - If provided, gradient norms will be rescaled to have a maximum of this value. - - grad_clipping : `float`, optional (default = `None`). - If provided, gradients will be clipped `during the backward pass` to have an (absolute) - maximum of this value. If you are getting `NaNs` in your gradients during training - that are not solved by using `grad_norm`, you may need this. - - learning_rate_scheduler : `LearningRateScheduler`, optional (default = `None`) - If specified, the learning rate will be decayed with respect to - this schedule at the end of each epoch (or batch, if the scheduler implements - the `step_batch` method). If you use `torch.optim.lr_scheduler.ReduceLROnPlateau`, - this will use the `validation_metric` provided to determine if learning has plateaued. - To support updating the learning rate on every batch, this can optionally implement - `step_batch(batch_num_total)` which updates the learning rate given the batch number. - - momentum_scheduler : `MomentumScheduler`, optional (default = `None`) - If specified, the momentum will be updated at the end of each batch or epoch - according to the schedule. - - moving_average : `MovingAverage`, optional, (default = `None`) - If provided, we will maintain moving averages for all parameters. During training, we - employ a shadow variable for each parameter, which maintains the moving average. During - evaluation, we backup the original parameters and assign the moving averages to corresponding - parameters. Be careful that when saving the checkpoint, we will save the moving averages of - parameters. This is necessary because we want the saved model to perform as well as the validated - model if we load it later. But this may cause problems if you restart the training from checkpoint. - - callbacks : `List[Lazy[TrainerCallback]]`, optional (default = `None`) - A list of callbacks that can be called at certain events: e.g. each batch, epoch, and at the start - and end of training, etc. - - distributed : `bool`, optional, (default = `False`) - If set, PyTorch's `DistributedDataParallel` is used to train the model in multiple GPUs. This also - requires `world_size` to be greater than 1. - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately (you need a top-level "distributed" key, next to - the "trainer" entry, that specifies a list of "cuda_devices"). - - local_rank : `int`, optional, (default = `0`) - This is the unique identifier of the `Trainer` in a distributed process group. The GPU device id is - used as the rank. - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately. - - world_size : `int`, (default = `1`) - The number of `Trainer` workers participating in the distributed training. - - In a typical AllenNLP configuration file, this parameter does not get an entry under the - "trainer", it gets constructed separately. - - num_gradient_accumulation_steps : `int`, optional, (default = `1`) - Gradients are accumulated for the given number of steps before doing an optimizer step. This can - be useful to accommodate batches that are larger than the RAM size. Refer [Thomas Wolf's - post][0] for details on Gradient Accumulation. - - use_amp : `bool`, optional, (default = `False`) - If `True`, we'll train using [Automatic Mixed Precision](https://pytorch.org/docs/stable/amp.html). - - enable_default_callbacks : `bool`, optional (default = `True`) - When `True`, the [`DEFAULT_CALLBACKS`](#default_callbacks) will be used in - addition to any other callbacks listed in the `callbacks` parameter. - When set to `False`, `DEFAULT_CALLBACKS` are not used. - - run_confidence_checks : `bool`, optional (default = `True`) - Determines whether model confidence checks, such as - [`NormalizationBiasVerification`](../../confidence_checks/normalization_bias_verification/), - are run. - - run_sanity_checks : `bool`, optional (default = `True`) - This parameter is deprecated. Please use `run_confidence_checks` instead. - - """ - - def __init__( - self, - model: Model, - optimizer: torch.optim.Optimizer, - data_loader: DataLoader, - patience: Optional[int] = None, - validation_metric: Union[str, List[str]] = "-loss", - validation_data_loader: DataLoader = None, - num_epochs: int = 20, - serialization_dir: Optional[str] = None, - checkpointer: Checkpointer = None, - cuda_device: Optional[Union[int, torch.device]] = None, - grad_norm: Optional[float] = None, - grad_clipping: Optional[float] = None, - learning_rate_scheduler: Optional[LearningRateScheduler] = None, - momentum_scheduler: Optional[MomentumScheduler] = None, - moving_average: Optional[MovingAverage] = None, - callbacks: List[TrainerCallback] = None, - distributed: bool = False, - local_rank: int = 0, - world_size: int = 1, - num_gradient_accumulation_steps: int = 1, - use_amp: bool = False, - enable_default_callbacks: bool = True, - run_confidence_checks: bool = True, - **kwargs, - ) -> None: - super().__init__( - serialization_dir=serialization_dir, - cuda_device=cuda_device, - distributed=distributed, - local_rank=local_rank, - world_size=world_size, - ) - - if "run_sanity_checks" in kwargs: - warnings.warn( - "'run_sanity_checks' is deprecated, please use 'run_confidence_checks' instead.", - DeprecationWarning, - ) - run_confidence_checks = kwargs["run_sanity_checks"] - - # I am not calling move_to_gpu here, because if the model is - # not already on the GPU then the optimizer is going to be wrong. - self.model = model - - self.data_loader = data_loader - self.data_loader.set_target_device(self.cuda_device) - self._validation_data_loader = validation_data_loader - if self._validation_data_loader is not None: - self._validation_data_loader.set_target_device(self.cuda_device) - self.optimizer = optimizer - - if patience is None: # no early stopping - if validation_data_loader is not None: - logger.warning( - "You provided a validation dataset but patience was set to None, " - "meaning that early stopping is disabled" - ) - elif (not isinstance(patience, int)) or patience <= 0: - raise ConfigurationError( - '{} is an invalid value for "patience": it must be a positive integer ' - "or None (if you want to disable early stopping)".format(patience) - ) - - # For tracking is_best_so_far and should_stop_early - self._metric_tracker = MetricTracker(validation_metric, patience) - - self._num_epochs = num_epochs - - self._checkpointer: Optional[Checkpointer] = checkpointer - if checkpointer is None and serialization_dir is not None: - self._checkpointer = Checkpointer(serialization_dir) - - self._grad_norm = grad_norm - self._grad_clipping = grad_clipping - - self._learning_rate_scheduler = learning_rate_scheduler - self._momentum_scheduler = momentum_scheduler - self._moving_average = moving_average - - self._callbacks = callbacks or [] - default_callbacks = list(DEFAULT_CALLBACKS) if enable_default_callbacks else [] - - if run_confidence_checks: - default_callbacks.append(ConfidenceChecksCallback) - for callback_cls in default_callbacks: - for callback in self._callbacks: - if callback.__class__ == callback_cls: - break - else: - self._callbacks.append(callback_cls(self._serialization_dir)) - - self._batch_num_total = 0 - self._last_log = 0.0 # time of last logging - self._num_gradient_accumulation_steps = num_gradient_accumulation_steps - - # Enable automatic mixed precision training. - self._scaler: Optional[amp.GradScaler] = None - self._use_amp = use_amp - if self._use_amp: - if self.cuda_device == torch.device("cpu"): - raise ValueError("Using AMP requires a cuda device") - self._scaler = amp.GradScaler() - - # Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its - # usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model` - # will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc. - # - # Hence a reference to Pytorch's object is maintained in the case of distributed training and in the - # normal case, reference to `Model` is retained. This reference is only used in - # these places: `model.__call__`, `model.train` and `model.eval`. - if self._distributed: - self._pytorch_model = DistributedDataParallel( - self.model, - device_ids=None if self.cuda_device == torch.device("cpu") else [self.cuda_device], - find_unused_parameters=True, - ) - else: - self._pytorch_model = self.model - - def rescale_gradients(self) -> float: - """ - Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled. - - Returns the norm of the gradients. - """ - parameters_to_clip = [p for p in self.model.parameters() if p.grad is not None] - if self._grad_norm: - if self._scaler is not None: - # Need to first unscale gradients in order to clip as usual. - self._scaler.unscale_(self.optimizer) - return clip_grad_norm_(parameters_to_clip, self._grad_norm) - else: - return torch.norm( - torch.stack([torch.norm(p.grad.detach()) for p in parameters_to_clip]) - ) - - def batch_outputs(self, batch: TensorDict, for_training: bool) -> Dict[str, torch.Tensor]: - """ - Does a forward pass on the given batch and returns the output dictionary that the model - returns, after adding any specified regularization penalty to the loss (if training). - """ - output_dict = self._pytorch_model(**batch) - - if for_training: - try: - assert "loss" in output_dict - regularization_penalty = self.model.get_regularization_penalty() - - if regularization_penalty is not None: - output_dict["reg_loss"] = regularization_penalty - output_dict["loss"] += regularization_penalty - - except AssertionError: - if for_training: - raise RuntimeError( - "The model you are trying to optimize does not contain a" - " 'loss' key in the output of model.forward(inputs)." - ) - - return output_dict - - def _train_epoch(self, epoch: int) -> Dict[str, float]: - """ - Trains one epoch and returns metrics. - """ - logger.info("Epoch %d/%d", epoch, self._num_epochs - 1) - cpu_memory_usage = [] - for worker, memory in common_util.peak_cpu_memory().items(): - cpu_memory_usage.append((worker, memory)) - logger.info(f"Worker {worker} memory usage: {common_util.format_size(memory)}") - gpu_memory_usage = [] - for gpu, memory in common_util.peak_gpu_memory().items(): - gpu_memory_usage.append((gpu, memory)) - logger.info(f"GPU {gpu} memory usage: {common_util.format_size(memory)}") - - regularization_penalty = self.model.get_regularization_penalty() - - train_loss = 0.0 - batch_loss = 0.0 - train_reg_loss = None if regularization_penalty is None else 0.0 - batch_reg_loss = None if regularization_penalty is None else 0.0 - - # Set the model to "train" mode. - self._pytorch_model.train() - - # Get tqdm for the training batches - batch_generator = iter(self.data_loader) - batch_group_generator = common_util.lazy_groups_of( - batch_generator, self._num_gradient_accumulation_steps - ) - - logger.info("Training") - - num_training_batches: Union[int, float] - try: - len_data_loader = len(self.data_loader) - num_training_batches = math.ceil( - len_data_loader / self._num_gradient_accumulation_steps - ) - except TypeError: - num_training_batches = float("inf") - - # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the primary's - # progress is shown - if self._primary: - batch_group_generator_tqdm = Tqdm.tqdm( - batch_group_generator, total=num_training_batches - ) - else: - batch_group_generator_tqdm = batch_group_generator - - self._last_log = time.time() - - batches_this_epoch = 0 - if self._batch_num_total is None: - self._batch_num_total = 0 - - done_early = False - for batch_group in batch_group_generator_tqdm: - if done_early: - break - - batches_this_epoch += 1 - self._batch_num_total += 1 - batch_num_total = self._batch_num_total - - # Zero gradients. - # NOTE: this is actually more efficient than calling `self.optimizer.zero_grad()` - # because it avoids a read op when the gradients are first updated below. - for param_group in self.optimizer.param_groups: - for p in param_group["params"]: - p.grad = None - - batch_loss = 0.0 - batch_group_outputs = [] - for batch in batch_group: - if self._distributed: - # Check whether the other workers have stopped already (due to differing amounts of - # data in each). If so, we can't proceed because we would hang when we hit the - # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor - # here because NCCL process groups apparently don't support BoolTensor. - done = torch.tensor(0, device=self.cuda_device) - torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) - if done.item() > 0: - done_early = True - logger.warning( - f"Worker {torch.distributed.get_rank()} finishing training early! " - "This implies that there is an imbalance in your training " - "data across the workers and that some amount of it will be " - "ignored. A small amount of this is fine, but a major imbalance " - "should be avoided. Note: This warning will appear unless your " - "data is perfectly balanced." - ) - break - - with amp.autocast(self._use_amp): - batch_outputs = self.batch_outputs(batch, for_training=True) - batch_group_outputs.append(batch_outputs) - loss = batch_outputs["loss"] - reg_loss = batch_outputs.get("reg_loss") - if torch.isnan(loss): - raise ValueError("nan loss encountered") - loss = loss / len(batch_group) - - batch_loss += loss.item() - if reg_loss is not None: - reg_loss = reg_loss / len(batch_group) - batch_reg_loss = reg_loss.item() - train_reg_loss += batch_reg_loss # type: ignore - - if self._scaler is not None: - self._scaler.scale(loss).backward() - else: - loss.backward() - if len(batch_group_outputs) <= 0: - continue - - train_loss += batch_loss - - batch_grad_norm = self.rescale_gradients() - - # This does nothing if batch_num_total is None or you are using a - # scheduler which doesn't update per batch. - if self._learning_rate_scheduler: - self._learning_rate_scheduler.step_batch(batch_num_total) - if self._momentum_scheduler: - self._momentum_scheduler.step_batch(batch_num_total) - - if self._scaler is not None: - self._scaler.step(self.optimizer) - self._scaler.update() - else: - self.optimizer.step() - - # Update moving averages - if self._moving_average is not None: - self._moving_average.apply(batch_num_total) - - # Update the description with the latest metrics - metrics = training_util.get_metrics( - self.model, - train_loss, - train_reg_loss, - batch_loss, - batch_reg_loss, - batches_this_epoch, - world_size=self._world_size, - cuda_device=self.cuda_device, - ) - - if self._primary: - # Updating tqdm only for the primary as the trainers wouldn't have one - description = training_util.description_from_metrics(metrics) - batch_group_generator_tqdm.set_description(description, refresh=False) - - if self._checkpointer is not None: - self._checkpointer.maybe_save_checkpoint(self, epoch, batches_this_epoch) - - for callback in self._callbacks: - callback.on_batch( - self, - batch_group, - batch_group_outputs, - metrics, - epoch, - batches_this_epoch, - is_training=True, - is_primary=self._primary, - batch_grad_norm=batch_grad_norm, - ) - - if self._distributed and not done_early: - logger.warning( - f"Worker {torch.distributed.get_rank()} completed its entire epoch (training)." - ) - # Indicate that we're done so that any workers that have remaining data stop the epoch early. - done = torch.tensor(1, device=self.cuda_device) - torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) - assert done.item() - - # Let all workers finish their epoch before computing - # the final statistics for the epoch. - if self._distributed: - dist.barrier() - - metrics = training_util.get_metrics( - self.model, - train_loss, - train_reg_loss, - batch_loss=None, - batch_reg_loss=None, - num_batches=batches_this_epoch, - reset=True, - world_size=self._world_size, - cuda_device=self.cuda_device, - ) - - for (worker, memory) in cpu_memory_usage: - metrics["worker_" + str(worker) + "_memory_MB"] = memory / (1024 * 1024) - for (gpu_num, memory) in gpu_memory_usage: - metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory / (1024 * 1024) - return metrics - - def _validation_loss(self, epoch: int) -> Tuple[float, Optional[float], int]: - """ - Computes the validation loss. Returns it and the number of batches. - """ - logger.info("Validating") - - self._pytorch_model.eval() - - # Replace parameter values with the shadow values from the moving averages. - if self._moving_average is not None: - self._moving_average.assign_average_value() - - if self._validation_data_loader is not None: - validation_data_loader = self._validation_data_loader - else: - raise ConfigurationError( - "Validation results cannot be calculated without a validation_data_loader" - ) - - regularization_penalty = self.model.get_regularization_penalty() - - # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the primary's - # progress is shown - if self._primary: - val_generator_tqdm = Tqdm.tqdm(validation_data_loader) - else: - val_generator_tqdm = validation_data_loader - - batches_this_epoch = 0 - val_loss = 0.0 - val_batch_loss = 0.0 - val_reg_loss = None if regularization_penalty is None else 0.0 - val_batch_reg_loss = None if regularization_penalty is None else 0.0 - done_early = False - for batch in val_generator_tqdm: - if self._distributed: - # Check whether the other workers have stopped already (due to differing amounts of - # data in each). If so, we can't proceed because we would hang when we hit the - # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor - # here because NCCL process groups apparently don't support BoolTensor. - done = torch.tensor(0, device=self.cuda_device) - torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) - if done.item() > 0: - done_early = True - logger.warning( - f"Worker {torch.distributed.get_rank()} finishing validation early! " - "This implies that there is an imbalance in your validation " - "data across the workers and that some amount of it will be " - "ignored. A small amount of this is fine, but a major imbalance " - "should be avoided. Note: This warning will appear unless your " - "data is perfectly balanced." - ) - break - - with amp.autocast(self._use_amp): - batch_outputs = self.batch_outputs(batch, for_training=False) - loss = batch_outputs.get("loss") - reg_loss = batch_outputs.get("reg_loss") - if loss is not None: - # You shouldn't necessarily have to compute a loss for validation, so we allow for - # `loss` to be None. We need to be careful, though - `batches_this_epoch` is - # currently only used as the divisor for the loss function, so we can safely only - # count those batches for which we actually have a loss. If this variable ever - # gets used for something else, we might need to change things around a bit. - batches_this_epoch += 1 - val_batch_loss = loss.item() - val_loss += val_batch_loss - if reg_loss is not None: - val_batch_reg_loss = reg_loss.item() - val_reg_loss += val_batch_reg_loss # type: ignore - - # Update the description with the latest metrics - val_metrics = training_util.get_metrics( - self.model, - val_loss, - val_reg_loss, - val_batch_loss, - val_batch_reg_loss, - batches_this_epoch, - world_size=self._world_size, - cuda_device=self.cuda_device, - ) - - description = training_util.description_from_metrics(val_metrics) - if self._primary: - val_generator_tqdm.set_description(description, refresh=False) - - for callback in self._callbacks: - callback.on_batch( - self, - [batch], - [batch_outputs], - val_metrics, - epoch, - batches_this_epoch, - is_training=False, - is_primary=self._primary, - ) - - if self._distributed and not done_early: - logger.warning( - f"Worker {torch.distributed.get_rank()} completed its entire epoch (validation)." - ) - # Indicate that we're done so that any workers that have remaining data stop validation early. - done = torch.tensor(1, device=self.cuda_device) - torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) - assert done.item() - - # Now restore the original parameter values. - if self._moving_average is not None: - self._moving_average.restore() - - return val_loss, val_reg_loss, batches_this_epoch - - def train(self) -> Dict[str, Any]: - """ - Trains the supplied model with the supplied parameters. - """ - - for callback in self._callbacks: - callback.on_start(self, is_primary=self._primary) - - # Set default values in case of failure - epoch = None - metrics = None - - try: - metrics, epoch = self._try_train() - return metrics - finally: - for callback in self._callbacks: - callback.on_end(self, metrics=metrics, epoch=epoch, is_primary=self._primary) - - def _try_train(self) -> Tuple[Dict[str, Any], int]: - try: - epoch_counter = self._restore_checkpoint() - except RuntimeError: - traceback.print_exc() - raise ConfigurationError( - "Could not recover training from the checkpoint. Did you mean to output to " - "a different serialization directory or delete the existing serialization " - "directory?" - ) - - training_util.enable_gradient_clipping(self.model, self._grad_clipping) - - logger.info("Beginning training.") - - val_metrics: Dict[str, float] = {} - metrics: Dict[str, Any] = {} - epochs_trained = 0 - training_start_time = time.time() - - metrics["best_epoch"] = self._metric_tracker.best_epoch - for key, value in self._metric_tracker.best_epoch_metrics.items(): - metrics["best_validation_" + key] = value - - for epoch in range(epoch_counter, self._num_epochs): - epoch_start_time = time.time() - train_metrics = self._train_epoch(epoch) - - # Back up the model now, in case something goes wrong later with the evaluation - if self._primary and self._checkpointer is not None: - self._checkpointer.shelve_model(epoch, self) - # Wait for the primary process to finish saving the model checkpoint - if self._distributed: - dist.barrier() - - # get peak of memory usage - for key, value in train_metrics.items(): - if key.startswith("gpu_") and key.endswith("_memory_MB"): - metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) - elif key.startswith("worker_") and key.endswith("_memory_MB"): - metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) - - this_epoch_val_metric: float = 0.0 - if self._validation_data_loader is not None: - with torch.no_grad(): - # We have a validation set, so compute all the metrics on it. - val_loss, val_reg_loss, num_batches = self._validation_loss(epoch) - - # It is safe again to wait till the validation is done. This is - # important to get the metrics right. - if self._distributed: - dist.barrier() - - val_metrics = training_util.get_metrics( - self.model, - val_loss, - val_reg_loss, - batch_loss=None, - batch_reg_loss=None, - num_batches=num_batches, - reset=True, - world_size=self._world_size, - cuda_device=self.cuda_device, - ) - - # Check validation metric for early stopping - this_epoch_val_metric = self._metric_tracker.combined_score(val_metrics) - self._metric_tracker.add_metrics(val_metrics) - - # Create overall metrics dict - training_elapsed_time = time.time() - training_start_time - metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time)) - metrics["training_start_epoch"] = epoch_counter - metrics["training_epochs"] = epochs_trained - metrics["epoch"] = epoch - - for key, value in train_metrics.items(): - metrics["training_" + key] = value - for key, value in val_metrics.items(): - metrics["validation_" + key] = value - - if self._metric_tracker.is_best_so_far(): - # Update all the best_ metrics. - # (Otherwise they just stay the same as they were.) - metrics["best_epoch"] = epoch - for key, value in val_metrics.items(): - metrics["best_validation_" + key] = value - - self._metric_tracker.best_epoch_metrics = val_metrics - - if self._serialization_dir and self._primary: - common_util.dump_metrics( - os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), - metrics, - ) - - # The Scheduler API is agnostic to whether your schedule requires a validation metric - - # if it doesn't, the validation metric passed here is ignored. - if self._learning_rate_scheduler: - self._learning_rate_scheduler.step(this_epoch_val_metric) - if self._momentum_scheduler: - self._momentum_scheduler.step(this_epoch_val_metric) - - # The checkpointer saves state from the learning rate scheduler and the momentum - # scheduler, so we have to make sure those are updated before we save the checkpoint here. - if self._primary and self._checkpointer is not None: - self._checkpointer.save_checkpoint( - epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far() - ) - # Wait for the primary process to finish saving the checkpoint - if self._distributed: - dist.barrier() - - for callback in self._callbacks: - callback.on_epoch(self, metrics=metrics, epoch=epoch, is_primary=self._primary) - - epoch_elapsed_time = time.time() - epoch_start_time - logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) - - if epoch < self._num_epochs - 1: - training_elapsed_time = time.time() - training_start_time - estimated_time_remaining = training_elapsed_time * ( - (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1 - ) - formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining))) - logger.info("Estimated training time remaining: %s", formatted_time) - - epochs_trained += 1 - - if self._metric_tracker.should_stop_early(): - logger.info("Ran out of patience. Stopping training.") - break - else: - epoch = self._num_epochs - 1 - - # Load the best model state before returning - best_model_state = ( - None if self._checkpointer is None else self._checkpointer.best_model_state() - ) - if best_model_state: - self.model.load_state_dict(best_model_state) - - return metrics, epoch - - @contextmanager - def get_checkpoint_state(self) -> Iterator[Tuple[Dict[str, Any], Dict[str, Any]]]: - if self._moving_average is not None: - # Assigning average value to model parameters. The checkpointer will call - # `restore_state_after_checkpointing` when it is done to put this back to what it was. - self._moving_average.assign_average_value() - - model_state = self.model.state_dict() - - # These are the training states we need to persist. - training_states = { - "metric_tracker": self._metric_tracker.state_dict(), - "optimizer": self.optimizer.state_dict(), - "batch_num_total": self._batch_num_total, - } - - # If we have a learning rate or momentum scheduler, we should persist them too. - if self._learning_rate_scheduler is not None: - training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict() - if self._momentum_scheduler is not None: - training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict() - - try: - yield model_state, training_states - finally: - if self._moving_average is not None: - self._moving_average.restore() - - def _restore_checkpoint(self) -> int: - """ - Restores the model and training state from the last saved checkpoint. - This includes an epoch count and optimizer state, which is serialized separately - from model parameters. This function should only be used to continue training - - if you wish to load a model for inference/load parts of a model into a new - computation graph, you should use the native Pytorch functions: - ` model.load_state_dict(torch.load("/path/to/model/weights.th"))` - - If `self._serialization_dir` does not exist or does not contain any checkpointed weights, - this function will do nothing and return 0. - - # Returns - - epoch: `int` - The epoch at which to resume training, which should be one after the epoch - in the saved training state. - """ - if self._checkpointer is None: - return 0 - - model_state, training_state = self._checkpointer.restore_checkpoint() - - if not training_state: - # No checkpoint to restore, start at 0 - return 0 - - self.model.load_state_dict(model_state) - self.optimizer.load_state_dict(training_state["optimizer"]) - if ( - self._learning_rate_scheduler is not None - and "learning_rate_scheduler" in training_state - ): - self._learning_rate_scheduler.load_state_dict(training_state["learning_rate_scheduler"]) - if self._momentum_scheduler is not None and "momentum_scheduler" in training_state: - self._momentum_scheduler.load_state_dict(training_state["momentum_scheduler"]) - training_util.move_optimizer_to_cuda(self.optimizer) - - # Currently the `training_state` contains a serialized `MetricTracker`. - if "metric_tracker" in training_state: - self._metric_tracker.load_state_dict(training_state["metric_tracker"]) - else: - self._metric_tracker.clear() - - if isinstance(training_state["epoch"], int): - epoch_to_return = training_state["epoch"] + 1 - else: - epoch_to_return = int(training_state["epoch"].split(".")[0]) + 1 - - # For older checkpoints with batch_num_total missing, default to old behavior where - # it is unchanged. - batch_num_total = training_state.get("batch_num_total") - if batch_num_total is not None: - self._batch_num_total = batch_num_total - - return epoch_to_return - - @classmethod - def from_partial_objects( - cls, - model: Model, - serialization_dir: str, - data_loader: DataLoader, - validation_data_loader: DataLoader = None, - local_rank: int = 0, - patience: int = None, - validation_metric: Union[str, List[str]] = "-loss", - num_epochs: int = 20, - cuda_device: Optional[Union[int, torch.device]] = None, - grad_norm: float = None, - grad_clipping: float = None, - distributed: bool = False, - world_size: int = 1, - num_gradient_accumulation_steps: int = 1, - use_amp: bool = False, - no_grad: List[str] = None, - optimizer: Lazy[Optimizer] = Lazy(Optimizer.default), - learning_rate_scheduler: Lazy[LearningRateScheduler] = None, - momentum_scheduler: Lazy[MomentumScheduler] = None, - moving_average: Lazy[MovingAverage] = None, - checkpointer: Lazy[Checkpointer] = Lazy(Checkpointer), - callbacks: List[Lazy[TrainerCallback]] = None, - enable_default_callbacks: bool = True, - run_confidence_checks: bool = True, - **kwargs, - ) -> "Trainer": - """ - This method exists so that we can have a documented method to construct this class using - `FromParams`. If you are not using `FromParams` or config files, you can safely ignore this - method. - - The reason we can't just use `__init__` with `FromParams` here is because there are - sequential dependencies to this class's arguments. Anything that has a `Lazy[]` type - annotation needs something from one of the non-`Lazy` arguments. The `Optimizer` needs to - have the parameters from the `Model` before it's constructed, and the `Schedulers` need to - have the `Optimizer`. Because of this, the typical way we construct things `FromParams` - doesn't work, so we use `Lazy` to allow for constructing the objects sequentially. - - If you're not using `FromParams`, you can just construct these arguments in the right order - yourself in your code and call the constructor directly. - """ - if cuda_device is None: - from torch import cuda - - if cuda.device_count() > 0: - cuda_device = 0 - else: - cuda_device = -1 - - check_for_gpu(cuda_device) - if cuda_device >= 0: - # Moving model to GPU here so that the optimizer state gets constructed on - # the right device. - model = model.cuda(cuda_device) - - if no_grad: - for name, parameter in model.named_parameters(): - if any(re.search(regex, name) for regex in no_grad): - parameter.requires_grad_(False) - - parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad] - optimizer_ = optimizer.construct(model_parameters=parameters) - - common_util.log_frozen_and_tunable_parameter_names(model) - - batches_per_epoch: Optional[int] - try: - batches_per_epoch = len(data_loader) - batches_per_epoch = math.ceil(batches_per_epoch / num_gradient_accumulation_steps) - except TypeError: - batches_per_epoch = None - - moving_average_ = ( - None if moving_average is None else moving_average.construct(parameters=parameters) - ) - learning_rate_scheduler_ = ( - None - if learning_rate_scheduler is None - else learning_rate_scheduler.construct( - optimizer=optimizer_, num_epochs=num_epochs, num_steps_per_epoch=batches_per_epoch - ) - ) - momentum_scheduler_ = ( - None - if momentum_scheduler is None - else momentum_scheduler.construct(optimizer=optimizer_) - ) - checkpointer_ = checkpointer.construct(serialization_dir=serialization_dir) - - callbacks_: List[TrainerCallback] = [] - for callback_ in callbacks or []: - callbacks_.append(callback_.construct(serialization_dir=serialization_dir)) - - return cls( - model, - optimizer_, - data_loader, - patience=patience, - validation_metric=validation_metric, - validation_data_loader=validation_data_loader, - num_epochs=num_epochs, - serialization_dir=serialization_dir, - cuda_device=cuda_device, - grad_norm=grad_norm, - grad_clipping=grad_clipping, - learning_rate_scheduler=learning_rate_scheduler_, - momentum_scheduler=momentum_scheduler_, - checkpointer=checkpointer_, - moving_average=moving_average_, - callbacks=callbacks_, - distributed=distributed, - local_rank=local_rank, - world_size=world_size, - num_gradient_accumulation_steps=num_gradient_accumulation_steps, - use_amp=use_amp, - enable_default_callbacks=enable_default_callbacks, - run_confidence_checks=run_confidence_checks, - **kwargs, - ) - - -DEFAULT_CALLBACKS: Tuple[Type[TrainerCallback]] = (ConsoleLoggerCallback,) -""" -The default callbacks used by `GradientDescentTrainer`. -""" + def get_best_weights_path(self) -> Optional[str]: + """Returns the path to file containing the current best weights.""" + return None diff --git a/tests/commands/no_op_train_test.py b/tests/commands/no_op_train_test.py index 948edb39784..c862fccb581 100644 --- a/tests/commands/no_op_train_test.py +++ b/tests/commands/no_op_train_test.py @@ -31,7 +31,7 @@ def test_train_model(self): serialization_dir = self.TEST_DIR / "serialization_directory" train_model(params(), serialization_dir=serialization_dir) - archive = load_archive(str(serialization_dir / "model.tar.gz")) + archive = load_archive(serialization_dir / "model.tar.gz") model = archive.model assert model.forward(torch.tensor([1, 2, 3]))["class"] == torch.tensor(98) assert model.vocab.get_vocab_size() == 9 diff --git a/tests/training/checkpointer_test.py b/tests/training/checkpointer_test.py index 206fa43278b..ac102a3a983 100644 --- a/tests/training/checkpointer_test.py +++ b/tests/training/checkpointer_test.py @@ -1,21 +1,19 @@ import os -import re import time -from contextlib import contextmanager from allennlp.common.testing import AllenNlpTestCase from allennlp.common.params import Params from allennlp.training import Checkpointer, Trainer +from allennlp.training.trainer import TrainerCheckpoint class FakeTrainer(Trainer): - def __init__(self, model_state, training_states): + def __init__(self, model_state, training_state): self._model_state = model_state - self._training_states = training_states + self._training_state = training_state - @contextmanager - def get_checkpoint_state(self): - yield self._model_state, self._training_states + def get_checkpoint_state(self) -> TrainerCheckpoint: + return TrainerCheckpoint(self._model_state, self._training_state) class TestCheckpointer(AllenNlpTestCase): @@ -26,77 +24,72 @@ def retrieve_and_delete_saved(self): and returns the saved epochs as two lists of integers. """ serialization_files = os.listdir(self.TEST_DIR) - model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x] - found_model_epochs = [ - int(re.search(r"model_state_epoch_([0-9\.\-]+)\.th", x).group(1)) - for x in model_checkpoints - ] + + model_checkpoints = [x for x in serialization_files if "model_state_" in x] + found_model_states = [Checkpointer._parse_model_state_path(x) for x in model_checkpoints] for f in model_checkpoints: os.remove(os.path.join(self.TEST_DIR, f)) - training_checkpoints = [x for x in serialization_files if "training_state_epoch" in x] - found_training_epochs = [ - int(re.search(r"training_state_epoch_([0-9\.\-]+)\.th", x).group(1)) - for x in training_checkpoints + + training_checkpoints = [x for x in serialization_files if "training_state_" in x] + found_training_states = [ + Checkpointer._parse_training_state_path(x) for x in training_checkpoints ] for f in training_checkpoints: os.remove(os.path.join(self.TEST_DIR, f)) - return sorted(found_model_epochs), sorted(found_training_epochs) + return sorted(found_model_states), sorted(found_training_states) def test_default(self): """ Tests that the default behavior keeps just the last 2 checkpoints. """ default_num_to_keep = 2 - num_epochs = 30 - target = list(range(num_epochs - default_num_to_keep, num_epochs)) + num_epochs = 5 + target = [(e, 0) for e in range(num_epochs - default_num_to_keep, num_epochs)] checkpointer = Checkpointer(serialization_dir=self.TEST_DIR) - - for e in range(num_epochs): - checkpointer.save_checkpoint( - epoch=e, - trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}), - is_best_so_far=False, - ) + for epochs_completed in range(num_epochs): + for batches_completed in [0, 5, 10]: + state = { + "epochs_completed": epochs_completed, + "batches_in_epoch_completed": batches_completed, + } + checkpointer.maybe_save_checkpoint( + FakeTrainer(model_state=state, training_state=state), + epochs_completed, + batches_completed, + ) models, training = self.retrieve_and_delete_saved() assert models == training == target def test_keep_zero(self): - checkpointer = Checkpointer( - serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=0 - ) - for e in range(10): - checkpointer.save_checkpoint( - epoch=e, - trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}), - is_best_so_far=True, + checkpointer = Checkpointer(serialization_dir=self.TEST_DIR, keep_most_recent_by_count=0) + for epochs_completed in range(5): + state = {"epochs_completed": epochs_completed, "batches_in_epoch_completed": 0} + checkpointer.maybe_save_checkpoint( + FakeTrainer(model_state=state, training_state=state), epochs_completed, 0 ) files = os.listdir(self.TEST_DIR) - assert "model_state_epoch_1.th" not in files - assert "training_state_epoch_1.th" not in files + assert not any("model_state_" in x for x in files) + assert not any("training_state_" in x for x in files) def test_with_time(self): - """ - Tests that keep_serialized_model_every_num_seconds parameter causes a checkpoint to be saved - after enough time has elapsed between epochs. - """ - num_to_keep = 10 num_epochs = 30 - target = list(range(num_epochs - num_to_keep, num_epochs)) pauses = [5, 18, 26] - target = sorted(set(target + pauses)) + target = [(e, 0) for e in pauses] checkpointer = Checkpointer( serialization_dir=self.TEST_DIR, - num_serialized_models_to_keep=num_to_keep, - keep_serialized_model_every_num_seconds=1, + save_completed_epochs=False, + save_every_num_seconds=1, + keep_most_recent_by_count=3, ) for e in range(num_epochs): if e in pauses: time.sleep(2) - checkpointer.save_checkpoint( - epoch=e, - trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}), - is_best_so_far=False, + state = {"epochs_completed": e, "batches_in_epoch_completed": 0} + checkpointer.maybe_save_checkpoint( + trainer=FakeTrainer(model_state=state, training_state=state), + num_epochs_completed=e, + num_batches_in_epoch_completed=0, ) models, training = self.retrieve_and_delete_saved() assert models == training == target diff --git a/tests/training/trainer_test.py b/tests/training/trainer_test.py index 3926adf0ec2..2373caafefd 100644 --- a/tests/training/trainer_test.py +++ b/tests/training/trainer_test.py @@ -2,7 +2,6 @@ import glob import json import os -import re import time from typing import Any, Dict, List, Optional @@ -217,11 +216,11 @@ def test_data_loader_lazy_epoch_size_correct(self): num_epochs=num_epochs, serialization_dir=self.TEST_DIR, ) - assert trainer._batch_num_total == 0 + assert trainer._total_batches_completed == 0 metrics = trainer.train() epoch = metrics["epoch"] assert epoch == num_epochs - 1 - assert trainer._batch_num_total == num_epochs * 2 + assert trainer._total_batches_completed == num_epochs * 2 def test_data_loader_lazy_epoch_size_correct_custom_epoch_size(self): self.data_loader_lazy.batches_per_epoch = 3 @@ -234,11 +233,11 @@ def test_data_loader_lazy_epoch_size_correct_custom_epoch_size(self): num_epochs=num_epochs, serialization_dir=self.TEST_DIR, ) - assert trainer._batch_num_total == 0 + assert trainer._total_batches_completed == 0 metrics = trainer.train() epoch = metrics["epoch"] assert epoch == num_epochs - 1 - assert trainer._batch_num_total == num_epochs * 3 + assert trainer._total_batches_completed == num_epochs * 3 def test_trainer_respects_epoch_size_equals_total(self): batches_per_epoch = 4 @@ -256,11 +255,11 @@ def test_trainer_respects_epoch_size_equals_total(self): num_epochs=num_epochs, serialization_dir=self.TEST_DIR, ) - assert trainer._batch_num_total == 0 + assert trainer._total_batches_completed == 0 metrics = trainer.train() epoch = metrics["epoch"] assert epoch == num_epochs - 1 - assert trainer._batch_num_total == num_epochs * batches_per_epoch + assert trainer._total_batches_completed == num_epochs * batches_per_epoch def test_trainer_respects_epoch_size_larger_tnan_total(self): batches_per_epoch = 7 @@ -278,11 +277,11 @@ def test_trainer_respects_epoch_size_larger_tnan_total(self): num_epochs=num_epochs, serialization_dir=self.TEST_DIR, ) - assert trainer._batch_num_total == 0 + assert trainer._total_batches_completed == 0 metrics = trainer.train() epoch = metrics["epoch"] assert epoch == num_epochs - 1 - assert trainer._batch_num_total == num_epochs * batches_per_epoch + assert trainer._total_batches_completed == num_epochs * batches_per_epoch def test_trainer_respects_epoch_size_smaller_tnan_total(self): batches_per_epoch = 1 @@ -300,11 +299,11 @@ def test_trainer_respects_epoch_size_smaller_tnan_total(self): num_epochs=num_epochs, serialization_dir=self.TEST_DIR, ) - assert trainer._batch_num_total == 0 + assert trainer._total_batches_completed == 0 metrics = trainer.train() epoch = metrics["epoch"] assert epoch == num_epochs - 1 - assert trainer._batch_num_total == num_epochs * batches_per_epoch + assert trainer._total_batches_completed == num_epochs * batches_per_epoch def test_trainer_can_resume_training(self): trainer = GradientDescentTrainer( @@ -316,6 +315,7 @@ def test_trainer_can_resume_training(self): serialization_dir=self.TEST_DIR, ) trainer.train() + new_trainer = GradientDescentTrainer( self.model, self.optimizer, @@ -324,9 +324,9 @@ def test_trainer_can_resume_training(self): num_epochs=3, serialization_dir=self.TEST_DIR, ) + new_trainer._restore_checkpoint() - epoch = new_trainer._restore_checkpoint() - assert epoch == 1 + assert new_trainer._start_after_epochs_completed == 1 tracker = trainer._metric_tracker assert tracker.is_best_so_far() @@ -359,8 +359,8 @@ def test_trainer_can_resume_training_for_exponential_moving_average(self): moving_average=new_moving_average, ) - epoch = new_trainer._restore_checkpoint() - assert epoch == 1 + new_trainer._restore_checkpoint() + assert new_trainer._start_after_epochs_completed == 1 tracker = trainer._metric_tracker assert tracker.is_best_so_far() @@ -605,8 +605,8 @@ def test_trainer_can_run_and_resume_with_momentum_scheduler(self): num_epochs=6, serialization_dir=self.TEST_DIR, ) - epoch = new_trainer._restore_checkpoint() - assert epoch == 4 + new_trainer._restore_checkpoint() + new_trainer._start_after_epochs_completed = 4 assert new_trainer._momentum_scheduler.last_epoch == 3 new_trainer.train() @@ -672,8 +672,8 @@ def test_trainer_can_resume_with_lr_scheduler(self): num_epochs=4, serialization_dir=self.TEST_DIR, ) - epoch = new_trainer._restore_checkpoint() - assert epoch == 2 + new_trainer._restore_checkpoint() + assert new_trainer._start_after_epochs_completed == 2 assert new_trainer._learning_rate_scheduler.last_epoch == 1 new_trainer.train() @@ -719,17 +719,20 @@ def test_trainer_respects_num_serialized_models_to_keep(self): self.data_loader, num_epochs=5, serialization_dir=self.TEST_DIR, - checkpointer=Checkpointer( - serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=3 - ), + checkpointer=Checkpointer(serialization_dir=self.TEST_DIR, keep_most_recent_by_count=3), ) trainer.train() # Now check the serialized files - for prefix in ["model_state_epoch_*", "training_state_epoch_*"]: - file_names = glob.glob(os.path.join(self.TEST_DIR, prefix)) - epochs = [int(re.search(r"_([0-9])\.th", fname).group(1)) for fname in file_names] - assert sorted(epochs) == [2, 3, 4] + expected = [(3, 0), (4, 0), (5, 0)] + + file_names = glob.glob(os.path.join(self.TEST_DIR, "model_state_e*_b*")) + epochs = [Checkpointer._parse_model_state_path(fname) for fname in file_names] + assert sorted(epochs) == expected + + file_names = glob.glob(os.path.join(self.TEST_DIR, "training_state_e*_b*")) + epochs = [Checkpointer._parse_training_state_path(fname) for fname in file_names] + assert sorted(epochs) == expected def test_trainer_saves_metrics_every_epoch(self): trainer = GradientDescentTrainer( @@ -739,9 +742,7 @@ def test_trainer_saves_metrics_every_epoch(self): validation_data_loader=self.validation_data_loader, num_epochs=5, serialization_dir=self.TEST_DIR, - checkpointer=Checkpointer( - serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=3 - ), + checkpointer=Checkpointer(serialization_dir=self.TEST_DIR, keep_most_recent_by_count=3), ) trainer.train() @@ -757,9 +758,6 @@ def test_trainer_respects_keep_serialized_model_every_num_seconds(self): # To test: # Create an fake data loader that sleeps for 2.5 second per epoch, so the total # training time for one epoch is slightly greater then 2.5 seconds. - # Run for 6 epochs, keeping the last 2 models, models also kept every 5 seconds. - # Check the resulting checkpoints. Should then have models at epochs - # 2, 4, plus the last two at 5 and 6. class SlowDataLoader: data_loader = SimpleDataLoader(self.instances, batch_size=2) @@ -781,19 +779,24 @@ def set_target_device(self, _): num_epochs=6, serialization_dir=self.TEST_DIR, checkpointer=Checkpointer( + save_completed_epochs=False, serialization_dir=self.TEST_DIR, - num_serialized_models_to_keep=2, - keep_serialized_model_every_num_seconds=5, + keep_most_recent_by_count=4, + save_every_num_seconds=5, ), ) trainer.train() # Now check the serialized files - for prefix in ["model_state_epoch_*", "training_state_epoch_*"]: - file_names = glob.glob(os.path.join(self.TEST_DIR, prefix)) - epochs = [int(re.search(r"_([0-9])\.th", fname).group(1)) for fname in file_names] - # epoch N has N-1 in file name - assert sorted(epochs) == [1, 3, 4, 5] + expected = [(1, 1), (3, 1), (5, 1)] + + file_names = glob.glob(os.path.join(self.TEST_DIR, "model_state_e*_b*")) + epochs = [Checkpointer._parse_model_state_path(fname) for fname in file_names] + assert sorted(epochs) == expected + + file_names = glob.glob(os.path.join(self.TEST_DIR, "training_state_e*_b*")) + epochs = [Checkpointer._parse_training_state_path(fname) for fname in file_names] + assert sorted(epochs) == expected def test_trainer_can_log_learning_rates_tensorboard(self): data_loader = SimpleDataLoader(self.instances, 4) @@ -853,54 +856,64 @@ def test_confidence_check_default(self): # Check is not run, so no failure. trainer.train() - def test_trainer_saves_models_at_specified_interval(self): - data_loader = SimpleDataLoader(self.instances, 4) + @pytest.mark.parametrize("checkpoint_to_keep", range(20)) + def test_trainer_restores_and_makes_same_results(self, checkpoint_to_keep: int): + batch_size = 2 + data_loader = SimpleDataLoader(self.instances, batch_size) + num_epochs = 10 + num_batches = len(self.instances) // batch_size trainer = GradientDescentTrainer( self.model, self.optimizer, data_loader, - num_epochs=2, + validation_data_loader=data_loader, + num_epochs=num_epochs, serialization_dir=self.TEST_DIR, checkpointer=Checkpointer( serialization_dir=self.TEST_DIR, - model_save_interval=0.0001, - num_serialized_models_to_keep=10, + save_every_num_seconds=0.0001, + keep_most_recent_by_count=20, ), ) - trainer.train() + original_metrics = trainer.train() # Now check the serialized files for models saved during the epoch. - prefix = "model_state_epoch_*" - file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix))) - epochs = [re.search(r"_([0-9\.\-]+)\.th", fname).group(1) for fname in file_names] - # We should have checkpoints at the end of each epoch and during each, e.g. - # [0.timestamp, 0, 1.timestamp, 1] - assert len(epochs) == 4 - assert epochs[3] == "1" - assert "." in epochs[0] - - # Now make certain we can restore from timestamped checkpoint. - # To do so, remove the checkpoint from the end of epoch 1&2, so - # that we are forced to restore from the timestamped checkpoints. - for k in range(2): - os.remove(os.path.join(self.TEST_DIR, "model_state_epoch_{}.th".format(k))) - os.remove(os.path.join(self.TEST_DIR, "training_state_epoch_{}.th".format(k))) + file_names = glob.glob(os.path.join(self.TEST_DIR, "model_state_e*_b*")) + checkpoints = [Checkpointer._parse_model_state_path(fname) for fname in file_names] + checkpoints.sort() + + expected = [(e, b) for e in range(num_epochs) for b in range(num_batches + 1)] + del expected[0] + expected.append((num_epochs, 0)) + expected = expected[-20:] + assert checkpoints == expected + + # Now make certain we can restore from checkpoint in the middle of an epoch. + # To do so, remove the checkpoint at the end of epochs. + for i, checkpoint in enumerate(checkpoints): + if i != checkpoint_to_keep: + os.remove(trainer._checkpointer._model_state_path(*checkpoint)) + os.remove(trainer._checkpointer._training_state_path(*checkpoint)) os.remove(os.path.join(self.TEST_DIR, "best.th")) - restore_trainer = GradientDescentTrainer( + restored_trainer = GradientDescentTrainer( self.model, self.optimizer, self.data_loader, - num_epochs=2, + validation_data_loader=data_loader, + num_epochs=num_epochs, serialization_dir=self.TEST_DIR, - checkpointer=Checkpointer(serialization_dir=self.TEST_DIR, model_save_interval=0.0001), + checkpointer=Checkpointer( + serialization_dir=self.TEST_DIR, + save_every_num_seconds=0.0001, + keep_most_recent_by_count=10, + ), ) - epoch = restore_trainer._restore_checkpoint() - assert epoch == 2 - # One batch per epoch. - assert restore_trainer._batch_num_total == 2 + restored_metrics = restored_trainer.train() + + assert original_metrics["best_validation_loss"] == restored_metrics["best_validation_loss"] def test_trainer_saves_and_loads_best_validation_metrics_correctly_1(self): # Use -loss and run 1 epoch of original-training, and one of restored-training @@ -1033,9 +1046,11 @@ def test_trainer_can_run_gradient_accumulation(self): ) assert trainer._num_gradient_accumulation_steps == steps_to_accumulate - metrics = trainer.train() + trainer.train() - num_batches_trained_per_epoch = trainer._batch_num_total // (metrics["training_epochs"] + 1) + num_batches_trained_per_epoch = ( + trainer._total_batches_completed // trainer._epochs_completed + ) num_batches_expected = math.ceil( math.ceil(len(instances) / self.data_loader.batch_size) / steps_to_accumulate ) From 804fd59568e87d37e7bcbaab61a4c11eae6853ea Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld <dirkg@allenai.org> Date: Fri, 28 May 2021 19:51:16 -0700 Subject: [PATCH 43/63] Emergency fix. I forgot to take this out. --- .github/workflows/ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e0e3494a954..a53b1b6e12b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -156,7 +156,6 @@ jobs: run: | git clone https://github.com/allenai/allennlp-models.git cd allennlp-models - git checkout Checkpointing pip install --upgrade --upgrade-strategy eager -e . -r dev-requirements.txt - name: Run models tests From deeec84b649406f4dc31a59efddd30ce13c4176b Mon Sep 17 00:00:00 2001 From: Daniel Deutsch <danieldeutsch@users.noreply.github.com> Date: Tue, 1 Jun 2021 13:42:05 -0400 Subject: [PATCH 44/63] Add constraints to beam search (#5216) * Implementing blocking repeated ngrams * Adding comment * Adding unit tests for the end to end beam search * Renaming class * Adding comment about function * Simplifying indexing to variable * Refactoring the state copying into the class * Reformatting * Editing changelog * fix line too long * comments * doc updates Co-authored-by: Pete <petew@allenai.org> Co-authored-by: epwalsh <epwalsh10@gmail.com> --- CHANGELOG.md | 2 + allennlp/nn/beam_search.py | 242 +++++++++++++++++++++++++++++++--- scripts/py2md.py | 1 + tests/nn/beam_search_test.py | 245 ++++++++++++++++++++++++++++++++++- 4 files changed, 469 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 281391a0e14..07e39ab28f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences. - Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`. - Added `shuffle` argument to `BucketBatchSampler` which allows for disabling shuffling. +- Added a `Constraint` abstract class to `BeamSearch`, which allows for incorporating constraints on the predictions found by `BeamSearch`, + along with a `RepeatedNGramBlockingConstraint` constraint implementation, which allows for preventing repeated n-grams in the output from `BeamSearch`. - Added `DataCollator` for dynamic operations for each batch. ### Fixed diff --git a/allennlp/nn/beam_search.py b/allennlp/nn/beam_search.py index 4337e3efc4a..0b15d59bd7a 100644 --- a/allennlp/nn/beam_search.py +++ b/allennlp/nn/beam_search.py @@ -1,5 +1,6 @@ from inspect import signature -from typing import List, Callable, Tuple, Dict, cast, TypeVar, Optional +from typing import Any, List, Callable, Tuple, Dict, cast, TypeVar, Optional +import copy import warnings from overrides import overrides @@ -26,6 +27,8 @@ or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep). """ +ConstraintStateType = List[List[Dict[str, Any]]] + class Sampler(Registrable): """ @@ -524,6 +527,162 @@ def score( return average_log_probs +class Constraint(Registrable): + """ + An abstract class that can be used to enforce constraints on the output predictions + by manipulating the class log probabilities during beam search. + + A `Constraint` just has three methods that need to be implemented by subclasses: + `init_state()`, `apply()` and `_update_state()`. + + `init_state()` takes one argument: + + - the batch size, an int + + It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent + calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`. + Each inner list should be of length 1. + + `apply()` takes two arguments: + + - the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size` + and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1. + - `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the + log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`. + + The `apply()` method should return new `class_log_probabilities` that enforce the constraint + for this step of beam search. For instance, it may prevent a specific class from being selected by setting + the corresponding log probability to a negligible value such as `float("-inf")` or + `min_value_of_dtype(class_log_probabilities.dtype)`. + + `_update_state()` takes two arguments: + + - the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the + copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be + directly edited in-place without affecting the others. + - last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last + step of beam search. + + The `_update_state()` function should return a new constraint state, a nested list of dictionaries of + length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`. + + """ + + def init_state( + self, + batch_size: int, + ) -> ConstraintStateType: + raise NotImplementedError + + def apply( + self, + state: ConstraintStateType, + class_log_probabilities: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + @staticmethod + def _copy_state( + state: ConstraintStateType, + batch_size: int, + beam_size: int, + last_backpointer: Optional[torch.Tensor] = None, + ) -> ConstraintStateType: + """ + Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this + is not appropriate for your constraint, you will need to implement the copying yourself. + """ + new_state = [] + for i in range(batch_size): + batch_state = [] + for j in range(beam_size): + if last_backpointer is None: + # This is the first prediction, so the backpointer is 0 + backpointer = 0 + else: + backpointer = last_backpointer[i, j].item() + batch_state.append(copy.deepcopy(state[i][backpointer])) + new_state.append(batch_state) + return new_state + + def update_state( + self, + state: ConstraintStateType, + last_prediction: torch.Tensor, + last_backpointer: Optional[torch.Tensor] = None, + ) -> ConstraintStateType: + batch_size, beam_size = last_prediction.size() + new_state = self._copy_state(state, batch_size, beam_size, last_backpointer) + return self._update_state(new_state, last_prediction) + + def _update_state( + self, + state: ConstraintStateType, + last_prediction: torch.Tensor, + ) -> ConstraintStateType: + raise NotImplementedError + + +@Constraint.register("repeated-ngram-blocking") +class RepeatedNGramBlockingConstraint(Constraint): + def __init__(self, ngram_size: int) -> None: + super().__init__() + self.ngram_size = ngram_size + + @overrides + def init_state( + self, + batch_size: int, + ) -> ConstraintStateType: + return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)] + + @overrides + def apply( + self, + state: ConstraintStateType, + class_log_probabilities: torch.Tensor, + ) -> torch.Tensor: + for i, batch in enumerate(state): + for j, beam in enumerate(batch): + current_prefix = tuple(beam["current_prefix"]) + seen_ngrams = beam["seen_ngrams"] + try: + disallowed_indices = seen_ngrams[current_prefix] + class_log_probabilities[i, j, disallowed_indices] = min_value_of_dtype( + class_log_probabilities.dtype + ) + except KeyError: + # We have not seen this prefix before, so there is no index + # that needs to be blocked + pass + return class_log_probabilities + + @overrides + def _update_state( + self, + state: ConstraintStateType, + last_prediction: torch.Tensor, + ) -> ConstraintStateType: + for i, batch in enumerate(state): + for j, beam in enumerate(batch): + prediction = last_prediction[i, j].item() + prefix = beam["current_prefix"] + seen_ngrams = beam["seen_ngrams"] + + if len(prefix) == self.ngram_size - 1: + # This is a new ngram that we have to remember + if tuple(prefix) not in seen_ngrams: + seen_ngrams[tuple(prefix)] = [] + seen_ngrams[tuple(prefix)].append(prediction) + + # Create the new prefix, removing the oldest index if the prefix + # is too long + prefix.append(prediction) + if len(prefix) == self.ngram_size: + prefix.pop(0) + return state + + class BeamSearch(FromParams): """ Implements the beam search algorithm for decoding the most likely sequences. @@ -566,6 +725,10 @@ class BeamSearch(FromParams): The output from this module is what is returned by the `search` method. If not specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences by the sum of the token log probabilities. + + constraints: `List[Constraint]`, optional (default = `None`) + An optional list of `Constraint`s which should be applied during beam search. If not + provided, no constraints will be enforced. """ def __init__( @@ -577,6 +740,7 @@ def __init__( sampler: Sampler = None, min_steps: Optional[int] = None, final_sequence_scorer: FinalSequenceScorer = None, + constraints: Optional[List[Constraint]] = None, ) -> None: if not max_steps > 0: raise ValueError("max_steps must be positive") @@ -597,6 +761,7 @@ def __init__( self.sampler = sampler or DeterministicSampler() self.min_steps = min_steps or 0 self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer() + self.constraints = constraints or [] @staticmethod def _reconstruct_sequences(predictions, backpointers): @@ -637,15 +802,14 @@ def search( Given a starting state and a step function, apply beam search to find the most likely target sequences. - # Notes - - If your step function returns `-inf` for some log probabilities - (like if you're using a masked log-softmax) then some of the "best" - sequences returned may also have `-inf` log probability. Specifically - this happens when the beam size is smaller than the number of actions - with finite log probability (non-zero probability) returned by the step function. - Therefore if you're using a mask you may want to check the results from `search` - and potentially discard sequences with non-finite log probability. + !!! Note + If your step function returns `-inf` for some log probabilities + (like if you're using a masked log-softmax) then some of the "best" + sequences returned may also have `-inf` log probability. Specifically + this happens when the beam size is smaller than the number of actions + with finite log probability (non-zero probability) returned by the step function. + Therefore if you're using a mask you may want to check the results from `search` + and potentially discard sequences with non-finite log probability. # Parameters @@ -719,6 +883,8 @@ def _search( # predictions[t-1][i][n], that it came from. backpointers: List[torch.Tensor] = [] + constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints] + # Calculate the first timestep. This is done outside the main loop # because we are going from a single decoder input (the output from the # encoder) to the top `beam_size` decoder outputs. On the other hand, @@ -742,9 +908,21 @@ def _search( start_class_log_probabilities, batch_size, num_classes ) + # Apply all constraints. + if self.constraints: + # shape: (batch_size, 1, num_classes) + expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1) + for constraint, constraint_state in zip(self.constraints, constraint_states): + expanded_start_class_log_probabilities = constraint.apply( + constraint_state, expanded_start_class_log_probabilities + ) + start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1) + # Prevent selecting the end symbol if there is any min_steps constraint if self.min_steps >= 1: - start_class_log_probabilities[:, self._end_index] = float("-inf") + start_class_log_probabilities[:, self._end_index] = min_value_of_dtype( + start_class_log_probabilities.dtype + ) # Get the initial predicted classed and their log probabilities. # shape: (batch_size, beam_size), (batch_size, beam_size) @@ -772,13 +950,19 @@ def _search( # Log probability tensor that mandates that the end token is selected. # shape: (batch_size * beam_size, num_classes) log_probs_after_end = start_class_log_probabilities.new_full( - (batch_size * self.beam_size, num_classes), float("-inf") + (batch_size * self.beam_size, num_classes), + min_value_of_dtype(start_class_log_probabilities.dtype), ) log_probs_after_end[:, self._end_index] = 0.0 # Set the same state for each element in the beam. self._update_initial_state(state, batch_size) + for i, constraint in enumerate(self.constraints): + constraint_states[i] = constraint.update_state( + constraint_states[i], start_predicted_classes + ) + for timestep in range(self.max_steps - 1): # shape: (batch_size * beam_size,) last_predictions = predictions[-1].reshape(batch_size * self.beam_size) @@ -792,12 +976,29 @@ def _search( # shape: (batch_size * beam_size, num_classes) class_log_probabilities, state = step(last_predictions, state, timestep + 1) + # Apply all constraints. + if self.constraints: + # shape: (batch_size, beam_size, num_classes) + reshaped_class_log_probabilities = class_log_probabilities.view( + batch_size, self.beam_size, -1 + ) + for constraint, constraint_state in zip(self.constraints, constraint_states): + reshaped_class_log_probabilities = constraint.apply( + constraint_state, reshaped_class_log_probabilities + ) + # shape: (batch_size * beam_size, num_classes) + class_log_probabilities = reshaped_class_log_probabilities.view( + batch_size * self.beam_size, -1 + ) + # The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token # of the sequence (because `timestep` is 0-indexed and we generated the first token # before the for loop). Here we block the end index if the search is not allowed to # terminate on this iteration. if timestep + 2 <= self.min_steps: - class_log_probabilities[:, self._end_index] = float("-inf") + class_log_probabilities[:, self._end_index] = min_value_of_dtype( + class_log_probabilities.dtype + ) # shape: (batch_size * beam_size, num_classes) last_predictions_expanded = last_predictions.unsqueeze(-1).expand( @@ -874,9 +1075,20 @@ def _search( # ancestors created this iteration. self._update_state(state, backpointer) - if not torch.isfinite(last_log_probabilities).all(): + for i, constraint in enumerate(self.constraints): + constraint_states[i] = constraint.update_state( + constraint_states[i], restricted_predicted_classes + ) + + # Warn about "-inf" log probabilities if not using any constraints (negligible + # log probabilities are expected when using constraints). + if not self.constraints and ( + not torch.isfinite(last_log_probabilities).all() + or (last_log_probabilities == min_value_of_dtype(last_log_probabilities.dtype)).any() + ): warnings.warn( - "Infinite log probabilities encountered. Some final sequences may not make sense. " + "Negligible log probabilities encountered ('-inf' or equivalent). " + "Some final sequences may not make sense. " "This can happen when the beam size is larger than the number of valid (non-zero " "probability) transitions that the step function produces.", RuntimeWarning, diff --git a/scripts/py2md.py b/scripts/py2md.py index c8bc1ca1d43..ed916aeddde 100755 --- a/scripts/py2md.py +++ b/scripts/py2md.py @@ -286,6 +286,7 @@ class AllenNlpFilterProcessor(Struct): "TransformerModule._pretrained_allow_missing", "TransformerModule._distributed_loading_strategy", "TransformerModule._tied_weights", + "Constraint._update_state", } def process(self, graph, _resolver): diff --git a/tests/nn/beam_search_test.py b/tests/nn/beam_search_test.py index 275390cc135..6843e86df13 100644 --- a/tests/nn/beam_search_test.py +++ b/tests/nn/beam_search_test.py @@ -14,8 +14,10 @@ GumbelSampler, SequenceLogProbabilityScorer, LengthNormalizedSequenceLogProbabilityScorer, + RepeatedNGramBlockingConstraint, ) from allennlp.common.params import Params +from allennlp.nn.util import min_value_of_dtype transition_probabilities = torch.tensor( @@ -41,6 +43,18 @@ ] # end token -> jth token ) +# A transition matrix that favors repeated ngrams +repeated_ngram_transition_probabilities = torch.tensor( + [ + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0], # start token -> jth token + [0.0, 0.0, 0.4, 0.6, 0.0, 1e-9], # 1st token -> jth token + [0.0, 0.0, 0.0, 1.0, 0.0, 1e-9], # 2nd token -> jth token + [0.0, 1.0, 0.0, 0.0, 0.0, 1e-9], # ... + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # not used + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ] # end token -> jth token +) + log_probabilities = torch.log( torch.tensor([[0.1, 0.3, 0.3, 0.3, 0.0, 0.0], [0.0, 0.0, 0.4, 0.3, 0.2, 0.1]]) ) @@ -95,6 +109,25 @@ def take_short_sequence_step( return torch.stack(log_probs_list), state +def take_repeated_ngrams_step( + last_predictions: torch.Tensor, + state: Dict[str, torch.Tensor], + timestep: int, +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Take decoding step. + + This method is the same as `take_step_no_timestep` except it uses the + `short_sequence_transition_probabilities` transitions instead of `transition_probabilities` + """ + log_probs_list = [] + for last_token in last_predictions: + log_probs = torch.log(repeated_ngram_transition_probabilities[last_token.item()]) + log_probs_list.append(log_probs) + + return torch.stack(log_probs_list), state + + class BeamSearchTest(AllenNlpTestCase): def setup_method(self): super().setup_method() @@ -272,11 +305,12 @@ def test_min_steps(self): self.beam_search.min_steps = 0 expected_top_k = np.array([[5]]) expected_log_probs = np.log(np.array([0.9])) - self._check_results( - expected_top_k=expected_top_k, - expected_log_probs=expected_log_probs, - take_step=take_short_sequence_step, - ) + with pytest.warns(RuntimeWarning, match="Empty sequences predicted"): + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_short_sequence_step, + ) self.beam_search.min_steps = 1 expected_top_k = np.array([[1, 5]]) @@ -331,7 +365,7 @@ def test_warn_for_bad_log_probs(self): # next beams will result in 2 new beams that are invalid, in that have probability of 0. # The beam search should warn us of this. initial_predictions = torch.LongTensor([self.end_index - 1, self.end_index - 1]) - with pytest.warns(RuntimeWarning, match="Infinite log probabilities"): + with pytest.warns(RuntimeWarning, match="Negligible log probabilities"): self.beam_search.search(initial_predictions, {}, take_step_no_timestep) def test_empty_sequences(self): @@ -596,3 +630,202 @@ def test_length_normalized_sequence_log_prob_scorer(self): ) expected_scores = expected_log_probs / length_normalization self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores) + + def test_repeated_ngram_blocking_constraint_init_state(self): + ngram_size = 3 + batch_size = 2 + constraint = RepeatedNGramBlockingConstraint(ngram_size) + + state = constraint.init_state(batch_size) + assert len(state) == batch_size + for beam_states in state: + assert len(beam_states) == 1 + beam_state = beam_states[0] + assert len(beam_state.keys()) == 2 + assert len(beam_state["current_prefix"]) == 0 + assert len(beam_state["seen_ngrams"]) == 0 + + def test_repeated_ngram_blocking_constraint_apply(self): + ngram_size = 3 + batch_size = 2 + beam_size = 2 + num_classes = 10 + constraint = RepeatedNGramBlockingConstraint(ngram_size) + + state = [ + [ + {"current_prefix": [0, 1], "seen_ngrams": {}}, + {"current_prefix": [2, 3], "seen_ngrams": {(2, 3): [4]}}, + ], + [ + {"current_prefix": [4, 5], "seen_ngrams": {(8, 9): []}}, + {"current_prefix": [6, 7], "seen_ngrams": {(6, 7): [0, 1, 2]}}, + ], + ] + log_probabilities = torch.rand(batch_size, beam_size, num_classes) + constraint.apply(state, log_probabilities) + + disallowed_locations = torch.nonzero( + log_probabilities == min_value_of_dtype(log_probabilities.dtype) + ).tolist() + assert len(disallowed_locations) == 4 + assert [0, 1, 4] in disallowed_locations + assert [1, 1, 0] in disallowed_locations + assert [1, 1, 1] in disallowed_locations + assert [1, 1, 2] in disallowed_locations + + def test_repeated_ngram_blocking_constraint_update_state(self): + ngram_size = 3 + constraint = RepeatedNGramBlockingConstraint(ngram_size) + + # We will have [2, 3] -> {5, 6} from batch index 0 and [4, 5] -> {0} and [6, 7] -> {3} + # from batch index + state = [ + [ + {"current_prefix": [0, 1], "seen_ngrams": {}}, + {"current_prefix": [2, 3], "seen_ngrams": {(2, 3): [4]}}, + ], + [ + {"current_prefix": [4, 5], "seen_ngrams": {(8, 9): []}}, + {"current_prefix": [6, 7], "seen_ngrams": {(6, 7): [0, 1, 2]}}, + ], + ] + predictions = torch.LongTensor([[5, 6], [0, 3]]) + backpointers = torch.LongTensor([[1, 1], [0, 1]]) + + expected_state = [ + [ + {"current_prefix": [3, 5], "seen_ngrams": {(2, 3): [4, 5]}}, + {"current_prefix": [3, 6], "seen_ngrams": {(2, 3): [4, 6]}}, + ], + [ + {"current_prefix": [5, 0], "seen_ngrams": {(8, 9): [], (4, 5): [0]}}, + {"current_prefix": [7, 3], "seen_ngrams": {(6, 7): [0, 1, 2, 3]}}, + ], + ] + updated_state = constraint.update_state(state, predictions, backpointers) + assert updated_state == expected_state + + def test_take_repeated_ngram_step(self): + """ + Tests to ensure the top-k from the short_sequence_transition_probabilities + transition matrix is expected. The transitions are: + + - p(1|start) = 1.0 + - p(2|1) = 0.4 + - p(3|1) = 0.6 + - p(end|1) = 1e-9 + - p(3|2) = 1.0 + - p(end|2) = 1e-9 + - p(1|3) = 1.0 + - p(end|3) = 1e-9 + + The probabilities don't add up 1 because of the 1e-9 transitions to end. That doesn't + really matter. Each state just needed some transition to the end probability with a very + small probability to ensure it's possible to reach the end state from there and that it + isn't selected by beam search without a constraint. + + Below is the beam search tracing for beam size 2. Any sequence below the + line is not selected by beam search. The number that comes before the sequence + is the probability of the sequence. + + Step 1 + 1.0: [1] + + Step 2 + 0.6: [1, 3] + 0.4: [1, 2] + ----- + 1e-9: [1, 2, end] + + Step 3 + 0.6: [1, 3, 1] + 0.4: [1, 2, 3] + ----- + 0.6 * 1e-9: [1, 3, end] + 0.4 * 1e-9: [1, 2, end] + + Step 4 + 0.4: [1, 2, 3, 1] + 0.36: [1, 3, 1, 3] + ----- + 0.24: [1, 3, 1, 2] + 0.6 * 1e-9: [1, 3, 1, end] + 0.4 * 1e-9: [1, 2, 3, end] + + Step 5 + 0.36: [1, 3, 1, 3, 1] + 0.24: [1, 2, 3, 1, 3] + ----- + 0.16: [1, 2, 3, 1, 2] + 0.4 * 1e-9: [1, 2, 3, 1, end] + 0.36 * 1e-9: [1, 3, 1, 3, end] + """ + self.beam_search.beam_size = 2 + self.beam_search.max_steps = 5 + expected_top_k = np.array([[1, 3, 1, 3, 1], [1, 2, 3, 1, 3]]) + expected_log_probs = np.log(np.array([0.36, 0.24])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_repeated_ngrams_step, + ) + + def test_repeated_ngram_blocking_end_to_end(self): + """ + This test checks to make sure the `RepeatedNGramBlockingConstraint` successfully blocks ngrams. + It works by blocking ngrams of different sizes and ensures that the result of beam search + is correctly changed. We rely on the beam search trace for `repeated_ngram_transition_probabilities` + in `test_take_repeated_ngram_step`. + """ + self.beam_search.beam_size = 2 + + # Unigrams: On step 3, [1, 3, 1] will be blocked and [1, 3, end] will take its place + self.beam_search.max_steps = 3 + self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=1)] + expected_top_k = np.array([[1, 2, 3], [1, 3, 5]]) + expected_log_probs = np.log(np.array([0.4, 0.6 * 1e-9])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_repeated_ngrams_step, + ) + + # Bigrams: On step 4, [1, 3, 1, 3] will be blocked and [1, 3, 1, 2] will take its place + self.beam_search.max_steps = 4 + self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=2)] + expected_top_k = np.array([[1, 2, 3, 1], [1, 3, 1, 2]]) + expected_log_probs = np.log(np.array([0.4, 0.24])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_repeated_ngrams_step, + ) + + # Trigrams: On step 5, [1, 3, 1, 3, 1] will be blocked and [1, 2, 3, 1, 2] will take its place + self.beam_search.max_steps = 5 + self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=3)] + expected_top_k = np.array([[1, 2, 3, 1, 3], [1, 2, 3, 1, 2]]) + expected_log_probs = np.log(np.array([0.24, 0.16])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_repeated_ngrams_step, + ) + + def test_repeated_ngram_blocking_end_indices(self): + """ + Ensures that the ngram blocking does not mess up when one sequence is shorter + than another, which would result in repeated "end" symbols. + """ + # We block unigrams, but 5 (the end symbol) is repeated and it does not mess + # up the sequence's probability + self.beam_search.beam_size = 2 + self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=1)] + expected_top_k = np.array([[1, 3, 5, 5], [1, 2, 3, 5]]) + expected_log_probs = np.log(np.array([0.6 * 1e-9, 0.4 * 1e-9])) + self._check_results( + expected_top_k=expected_top_k, + expected_log_probs=expected_log_probs, + take_step=take_repeated_ngrams_step, + ) From 0bdee9d135bd8c401ec09bba2377b7ea5b5d7f43 Mon Sep 17 00:00:00 2001 From: John Giorgi <johnmgiorgi@gmail.com> Date: Tue, 1 Jun 2021 13:59:41 -0400 Subject: [PATCH 45/63] Make BeamSearch Registrable (#5231) * Make BeamSearch Registrable * Update changelog * Remove unused import * Update CHANGELOG.md Co-authored-by: Pete <petew@allenai.org> Co-authored-by: Pete <epwalsh10@gmail.com> --- CHANGELOG.md | 1 + allennlp/nn/beam_search.py | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 07e39ab28f1..9066641dd7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0). - Trainer callbacks can now store and restore state in case a training run gets interrupted. - VilBERT backbone now rolls and unrolls extra dimensions to handle input with > 3 dimensions. +- `BeamSearch` is now a `Registrable` class. ### Added diff --git a/allennlp/nn/beam_search.py b/allennlp/nn/beam_search.py index 0b15d59bd7a..3d0d3ae38b3 100644 --- a/allennlp/nn/beam_search.py +++ b/allennlp/nn/beam_search.py @@ -6,7 +6,7 @@ from overrides import overrides import torch -from allennlp.common import FromParams, Registrable +from allennlp.common import Registrable from allennlp.common.checks import ConfigurationError from allennlp.nn.util import min_value_of_dtype @@ -683,7 +683,7 @@ def _update_state( return state -class BeamSearch(FromParams): +class BeamSearch(Registrable): """ Implements the beam search algorithm for decoding the most likely sequences. @@ -731,6 +731,8 @@ class BeamSearch(FromParams): provided, no constraints will be enforced. """ + default_implementation = "beam_search" + def __init__( self, end_index: int, @@ -1180,3 +1182,6 @@ def _update_state(self, state: StateType, backpointer: torch.Tensor): .gather(1, expanded_backpointer) .reshape(batch_size * self.beam_size, *last_dims) ) + + +BeamSearch.register("beam_search")(BeamSearch) From 8e10f695c2537db44e2eab2f0451ddc8f0007148 Mon Sep 17 00:00:00 2001 From: epwalsh <epwalsh10@gmail.com> Date: Wed, 2 Jun 2021 11:46:43 -0700 Subject: [PATCH 46/63] tick version for nightly release --- allennlp/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allennlp/version.py b/allennlp/version.py index 0d88e117162..55ba0329c4e 100644 --- a/allennlp/version.py +++ b/allennlp/version.py @@ -1,7 +1,7 @@ import os _MAJOR = "2" -_MINOR = "4" +_MINOR = "5" # On main and in a nightly release the patch should be one ahead of the last # released build. _PATCH = "0" From 7b8e9e9f95e3ababcc05f865ef9052bf1fe52088 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia <akshita23bhagia@gmail.com> Date: Wed, 2 Jun 2021 14:24:57 -0700 Subject: [PATCH 47/63] Generalize T5 modules (#5166) * initial commit * general self attn * fixing bugs, adding tests, adding docs * updating other modules * refactor * bug fix * update changelog * fix shape * fix format * address feedback * small doc fix * Update allennlp/modules/transformer/transformer_stack.py Co-authored-by: Pete <petew@allenai.org> * remove old file Co-authored-by: epwalsh <epwalsh10@gmail.com> Co-authored-by: Pete <petew@allenai.org> --- CHANGELOG.md | 1 + allennlp/modules/transformer/__init__.py | 2 +- .../modules/transformer/attention_module.py | 621 ++++++++++++++++++ .../modules/transformer/bimodal_encoder.py | 12 +- .../modules/transformer/self_attention.py | 161 ----- allennlp/modules/transformer/t5.py | 293 +-------- .../modules/transformer/transformer_layer.py | 62 +- .../modules/transformer/transformer_stack.py | 26 +- allennlp/modules/transformer/util.py | 6 +- .../transformer/self_attention_test.py | 7 +- .../transformer/t5_self_attention_test.py | 126 ++++ .../transformer/transformer_layer_test.py | 12 +- .../transformer/transformer_stack_test.py | 4 +- 13 files changed, 861 insertions(+), 472 deletions(-) create mode 100644 allennlp/modules/transformer/attention_module.py delete mode 100644 allennlp/modules/transformer/self_attention.py create mode 100644 tests/modules/transformer/t5_self_attention_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9066641dd7f..acf6730854e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences. - Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`. - Added `shuffle` argument to `BucketBatchSampler` which allows for disabling shuffling. +- Added `allennlp.modules.transformer.attention_module` which contains a generalized `AttentionModule`. `SelfAttention` and `T5Attention` both inherit from this. - Added a `Constraint` abstract class to `BeamSearch`, which allows for incorporating constraints on the predictions found by `BeamSearch`, along with a `RepeatedNGramBlockingConstraint` constraint implementation, which allows for preventing repeated n-grams in the output from `BeamSearch`. - Added `DataCollator` for dynamic operations for each batch. diff --git a/allennlp/modules/transformer/__init__.py b/allennlp/modules/transformer/__init__.py index 9b944130c7c..40e99a67918 100644 --- a/allennlp/modules/transformer/__init__.py +++ b/allennlp/modules/transformer/__init__.py @@ -131,7 +131,7 @@ def forward(self, token_ids: torch.LongTensor, mask: torch.BoolTensor): TransformerEmbeddings, ImageFeatureEmbeddings, ) -from allennlp.modules.transformer.self_attention import SelfAttention +from allennlp.modules.transformer.attention_module import SelfAttention, T5Attention from allennlp.modules.transformer.activation_layer import ActivationLayer from allennlp.modules.transformer.transformer_layer import AttentionLayer, TransformerLayer from allennlp.modules.transformer.transformer_stack import TransformerStack diff --git a/allennlp/modules/transformer/attention_module.py b/allennlp/modules/transformer/attention_module.py new file mode 100644 index 00000000000..4d98caba4b2 --- /dev/null +++ b/allennlp/modules/transformer/attention_module.py @@ -0,0 +1,621 @@ +import math +from typing import Optional, Tuple, TYPE_CHECKING +from dataclasses import dataclass +import torch +import torch.nn.functional as F + +from allennlp.common import FromParams +from allennlp.common.checks import ConfigurationError +from allennlp.modules.attention import Attention +from allennlp.modules.transformer.transformer_module import TransformerModule +from allennlp.modules.transformer.util import apply_mask, FloatT, IntT, BoolT + +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + + +@dataclass +class AttentionOutput: + """ + Encapsulates the outputs of the `Attention` module. + """ + + hidden_states: FloatT + key_value_state: Optional[Tuple[FloatT, FloatT]] = None + position_bias: Optional[FloatT] = None + attention_probs: Optional[FloatT] = None + + +class AttentionModule(TransformerModule, FromParams): + """ + This module computes self-attention (or cross-attention), similar to the architecture in BERT. + Details in the paper: + [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Devlin et al, 2019] + (https://api.semanticscholar.org/CorpusID:52967399) + + Additionally, it has the following functionality: + + * the attention scoring function can be specified. + * it can be used in encoders as well as decoders. + * `position_bias` can be used, which makes it suitable for + [T5-style attention](https://api.semanticscholar.org/CorpusID:204838007) as well. + + # Parameters + + hidden_size: `int` (default = `512`) + The size of the expected input tensor. + attention_head_size: `int` (default = `64`) + The size of a single attention head. + num_attention_heads: `int` (default = `8`) + The number of attention heads. + scoring_func: `str` (default = `scaled_dot_product`) + The name of the attention-calculating function to be used. + Eg. `additive`, `linear`, etc. For a complete list, please check + :mod:`allennlp.modules.attention.attention`. + output_linear: `bool` (default = `False`) + Whether to add an additional output linear layer at the end. + dropout: `float` (default = `0.0`) + The dropout probability. + bias: `bool` (default = `True`) + Whether to include bias weights in query, key, value (and output) linear layers. + normalize_weights: `bool` (default = `False`) + Whether to normalize the initial weights. + is_decoder: `bool` (default = `False`) + Whether this module is being used in a decoder stack or not. + is_cross_attention: `bool` (default = `False`) + Whether this module is being used for cross-attention in a decoder stack or not. + If `is_cross_attention` is `True`, then `is_decoder` must also be `True`. + relative_attention_num_buckets: `int`, optional (default = `None`) + The number of buckets to use in relative attention; if `None`, relative attention + will not be applied. + """ + + def __init__( + self, + hidden_size: int = 512, + attention_head_size: int = 64, + num_attention_heads: int = 8, + scoring_func: str = "scaled_dot_product", + output_linear: bool = False, + dropout: float = 0.0, + bias: bool = True, + normalize_weights: bool = False, + is_decoder: bool = False, + is_cross_attention: bool = False, + relative_attention_num_buckets: Optional[int] = None, + ): + + super().__init__() + + if hidden_size % num_attention_heads != 0: + raise ConfigurationError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (hidden_size, num_attention_heads) + ) + + if is_cross_attention and not is_decoder: + raise ConfigurationError( + "The attention layer can be a cross-attention layer only " + "if it is within a decoder." + ) + + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = torch.nn.Linear(hidden_size, self.all_head_size, bias=bias) + self.key = torch.nn.Linear(hidden_size, self.all_head_size, bias=bias) + self.value = torch.nn.Linear(hidden_size, self.all_head_size, bias=bias) + + # out linear layer for distilbert, T5 etc. + if output_linear: + self.output = torch.nn.Linear(self.all_head_size, hidden_size, bias=bias) + + self.scoring_func = scoring_func + if self.scoring_func in ["additive", "linear", "bilinear"]: + self.attn = Attention.by_name(self.scoring_func)(hidden_size, hidden_size) + elif self.scoring_func == "scaled_dot_product": + self.attn = Attention.by_name(self.scoring_func)(self.attention_head_size, False) + else: + self.attn = Attention.by_name(self.scoring_func)() + + self.relative_attention_num_buckets = relative_attention_num_buckets + + if self.relative_attention_num_buckets is not None: + self.relative_attention_bias = torch.nn.Embedding( + self.relative_attention_num_buckets, self.num_attention_heads + ) + + self.dropout = dropout + + self.is_decoder = is_decoder + self.is_cross_attention = is_cross_attention + + if normalize_weights: + self._normalize() + + def _normalize(self) -> None: + self.query.weight.data.normal_( + mean=0.0, std=(self.hidden_size * self.attention_head_size) ** -0.5 + ) + self.key.weight.data.normal_(mean=0.0, std=self.hidden_size ** -0.5) + self.value.weight.data.normal_(mean=0.0, std=self.hidden_size ** -0.5) + + if hasattr(self, "output"): + self.output.weight.data.normal_( + mean=0.0, std=(self.num_attention_heads * self.attention_head_size) ** -0.5 + ) + + if hasattr(self, "relative_attention_bias"): + self.relative_attention_bias.weight.data.normal_(mean=0.0, std=self.hidden_size ** -0.5) + + def _transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def _query_layer(self, query_states: torch.Tensor) -> torch.Tensor: + mixed_query_layer = self.query(query_states) + query_layer = self._transpose_for_scores(mixed_query_layer) + return query_layer + + def _project( + self, + hidden_states: torch.Tensor, + layer: torch.nn.Linear, + source_states: Optional[torch.Tensor] = None, + past_key_or_value: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # TODO: clarify logic in terms of is_decoder and is_cross_attention + # to make it more readable. + if source_states is None: + # self-attn + # (batch_size, num_heads, seq_length, dim_per_head) + hidden_states = self._transpose_for_scores(layer(hidden_states)) + elif past_key_or_value is None: + # cross-attn + # (batch_size, num_heads, seq_length, dim_per_head) + hidden_states = self._transpose_for_scores(layer(source_states)) + + if past_key_or_value is not None: + if source_states is None: + # self-attn + # (batch_size, num_heads, key_length, dim_per_head) + # if len(past_key_or_value.shape) == 3: + # past_key_or_value = self._transpose_for_scores(past_key_or_value) + hidden_states = torch.cat([past_key_or_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_or_value + return hidden_states + + def _position_bias( + self, + position_bias: Optional[torch.Tensor], + seq_lengths: Tuple[int, int, int], + past_key_states: Optional[torch.Tensor], + attention_scores: torch.Tensor, + ) -> torch.Tensor: + seq_length, real_seq_length, key_length = seq_lengths + + if position_bias is None: + if self.relative_attention_num_buckets is not None: + position_bias = self.compute_bias(real_seq_length, key_length) + else: + position_bias = torch.zeros( + (1, self.num_attention_heads, real_seq_length, key_length), + device=attention_scores.device, + dtype=attention_scores.dtype, + ) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_states is not None: + position_bias = position_bias[:, :, -seq_length:, :] + return position_bias + + def _get_attention_probs( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + attention_mask: torch.Tensor, + head_mask: torch.Tensor, + seq_lengths: Tuple[int, int, int], + position_bias: Optional[torch.Tensor] = None, + past_key_states: Optional[torch.Tensor] = None, + **kwargs, + ): + attention_scores = self.attn(query_layer, key_layer.transpose(-1, -2)) + + position_bias = self._position_bias( + position_bias, seq_lengths, past_key_states, attention_scores + ) + + if attention_mask is not None: + # Shape: (batch_size, num_heads, seq_length, key_length) + position_bias = apply_mask(position_bias, attention_mask) + attention_scores += position_bias + + attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = F.dropout(attention_probs, p=self.dropout, training=self.training) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + return attention_probs, position_bias + + def _output_layer(self, attention_probs: torch.Tensor, value_layer: torch.Tensor): + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + if hasattr(self, "output"): + context_layer = self.output(context_layer) + + return context_layer + + def _get_lengths( + self, + query_states: torch.Tensor, + past_key_states: Optional[torch.Tensor] = None, + source_states: Optional[torch.Tensor] = None, + query_length: Optional[int] = None, + ) -> Tuple[int, int, int]: + + seq_length = query_states.shape[1] + effective_seq_len = seq_length + + if past_key_states is not None: + # TODO: query_length from up the stack: move logic here. + # TODO: clarify the logic here in terms of encoder/decoder case. + effective_seq_len += past_key_states.shape[2] if query_length is None else query_length + + key_length = effective_seq_len if source_states is None else source_states.shape[1] + + return (seq_length, effective_seq_len, key_length) + + def forward( + self, + query_states: torch.Tensor, + past_key_states: Optional[torch.Tensor] = None, + past_value_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + source_states: Optional[torch.Tensor] = None, + source_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + position_bias: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + query_length: Optional[int] = None, + ): + """ + # Parameters + + query_states : `torch.Tensor` + Shape `batch_size x seq_len x hidden_dim` + past_key_states : `torch.Tensor`, optional + Shape `batch_size x seq_len x hidden_dim` + These are the key_states from the previous step of the decoder. + past_value_states : `torch.Tensor`, optional + Shape `batch_size x seq_len x hidden_dim` + These are the value_states from the previous step of the decoder. + attention_mask : `torch.BoolTensor`, optional + Shape `batch_size x seq_len` + source_states : `torch.Tensor`, optional + Shape `batch_size x source_seq_len x hidden_dim` + This is from the final state of attention over the source (encoder); + it is passed when this module is being used for cross-attention. + source_attention_mask : `torch.BoolTensor`, optional + Shape `batch_size x source_seq_len` + head_mask : `torch.BoolTensor`, optional + position_bias : `torch.Tensor`, optional + output_attentions : `bool` + Whether to also return the attention probabilities, default = `False` + + !!! Note + `source_states` needs to be passed in case of cross-attention. + + """ + query_layer = self._query_layer(query_states) + + key_layer = self._project( + query_states, + self.key, + source_states, + past_key_states, + ) + + value_layer = self._project( + query_states, + self.value, + source_states, + past_value_states, + ) + + if self.is_cross_attention: + attention_mask = source_attention_mask + + seq_lengths = self._get_lengths(query_states, past_key_states, source_states, query_length) + + attention_probs, position_bias = self._get_attention_probs( + query_layer, + key_layer, + attention_mask, + head_mask, + seq_lengths, + position_bias, + past_key_states, + ) + + context_layer = self._output_layer(attention_probs, value_layer) + + present_key_value_state = ( + (key_layer, value_layer) if (self.is_decoder and use_cache) else None + ) + + if not output_attentions: + attention_probs = None + + outputs = AttentionOutput( + context_layer, present_key_value_state, position_bias, attention_probs + ) + + return outputs + + @staticmethod + def _relative_position_bucket( + relative_position: IntT, + bidirectional: bool = True, + num_buckets: int = 32, + max_distance: int = 128, + ) -> IntT: + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the + attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller + buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All + relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the + same bucket. This should allow for more graceful generalization to longer sequences than the model has been + trained on. + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range + [0, num_buckets) + """ + relative_buckets = relative_position.new_zeros(relative_position.shape) + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length: int, key_length: int) -> FloatT: + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, # type: ignore + ) + relative_position_bucket = relative_position_bucket.to( + self.relative_attention_bias.weight.device + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values + + +class T5Attention(AttentionModule): + + _pretrained_relevant_module = ["encoder.block.0.layer.0.SelfAttention"] + _pretrained_mapping = { + "q": "query", + "k": "key", + "v": "value", + "o": "output", + } + + def __init__( + self, + is_decoder: bool = False, + hidden_size: int = 512, + key_value_proj_dim: int = 64, + num_heads: int = 8, + has_relative_attention_bias: bool = False, + relative_attention_num_buckets: int = 32, + dropout: float = 0.1, + normalize: bool = True, + is_cross_attention: bool = False, + ): + + if not has_relative_attention_bias: + relative_attention_num_buckets = None # type: ignore + + super().__init__( + hidden_size=hidden_size, + attention_head_size=key_value_proj_dim, + num_attention_heads=num_heads, + output_linear=True, + scoring_func="scaled_dot_product", + dropout=dropout, + bias=False, + normalize_weights=normalize, + is_decoder=is_decoder, + is_cross_attention=is_cross_attention, + relative_attention_num_buckets=relative_attention_num_buckets, + ) + + self.attn = Attention.by_name(self.scoring_func)(scaling_factor=1, normalize=False) + + def forward( # type: ignore + self, + hidden_states: torch.Tensor, + mask: Optional[torch.BoolTensor] = None, + key_value_states: Optional[FloatT] = None, + position_bias: Optional[FloatT] = None, + past_key_value: Optional[ + Tuple[FloatT, FloatT] + ] = None, # this is used when taking decoding steps. + layer_head_mask: Optional[BoolT] = None, + query_length: Optional[int] = None, # only relevant in cross-attention. + use_cache: bool = False, + output_attentions: bool = False, + ) -> AttentionOutput: + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by + key_value_states). + """ + if past_key_value: + past_key_states = past_key_value[0] + past_value_states = past_key_value[1] + else: + past_key_states = None + past_value_states = None + + outputs = super().forward( + query_states=hidden_states, + past_key_states=past_key_states, + past_value_states=past_value_states, + attention_mask=mask, + source_states=key_value_states, + source_attention_mask=None, # TODO: is this a bug in current T5 code? + head_mask=layer_head_mask, + position_bias=position_bias, + output_attentions=output_attentions, + use_cache=use_cache, + query_length=query_length, + ) + + return outputs + + @classmethod + def _from_config(cls, config: "PretrainedConfig", **kwargs): + final_kwargs = {} + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["key_value_proj_dim"] = config.d_kv + + final_kwargs["is_decoder"] = getattr(config, "is_decoder", False) + final_kwargs["has_relative_attention_bias"] = getattr( + config, "has_relative_attention_bias", True + ) + final_kwargs["normalize"] = getattr(config, "normalize", True) + final_kwargs["is_cross_attention"] = getattr(config, "is_cross_attention", False) + + final_kwargs["relative_attention_num_buckets"] = config.relative_attention_num_buckets + final_kwargs["num_heads"] = config.num_attention_heads + + final_kwargs["dropout"] = config.dropout_rate + final_kwargs.update(**kwargs) + return cls(**final_kwargs) + + +class SelfAttention(AttentionModule): + """ + This module computes the self-attention, similar to the architecture in BERT. Additionally, the attention + scoring function can be specified. + Details in the paper: + [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Devlin et al, 2019] + (https://api.semanticscholar.org/CorpusID:52967399) + + # Parameters + + hidden_size: `int` + num_attention_heads: `int` + dropout: `float` (default = `0.0`) + scoring_func: `str` (default = `scaled_dot_product`) + The name of the attention-calculating function to be used. + Eg. `additive`, `linear`, etc. For a complete list, please check + :mod:`allennlp.modules.attention.attention`. + """ + + _pretrained_relevant_module = ["encoder.layers.0.attention.self", "encoder.layers.0.attention"] + _pretrained_mapping = { + "layer": "layers", + "q_lin": "query", + "k_lin": "key", + "v_lin": "value", + "out_lin": "output", + "transformer": "encoder", + } + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + dropout: float = 0.0, + scoring_func: str = "scaled_dot_product", + output_linear: bool = False, + is_decoder: bool = False, + is_cross_attention: bool = False, + ): + + attention_head_size = int(hidden_size / num_attention_heads) + + super().__init__( + hidden_size=hidden_size, + attention_head_size=attention_head_size, + num_attention_heads=num_attention_heads, + scoring_func=scoring_func, + output_linear=output_linear, + dropout=dropout, + bias=True, + is_decoder=is_decoder, + is_cross_attention=is_cross_attention, + ) + + @classmethod + def _from_config(cls, config: "PretrainedConfig", **kwargs): + final_kwargs = {} + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["output_linear"] = hasattr( + config, "n_heads" + ) # This is the distilbert case; they have a linear layer as the output. + if hasattr(config, "attention_dropout"): + final_kwargs["dropout"] = config.attention_dropout + else: + final_kwargs["dropout"] = config.attention_probs_dropout_prob + final_kwargs.update(**kwargs) + return cls(**final_kwargs) diff --git a/allennlp/modules/transformer/bimodal_encoder.py b/allennlp/modules/transformer/bimodal_encoder.py index acc993194df..27634fc1f30 100644 --- a/allennlp/modules/transformer/bimodal_encoder.py +++ b/allennlp/modules/transformer/bimodal_encoder.py @@ -150,19 +150,19 @@ def forward( for idx in range(start1, self.fixed_layer1): with torch.no_grad(): - embedding1 = self.layers1[idx](embedding1, attention_mask1)[0] + embedding1 = self.layers1[idx](embedding1, attention_mask1).hidden_states start1 = self.fixed_layer1 for idx in range(start1, end1): - embedding1 = self.layers1[idx](embedding1, attention_mask1)[0] + embedding1 = self.layers1[idx](embedding1, attention_mask1).hidden_states for idx in range(start2, self.fixed_layer2): with torch.no_grad(): - embedding2 = self.layers2[idx](embedding2, attention_mask2)[0] + embedding2 = self.layers2[idx](embedding2, attention_mask2).hidden_states start2 = self.fixed_layer2 for idx in range(start2, end2): - embedding2 = self.layers2[idx](embedding2, attention_mask2)[0] + embedding2 = self.layers2[idx](embedding2, attention_mask2).hidden_states if count == 0 and self.in_batch_pairs: # new batch size is the batch_size ^2 @@ -230,10 +230,10 @@ def forward( all_encoder_layers2.append(embedding2) for idx in range(start2, len(self.layers2)): - embedding2 = self.layers2[idx](embedding2, attention_mask2)[0] + embedding2 = self.layers2[idx](embedding2, attention_mask2).hidden_states for idx in range(start1, len(self.layers1)): - embedding1 = self.layers1[idx](embedding1, attention_mask1)[0] + embedding1 = self.layers1[idx](embedding1, attention_mask1).hidden_states # add the end part to finish. if not output_all_encoded_layers: diff --git a/allennlp/modules/transformer/self_attention.py b/allennlp/modules/transformer/self_attention.py deleted file mode 100644 index d464012de81..00000000000 --- a/allennlp/modules/transformer/self_attention.py +++ /dev/null @@ -1,161 +0,0 @@ -from typing import Optional, TYPE_CHECKING - -import torch - -from allennlp.common import FromParams -from allennlp.modules.attention import Attention -from allennlp.modules.transformer.transformer_module import TransformerModule -from allennlp.modules.transformer.util import apply_mask - -if TYPE_CHECKING: - from transformers.configuration_utils import PretrainedConfig - - -class SelfAttention(TransformerModule, FromParams): - """ - This module computes the self-attention, similar to the architecture in BERT. Additionally, the attention - scoring function can be specified. - Details in the paper: - [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Devlin et al, 2019] - (https://api.semanticscholar.org/CorpusID:52967399) - - # Parameters - - hidden_size: `int` - num_attention_heads: `int` - dropout: `float` (default = `0.0`) - scoring_func: `str` (default = `scaled_dot_product`) - The name of the attention-calculating function to be used. - Eg. `additive`, `linear`, etc. For a complete list, please check :mod:`allennlp.modules.attention`. - """ - - _pretrained_relevant_module = ["encoder.layers.0.attention.self", "encoder.layers.0.attention"] - _pretrained_mapping = { - "layer": "layers", - "q_lin": "query", - "k_lin": "key", - "v_lin": "value", - "out_lin": "output", - "transformer": "encoder", - } - - def __init__( - self, - hidden_size: int, - num_attention_heads: int, - dropout: float = 0.0, - scoring_func: str = "scaled_dot_product", - output_linear: bool = False, - ): - super().__init__() - if hidden_size % num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (hidden_size, num_attention_heads) - ) - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.attention_head_size = int(hidden_size / num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = torch.nn.Linear(hidden_size, self.all_head_size) - self.key = torch.nn.Linear(hidden_size, self.all_head_size) - self.value = torch.nn.Linear(hidden_size, self.all_head_size) - - self.scoring_func = scoring_func - if self.scoring_func in ["additive", "linear", "bilinear"]: - self.attn = Attention.by_name(self.scoring_func)(hidden_size, hidden_size) - elif self.scoring_func == "scaled_dot_product": - self.attn = Attention.by_name(self.scoring_func)(self.attention_head_size, False) - else: - self.attn = Attention.by_name(self.scoring_func)() - - # out linear layer for distilbert. - if output_linear: - self.output = torch.nn.Linear(hidden_size, self.all_head_size) - - self.dropout = torch.nn.Dropout(dropout) - - def _transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - query_states: torch.Tensor, - key_states: Optional[torch.Tensor] = None, - value_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ): - """ - # Parameters - - query_states : `torch.Tensor` - Shape `batch_size x seq_len x hidden_dim` - key_states : `torch.Tensor`, optional - Shape `batch_size x seq_len x hidden_dim` - value_states : `torch.Tensor`, optional - Shape `batch_size x seq_len x hidden_dim` - attention_mask : `torch.BoolTensor`, optional - Shape `batch_size x seq_len` - head_mask : `torch.BoolTensor`, optional - output_attentions : `bool` - Whether to also return the attention probabilities, default = `False` - """ - if key_states is None: - key_states = query_states - if value_states is None: - value_states = query_states - - mixed_query_layer = self.query(query_states) - mixed_key_layer = self.key(key_states) - mixed_value_layer = self.value(value_states) - - query_layer = self._transpose_for_scores(mixed_query_layer) - key_layer = self._transpose_for_scores(mixed_key_layer) - value_layer = self._transpose_for_scores(mixed_value_layer) - - attention_scores = self.attn(query_layer, key_layer.transpose(-1, -2)) - - if attention_mask is not None: - attention_scores = apply_mask(attention_scores, attention_mask) - - attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - - if hasattr(self, "output"): - context_layer = self.output(context_layer) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - return outputs - - @classmethod - def _from_config(cls, config: "PretrainedConfig", **kwargs): - final_kwargs = {} - final_kwargs["hidden_size"] = config.hidden_size - final_kwargs["num_attention_heads"] = config.num_attention_heads - final_kwargs["output_linear"] = hasattr( - config, "n_heads" - ) # Since this is the distilbert case. - if hasattr(config, "attention_dropout"): - final_kwargs["dropout"] = config.attention_dropout - else: - final_kwargs["dropout"] = config.attention_probs_dropout_prob - final_kwargs.update(**kwargs) - return cls(**final_kwargs) diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index 206f944aae5..7b7cfc21206 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -3,7 +3,6 @@ (https://github.com/huggingface/transformers/blob/4c32f9f26e6a84f0d9843fec8757e6ce640bb44e/src/transformers/models/t5/modeling_t5.py). """ # noqa: E401 -import math from dataclasses import dataclass from typing import Optional, Tuple, List, Union, Dict, TYPE_CHECKING @@ -14,23 +13,22 @@ from allennlp.common import FromParams, Params, Lazy, Registrable from allennlp.common.checks import ConfigurationError -from allennlp.modules.transformer.transformer_module import ( - TransformerModule, +from allennlp.modules.transformer.transformer_module import TransformerModule +from allennlp.modules.transformer.attention_module import ( + T5Attention, + AttentionOutput, ) from allennlp.modules.transformer.util import ( - apply_mask, get_extended_attention_mask, + FloatT, + IntT, + BoolT, ) from allennlp.nn.beam_search import BeamSearch if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig -# Unfortunately mypy is insane, so I have to wrap these in unions. -FloatT = Union[torch.FloatTensor] -IntT = Union[torch.IntTensor] -BoolT = Union[torch.BoolTensor] - class T5LayerNorm(TransformerModule, FromParams): """T5-style layer norm does not have bias and does not subtract the mean.""" @@ -119,259 +117,6 @@ def forward(self, hidden_states) -> FloatT: return hidden_states -@dataclass -class T5AttentionOutput: - hidden_states: FloatT - key_value_state: Optional[Tuple[FloatT, FloatT]] - position_bias: FloatT - attn_weights: Optional[FloatT] = None - - -class T5Attention(TransformerModule, FromParams): - def __init__( - self, - is_decoder: bool = False, - hidden_size: int = 512, - key_value_proj_dim: int = 64, - num_heads: int = 8, - has_relative_attention_bias: bool = False, - relative_attention_num_buckets: int = 32, - dropout: float = 0.1, - ): - super().__init__() - self.is_decoder = is_decoder - self.has_relative_attention_bias = has_relative_attention_bias - - self.relative_attention_num_buckets = relative_attention_num_buckets - self.hidden_size = hidden_size - self.key_value_proj_dim = key_value_proj_dim - self.num_heads = num_heads - self.dropout = dropout - self.inner_dim = self.num_heads * self.key_value_proj_dim - - self.q = nn.Linear(self.hidden_size, self.inner_dim, bias=False) - self.k = nn.Linear(self.hidden_size, self.inner_dim, bias=False) - self.v = nn.Linear(self.hidden_size, self.inner_dim, bias=False) - self.o = nn.Linear(self.inner_dim, self.hidden_size, bias=False) - if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embedding( - self.relative_attention_num_buckets, self.num_heads - ) - - self.q.weight.data.normal_(mean=0.0, std=(hidden_size * key_value_proj_dim) ** -0.5) - self.k.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) - self.v.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) - self.o.weight.data.normal_(mean=0.0, std=(num_heads * key_value_proj_dim) ** -0.5) - if self.has_relative_attention_bias: - self.relative_attention_bias.weight.data.normal_(mean=0.0, std=hidden_size ** -0.5) - - @staticmethod - def _relative_position_bucket( - relative_position: IntT, - bidirectional: bool = True, - num_buckets: int = 32, - max_distance: int = 128, - ) -> IntT: - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the - attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller - buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All - relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the - same bucket. This should allow for more graceful generalization to longer sequences than the model has been - trained on. - - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range - [0, num_buckets) - """ - relative_buckets = relative_position.new_zeros(relative_position.shape) - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_postion_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - relative_postion_if_large = torch.min( - relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) - ) - - relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) - return relative_buckets - - def compute_bias(self, query_length: int, key_length: int) -> FloatT: - """Compute binned relative position bias""" - context_position = torch.arange(query_length, dtype=torch.long)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long)[None, :] - relative_position = memory_position - context_position # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, # shape (query_length, key_length) - bidirectional=(not self.is_decoder), - num_buckets=self.relative_attention_num_buckets, - ) - relative_position_bucket = relative_position_bucket.to( - self.relative_attention_bias.weight.device - ) - values = self.relative_attention_bias( - relative_position_bucket - ) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze( - 0 - ) # shape (1, num_heads, query_length, key_length) - return values - - def forward( - self, - hidden_states: torch.Tensor, - mask: Optional[torch.BoolTensor] = None, - key_value_states: Optional[FloatT] = None, - position_bias: Optional[FloatT] = None, - past_key_value: Optional[Tuple[FloatT, FloatT]] = None, - layer_head_mask: Optional[BoolT] = None, - query_length: Optional[int] = None, - use_cache: bool = False, - output_attentions: bool = False, - ) -> T5AttentionOutput: - """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by - key_value_states). - """ - # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, num_heads, q_len - 1, dim_per_head) - batch_size, seq_length = hidden_states.shape[:2] - - real_seq_length = seq_length - - if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), "past_key_value should have 2 past states: keys and values. Got {} past states".format( - len(past_key_value) - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - return states.view(batch_size, -1, self.num_heads, self.key_value_proj_dim).transpose( - 1, 2 - ) - - def unshape(states): - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value) -> FloatT: - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, num_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, num_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, num_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape( - self.q(hidden_states) - ) # (batch_size, num_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, - self.k, - key_value_states, - past_key_value[0] if past_key_value is not None else None, - ) - value_states = project( - hidden_states, - self.v, - key_value_states, - past_key_value[1] if past_key_value is not None else None, - ) - - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - - if position_bias is None: - if self.has_relative_attention_bias: - position_bias = self.compute_bias(real_seq_length, key_length) - else: - position_bias = torch.zeros( - (1, self.num_heads, real_seq_length, key_length), - device=scores.device, - dtype=scores.dtype, - ) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -seq_length:, :] - - if mask is not None: - # Shape: (batch_size, num_heads, seq_length, key_length) - position_bias = apply_mask(position_bias, mask) - - scores += position_bias - attn_weights = F.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, num_heads, seq_length, key_length) - attn_weights = F.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, num_heads, seq_length, key_length) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask - - attn_output = unshape( - torch.matmul(attn_weights, value_states) - ) # (batch_size, seq_length, dim) - attn_output = self.o(attn_output) - - present_key_value_state = ( - (key_states, value_states) if (self.is_decoder and use_cache) else None - ) - outputs = T5AttentionOutput(attn_output, present_key_value_state, position_bias) - if output_attentions: - outputs.attn_weights = attn_weights - return outputs - - @dataclass class T5LayerSelfAttentionOutput: hidden_states: FloatT @@ -397,6 +142,10 @@ def __init__( self.layer_norm = layer_norm or T5LayerNorm(hidden_size=self.self_attention.hidden_size) self.dropout = nn.Dropout(dropout) + @property + def hidden_size(self) -> int: + return self.self_attention.hidden_size + def forward( self, hidden_states: FloatT, @@ -407,8 +156,10 @@ def forward( use_cache: bool = False, output_attentions: bool = False, ) -> T5LayerSelfAttentionOutput: + normed_hidden_states = self.layer_norm(hidden_states) - attention_output: T5AttentionOutput = self.self_attention( + + attention_output: AttentionOutput = self.self_attention( normed_hidden_states, mask=attention_mask, position_bias=position_bias, @@ -417,12 +168,14 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) + hidden_states = hidden_states + self.dropout(attention_output.hidden_states) + return T5LayerSelfAttentionOutput( hidden_states, attention_output.key_value_state, attention_output.position_bias, - attention_output.attn_weights, + attention_output.attention_probs, ) @@ -445,7 +198,9 @@ def __init__( ): super().__init__() self.enc_dec_attention = enc_dec_attention or T5Attention( - is_decoder=True, has_relative_attention_bias=False + is_decoder=True, + has_relative_attention_bias=False, + is_cross_attention=True, ) self.layer_norm = layer_norm or T5LayerNorm(hidden_size=self.enc_dec_attention.hidden_size) self.dropout = nn.Dropout(dropout) @@ -463,7 +218,7 @@ def forward( output_attentions: bool = False, ) -> T5LayerCrossAttentionOutput: normed_hidden_states = self.layer_norm(hidden_states) - attention_output: T5AttentionOutput = self.enc_dec_attention( + attention_output: AttentionOutput = self.enc_dec_attention( normed_hidden_states, mask=attention_mask, key_value_states=key_value_states, @@ -475,11 +230,12 @@ def forward( output_attentions=output_attentions, ) layer_output = hidden_states + self.dropout(attention_output.hidden_states) + return T5LayerCrossAttentionOutput( layer_output, attention_output.key_value_state, attention_output.position_bias, - attention_output.attn_weights, + attention_output.attention_probs, ) @@ -518,7 +274,7 @@ def __init__( @property def hidden_size(self) -> int: - return self.layer[0].self_attention.hidden_size + return self.layer[0].hidden_size def forward( self, @@ -967,6 +723,7 @@ class T5Output: class T5(TransformerModule, Registrable): + _pretrained_mapping = {"shared": "token_embeddings"} _tied_weights = { "token_embeddings.weight": [ diff --git a/allennlp/modules/transformer/transformer_layer.py b/allennlp/modules/transformer/transformer_layer.py index 43a76d33144..f3814e6201e 100644 --- a/allennlp/modules/transformer/transformer_layer.py +++ b/allennlp/modules/transformer/transformer_layer.py @@ -1,12 +1,14 @@ from typing import Union, Optional, TYPE_CHECKING +from dataclasses import dataclass import torch from allennlp.common import FromParams from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.activation_layer import ActivationLayer -from allennlp.modules.transformer.self_attention import SelfAttention +from allennlp.modules.transformer.attention_module import SelfAttention, AttentionOutput from allennlp.modules.transformer.output_layer import OutputLayer +from allennlp.modules.transformer.util import FloatT if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig @@ -38,9 +40,17 @@ def __init__( num_attention_heads: int, attention_dropout: float = 0.0, hidden_dropout: float = 0.0, + is_cross_attention: bool = False, + is_decoder: bool = False, ): super().__init__() - self.self = SelfAttention(hidden_size, num_attention_heads, attention_dropout) + self.self = SelfAttention( + hidden_size, + num_attention_heads, + attention_dropout, + is_cross_attention=is_cross_attention, + is_decoder=is_decoder, + ) self.output = OutputLayer(hidden_size, hidden_size, hidden_dropout) def forward( @@ -69,14 +79,19 @@ def forward( self_output = self.self( input_tensor, - encoder_hidden_states, - encoder_hidden_states, - attention_mask, - head_mask, - output_attentions, + source_states=encoder_hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + ) + + attention_output = self.output(self_output.hidden_states, input_tensor) + outputs = AttentionOutput( + attention_output, + self_output.key_value_state, + self_output.position_bias, + self_output.attention_probs, ) - attention_output = self.output(self_output[0], input_tensor) - outputs = (attention_output,) + self_output[1:] # add attentions if we output them return outputs @classmethod @@ -92,6 +107,17 @@ def _from_config(cls, config: "PretrainedConfig", **kwargs): return cls(**final_kwargs) +@dataclass +class TransformerLayerOutput: + """ + Encapsulates the outputs of the `TransformerLayer` module. + """ + + hidden_states: FloatT + self_attention_probs: Optional[FloatT] = None + cross_attention_probs: Optional[FloatT] = None + + class TransformerLayer(TransformerModule, FromParams): """ This module is a single transformer layer, mapping to `BertLayer` in the architecture in BERT. @@ -149,6 +175,8 @@ def __init__( num_attention_heads=num_attention_heads, attention_dropout=attention_dropout, hidden_dropout=hidden_dropout, + is_cross_attention=True, + is_decoder=True, ) self.intermediate = ActivationLayer( @@ -166,7 +194,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - ): + ) -> TransformerLayerOutput: """ # Parameters @@ -186,8 +214,9 @@ def forward( head_mask, output_attentions=output_attentions, ) - attention_output = attention_outputs[0] - outputs = attention_outputs[1:] # add self attentions if we output attention weights + attention_output = attention_outputs.hidden_states + self_attention_probs = attention_outputs.attention_probs + cross_attention_probs = None if encoder_hidden_states is not None: assert hasattr( @@ -203,14 +232,13 @@ def forward( encoder_attention_mask, output_attentions, ) - attention_output = cross_attention_outputs[0] - outputs = ( - outputs + cross_attention_outputs[1:] - ) # add cross attentions if we output attention weights + attention_output = cross_attention_outputs.hidden_states + cross_attention_probs = cross_attention_outputs.attention_probs intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) - outputs = (layer_output,) + outputs + + outputs = TransformerLayerOutput(layer_output, self_attention_probs, cross_attention_probs) return outputs @classmethod diff --git a/allennlp/modules/transformer/transformer_stack.py b/allennlp/modules/transformer/transformer_stack.py index 7bc4a7247d3..3825ccac7cb 100644 --- a/allennlp/modules/transformer/transformer_stack.py +++ b/allennlp/modules/transformer/transformer_stack.py @@ -1,5 +1,6 @@ -from typing import Union, Optional, TYPE_CHECKING +from typing import Union, Optional, Tuple, TYPE_CHECKING import logging +from dataclasses import dataclass import torch @@ -7,6 +8,7 @@ from allennlp.modules.util import replicate_layers from allennlp.modules.transformer.transformer_layer import TransformerLayer from allennlp.modules.transformer.transformer_module import TransformerModule +from allennlp.modules.transformer.util import FloatT if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig @@ -15,6 +17,18 @@ logger = logging.getLogger(__name__) +@dataclass +class TransformerStackOutput: + """ + Encapsulates the outputs of the `TransformerStack` module. + """ + + final_hidden_states: FloatT + all_hidden_states: Optional[Tuple] = None + all_self_attentions: Optional[Tuple] = None + all_cross_attentions: Optional[Tuple] = None + + class TransformerStack(TransformerModule, FromParams): """ This module is the basic transformer stack. @@ -87,7 +101,7 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, - ): + ) -> TransformerStackOutput: """ # Parameters @@ -118,7 +132,7 @@ def forward( encoder_attention_mask, output_attentions, ) - hidden_states = layer_outputs[0] + hidden_states = layer_outputs.hidden_states if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) # type: ignore if self._add_cross_attention: @@ -127,10 +141,8 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore - return tuple( - v - for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] - if v is not None + return TransformerStackOutput( + hidden_states, all_hidden_states, all_attentions, all_cross_attentions ) @classmethod diff --git a/allennlp/modules/transformer/util.py b/allennlp/modules/transformer/util.py index e9797bff68c..7d6a43cf198 100644 --- a/allennlp/modules/transformer/util.py +++ b/allennlp/modules/transformer/util.py @@ -1,8 +1,12 @@ from typing import Union, Tuple import torch - from allennlp.nn.util import min_value_of_dtype +# Unfortunately mypy is insane, so we have to wrap these in unions. +FloatT = Union[torch.FloatTensor] +IntT = Union[torch.IntTensor] +BoolT = Union[torch.BoolTensor] + def apply_mask( values: torch.FloatTensor, mask: Union[torch.BoolTensor, torch.IntTensor, torch.FloatTensor] diff --git a/tests/modules/transformer/self_attention_test.py b/tests/modules/transformer/self_attention_test.py index 7a3dcb81ec8..af1b4a4c43a 100644 --- a/tests/modules/transformer/self_attention_test.py +++ b/tests/modules/transformer/self_attention_test.py @@ -5,7 +5,8 @@ from transformers import AutoModel from allennlp.common import Params -from allennlp.modules.transformer import SelfAttention + +from allennlp.modules.transformer.attention_module import SelfAttention from allennlp.nn.util import min_value_of_dtype @@ -46,7 +47,7 @@ def test_can_construct_from_params(self_attention, params_dict): assert self_attention.key.in_features == params_dict["hidden_size"] assert self_attention.value.in_features == params_dict["hidden_size"] - assert self_attention.dropout.p == params_dict["dropout"] + assert self_attention.dropout == params_dict["dropout"] @pytest.mark.parametrize( @@ -78,7 +79,7 @@ def test_loading_from_pretrained_weights_using_model_name(pretrained_name, relev pretrained_module = pretrained_module.eval() torch.manual_seed(1234) - output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + output = module(hidden_states, attention_mask=attention_mask.squeeze()).hidden_states if "distilbert" in pretrained_name: torch.manual_seed(1234) hf_output = pretrained_module( diff --git a/tests/modules/transformer/t5_self_attention_test.py b/tests/modules/transformer/t5_self_attention_test.py new file mode 100644 index 00000000000..96eb58bd806 --- /dev/null +++ b/tests/modules/transformer/t5_self_attention_test.py @@ -0,0 +1,126 @@ +import copy +import torch +import pytest + +from transformers import AutoModel + +from allennlp.common import Params + +from allennlp.modules.transformer.attention_module import T5Attention + +from transformers.models.t5.configuration_t5 import T5Config +from transformers.models.t5.modeling_t5 import T5Attention as HFT5Attention +from allennlp.nn.util import min_value_of_dtype + +PARAMS_DICT = { + "hidden_size": 6, + "num_heads": 2, + "key_value_proj_dim": 3, + "dropout": 0.0, + "relative_attention_num_buckets": 2, +} + + +@pytest.fixture +def params_dict(): + return copy.deepcopy(PARAMS_DICT) + + +@pytest.fixture +def params(params_dict): + return Params(params_dict) + + +@pytest.fixture +def t5_attention(params): + return T5Attention.from_params(params.duplicate()) + + +def test_can_construct_from_params(t5_attention, params_dict): + + assert t5_attention.num_attention_heads == params_dict["num_heads"] + assert t5_attention.attention_head_size == params_dict["key_value_proj_dim"] + + assert ( + t5_attention.all_head_size == params_dict["num_heads"] * params_dict["key_value_proj_dim"] + ) + + assert t5_attention.query.in_features == params_dict["hidden_size"] + assert t5_attention.key.in_features == params_dict["hidden_size"] + assert t5_attention.value.in_features == params_dict["hidden_size"] + assert t5_attention.output.in_features == params_dict["hidden_size"] + + assert t5_attention.dropout == params_dict["dropout"] + + +def test_forward_against_huggingface_output(params_dict): + hidden_states = torch.randn(2, 3, 6) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + + hf_kwargs = { + "d_model": params_dict["hidden_size"], + "d_kv": params_dict["key_value_proj_dim"], + "num_heads": params_dict["num_heads"], + "relative_attention_num_buckets": params_dict["relative_attention_num_buckets"], + "dropout_rate": params_dict["dropout"], + } + + torch.manual_seed(1234) + hf_module = HFT5Attention(T5Config(**hf_kwargs), has_relative_attention_bias=False) + + torch.manual_seed(1234) + + params = copy.deepcopy(params_dict) + params["normalize"] = False # only for this test, as HF does not normalize. + t5_attention = T5Attention(**params) + + # setting to eval mode to avoid non-deterministic dropout. + t5_attention = t5_attention.eval() + hf_module = hf_module.eval() + + output = t5_attention.forward(hidden_states, mask=attention_mask) + attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand( + 2, 2, 3, 3 + ) * min_value_of_dtype(hidden_states.dtype) + hf_output = hf_module.forward(hidden_states, mask=attention_mask_hf) + + hs = output.hidden_states + + assert torch.allclose(hs, hf_output[0]) + + +@pytest.mark.parametrize( + "pretrained_name, relevant_module", + [ + ("t5-small", "encoder.block.0.layer.0.SelfAttention"), + ], +) +def test_loading_from_pretrained_weights_using_model_name(pretrained_name, relevant_module): + + torch.manual_seed(1234) + module = T5Attention.from_pretrained_module(pretrained_name, relevant_module=relevant_module) + + torch.manual_seed(1234) + pretrained_module = dict(AutoModel.from_pretrained(pretrained_name).named_modules())[ + relevant_module + ] + + batch_size = 2 + seq_len = 3 + dim = module.query.in_features + hidden_states = torch.randn(batch_size, seq_len, dim) + attention_mask = torch.tensor([[1, 1, 0], [1, 0, 1]])[:, None, None, :] + + # setting to eval mode to avoid non-deterministic dropout. + module = module.eval() + pretrained_module = pretrained_module.eval() + + torch.manual_seed(1234) + output = module(hidden_states, mask=attention_mask.squeeze()).hidden_states + + # The attn_mask is processed outside the self attention module in HF bert models. + attention_mask = (~(attention_mask == 1)) * min_value_of_dtype(hidden_states.dtype) + torch.manual_seed(1234) + hf_output = pretrained_module(hidden_states, mask=attention_mask)[0] + + assert torch.allclose(output, hf_output) diff --git a/tests/modules/transformer/transformer_layer_test.py b/tests/modules/transformer/transformer_layer_test.py index 4c1e141a5a8..538df89c9e0 100644 --- a/tests/modules/transformer/transformer_layer_test.py +++ b/tests/modules/transformer/transformer_layer_test.py @@ -44,7 +44,7 @@ def test_attention(attention_params): assert attention_layer.self.query.in_features == attention_params["hidden_size"] assert attention_layer.self.key.in_features == attention_params["hidden_size"] assert attention_layer.self.value.in_features == attention_params["hidden_size"] - assert attention_layer.self.dropout.p == attention_params["attention_dropout"] + assert attention_layer.self.dropout == attention_params["attention_dropout"] assert attention_layer.output.dense.in_features == attention_params["hidden_size"] assert attention_layer.output.dense.out_features == attention_params["hidden_size"] @@ -87,7 +87,7 @@ def test_attention_matches_huggingface(attention_params, module_name, hf_module) torch.manual_seed(1234) hf_output = hf_module(hidden_states, attention_mask=attention_mask_hf) - assert torch.allclose(output[0], hf_output[0]) + assert torch.allclose(output.hidden_states, hf_output[0]) @pytest.mark.parametrize( @@ -126,7 +126,7 @@ def test_attention_from_pretrained(pretrained_name, relevant_top_level_module): attention_mask_hf = (1.0 - attention_mask_hf) * -10e5 torch.manual_seed(1234) - output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + output = module(hidden_states, attention_mask=attention_mask.squeeze()).hidden_states torch.manual_seed(1234) hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf)[0] @@ -166,7 +166,7 @@ def test_layer(layer_params): assert transformer_layer.attention.self.query.in_features == layer_params["hidden_size"] assert transformer_layer.attention.self.key.in_features == layer_params["hidden_size"] assert transformer_layer.attention.self.value.in_features == layer_params["hidden_size"] - assert transformer_layer.attention.self.dropout.p == layer_params["attention_dropout"] + assert transformer_layer.attention.self.dropout == layer_params["attention_dropout"] assert transformer_layer.attention.output.dense.in_features == layer_params["hidden_size"] assert transformer_layer.attention.output.dense.out_features == layer_params["hidden_size"] @@ -243,7 +243,7 @@ def test_layer_matches_huggingface(layer_params, module_name, hf_module): torch.manual_seed(1234) hf_output = hf_module(hidden_states, attention_mask=attention_mask_hf) - assert torch.allclose(output[0], hf_output[0]) + assert torch.allclose(output.hidden_states, hf_output[0]) @pytest.mark.parametrize( @@ -282,7 +282,7 @@ def test_layer_from_pretrained(pretrained_name, relevant_top_level_module): attention_mask_hf = (1.0 - attention_mask_hf) * -10e5 torch.manual_seed(1234) - output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + output = module(hidden_states, attention_mask=attention_mask.squeeze()).hidden_states torch.manual_seed(1234) hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf)[0] diff --git a/tests/modules/transformer/transformer_stack_test.py b/tests/modules/transformer/transformer_stack_test.py index cf42f6c0f6d..4812a74f58b 100644 --- a/tests/modules/transformer/transformer_stack_test.py +++ b/tests/modules/transformer/transformer_stack_test.py @@ -55,7 +55,7 @@ def test_transformer_stack_from_params(params): hidden_states, attention_mask=attention_mask ) - assert torch.allclose(from_layer_output[0], output[0]) + assert torch.allclose(from_layer_output.final_hidden_states, output.final_hidden_states) # Make sure forward pass raises with bad input. with pytest.raises(AssertionError): @@ -102,7 +102,7 @@ def test_loading_from_pretrained(pretrained_model_name): torch.manual_seed(SEED) hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf) - assert torch.allclose(output[0], hf_output[0]) + assert torch.allclose(output.final_hidden_states, hf_output[0]) def test_loading_partial_pretrained_weights(): From 3916cf36520e8017276dac6236b6fb4a1ae8261a Mon Sep 17 00:00:00 2001 From: Kuo Liao <MagiaSN@yeah.net> Date: Thu, 3 Jun 2021 05:44:43 +0800 Subject: [PATCH 48/63] Fix tqdm logging into multiple files with allennlp-optuna (#5235) * Fix tqdm logging into multiple files with allennlp-optuna * Update changelog * Add unittest for resetting tqdm logger handlers Co-authored-by: Pete <petew@allenai.org> --- CHANGELOG.md | 1 + allennlp/common/logging.py | 1 + tests/common/logging_test.py | 18 +++++++++++++++++- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index acf6730854e..3aba4a6d638 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed the potential for a race condition with `cached_path()` when extracting archives. Although the race condition is still possible if used with `force_extract=True`. - Fixed `wandb` callback to work in distributed training. +- Fixed `tqdm` logging into multiple files with `allennlp-optuna`. ## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22 diff --git a/allennlp/common/logging.py b/allennlp/common/logging.py index a278ec2edb4..23a337a0ce1 100644 --- a/allennlp/common/logging.py +++ b/allennlp/common/logging.py @@ -126,4 +126,5 @@ def excepthook(exctype, value, traceback): # also log tqdm from allennlp.common.tqdm import logger as tqdm_logger + tqdm_logger.handlers.clear() tqdm_logger.addHandler(file_handler) diff --git a/tests/common/logging_test.py b/tests/common/logging_test.py index 643eb6ccbdf..c072efb3ecb 100644 --- a/tests/common/logging_test.py +++ b/tests/common/logging_test.py @@ -2,8 +2,9 @@ import logging import random -from allennlp.common.logging import AllenNlpLogger +from allennlp.common.logging import AllenNlpLogger, prepare_global_logging from allennlp.common.testing import AllenNlpTestCase +from allennlp.common.tqdm import Tqdm class TestLogging(AllenNlpTestCase): @@ -64,3 +65,18 @@ def test_getLogger(self): logger = logging.getLogger("test_logger") assert isinstance(logger, AllenNlpLogger) + + def test_reset_tqdm_logger_handlers(self): + serialization_dir_a = os.path.join(self.TEST_DIR, "test_a") + os.makedirs(serialization_dir_a, exist_ok=True) + prepare_global_logging(serialization_dir_a) + serialization_dir_b = os.path.join(self.TEST_DIR, "test_b") + os.makedirs(serialization_dir_b, exist_ok=True) + prepare_global_logging(serialization_dir_b) + # Use range(1) to make sure there should be only 2 lines in the file (0% and 100%) + for _ in Tqdm.tqdm(range(1)): + pass + with open(os.path.join(serialization_dir_a, "out.log"), "r") as f: + assert len(f.readlines()) == 0 + with open(os.path.join(serialization_dir_b, "out.log"), "r") as f: + assert len(f.readlines()) == 2 From 4753906f10287a8dcc20b5a740d6c91f8aca0e4c Mon Sep 17 00:00:00 2001 From: Akshita Bhagia <akshita23bhagia@gmail.com> Date: Wed, 2 Jun 2021 15:01:59 -0700 Subject: [PATCH 49/63] Checklist fixes (#5239) * bug fix * common lexicons * update changelog * Update CHANGELOG.md --- .../sentiment_analysis_suite.py | 312 +++++++++--------- .../task_checklists/task_suite.py | 1 + .../textual_entailment_suite.py | 88 +---- .../task_checklists/utils.py | 86 +++++ 4 files changed, 247 insertions(+), 240 deletions(-) diff --git a/allennlp/confidence_checks/task_checklists/sentiment_analysis_suite.py b/allennlp/confidence_checks/task_checklists/sentiment_analysis_suite.py index 2c68cd9efaf..a6eeac149a0 100644 --- a/allennlp/confidence_checks/task_checklists/sentiment_analysis_suite.py +++ b/allennlp/confidence_checks/task_checklists/sentiment_analysis_suite.py @@ -3,7 +3,6 @@ from overrides import overrides from checklist.test_suite import TestSuite from checklist.test_types import MFT, INV, DIR, Expect -from checklist.editor import Editor from checklist.perturb import Perturb from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite from allennlp.confidence_checks.task_checklists import utils @@ -60,7 +59,7 @@ def preds_and_confs_fn(data): for pred in predictions: label = pred["probs"].index(max(pred["probs"])) labels.append(label) - confs.append([pred["probs"][self._positive], pred["probs"][self._negative]]) + confs.append(pred["probs"]) return np.array(labels), np.array(confs) return preds_and_confs_fn @@ -97,156 +96,153 @@ def _default_tests(self, data: Optional[Iterable[str]], num_test_cases=100): self._default_negation_tests(data, num_test_cases) def _setup_editor(self): - if not hasattr(self, "editor"): - self.editor = Editor() - - pos_adj = [ - "good", - "great", - "excellent", - "amazing", - "extraordinary", - "beautiful", - "fantastic", - "nice", - "incredible", - "exceptional", - "awesome", - "perfect", - "fun", - "adorable", - "brilliant", - "exciting", - "sweet", - "wonderful", - ] - neg_adj = [ - "awful", - "bad", - "horrible", - "weird", - "rough", - "lousy", - "unhappy", - "average", - "difficult", - "poor", - "sad", - "frustrating", - "hard", - "lame", - "nasty", - "annoying", - "boring", - "creepy", - "dreadful", - "ridiculous", - "terrible", - "ugly", - "unpleasant", - ] - self.editor.add_lexicon("pos_adj", pos_adj, overwrite=True) - self.editor.add_lexicon("neg_adj", neg_adj, overwrite=True) - - pos_verb_present = [ - "like", - "enjoy", - "appreciate", - "love", - "recommend", - "admire", - "value", - "welcome", - ] - neg_verb_present = ["hate", "dislike", "regret", "abhor", "dread", "despise"] - pos_verb_past = [ - "liked", - "enjoyed", - "appreciated", - "loved", - "admired", - "valued", - "welcomed", - ] - neg_verb_past = ["hated", "disliked", "regretted", "abhorred", "dreaded", "despised"] - self.editor.add_lexicon("pos_verb_present", pos_verb_present, overwrite=True) - self.editor.add_lexicon("neg_verb_present", neg_verb_present, overwrite=True) - self.editor.add_lexicon("pos_verb_past", pos_verb_past, overwrite=True) - self.editor.add_lexicon("neg_verb_past", neg_verb_past, overwrite=True) - self.editor.add_lexicon("pos_verb", pos_verb_present + pos_verb_past, overwrite=True) - self.editor.add_lexicon("neg_verb", neg_verb_present + neg_verb_past, overwrite=True) - - noun = [ - "airline", - "movie", - "product", - "customer service", - "restaurant", - "hotel", - "food", - "staff", - "company", - "crew", - "service", - ] - self.editor.add_lexicon("noun", noun, overwrite=True) - - intens_adj = [ - "very", - "really", - "absolutely", - "truly", - "extremely", - "quite", - "incredibly", - "amazingly", - "especially", - "exceptionally", - "unbelievably", - "utterly", - "exceedingly", - "rather", - "totally", - "particularly", - ] - intens_verb = [ - "really", - "absolutely", - "truly", - "extremely", - "especially", - "utterly", - "totally", - "particularly", - "highly", - "definitely", - "certainly", - "genuinely", - "honestly", - "strongly", - "sure", - "sincerely", - ] - - self.editor.add_lexicon("intens_adj", intens_adj, overwrite=True) - self.editor.add_lexicon("intens_verb", intens_verb, overwrite=True) - - reducer_adj = [ - "somewhat", - "kinda", - "mostly", - "probably", - "generally", - "reasonably", - "a little", - "a bit", - "slightly", - ] - - self.editor.add_lexicon("reducer_adj", reducer_adj, overwrite=True) - - self.monotonic_label = Expect.monotonic(increasing=True, tolerance=0.1) - self.monotonic_label_down = Expect.monotonic(increasing=False, tolerance=0.1) + super()._setup_editor() + + pos_adj = [ + "good", + "great", + "excellent", + "amazing", + "extraordinary", + "beautiful", + "fantastic", + "nice", + "incredible", + "exceptional", + "awesome", + "perfect", + "fun", + "adorable", + "brilliant", + "exciting", + "sweet", + "wonderful", + ] + neg_adj = [ + "awful", + "bad", + "horrible", + "weird", + "rough", + "lousy", + "average", + "difficult", + "poor", + "sad", + "frustrating", + "lame", + "nasty", + "annoying", + "boring", + "creepy", + "dreadful", + "ridiculous", + "terrible", + "ugly", + "unpleasant", + ] + self.editor.add_lexicon("pos_adj", pos_adj, overwrite=True) + self.editor.add_lexicon("neg_adj", neg_adj, overwrite=True) + + pos_verb_present = [ + "like", + "enjoy", + "appreciate", + "love", + "recommend", + "admire", + "value", + "welcome", + ] + neg_verb_present = ["hate", "dislike", "regret", "abhor", "dread", "despise"] + pos_verb_past = [ + "liked", + "enjoyed", + "appreciated", + "loved", + "admired", + "valued", + "welcomed", + ] + neg_verb_past = ["hated", "disliked", "regretted", "abhorred", "dreaded", "despised"] + self.editor.add_lexicon("pos_verb_present", pos_verb_present, overwrite=True) + self.editor.add_lexicon("neg_verb_present", neg_verb_present, overwrite=True) + self.editor.add_lexicon("pos_verb_past", pos_verb_past, overwrite=True) + self.editor.add_lexicon("neg_verb_past", neg_verb_past, overwrite=True) + self.editor.add_lexicon("pos_verb", pos_verb_present + pos_verb_past, overwrite=True) + self.editor.add_lexicon("neg_verb", neg_verb_present + neg_verb_past, overwrite=True) + + noun = [ + "airline", + "movie", + "product", + "customer service", + "restaurant", + "hotel", + "food", + "staff", + "company", + "crew", + "service", + ] + self.editor.add_lexicon("noun", noun, overwrite=True) + + intens_adj = [ + "very", + "really", + "absolutely", + "truly", + "extremely", + "quite", + "incredibly", + "amazingly", + "especially", + "exceptionally", + "unbelievably", + "utterly", + "exceedingly", + "rather", + "totally", + "particularly", + ] + intens_verb = [ + "really", + "absolutely", + "truly", + "extremely", + "especially", + "utterly", + "totally", + "particularly", + "highly", + "definitely", + "certainly", + "genuinely", + "honestly", + "strongly", + "sure", + "sincerely", + ] + + self.editor.add_lexicon("intens_adj", intens_adj, overwrite=True) + self.editor.add_lexicon("intens_verb", intens_verb, overwrite=True) + + reducer_adj = [ + "somewhat", + "kinda", + "mostly", + "probably", + "generally", + "reasonably", + "a little", + "a bit", + "slightly", + ] + + self.editor.add_lexicon("reducer_adj", reducer_adj, overwrite=True) + + self.monotonic_label = Expect.monotonic(increasing=True, tolerance=0.1) + self.monotonic_label_down = Expect.monotonic(increasing=False, tolerance=0.1) def _default_vocabulary_tests(self, data: Optional[Iterable[str]], num_test_cases=100): @@ -371,7 +367,7 @@ def _default_vocabulary_tests(self, data: Optional[Iterable[str]], num_test_case templates=template.templates, name="Intensifiers", capability="Vocabulary", - description="Test is composed of pairs of sentences (x1, x2), where we add an intensifier" + description="Test is composed of pairs of sentences (x1, x2), where we add an intensifier " "such as 'really',or 'very' to x2 and expect the confidence to NOT go down " "(with tolerance=0.1). e.g.:" "x1 = 'That was a good movie'" @@ -400,7 +396,7 @@ def _default_vocabulary_tests(self, data: Optional[Iterable[str]], num_test_case templates=template.templates, name="Reducers", capability="Vocabulary", - description="Test is composed of pairs of sentences (x1, x2), where we add a reducer" + description="Test is composed of pairs of sentences (x1, x2), where we add a reducer " "such as 'somewhat', or 'kinda' to x2 and expect the confidence to NOT go up " " (with tolerance=0.1). e.g.:" "x1 = 'The staff was good.'" @@ -555,8 +551,8 @@ def _default_temporal_tests(self, data: Optional[Iterable[str]], num_test_cases= capability="Temporal", description="Have two conflicing statements, one about the past and " "one about the present." - "Expect the present to carry the sentiment. Examples:" - "I used to love this airline, now I hate it -> should be negative" + "Expect the present to carry the sentiment. Examples:\n" + "I used to love this airline, now I hate it -> should be negative\n" "I love this airline, although I used to hate it -> should be positive", ) @@ -604,13 +600,13 @@ def _default_fairness_tests(self, data: Optional[Iterable[str]], num_test_cases= for p, vals in protected.items(): template = self.editor.template( - ["{male} is %s {mask}." % r for r in vals], + ["{male} is %s {profession}." % r for r in vals], return_maps=False, nsamples=num_test_cases, save=True, ) template += self.editor.template( - ["{female} is %s {mask}." % r for r in vals], + ["{female} is %s {profession}." % r for r in vals], return_maps=False, nsamples=num_test_cases, save=True, diff --git a/allennlp/confidence_checks/task_checklists/task_suite.py b/allennlp/confidence_checks/task_checklists/task_suite.py index 6ddf00d59b1..0d7e1a1f688 100644 --- a/allennlp/confidence_checks/task_checklists/task_suite.py +++ b/allennlp/confidence_checks/task_checklists/task_suite.py @@ -378,6 +378,7 @@ def _setup_editor(self): """ if not hasattr(self, "editor"): self.editor = Editor() + utils.add_common_lexicons(self.editor) def add_test(self, test: Union[MFT, INV, DIR]): """ diff --git a/allennlp/confidence_checks/task_checklists/textual_entailment_suite.py b/allennlp/confidence_checks/task_checklists/textual_entailment_suite.py index b8e1a810f23..7e7fb30209d 100644 --- a/allennlp/confidence_checks/task_checklists/textual_entailment_suite.py +++ b/allennlp/confidence_checks/task_checklists/textual_entailment_suite.py @@ -220,82 +220,6 @@ def _setup_editor(self): ] self.editor.add_lexicon("nouns", nouns, overwrite=True) - professions = [ - "journalist", - "historian", - "secretary", - "nurse", - "waitress", - "accountant", - "engineer", - "attorney", - "artist", - "editor", - "architect", - "model", - "interpreter", - "analyst", - "actor", - "actress", - "assistant", - "intern", - "economist", - "organizer", - "author", - "investigator", - "agent", - "administrator", - "executive", - "educator", - "investor", - "DJ", - "entrepreneur", - "auditor", - "advisor", - "instructor", - "activist", - "consultant", - "apprentice", - "reporter", - "expert", - "psychologist", - "examiner", - "painter", - "manager", - "contractor", - "therapist", - "programmer", - "musician", - "producer", - "associate", - "intermediary", - "designer", - "cook", - "salesperson", - "dentist", - "attorney", - "detective", - "banker", - "researcher", - "cop", - "driver", - "counselor", - "clerk", - "professor", - "tutor", - "coach", - "chemist", - "scientist", - "veterinarian", - "firefighter", - "baker", - "psychiatrist", - "prosecutor", - "director", - "technician", - ] - self.editor.add_lexicon("professions", professions, overwrite=True) - @overrides def _default_tests(self, data: Optional[Iterable[Tuple]], num_test_cases=100): super()._default_tests(data, num_test_cases) @@ -406,8 +330,8 @@ def _default_ner_tests(self, data: Optional[Iterable[Tuple]], num_test_cases=100 def _default_temporal_tests(self, data: Optional[Iterable[Tuple]], num_test_cases=100): template = self.editor.template( ( - "{first_name} works as {a:professions}", - "{first_name} used to work as a {professions}", + "{first_name} works as {a:profession}", + "{first_name} used to work as a {profession}", ), nsamples=num_test_cases, remove_duplicates=True, @@ -415,8 +339,8 @@ def _default_temporal_tests(self, data: Optional[Iterable[Tuple]], num_test_case template += self.editor.template( ( - "{first_name} {last_name} is {a:professions}", - "{first_name} {last_name} was {a:professions}", + "{first_name} {last_name} is {a:profession}", + "{first_name} {last_name} was {a:profession}", ), nsamples=num_test_cases, remove_duplicates=True, @@ -434,8 +358,8 @@ def _default_temporal_tests(self, data: Optional[Iterable[Tuple]], num_test_case template = self.editor.template( ( - "{first_name} was {a:professions1} before they were {a:professions2}", - "{first_name} was {a:professions1} after they were {a:professions2}", + "{first_name} was {a:profession1} before they were {a:profession2}", + "{first_name} was {a:profession1} after they were {a:profession2}", ), nsamples=num_test_cases, remove_duplicates=True, diff --git a/allennlp/confidence_checks/task_checklists/utils.py b/allennlp/confidence_checks/task_checklists/utils.py index 22ad9deedf1..236c1618372 100644 --- a/allennlp/confidence_checks/task_checklists/utils.py +++ b/allennlp/confidence_checks/task_checklists/utils.py @@ -2,6 +2,92 @@ from typing import Dict, Callable, List, Union import numpy as np import spacy +from checklist.editor import Editor + + +def add_common_lexicons(editor: Editor): + """ + Add commonly used lexicons to the editor object. These can be used in all + the task suites. + + Note: Updates the `editor` object in place. + """ + profession = [ + "journalist", + "historian", + "secretary", + "nurse", + "waitress", + "accountant", + "engineer", + "attorney", + "artist", + "editor", + "architect", + "model", + "interpreter", + "analyst", + "actor", + "actress", + "assistant", + "intern", + "economist", + "organizer", + "author", + "investigator", + "agent", + "administrator", + "executive", + "educator", + "investor", + "DJ", + "entrepreneur", + "auditor", + "advisor", + "instructor", + "activist", + "consultant", + "apprentice", + "reporter", + "expert", + "psychologist", + "examiner", + "painter", + "manager", + "contractor", + "therapist", + "programmer", + "musician", + "producer", + "associate", + "intermediary", + "designer", + "cook", + "salesperson", + "dentist", + "attorney", + "detective", + "banker", + "researcher", + "cop", + "driver", + "counselor", + "clerk", + "professor", + "tutor", + "coach", + "chemist", + "scientist", + "veterinarian", + "firefighter", + "baker", + "psychiatrist", + "prosecutor", + "director", + "technician", + ] + + editor.add_lexicon("profession", profession, overwrite=True) def spacy_wrap(fn: Callable, language: str = "en_core_web_sm", **kwargs) -> Callable: From b7a62fa1850c9e88f5d651626122845f4113a95e Mon Sep 17 00:00:00 2001 From: ArjunSubramonian <arjun.subramonian@gmail.com> Date: Wed, 2 Jun 2021 16:14:01 -0700 Subject: [PATCH 50/63] Contextualized bias mitigation (#5176) * added linear and hard debiasers * worked on documentation * committing changes before branch switch * committing changes before switching branch * finished bias direction, linear and hard debiasers, need to write tests * finished bias direction test * Commiting changes before switching branch * finished hard and linear debiasers * finished OSCaR * bias mitigators tests and bias metrics remaining * added bias mitigator tests * added bias mitigator tests * finished tests for bias mitigation methods * fixed gpu issues * fixed gpu issues * fixed gpu issues * resolve issue with count_nonzero not being differentiable * added more references * fairness during finetuning * finished bias mitigator wrapper * added reference * updated CHANGELOG and fixed minor docs issues * move id tensors to embedding device * fixed to use predetermined bias direction * fixed minor doc errors * snli reader registration issue * fixed _pretrained from params issue * fixed device issues * evaluate bias mitigation initial commit * finished evaluate bias mitigation * handles multiline prediction files * fixed minor bugs * fixed minor bugs * improved prediction diff JSON format * forgot to resolve a conflict * Refactored evaluate bias mitigation to use NLI metric * Added SNLIPredictionsDiff class * ensured dataloader is same for bias mitigated and baseline models * finished evaluate bias mitigation * Update CHANGELOG.md * Replaced local data files with github raw content links * Update allennlp/fairness/bias_mitigator_applicator.py Co-authored-by: Pete <petew@allenai.org> * deleted evaluate_bias_mitigation from git tracking * removed evaluate-bias-mitigation instances from rest of repo * addressed Akshita's comments * moved bias mitigator applicator test to allennlp-models * removed unnecessary files Co-authored-by: Arjun Subramonian <arjuns@Arjuns-MacBook-Pro.local> Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-0-106.us-west-2.compute.internal> Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-0-108.us-west-2.compute.internal> Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-1-108.us-west-2.compute.internal> Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com> Co-authored-by: Pete <petew@allenai.org> --- CHANGELOG.md | 7 +- allennlp/fairness/__init__.py | 17 +- allennlp/fairness/bias_direction_wrappers.py | 269 +++ allennlp/fairness/bias_metrics.py | 2 + .../fairness/bias_mitigator_applicator.py | 114 ++ allennlp/fairness/bias_mitigator_wrappers.py | 266 +++ allennlp/fairness/bias_mitigators.py | 1 + allennlp/fairness/bias_utils.py | 111 ++ .../fairness/definitional_pairs.json | 42 + test_fixtures/fairness/equalize_pairs.json | 210 +++ .../fairness/gender_specific_full.json | 1443 +++++++++++++++++ tests/fairness/bias_utils_test.py | 79 + 12 files changed, 2557 insertions(+), 4 deletions(-) create mode 100644 allennlp/fairness/bias_direction_wrappers.py create mode 100644 allennlp/fairness/bias_mitigator_applicator.py create mode 100644 allennlp/fairness/bias_mitigator_wrappers.py create mode 100644 allennlp/fairness/bias_utils.py create mode 100644 test_fixtures/fairness/definitional_pairs.json create mode 100644 test_fixtures/fairness/equalize_pairs.json create mode 100644 test_fixtures/fairness/gender_specific_full.json create mode 100644 tests/fairness/bias_utils_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3aba4a6d638..8778f696a95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.confidence_checks.task_checklists` module. +- Added `BiasMitigatorApplicator`, which wraps any Model and mitigates biases by finetuning +on a downstream task. - Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files. - Meta data defined by the class `allennlp.common.meta.Meta` is now saved in the serialization directory and archive file when training models from the command line. This is also now part of the `Archive` named tuple that's returned from `load_archive()`. @@ -54,7 +56,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed `wandb` callback to work in distributed training. - Fixed `tqdm` logging into multiple files with `allennlp-optuna`. - ## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22 ### Added @@ -80,8 +81,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add new dimension to the `interpret` module: influence functions via the `InfluenceInterpreter` base class, along with a concrete implementation: `SimpleInfluence`. - Added a `quiet` parameter to the `MultiProcessDataLoading` that disables `Tqdm` progress bars. - The test for distributed metrics now takes a parameter specifying how often you want to run it. -- Created the fairness module and added four fairness metrics: `Independence`, `Separation`, and `Sufficiency`. -- Added three bias metrics to the fairness module: `WordEmbeddingAssociationTest`, `EmbeddingCoherenceTest`, `NaturalLanguageInference`, and `AssociationWithoutGroundTruth`. +- Created the fairness module and added three fairness metrics: `Independence`, `Separation`, and `Sufficiency`. +- Added four bias metrics to the fairness module: `WordEmbeddingAssociationTest`, `EmbeddingCoherenceTest`, `NaturalLanguageInference`, and `AssociationWithoutGroundTruth`. - Added four bias direction methods (`PCABiasDirection`, `PairedPCABiasDirection`, `TwoMeansBiasDirection`, `ClassificationNormalBiasDirection`) and four bias mitigation methods (`LinearBiasMitigator`, `HardBiasMitigator`, `INLPBiasMitigator`, `OSCaRBiasMitigator`). ### Changed diff --git a/allennlp/fairness/__init__.py b/allennlp/fairness/__init__.py index 976ada2d076..02a02506eb1 100644 --- a/allennlp/fairness/__init__.py +++ b/allennlp/fairness/__init__.py @@ -3,7 +3,8 @@ 1. measure the fairness of models according to multiple definitions of fairness 2. measure bias amplification -3. debias embeddings during training time and post-processing +3. mitigate bias in static and contextualized embeddings during training time and +post-processing """ from allennlp.fairness.fairness_metrics import Independence, Separation, Sufficiency @@ -25,3 +26,17 @@ INLPBiasMitigator, OSCaRBiasMitigator, ) +from allennlp.fairness.bias_utils import load_words, load_word_pairs +from allennlp.fairness.bias_mitigator_applicator import BiasMitigatorApplicator +from allennlp.fairness.bias_mitigator_wrappers import ( + HardBiasMitigatorWrapper, + LinearBiasMitigatorWrapper, + INLPBiasMitigatorWrapper, + OSCaRBiasMitigatorWrapper, +) +from allennlp.fairness.bias_direction_wrappers import ( + PCABiasDirectionWrapper, + PairedPCABiasDirectionWrapper, + TwoMeansBiasDirectionWrapper, + ClassificationNormalBiasDirectionWrapper, +) diff --git a/allennlp/fairness/bias_direction_wrappers.py b/allennlp/fairness/bias_direction_wrappers.py new file mode 100644 index 00000000000..94cb4abe8ca --- /dev/null +++ b/allennlp/fairness/bias_direction_wrappers.py @@ -0,0 +1,269 @@ +import torch +from typing import Union, Optional +from os import PathLike + +from allennlp.fairness.bias_direction import ( + BiasDirection, + PCABiasDirection, + PairedPCABiasDirection, + TwoMeansBiasDirection, + ClassificationNormalBiasDirection, +) +from allennlp.fairness.bias_utils import load_word_pairs, load_words + +from allennlp.common import Registrable +from allennlp.data.tokenizers.tokenizer import Tokenizer +from allennlp.data import Vocabulary + + +class BiasDirectionWrapper(Registrable): + """ + Parent class for bias direction wrappers. + """ + + def __init__(self): + self.direction: BiasDirection = None + self.noise: float = None + + def __call__(self, module): + raise NotImplementedError + + def train(self, mode: bool = True): + """ + + # Parameters + + mode : `bool`, optional (default=`True`) + Sets `requires_grad` to value of `mode` for bias direction. + """ + self.direction.requires_grad = mode + + def add_noise(self, t: torch.Tensor): + """ + + # Parameters + + t : `torch.Tensor` + Tensor to which to add small amount of Gaussian noise. + """ + return t + self.noise * torch.randn(t.size(), device=t.device) + + +@BiasDirectionWrapper.register("pca") +class PCABiasDirectionWrapper(BiasDirectionWrapper): + """ + + # Parameters + + seed_words_file : `Union[PathLike, str]` + Path of file containing seed words. + tokenizer : `Tokenizer` + Tokenizer used to tokenize seed words. + direction_vocab : `Vocabulary`, optional (default=`None`) + Vocabulary of tokenizer. If `None`, assumes tokenizer is of + type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute. + namespace : `str`, optional (default=`"tokens"`) + Namespace of direction_vocab to use when tokenizing. + Disregarded when direction_vocab is `None`. + requires_grad : `bool`, optional (default=`False`) + Option to enable gradient calculation for bias direction. + noise : `float`, optional (default=`1e-10`) + To avoid numerical instability if embeddings are initialized uniformly. + """ + + def __init__( + self, + seed_words_file: Union[PathLike, str], + tokenizer: Tokenizer, + direction_vocab: Optional[Vocabulary] = None, + namespace: str = "tokens", + requires_grad: bool = False, + noise: float = 1e-10, + ): + self.ids = load_words(seed_words_file, tokenizer, direction_vocab, namespace) + self.direction = PCABiasDirection(requires_grad=requires_grad) + self.noise = noise + + def __call__(self, module): + # embed subword token IDs and mean pool to get + # embedding of original word + ids_embeddings = [] + for i in self.ids: + i = i.to(module.weight.device) + ids_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) + ids_embeddings = torch.cat(ids_embeddings) + + # adding trivial amount of noise + # to eliminate linear dependence amongst all embeddings + # when training first starts + ids_embeddings = self.add_noise(ids_embeddings) + + return self.direction(ids_embeddings) + + +@BiasDirectionWrapper.register("paired_pca") +class PairedPCABiasDirectionWrapper(BiasDirectionWrapper): + """ + + # Parameters + + seed_word_pairs_file : `Union[PathLike, str]` + Path of file containing seed word pairs. + tokenizer : `Tokenizer` + Tokenizer used to tokenize seed words. + direction_vocab : `Vocabulary`, optional (default=`None`) + Vocabulary of tokenizer. If `None`, assumes tokenizer is of + type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute. + namespace : `str`, optional (default=`"tokens"`) + Namespace of direction_vocab to use when tokenizing. + Disregarded when direction_vocab is `None`. + requires_grad : `bool`, optional (default=`False`) + Option to enable gradient calculation for bias direction. + noise : `float`, optional (default=`1e-10`) + To avoid numerical instability if embeddings are initialized uniformly. + """ + + def __init__( + self, + seed_word_pairs_file: Union[PathLike, str], + tokenizer: Tokenizer, + direction_vocab: Optional[Vocabulary] = None, + namespace: str = "tokens", + requires_grad: bool = False, + noise: float = 1e-10, + ): + self.ids1, self.ids2 = load_word_pairs( + seed_word_pairs_file, tokenizer, direction_vocab, namespace + ) + self.direction = PairedPCABiasDirection(requires_grad=requires_grad) + self.noise = noise + + def __call__(self, module): + # embed subword token IDs and mean pool to get + # embedding of original word + ids1_embeddings = [] + for i in self.ids1: + i = i.to(module.weight.device) + ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) + ids2_embeddings = [] + for i in self.ids2: + i = i.to(module.weight.device) + ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) + ids1_embeddings = torch.cat(ids1_embeddings) + ids2_embeddings = torch.cat(ids2_embeddings) + + ids1_embeddings = self.add_noise(ids1_embeddings) + ids2_embeddings = self.add_noise(ids2_embeddings) + + return self.direction(ids1_embeddings, ids2_embeddings) + + +@BiasDirectionWrapper.register("two_means") +class TwoMeansBiasDirectionWrapper(BiasDirectionWrapper): + """ + + # Parameters + + seed_word_pairs_file : `Union[PathLike, str]` + Path of file containing seed word pairs. + tokenizer : `Tokenizer` + Tokenizer used to tokenize seed words. + direction_vocab : `Vocabulary`, optional (default=`None`) + Vocabulary of tokenizer. If `None`, assumes tokenizer is of + type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute. + namespace : `str`, optional (default=`"tokens"`) + Namespace of direction_vocab to use when tokenizing. + Disregarded when direction_vocab is `None`. + requires_grad : `bool`, optional (default=`False`) + Option to enable gradient calculation for bias direction. + noise : `float`, optional (default=`1e-10`) + To avoid numerical instability if embeddings are initialized uniformly. + """ + + def __init__( + self, + seed_word_pairs_file: Union[PathLike, str], + tokenizer: Tokenizer, + direction_vocab: Optional[Vocabulary] = None, + namespace: str = "tokens", + requires_grad: bool = False, + noise: float = 1e-10, + ): + self.ids1, self.ids2 = load_word_pairs( + seed_word_pairs_file, tokenizer, direction_vocab, namespace + ) + self.direction = TwoMeansBiasDirection(requires_grad=requires_grad) + self.noise = noise + + def __call__(self, module): + # embed subword token IDs and mean pool to get + # embedding of original word + ids1_embeddings = [] + for i in self.ids1: + i = i.to(module.weight.device) + ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) + ids2_embeddings = [] + for i in self.ids2: + i = i.to(module.weight.device) + ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) + ids1_embeddings = torch.cat(ids1_embeddings) + ids2_embeddings = torch.cat(ids2_embeddings) + + ids1_embeddings = self.add_noise(ids1_embeddings) + ids2_embeddings = self.add_noise(ids2_embeddings) + + return self.direction(ids1_embeddings, ids2_embeddings) + + +@BiasDirectionWrapper.register("classification_normal") +class ClassificationNormalBiasDirectionWrapper(BiasDirectionWrapper): + """ + + # Parameters + + seed_word_pairs_file : `Union[PathLike, str]` + Path of file containing seed word pairs. + tokenizer : `Tokenizer` + Tokenizer used to tokenize seed words. + direction_vocab : `Vocabulary`, optional (default=`None`) + Vocabulary of tokenizer. If `None`, assumes tokenizer is of + type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute. + namespace : `str`, optional (default=`"tokens"`) + Namespace of direction_vocab to use when tokenizing. + Disregarded when direction_vocab is `None`. + noise : `float`, optional (default=`1e-10`) + To avoid numerical instability if embeddings are initialized uniformly. + """ + + def __init__( + self, + seed_word_pairs_file: Union[PathLike, str], + tokenizer: Tokenizer, + direction_vocab: Optional[Vocabulary] = None, + namespace: str = "tokens", + noise: float = 1e-10, + ): + self.ids1, self.ids2 = load_word_pairs( + seed_word_pairs_file, tokenizer, direction_vocab, namespace + ) + self.direction = ClassificationNormalBiasDirection() + self.noise = noise + + def __call__(self, module): + # embed subword token IDs and mean pool to get + # embedding of original word + ids1_embeddings = [] + for i in self.ids1: + i = i.to(module.weight.device) + ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) + ids2_embeddings = [] + for i in self.ids2: + i = i.to(module.weight.device) + ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) + ids1_embeddings = torch.cat(ids1_embeddings) + ids2_embeddings = torch.cat(ids2_embeddings) + + ids1_embeddings = self.add_noise(ids1_embeddings) + ids2_embeddings = self.add_noise(ids2_embeddings) + + return self.direction(ids1_embeddings, ids2_embeddings) diff --git a/allennlp/fairness/bias_metrics.py b/allennlp/fairness/bias_metrics.py index e7be2763c1c..3c38e35dc08 100644 --- a/allennlp/fairness/bias_metrics.py +++ b/allennlp/fairness/bias_metrics.py @@ -258,6 +258,8 @@ class NaturalLanguageInference(Metric): 3. Threshold:tau (T:tau): A parameterized measure that reports the fraction of examples whose probability of neutral is above tau. + # Parameters + neutral_label : `int`, optional (default=`2`) The discrete integer label corresponding to a neutral entailment prediction. taus : `List[float]`, optional (default=`[0.5, 0.7]`) diff --git a/allennlp/fairness/bias_mitigator_applicator.py b/allennlp/fairness/bias_mitigator_applicator.py new file mode 100644 index 00000000000..add604473f0 --- /dev/null +++ b/allennlp/fairness/bias_mitigator_applicator.py @@ -0,0 +1,114 @@ +""" +A Model wrapper to mitigate biases in +contextual embeddings during finetuning +on a downstream task and test time. + +Based on: Dev, S., Li, T., Phillips, J.M., & Srikumar, V. (2020). +[On Measuring and Mitigating Biased Inferences of Word Embeddings] +(https://api.semanticscholar.org/CorpusID:201670701). +ArXiv, abs/1908.09369. +""" + +from overrides import overrides + +from allennlp.fairness.bias_mitigator_wrappers import BiasMitigatorWrapper + +from allennlp.common.lazy import Lazy +from allennlp.data import Vocabulary +from allennlp.models import Model +from allennlp.nn.util import find_embedding_layer + + +@Model.register("bias_mitigator_applicator") +class BiasMitigatorApplicator(Model): + """ + Wrapper class to apply bias mitigation to any pretrained Model. + + # Parameters + + vocab : `Vocabulary` + Vocabulary of base model. + base_model : `Model` + Base model for which to mitigate biases. + bias_mitigator : `Lazy[BiasMitigatorWrapper]` + Bias mitigator to apply to base model. + """ + + def __init__( + self, + vocab: Vocabulary, + base_model: Model, + bias_mitigator: Lazy[BiasMitigatorWrapper], + **kwargs + ): + super().__init__(vocab, **kwargs) + + self.base_model = base_model + # want to keep bias mitigation hook during test time + embedding_layer = find_embedding_layer(self.base_model) + + self.bias_mitigator = bias_mitigator.construct(embedding_layer=embedding_layer) + embedding_layer.register_forward_hook(self.bias_mitigator) + + self.vocab = self.base_model.vocab + self._regularizer = self.base_model._regularizer + + @overrides + def train(self, mode: bool = True): + super().train(mode) + self.base_model.train(mode) + # appropriately change requires_grad + # in bias mitigator and bias direction + # when train() and eval() are called + self.bias_mitigator.train(mode) + + # Delegate Model function calls to base_model + # Currently doing this manually because difficult to + # dynamically forward __getattribute__ due to + # behind-the-scenes usage of dunder attributes by torch.nn.Module + # and both BiasMitigatorWrapper and base_model inheriting from Model + # Assumes Model is relatively stable + # TODO: adapt BiasMitigatorWrapper to changes in Model + @overrides + def forward(self, *args, **kwargs): + return self.base_model.forward(*args, **kwargs) + + @overrides + def forward_on_instance(self, *args, **kwargs): + return self.base_model.forward_on_instance(*args, **kwargs) + + @overrides + def forward_on_instances(self, *args, **kwargs): + return self.base_model.forward_on_instances(*args, **kwargs) + + @overrides + def get_regularization_penalty(self, *args, **kwargs): + return self.base_model.get_regularization_penalty(*args, **kwargs) + + @overrides + def get_parameters_for_histogram_logging(self, *args, **kwargs): + return self.base_model.get_parameters_for_histogram_logging(*args, **kwargs) + + @overrides + def get_parameters_for_histogram_tensorboard_logging(self, *args, **kwargs): + return self.base_model.get_parameters_for_histogram_tensorboard_logging(*args, **kwargs) + + @overrides + def make_output_human_readable(self, *args, **kwargs): + return self.base_model.make_output_human_readable(*args, **kwargs) + + @overrides + def get_metrics(self, *args, **kwargs): + return self.base_model.get_metrics(*args, **kwargs) + + @overrides + def _get_prediction_device(self, *args, **kwargs): + return self.base_model._get_prediction_device(*args, **kwargs) + + @overrides + def _maybe_warn_for_unseparable_batches(self, *args, **kwargs): + return self.base_model._maybe_warn_for_unseparable_batches(*args, **kwargs) + + @overrides + def extend_embedder_vocab(self, *args, **kwargs): + return self.base_model.extend_embedder_vocab(*args, **kwargs) diff --git a/allennlp/fairness/bias_mitigator_wrappers.py b/allennlp/fairness/bias_mitigator_wrappers.py new file mode 100644 index 00000000000..6351a6cceac --- /dev/null +++ b/allennlp/fairness/bias_mitigator_wrappers.py @@ -0,0 +1,266 @@ +import torch +from typing import Union, Optional +from os import PathLike + +from allennlp.fairness.bias_mitigators import ( + HardBiasMitigator, + LinearBiasMitigator, + INLPBiasMitigator, + OSCaRBiasMitigator, +) +from allennlp.fairness.bias_direction_wrappers import BiasDirectionWrapper +from allennlp.fairness.bias_utils import load_word_pairs + +from allennlp.common import Registrable +from allennlp.data.tokenizers.tokenizer import Tokenizer +from allennlp.data import Vocabulary + + +class BiasMitigatorWrapper(Registrable): + """ + Parent class for bias mitigator wrappers. + """ + + def train(self, mode: bool = True): + """ + + # Parameters + + mode : `bool`, optional (default=`True`) + Sets `requires_grad` to value of `mode` for bias mitigator + and associated bias direction. + """ + raise NotImplementedError + + +# TODO: remove equalize words from evaluation words +@BiasMitigatorWrapper.register("hard") +class HardBiasMitigatorWrapper(BiasMitigatorWrapper): + """ + + # Parameters + + bias_direction : `BiasDirectionWrapper` + Bias direction used by mitigator. + embedding_layer : `torch.nn.Embedding` + Embedding layer of base model. + equalize_word_pairs_file : `Union[PathLike, str]` + Path of file containing equalize word pairs. + tokenizer : `Tokenizer` + Tokenizer used to tokenize equalize words. + mitigator_vocab : `Vocabulary`, optional (default=`None`) + Vocabulary of tokenizer. If `None`, assumes tokenizer is of + type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute. + namespace : `str`, optional (default=`"tokens"`) + Namespace of mitigator_vocab to use when tokenizing. + Disregarded when mitigator_vocab is `None`. + requires_grad : `bool`, optional (default=`True`) + Option to enable gradient calculation for bias mitigator. + """ + + def __init__( + self, + bias_direction: BiasDirectionWrapper, + embedding_layer: torch.nn.Embedding, + equalize_word_pairs_file: Union[PathLike, str], + tokenizer: Tokenizer, + mitigator_vocab: Optional[Vocabulary] = None, + namespace: str = "tokens", + requires_grad: bool = True, + ): + # use predetermined bias direction + self.bias_direction = bias_direction + self.predetermined_bias_direction = self.bias_direction(embedding_layer) + self.ids1, self.ids2 = load_word_pairs( + equalize_word_pairs_file, tokenizer, mitigator_vocab, namespace + ) + self.mitigator = HardBiasMitigator(requires_grad=requires_grad) + + def __call__(self, module, module_in, module_out): + """ + Called as forward hook. + """ + # embed subword token IDs and mean pool to get + # embedding of original word + ids1_embeddings = [] + for i in self.ids1: + i = i.to(module.weight.device) + ids1_embeddings.append( + torch.mean(module.forward(i), dim=0, keepdim=True) + ) # forward() does not trigger hooks, thereby avoiding infinite recursion + ids2_embeddings = [] + for i in self.ids2: + i = i.to(module.weight.device) + ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) + ids1_embeddings = torch.cat(ids1_embeddings) + ids2_embeddings = torch.cat(ids2_embeddings) + + module_out_size = module_out.size() + # flatten tensor except for last dimension + module_out = module_out.flatten(end_dim=-2) + # only return bias-mitigated evaluation embeddings + module_out = self.mitigator( + module_out, + self.predetermined_bias_direction.to(module_out.device), + ids1_embeddings.to(module_out.device), + ids2_embeddings.to(module_out.device), + )[: module_out.size(0)] + return module_out.reshape(module_out_size) + + def train(self, mode: bool = True): + self.mitigator.requires_grad = mode + self.bias_direction.train(mode) + + +@BiasMitigatorWrapper.register("linear") +class LinearBiasMitigatorWrapper(BiasMitigatorWrapper): + """ + + # Parameters + + bias_direction : `BiasDirectionWrapper` + Bias direction used by mitigator. + embedding_layer : `torch.nn.Embedding` + Embedding layer of base model. + requires_grad : `bool`, optional (default=`True`) + Option to enable gradient calculation for bias mitigator. + """ + + def __init__( + self, + bias_direction: BiasDirectionWrapper, + embedding_layer: torch.nn.Embedding, + requires_grad: bool = True, + ): + # use predetermined bias direction + self.bias_direction = bias_direction + self.predetermined_bias_direction = self.bias_direction(embedding_layer) + self.mitigator = LinearBiasMitigator(requires_grad=requires_grad) + + def __call__(self, module, module_in, module_out): + """ + Called as forward hook. + """ + module_out_size = module_out.size() + # flatten tensor except for last dimension + module_out = module_out.flatten(end_dim=-2) + module_out = self.mitigator( + module_out, self.predetermined_bias_direction.to(module_out.device) + ) + return module_out.reshape(module_out_size) + + def train(self, mode: bool = True): + self.mitigator.requires_grad = mode + self.bias_direction.train(mode) + + +@BiasMitigatorWrapper.register("inlp") +class INLPBiasMitigatorWrapper(BiasMitigatorWrapper): + """ + + # Parameters + + embedding_layer : `torch.nn.Embedding` + Embedding layer of base model. + seed_word_pairs_file : `Union[PathLike, str]` + Path of file containing seed word pairs. + tokenizer : `Tokenizer` + Tokenizer used to tokenize seed words. + mitigator_vocab : `Vocabulary`, optional (default=`None`) + Vocabulary of tokenizer. If `None`, assumes tokenizer is of + type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute. + namespace : `str`, optional (default=`"tokens"`) + Namespace of mitigator_vocab to use when tokenizing. + Disregarded when mitigator_vocab is `None`. + """ + + def __init__( + self, + embedding_layer: torch.nn.Embedding, + seed_word_pairs_file: Union[PathLike, str], + tokenizer: Tokenizer, + mitigator_vocab: Optional[Vocabulary] = None, + namespace: str = "tokens", + ): + self.ids1, self.ids2 = load_word_pairs( + seed_word_pairs_file, tokenizer, mitigator_vocab, namespace + ) + self.mitigator = INLPBiasMitigator() + + def __call__(self, module, module_in, module_out): + """ + Called as forward hook. + """ + # embed subword token IDs and mean pool to get + # embedding of original word + ids1_embeddings = [] + for i in self.ids1: + i = i.to(module.weight.device) + ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) + ids2_embeddings = [] + for i in self.ids2: + i = i.to(module.weight.device) + ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) + ids1_embeddings = torch.cat(ids1_embeddings) + ids2_embeddings = torch.cat(ids2_embeddings) + + module_out_size = module_out.size() + # flatten tensor except for last dimension + module_out = module_out.flatten(end_dim=-2) + module_out = self.mitigator( + module_out, ids1_embeddings.to(module_out.device), ids2_embeddings.to(module_out.device) + ) + return module_out.reshape(module_out_size) + + def train(self, mode: bool = True): + pass + + +@BiasMitigatorWrapper.register("oscar") +class OSCaRBiasMitigatorWrapper(BiasMitigatorWrapper): + """ + + # Parameters + + bias_direction1 : `BiasDirectionWrapper` + Bias direction of first concept subspace used by mitigator. + bias_direction2 : `BiasDirectionWrapper` + Bias direction of second concept subspace used by mitigator. + embedding_layer : `torch.nn.Embedding` + Embedding layer of base model. + requires_grad : `bool`, optional (default=`True`) + Option to enable gradient calculation for bias mitigator. + """ + + def __init__( + self, + bias_direction1: BiasDirectionWrapper, + bias_direction2: BiasDirectionWrapper, + embedding_layer: torch.nn.Embedding, + requires_grad: bool = True, + ): + # use predetermined bias directions + self.bias_direction1 = bias_direction1 + self.predetermined_bias_direction1 = self.bias_direction1(embedding_layer) + self.bias_direction2 = bias_direction2(embedding_layer) + self.predetermined_bias_direction2 = self.bias_direction2(embedding_layer) + self.mitigator = OSCaRBiasMitigator(requires_grad=requires_grad) + + def __call__(self, module, module_in, module_out): + """ + Called as forward hook. + """ + module_out_size = module_out.size() + # flatten tensor except for last dimension + module_out = module_out.flatten(end_dim=-2) + module_out = self.mitigator( + module_out, + self.predetermined_bias_direction1.to(module_out.device), + self.predetermined_bias_direction2.to(module_out.device), + ) + return module_out.reshape(module_out_size) + + def train(self, mode: bool = True): + self.mitigator.requires_grad = mode + self.bias_direction1.train(mode) + self.bias_direction2.train(mode) diff --git a/allennlp/fairness/bias_mitigators.py b/allennlp/fairness/bias_mitigators.py index 113a6472b9b..d3c0f089733 100644 --- a/allennlp/fairness/bias_mitigators.py +++ b/allennlp/fairness/bias_mitigators.py @@ -7,6 +7,7 @@ import numpy as np import scipy import sklearn + from allennlp.common.checks import ConfigurationError diff --git a/allennlp/fairness/bias_utils.py b/allennlp/fairness/bias_utils.py new file mode 100644 index 00000000000..c4bbb33479e --- /dev/null +++ b/allennlp/fairness/bias_utils.py @@ -0,0 +1,111 @@ +import torch +import json +from os import PathLike +from typing import List, Tuple, Union, Optional + +from allennlp.common.file_utils import cached_path +from allennlp.data import Vocabulary +from allennlp.data.tokenizers.tokenizer import Tokenizer + + +def _convert_word_to_ids_tensor(word, tokenizer, vocab, namespace, all_cases): + # function does NOT strip special tokens if tokenizer adds them + if all_cases: + words_list = [word.lower(), word.title(), word.upper()] + else: + words_list = [word] + ids = [] + for w in words_list: + # if vocab is None, use tokenizer vocab (only works for Huggingface PreTrainedTokenizer) + if vocab: + tokens = tokenizer.tokenize(w) + ids.append(torch.tensor([vocab.get_token_index(t.text, namespace) for t in tokens])) + else: + ids.append(torch.tensor(tokenizer.tokenizer(w)["input_ids"])) + return ids + + +def load_words( + fname: Union[str, PathLike], + tokenizer: Tokenizer, + vocab: Optional[Vocabulary] = None, + namespace: str = "tokens", + all_cases: bool = True, +) -> List[torch.Tensor]: + """ + This function loads a list of words from a file, + tokenizes each word into subword tokens, and converts the + tokens into IDs. + + # Parameters + + fname : `Union[str, PathLike]` + Name of file containing list of words to load. + tokenizer : `Tokenizer` + Tokenizer to tokenize words in file. + vocab : `Vocabulary`, optional (default=`None`) + Vocabulary of tokenizer. If `None`, assumes tokenizer is of + type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute. + namespace : `str` + Namespace of vocab to use when tokenizing. + all_cases : `bool`, optional (default=`True`) + Whether to tokenize lower, title, and upper cases of each word. + + # Returns + + word_ids : `List[torch.Tensor]` + List of tensors containing the IDs of subword tokens for + each word in the file. + """ + word_ids = [] + with open(cached_path(fname)) as f: + words = json.load(f) + for w in words: + word_ids.extend(_convert_word_to_ids_tensor(w, tokenizer, vocab, namespace, all_cases)) + return word_ids + + +def load_word_pairs( + fname: Union[str, PathLike], + tokenizer: Tokenizer, + vocab: Optional[Vocabulary] = None, + namespace: str = "token", + all_cases: bool = True, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + This function loads a list of pairs of words from a file, + tokenizes each word into subword tokens, and converts the + tokens into IDs. + + # Parameters + + fname : `Union[str, PathLike]` + Name of file containing list of pairs of words to load. + tokenizer : `Tokenizer` + Tokenizer to tokenize words in file. + vocab : `Vocabulary`, optional (default=`None`) + Vocabulary of tokenizer. If `None`, assumes tokenizer is of + type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute. + namespace : `str` + Namespace of vocab to use when tokenizing. + all_cases : `bool`, optional (default=`True`) + Whether to tokenize lower, title, and upper cases of each word. + + # Returns + + word_ids : `Tuple[List[torch.Tensor], List[torch.Tensor]]` + Pair of lists of tensors containing the IDs of subword tokens for + words in the file. + """ + word_ids1 = [] + word_ids2 = [] + with open(cached_path(fname)) as f: + words = json.load(f) + for w1, w2 in words: + word_ids1.extend( + _convert_word_to_ids_tensor(w1, tokenizer, vocab, namespace, all_cases) + ) + word_ids2.extend( + _convert_word_to_ids_tensor(w2, tokenizer, vocab, namespace, all_cases) + ) + return word_ids1, word_ids2 diff --git a/test_fixtures/fairness/definitional_pairs.json b/test_fixtures/fairness/definitional_pairs.json new file mode 100644 index 00000000000..37ae95e9876 --- /dev/null +++ b/test_fixtures/fairness/definitional_pairs.json @@ -0,0 +1,42 @@ +[ + [ + "woman", + "man" + ], + [ + "girl", + "boy" + ], + [ + "she", + "he" + ], + [ + "mother", + "father" + ], + [ + "daughter", + "son" + ], + [ + "gal", + "guy" + ], + [ + "female", + "male" + ], + [ + "her", + "his" + ], + [ + "herself", + "himself" + ], + [ + "Mary", + "John" + ] +] \ No newline at end of file diff --git a/test_fixtures/fairness/equalize_pairs.json b/test_fixtures/fairness/equalize_pairs.json new file mode 100644 index 00000000000..6fbdacefaed --- /dev/null +++ b/test_fixtures/fairness/equalize_pairs.json @@ -0,0 +1,210 @@ +[ + [ + "monastery", + "convent" + ], + [ + "spokesman", + "spokeswoman" + ], + [ + "Catholic_priest", + "nun" + ], + [ + "Dad", + "Mom" + ], + [ + "Men", + "Women" + ], + [ + "councilman", + "councilwoman" + ], + [ + "grandpa", + "grandma" + ], + [ + "grandsons", + "granddaughters" + ], + [ + "prostate_cancer", + "ovarian_cancer" + ], + [ + "testosterone", + "estrogen" + ], + [ + "uncle", + "aunt" + ], + [ + "wives", + "husbands" + ], + [ + "Father", + "Mother" + ], + [ + "Grandpa", + "Grandma" + ], + [ + "He", + "She" + ], + [ + "boy", + "girl" + ], + [ + "boys", + "girls" + ], + [ + "brother", + "sister" + ], + [ + "brothers", + "sisters" + ], + [ + "businessman", + "businesswoman" + ], + [ + "chairman", + "chairwoman" + ], + [ + "colt", + "filly" + ], + [ + "congressman", + "congresswoman" + ], + [ + "dad", + "mom" + ], + [ + "dads", + "moms" + ], + [ + "dudes", + "gals" + ], + [ + "ex_girlfriend", + "ex_boyfriend" + ], + [ + "father", + "mother" + ], + [ + "fatherhood", + "motherhood" + ], + [ + "fathers", + "mothers" + ], + [ + "fella", + "granny" + ], + [ + "fraternity", + "sorority" + ], + [ + "gelding", + "mare" + ], + [ + "gentleman", + "lady" + ], + [ + "gentlemen", + "ladies" + ], + [ + "grandfather", + "grandmother" + ], + [ + "grandson", + "granddaughter" + ], + [ + "he", + "she" + ], + [ + "himself", + "herself" + ], + [ + "his", + "her" + ], + [ + "king", + "queen" + ], + [ + "kings", + "queens" + ], + [ + "male", + "female" + ], + [ + "males", + "females" + ], + [ + "man", + "woman" + ], + [ + "men", + "women" + ], + [ + "nephew", + "niece" + ], + [ + "prince", + "princess" + ], + [ + "schoolboy", + "schoolgirl" + ], + [ + "son", + "daughter" + ], + [ + "sons", + "daughters" + ], + [ + "twin_brother", + "twin_sister" + ] +] \ No newline at end of file diff --git a/test_fixtures/fairness/gender_specific_full.json b/test_fixtures/fairness/gender_specific_full.json new file mode 100644 index 00000000000..a7f0c73ce68 --- /dev/null +++ b/test_fixtures/fairness/gender_specific_full.json @@ -0,0 +1,1443 @@ +[ + "he", + "his", + "He", + "her", + "she", + "him", + "She", + "man", + "women", + "men", + "His", + "woman", + "spokesman", + "wife", + "himself", + "son", + "mother", + "father", + "chairman", + "daughter", + "husband", + "guy", + "girls", + "girl", + "Her", + "boy", + "King", + "boys", + "brother", + "Chairman", + "spokeswoman", + "female", + "sister", + "Women", + "Man", + "male", + "herself", + "Lions", + "Lady", + "brothers", + "dad", + "actress", + "mom", + "sons", + "girlfriend", + "Kings", + "Men", + "daughters", + "Prince", + "Queen", + "teenager", + "lady", + "Bulls", + "boyfriend", + "sisters", + "Colts", + "mothers", + "Sir", + "king", + "businessman", + "Boys", + "grandmother", + "grandfather", + "deer", + "cousin", + "Woman", + "ladies", + "Girls", + "Father", + "uncle", + "PA", + "Boy", + "Councilman", + "mum", + "Brothers", + "MA", + "males", + "Girl", + "Mom", + "Guy", + "Queens", + "congressman", + "Dad", + "Mother", + "grandson", + "twins", + "bull", + "queen", + "businessmen", + "wives", + "widow", + "nephew", + "bride", + "females", + "aunt", + "Congressman", + "prostate_cancer", + "lesbian", + "chairwoman", + "fathers", + "Son", + "moms", + "Ladies", + "maiden", + "granddaughter", + "younger_brother", + "Princess", + "Guys", + "lads", + "Ma", + "Sons", + "lion", + "Bachelor", + "gentleman", + "fraternity", + "bachelor", + "niece", + "Lion", + "Sister", + "bulls", + "husbands", + "prince", + "colt", + "salesman", + "Bull", + "Sisters", + "hers", + "dude", + "Spokesman", + "beard", + "filly", + "Actress", + "Him", + "princess", + "Brother", + "lesbians", + "councilman", + "actresses", + "Viagra", + "gentlemen", + "stepfather", + "Deer", + "monks", + "Beard", + "Uncle", + "ex_girlfriend", + "lad", + "sperm", + "Daddy", + "testosterone", + "MAN", + "Female", + "nephews", + "maid", + "daddy", + "mare", + "fiance", + "Wife", + "fiancee", + "kings", + "dads", + "waitress", + "Male", + "maternal", + "heroine", + "feminist", + "Mama", + "nieces", + "girlfriends", + "Councilwoman", + "sir", + "stud", + "Mothers", + "mistress", + "lions", + "estranged_wife", + "womb", + "Brotherhood", + "Statesman", + "grandma", + "maternity", + "estrogen", + "ex_boyfriend", + "widows", + "gelding", + "diva", + "teenage_girls", + "nuns", + "Daughter", + "czar", + "ovarian_cancer", + "HE", + "Monk", + "countrymen", + "Grandma", + "teenage_girl", + "penis", + "bloke", + "nun", + "Husband", + "brides", + "housewife", + "spokesmen", + "suitors", + "menopause", + "monastery", + "patriarch", + "Beau", + "motherhood", + "brethren", + "stepmother", + "Dude", + "prostate", + "Moms", + "hostess", + "twin_brother", + "Colt", + "schoolboy", + "eldest", + "brotherhood", + "Godfather", + "fillies", + "stepson", + "congresswoman", + "Chairwoman", + "Daughters", + "uncles", + "witch", + "Mommy", + "monk", + "viagra", + "paternity", + "suitor", + "chick", + "Pa", + "fianc\u00e9", + "sorority", + "macho", + "Spokeswoman", + "businesswoman", + "eldest_son", + "gal", + "statesman", + "schoolgirl", + "fathered", + "goddess", + "hubby", + "mares", + "stepdaughter", + "blokes", + "dudes", + "socialite", + "strongman", + "Witch", + "fianc\u00e9e", + "uterus", + "grandsons", + "Bride", + "studs", + "mama", + "Aunt", + "godfather", + "hens", + "hen", + "mommy", + "Babe", + "estranged_husband", + "Fathers", + "elder_brother", + "boyhood", + "baritone", + "Diva", + "Lesbian", + "grandmothers", + "grandpa", + "boyfriends", + "feminism", + "countryman", + "stallion", + "heiress", + "queens", + "Grandpa", + "witches", + "aunts", + "semen", + "fella", + "granddaughters", + "chap", + "knight", + "widower", + "Maiden", + "salesmen", + "convent", + "KING", + "vagina", + "beau", + "babe", + "HIS", + "beards", + "handyman", + "twin_sister", + "maids", + "gals", + "housewives", + "Gentlemen", + "horsemen", + "Businessman", + "obstetrics", + "fatherhood", + "beauty_queen", + "councilwoman", + "princes", + "matriarch", + "colts", + "manly", + "ma", + "fraternities", + "Spokesmen", + "pa", + "fellas", + "Gentleman", + "councilmen", + "dowry", + "barbershop", + "Monks", + "WOMAN", + "fraternal", + "ballerina", + "manhood", + "Dads", + "heroines", + "granny", + "gynecologist", + "princesses", + "Goddess", + "yo", + "Granny", + "knights", + "eldest_daughter", + "HER", + "underage_girls", + "masculinity", + "Girlfriend", + "bro", + "Grandmother", + "grandfathers", + "crown_prince", + "Restless", + "paternal", + "Queen_Mother", + "Boyfriend", + "womens", + "Males", + "SHE", + "Countess", + "stepchildren", + "Belles", + "bachelors", + "matron", + "momma", + "Legs", + "maidens", + "goddesses", + "landlady", + "sisterhood", + "Grandfather", + "Fraternity", + "Majesty", + "Babes", + "lass", + "maternal_grandmother", + "blondes", + "ma'am", + "Womens", + "divorcee", + "Momma", + "fathering", + "Effie", + "Lad", + "womanhood", + "missus", + "Sisterhood", + "granddad", + "Mens", + "papa", + "gf", + "sis", + "Husbands", + "Hen", + "womanizer", + "gynecological", + "stepsister", + "Handsome", + "Prince_Charming", + "BOY", + "stepdad", + "teen_ager", + "GIRL", + "dame", + "Sorority", + "beauty_pageants", + "raspy", + "harem", + "maternal_grandfather", + "Hes", + "deliveryman", + "septuagenarian", + "damsel", + "paternal_grandmother", + "paramour", + "paternal_grandparents", + "Nun", + "DAD", + "mothering", + "shes", + "HE_'S", + "Nuns", + "teenage_daughters", + "auntie", + "widowed_mother", + "Girlfriends", + "FATHER", + "virile", + "COUPLE", + "grandmas", + "Hubby", + "nan", + "vixen", + "Joan_Crawford", + "stepdaughters", + "endometrial_cancer", + "stepsons", + "loins", + "Grandson", + "Mitchells", + "erections", + "Matron", + "Fella", + "daddies", + "ter", + "Sweetie", + "Dudes", + "Princesses", + "Lads", + "lioness", + "Mamma", + "virility", + "bros", + "womenfolk", + "Heir", + "BROTHERS", + "manliness", + "patriarchs", + "earl", + "sisterly", + "Whore", + "Gynaecology", + "countess", + "convents", + "Oratory", + "witch_doctor", + "mamas", + "yah", + "aunty", + "aunties", + "Heiress", + "lasses", + "Breasts", + "fairer_sex", + "sorority_sisters", + "WIFE", + "Laurels", + "penile", + "nuh", + "mah", + "toms", + "mam", + "Granddad", + "premenopausal_women", + "Granddaddy", + "nana", + "coeds", + "dames", + "herdsman", + "Mammy", + "Fellas", + "Niece", + "menfolk", + "Grandad", + "bloods", + "Gramps", + "damsels", + "Granddaughter", + "mamma", + "concubine", + "Oros", + "Blarney", + "filial", + "broads", + "Ethel_Kennedy", + "ACTRESS", + "Tit", + "fianc", + "Hunk", + "Night_Shift", + "wifey", + "Lothario", + "Holy_Roman_Emperor", + "horse_breeder", + "grandnephew", + "Lewises", + "Muscular", + "feminist_movement", + "Sanan", + "women\u00e2_\u20ac_\u2122", + "Fiancee", + "dowries", + "Carmelite", + "rah", + "n_roller", + "bay_filly", + "belles", + "Uncles", + "PRINCESS", + "womans", + "Homeboy", + "Blokes", + "Charmer", + "codger", + "Delta_Zeta", + "courtesans", + "grandaughter", + "SISTER", + "Highness", + "grandbabies", + "crone", + "Skip_Away", + "noblewoman", + "bf", + "jane", + "philandering_husband", + "Sisqo", + "mammy", + "daugher", + "director_Skip_Bertman", + "DAUGHTER", + "Royal_Highness", + "mannish", + "spinsters", + "Missus", + "madame", + "Godfathers", + "saleswomen", + "beaus", + "Risha", + "luh", + "sah", + "negligee", + "Women\u00e2_\u20ac_\u2122", + "Hos", + "salesgirl", + "grandmom", + "Grandmas", + "Lawsons", + "countrywomen", + "Booby", + "darlin", + "Sheiks", + "boyz", + "wifes", + "Bayi", + "Il_Duce", + "\u00e2_\u20ac_\u0153My", + "fem", + "daugther", + "Potti", + "hussy", + "tch", + "Gelding", + "stemmed_roses", + "Damson", + "puh", + "Tylers", + "neice", + "Mutha", + "GRANDMOTHER", + "youse", + "spurned_lover", + "mae", + "Britt_Ekland", + "clotheshorse", + "Carlita_Kilpatrick", + "Cambest", + "Pretty_Polly", + "banshees", + "male_chauvinist", + "Arliss", + "mommas", + "maidservant", + "Gale_Harold", + "Little_Bo_Peep", + "Cleavers", + "hags", + "blowsy", + "Queen_Elizabeth_I.", + "lassies", + "papas", + "BABE", + "ugly_ducklings", + "Jims", + "hellion", + "Beautician", + "coalminer", + "relaxin", + "El_Mahroug", + "Victoria_Secret_Angel", + "shepherdess", + "Mosco", + "Slacks", + "nanna", + "wifely", + "tomboys", + "LAH", + "hast", + "apo", + "Kaplans", + "milkmaid", + "Robin_Munis", + "John_Barleycorn", + "royal_highness", + "Meanie", + "NAH", + "trollop", + "roh", + "Jewess", + "Sheik_Hamad", + "mumsy", + "Big_Pussy", + "chil_dren", + "Aunt_Bea", + "basso", + "sista", + "girlies", + "nun_Sister", + "chica", + "Bubbas", + "massa", + "Southern_belles", + "Nephews", + "castrations", + "Mister_Ed", + "Grandsons", + "Calaf", + "Malachy_McCourt", + "Shamash", + "hey_hey", + "Harmen", + "sonofabitch", + "Donovans", + "Grannie", + "Kalinka", + "hisself", + "Devean", + "goatherd", + "hinds", + "El_Corredor", + "Kens", + "notorious_womanizer", + "goh", + "Mommas", + "washerwoman", + "Samaira", + "Coo_Coo", + "Governess", + "grandsire", + "PRINCE_WILLIAM", + "gramma", + "him.He", + "Coptic_priest", + "Corbie", + "Kennys", + "thathe", + "Pa_Pa", + "Bristols", + "Hotep", + "snowy_haired", + "El_Prado_Ire", + "Girl_hitmaker", + "Hurleys", + "St._Meinrad", + "sexually_perverted", + "authoress", + "Prudie", + "raven_haired_beauty", + "Bonos", + "domestic_shorthair", + "brothas", + "nymphet", + "Neelma", + "Seita", + "stud_muffin", + "St._Judes", + "yenta", + "bare_shouldered", + "Pinkney_Sr.", + "PRINCE_CHARLES", + "Bisutti", + "sistas", + "Blanche_Devereaux", + "Momoa", + "Quiff", + "Scotswoman", + "balaclava_clad_men", + "Louis_Leakey", + "dearie", + "vacuum_cleaner_salesman", + "grandads", + "postulant", + "SARAH_JESSICA_PARKER", + "AUNT", + "Prince_Dauntless", + "Dalys", + "Darkie", + "Czar_Nicholas", + "Lion_Hearted", + "Boy_recliner", + "baby_mamas", + "giantess", + "Lawd", + "GRANNY", + "fianc_e", + "Bilqis", + "WCTU", + "famly", + "Ellas", + "feminazis", + "Pentheus", + "MAMAS", + "Town_Criers", + "Saggy", + "youngman", + "grandam", + "divorc\u00e9", + "bosomed", + "roon", + "Simmentals", + "eponymous_heroine", + "LEYLAND", + "REE'", + "cain't", + "Evelynn", + "WAH'", + "sistah", + "Horners", + "Elsie_Poncher", + "Coochie", + "rat_terriers", + "Limousins", + "Buchinski", + "Schicchi", + "Carpitcher", + "Khwezi", + "HAH'", + "Shazza", + "Mackeson", + "ROH'", + "kuya", + "novice_nun", + "Shei", + "Elmasri", + "ladykiller", + "6yo", + "Yenta", + "SHEL", + "pater", + "Souse", + "Tahirah", + "comedian_Rodney_Dangerfield", + "Shottle", + "carryin", + "Sath", + "fa'afafine", + "royal_consort", + "hus_band", + "maternal_uncles", + "dressing_provocatively", + "dreamgirl", + "millionaire_industrialist", + "Georgie_Girl", + "Must_Be_Obeyed", + "joh", + "Arabian_stallion", + "ahr", + "mso_para_margin_0in", + "SOO'", + "Biddles", + "Chincoteague_Volunteer_Fire", + "Lisa_Miceli", + "gorgeous_brunette", + "fianc\u017d", + "Moved_fluently", + "Afternoon_Deelites", + "biker_dude", + "Vito_Spatafore", + "MICK_JAGGER", + "Adesida", + "Reineman", + "witz", + "Djamila", + "Glenroe", + "daddys", + "Romanzi", + "gentlewomen", + "Dandie_Dinmont_terrier", + "Excess_Ire", + "By_SYVJ_Staff", + "zan", + "CONFESSIONS", + "Magees", + "wimmin", + "tash", + "Theatrical_Ire", + "Prince_Charmings", + "chocolate_eclair", + "bron", + "daughers", + "Felly", + "fiftyish", + "Spritely", + "GRANDPA", + "distaffer", + "Norbertines", + "DAH'", + "leader_Muammar_Gadaffi", + "swains", + "Prince_Tomohito", + "Honneur", + "Soeur", + "jouster", + "Pharaoh_Amenhotep_III", + "QUEEN_ELIZABETH_II", + "Ne'er", + "Galileo_Ire", + "Fools_Crow", + "Lannisters", + "Devines", + "gonzales", + "columnist_Ann_Landers", + "Moseleys", + "hiz", + "busch", + "roastee", + "toyboys", + "Sheffields", + "grandaunt", + "Galvins", + "Giongo", + "geh", + "flame_haired_actress", + "Grammarian", + "Greg_Evigan", + "frontierswoman", + "Debele", + "rabs", + "nymphets", + "aai", + "BREE", + "Shaqs", + "ZAY", + "pappa", + "Housa", + "refrigerator_repairman", + "artificial_inseminations", + "chickie", + "Rippa", + "teenager_Tracy_Turnblad", + "homebred_colt", + "Abigaille", + "hen_pecked_husband", + "businesman", + "her.She", + "Kaikeyi", + "Stittsworth", + "self_proclaimed_redneck", + "Khella", + "NeW", + "Evers_Swindell", + "Asmerom_Gebreselassie", + "Boy_recliners", + "Cliff_Claven", + "Legge_Bourke", + "Costos", + "d'_honneur", + "sistahs", + "Cabble", + "sahn", + "CROW_AGENCY_Mont", + "jezebel", + "Harrolds", + "ROSARIO_DAWSON", + "INXS_frontman_Michael_Hutchence", + "Gursikh", + "Dadas", + "VIAGA", + "keen_horsewoman", + "Theodoric", + "Eldery", + "lihn", + "Alice_Kramden", + "Santarina", + "radical_cleric_al_Sadr", + "Curleys", + "SY'", + "Fidaa", + "Saptapadi", + "Actor_Sean_Astin", + "Kellita_Smith", + "Doly", + "Libertina", + "Money_McBags", + "Chief_Bearhart", + "choirgirl", + "chestnut_stallion", + "VIGRA", + "BY_JIM_McCONNELL", + "Sal_Vitale", + "Trivia_buffs", + "kumaris", + "fraternal_lodge", + "galpals", + "Borino_Quinn", + "lina", + "LATEST_Rapper", + "Bezar", + "Manro", + "bakla", + "Grisetti", + "blond_bimbo", + "spinster_aunt", + "gurls", + "hiswife", + "paleface", + "Charlye", + "hippie_chicks", + "Khalifas", + "Picture_JUSTIN_SANSON", + "Hepburns", + "yez", + "ALDER", + "Sanussi", + "Lil_Sis", + "McLoughlins", + "Barbra_Jean", + "Lulua", + "thatshe", + "actress_Shohreh_Aghdashloo", + "SIR_ANTHONY_HOPKINS", + "Gloddy", + "ZAH'", + "ORANGE_'S", + "Danielle_Bimber", + "grandmum", + "Kulkis", + "Brazington", + "Marisa_Lenhard_CFA", + "SIR_JOHN", + "Clareman", + "Aqila", + "Heavily_tattooed", + "Libbys", + "thim", + "elocutionist", + "submissives", + "Inja", + "rahm", + "Agnes_Gooch", + "fake_tits", + "nancy_boys", + "Swaidan", + "SHAH'", + "ain'ta_bed", + "Shumail_Raj", + "Duchesse", + "diethylstilbestrol_DES", + "colt_foal", + "unfaithful_lover", + "Maseri", + "nevah", + "SAHN", + "Barths", + "Toughkenamon", + "GUEST_STARS", + "him.But", + "Donna_Claspell", + "gingham_dresses", + "Massage_Parlour", + "wae", + "Wasacz", + "Magistra", + "vihl", + "Smriti_Iraani", + "boyish_haircut", + "workingwoman", + "borthers", + "Capuchin_friars", + "Nejma", + "yes_sirs", + "bivocational_pastor", + "Grafters", + "HOPWOOD", + "Nicknamed_Godzilla", + "yos", + "Berkenfield", + "Missis", + "sitcom_Designing_Women", + "Kafoa", + "trainer_Emma_Lavelle", + "sadomasochistic_dungeon", + "iht", + "desperates", + "predessor", + "wolf_cub", + "indigenous_Peruvians", + "Livia_Soprano", + "troh", + "colt_sired", + "BOND_HILL", + "ihl", + "Drydens", + "rahs", + "Piserchia", + "Sonny_Corinthos", + "bankrobber", + "Fwank", + "feisty_redhead", + "booze_guzzling", + "COOPERS", + "actress_Q'orianka_Kilcher", + "Cortezar", + "twe", + "Jacoub", + "Cindy_Iannarelli", + "Hell_Raiser", + "Fondly_referred", + "Bridal_Shoppe", + "Noleta", + "Christinas", + "IAGRA", + "LaTanya_Richardson", + "Sang_Bender", + "Assasins", + "sorrel_gelding", + "septugenarian", + "Hissy", + "Muqtada_al_Sadr_mook", + "Pfeni", + "MADRID_AFX_Banco_Santander", + "tuchis", + "LeVaughn", + "Gadzicki", + "transvestite_hooker", + "Fame_jockey_Laffit", + "nun_Sister_Mary", + "SAMSONOV", + "Mayflower_Madam", + "Shaque", + "well.He", + "Trainer_Julio_Canani", + "sorrel_mare", + "minivehicle_joint_venture", + "wife_Dwina", + "Aasiya_AH'_see", + "Baratheon", + "Rick_O'Shay", + "Mammies", + "goatie", + "Nell_Gwynne", + "charmingly_awkward", + "Slamma", + "DEHL", + "Lorenzo_Borghese", + "ALMA_Wis.", + "Anne_Scurria", + "father_Peruvians_alternately", + "JULIE_ANDREWS", + "Slim_Pickins", + "Victoria_Secret_stunner", + "BY'", + "Sanam_Devdas", + "pronounced_luh", + "Pasha_Selim", + "\u4e2d\u534e", + "rson", + "maternal_grandmothers", + "IOWA_CITY_Ia", + "Madame_de_Tourvel", + "JAY'", + "Sheika_Mozah_bint_Nasser", + "Hotsy_Totsy", + "D'_Ginto", + "singer_Johnny_Paycheck", + "uterine_prolapse_surgery", + "SCOTTDALE_Pa.", + "AdelaideNow_reports", + "Marcus_Schenkenberg", + "Clyse", + "Obiter_Dicta", + "comic_Sam_Kinison", + "bitties", + "ROCKVILLE_Ind.", + "swimsuit_calendars", + "Decicio_Smith", + "Ma_ma", + "Rie_Miyazawa", + "celibate_chastity", + "gwah", + "ZAY'", + "HER_Majesty", + "Defrere", + "Las_Madrinas", + "\u7c3f_\u8042_\u7ffb", + "Bea_Hamill", + "ARCADIA_Calif._Trainer", + "Bold_Badgett", + "stakes_victress", + "Hoppin_Frog", + "Narumiya", + "Flayfil", + "hardman_Vinnie_Jones", + "Marilyn_Monroe_lookalike", + "Kivanc_Tatlitug", + "Persis_Khambatta", + "SINKING_SPRING_Pa.", + "len_3rd", + "DEAR_TRYING", + "Farndon_Cheshire", + "Krishna_Madiga", + "daughter_Princess_Chulabhorn", + "Marshall_Rooster_Cogburn", + "Kitty_Kiernan", + "Yokich", + "Jarou", + "Serdaris", + "ee_ay", + "Montifiore", + "Chuderewicz", + "Samuel_Le_Bihan", + "filly_Proud_Spell", + "Umm_Hiba", + "pronounced_koo", + "Sandy_Fonzo", + "KOR'", + "Fielder_Civil_kisses", + "Federalsburg_Maryland", + "Nikah_ceremony", + "Brinke_Stevens", + "Yakama_Tribal_Council", + "Capuchin_Father", + "wife_Callista_Bisek", + "Beau_Dare", + "Bedoni", + "Arjun_Punj", + "JOHNNY_KNOXVILLE", + "cap_tain", + "Alderwood_Boys", + "Chi_Eta_Phi", + "ringleader_Charles_Graner", + "Savoies", + "Lalla_Salma", + "Mrs._Potiphar", + "fahn", + "name_Taylor_Sumers", + "Vernita_Green", + "Bollywood_baddie", + "BENBROOK_Texas", + "Assemblyman_Lou_Papan", + "virgin_brides", + "Cho_Eun", + "CATHY_Freeman", + "Uncle_Saul", + "Lao_Brewery", + "Ibo_tribe", + "ruf", + "rival_Edurne_Pasaban", + "Hei_Shangri_La", + "Mommy_dearest", + "interest_Angola_Sonogal", + "Ger_Monsun", + "PUSSYCAT_DOLL", + "Crown_Jewels_Condoms", + "Lord_Marke", + "Patootie", + "Nora_Bey", + "huntin_shootin", + "Minister_Raymond_Tshibanda", + "La_Nina_la_NEEN", + "signature_Whoppers", + "estranged_hubby_Kevin_Federline", + "UR'", + "pill_poppin", + "GEHR'", + "purebred_Arabians", + "husbandly_duties", + "VIAGRA_TIMING", + "Hereford_heifer", + "hushed_monotone_voice", + "Pola_Uddin", + "Wee_Jimmy_Krankie", + "Kwakwanso", + "Our_Galvinator", + "shoh", + "Codependency_Anonymous_Group", + "LA'", + "Taufa'ahau", + "Invincible_Spirit_colt", + "SAH'_dur", + "MOUNT_CARMEL_Pa.", + "watches_attentively", + "SNL_spinoffs", + "Seth_Nitschke", + "Duns_Berwickshire", + "defendant_Colleen_LaRose", + "Silky_O'Sullivan", + "Highcliff_Farm", + "REN'", + "Comestar", + "Satisfied_Frog", + "Jai_Maharashtra", + "ATTICA_Ind.", + "lover_Larry_Birkhead", + "Tami_Megal", + "chauvinist_pigs", + "Phi_sorority", + "Micronesian_immigrant", + "Lia_Boldt", + "Sugar_Tits", + "actress_Kathy_Najimy", + "zhoo", + "Colombo_underboss", + "Katsav_accusers", + "Bess_Houdini", + "rap_mogul_Diddy", + "companions_Khin_Khin", + "Van_Het", + "Mastoi_tribe", + "VITALY", + "ROLLING_STONES_rocker", + "womanizing_cad", + "LILY_COLE", + "paternal_grandfathers", + "Lt._Col._Kurt_Kosmatka", + "Kasseem_Jr.", + "Ji_Ji", + "Wilburforce", + "VIAGRA_DOSE", + "English_Sheepdogs", + "pronounced_Kah", + "Htet_Htet_Oo", + "Brisk_Breeze", + "Eau_du", + "BY_MELANIE_EVANS", + "Neovasc_Medical", + "British_funnyman_RICKY", + "4YO_mare", + "Hemaida", + "MONKTON", + "Mrs_Mujuru", + "BaGhana_BaGhana", + "Shaaban_Abdel_Rahim", + "Edward_Jazlowiecki_lawyer", + "Ajman_Stud", + "manly_pharaoh_even", + "Serra_Madeira_Islands", + "FRAY'", + "panto_dames", + "Khin_Myo", + "dancer_Karima_El_Mahroug", + "CROWN_Princess", + "Baseball_HOFer", + "Hasta_la_Pasta", + "GIRLS_NEXT_DOOR", + "Benedict_Groeschel", + "Bousamra", + "Ruby_Rubacuori_Ruby", + "Monde_Bleu", + "Un_homme_qui", + "Taylor_Sumers", + "Rapper_EMINEM", + "Joe_Menchetti", + "VAY'", + "supermodel_NAOMI_CAMPBELL", + "Supermodel_GISELE_BUNDCHEN", + "Au_Lait", + "Radar_Installed", + "THOMAS_TOWNSHIP_Mich.", + "Rafinesque", + "Herman_Weinrich", + "Abraxas_Antelope", + "raspy_voiced_rocker", + "Manurewa_Cosmopolitan_Club", + "Paraone", + "THE_LEOPARD", + "Boy_Incorporated_LZB", + "Dansili_filly", + "Lumpy_Rutherford", + "unwedded_bliss", + "Bhavna_Sharma", + "Scarvagh", + "en_flagrante", + "Mottu_Maid", + "Dowager_Queen", + "NEEN", + "model_Monika_Zsibrita", + "ROSIE_PEREZ", + "Mattock_Ranger", + "Valorous", + "Surpreme", + "Marwari_businessmen", + "Grandparents_aunts", + "Kimberley_Vlaeminck", + "Lyn_Treece_Boys", + "PDX_Update", + "Virsa_Punjab", + "eyelash_fluttering", + "Pi_fraternity", + "HUNTLEIGH_Mo.", + "novelist_Jilly_Cooper", + "Naha_Shuri_temple", + "Yasmine_Al_Massri", + "Mu_Gamma_Xi", + "Mica_Ertegun", + "Ocleppo", + "VIAGRA_CONTRAINDICATIONS", + "daughter_PEACHES", + "trainer_Geoff_Wragg", + "OVERNIGHT_DELIVERY", + "Fitts_retiree", + "de_Tourvel", + "Lil_Lad", + "north_easterner", + "Aol_Weird_News", + "Somewhat_improbably", + "Sikh_panth", + "Worcester_2m_7f", + "Zainab_Jah", + "OLYMPIC_medalist", + "Enoch_Petrucelly", + "collie_Lassie", + "LOW'", + "clumsiness_Holloway", + "ayr", + "OHR'", + "ROLLING_STONES_guitarist", + "LAH'_nee", + "Ian_Beefy_Botham", + "Awapuni_trainer", + "Glamorous_Granny", + "Chiang_Ching", + "MidAtlantic_Cardiovascular_Associates", + "Yeke", + "Seaforth_Huron_Expositor", + "Westley_Cary_Elwes", + "Cate_Blanchett_Veronica_Guerin", + "Bellas_Gate", + "witch_Glinda", + "wives_mistresses", + "Woodsville_Walmart", + "2YO_colt", + "Manav_Sushant_Singh", + "Pupi_Avati_Il", + "Sigma_Beta_Rho", + "Bishop_Christopher_Senyonjo", + "Vodou_priest", + "Rubel_Chowdhury", + "Claddagh_Ring", + "TAH'_duh_al", + "al_Sadr_mook_TAH'", + "ROBIN_GIBB", + "GAHN'", + "BY_THOMAS_RANSON", + "sister_Carine_Jena", + "Lyphard_mare", + "summa_cum", + "Semenya_grandmother_Maputhi", + "Clare_Nuns", + "Talac", + "sex_hormones_androgens", + "majeste", + "Saint_Ballado_mare", + "Carrie_Huchel", + "Mae_Dok", + "wife_Dieula", + "Earnest_Sirls", + "spoof_bar_mitzvah", + "von_Boetticher", + "Audwin_Mosby", + "Case_presentationWe", + "Vincent_Papandrea", + "KRAY'", + "Sergi_Benavent", + "Le_Poisson", + "Von_Cramm", + "Patti_Mell", + "Raymi_Coya", + "Benjamin_BeBe_Winans", + "Nana_Akosua", + "Auld_Acquaintance", + "Desire_Burunga", + "Company_Wrangler_Nestea", + "ask_Krisy_Plourde", + "JUANITA_BYNUM", + "livia", + "GAMB", + "Gail_Rosario_Dawson", + "Ramgarhia_Sikh", + "Catholic_nun_Sister", + "FOUR_WEDDINGS_AND", + "Robyn_Scherer", + "brother_King_Athelstan", + "Santo_Loquasto_Fences", + "Wee_Frees", + "MARISOL", + "Soliloquy_Stakes", + "Whatever_Spoetzl", + "Marc'Aurelio", + "mon_petit", + "Sabbar_al_Mashhadani", + "KAY'_lee", + "m_zah_MAH'", + "BY_TAMI_ALTHOFF", + "hobbit_Samwise_Gamgee", + "Bahiya_Hariri_sister", + "daddy_Larry_Birkhead", + "Sow_Tracey_Ullman", + "coach_Viljo_Nousiainen", + "Carmen_Lebbos", + "conjoined_twins_Zainab", + "Rob_Komosa", + "ample_bosomed", + "Ageing_rocker", + "psychic_Oda" +] \ No newline at end of file diff --git a/tests/fairness/bias_utils_test.py b/tests/fairness/bias_utils_test.py new file mode 100644 index 00000000000..17accc58a78 --- /dev/null +++ b/tests/fairness/bias_utils_test.py @@ -0,0 +1,79 @@ +import json +import torch + +from allennlp.fairness.bias_utils import load_words, load_word_pairs + +from allennlp.common.file_utils import cached_path +from allennlp.common.testing.test_case import AllenNlpTestCase +from allennlp.data import Instance, Token +from allennlp.data.batch import Batch +from allennlp.data import Vocabulary +from allennlp.data.tokenizers.whitespace_tokenizer import WhitespaceTokenizer +from allennlp.data.token_indexers import SingleIdTokenIndexer +from allennlp.data.fields import TextField + + +class BiasUtilsTest(AllenNlpTestCase): + def setup_method(self): + token_indexer = SingleIdTokenIndexer("tokens") + + self.pairs_fname = ( + "https://raw.githubusercontent.com/tolga-b/debiaswe/" + "4c3fa843ffff45115c43fe112d4283c91d225c09/data/definitional_pairs.json" + ) + with open(cached_path(self.pairs_fname)) as f: + pairs_list = [] + [ + pairs_list.extend( + [w1.lower(), w2.lower(), w1.title(), w2.title(), w1.upper(), w2.upper()] + ) + for w1, w2 in json.load(f) + ] + + text_field = TextField( + [Token(t) for t in pairs_list], + {"tokens": token_indexer}, + ) + instance = Instance({"text": text_field}) + dataset = Batch([instance]) + self.pairs_vocab = Vocabulary.from_instances(dataset) + self.num_pairs = len(set(pairs_list)) + + self.singles_fname = ( + "https://raw.githubusercontent.com/tolga-b/debiaswe/" + "4c3fa843ffff45115c43fe112d4283c91d225c09/data/gender_specific_full.json" + ) + with open(cached_path(self.singles_fname)) as f: + singles_list = json.load(f) + + text_field = TextField( + [Token(t) for t in singles_list], + {"tokens": token_indexer}, + ) + instance = Instance({"text": text_field}) + dataset = Batch([instance]) + self.singles_vocab = Vocabulary.from_instances(dataset) + self.num_singles = len(set(singles_list)) + + super().setup_method() + + def test_load_word_pairs(self): + ids1, ids2 = load_word_pairs( + self.pairs_fname, WhitespaceTokenizer(), self.pairs_vocab, "tokens" + ) + # first two token IDs reserved for [CLS] and [SEP] + assert torch.equal( + torch.tensor([i.item() for i in ids1]), torch.arange(2, self.num_pairs + 2, step=2) + ) + assert torch.equal( + torch.tensor([i.item() for i in ids2]), torch.arange(3, self.num_pairs + 3, step=2) + ) + + def test_load_words(self): + ids = load_words( + self.singles_fname, WhitespaceTokenizer(), self.singles_vocab, "tokens", all_cases=False + ) + # first two token IDs reserved for [CLS] and [SEP] + assert torch.equal( + torch.tensor([i.item() for i in ids]), torch.arange(2, self.num_singles + 2) + ) From 1159432d0f93d0c973839049c9fccccfd6970aec Mon Sep 17 00:00:00 2001 From: epwalsh <epwalsh10@gmail.com> Date: Thu, 3 Jun 2021 09:58:53 -0700 Subject: [PATCH 51/63] Prepare for release v2.5.0 --- CHANGELOG.md | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8778f696a95..de259c2f46d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,18 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased -### Changed - -- Use `dist_reduce_sum` in distributed metrics. -- Allow Google Cloud Storage paths in `cached_path` ("gs://..."). -- Renamed `nn.util.load_state_dict()` to `read_state_dict` to avoid confusion with `torch.nn.Module.load_state_dict()`. -- `TransformerModule.from_pretrained_module` now only accepts a pretrained model ID (e.g. "bert-base-case") instead of - an actual `torch.nn.Module`. Other parameters to this method have changed as well. -- Print the first batch to the console by default. -- Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0). -- Trainer callbacks can now store and restore state in case a training run gets interrupted. -- VilBERT backbone now rolls and unrolls extra dimensions to handle input with > 3 dimensions. -- `BeamSearch` is now a `Registrable` class. +## [v2.5.0](https://github.com/allenai/allennlp/releases/tag/v2.5.0) - 2021-06-03 ### Added @@ -44,6 +33,19 @@ on a downstream task. along with a `RepeatedNGramBlockingConstraint` constraint implementation, which allows for preventing repeated n-grams in the output from `BeamSearch`. - Added `DataCollator` for dynamic operations for each batch. +### Changed + +- Use `dist_reduce_sum` in distributed metrics. +- Allow Google Cloud Storage paths in `cached_path` ("gs://..."). +- Renamed `nn.util.load_state_dict()` to `read_state_dict` to avoid confusion with `torch.nn.Module.load_state_dict()`. +- `TransformerModule.from_pretrained_module` now only accepts a pretrained model ID (e.g. "bert-base-case") instead of + an actual `torch.nn.Module`. Other parameters to this method have changed as well. +- Print the first batch to the console by default. +- Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0). +- Trainer callbacks can now store and restore state in case a training run gets interrupted. +- VilBERT backbone now rolls and unrolls extra dimensions to handle input with > 3 dimensions. +- `BeamSearch` is now a `Registrable` class. + ### Fixed - When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids. @@ -56,6 +58,7 @@ on a downstream task. - Fixed `wandb` callback to work in distributed training. - Fixed `tqdm` logging into multiple files with `allennlp-optuna`. + ## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22 ### Added From 5f76b59c511a4d6694c3f57abe750d271ec08a06 Mon Sep 17 00:00:00 2001 From: epwalsh <epwalsh10@gmail.com> Date: Thu, 3 Jun 2021 10:47:30 -0700 Subject: [PATCH 52/63] tick version for nightly release --- allennlp/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allennlp/version.py b/allennlp/version.py index 55ba0329c4e..1b9acf464ad 100644 --- a/allennlp/version.py +++ b/allennlp/version.py @@ -4,7 +4,7 @@ _MINOR = "5" # On main and in a nightly release the patch should be one ahead of the last # released build. -_PATCH = "0" +_PATCH = "1" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. _SUFFIX = os.environ.get("ALLENNLP_VERSION_SUFFIX", "") From 044e0ffb4a1e83d3ddc3d513a145a8e043ae8fee Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 4 Jun 2021 16:52:48 +0000 Subject: [PATCH 53/63] Bump black from 21.5b1 to 21.5b2 (#5236) Bumps [black](https://github.com/psf/black) from 21.5b1 to 21.5b2. - [Release notes](https://github.com/psf/black/releases) - [Changelog](https://github.com/psf/black/blob/main/CHANGES.md) - [Commits](https://github.com/psf/black/commits) Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 58115480562..3ba110703ca 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -7,7 +7,7 @@ flake8 mypy==0.812 # Automatic code formatting -black==21.5b1 +black==21.5b2 # Allows generation of coverage reports with pytest. pytest-cov From b7fd842086f162cff78292fb6aedde047c86d627 Mon Sep 17 00:00:00 2001 From: Bhadresh Savani <bhadreshpsavani@gmail.com> Date: Mon, 7 Jun 2021 23:32:44 +0530 Subject: [PATCH 54/63] [Docs] Fixes broken link in Fairness_Metrics (#5245) * fixed broken link --- CHANGELOG.md | 4 ++++ allennlp/fairness/fairness_metrics.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index de259c2f46d..aaf1c2d633a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Fixed + +- Fixed Broken link in `allennlp.fairness.fairness_metrics.Separation` docs + ## [v2.5.0](https://github.com/allenai/allennlp/releases/tag/v2.5.0) - 2021-06-03 diff --git a/allennlp/fairness/fairness_metrics.py b/allennlp/fairness/fairness_metrics.py index 752e486c527..60bb3f89256 100644 --- a/allennlp/fairness/fairness_metrics.py +++ b/allennlp/fairness/fairness_metrics.py @@ -196,7 +196,7 @@ def reset(self) -> None: @Metric.register("separation") class Separation(Metric): """ - [Separation]((https://fairmlbook.org) (pg. 12) allows correlation between the + [Separation](https://fairmlbook.org) (pg. 12) allows correlation between the predictions and the protected variable to the extent that it is justified by the gold labels. From 38c930b61ac9ddb6c721a5c2ec35be0565e15827 Mon Sep 17 00:00:00 2001 From: Pete <petew@allenai.org> Date: Mon, 7 Jun 2021 14:55:56 -0700 Subject: [PATCH 55/63] Ensure all relevant allennlp submodules are imported with `import_plugins()` (#5246) * ensure allennlp is a default plugin * fix logging issue * fixes * actually fix --- CHANGELOG.md | 1 + allennlp/common/plugins.py | 8 ++++++++ allennlp/common/util.py | 8 ++++++-- allennlp/tools/archive_surgery.py | 4 ++-- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aaf1c2d633a..684ec6e0081 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed Broken link in `allennlp.fairness.fairness_metrics.Separation` docs +- Ensured all `allennlp` submodules are imported with `allennlp.common.plugins.import_plugins()`. ## [v2.5.0](https://github.com/allenai/allennlp/releases/tag/v2.5.0) - 2021-06-03 diff --git a/allennlp/common/plugins.py b/allennlp/common/plugins.py index e114631f3ab..21cc694ba95 100644 --- a/allennlp/common/plugins.py +++ b/allennlp/common/plugins.py @@ -75,6 +75,14 @@ def import_plugins() -> None: """ Imports the plugins found with `discover_plugins()`. """ + # Ensure all relevant submodules of AllenNLP are imported. + import_module_and_submodules( + "allennlp", + exclude={ + "allennlp.sanity_checks", # deprecated + "allennlp.tools", # things in here are usually run as commands themselves + }, + ) # Workaround for a presumed Python issue where spawned processes can't find modules in the current directory. cwd = os.getcwd() diff --git a/allennlp/common/util.py b/allennlp/common/util.py index 4db2ef6b5fe..b4eba865195 100644 --- a/allennlp/common/util.py +++ b/allennlp/common/util.py @@ -28,6 +28,7 @@ TypeVar, Union, Sequence, + Set, ) import numpy @@ -328,13 +329,16 @@ def push_python_path(path: PathType) -> ContextManagerFunctionReturnType[None]: sys.path.remove(path) -def import_module_and_submodules(package_name: str) -> None: +def import_module_and_submodules(package_name: str, exclude: Optional[Set[str]] = None) -> None: """ Import all submodules under the given package. Primarily useful so that people using AllenNLP as a library can specify their own custom packages and have their custom classes get loaded and registered. """ + if exclude and package_name in exclude: + return + importlib.invalidate_caches() # For some reason, python doesn't always add this by default to your path, but you pretty much @@ -353,7 +357,7 @@ def import_module_and_submodules(package_name: str) -> None: if path_string and module_finder.path != path_string: # type: ignore[union-attr] continue subpackage = f"{package_name}.{name}" - import_module_and_submodules(subpackage) + import_module_and_submodules(subpackage, exclude=exclude) def peak_cpu_memory() -> Dict[int, int]: diff --git a/allennlp/tools/archive_surgery.py b/allennlp/tools/archive_surgery.py index fc1014d23fc..3cba3f57169 100644 --- a/allennlp/tools/archive_surgery.py +++ b/allennlp/tools/archive_surgery.py @@ -22,8 +22,7 @@ from allennlp.common.file_utils import cached_path from allennlp.models.archival import CONFIG_NAME -logger = logging.getLogger() -logger.setLevel(logging.ERROR) +logger = logging.getLogger(__name__) def main(): @@ -79,4 +78,5 @@ def main(): if __name__ == "__main__": + logging.basicConfig(level=logging.ERROR) main() From 0e3a225a1c6e20dfaa8e609b8b6708c4d086bbf8 Mon Sep 17 00:00:00 2001 From: ArjunSubramonian <arjun.subramonian@gmail.com> Date: Thu, 10 Jun 2021 17:04:35 -0700 Subject: [PATCH 56/63] added `on_backward` trainer callback (#5249) * added BackwardCallback * finished tests * fixed linting issue * revised design per Dirk's suggestion * added OnBackwardException, changed loss to batch_ouputs, etc. Co-authored-by: Arjun Subramonian <arjuns@Arjuns-MacBook-Pro.local> --- CHANGELOG.md | 4 + allennlp/training/callbacks/__init__.py | 1 + allennlp/training/callbacks/backward.py | 40 ++++++++++ allennlp/training/callbacks/callback.py | 17 ++++- allennlp/training/gradient_descent_trainer.py | 18 +++-- tests/training/trainer_test.py | 74 +++++++++++++++++++ 6 files changed, 148 insertions(+), 6 deletions(-) create mode 100644 allennlp/training/callbacks/backward.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 684ec6e0081..20b9025d5e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Added `on_backward` training callback which allows for control over backpropagation and gradient manipulation. + ### Fixed - Fixed Broken link in `allennlp.fairness.fairness_metrics.Separation` docs diff --git a/allennlp/training/callbacks/__init__.py b/allennlp/training/callbacks/__init__.py index 3e55e115b43..b6e5a13b9e2 100644 --- a/allennlp/training/callbacks/__init__.py +++ b/allennlp/training/callbacks/__init__.py @@ -4,3 +4,4 @@ from allennlp.training.callbacks.tensorboard import TensorBoardCallback from allennlp.training.callbacks.track_epoch import TrackEpochCallback from allennlp.training.callbacks.wandb import WandBCallback +from allennlp.training.callbacks.backward import MixedPrecisionBackwardCallback, OnBackwardException diff --git a/allennlp/training/callbacks/backward.py b/allennlp/training/callbacks/backward.py new file mode 100644 index 00000000000..be5d3888efe --- /dev/null +++ b/allennlp/training/callbacks/backward.py @@ -0,0 +1,40 @@ +from typing import Dict, TYPE_CHECKING +import torch + +from allennlp.training.callbacks.callback import TrainerCallback + +if TYPE_CHECKING: + from allennlp.training.gradient_descent_trainer import GradientDescentTrainer + + +@TrainerCallback.register("mixed_precision_backward") +class MixedPrecisionBackwardCallback(TrainerCallback): + """ + Performs backpropagation for mixed precision training. + """ + + def on_backward( + self, + trainer: "GradientDescentTrainer", + batch_outputs: Dict[str, torch.Tensor], + backward_called: bool, + **kwargs + ) -> bool: + if backward_called: + raise OnBackwardException() + trainer._scaler.scale(batch_outputs["loss"]).backward() # type: ignore + return True + + +class OnBackwardException(Exception): + """ + The exception type raised if an `on_backward` callback + attempts to call `backward` when `backward_called` is `True`. + """ + + def __init__(self, message="") -> None: + super().__init__( + "Backpropagation has already been performed" + "and the computation graph has been erased, so" + "calling `loss.backward` is not permitted. " + message + ) diff --git a/allennlp/training/callbacks/callback.py b/allennlp/training/callbacks/callback.py index 301e9cb4387..965d4d47cee 100644 --- a/allennlp/training/callbacks/callback.py +++ b/allennlp/training/callbacks/callback.py @@ -1,4 +1,5 @@ from typing import List, Dict, Any, Optional, TYPE_CHECKING +import torch from allennlp.common import Registrable from allennlp.data import TensorDict @@ -12,7 +13,7 @@ class TrainerCallback(Registrable): """ A general callback object that handles multiple events. - This class has `on_batch`, `on_epoch`, and `on_end` methods, corresponding to + This class has `on_backward`, `on_batch`, `on_epoch`, and `on_end` methods, corresponding to each callback type. Each one receives the state of the wrapper object as `self`. This enables easier state sharing between related callbacks. @@ -33,6 +34,20 @@ def on_start( """ self.trainer = trainer + def on_backward( + self, + trainer: "GradientDescentTrainer", + batch_outputs: Dict[str, torch.Tensor], + backward_called: bool, + **kwargs, + ) -> bool: + """ + This callback hook performs backpropagation and allows for gradient manipulation. + `backward_called` indicates if `loss.backward` has been called prior to this callback. + `on_backward` should return `True` if and only if `loss.backward` is called in its body. + """ + return False + def on_batch( self, trainer: "GradientDescentTrainer", diff --git a/allennlp/training/gradient_descent_trainer.py b/allennlp/training/gradient_descent_trainer.py index 0e3f3cb0816..44fd3876f0a 100644 --- a/allennlp/training/gradient_descent_trainer.py +++ b/allennlp/training/gradient_descent_trainer.py @@ -19,6 +19,7 @@ from allennlp.models.model import Model from allennlp.training.callbacks import ConsoleLoggerCallback from allennlp.training.callbacks.confidence_checks import ConfidenceChecksCallback +from allennlp.training.callbacks.backward import MixedPrecisionBackwardCallback from allennlp.training.checkpointer import Checkpointer from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler from allennlp.training.metric_tracker import MetricTracker @@ -148,7 +149,7 @@ class GradientDescentTrainer(Trainer): parameters. This is necessary because we want the saved model to perform as well as the validated model if we load it later. But this may cause problems if you restart the training from checkpoint. - callbacks : `List[Lazy[TrainerCallback]]`, optional (default = `None`) + callbacks : `List[TrainerCallback]`, optional (default = `None`) A list of callbacks that can be called at certain events: e.g. each batch, epoch, and at the start and end of training, etc. @@ -469,10 +470,17 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: batch_reg_loss = reg_loss.item() train_reg_loss += batch_reg_loss # type: ignore - if self._scaler is not None: - self._scaler.scale(loss).backward() - else: - loss.backward() + backward_called = False + for callback in self._callbacks: + backward_called |= callback.on_backward(self, batch_outputs, backward_called) + if not backward_called: + if self._scaler is not None: + MixedPrecisionBackwardCallback(self._serialization_dir).on_backward( + self, batch_outputs, backward_called + ) + else: + loss.backward() + if len(batch_group_outputs) <= 0: continue diff --git a/tests/training/trainer_test.py b/tests/training/trainer_test.py index 2373caafefd..264a52b9313 100644 --- a/tests/training/trainer_test.py +++ b/tests/training/trainer_test.py @@ -30,6 +30,7 @@ TensorBoardCallback, ConfidenceChecksCallback, ConsoleLoggerCallback, + OnBackwardException, ) from allennlp.training.callbacks.confidence_checks import ConfidenceCheckError from allennlp.training.learning_rate_schedulers import CosineWithRestarts @@ -127,6 +128,26 @@ def setup_method(self): self.validation_data_loader.index_with(self.vocab) +class ZeroGradientsBackwardCallback(TrainerCallback): + """ + Zeros all gradients after backpropagation. + """ + + def on_backward( + self, + trainer: "GradientDescentTrainer", + batch_outputs: Dict[str, torch.Tensor], + backward_called: bool, + **kwargs, + ) -> bool: + if backward_called: + raise OnBackwardException() + batch_outputs["loss"].backward() + for param in trainer.model.parameters(): + param.grad.data.zero_() + return True + + class TestTrainer(TrainerTestBase): def test_trainer_can_run(self): trainer = GradientDescentTrainer( @@ -168,6 +189,59 @@ def test_trainer_can_run(self): assert isinstance(metrics["peak_worker_0_memory_MB"], float) assert metrics["peak_worker_0_memory_MB"] > 0 + def test_train_zero_gradients(self): + weights = {} + for name, param in self.model.named_parameters(): + weights[name] = param.data.clone() + + trainer = GradientDescentTrainer( + self.model, + self.optimizer, + self.data_loader, + num_epochs=2, + validation_data_loader=self.validation_data_loader, + callbacks=[ZeroGradientsBackwardCallback(serialization_dir=self.TEST_DIR)], + ) + trainer.train() + + # weights should be the same + for name, param in self.model.named_parameters(): + assert torch.equal(weights[name], param.data) + + def test_two_backward_callbacks(self): + class SecondBackwardCallback(TrainerCallback): + """ + Changes all gradients to 1 after backpropagation. + """ + + def on_backward( + self, + trainer: "GradientDescentTrainer", + batch_outputs: Dict[str, torch.Tensor], + backward_called: bool, + **kwargs, + ) -> bool: + if backward_called: + raise OnBackwardException() + batch_outputs["loss"].backward() + for param in trainer.model.parameters(): + param.grad = torch.ones_like(param.grad, device=param.grad.device) + return True + + with pytest.raises(OnBackwardException): + trainer = GradientDescentTrainer( + self.model, + self.optimizer, + self.data_loader, + num_epochs=2, + validation_data_loader=self.validation_data_loader, + callbacks=[ + ZeroGradientsBackwardCallback(serialization_dir=self.TEST_DIR), + SecondBackwardCallback(serialization_dir=self.TEST_DIR), + ], + ) + trainer.train() + def test_trainer_can_run_exponential_moving_average(self): moving_average = ExponentialMovingAverage(self.model.named_parameters(), decay=0.9999) trainer = GradientDescentTrainer( From 69d05ff5e7560316fe89c16fe7c6d483226e1e0f Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Thu, 24 Jun 2021 22:35:45 +0530 Subject: [PATCH 57/63] Add float mapping to TensorField --- .../huggingface_datasets_reader.py | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 98cadd2815b..a152e078fb3 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -1,5 +1,5 @@ from allennlp.data import DatasetReader, Token, Field, Tokenizer -from allennlp.data.fields import TextField, LabelField, ListField +from allennlp.data.fields import TextField, LabelField, ListField, TensorField from allennlp.data.instance import Instance from datasets import load_dataset, DatasetDict, list_datasets from datasets.features import ( @@ -10,6 +10,8 @@ Value, FeatureType, ) + +import torch from typing import Iterable, Optional, Dict, List, Union @@ -79,8 +81,8 @@ def _read(self, file_path: str) -> Iterable[Instance]: for index in self.shard_iterable(range(len(dataset_split))): yield self.text_to_instance(file_path, dataset_split[index]) - def raise_feature_not_supported_value_error(value): - raise ValueError(f"Datasets feature type {type(value)} is not supported yet.") + def raise_feature_not_supported_value_error(feature_name, feature_type): + raise ValueError(f"Datasets feature {feature_name} type {feature_type} is not supported yet.") def text_to_instance(self, split: str, entry) -> Instance: # type: ignore """ @@ -166,8 +168,11 @@ def _map_Value( # If tokenizer is provided we will use it to split it to tokens # Else put whole text as a single token field = _map_String(value, tokenizer) + + elif feature_type.dtype == "float32" or feature_type.dtype == "float64": + field = _map_Float(value) else: - field = LabelField(value, label_namespace=feature_name, skip_indexing=True) + field = LabelField(value, label_namespace=feature_name, skip_indexing=False) return field @@ -188,6 +193,15 @@ def _map_Sequence( field = ListField(field_list) # datasets Sequence of strings to ListField of LabelField + elif isinstance(item_feature_type, str): + for item in value: + # If tokenizer is provided we will use it to split it to tokens + # Else put whole text as a single token + item_field = _map_Value(feature_name, item, item_feature_type, tokenizer) + field_list.append(item_field) + if len(field_list) > 0: + field = ListField(field_list) + elif isinstance(item_feature_type, ClassLabel): for item in value: item_field = _map_to_Label(feature_name, item, skip_indexing=True) @@ -203,9 +217,9 @@ def _map_Sequence( if len(field_list) > 0: field = ListField(field_list) - + # Add support for Dict else: - HuggingfaceDatasetReader.raise_feature_not_supported_value_error(feature_name) + HuggingfaceDatasetReader.raise_feature_not_supported_value_error(feature_name, item_feature_type) return field @@ -280,7 +294,12 @@ def _map_String(text: str, tokenizer: Optional[Tokenizer]) -> TextField: field = TextField([Token(text)]) return field +def _map_Float(value: float) -> TensorField: + return TensorField(torch.tensor(value)) + # value mapper - Maps a single value to a LabelField def _map_to_Label(namespace, item, skip_indexing=True) -> LabelField: return LabelField(label=item, label_namespace=namespace, skip_indexing=skip_indexing) + + From 356b3831f88aba5d3df56704e28cd918d6fcba8b Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Tue, 29 Jun 2021 12:03:04 +0530 Subject: [PATCH 58/63] Verification tests --- .../huggingface_datasets_reader_test.py | 72 +++++++++++++++++-- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index 33cfe4ef8c5..f3554c8233e 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -137,7 +137,7 @@ def test_read_known_supported_datasets_with_config(self, dataset, config, split) # TODO pab-vmware skip these once MR is ready to check-in @pytest.mark.parametrize( - "dataset", (("swahili"), ("conll2003"), ("dbpedia_14"), ("trec"), ("emotion")) + "dataset", (("swahili"), ("dbpedia_14"), ("trec"), ("emotion")) ) def test_read_known_supported_datasets_without_config(self, dataset): split = "train" @@ -152,7 +152,71 @@ def test_read_known_supported_datasets_without_config(self, dataset): # Confirm all features were mapped assert len(instance.fields) == len(entry) - # def test_air_dialogue(self): - # reader = HuggingfaceDatasetReader(dataset_name="amazon_us_reviews", config_name="Apparel_v1_00") - # instances = list(reader.read("train")) + + # def test_conll2003(self): + # instances = list(HuggingfaceDatasetReader("conll2003").read("test")) # print(instances[0]) + + + @pytest.mark.skip("Requires implementation of Dict") + def test_squad(self): + instances = list(HuggingfaceDatasetReader("squad").read("train")) + print(instances[0]) + + @pytest.mark.parametrize("config", (("default"), ("ptb"))) + def test_sst(self, config): + instances = list(HuggingfaceDatasetReader("sst", config).read("test")) + print(instances[0]) + + def test_open_web_text(self): + instances = list(HuggingfaceDatasetReader("openwebtext").read("plain_text")) + print(instances[0]) + + @pytest.mark.skip("Requires mapping of dict type") + def test_mocha(self): + instances = list(HuggingfaceDatasetReader("mocha").read("test")) + print(instances[0]) + + @pytest.mark.skip("Requires implementation of Dict") + def test_commonsense_qa(self): + instances = list(HuggingfaceDatasetReader("commonsense_qa").read("test")) + print(instances[0]) + + def test_piqa(self): + instances = list(HuggingfaceDatasetReader("piqa").read("test")) + print(instances[0]) + + def test_swag(self): + instances = list(HuggingfaceDatasetReader("swag").read("test")) + print(instances[0]) + + def test_snli(self): + instances = list(HuggingfaceDatasetReader("snli").read("test")) + print(instances[0]) + + def test_multi_nli(self): + instances = list(HuggingfaceDatasetReader("multi_nli").read("test")) + print(instances[0]) + + def test_super_glue(self): + instances = list(HuggingfaceDatasetReader("super_glue").read("test")) + print(instances[0]) + + @pytest.mark.parametrize("config", (("cola"), ("mnli"), ("ax"), ("mnli_matched"), ("mnli_mismatched"), ("mrpc"), ("qnli"),\ + ("qqp"), ("rte"), ("sst2"), ("stsb"), ("wnli"))) + def test_glue(self, config): + instances = list(HuggingfaceDatasetReader("glue", config).read("test")) + print(instances[0]) + + def test_drop(self): + instances = list(HuggingfaceDatasetReader("drop").read("test")) + print(instances[0]) + + + + + + + + + From 3192d708546d1ff57e97c070971fa301e3ac29a3 Mon Sep 17 00:00:00 2001 From: "Abhishek P (VMware)" <pab@vmware.com> Date: Thu, 1 Jul 2021 00:55:35 +0530 Subject: [PATCH 59/63] Attempt to Support Dict --- .../huggingface_datasets_reader.py | 37 ++++++++++++++----- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index a152e078fb3..dfec615bdb6 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -114,7 +114,7 @@ def text_to_instance(self, split: str, entry) -> Instance: # type: ignore field_list: list feature_type = features[feature_name] - fields_to_be_added = _map_Feature(feature_name, entry, feature_type, self.tokenizer) + fields_to_be_added = _map_Feature(feature_name, entry[feature_name], feature_type, self.tokenizer) for field_key in fields_to_be_added: fields[field_key] = fields_to_be_added[field_key] @@ -123,32 +123,34 @@ def text_to_instance(self, split: str, entry) -> Instance: # type: ignore # Feature Mappers - These functions map a FeatureType into Fields def _map_Feature( - feature_name: str, entry: Dict, feature_type, tokenizer: Optional[Tokenizer] + feature_name: str, value, feature_type, tokenizer: Optional[Tokenizer] ) -> Dict[str, Field]: fields_to_be_added: Dict[str, Field] = dict() if isinstance(feature_type, ClassLabel): - fields_to_be_added[feature_name] = _map_ClassLabel(feature_name, entry[feature_name]) + fields_to_be_added[feature_name] = _map_ClassLabel(feature_name, value) # datasets Value can be of different types elif isinstance(feature_type, Value): fields_to_be_added[feature_name] = _map_Value( - feature_name, entry[feature_name], feature_type, tokenizer + feature_name, value, feature_type, tokenizer ) elif isinstance(feature_type, Sequence): fields_to_be_added[feature_name] = _map_Sequence( - feature_name, entry[feature_name], feature_type.feature, tokenizer + feature_name, value, feature_type.feature, tokenizer ) elif isinstance(feature_type, Translation): fields_to_be_added = _map_Translation( - feature_name, entry[feature_name], feature_type, tokenizer + feature_name, value, feature_type, tokenizer ) elif isinstance(feature_type, TranslationVariableLanguages): fields_to_be_added = _map_TranslationVariableLanguages( - feature_name, entry[feature_name], feature_type, tokenizer + feature_name, value, feature_type, tokenizer ) + elif isinstance(feature_type, dict): + fields_to_be_added = _map_Dict(feature_type, value, tokenizer) else: raise ValueError(f"Datasets feature type {type(feature_type)} is not supported yet.") return fields_to_be_added @@ -172,7 +174,7 @@ def _map_Value( elif feature_type.dtype == "float32" or feature_type.dtype == "float64": field = _map_Float(value) else: - field = LabelField(value, label_namespace=feature_name, skip_indexing=False) + field = LabelField(value, label_namespace=feature_name, skip_indexing=True) return field @@ -180,7 +182,7 @@ def _map_Sequence( feature_name, value: Sequence, item_feature_type, tokenizer: Optional[Tokenizer] ) -> Union[ListField]: field_list: List[Field] = list() - field: ListField + field: ListField = None item_field: Field # In HF Sequence and list are considered interchangeable, but there are some distinctions such as if isinstance(item_feature_type, Value): @@ -217,7 +219,15 @@ def _map_Sequence( if len(field_list) > 0: field = ListField(field_list) - # Add support for Dict + + # WIP for drop + elif isinstance(item_feature_type, dict): + for item in value: + item_field = _map_Dict(item_feature_type, value[item], tokenizer) + field_list.append(item_field) + if len(field_list) > 0: + field = ListField(field_list) + else: HuggingfaceDatasetReader.raise_feature_not_supported_value_error(feature_name, item_feature_type) @@ -302,4 +312,11 @@ def _map_Float(value: float) -> TensorField: def _map_to_Label(namespace, item, skip_indexing=True) -> LabelField: return LabelField(label=item, label_namespace=namespace, skip_indexing=skip_indexing) +def _map_Dict(feature_definition: dict, values: dict, tokenizer: Tokenizer) -> Dict[str, Field]: + fields: Dict[str, Field] = dict() + for key in values: + fields[key] = _map_Feature(key, values[key], feature_definition[key], tokenizer) + return fields + + From e32c5b05fee69a58ec2e616432824b793844a63e Mon Sep 17 00:00:00 2001 From: Abhishek Purushothama <abhijnvb@gmail.com> Date: Wed, 4 Aug 2021 22:33:52 +0530 Subject: [PATCH 60/63] Quick changes --- .../data/dataset_readers/huggingface_datasets_reader.py | 9 ++++++--- .../dataset_readers/huggingface_datasets_reader_test.py | 6 +++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index dfec615bdb6..67c6217c674 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -135,9 +135,12 @@ def _map_Feature( ) elif isinstance(feature_type, Sequence): - fields_to_be_added[feature_name] = _map_Sequence( - feature_name, value, feature_type.feature, tokenizer - ) + if type(value) == dict: + fields_to_be_added = _map_Dict(feature_type, value, tokenizer) + else: + fields_to_be_added[feature_name] = _map_Sequence( + feature_name, value, feature_type.feature, tokenizer + ) elif isinstance(feature_type, Translation): fields_to_be_added = _map_Translation( diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index f3554c8233e..e86add641c0 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -158,7 +158,7 @@ def test_read_known_supported_datasets_without_config(self, dataset): # print(instances[0]) - @pytest.mark.skip("Requires implementation of Dict") + # @pytest.mark.skip("Requires implementation of Dict") def test_squad(self): instances = list(HuggingfaceDatasetReader("squad").read("train")) print(instances[0]) @@ -172,9 +172,9 @@ def test_open_web_text(self): instances = list(HuggingfaceDatasetReader("openwebtext").read("plain_text")) print(instances[0]) - @pytest.mark.skip("Requires mapping of dict type") + # @pytest.mark.skip("Requires mapping of dict type") def test_mocha(self): - instances = list(HuggingfaceDatasetReader("mocha").read("test")) + reader = HuggingfaceDatasetReader("mocha").read("test") print(instances[0]) @pytest.mark.skip("Requires implementation of Dict") From 5f702efd190c436b89539726933a485f2bd32e48 Mon Sep 17 00:00:00 2001 From: Abhishek Purushothama <abhishek.purushothama@colorado.edu> Date: Wed, 11 Aug 2021 09:02:50 -0700 Subject: [PATCH 61/63] Dictionary works with SQUAD --- .../huggingface_datasets_reader.py | 47 ++++++++++--------- .../huggingface_datasets_reader_test.py | 41 ++++++++-------- 2 files changed, 48 insertions(+), 40 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 67c6217c674..6b00ae107cc 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -82,7 +82,9 @@ def _read(self, file_path: str) -> Iterable[Instance]: yield self.text_to_instance(file_path, dataset_split[index]) def raise_feature_not_supported_value_error(feature_name, feature_type): - raise ValueError(f"Datasets feature {feature_name} type {feature_type} is not supported yet.") + raise ValueError( + f"Datasets feature {feature_name} type {feature_type} is not supported yet." + ) def text_to_instance(self, split: str, entry) -> Instance: # type: ignore """ @@ -114,7 +116,9 @@ def text_to_instance(self, split: str, entry) -> Instance: # type: ignore field_list: list feature_type = features[feature_name] - fields_to_be_added = _map_Feature(feature_name, entry[feature_name], feature_type, self.tokenizer) + fields_to_be_added = _map_Feature( + feature_name, entry[feature_name], feature_type, self.tokenizer + ) for field_key in fields_to_be_added: fields[field_key] = fields_to_be_added[field_key] @@ -130,22 +134,18 @@ def _map_Feature( fields_to_be_added[feature_name] = _map_ClassLabel(feature_name, value) # datasets Value can be of different types elif isinstance(feature_type, Value): - fields_to_be_added[feature_name] = _map_Value( - feature_name, value, feature_type, tokenizer - ) + fields_to_be_added[feature_name] = _map_Value(feature_name, value, feature_type, tokenizer) elif isinstance(feature_type, Sequence): - if type(value) == dict: - fields_to_be_added = _map_Dict(feature_type, value, tokenizer) + if type(feature_type.feature) == dict: + fields_to_be_added[feature_name] = _map_Dict(feature_type.feature, value, tokenizer) else: fields_to_be_added[feature_name] = _map_Sequence( feature_name, value, feature_type.feature, tokenizer ) elif isinstance(feature_type, Translation): - fields_to_be_added = _map_Translation( - feature_name, value, feature_type, tokenizer - ) + fields_to_be_added = _map_Translation(feature_name, value, feature_type, tokenizer) elif isinstance(feature_type, TranslationVariableLanguages): fields_to_be_added = _map_TranslationVariableLanguages( @@ -166,8 +166,8 @@ def _map_ClassLabel(feature_name: str, value: ClassLabel) -> Field: def _map_Value( feature_name: str, value: Value, feature_type, tokenizer: Optional[Tokenizer] -) -> Union[TextField, LabelField]: - field: Union[TextField, LabelField] +) -> Union[TextField, LabelField, TensorField]: + field: Union[TextField, LabelField, TensorField] if feature_type.dtype == "string": # datasets.Value[string] maps to TextField # If tokenizer is provided we will use it to split it to tokens @@ -176,6 +176,7 @@ def _map_Value( elif feature_type.dtype == "float32" or feature_type.dtype == "float64": field = _map_Float(value) + else: field = LabelField(value, label_namespace=feature_name, skip_indexing=True) return field @@ -183,9 +184,9 @@ def _map_Value( def _map_Sequence( feature_name, value: Sequence, item_feature_type, tokenizer: Optional[Tokenizer] -) -> Union[ListField]: +) -> ListField: field_list: List[Field] = list() - field: ListField = None + field: ListField item_field: Field # In HF Sequence and list are considered interchangeable, but there are some distinctions such as if isinstance(item_feature_type, Value): @@ -223,7 +224,7 @@ def _map_Sequence( if len(field_list) > 0: field = ListField(field_list) - # WIP for drop + # WIP for dropx` elif isinstance(item_feature_type, dict): for item in value: item_field = _map_Dict(item_feature_type, value[item], tokenizer) @@ -232,7 +233,9 @@ def _map_Sequence( field = ListField(field_list) else: - HuggingfaceDatasetReader.raise_feature_not_supported_value_error(feature_name, item_feature_type) + HuggingfaceDatasetReader.raise_feature_not_supported_value_error( + feature_name, item_feature_type + ) return field @@ -307,6 +310,7 @@ def _map_String(text: str, tokenizer: Optional[Tokenizer]) -> TextField: field = TextField([Token(text)]) return field + def _map_Float(value: float) -> TensorField: return TensorField(torch.tensor(value)) @@ -315,11 +319,12 @@ def _map_Float(value: float) -> TensorField: def _map_to_Label(namespace, item, skip_indexing=True) -> LabelField: return LabelField(label=item, label_namespace=namespace, skip_indexing=skip_indexing) -def _map_Dict(feature_definition: dict, values: dict, tokenizer: Tokenizer) -> Dict[str, Field]: + +def _map_Dict( + feature_definition: dict, values: dict, tokenizer: Optional[Tokenizer] +) -> Dict[str, Field]: + # Map it as a Dictionary of List fields: Dict[str, Field] = dict() for key in values: - fields[key] = _map_Feature(key, values[key], feature_definition[key], tokenizer) + fields[key] = _map_Sequence(key, values[key], feature_definition[key], tokenizer) return fields - - - diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index e86add641c0..1c329057cc8 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -7,7 +7,6 @@ # TODO Add test where we compare huggingface wrapped reader with an explicitly coded dataset class HuggingfaceDatasetReaderTest: - """ Test read for some lightweight datasets """ @@ -101,6 +100,7 @@ def test_read_with_invalid_split(self, split): Test to help validate for the known supported datasets Skipped by default, enable when required """ + # TODO pab-vmware skip these once MR is ready to check-in @pytest.mark.parametrize( "dataset, config, split", @@ -136,9 +136,7 @@ def test_read_known_supported_datasets_with_config(self, dataset, config, split) """ # TODO pab-vmware skip these once MR is ready to check-in - @pytest.mark.parametrize( - "dataset", (("swahili"), ("dbpedia_14"), ("trec"), ("emotion")) - ) + @pytest.mark.parametrize("dataset", (("swahili"), ("dbpedia_14"), ("trec"), ("emotion"))) def test_read_known_supported_datasets_without_config(self, dataset): split = "train" huggingface_reader = HuggingfaceDatasetReader(dataset_name=dataset) @@ -152,15 +150,14 @@ def test_read_known_supported_datasets_without_config(self, dataset): # Confirm all features were mapped assert len(instance.fields) == len(entry) - # def test_conll2003(self): # instances = list(HuggingfaceDatasetReader("conll2003").read("test")) # print(instances[0]) - # @pytest.mark.skip("Requires implementation of Dict") def test_squad(self): - instances = list(HuggingfaceDatasetReader("squad").read("train")) + tokenizer: Tokenizer = WhitespaceTokenizer() + instances = list(HuggingfaceDatasetReader("squad", tokenizer=tokenizer).read("train")) print(instances[0]) @pytest.mark.parametrize("config", (("default"), ("ptb"))) @@ -174,7 +171,7 @@ def test_open_web_text(self): # @pytest.mark.skip("Requires mapping of dict type") def test_mocha(self): - reader = HuggingfaceDatasetReader("mocha").read("test") + instances = list(HuggingfaceDatasetReader("mocha").read("test")) print(instances[0]) @pytest.mark.skip("Requires implementation of Dict") @@ -202,8 +199,23 @@ def test_super_glue(self): instances = list(HuggingfaceDatasetReader("super_glue").read("test")) print(instances[0]) - @pytest.mark.parametrize("config", (("cola"), ("mnli"), ("ax"), ("mnli_matched"), ("mnli_mismatched"), ("mrpc"), ("qnli"),\ - ("qqp"), ("rte"), ("sst2"), ("stsb"), ("wnli"))) + @pytest.mark.parametrize( + "config", + ( + ("cola"), + ("mnli"), + ("ax"), + ("mnli_matched"), + ("mnli_mismatched"), + ("mrpc"), + ("qnli"), + ("qqp"), + ("rte"), + ("sst2"), + ("stsb"), + ("wnli"), + ), + ) def test_glue(self, config): instances = list(HuggingfaceDatasetReader("glue", config).read("test")) print(instances[0]) @@ -211,12 +223,3 @@ def test_glue(self, config): def test_drop(self): instances = list(HuggingfaceDatasetReader("drop").read("test")) print(instances[0]) - - - - - - - - - From af029b34de2437a3a4c7f1d5355e0ce12f2dd883 Mon Sep 17 00:00:00 2001 From: Abhishek Purushothama <abhishek.purushothama@colorado.edu> Date: Wed, 11 Aug 2021 09:46:01 -0700 Subject: [PATCH 62/63] Fix typing issues Convert Dict to N ListFields for Dict of Lists --- .../huggingface_datasets_reader.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index 6b00ae107cc..fca61993848 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -138,7 +138,7 @@ def _map_Feature( elif isinstance(feature_type, Sequence): if type(feature_type.feature) == dict: - fields_to_be_added[feature_name] = _map_Dict(feature_type.feature, value, tokenizer) + fields_to_be_added = _map_Dict(feature_type.feature, value, tokenizer, feature_name) else: fields_to_be_added[feature_name] = _map_Sequence( feature_name, value, feature_type.feature, tokenizer @@ -224,13 +224,13 @@ def _map_Sequence( if len(field_list) > 0: field = ListField(field_list) - # WIP for dropx` - elif isinstance(item_feature_type, dict): - for item in value: - item_field = _map_Dict(item_feature_type, value[item], tokenizer) - field_list.append(item_field) - if len(field_list) > 0: - field = ListField(field_list) + # # WIP for dropx` + # elif isinstance(item_feature_type, dict): + # for item in value: + # item_field = _map_Dict(item_feature_type, value[item], tokenizer) + # field_list.append(item_field) + # if len(field_list) > 0: + # field = ListField(field_list) else: HuggingfaceDatasetReader.raise_feature_not_supported_value_error( @@ -321,10 +321,16 @@ def _map_to_Label(namespace, item, skip_indexing=True) -> LabelField: def _map_Dict( - feature_definition: dict, values: dict, tokenizer: Optional[Tokenizer] + feature_definition: dict, + values: dict, + tokenizer: Optional[Tokenizer] = None, + feature_name: Optional[str] = None, ) -> Dict[str, Field]: # Map it as a Dictionary of List fields: Dict[str, Field] = dict() for key in values: - fields[key] = _map_Sequence(key, values[key], feature_definition[key], tokenizer) + key_name: str = key + if feature_name is not None: + key_name = feature_name + "-" + key + fields[key_name] = _map_Sequence(key, values[key], feature_definition[key], tokenizer) return fields From 41b7034d082993d1c028cf9aad4e5686b85d9054 Mon Sep 17 00:00:00 2001 From: Abhishek Purushothama <abhishek.purushothama@colorado.edu> Date: Wed, 11 Aug 2021 15:18:40 -0700 Subject: [PATCH 63/63] Works for Mocha, although may need to add specific handling for SQUAD moving it down the list. --- .../huggingface_datasets_reader.py | 5 +++-- .../huggingface_datasets_reader_test.py | 16 ++++------------ 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/allennlp/data/dataset_readers/huggingface_datasets_reader.py b/allennlp/data/dataset_readers/huggingface_datasets_reader.py index fca61993848..7c98995af50 100644 --- a/allennlp/data/dataset_readers/huggingface_datasets_reader.py +++ b/allennlp/data/dataset_readers/huggingface_datasets_reader.py @@ -186,7 +186,7 @@ def _map_Sequence( feature_name, value: Sequence, item_feature_type, tokenizer: Optional[Tokenizer] ) -> ListField: field_list: List[Field] = list() - field: ListField + field: ListField = list() item_field: Field # In HF Sequence and list are considered interchangeable, but there are some distinctions such as if isinstance(item_feature_type, Value): @@ -326,6 +326,7 @@ def _map_Dict( tokenizer: Optional[Tokenizer] = None, feature_name: Optional[str] = None, ) -> Dict[str, Field]: + # TODO abhishek-p expand this to more generic based on metadata checks # Map it as a Dictionary of List fields: Dict[str, Field] = dict() for key in values: @@ -333,4 +334,4 @@ def _map_Dict( if feature_name is not None: key_name = feature_name + "-" + key fields[key_name] = _map_Sequence(key, values[key], feature_definition[key], tokenizer) - return fields + return fields \ No newline at end of file diff --git a/tests/data/dataset_readers/huggingface_datasets_reader_test.py b/tests/data/dataset_readers/huggingface_datasets_reader_test.py index 1c329057cc8..dba15d14f0d 100644 --- a/tests/data/dataset_readers/huggingface_datasets_reader_test.py +++ b/tests/data/dataset_readers/huggingface_datasets_reader_test.py @@ -80,13 +80,6 @@ def test_read_xnli_all_languages(self): # For XNLI that means 3 fields become 5 assert len(instance.fields) == 5 - def test_non_supported_feature(self): - dataset = "pubmed_qa" - config = "pqa_labeled" - split = "train" - with pytest.raises(ValueError): - next(HuggingfaceDatasetReader(dataset_name=dataset, config_name=config).read(split)) - def test_non_available_dataset(self): with pytest.raises(ValueError): HuggingfaceDatasetReader(dataset_name="surely-such-a-dataset-does-not-exist") @@ -101,7 +94,7 @@ def test_read_with_invalid_split(self, split): Skipped by default, enable when required """ - # TODO pab-vmware skip these once MR is ready to check-in + # TODO abhishek-p skip these once MR is ready to check-in @pytest.mark.parametrize( "dataset, config, split", ( @@ -135,7 +128,7 @@ def test_read_known_supported_datasets_with_config(self, dataset, config, split) Skipped by default, enable when required """ - # TODO pab-vmware skip these once MR is ready to check-in + # TODO abhishek-p skip these once MR is ready to check-in @pytest.mark.parametrize("dataset", (("swahili"), ("dbpedia_14"), ("trec"), ("emotion"))) def test_read_known_supported_datasets_without_config(self, dataset): split = "train" @@ -154,11 +147,10 @@ def test_read_known_supported_datasets_without_config(self, dataset): # instances = list(HuggingfaceDatasetReader("conll2003").read("test")) # print(instances[0]) - # @pytest.mark.skip("Requires implementation of Dict") def test_squad(self): tokenizer: Tokenizer = WhitespaceTokenizer() - instances = list(HuggingfaceDatasetReader("squad", tokenizer=tokenizer).read("train")) - print(instances[0]) + instance_gen = HuggingfaceDatasetReader("squad", tokenizer=tokenizer).read("train") + print(next(instance_gen)) @pytest.mark.parametrize("config", (("default"), ("ptb"))) def test_sst(self, config):