Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand Down Expand Up @@ -266,7 +267,7 @@ def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.Fl
return hidden_states


class GPTBigCodeBlock(nn.Module):
class GPTBigCodeBlock(GradientCheckpointingLayer):
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
Expand All @@ -291,9 +292,9 @@ def __init__(self, config, layer_idx=None):
def forward(
self,
hidden_states: Optional[tuple[torch.Tensor]],
encoder_hidden_states: Optional[torch.Tensor] = None,
layer_past: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
Comment on lines +295 to -296
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not change the order here, we could break things for users here. Rather change the args, kwargs positions if necessary on the module call

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that this is possible. It is mandatory that we pass layer_past as keyword argument, otherwise GradientCheckpointingLayer will not be able to remove it from the kwargs in case of gradient checkpointing. On the other hand every input that may require gradients (hidden_states, encoder_hidden_states) must be passed as positional argument for checkpoint() to work. Maybe I'm missing something but I don't think we can bring those together without moving encoder_hidden_states up in the list.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean that the signature should stay the same, e.g. see

def forward(
self,
hidden_states: Optional[tuple[torch.Tensor]],
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,

It will need to adjust the calls from the module above like

outputs = block(
hidden_states,
layer_past,
attention_mask,
head_mask[i],
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)

Changing the signature is breaking a bit too much!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For viz, as discussed internally, we need this to be breaking

encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
Expand Down Expand Up @@ -536,10 +537,10 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

outputs = block(
hidden_states,
past_key_values,
causal_mask,
hidden_states, # as a positional argument for gradient checkpointing
encoder_hidden_states, # as a positional argument for gradient checkpointing
layer_past=past_key_values, # as keyword argument so it can be removed by GradientCheckpointingLayer
attention_mask=causal_mask,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/swiftformer/modeling_swiftformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch import nn

from ...activations import ACT2CLS
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithNoAttention, ImageClassifierOutputWithNoAttention
from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, logging
Expand Down Expand Up @@ -295,7 +296,7 @@ def forward(self, x):
return x


class SwiftFormerStage(nn.Module):
class SwiftFormerStage(GradientCheckpointingLayer):
"""
A Swiftformer stage consisting of a series of `SwiftFormerConvEncoder` blocks and a final
`SwiftFormerEncoderBlock`.
Expand Down
26 changes: 12 additions & 14 deletions src/transformers/models/xlstm/modeling_xlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,21 @@
from torch.nn import CrossEntropyLoss

from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_xlstm_available
from .configuration_xlstm import xLSTMConfig


if is_xlstm_available():
from xlstm.xlstm_large.model import RMSNorm as xLSTMRMSNorm
from xlstm.xlstm_large.model import mLSTMBlock as xLSTMBlock
from xlstm.xlstm_large.model import mLSTMStateType, soft_cap
from xlstm.xlstm_large.model import mLSTMBlock, mLSTMStateType, soft_cap

external_xlstm = True

class xLSTMBlock(GradientCheckpointingLayer, mLSTMBlock):
pass

else:
from collections.abc import Callable
from functools import partial
Expand Down Expand Up @@ -1164,7 +1168,7 @@ def forward(
y = self.out_proj(h_out)
return y, state

class xLSTMBlock(nn.Module):
class xLSTMBlock(GradientCheckpointingLayer):
def __init__(self, config: xLSTMConfig):
super().__init__()
self.config = config
Expand Down Expand Up @@ -1457,17 +1461,11 @@ def forward(
else:
all_hidden_states = () if output_hidden_states else None
for layer_idx, xlstm_block in enumerate(self.blocks):
if self.gradient_checkpointing and self.training:
hidden_states, rnn_state = self._gradient_checkpointing_func(
xlstm_block.__call__,
hidden_states,
cache_params.rnn_state[layer_idx] if cache_params is not None else None,
)
else:
hidden_states, rnn_state = xlstm_block(
hidden_states,
state=cache_params.rnn_state[layer_idx] if cache_params is not None else None,
)
hidden_states, rnn_state = xlstm_block(
hidden_states,
cache_params.rnn_state[layer_idx] if cache_params is not None else None,
)

if cache_params:
for state_idx in range(len(cache_params.rnn_state[layer_idx])):
local_rnn_state = rnn_state[state_idx]
Expand Down
42 changes: 15 additions & 27 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
Expand Down Expand Up @@ -639,7 +640,7 @@ def forward(
return outputs


class ZambaMambaDecoderLayer(nn.Module):
class ZambaMambaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: ZambaConfig, layer_idx: int):
super().__init__()
self.mamba = ZambaMambaMixer(config=config, layer_idx=layer_idx)
Expand Down Expand Up @@ -708,7 +709,7 @@ def forward(
return outputs


class ZambaHybridLayer(nn.Module):
class ZambaHybridLayer(GradientCheckpointingLayer):
def __init__(self, shared_transf: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer):
super().__init__()
self.shared_transf = shared_transf
Expand Down Expand Up @@ -942,31 +943,18 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
original_hidden_states,
layer_idx,
attention_mask,
causal_mask,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = layer(
hidden_states,
original_hidden_states=original_hidden_states,
layer_idx=layer_idx,
attention_mask=attention_mask,
causal_mask=causal_mask,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
layer_outputs = layer(
hidden_states,
original_hidden_states,
layer_idx,
attention_mask,
causal_mask,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)

hidden_states = layer_outputs[0]

if output_attentions:
Expand Down
45 changes: 16 additions & 29 deletions src/transformers/models/zamba2/modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
Expand Down Expand Up @@ -1058,7 +1059,7 @@ def forward(
return outputs


class Zamba2MambaDecoderLayer(nn.Module):
class Zamba2MambaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Zamba2Config, layer_idx: int):
super().__init__()
self.mamba = Zamba2MambaMixer(config=config, layer_idx=layer_idx)
Expand Down Expand Up @@ -1127,7 +1128,7 @@ def forward(
return outputs


class Zamba2HybridLayer(nn.Module):
class Zamba2HybridLayer(GradientCheckpointingLayer):
def __init__(
self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer
):
Expand Down Expand Up @@ -1344,33 +1345,19 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
original_hidden_states,
layer_idx,
attention_mask,
causal_mask,
past_key_values,
output_attentions,
use_cache,
position_embeddings,
position_ids,
)
else:
layer_outputs = layer(
hidden_states,
original_hidden_states=original_hidden_states,
layer_idx=layer_idx,
attention_mask=attention_mask,
causal_mask=causal_mask,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=position_embeddings,
position_ids=position_ids,
)
layer_outputs = layer(
hidden_states,
original_hidden_states,
layer_idx,
attention_mask,
causal_mask,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=position_embeddings,
position_ids=position_ids,
)

hidden_states = layer_outputs[0]

if output_attentions:
Expand Down
40 changes: 13 additions & 27 deletions src/transformers/models/zamba2/modular_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,33 +1079,19 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
original_hidden_states,
layer_idx,
attention_mask,
causal_mask,
past_key_values,
output_attentions,
use_cache,
position_embeddings,
position_ids,
)
else:
layer_outputs = layer(
hidden_states,
original_hidden_states=original_hidden_states,
layer_idx=layer_idx,
attention_mask=attention_mask,
causal_mask=causal_mask,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=position_embeddings,
position_ids=position_ids,
)
layer_outputs = layer(
hidden_states,
original_hidden_states,
layer_idx,
attention_mask,
causal_mask,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=position_embeddings,
position_ids=position_ids,
)

hidden_states = layer_outputs[0]

if output_attentions:
Expand Down
14 changes: 7 additions & 7 deletions tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,16 @@ def create_and_check_lm_head_model(self, config, input_ids, input_mask, token_ty
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, token_type_ids, *args, gradient_checkpointing=False
self,
config,
input_ids,
input_mask,
token_type_ids,
*args,
):
model = GPTBigCodeForCausalLM(config)
model.train()
model.to(torch_device)
if gradient_checkpointing:
model.gradient_checkpointing_enable()

result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
Expand Down Expand Up @@ -463,10 +467,6 @@ def test_gpt_bigcode_token_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt_bigcode_for_token_classification(*config_and_inputs)

def test_gpt_bigcode_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)

def test_gpt_bigcode_scale_attn_by_inverse_layer_idx(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(scale_attn_by_inverse_layer_idx=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
Expand Down
Loading