Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Reinit layers of pretrained transformer in cached_transformers.get() #5505

Merged

Conversation

JohnGiorgi
Copy link
Contributor

@JohnGiorgi JohnGiorgi commented Dec 10, 2021

Fixes #5491.

Changes proposed in this pull request:

  • Add a new parameter to cached_transformers.get(), reinit_layers.
  • If reinit_layers is an integer, the parameters of the last reinit_layers of the transformer model are re-initialized.
  • If reinit_layers is a list, the parameters of the layers indexed by reinit_layers of the transformer model are re-initialized.
  • Added a test to make sure this feature works as expected.

Before submitting

  • I've read and followed all steps in the Making a pull request
    section of the CONTRIBUTING docs.
  • I've updated or added any relevant docstrings following the syntax described in the
    Writing docstrings section of the CONTRIBUTING docs.
  • If this PR fixes a bug, I've added a test that will fail without my fix.
  • If this PR adds a new feature, I've added tests that sufficiently cover my new functionality.

After submitting

  • All GitHub Actions jobs for my pull request have passed.
  • codecov/patch reports high test coverage (at least 90%).
    You can find this under the "Actions" tab of the pull request once the other checks have finished.

@JohnGiorgi
Copy link
Contributor Author

@epwalsh mind taking a look?

# Optionally, re-initialize the parameters of certain layers.
self._reinit_layers = cast(List[int], reinit_layers)
if self._reinit_layers and load_weights:
num_layers = len(self.transformer_model.encoder.layer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work fine for BERT, but won't for other models like XLM that name their modules differently. In general I think it'll be brittle to rely on guessing the internal module structure of transformer models.

But I can think of alternative which - although it's a little less friendly from a user perspective - still provides this useful functionality:

  • Change the type of reinit_layers to Optional[Union[int, List[int], List[str]]].
  • When reinit_layers is List[str], we interpret it as a list of regular expressions and we reinitialize all modules that match any of the regular expressions in the list.
  • Otherwise when reinit_layers is int or List[int] we try the current approach, but if the given transformer_model does not have this module structure we throw a ConfigurationError and suggest that the user use the List[str] form of this parameter instead.

Copy link
Contributor Author

@JohnGiorgi JohnGiorgi Dec 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I took a crack at this.

  • I renamed reinit_layers to reinit_modules just because re-initializing modules (instead of whole layers) is possible when it is a list of regexes.
  • Updated the type of the argument to Optional[Union[int, List[int], List[str]]]
  • A configuration error is thrown if the model's layers cannot be easily indexed. In practice, most models I tried from hugging face could (BERTs, RoBERTa's) but you were right that XLM based models cant.
  • Added tests for when reinit_modules is a list of regex strings.

Unfortunately, I cant seem to get the weights to re-initialize. I figured that after finding a match between the regex and some module name I could just call module.apply(self.transformer_model._init_weights) like so

for regex in self._reinit_modules:
    for name, module in self.transformer_model.named_modules():
        if re.search(regex, name):
            module.apply(self.transformer_model._init_weights)

but this doesn't actually re-init the modules weights? Strange. I will have to dig into the HF repo and docs to figure out why.

EDIT: My test was broken. I was making a shallow copy of the weights pre re-initialization, which were getting mutated.

@JohnGiorgi
Copy link
Contributor Author

@epwalsh Okay, I believe everything works and this new functionality is covered by a bunch of tests. I think the only outstanding questions are:

  • This code has gotten quite long and complicated. It might be better to place it in its own method rather than __init__.
  • In the unit test, I used "xlm-mlm-enfr-1024" as the XLM model to test, but it is quite big. Might be worth it to look for a smaller XLM model.
  • I am realizing now that it might have made more sense to add this logic to cached_transformers, so that it can be used outside of just PretrainedTransformerEmbedder. If you agree I can move it over.

@epwalsh
Copy link
Member

epwalsh commented Dec 20, 2021

I am realizing now that it might have made more sense to add this logic to cached_transformers, so that it can be used outside of just PretrainedTransformerEmbedder. If you agree I can move it over.

That makes sense to me!

In the unit test, I used "xlm-mlm-enfr-1024" as the XLM model to test, but it is quite big. Might be worth it to look for a smaller XLM model.

We could test with a GPT2 model instead since those don't have an encoder module either. And there is https://huggingface.co/sshleifer/tiny-gpt2.

@JohnGiorgi
Copy link
Contributor Author

I am realizing now that it might have made more sense to add this logic to cached_transformers, so that it can be used outside of just PretrainedTransformerEmbedder. If you agree I can move it over.

That makes sense to me!

Decided against this as I realized it's actually a little complicated. The caching doesn't work because reinit_modules can be a list and that isn't hashable. Leaving this functionality in PretrainedTransformerEmbedder

In the unit test, I used "xlm-mlm-enfr-1024" as the XLM model to test, but it is quite big. Might be worth it to look for a smaller XLM model.

We could test with a GPT2 model instead since those don't have an encoder module either. And there is https://huggingface.co/sshleifer/tiny-gpt2.

Perfect! I used tiny-gpt2 and it made the tests run quite a bit faster

@epwalsh
Copy link
Member

epwalsh commented Dec 22, 2021

Decided against this as I realized it's actually a little complicated. The caching doesn't work because reinit_modules can be a list and that isn't hashable.

We could convert it to a tuple instead, which will be hashable

@JohnGiorgi JohnGiorgi changed the title Reinit layers of pretrained transformer embedder Reinit layers of pretrained transformer in cached_transformers.get() Dec 23, 2021
Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really great! I just have a couple minor comments

allennlp/common/cached_transformers.py Outdated Show resolved Hide resolved
Comment on lines 50 to 56
If this is an integer, the last `reinit_modules` layers of the transformer will be
re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will
be re-initialized. If this is a tuple of strings, they will be treated as regexes and any
module with a name matching the regex will be re-initialized. Re-initializing the last few
layers of a pretrained transformer can reduce the instability of fine-tuning on small
datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect
if `load_weights` is `False` or `override_weights_file` is not None.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be helpful to have a note here about how we can't guarantee using the int or list[int] form will work. And if that fails, the user should use the list[str] form.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call! Updated this docstring in PretrainedTransformerEmbedder and cached_transformers.get() with a note warning that layer indices could fail

Comment on lines +52 to +55
be re-initialized. Note, because the module structure of the transformer `model_name` can
differ, we cannot guarantee that providing an integer or tuple of integers will work. If
this fails, you can instead provide a tuple of strings, which will be treated as regexes and
any module with a name matching the regex will be re-initialized. Re-initializing the last
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@epwalsh Updated part of docstring here

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I think #5529 should fix the failures in CI, so once that goes through and all tests pass I'll merge. Thanks @JohnGiorgi!

@epwalsh epwalsh enabled auto-merge (squash) December 23, 2021 20:26
@JohnGiorgi
Copy link
Contributor Author

LGTM! I think #5529 should fix the failures in CI, so once that goes through and all tests pass I'll merge. Thanks @JohnGiorgi!

Awesome :) Thanks for your help as always!!

@epwalsh epwalsh merged commit 06ec7f9 into allenai:main Dec 23, 2021
@dirkgr
Copy link
Member

dirkgr commented Jan 7, 2022

@JohnGiorgi, @epwalsh, I don't think this makes a ton of sense as implemented. You want to cache the weights because you want to re-use them in many places. But _init_weights() randomizes the weights. The way it is now, you will get the same randomized weights every time. Even if that's what you want, randomizing the weights only makes sense if you train (i.e., alter) them later. When you do that, you need to make a copy of the weights, so you don't change the cached version (and thus affect other users of the cache).

We can still have this be part of cached_transformers.get(), but we should not cache the version with re-initialized weights. reinit_layers should not be part of TransformerSpec. The re-initialization should happen every time get() is called.

@JohnGiorgi JohnGiorgi mentioned this pull request Jan 16, 2022
6 tasks
@JohnGiorgi
Copy link
Contributor Author

@dirkgr I see, good catch. Opened a PR with those fixes here: #5543

@dirkgr
Copy link
Member

dirkgr commented Jan 18, 2022

Thanks!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Re-initialize some layers of a PretrainedTransformerEmbedder
3 participants