-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Implement gradient checkpointing in GPTBigCode #41818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement gradient checkpointing in GPTBigCode #41818
Conversation
Support for gradient checkpointing was lost in the major refactoring in PR huggingface#38635 and this is the attempt to re-add it. I extended the tests to - test `use_reentrant=True` and `False` - make sure `model.train` is called so that gradient checkpointing works; this is a limiation of the tests currently used by GPTBigCode - make sure that one (the first) gradient checkpointing layer is called - make sure that the same non-zero grads are there for normal and checkpointing runs - this is something we tripped over before in PEFT due to the possibly incompletely stored runtime environment in the checkpointed forward step, see also peft#2826 Note that the invocation of `GPTBigCodeBlock.forward` has changed: - `layer_past` is now passed as a keyword argument so that `GradientCheckpointingLayer.__call__` can see and filter this parameter (`use_reentrant=False` fails otherwise) - `{encoder_}hidden_states` are still passed as positional arguments so that `torch.utils.checkpoint.checkpoint` receives them as pos. args and computes gradients for these (kwargs would be filtered by `GradientCheckpointingLayer`).
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gpt_bigcode |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests are neat, I think we should move them to common tests tho. Not exactly sure why it was specially treated here.
And ig there will be a need for another round to check similar models that may have been accidentally overriden with the ckpting layer 😓 not necessarily this PR tho
| encoder_hidden_states: Optional[torch.Tensor] = None, | ||
| layer_past: Optional[Cache] = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| encoder_hidden_states: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not change the order here, we could break things for users here. Rather change the args, kwargs positions if necessary on the module call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure that this is possible. It is mandatory that we pass layer_past as keyword argument, otherwise GradientCheckpointingLayer will not be able to remove it from the kwargs in case of gradient checkpointing. On the other hand every input that may require gradients (hidden_states, encoder_hidden_states) must be passed as positional argument for checkpoint() to work. Maybe I'm missing something but I don't think we can bring those together without moving encoder_hidden_states up in the list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean that the signature should stay the same, e.g. see
transformers/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Lines 586 to 596 in 84d19be
| def forward( | |
| self, | |
| hidden_states: Optional[tuple[torch.Tensor]], | |
| layer_past: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| use_cache: Optional[bool] = False, | |
| output_attentions: Optional[bool] = False, | |
| **kwargs, |
It will need to adjust the calls from the module above like
transformers/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Lines 901 to 910 in 84d19be
| outputs = block( | |
| hidden_states, | |
| layer_past, | |
| attention_mask, | |
| head_mask[i], | |
| encoder_hidden_states, # as a positional argument for gradient checkpointing | |
| encoder_attention_mask=encoder_attention_mask, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| ) |
Changing the signature is breaking a bit too much!
| self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) | ||
|
|
||
| def create_and_check_forward_and_backwards( | ||
| self, config, input_ids, input_mask, token_type_ids, *args, gradient_checkpointing=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit surprised that it was overriden here. It would be nicer if we could move this into test_modeling_common instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed but I'm not sure how to deal with the fact that not all models use a GradientCheckpointingLayer and sometimes call the function by themselves. Do you have a suggestion how to deal with that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can check if the layer exists somewhere, no? If we do not detect that it exists, raise an error and check which models fails --> all models should already have this but bigger PRs apparently clashed so it isnt the case anymore
| for _, p in trainable_params: | ||
| p.grad = None | ||
|
|
||
| checkpointing_layer = next(m for m in model.modules() if isinstance(m, GradientCheckpointingLayer)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok so we are bound to the new gradient ckpting then. Guess there will be a need to check for all models to use this properly.
| result.loss.backward() | ||
|
|
||
| non_zero_grads_normal = {n for n, p in trainable_params if p.grad.abs().sum() > 0} | ||
| assert non_zero_grads_normal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use
| self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) |
instead of normal asserts. Depends on if we move the test too ig.
Support for gradient checkpointing was lost in the major refactoring in PR #38635 and this is the attempt to re-add it.
I extended the tests to
use_reentrant=TrueandFalsemodel.trainis called so that gradient checkpointing works; this is a limiation of the tests currently used by GPTBigCodeNote that the invocation of
GPTBigCodeBlock.forwardhas changed:layer_pastis now passed as a keyword argument so thatGradientCheckpointingLayer.__call__can see and filter this parameter (use_reentrant=Falsefails otherwise){encoder_}hidden_statesare still passed as positional arguments so thattorch.utils.checkpoint.checkpointreceives them as pos. args and computes gradients for these (kwargs would be filtered byGradientCheckpointingLayer).