Skip to content

Conversation

@githubnemo
Copy link
Contributor

Support for gradient checkpointing was lost in the major refactoring in PR #38635 and this is the attempt to re-add it.

I extended the tests to

  • test use_reentrant=True and False
  • make sure model.train is called so that gradient checkpointing works; this is a limiation of the tests currently used by GPTBigCode
  • make sure that one (the first) gradient checkpointing layer is called
  • make sure that the same non-zero grads are there for normal and checkpointing runs - this is something we tripped over before in PEFT due to the possibly incompletely stored runtime environment in the checkpointed forward step, see also peft#2826

Note that the invocation of GPTBigCodeBlock.forward has changed:

  • layer_past is now passed as a keyword argument so that GradientCheckpointingLayer.__call__ can see and filter this parameter (use_reentrant=False fails otherwise)
  • {encoder_}hidden_states are still passed as positional arguments so that torch.utils.checkpoint.checkpoint receives them as pos. args and computes gradients for these (kwargs would be filtered by GradientCheckpointingLayer).

Support for gradient checkpointing was lost in the major refactoring in PR huggingface#38635
and this is the attempt to re-add it.

I extended the tests to
- test `use_reentrant=True` and `False`
- make sure `model.train` is called so that gradient checkpointing works;
  this is a limiation of the tests currently used by GPTBigCode
- make sure that one (the first) gradient checkpointing layer is called
- make sure that the same non-zero grads are there for normal and checkpointing
  runs - this is something we tripped over before in PEFT due to the possibly
  incompletely stored runtime environment in the checkpointed forward step,
  see also peft#2826

Note that the invocation of `GPTBigCodeBlock.forward` has changed:

- `layer_past` is now passed as a keyword argument so that
  `GradientCheckpointingLayer.__call__` can see and filter this parameter
  (`use_reentrant=False` fails otherwise)
- `{encoder_}hidden_states` are still passed as positional arguments
  so that `torch.utils.checkpoint.checkpoint` receives them as pos. args
  and computes gradients for these (kwargs would be filtered by
  `GradientCheckpointingLayer`).
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gpt_bigcode

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are neat, I think we should move them to common tests tho. Not exactly sure why it was specially treated here.

And ig there will be a need for another round to check similar models that may have been accidentally overriden with the ckpting layer 😓 not necessarily this PR tho

Comment on lines +295 to -296
encoder_hidden_states: Optional[torch.Tensor] = None,
layer_past: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not change the order here, we could break things for users here. Rather change the args, kwargs positions if necessary on the module call

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that this is possible. It is mandatory that we pass layer_past as keyword argument, otherwise GradientCheckpointingLayer will not be able to remove it from the kwargs in case of gradient checkpointing. On the other hand every input that may require gradients (hidden_states, encoder_hidden_states) must be passed as positional argument for checkpoint() to work. Maybe I'm missing something but I don't think we can bring those together without moving encoder_hidden_states up in the list.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean that the signature should stay the same, e.g. see

def forward(
self,
hidden_states: Optional[tuple[torch.Tensor]],
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,

It will need to adjust the calls from the module above like

outputs = block(
hidden_states,
layer_past,
attention_mask,
head_mask[i],
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)

Changing the signature is breaking a bit too much!

self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, token_type_ids, *args, gradient_checkpointing=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit surprised that it was overriden here. It would be nicer if we could move this into test_modeling_common instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed but I'm not sure how to deal with the fact that not all models use a GradientCheckpointingLayer and sometimes call the function by themselves. Do you have a suggestion how to deal with that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can check if the layer exists somewhere, no? If we do not detect that it exists, raise an error and check which models fails --> all models should already have this but bigger PRs apparently clashed so it isnt the case anymore

for _, p in trainable_params:
p.grad = None

checkpointing_layer = next(m for m in model.modules() if isinstance(m, GradientCheckpointingLayer))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok so we are bound to the new gradient ckpting then. Guess there will be a need to check for all models to use this properly.

result.loss.backward()

non_zero_grads_normal = {n for n, p in trainable_params if p.grad.abs().sum() > 0}
assert non_zero_grads_normal
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use

self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])

instead of normal asserts. Depends on if we move the test too ig.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants