-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Reinit layers of pretrained transformer in cached_transformers.get() #5505
Reinit layers of pretrained transformer in cached_transformers.get() #5505
Conversation
@epwalsh mind taking a look? |
# Optionally, re-initialize the parameters of certain layers. | ||
self._reinit_layers = cast(List[int], reinit_layers) | ||
if self._reinit_layers and load_weights: | ||
num_layers = len(self.transformer_model.encoder.layer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should work fine for BERT, but won't for other models like XLM that name their modules differently. In general I think it'll be brittle to rely on guessing the internal module structure of transformer models.
But I can think of alternative which - although it's a little less friendly from a user perspective - still provides this useful functionality:
- Change the type of
reinit_layers
toOptional[Union[int, List[int], List[str]]]
. - When
reinit_layers
isList[str]
, we interpret it as a list of regular expressions and we reinitialize all modules that match any of the regular expressions in the list. - Otherwise when
reinit_layers
isint
orList[int]
we try the current approach, but if the giventransformer_model
does not have this module structure we throw aConfigurationError
and suggest that the user use theList[str]
form of this parameter instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I took a crack at this.
- I renamed
reinit_layers
toreinit_modules
just because re-initializing modules (instead of whole layers) is possible when it is a list of regexes. - Updated the type of the argument to
Optional[Union[int, List[int], List[str]]]
- A configuration error is thrown if the model's layers cannot be easily indexed. In practice, most models I tried from hugging face could (BERTs, RoBERTa's) but you were right that XLM based models cant.
- Added tests for when
reinit_modules
is a list of regex strings.
Unfortunately, I cant seem to get the weights to re-initialize. I figured that after finding a match between the regex and some module
name I could just call module.apply(self.transformer_model._init_weights)
like so
for regex in self._reinit_modules:
for name, module in self.transformer_model.named_modules():
if re.search(regex, name):
module.apply(self.transformer_model._init_weights)
but this doesn't actually re-init the module
s weights? Strange. I will have to dig into the HF repo and docs to figure out why.
EDIT: My test was broken. I was making a shallow copy of the weights pre re-initialization, which were getting mutated.
@epwalsh Okay, I believe everything works and this new functionality is covered by a bunch of tests. I think the only outstanding questions are:
|
That makes sense to me!
We could test with a GPT2 model instead since those don't have an |
Decided against this as I realized it's actually a little complicated. The caching doesn't work because
Perfect! I used tiny-gpt2 and it made the tests run quite a bit faster |
We could convert it to a tuple instead, which will be hashable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really great! I just have a couple minor comments
If this is an integer, the last `reinit_modules` layers of the transformer will be | ||
re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will | ||
be re-initialized. If this is a tuple of strings, they will be treated as regexes and any | ||
module with a name matching the regex will be re-initialized. Re-initializing the last few | ||
layers of a pretrained transformer can reduce the instability of fine-tuning on small | ||
datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect | ||
if `load_weights` is `False` or `override_weights_file` is not None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be helpful to have a note here about how we can't guarantee using the int
or list[int]
form will work. And if that fails, the user should use the list[str]
form.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call! Updated this docstring in PretrainedTransformerEmbedder
and cached_transformers.get()
with a note warning that layer indices could fail
be re-initialized. Note, because the module structure of the transformer `model_name` can | ||
differ, we cannot guarantee that providing an integer or tuple of integers will work. If | ||
this fails, you can instead provide a tuple of strings, which will be treated as regexes and | ||
any module with a name matching the regex will be re-initialized. Re-initializing the last |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@epwalsh Updated part of docstring here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I think #5529 should fix the failures in CI, so once that goes through and all tests pass I'll merge. Thanks @JohnGiorgi!
Awesome :) Thanks for your help as always!! |
@JohnGiorgi, @epwalsh, I don't think this makes a ton of sense as implemented. You want to cache the weights because you want to re-use them in many places. But We can still have this be part of |
Thanks! |
Fixes #5491.
Changes proposed in this pull request:
cached_transformers.get()
,reinit_layers
.reinit_layers
is an integer, the parameters of the lastreinit_layers
of the transformer model are re-initialized.reinit_layers
is a list, the parameters of the layers indexed byreinit_layers
of the transformer model are re-initialized.Before submitting
section of the
CONTRIBUTING
docs.Writing docstrings section of the
CONTRIBUTING
docs.After submitting
codecov/patch
reports high test coverage (at least 90%).You can find this under the "Actions" tab of the pull request once the other checks have finished.