diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e1fc5784a067..f2a1d3ff009d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -698,6 +698,8 @@ title: T5 - local: model_doc/t5gemma title: T5Gemma + - local: model_doc/t5gemma2 + title: T5Gemma2 - local: model_doc/t5v1.1 title: T5v1.1 - local: model_doc/ul2 diff --git a/docs/source/en/model_doc/t5gemma2.md b/docs/source/en/model_doc/t5gemma2.md new file mode 100644 index 000000000000..7cf306069a7f --- /dev/null +++ b/docs/source/en/model_doc/t5gemma2.md @@ -0,0 +1,116 @@ + + +
+
+ PyTorch + FlashAttention + SDPA +
+
+ +# T5Gemma 2 + +T5Gemma 2 is a family of pretrained encoder-decoder large language models with strong multilingual, multimodal and long-context capability, available in 270M-270M, 1B-1B and 4B-4B parameters. Following T5Gemma, it is built via model adaptation (based on Gemma 3) using UL2. The architecture is similar to T5Gemma and Gemma 3, enhanced with tied word embeddings and merged self- and cross-attention to save model parameters. + +> [!TIP] +> Click on the T5Gemma 2 models in the right sidebar for more examples of how to apply T5Gemma 2 to different language tasks. + +The example below demonstrates how to chat with the model with [`Pipeline`] or the [`AutoModel`] class, and from the command line. + + + + +```python +import torch +from transformers import pipeline + +generator = pipeline( + "image-text-to-text", + model="google/t5gemma-2-270m-270m", + dtype=torch.bfloat16, + device_map="auto", +) + +generator( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", + text=" in this image, there is", + generate_kwargs={"do_sample": False, "max_new_tokens": 50}, +) +``` + + + + +```python +import torch +import requests +from PIL import Image +from transformers import AutoProcessor, AutoModelForSeq2SeqLM + +processor = AutoProcessor.from_pretrained("google/t5gemma-2-270m-270m") +model = AutoModelForSeq2SeqLM.from_pretrained( + "google/t5gemma-2-270m-270m", + device_map="auto", + dtype=torch.bfloat16, +) + +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" +image = Image.open(requests.get(url, stream=True).raw) +prompt = " in this image, there is" + +model_inputs = processor(text=prompt, images=image, return_tensors="pt") +generation = model.generate(**model_inputs, max_new_tokens=20, do_sample=False) +print(processor.decode(generation[0])) +``` + + + + +## T5Gemma2Config + +[[autodoc]] T5Gemma2Config + +## T5Gemma2TextConfig + +[[autodoc]] T5Gemma2TextConfig + +## T5Gemma2EncoderConfig + +[[autodoc]] T5Gemma2EncoderConfig + +## T5Gemma2DecoderConfig +[[autodoc]] T5Gemma2DecoderConfig + +## T5Gemma2Model + +[[autodoc]] T5Gemma2Model + - forward + +## T5Gemma2ForConditionalGeneration + +[[autodoc]] T5Gemma2ForConditionalGeneration + - forward + +## T5Gemma2ForSequenceClassification + +[[autodoc]] T5Gemma2ForSequenceClassification + - forward + +## T5Gemma2ForTokenClassification + +[[autodoc]] T5Gemma2ForTokenClassification + - forward diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 76b7a9a32ac6..31526419da08 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -339,6 +339,7 @@ from .switch_transformers import * from .t5 import * from .t5gemma import * + from .t5gemma2 import * from .table_transformer import * from .tapas import * from .textnet import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index cdb6cc834163..518f47cfffd2 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -387,6 +387,7 @@ ("switch_transformers", "SwitchTransformersConfig"), ("t5", "T5Config"), ("t5gemma", "T5GemmaConfig"), + ("t5gemma2", "T5Gemma2Config"), ("table-transformer", "TableTransformerConfig"), ("tapas", "TapasConfig"), ("textnet", "TextNetConfig"), @@ -829,6 +830,7 @@ ("switch_transformers", "SwitchTransformers"), ("t5", "T5"), ("t5gemma", "T5Gemma"), + ("t5gemma2", "T5Gemma2"), ("t5v1.1", "T5v1.1"), ("table-transformer", "Table Transformer"), ("tapas", "TAPAS"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index f6f28ad04658..11677dbf51fd 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -187,6 +187,7 @@ ("swin", ("ViTImageProcessor", "ViTImageProcessorFast")), ("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")), ("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("t5gemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), ("table-transformer", ("DetrImageProcessor", "DetrImageProcessorFast")), ("textnet", ("TextNetImageProcessor", "TextNetImageProcessorFast")), ("timesformer", ("VideoMAEImageProcessor", None)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index bedc66fc37bb..60fccf2efe3d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -374,6 +374,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("switch_transformers", "SwitchTransformersModel"), ("t5", "T5Model"), ("t5gemma", "T5GemmaModel"), + ("t5gemma2", "T5Gemma2Model"), ("table-transformer", "TableTransformerModel"), ("tapas", "TapasModel"), ("textnet", "TextNetModel"), @@ -496,6 +497,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("t5gemma", "T5GemmaForConditionalGeneration"), + ("t5gemma2", "T5Gemma2ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("unispeech", "UniSpeechForPreTraining"), ("unispeech-sat", "UniSpeechSatForPreTraining"), @@ -589,6 +591,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("t5gemma", "T5GemmaForConditionalGeneration"), + ("t5gemma2", "T5Gemma2ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("wav2vec2", "Wav2Vec2ForMaskedLM"), ("whisper", "WhisperForConditionalGeneration"), @@ -1015,6 +1018,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), ("shieldgemma2", "Gemma3ForConditionalGeneration"), ("smolvlm", "SmolVLMForConditionalGeneration"), + ("t5gemma2", "T5Gemma2ForConditionalGeneration"), ("udop", "UdopForConditionalGeneration"), ("video_llama_3", "VideoLlama3ForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"), @@ -1137,6 +1141,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("t5gemma", "T5GemmaForConditionalGeneration"), + ("t5gemma2", "T5Gemma2ForConditionalGeneration"), ("umt5", "UMT5ForConditionalGeneration"), ("voxtral", "VoxtralForConditionalGeneration"), ] @@ -1260,6 +1265,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("starcoder2", "Starcoder2ForSequenceClassification"), ("t5", "T5ForSequenceClassification"), ("t5gemma", "T5GemmaForSequenceClassification"), + ("t5gemma2", "T5Gemma2ForSequenceClassification"), ("tapas", "TapasForSequenceClassification"), ("umt5", "UMT5ForSequenceClassification"), ("xlm", "XLMForSequenceClassification"), @@ -1458,6 +1464,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("starcoder2", "Starcoder2ForTokenClassification"), ("t5", "T5ForTokenClassification"), ("t5gemma", "T5GemmaForTokenClassification"), + ("t5gemma2", "T5Gemma2ForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), ("xlm", "XLMForTokenClassification"), ("xlm-roberta", "XLMRobertaForTokenClassification"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 584ac914323a..7cb263f27799 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -138,6 +138,7 @@ ("smolvlm", "SmolVLMProcessor"), ("speech_to_text", "Speech2TextProcessor"), ("speecht5", "SpeechT5Processor"), + ("t5gemma2", "Gemma3Processor"), ("trocr", "TrOCRProcessor"), ("tvp", "TvpProcessor"), ("udop", "UdopProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index be0a1f0dd754..b04cd376e14e 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -718,6 +718,13 @@ "GemmaTokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "t5gemma2", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("tapas", ("TapasTokenizer", None)), ("trocr", ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast" if is_tokenizers_available() else None)), ("tvp", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/t5gemma/__init__.py b/src/transformers/models/t5gemma/__init__.py index aa8099e26782..0688bdb54cbe 100644 --- a/src/transformers/models/t5gemma/__init__.py +++ b/src/transformers/models/t5gemma/__init__.py @@ -18,8 +18,8 @@ if TYPE_CHECKING: - from .configuration_encdecgemma2 import * - from .modeling_encdecgemma2 import * + from .configuration_t5gemma import * + from .modeling_t5gemma import * else: import sys diff --git a/src/transformers/models/t5gemma2/__init__.py b/src/transformers/models/t5gemma2/__init__.py new file mode 100644 index 000000000000..7d018bfe722a --- /dev/null +++ b/src/transformers/models/t5gemma2/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_t5gemma2 import * + from .modeling_t5gemma2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/t5gemma2/configuration_t5gemma2.py b/src/transformers/models/t5gemma2/configuration_t5gemma2.py new file mode 100644 index 000000000000..4ec073a9fc12 --- /dev/null +++ b/src/transformers/models/t5gemma2/configuration_t5gemma2.py @@ -0,0 +1,645 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/t5gemma2/modular_t5gemma2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_t5gemma2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +from ...configuration_utils import PreTrainedConfig, layer_type_validation +from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params +from ...utils import logging +from ..siglip import SiglipVisionConfig + + +logger = logging.get_logger(__name__) + + +class T5Gemma2TextConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5Gemma2TextModel`]. It is used to instantiate an T5Gemma2Text + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the T5Gemma2Text-7B. + e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b) + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 262208): + Vocabulary size of the T5Gemma2Text model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`T5Gemma2TextModel`] + hidden_size (`int`, *optional*, defaults to 2304): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 9216): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): + Scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): + In T5Gemma2Text, every other layer uses sliding window attention. This is the size of the sliding window. + layer_types (`list`, *optional*): + Attention pattern for each layer. + final_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the attention scores. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + use_bidirectional_attention (`bool`, *optional*, defaults to `False`): + If True, the model will attend to all text tokens instead of using a causal mask. This does not change + behavior for vision tokens. + + ```python + >>> from transformers import T5Gemma2TextModel, T5Gemma2TextConfig + >>> # Initializing a T5Gemma2Text t5gemma2_text-7b style configuration + >>> configuration = T5Gemma2TextConfig() + >>> # Initializing a model from the t5gemma2_text-7b style configuration + >>> model = T5Gemma2TextModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "t5gemma2_text" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size: Optional[int] = 262_208, + hidden_size: Optional[int] = 2304, + intermediate_size: Optional[int] = 9216, + num_hidden_layers: Optional[int] = 26, + num_attention_heads: Optional[int] = 8, + num_key_value_heads: Optional[int] = 4, + head_dim: Optional[int] = 256, + hidden_activation: Optional[str] = "gelu_pytorch_tanh", + max_position_embeddings: Optional[int] = 131_072, + initializer_range: Optional[float] = 0.02, + rms_norm_eps: Optional[int] = 1e-6, + use_cache: Optional[bool] = True, + pad_token_id: Optional[int] = 0, + eos_token_id: Optional[int] = 1, + bos_token_id: Optional[int] = 2, + tie_word_embeddings: Optional[bool] = True, + attention_bias: Optional[bool] = False, + attention_dropout: Optional[float] = 0.0, + query_pre_attn_scalar: Optional[int] = 256, + sliding_window: Optional[int] = 4096, + layer_types: Optional[list[str]] = None, + final_logit_softcapping: Optional[float] = None, + attn_logit_softcapping: Optional[float] = None, + rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, + use_bidirectional_attention: Optional[bool] = False, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.query_pre_attn_scalar = query_pre_attn_scalar + self.sliding_window = sliding_window + self.final_logit_softcapping = final_logit_softcapping + self.attn_logit_softcapping = attn_logit_softcapping + self.layer_types = layer_types + + # Try to set `rope_scaling` if available, otherwise use `rope_parameters` + if (rope_scaling := kwargs.pop("rope_scaling", None)) is not None: + if rope_parameters is None: + rope_parameters = {"sliding_attention": {"rope_type": "default"}, "full_attention": rope_scaling} + elif "full_attention" in rope_parameters: + rope_parameters["full_attention"].update(rope_scaling) + else: + rope_parameters.update(rope_scaling) + + self.rope_parameters = rope_parameters + self.use_bidirectional_attention = use_bidirectional_attention + if use_bidirectional_attention: + self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds + + # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub + self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6) + + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + # Validate the correctness of rotary position embeddings parameters + rope_theta = getattr(self, "rope_theta", 1_000_000.0) + rope_local_base_freq = getattr(self, "rope_local_base_freq", 10000.0) + standardize_rope_params( + self, rope_theta={"full_attention": rope_theta, "sliding_attention": rope_local_base_freq} + ) + rope_config_validation(self) + + +class T5Gemma2EncoderConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5Gemma2EncoderForConditionalGeneration`]. It is used to instantiate an + T5Gemma2EncoderForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PaliGemma-2B. + + e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + text_config (`Union[T5Gemma2EncoderTextConfig, dict]`, *optional*): + The config object of the text backbone. + vision_config (`Union[AutoConfig, dict]`, *optional*): + Custom vision config or dict. + mm_tokens_per_image (`int`, *optional*, defaults to 256): + The number of tokens per image embedding. + boi_token_index (`int`, *optional*, defaults to 255999): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 256000): + The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 262144): + The image token index to encode the image prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + + Example: + + ```python + >>> from transformers import T5Gemma2EncoderForConditionalGeneration, T5Gemma2EncoderConfig, SiglipVisionConfig, T5Gemma2EncoderTextConfig + + >>> # Initializing a Siglip-like vision config + >>> vision_config = SiglipVisionConfig() + + >>> # Initializing a T5Gemma2Encoder Text config + >>> text_config = T5Gemma2EncoderTextConfig() + + >>> # Initializing a T5Gemma2Encoder gemma-3-4b style configuration + >>> configuration = T5Gemma2EncoderConfig(vision_config, text_config) + + >>> # Initializing a model from the gemma-3-4b style configuration + >>> model = T5Gemma2EncoderTextConfig(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "t5gemma2_encoder" + attribute_map = { + "image_token_id": "image_token_index", + "boi_token_id": "boi_token_index", + "eoi_token_id": "eoi_token_index", + } + + sub_configs = { + "text_config": T5Gemma2TextConfig, + "vision_config": SiglipVisionConfig, + } + + def __init__( + self, + text_config: Optional[Union[T5Gemma2TextConfig, dict[str, Any]]] = None, + vision_config: Optional[Union[SiglipVisionConfig, dict[str, Any]]] = None, + mm_tokens_per_image: int = 256, + boi_token_index: int = 255_999, + eoi_token_index: int = 256_000, + image_token_index: int = 262_144, + initializer_range: float = 0.02, + **kwargs, + ): + if text_config is None: + text_config = T5Gemma2TextConfig() + logger.info("text_config is None, using default T5Gemma2EncoderTextConfig text config.") + elif isinstance(text_config, dict): + text_config = T5Gemma2TextConfig(**text_config) + + if isinstance(vision_config, dict): + vision_config = SiglipVisionConfig(**vision_config) + elif vision_config is None: + vision_config = SiglipVisionConfig() + logger.info("vision_config is None, using default SiglipVisionConfig vision config.") + + self.text_config = text_config + self.vision_config = vision_config + self.mm_tokens_per_image = mm_tokens_per_image + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + +class T5Gemma2DecoderConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5Gemma2DecoderModel`]. It is used to instantiate an T5Gemma2Decoder + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the T5Gemma2Decoder-7B. + e.g. [google/t5gemma2_text-7b](https://huggingface.co/google/t5gemma2_text-7b) + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 262208): + Vocabulary size of the T5Gemma2Decoder model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`T5Gemma2DecoderModel`] + hidden_size (`int`, *optional*, defaults to 2304): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 9216): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): + Scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): + In T5Gemma2Decoder, every other layer uses sliding window attention. This is the size of the sliding window. + layer_types (`list`, *optional*): + Attention pattern for each layer. + final_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the attention scores. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + use_bidirectional_attention (`bool`, *optional*, defaults to `False`): + If True, the model will attend to all text tokens instead of using a causal mask. This does not change + behavior for vision tokens. + + ```python + >>> from transformers import T5Gemma2DecoderModel, T5Gemma2DecoderConfig + >>> # Initializing a T5Gemma2Decoder t5gemma2_text-7b style configuration + >>> configuration = T5Gemma2DecoderConfig() + >>> # Initializing a model from the t5gemma2_text-7b style configuration + >>> model = T5Gemma2DecoderModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "t5gemma2_decoder" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size: Optional[int] = 262_208, + hidden_size: Optional[int] = 2304, + intermediate_size: Optional[int] = 9216, + num_hidden_layers: Optional[int] = 26, + num_attention_heads: Optional[int] = 8, + num_key_value_heads: Optional[int] = 4, + head_dim: Optional[int] = 256, + hidden_activation: Optional[str] = "gelu_pytorch_tanh", + max_position_embeddings: Optional[int] = 131_072, + initializer_range: Optional[float] = 0.02, + rms_norm_eps: Optional[int] = 1e-6, + use_cache: Optional[bool] = True, + pad_token_id: Optional[int] = 0, + eos_token_id: Optional[int] = 1, + bos_token_id: Optional[int] = 2, + tie_word_embeddings: Optional[bool] = True, + attention_bias: Optional[bool] = False, + attention_dropout: Optional[float] = 0.0, + query_pre_attn_scalar: Optional[int] = 256, + sliding_window: Optional[int] = 4096, + layer_types: Optional[list[str]] = None, + final_logit_softcapping: Optional[float] = None, + attn_logit_softcapping: Optional[float] = None, + rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, + use_bidirectional_attention: Optional[bool] = False, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.query_pre_attn_scalar = query_pre_attn_scalar + self.sliding_window = sliding_window + self.final_logit_softcapping = final_logit_softcapping + self.attn_logit_softcapping = attn_logit_softcapping + self.layer_types = layer_types + + # Try to set `rope_scaling` if available, otherwise use `rope_parameters` + if (rope_scaling := kwargs.pop("rope_scaling", None)) is not None: + if rope_parameters is None: + rope_parameters = {"sliding_attention": {"rope_type": "default"}, "full_attention": rope_scaling} + elif "full_attention" in rope_parameters: + rope_parameters["full_attention"].update(rope_scaling) + else: + rope_parameters.update(rope_scaling) + + self.rope_parameters = rope_parameters + self.use_bidirectional_attention = use_bidirectional_attention + if use_bidirectional_attention: + self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds + + # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub + self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6) + + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + # Validate the correctness of rotary position embeddings parameters + rope_theta = getattr(self, "rope_theta", 1_000_000.0) + rope_local_base_freq = getattr(self, "rope_local_base_freq", 10000.0) + standardize_rope_params( + self, rope_theta={"full_attention": rope_theta, "sliding_attention": rope_local_base_freq} + ) + rope_config_validation(self) + + +class T5Gemma2Config(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5Gemma2Model`]. It is used to instantiate an T5Gemma2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to a hypothetical balanced Gemma3 encoder-decoder model. + e.g. [google/t5gemma-2-270m-270m](https://huggingface.co/google/t5gemma-2-270m-270m) + Configuration objects inherit from [PreTrainedConfig] and can be used to control the model outputs. Read the + documentation from [PreTrainedConfig] for more information. + + Args: + encoder (`Union[T5Gemma2EncoderConfig, dict]`, optional, *optional*): + Configuration for the encoder. + decoder (`Union[T5Gemma2DecoderConfig, dict]`, optional, *optional*): + Configuration for the decoder. + is_encoder_decoder (bool, optional, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + dropout_rate (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers (following T5). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for attention. + classifier_dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier (following T5). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + image_token_index (`int`, *optional*, defaults to 256001): + The image token index to encode the image prompt. Defaults to 256001, which is right after the eoi_token_index. + Note this is different from Gemma 3. + ```python + >>> from transformers import T5Gemma2Config, T5Gemma2Model + >>> t5gemma2_config = T5Gemma2Config.from_pretrained("google/t5gemma-270m-270m") + >>> model = T5Gemma2Model(t5gemma2_config) + ``` + """ + + model_type = "t5gemma2" + keys_to_ignore_at_inference = ["past_key_values"] + + sub_configs = { + "encoder": T5Gemma2EncoderConfig, + "decoder": T5Gemma2DecoderConfig, + } + + attribute_map = { + "image_token_id": "image_token_index", + "eoi_token_id": "eoi_token_index", + } + + def __init__( + self, + encoder: Optional[Union[T5Gemma2EncoderConfig, dict[str, Any]]] = None, + decoder: Optional[Union[T5Gemma2DecoderConfig, dict[str, Any]]] = None, + is_encoder_decoder: bool = True, + dropout_rate: float = 0.0, + attention_dropout: float = 0.0, + classifier_dropout_rate: float = 0.0, + initializer_range: float = 0.02, + image_token_index: int = 256_001, + **kwargs, + ): + if isinstance(encoder, dict): + encoder = T5Gemma2EncoderConfig(**encoder) + elif encoder is None: + encoder = T5Gemma2EncoderConfig() + logger.info("encoder is None, using default T5Gemma2EncoderConfig encoder config.") + else: + if not isinstance(encoder, T5Gemma2EncoderConfig): + raise ValueError(f"{type(encoder)} is not supported.") + + if isinstance(decoder, dict): + decoder = T5Gemma2DecoderConfig(**decoder) + elif decoder is None: + decoder = T5Gemma2DecoderConfig() + logger.info("decoder is None, using default T5Gemma2DecoderConfig decoder config.") + else: + if not isinstance(decoder, T5Gemma2DecoderConfig): + raise ValueError(f"{type(decoder)} is not supported.") + + if encoder.text_config.hidden_size != decoder.hidden_size: + raise ValueError( + "Imbalanced encoder-decoder is not supported in T5Gemma2: " + f"encoder ({encoder.text_config.hidden_size}) vs decoder ({decoder.hidden_size})." + ) + + if not is_encoder_decoder: + raise ValueError("T5Gemma2Model only support encoder-decoder modeling.") + + if encoder.text_config.vocab_size != decoder.vocab_size: + raise ValueError( + "Imbalanced encoder-decoder vocabulary size is not supported in T5Gemma2: " + f"encoder ({encoder.text_config.vocab_size}) vs decoder ({decoder.vocab_size})." + ) + + # Encoder. + encoder.text_config.dropout_rate = dropout_rate + encoder.text_config.attention_dropout = attention_dropout + encoder.vision_config.attention_dropout = attention_dropout + encoder.image_token_index = image_token_index + self.encoder = encoder + + # Decoder. + decoder.dropout_rate = dropout_rate + decoder.attention_dropout = attention_dropout + self.decoder = decoder + + for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id", "vocab_size"]: + if special_token_key not in kwargs: + kwargs[special_token_key] = getattr(decoder, special_token_key) + + super().__init__(**kwargs) + + self.is_encoder_decoder = is_encoder_decoder + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.classifier_dropout_rate = classifier_dropout_rate + self.initializer_range = initializer_range + self.eoi_token_index = encoder.eoi_token_index + self.image_token_index = image_token_index + + def __setattr__(self, key, value): + shared_attr_with_submodules = [ + "output_hidden_states", + "output_attentions", + "_attn_implementation_internal", + "dropout_rate", + "attention_dropout", + "vocab_size", + "dtype", + ] + + if key in shared_attr_with_submodules: + setattr(self.encoder.text_config, key, value) + setattr(self.encoder.vision_config, key, value) + setattr(self.decoder, key, value) + setattr(self.encoder, key, value) + super().__setattr__(key, value) + + +__all__ = ["T5Gemma2Config", "T5Gemma2TextConfig", "T5Gemma2EncoderConfig", "T5Gemma2DecoderConfig"] diff --git a/src/transformers/models/t5gemma2/modeling_t5gemma2.py b/src/transformers/models/t5gemma2/modeling_t5gemma2.py new file mode 100644 index 000000000000..d38437134d64 --- /dev/null +++ b/src/transformers/models/t5gemma2/modeling_t5gemma2.py @@ -0,0 +1,1551 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/t5gemma2/modular_t5gemma2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_t5gemma2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from collections.abc import Callable +from typing import Optional, Union + +import torch +import torch.nn as nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...generation import GenerationConfig, GenerationMixin, GenerationMode +from ...masking_utils import create_bidirectional_mask, create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import OutputRecorder, check_model_inputs +from ..auto import AutoModel +from .configuration_t5gemma2 import T5Gemma2Config, T5Gemma2DecoderConfig, T5Gemma2EncoderConfig, T5Gemma2TextConfig + + +class T5Gemma2RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst T5Gemma2 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class T5Gemma2MLP(nn.Module): + def __init__(self, config: T5Gemma2TextConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, x): + hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + hidden_states = self.dropout(hidden_states) + down_proj = self.down_proj(hidden_states) + return down_proj + + +class T5Gemma2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: T5Gemma2TextConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.layer_types = list(set(config.layer_types)) + self.rope_type = {} + for layer_type in self.layer_types: + rope_params = self.config.rope_parameters[layer_type] + if rope_params is None: + continue + + self.rope_type[layer_type] = rope_params["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]] + curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type) + self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False) + setattr(self, f"{layer_type}_original_inv_freq", curr_inv_freq) + setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling) + + @staticmethod + def compute_default_rope_parameters( + config: Optional[T5Gemma2TextConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + layer_type: Optional[str] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + layer_type (`str`, *optional*): + The current layer type if the model has different RoPE parameters per type. + Should not be used unless `config.layer_types is not None` + + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + # For backward compatibility standardize the `rope_parameters_dict` if it uses old format + base = config.rope_parameters[layer_type]["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids, layer_type=None): + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * attention_scaling + sin = emb.sin() * attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class T5Gemma2SelfAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: T5Gemma2TextConfig, layer_idx: int): + super().__init__() + self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + self.is_causal = not self.config.use_bidirectional_attention + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + self.is_sliding = self.layer_type == "sliding_attention" + + self.q_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.k_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class T5Gemma2MergedAttention(nn.Module): + """Merged self-attention and cross-attention for decoder.""" + + def __init__(self, config: T5Gemma2TextConfig, layer_idx: int): + super().__init__() + self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + self.is_causal = not self.config.use_bidirectional_attention + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + self.is_sliding = self.layer_type == "sliding_attention" + + self.q_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.k_norm = T5Gemma2RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + # decoder self-attention inputs + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + merged_attention_mask: Optional[torch.Tensor], + # cross-attention inputs + encoder_hidden_states: torch.Tensor, + # cache inputs + past_key_values: Optional[EncoderDecoderCache] = None, + cache_position: Optional[torch.LongTensor] = None, + # others + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + # attention shapes. + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + cross_input_shape = encoder_hidden_states.shape[:-1] + cross_hidden_shape = (*cross_input_shape, -1, self.head_dim) + + # self-attention. + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # self-attention. + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + self_attention_cache = past_key_values.self_attention_cache + key_states, value_states = self_attention_cache.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # cross-attention. + is_updated = past_key_values.is_updated.get(self.layer_idx) + cross_attention_cache = past_key_values.cross_attention_cache + + if past_key_values is None or not is_updated: + cross_key_states = self.k_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2) + cross_value_states = self.v_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2) + + cross_key_states = self.k_norm(cross_key_states) + + if past_key_values is not None: + cross_key_states, cross_value_states = cross_attention_cache.update( + cross_key_states, cross_value_states, self.layer_idx + ) + past_key_values.is_updated[self.layer_idx] = True + else: + cross_key_states = cross_attention_cache.layers[self.layer_idx].keys + cross_value_states = cross_attention_cache.layers[self.layer_idx].values + + # merged attention. + query_states = query_states + cross_key_size = cross_input_shape[1] + key_states = torch.cat([key_states, cross_key_states], dim=2) + value_states = torch.cat([value_states, cross_value_states], dim=2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + merged_attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + is_causal=False, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + # decompose merged attention weights into self & cross attention weights + if attn_weights is not None: + self_attn_weights = attn_weights[..., :-cross_key_size] + cross_attn_weights = attn_weights[..., -cross_key_size:] + else: + self_attn_weights, cross_attn_weights = None, None + return attn_output, self_attn_weights, cross_attn_weights + + +class T5Gemma2EncoderLayer(GradientCheckpointingLayer): + """Encoder sub-layer.""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.config = config + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + + self.self_attn = T5Gemma2SelfAttention( + config=config, + layer_idx=layer_idx, + ) + self.pre_self_attn_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_self_attn_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.mlp = T5Gemma2MLP(config) + self.pre_feedforward_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.FloatTensor,]: + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class T5Gemma2DecoderLayer(T5Gemma2EncoderLayer): + """Decoder sub-layer: merged attention instead of vanilla self-attention.""" + + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + + # replace vanilla self-attention with merged attention to support joint cross-attention. + self.self_attn = T5Gemma2MergedAttention( + config=config, + layer_idx=layer_idx, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + merged_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + + hidden_states, _, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + merged_attention_mask=merged_attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + encoder_hidden_states=encoder_hidden_states, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class T5Gemma2LMHead(nn.Module): + """Head for language modeling (generation) tasks.""" + + def __init__(self, hidden_size: int, vocab_size: int, bias: bool = False): + super().__init__() + self.out_proj = nn.Linear(hidden_size, vocab_size, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.out_proj(hidden_states) + return logits + + +class T5Gemma2ClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, hidden_size: int, num_labels: int, classifier_dropout_rate: float = 0.0): + super().__init__() + self.dropout = nn.Dropout(p=classifier_dropout_rate) + self.out_proj = nn.Linear(hidden_size, num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class T5Gemma2MultiModalProjector(nn.Module): + def __init__(self, config: T5Gemma2EncoderConfig): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) + ) + + self.mm_soft_emb_norm = T5Gemma2RMSNorm( + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) + + self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight) + return projected_vision_outputs.type_as(vision_outputs) + + +class T5Gemma2TextScaledWordEmbedding(nn.Embedding): + """T5Gemma2 Embedding: override to add eoi token embedding separately.""" + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + embed_scale: float = 1.0, + eoi_token_index: int = 256_000, + ): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) + self.eoi_token_index = eoi_token_index + self.eoi_embedding = nn.Parameter(torch.zeros(self.embedding_dim)) + + def forward(self, input_ids: torch.Tensor): + input_embeddings = super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) + input_embeddings[input_ids == self.eoi_token_index] = self.eoi_embedding.to(input_embeddings.dtype) + return input_embeddings + + +@auto_docstring +class T5Gemma2PreTrainedModel(PreTrainedModel): + config: T5Gemma2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [ + "T5Gemma2EncoderLayer", + "T5Gemma2DecoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": [T5Gemma2EncoderLayer, T5Gemma2DecoderLayer], + "attentions": [ + OutputRecorder(T5Gemma2SelfAttention, index=1, layer_name="self_attn"), + OutputRecorder(T5Gemma2MergedAttention, index=1, layer_name="self_attn"), + OutputRecorder(T5Gemma2MergedAttention, index=2, layer_name="cross_attn"), + ], + } + input_modalities = ["image", "text"] + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, T5Gemma2MultiModalProjector): + init.zeros_(module.mm_input_projection_weight) + elif isinstance(module, T5Gemma2TextScaledWordEmbedding): + init.zeros_(module.eoi_embedding) + elif isinstance(module, T5Gemma2ClassificationHead): + scale = module.out_proj.weight.shape[0] ** -0.5 + init.normal_(module.out_proj.weight, mean=0.0, std=self.config.initializer_range * scale) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + init.zeros_(module.out_proj.bias) + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif "RMSNorm" in module.__class__.__name__: + init.zeros_(module.weight) + + def prepare_decoder_input_ids_from_labels(self, input_ids): + """ + Shifts input_ids to the right, prepends the decoder_start_token_id, and handles + pad_token_id replacement for labels that were -100. + This is a common preparation step for decoder inputs in sequence-to-sequence models. + """ + decoder_config = self.config.decoder + decoder_start_token_id = decoder_config.bos_token_id + pad_token_id = decoder_config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ") + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.decoder.pad_token_id has to be defined.") + + # Is this T5 specific? + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def sliding_window_mask_function(sliding_window: int, is_causal=True) -> Callable: + """ + This creates uni/bidirectional attention mask with sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + if is_causal: + left_window_size, right_window_size = sliding_window, 0 + else: + left_window_size, right_window_size = ((sliding_window + 1) // 2, (sliding_window) // 2 + 1) + + dist = q_idx - kv_idx + left_mask = (dist >= 0) & (dist < left_window_size) + right_mask = (dist < 0) & (-dist < right_window_size) + return left_mask | right_mask + + return inner_mask + + +class T5Gemma2Encoder(T5Gemma2PreTrainedModel): + config: T5Gemma2EncoderConfig + _can_record_outputs = { + "attentions": T5Gemma2SelfAttention, + "hidden_states": T5Gemma2EncoderLayer, + } + + def __init__( + self, + config: T5Gemma2EncoderConfig, + eoi_token_index: int = 256_000, + ): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.text_config.vocab_size + + vision_config = config.vision_config + text_config = config.text_config + + # setup vision tower + self.vision_tower = AutoModel.from_config(config=vision_config) + self.multi_modal_projector = T5Gemma2MultiModalProjector(config) + + self.embed_tokens = T5Gemma2TextScaledWordEmbedding( + text_config.vocab_size, + text_config.hidden_size, + self.padding_idx, + embed_scale=text_config.hidden_size**0.5, + eoi_token_index=eoi_token_index, + ) + self.norm = T5Gemma2RMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps) + self.gradient_checkpointing = False + + self.layers = nn.ModuleList( + [T5Gemma2EncoderLayer(text_config, layer_idx) for layer_idx in range(text_config.num_hidden_layers)] + ) + self.dropout = nn.Dropout(text_config.dropout_rate) + self.rotary_emb = T5Gemma2RotaryEmbedding(text_config) + + self.text_config = text_config + + # Initialize weights and apply final processing + self.post_init() + + def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: + """Convert pixel image to image features via the encoder and projector.""" + # pixel_values: (batch_size, channels, height, width) + # image_features: Image feature tensor of shape (num_images, image_length, embed_dim). + vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = self.multi_modal_projector(vision_outputs) + return image_features + + def get_image_placeholder_mask( + self, + input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.FloatTensor], + image_features: torch.FloatTensor, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + image_token_id = self.config.image_token_id + if input_ids is None: + if inputs_embeds is None: + raise ValueError("Either `input_ids` or `inputs_embeds` has to be provided.") + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + def preprocess_image_features( + self, + pixel_values: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Convert pixel images to image features and merge into input embeds.""" + image_features = self.get_image_features(pixel_values) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + + image_mask = self.get_image_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) + return inputs_embeds + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + # Unused for processor compatibility kept in signature. + token_type_ids: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + del token_type_ids + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # As we want to pass `past_key_values=None` explicitly everywhere, we need to pop them from kwargs if present + kwargs.pop("past_key_values", None) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if pixel_values is not None: + inputs_embeds = self.preprocess_image_features( + pixel_values, input_ids=input_ids, inputs_embeds=inputs_embeds + ) + + if position_ids is None: + position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + } + self_attn_mask_mapping = { + "full_attention": create_bidirectional_mask(**mask_kwargs), + "sliding_attention": create_bidirectional_mask( + **mask_kwargs, + and_mask_function=sliding_window_mask_function(self.text_config.sliding_window, is_causal=False), + ), + } + + # input layer + hidden_states = inputs_embeds + + # global and local position embeddings + position_embeddings = {} + for layer_type in self.text_config.layer_types: + position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) + + # dropout + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.text_config.num_hidden_layers]: + hidden_states = layer_module( + hidden_states, + position_embeddings[layer_module.attention_type], + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutput( + last_hidden_state=hidden_states, + ) + + +def bidirectional_mask_function(attention_mask: Optional[torch.Tensor]) -> Callable: + """ + This creates bidirectional attention mask. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + if attention_mask is None: + return torch.ones((), dtype=torch.bool) + return attention_mask[batch_idx, kv_idx].to(torch.bool) + + return inner_mask + + +class T5Gemma2Decoder(T5Gemma2PreTrainedModel): + config: T5Gemma2DecoderConfig + _can_record_outputs = { + "attentions": OutputRecorder(T5Gemma2MergedAttention, index=1), + "cross_attentions": OutputRecorder(T5Gemma2MergedAttention, index=2), + "hidden_states": T5Gemma2DecoderLayer, + } + + def __init__(self, config: T5Gemma2DecoderConfig, eoi_token_index: int = 256_000): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = T5Gemma2TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + config.pad_token_id, + embed_scale=config.hidden_size**0.5, + eoi_token_index=eoi_token_index, + ) + self.norm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + self.layers = nn.ModuleList( + [T5Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.dropout = nn.Dropout(config.dropout_rate) + self.rotary_emb = T5Gemma2RotaryEmbedding(config) + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPastAndCrossAttentions: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states` must be given in decoder") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if not self.training and use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache()) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None, + "position_ids": position_ids, + } + # this masking function did nothing to masking but forces `allow_is_causal_skip` to be False + # as we always need a mask during decoding for merged attention. + mask_kwargs["and_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool) + self_attn_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": encoder_hidden_states, + "attention_mask": encoder_attention_mask, + "cache_position": cache_position, + "past_key_values": None, + "position_ids": None, + } + cross_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(encoder_attention_mask), + ), + } + + merged_attn_mask_mapping = { + "full_attention": torch.cat( + [self_attn_mask_mapping["full_attention"], cross_attn_mask_mapping["full_attention"]], dim=-1 + ), + "sliding_attention": torch.cat( + [self_attn_mask_mapping["sliding_attention"], cross_attn_mask_mapping["full_attention"]], dim=-1 + ), + } + + # input layer + hidden_states = inputs_embeds + + # global and local position embeddings + position_embeddings = {} + for layer_type in self.config.layer_types: + position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) + + # dropout + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer_module( + hidden_states, + position_embeddings[layer_module.attention_type], + merged_attn_mask_mapping[layer_module.attention_type], + position_ids, + past_key_values, + use_cache, + cache_position, + encoder_hidden_states, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class T5Gemma2Model(T5Gemma2PreTrainedModel): + _tied_weights_keys = { + "decoder.embed_tokens.weight": "encoder.embed_tokens.weight", + "decoder.embed_tokens.eoi_embedding": "encoder.embed_tokens.eoi_embedding", + } + + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + + # setup encoder and decoder + self.encoder = T5Gemma2Encoder(config.encoder, config.eoi_token_index) + self.decoder = T5Gemma2Decoder(config.decoder, config.eoi_token_index) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder inputs + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder inputs + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others (mainly inference or cache related) + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Seq2SeqModelOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + """ + # encoder + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + return_dict=True, + **kwargs, + ) + + encoder_hidden_states = encoder_outputs.last_hidden_state + + # decoder + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + use_cache=use_cache, + cache_position=cache_position, + return_dict=True, + **kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class T5Gemma2ForConditionalGeneration(T5Gemma2PreTrainedModel, GenerationMixin): + _tied_weights_keys = { + "lm_head.out_proj.weight": "model.encoder.embed_tokens.weight", + } + _tp_plan = {"lm_head.out_proj": "colwise_rep"} + _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])} + + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + + self.model = T5Gemma2Model(config) + self.vocab_size = config.decoder.vocab_size + self.lm_head = T5Gemma2LMHead(config.decoder.hidden_size, self.vocab_size) + self.loss_type = "ForMaskedLM" + + self.post_init() + + def set_output_embeddings(self, new_embeddings): + self.lm_head.out_proj = new_embeddings + + def get_output_embeddings(self): + return self.lm_head.out_proj + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features(self, pixel_values): + return self.get_encoder().get_image_features(pixel_values) + + @property + def vision_tower(self): + return self.get_encoder().vision_tower + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder inputs + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder inputs + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others (mainly inference or cache related) + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels) + + decoder_outputs: Seq2SeqModelOutput = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = decoder_outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + decoder_config = self.config.decoder + if decoder_config.final_logit_softcapping is not None: + logits = logits / decoder_config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * decoder_config.final_logit_softcapping + + loss = None + if labels is not None: + # Input has right-shifted so we directly perform masked lm loss + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.decoder_hidden_states, + decoder_attentions=decoder_outputs.decoder_attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state, + encoder_hidden_states=decoder_outputs.encoder_hidden_states, + encoder_attentions=decoder_outputs.encoder_attentions, + ) + + def _prepare_cache_for_generation( + self, + generation_config: GenerationConfig, + model_kwargs: dict, + generation_mode: GenerationMode, + batch_size: int, + max_cache_length: int, + ) -> bool: + """Override cache preparation to support T5Gemma2-specific EncoderDecoder Cache.""" + + # Build cache and past_key_values structure first and then override as needed. + super()._prepare_cache_for_generation( + generation_config, + model_kwargs, + generation_mode, + batch_size, + max_cache_length, + ) + + # If use_cache is False, do not prepare the cache. + if generation_config.use_cache is False: + return + + cache_implementation = generation_config.cache_implementation + if cache_implementation is None: + offload_cache = False + else: + offload_cache = "offloaded" in generation_config.cache_implementation + + # Main change: use full cache for cross-attention. + cross_attn_config = copy.deepcopy(self.config.get_text_config(decoder=True)) + + # cross-attention does not use sliding window + del cross_attn_config.sliding_window + del cross_attn_config.layer_types + + cross_attn_cache_kwargs = { + "config": cross_attn_config, + "offloading": offload_cache, + } + + past_key_values = model_kwargs.get("past_key_values") + if past_key_values is not None: + if not isinstance(past_key_values, EncoderDecoderCache): + raise ValueError( + "The `past_key_values` in `model_kwargs` must be of type `EncoderDecoderCache` for T5Gemma2 model." + ) + + # Cache already established, no need to re-initialize. + if len(past_key_values.is_updated) > 0 and past_key_values.is_updated.get(0): + return + + cross_attn_cls = type(past_key_values.cross_attention_cache) + if cross_attn_cls == StaticCache: + cross_attn_cache_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1] + # Update cross-attention cache only (switch from sliding_window to full). + past_key_values.cross_attention_cache = cross_attn_cls(**cross_attn_cache_kwargs) + else: + # Initialize new cache. + model_kwargs["past_key_values"] = EncoderDecoderCache( + DynamicCache( + **{ + "config": self.config.get_text_config(decoder=True), + "offloading": offload_cache, + } + ), # self-attention cache + DynamicCache(), # cross-attention cache + ) + + if hasattr(self, "_cache") and self._cache is not None: + if not isinstance(self._cache, EncoderDecoderCache): + raise ValueError("The internal cache must be of type `EncoderDecoderCache` for T5Gemma2 model.") + + self._cache = model_kwargs["past_key_values"] + + +@auto_docstring +class T5Gemma2ForSequenceClassification(T5Gemma2PreTrainedModel): + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.hidden_size = config.decoder.hidden_size + + self.model = T5Gemma2Model(config) + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> SequenceClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if inputs_embeds is not None or decoder_inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}." + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + if decoder_input_ids is None: + decoder_input_ids = self.prepare_decoder_input_ids_from_labels(input_ids) + + outputs: Seq2SeqModelOutput = self.model( + input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + + logits = self.score(last_hidden_state) + + batch_size = input_ids.shape[0] + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (decoder_input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(decoder_input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +@auto_docstring +class T5Gemma2ForTokenClassification(T5Gemma2PreTrainedModel): + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.hidden_size = config.decoder.hidden_size + + self.model = T5Gemma2Model(config) + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if inputs_embeds is not None or decoder_inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}." + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + if decoder_input_ids is None: + decoder_input_ids = self.prepare_decoder_input_ids_from_labels(input_ids) + + outputs: Seq2SeqModelOutput = self.model( + input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + + logits = self.score(last_hidden_state) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = [ + "T5Gemma2ForConditionalGeneration", + "T5Gemma2Model", + "T5Gemma2PreTrainedModel", + "T5Gemma2ForSequenceClassification", + "T5Gemma2ForTokenClassification", +] diff --git a/src/transformers/models/t5gemma2/modular_t5gemma2.py b/src/transformers/models/t5gemma2/modular_t5gemma2.py new file mode 100644 index 000000000000..3957d65823b8 --- /dev/null +++ b/src/transformers/models/t5gemma2/modular_t5gemma2.py @@ -0,0 +1,1335 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from collections.abc import Callable +from typing import Any, Optional, Union + +import torch +import torch.nn as nn + +from ... import initialization as init +from ...cache_utils import DynamicCache, EncoderDecoderCache, StaticCache +from ...configuration_utils import PreTrainedConfig +from ...generation import GenerationConfig, GenerationMixin, GenerationMode +from ...masking_utils import create_bidirectional_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + logging, +) +from ...utils.generic import OutputRecorder, check_model_inputs +from ..auto import AutoModel +from ..gemma3.configuration_gemma3 import Gemma3Config, Gemma3TextConfig +from ..gemma3.modeling_gemma3 import ( + Gemma3Attention, + Gemma3MLP, + Gemma3MultiModalProjector, + Gemma3PreTrainedModel, + Gemma3RMSNorm, + Gemma3RotaryEmbedding, + Gemma3TextScaledWordEmbedding, + apply_rotary_pos_emb, + create_causal_mask, + create_sliding_window_causal_mask, + eager_attention_forward, +) +from ..siglip import SiglipVisionConfig +from ..t5gemma.modeling_t5gemma import ( + T5GemmaClassificationHead, + T5GemmaEncoderLayer, + T5GemmaLMHead, + bidirectional_mask_function, +) + + +logger = logging.get_logger(__name__) + + +class T5Gemma2TextConfig(Gemma3TextConfig): + model_type = "t5gemma2_text" + + +class T5Gemma2EncoderConfig(Gemma3Config): + model_type = "t5gemma2_encoder" + + sub_configs = { + "text_config": T5Gemma2TextConfig, + "vision_config": SiglipVisionConfig, + } + + +class T5Gemma2DecoderConfig(Gemma3TextConfig): + model_type = "t5gemma2_decoder" + + +class T5Gemma2Config(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5Gemma2Model`]. It is used to instantiate an T5Gemma2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to a hypothetical balanced Gemma3 encoder-decoder model. + e.g. [google/t5gemma-2-270m-270m](https://huggingface.co/google/t5gemma-2-270m-270m) + Configuration objects inherit from [PreTrainedConfig] and can be used to control the model outputs. Read the + documentation from [PreTrainedConfig] for more information. + + Args: + encoder (`Union[T5Gemma2EncoderConfig, dict]`, optional, *optional*): + Configuration for the encoder. + decoder (`Union[T5Gemma2DecoderConfig, dict]`, optional, *optional*): + Configuration for the decoder. + is_encoder_decoder (bool, optional, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + dropout_rate (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers (following T5). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for attention. + classifier_dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier (following T5). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + image_token_index (`int`, *optional*, defaults to 256001): + The image token index to encode the image prompt. Defaults to 256001, which is right after the eoi_token_index. + Note this is different from Gemma 3. + ```python + >>> from transformers import T5Gemma2Config, T5Gemma2Model + >>> t5gemma2_config = T5Gemma2Config.from_pretrained("google/t5gemma-270m-270m") + >>> model = T5Gemma2Model(t5gemma2_config) + ``` + """ + + model_type = "t5gemma2" + keys_to_ignore_at_inference = ["past_key_values"] + + sub_configs = { + "encoder": T5Gemma2EncoderConfig, + "decoder": T5Gemma2DecoderConfig, + } + + attribute_map = { + "image_token_id": "image_token_index", + "eoi_token_id": "eoi_token_index", + } + + def __init__( + self, + encoder: Optional[Union[T5Gemma2EncoderConfig, dict[str, Any]]] = None, + decoder: Optional[Union[T5Gemma2DecoderConfig, dict[str, Any]]] = None, + is_encoder_decoder: bool = True, + dropout_rate: float = 0.0, + attention_dropout: float = 0.0, + classifier_dropout_rate: float = 0.0, + initializer_range: float = 0.02, + image_token_index: int = 256_001, + **kwargs, + ): + if isinstance(encoder, dict): + encoder = T5Gemma2EncoderConfig(**encoder) + elif encoder is None: + encoder = T5Gemma2EncoderConfig() + logger.info("encoder is None, using default T5Gemma2EncoderConfig encoder config.") + else: + if not isinstance(encoder, T5Gemma2EncoderConfig): + raise ValueError(f"{type(encoder)} is not supported.") + + if isinstance(decoder, dict): + decoder = T5Gemma2DecoderConfig(**decoder) + elif decoder is None: + decoder = T5Gemma2DecoderConfig() + logger.info("decoder is None, using default T5Gemma2DecoderConfig decoder config.") + else: + if not isinstance(decoder, T5Gemma2DecoderConfig): + raise ValueError(f"{type(decoder)} is not supported.") + + if encoder.text_config.hidden_size != decoder.hidden_size: + raise ValueError( + "Imbalanced encoder-decoder is not supported in T5Gemma2: " + f"encoder ({encoder.text_config.hidden_size}) vs decoder ({decoder.hidden_size})." + ) + + if not is_encoder_decoder: + raise ValueError("T5Gemma2Model only support encoder-decoder modeling.") + + if encoder.text_config.vocab_size != decoder.vocab_size: + raise ValueError( + "Imbalanced encoder-decoder vocabulary size is not supported in T5Gemma2: " + f"encoder ({encoder.text_config.vocab_size}) vs decoder ({decoder.vocab_size})." + ) + + # Encoder. + encoder.text_config.dropout_rate = dropout_rate + encoder.text_config.attention_dropout = attention_dropout + encoder.vision_config.attention_dropout = attention_dropout + encoder.image_token_index = image_token_index + self.encoder = encoder + + # Decoder. + decoder.dropout_rate = dropout_rate + decoder.attention_dropout = attention_dropout + self.decoder = decoder + + for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id", "vocab_size"]: + if special_token_key not in kwargs: + kwargs[special_token_key] = getattr(decoder, special_token_key) + + super().__init__(**kwargs) + + self.is_encoder_decoder = is_encoder_decoder + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.classifier_dropout_rate = classifier_dropout_rate + self.initializer_range = initializer_range + self.eoi_token_index = encoder.eoi_token_index + self.image_token_index = image_token_index + + def __setattr__(self, key, value): + shared_attr_with_submodules = [ + "output_hidden_states", + "output_attentions", + "_attn_implementation_internal", + "dropout_rate", + "attention_dropout", + "vocab_size", + "dtype", + ] + + if key in shared_attr_with_submodules: + setattr(self.encoder.text_config, key, value) + setattr(self.encoder.vision_config, key, value) + setattr(self.decoder, key, value) + setattr(self.encoder, key, value) + super().__setattr__(key, value) + + +class T5Gemma2RMSNorm(Gemma3RMSNorm): + pass + + +class T5Gemma2MLP(Gemma3MLP): + def __init__(self, config: T5Gemma2TextConfig): + super().__init__(config) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, x): + hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + hidden_states = self.dropout(hidden_states) + down_proj = self.down_proj(hidden_states) + return down_proj + + +class T5Gemma2RotaryEmbedding(Gemma3RotaryEmbedding): + def __init__(self, config: T5Gemma2TextConfig, device=None): + super().__init__(config, device) + + @staticmethod + def compute_default_rope_parameters( + config: Optional[T5Gemma2TextConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + layer_type: Optional[str] = None, + ) -> tuple["torch.Tensor", float]: + return super().compute_default_rope_parameters(config, device, seq_len, layer_type) + + +class T5Gemma2SelfAttention(Gemma3Attention): + def __init__(self, config: T5Gemma2TextConfig, layer_idx: int): + super().__init__(config, layer_idx) + + +class T5Gemma2MergedAttention(Gemma3Attention): + """Merged self-attention and cross-attention for decoder.""" + + def __init__(self, config: T5Gemma2TextConfig, layer_idx: int): + super().__init__(config, layer_idx) + + def forward( + self, + # decoder self-attention inputs + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + merged_attention_mask: Optional[torch.Tensor], + # cross-attention inputs + encoder_hidden_states: torch.Tensor, + # cache inputs + past_key_values: Optional[EncoderDecoderCache] = None, + cache_position: Optional[torch.LongTensor] = None, + # others + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + # attention shapes. + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + cross_input_shape = encoder_hidden_states.shape[:-1] + cross_hidden_shape = (*cross_input_shape, -1, self.head_dim) + + # self-attention. + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # self-attention. + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + self_attention_cache = past_key_values.self_attention_cache + key_states, value_states = self_attention_cache.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # cross-attention. + is_updated = past_key_values.is_updated.get(self.layer_idx) + cross_attention_cache = past_key_values.cross_attention_cache + + if past_key_values is None or not is_updated: + cross_key_states = self.k_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2) + cross_value_states = self.v_proj(encoder_hidden_states).view(cross_hidden_shape).transpose(1, 2) + + cross_key_states = self.k_norm(cross_key_states) + + if past_key_values is not None: + cross_key_states, cross_value_states = cross_attention_cache.update( + cross_key_states, cross_value_states, self.layer_idx + ) + past_key_values.is_updated[self.layer_idx] = True + else: + cross_key_states = cross_attention_cache.layers[self.layer_idx].keys + cross_value_states = cross_attention_cache.layers[self.layer_idx].values + + # merged attention. + query_states = query_states + cross_key_size = cross_input_shape[1] + key_states = torch.cat([key_states, cross_key_states], dim=2) + value_states = torch.cat([value_states, cross_value_states], dim=2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + merged_attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + is_causal=False, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + # decompose merged attention weights into self & cross attention weights + if attn_weights is not None: + self_attn_weights = attn_weights[..., :-cross_key_size] + cross_attn_weights = attn_weights[..., -cross_key_size:] + else: + self_attn_weights, cross_attn_weights = None, None + return attn_output, self_attn_weights, cross_attn_weights + + +def sliding_window_mask_function(sliding_window: int, is_causal=True) -> Callable: + """ + This creates uni/bidirectional attention mask with sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + if is_causal: + left_window_size, right_window_size = sliding_window, 0 + else: + left_window_size, right_window_size = ((sliding_window + 1) // 2, (sliding_window) // 2 + 1) + + dist = q_idx - kv_idx + left_mask = (dist >= 0) & (dist < left_window_size) + right_mask = (dist < 0) & (-dist < right_window_size) + return left_mask | right_mask + + return inner_mask + + +class T5Gemma2EncoderLayer(T5GemmaEncoderLayer): + pass + + +class T5Gemma2DecoderLayer(T5Gemma2EncoderLayer): + """Decoder sub-layer: merged attention instead of vanilla self-attention.""" + + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + + # replace vanilla self-attention with merged attention to support joint cross-attention. + self.self_attn = T5Gemma2MergedAttention( + config=config, + layer_idx=layer_idx, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + merged_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + + hidden_states, _, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + merged_attention_mask=merged_attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + encoder_hidden_states=encoder_hidden_states, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class T5Gemma2LMHead(T5GemmaLMHead): + pass + + +class T5Gemma2ClassificationHead(T5GemmaClassificationHead): + pass + + +class T5Gemma2MultiModalProjector(Gemma3MultiModalProjector): + def __init__(self, config: T5Gemma2EncoderConfig): + super().__init__(config) + + +class T5Gemma2TextScaledWordEmbedding(Gemma3TextScaledWordEmbedding): + """T5Gemma2 Embedding: override to add eoi token embedding separately.""" + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + embed_scale: float = 1.0, + eoi_token_index: int = 256_000, + ): + super().__init__(num_embeddings, embedding_dim, padding_idx, embed_scale) + self.eoi_token_index = eoi_token_index + self.eoi_embedding = nn.Parameter(torch.zeros(self.embedding_dim)) + + def forward(self, input_ids: torch.Tensor): + input_embeddings = super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) + input_embeddings[input_ids == self.eoi_token_index] = self.eoi_embedding.to(input_embeddings.dtype) + return input_embeddings + + +@auto_docstring +class T5Gemma2PreTrainedModel(Gemma3PreTrainedModel): + config: T5Gemma2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [ + "T5Gemma2EncoderLayer", + "T5Gemma2DecoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + _can_record_outputs = { + "hidden_states": [T5Gemma2EncoderLayer, T5Gemma2DecoderLayer], + "attentions": [ + OutputRecorder(T5Gemma2SelfAttention, index=1, layer_name="self_attn"), + OutputRecorder(T5Gemma2MergedAttention, index=1, layer_name="self_attn"), + OutputRecorder(T5Gemma2MergedAttention, index=2, layer_name="cross_attn"), + ], + } + + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + if isinstance(module, T5Gemma2MultiModalProjector): + init.zeros_(module.mm_input_projection_weight) + elif isinstance(module, T5Gemma2TextScaledWordEmbedding): + init.zeros_(module.eoi_embedding) + elif isinstance(module, T5Gemma2ClassificationHead): + scale = module.out_proj.weight.shape[0] ** -0.5 + init.normal_(module.out_proj.weight, mean=0.0, std=self.config.initializer_range * scale) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + init.zeros_(module.out_proj.bias) + # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) + elif "RMSNorm" in module.__class__.__name__: + init.zeros_(module.weight) + + def prepare_decoder_input_ids_from_labels(self, input_ids): + """ + Shifts input_ids to the right, prepends the decoder_start_token_id, and handles + pad_token_id replacement for labels that were -100. + This is a common preparation step for decoder inputs in sequence-to-sequence models. + """ + decoder_config = self.config.decoder + decoder_start_token_id = decoder_config.bos_token_id + pad_token_id = decoder_config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ") + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.decoder.pad_token_id has to be defined.") + + # Is this T5 specific? + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class T5Gemma2Encoder(T5Gemma2PreTrainedModel): + config: T5Gemma2EncoderConfig + _can_record_outputs = { + "attentions": T5Gemma2SelfAttention, + "hidden_states": T5Gemma2EncoderLayer, + } + + def __init__( + self, + config: T5Gemma2EncoderConfig, + eoi_token_index: int = 256_000, + ): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.text_config.vocab_size + + vision_config = config.vision_config + text_config = config.text_config + + # setup vision tower + self.vision_tower = AutoModel.from_config(config=vision_config) + self.multi_modal_projector = T5Gemma2MultiModalProjector(config) + + self.embed_tokens = T5Gemma2TextScaledWordEmbedding( + text_config.vocab_size, + text_config.hidden_size, + self.padding_idx, + embed_scale=text_config.hidden_size**0.5, + eoi_token_index=eoi_token_index, + ) + self.norm = T5Gemma2RMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps) + self.gradient_checkpointing = False + + self.layers = nn.ModuleList( + [T5Gemma2EncoderLayer(text_config, layer_idx) for layer_idx in range(text_config.num_hidden_layers)] + ) + self.dropout = nn.Dropout(text_config.dropout_rate) + self.rotary_emb = T5Gemma2RotaryEmbedding(text_config) + + self.text_config = text_config + + # Initialize weights and apply final processing + self.post_init() + + def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: + """Convert pixel image to image features via the encoder and projector.""" + # pixel_values: (batch_size, channels, height, width) + # image_features: Image feature tensor of shape (num_images, image_length, embed_dim). + vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = self.multi_modal_projector(vision_outputs) + return image_features + + def get_image_placeholder_mask( + self, + input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.FloatTensor], + image_features: torch.FloatTensor, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + image_token_id = self.config.image_token_id + if input_ids is None: + if inputs_embeds is None: + raise ValueError("Either `input_ids` or `inputs_embeds` has to be provided.") + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + def preprocess_image_features( + self, + pixel_values: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Convert pixel images to image features and merge into input embeds.""" + image_features = self.get_image_features(pixel_values) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + + image_mask = self.get_image_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) + return inputs_embeds + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + # Unused for processor compatibility kept in signature. + token_type_ids: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + del token_type_ids + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # As we want to pass `past_key_values=None` explicitly everywhere, we need to pop them from kwargs if present + kwargs.pop("past_key_values", None) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if pixel_values is not None: + inputs_embeds = self.preprocess_image_features( + pixel_values, input_ids=input_ids, inputs_embeds=inputs_embeds + ) + + if position_ids is None: + position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + } + self_attn_mask_mapping = { + "full_attention": create_bidirectional_mask(**mask_kwargs), + "sliding_attention": create_bidirectional_mask( + **mask_kwargs, + and_mask_function=sliding_window_mask_function(self.text_config.sliding_window, is_causal=False), + ), + } + + # input layer + hidden_states = inputs_embeds + + # global and local position embeddings + position_embeddings = {} + for layer_type in self.text_config.layer_types: + position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) + + # dropout + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.text_config.num_hidden_layers]: + hidden_states = layer_module( + hidden_states, + position_embeddings[layer_module.attention_type], + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutput( + last_hidden_state=hidden_states, + ) + + +class T5Gemma2Decoder(T5Gemma2PreTrainedModel): + config: T5Gemma2DecoderConfig + _can_record_outputs = { + "attentions": OutputRecorder(T5Gemma2MergedAttention, index=1), + "cross_attentions": OutputRecorder(T5Gemma2MergedAttention, index=2), + "hidden_states": T5Gemma2DecoderLayer, + } + + def __init__(self, config: T5Gemma2DecoderConfig, eoi_token_index: int = 256_000): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = T5Gemma2TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + config.pad_token_id, + embed_scale=config.hidden_size**0.5, + eoi_token_index=eoi_token_index, + ) + self.norm = T5Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + self.layers = nn.ModuleList( + [T5Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.dropout = nn.Dropout(config.dropout_rate) + self.rotary_emb = T5Gemma2RotaryEmbedding(config) + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPastAndCrossAttentions: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states` must be given in decoder") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if not self.training and use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache()) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None, + "position_ids": position_ids, + } + # this masking function did nothing to masking but forces `allow_is_causal_skip` to be False + # as we always need a mask during decoding for merged attention. + mask_kwargs["and_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool) + self_attn_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": encoder_hidden_states, + "attention_mask": encoder_attention_mask, + "cache_position": cache_position, + "past_key_values": None, + "position_ids": None, + } + cross_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(encoder_attention_mask), + ), + } + + merged_attn_mask_mapping = { + "full_attention": torch.cat( + [self_attn_mask_mapping["full_attention"], cross_attn_mask_mapping["full_attention"]], dim=-1 + ), + "sliding_attention": torch.cat( + [self_attn_mask_mapping["sliding_attention"], cross_attn_mask_mapping["full_attention"]], dim=-1 + ), + } + + # input layer + hidden_states = inputs_embeds + + # global and local position embeddings + position_embeddings = {} + for layer_type in self.config.layer_types: + position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) + + # dropout + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + hidden_states = layer_module( + hidden_states, + position_embeddings[layer_module.attention_type], + merged_attn_mask_mapping[layer_module.attention_type], + position_ids, + past_key_values, + use_cache, + cache_position, + encoder_hidden_states, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class T5Gemma2Model(T5Gemma2PreTrainedModel): + _tied_weights_keys = { + "decoder.embed_tokens.weight": "encoder.embed_tokens.weight", + "decoder.embed_tokens.eoi_embedding": "encoder.embed_tokens.eoi_embedding", + } + + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + + # setup encoder and decoder + self.encoder = T5Gemma2Encoder(config.encoder, config.eoi_token_index) + self.decoder = T5Gemma2Decoder(config.decoder, config.eoi_token_index) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder inputs + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder inputs + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others (mainly inference or cache related) + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Seq2SeqModelOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + """ + # encoder + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + return_dict=True, + **kwargs, + ) + + encoder_hidden_states = encoder_outputs.last_hidden_state + + # decoder + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + use_cache=use_cache, + cache_position=cache_position, + return_dict=True, + **kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class T5Gemma2ForConditionalGeneration(T5Gemma2PreTrainedModel, GenerationMixin): + _tied_weights_keys = { + "lm_head.out_proj.weight": "model.encoder.embed_tokens.weight", + } + _tp_plan = {"lm_head.out_proj": "colwise_rep"} + _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])} + + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + + self.model = T5Gemma2Model(config) + self.vocab_size = config.decoder.vocab_size + self.lm_head = T5Gemma2LMHead(config.decoder.hidden_size, self.vocab_size) + self.loss_type = "ForMaskedLM" + + self.post_init() + + def set_output_embeddings(self, new_embeddings): + self.lm_head.out_proj = new_embeddings + + def get_output_embeddings(self): + return self.lm_head.out_proj + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features(self, pixel_values): + return self.get_encoder().get_image_features(pixel_values) + + @property + def vision_tower(self): + return self.get_encoder().vision_tower + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder inputs + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder inputs + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others (mainly inference or cache related) + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels) + + decoder_outputs: Seq2SeqModelOutput = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = decoder_outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + decoder_config = self.config.decoder + if decoder_config.final_logit_softcapping is not None: + logits = logits / decoder_config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * decoder_config.final_logit_softcapping + + loss = None + if labels is not None: + # Input has right-shifted so we directly perform masked lm loss + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.decoder_hidden_states, + decoder_attentions=decoder_outputs.decoder_attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state, + encoder_hidden_states=decoder_outputs.encoder_hidden_states, + encoder_attentions=decoder_outputs.encoder_attentions, + ) + + def _prepare_cache_for_generation( + self, + generation_config: GenerationConfig, + model_kwargs: dict, + generation_mode: GenerationMode, + batch_size: int, + max_cache_length: int, + ) -> bool: + """Override cache preparation to support T5Gemma2-specific EncoderDecoder Cache.""" + + # Build cache and past_key_values structure first and then override as needed. + super()._prepare_cache_for_generation( + generation_config, + model_kwargs, + generation_mode, + batch_size, + max_cache_length, + ) + + # If use_cache is False, do not prepare the cache. + if generation_config.use_cache is False: + return + + cache_implementation = generation_config.cache_implementation + if cache_implementation is None: + offload_cache = False + else: + offload_cache = "offloaded" in generation_config.cache_implementation + + # Main change: use full cache for cross-attention. + cross_attn_config = copy.deepcopy(self.config.get_text_config(decoder=True)) + + # cross-attention does not use sliding window + del cross_attn_config.sliding_window + del cross_attn_config.layer_types + + cross_attn_cache_kwargs = { + "config": cross_attn_config, + "offloading": offload_cache, + } + + past_key_values = model_kwargs.get("past_key_values") + if past_key_values is not None: + if not isinstance(past_key_values, EncoderDecoderCache): + raise ValueError( + "The `past_key_values` in `model_kwargs` must be of type `EncoderDecoderCache` for T5Gemma2 model." + ) + + # Cache already established, no need to re-initialize. + if len(past_key_values.is_updated) > 0 and past_key_values.is_updated.get(0): + return + + cross_attn_cls = type(past_key_values.cross_attention_cache) + if cross_attn_cls == StaticCache: + cross_attn_cache_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1] + # Update cross-attention cache only (switch from sliding_window to full). + past_key_values.cross_attention_cache = cross_attn_cls(**cross_attn_cache_kwargs) + else: + # Initialize new cache. + model_kwargs["past_key_values"] = EncoderDecoderCache( + DynamicCache( + **{ + "config": self.config.get_text_config(decoder=True), + "offloading": offload_cache, + } + ), # self-attention cache + DynamicCache(), # cross-attention cache + ) + + if hasattr(self, "_cache") and self._cache is not None: + if not isinstance(self._cache, EncoderDecoderCache): + raise ValueError("The internal cache must be of type `EncoderDecoderCache` for T5Gemma2 model.") + + self._cache = model_kwargs["past_key_values"] + + +@auto_docstring +class T5Gemma2ForSequenceClassification(T5Gemma2PreTrainedModel): + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.hidden_size = config.decoder.hidden_size + + self.model = T5Gemma2Model(config) + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> SequenceClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if inputs_embeds is not None or decoder_inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}." + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + if decoder_input_ids is None: + decoder_input_ids = self.prepare_decoder_input_ids_from_labels(input_ids) + + outputs: Seq2SeqModelOutput = self.model( + input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + + logits = self.score(last_hidden_state) + + batch_size = input_ids.shape[0] + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (decoder_input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(decoder_input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +@auto_docstring +class T5Gemma2ForTokenClassification(T5Gemma2PreTrainedModel): + def __init__(self, config: T5Gemma2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.hidden_size = config.decoder.hidden_size + + self.model = T5Gemma2Model(config) + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5Gemma2ClassificationHead(self.hidden_size, self.num_labels, classifier_dropout) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if inputs_embeds is not None or decoder_inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}." + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + if decoder_input_ids is None: + decoder_input_ids = self.prepare_decoder_input_ids_from_labels(input_ids) + + outputs: Seq2SeqModelOutput = self.model( + input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + **kwargs, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + + logits = self.score(last_hidden_state) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = [ + "T5Gemma2Config", + "T5Gemma2TextConfig", + "T5Gemma2EncoderConfig", + "T5Gemma2DecoderConfig", + "T5Gemma2ForConditionalGeneration", + "T5Gemma2Model", + "T5Gemma2PreTrainedModel", + "T5Gemma2ForSequenceClassification", + "T5Gemma2ForTokenClassification", +] diff --git a/tests/models/t5gemma2/__init__.py b/tests/models/t5gemma2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/t5gemma2/test_modeling_t5gemma2.py b/tests/models/t5gemma2/test_modeling_t5gemma2.py new file mode 100644 index 000000000000..2dcb5f240c63 --- /dev/null +++ b/tests/models/t5gemma2/test_modeling_t5gemma2.py @@ -0,0 +1,968 @@ +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch T5Gemma2 model.""" + +import copy +import unittest + +import pytest + +from transformers import ( + T5Gemma2Config, + T5Gemma2DecoderConfig, + T5Gemma2EncoderConfig, + T5Gemma2TextConfig, + is_torch_available, +) +from transformers.testing_utils import ( + require_torch, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + import torch.nn.functional as F + + from transformers import ( + T5Gemma2ForConditionalGeneration, + T5Gemma2ForSequenceClassification, + T5Gemma2ForTokenClassification, + T5Gemma2Model, + ) + + +class T5Gemma2ModelTester: + config_class = T5Gemma2Config + text_config_class = T5Gemma2TextConfig + encoder_config_class = T5Gemma2EncoderConfig + decoder_config_class = T5Gemma2DecoderConfig + + if is_torch_available(): + model_class = T5Gemma2Model + causal_lm_class = T5Gemma2ForConditionalGeneration + sequence_classification_class = T5Gemma2ForSequenceClassification + token_classification_class = T5Gemma2ForTokenClassification + + def __init__( + self, + parent, + batch_size=13, + is_training=True, + use_attention_mask=True, + use_labels=True, + vocab_size=99, + # decoder-specific + seq_length=7, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=37, + # encoder-specific + encoder_seq_length=7, + encoder_hidden_size=32, + encoder_num_hidden_layers=2, + encoder_num_attention_heads=4, + encoder_num_key_value_heads=2, + encoder_intermediate_size=37, + # vision-specific + mm_tokens_per_image=2, + image_token_index=4, + boi_token_index=5, + eoi_token_index=6, + siglip_config={ + "use_labels": True, + "image_size": 20, + "patch_size": 5, + "num_channels": 3, + "is_training": True, + "hidden_size": 32, + "num_key_value_heads": 1, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "dropout": 0.1, + "attention_dropout": 0.1, + "initializer_range": 0.02, + }, + # common + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + # special ids + eos_token_id=1, + pad_token_id=0, + bos_token_id=2, + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + # decoder + self.seq_length = seq_length + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + # encoder + self.encoder_seq_length = encoder_seq_length + self.encoder_hidden_size = encoder_hidden_size + self.encoder_num_hidden_layers = encoder_num_hidden_layers + self.encoder_num_attention_heads = encoder_num_attention_heads + self.encoder_num_key_value_heads = encoder_num_key_value_heads + self.encoder_intermediate_size = encoder_intermediate_size + # vision + self.mm_tokens_per_image = mm_tokens_per_image + self.image_token_index = image_token_index + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.siglip_config = siglip_config + self.num_channels = siglip_config["num_channels"] + self.image_size = siglip_config["image_size"] + # common + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.head_dim = self.hidden_size // self.num_attention_heads + # special ids + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + + def get_encoder_config(self): + return self.encoder_config_class( + text_config=self.text_config_class( + vocab_size=self.vocab_size, + hidden_size=self.encoder_hidden_size, + num_hidden_layers=self.encoder_num_hidden_layers, + num_attention_heads=self.encoder_num_attention_heads, + num_key_value_heads=self.encoder_num_key_value_heads, + intermediate_size=self.encoder_intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + head_dim=self.head_dim, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + ), + # vision. + vision_config=self.siglip_config, + image_token_index=self.image_token_index, + boi_token_index=self.boi_token_index, + eoi_token_index=self.eoi_token_index, + mm_tokens_per_image=self.mm_tokens_per_image, + hidden_size=self.encoder_hidden_size, + ) + + def get_decoder_config(self): + return self.decoder_config_class( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + intermediate_size=self.intermediate_size, + cross_attention_hidden_size=self.encoder_hidden_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=True, + initializer_range=self.initializer_range, + head_dim=self.head_dim, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + ) + + def get_config(self, is_encoder_decoder=True): + return self.config_class( + encoder=self.get_encoder_config(), + decoder=self.get_decoder_config(), + is_encoder_decoder=is_encoder_decoder, + # vision. + image_token_index=self.image_token_index, + # Used for generation test. + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size - 1) + 1 + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1) + 1 + # Vision inputs. + pixel_values = floats_tensor( + [ + self.batch_size, + self.siglip_config["num_channels"], + self.siglip_config["image_size"], + self.siglip_config["image_size"], + ] + ) + + # Remove BOS symbols from inputs. + input_ids = torch.where(input_ids == self.bos_token_id, 42, input_ids) + decoder_input_ids = torch.where(decoder_input_ids == self.bos_token_id, 42, decoder_input_ids) + + # Avoid leading PAD tokens from inputs. + decoder_input_ids[:, 0] = self.pad_token_id + 1 + + # set the 3 first tokens to be image, and ensure that no other tokens are image tokens + # do not change this unless you modified image size or patch size + input_ids[input_ids == config.encoder.image_token_index] = self.pad_token_id + input_ids[:, :1] = config.encoder.image_token_index + + attention_mask = None + decoder_attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + decoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + return ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "pixel_values": pixel_values, + } + return config, inputs_dict + + def create_and_check_model( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.model_class(config=config).to(torch_device).eval() + + result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + decoder_output = result.last_hidden_state + decoder_past = result.past_key_values + encoder_output = result.encoder_last_hidden_state + + self.parent.assertEqual( + encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.encoder_hidden_size) + ) + self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertIsNotNone(decoder_past) + self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers) + self.parent.assertEqual(len(decoder_past.cross_attention_cache), config.decoder.num_hidden_layers) + + def check_prepare_lm_labels_via_shift_left( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.model_class(config=config).to(torch_device).eval() + + # _shift_right should be called on labels + shifted_labels = model.prepare_decoder_input_ids_from_labels(lm_labels) + + # first token should be decoder_start_token_id + self.parent.assertTrue(torch.all(shifted_labels[:, 0] == config.decoder.bos_token_id)) + + # the rest should be the labels shifted by one, with -100 replaced by pad_token_id + labels_without_ignore_index = lm_labels.masked_fill(lm_labels == -100, config.decoder.pad_token_id) + self.parent.assertTrue(torch.all(shifted_labels[:, 1:] == labels_without_ignore_index[:, :-1])) + + def create_and_check_with_lm_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.causal_lm_class(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + labels=lm_labels, + pixel_values=pixel_values, + ) + self.parent.assertEqual(len(outputs), 4) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, self.vocab_size)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_with_sequence_classification_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + labels = torch.tensor([1] * self.batch_size, dtype=torch.long, device=torch_device) + model = self.sequence_classification_class(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + pixel_values=pixel_values, + decoder_input_ids=decoder_input_ids, + labels=labels, + ) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, config.num_labels)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_with_token_classification_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + labels = torch.tensor([1] * self.seq_length * self.batch_size, dtype=torch.long, device=torch_device) + model = self.token_classification_class(config=config) + model = model.to(torch_device).eval() + outputs = model( + input_ids=input_ids, + pixel_values=pixel_values, + decoder_input_ids=decoder_input_ids, + labels=labels, + ) + + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, config.num_labels)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.model_class(config=config).get_decoder().to(torch_device).eval() + encoder_hidden_states = torch.ones( + (self.batch_size, self.encoder_seq_length, self.encoder_hidden_size), dtype=torch.float32 + ).to(torch_device) + + # first forward pass + outputs = model(decoder_input_ids, encoder_hidden_states=encoder_hidden_states, use_cache=True) + outputs_use_cache_conf = model(decoder_input_ids, encoder_hidden_states=encoder_hidden_states) + outputs_no_past = model(decoder_input_ids, encoder_hidden_states=encoder_hidden_states, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids, encoder_hidden_states=encoder_hidden_states)["last_hidden_state"] + output_from_past = model( + next_tokens, encoder_hidden_states=encoder_hidden_states, past_key_values=past_key_values + )["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.model_class(config=config).get_decoder().to(torch_device).eval() + encoder_hidden_states = torch.ones( + (self.batch_size, self.encoder_seq_length, self.encoder_hidden_size), dtype=torch.float32 + ).to(torch_device) + + # create attention mask + attn_mask = torch.ones(decoder_input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = decoder_input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + output, past_key_values = model( + decoder_input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=attn_mask, use_cache=True + ).to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + decoder_input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model( + next_input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=attn_mask + )["last_hidden_state"] + output_from_past = model( + next_tokens, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + attention_mask=attn_mask, + )["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.model_class(config=config).get_decoder().to(torch_device).eval() + encoder_hidden_states = torch.ones( + (self.batch_size, self.encoder_seq_length, self.encoder_hidden_size), dtype=torch.float32 + ).to(torch_device) + + # first forward pass + outputs = model( + decoder_input_ids, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + use_cache=True, + ) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=next_attention_mask + )["last_hidden_state"] + output_from_past = model( + next_tokens, + encoder_hidden_states=encoder_hidden_states, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + )["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_generate_with_past_key_values( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.causal_lm_class(config=config).to(torch_device).eval() + torch.manual_seed(0) + output_without_past_cache = model.generate( + input_ids, pixel_values=pixel_values, num_beams=2, max_length=5, do_sample=True, use_cache=False + ) + torch.manual_seed(0) + output_with_past_cache = model.generate( + input_ids, pixel_values=pixel_values, num_beams=2, max_length=5, do_sample=True + ) + self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) + + def create_and_check_model_fp16_forward( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + pixel_values, + ): + model = self.model_class(config=config).to(torch_device).half().eval() + output = model( + input_ids, + pixel_values=pixel_values, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + )["last_hidden_state"] + self.parent.assertFalse(torch.isnan(output).any().item()) + + +@require_torch +class T5Gemma2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + T5Gemma2Model, + T5Gemma2ForConditionalGeneration, + T5Gemma2ForSequenceClassification, + T5Gemma2ForTokenClassification, + ) + if is_torch_available() + else () + ) + + _is_stateful = True + is_encoder_decoder = True + + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = T5Gemma2ForConditionalGeneration if is_torch_available() else None + + # MP works but offload doesn't work when the SigLIP MultiheadAttention is offloaded + test_cpu_offload = False + test_disk_offload_safetensors = False + test_disk_offload_bin = False + + def setUp(self): + self.model_tester = T5Gemma2ModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=T5Gemma2Config, + # For faking the testing. + hidden_size=37, + vocab_size=self.model_tester.vocab_size, + num_attention_heads=self.model_tester.num_attention_heads, + num_hidden_layers=self.model_tester.num_hidden_layers, + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_shift_right(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + # Based on tests.models.t5.test_modeling_t5.T5ModelTest.test_inputs_embeds + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in (T5Gemma2Model, T5Gemma2ForConditionalGeneration): + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = wte(input_ids) + else: + inputs["inputs_embeds"] = wte(encoder_input_ids) + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + + with torch.no_grad(): + model(**inputs)[0] + + def test_with_lm_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_lm_head(*config_and_inputs) + + def test_with_sequence_classification_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs) + + def test_with_token_classification_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs) + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_past_with_attn_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_generate_with_past_key_values(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Can't do half precision") + def test_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_Gemma_sequence_classification_model with Gemma -> T5Gemma2 (Add is_encoder_decoder option) + def test_T5Gemma2_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + + for pixel_values in [None, input_dict["pixel_values"]]: + model = self.model_tester.sequence_classification_class(config).to(torch_device).eval() + result = model(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_Gemma_sequence_classification_model_for_single_label with Gemma -> T5Gemma2 (Add is_encoder_decoder option) + def test_T5Gemma2_sequence_classification_model_for_single_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + + for pixel_values in [None, input_dict["pixel_values"]]: + model = self.model_tester.sequence_classification_class(config).to(torch_device).eval() + result = model(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_Gemma_sequence_classification_model_for_multi_label with Gemma -> T5Gemma2 (Add is_encoder_decoder option) + def test_T5Gemma2_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + + for pixel_values in [None, input_dict["pixel_values"]]: + model = self.model_tester.sequence_classification_class(config).to(torch_device).eval() + result = model(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_Gemma_token_classification_model with Gemma -> T5Gemma2 (Add is_encoder_decoder option) + def test_T5Gemma2_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + decoder_input_ids = input_dict["decoder_input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + + for pixel_values in [None, input_dict["pixel_values"]]: + model = self.model_tester.token_classification_class(config).to(torch_device).eval() + + result = model( + input_ids, + decoder_input_ids=decoder_input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + labels=token_labels, + ) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + + @unittest.skip("T5Gemma2 eager/FA2 attention outputs are expected to be different") + def test_flash_attn_2_equivalence(self): + pass + + @unittest.skip("This was not properly written, submodules need the attribute to be overwritten") + def test_attention_outputs(self): + pass + + @unittest.skip("Mismatch issue doesn't exist in T5Gemma2.") + def test_load_with_mismatched_shapes(self): + pass + + # Based on tests.generation.test_utils.GenerationTesterMixin.test_generate_continue_from_past_key_values + # Updated decoder_attention_mask to consider the appended bos token + @pytest.mark.generate + def test_generate_continue_from_past_key_values(self): + # Tests that we can continue generating from past key values, returned from a previous `generate` call + for model_class in self.all_generative_model_classes: + if model_class == self.model_tester.token_classification_class: + continue + if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]): + self.skipTest(reason="Won't fix: old model with unique inputs/caches/other") + if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): + self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility") + + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + if not hasattr(config.get_text_config(), "use_cache"): + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + + # Let's make it always: + # 1. use cache (for obvious reasons) + # 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which + # would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the + # continuation would force it to generate beyond an EOS token) + # 3. ignore `token_type_ids` for simplicity + # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is + # active by default on some models + # 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When + # we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents + # repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls + # with cache, what is considered a prompt is different in the two cases. + + if "token_type_ids" in inputs: + del inputs["token_type_ids"] + + model = model_class(config).to(torch_device) + model.eval() + + # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) + outputs = model(**inputs) + if "past_key_values" not in outputs: + self.skipTest(reason="This model doesn't return `past_key_values`") + + generate_kwargs = { + "pad_token_id": -1, + "eos_token_id": -1, + "forced_eos_token_id": None, + "encoder_no_repeat_ngram_size": 0, + "use_cache": True, + "do_sample": False, + "return_dict_in_generate": True, + "output_scores": True, + } + + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values + outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4) + + # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the + # inputs may need to be tweaked across `generate` calls (like the attention mask). + outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3) + + # Continue from the tokens generated above, preparing the inputs accordingly + inputs["past_key_values"] = outputs_cached.past_key_values + new_attention_len = outputs_cached.sequences.shape[-1] + + # It must be encoder-decoder models + self.assertTrue(config.is_encoder_decoder) + + inputs["decoder_input_ids"] = outputs_cached.sequences + if "decoder_attention_mask" in inputs: + decoder_attention_mask = inputs["decoder_attention_mask"] + + # Add BOS mask: the new sequence comes with a new BOS token, which is not included in the original inputs + padding_tensor = torch.ones_like(decoder_attention_mask[:, :1]) + decoder_attention_mask = torch.cat([padding_tensor, decoder_attention_mask], dim=1) + + inputs["decoder_attention_mask"] = torch.nn.functional.pad( + decoder_attention_mask, + (0, new_attention_len - decoder_attention_mask.shape[1]), + mode="constant", + value=1, + ) + + first_caches_scores = outputs_cached.scores + outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1) + full_cached_scores = first_caches_scores + outputs_cached.scores + outputs_cached.scores = full_cached_scores + + # The two sets of generated text and past kv should be equal to each other + self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached)) + self._check_caches_are_equal(outputs.past_key_values, outputs_cached.past_key_values) + + @unittest.skip("T5Gemma 2 only support final layer hidden states.") + def test_hidden_states_output(self): + pass + + # Based on tests.models.t5.test_modeling_t5.T5ModelTest.test_custom_4d_attention_mask + # Excluding the final token from input_ids + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self._get_custom_4d_mask_test_data() + mask_shared_prefix = mask_shared_prefix == 0.0 + + outputs = model.forward( + decoder_input_ids=input_ids, + input_ids=input_ids[:, :-1], + decoder_position_ids=position_ids, + ) + logits = outputs.logits + # logits.shape == torch.Size([3, 4, ...]) + + outputs_shared_prefix = model( + input_ids=input_ids[:1, :-1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + decoder_position_ids=position_ids_shared_prefix, + ) + logits_shared_prefix = outputs_shared_prefix.logits + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + torch.testing.assert_close( + outputs.encoder_last_hidden_state[0], outputs_shared_prefix.encoder_last_hidden_state[0] + ) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0[2], normalized_1[2], rtol=1e-3, atol=1e-4) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + + @unittest.skip(reason="T5Gemma doesn't support flex masking") + def test_flex_attention_with_grads(self): + pass + + @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Self&cross attention are splited after the merged attention") + def test_retain_grad_hidden_states_attentions(self): + pass diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index e569a0fc7b5c..907651fd7378 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -273,6 +273,8 @@ "GPTNeoXConfig": ["rotary_emb_base"], "Gemma3Config": ["boi_token_index", "eoi_token_index"], "Gemma3TextConfig": ["cache_implementation", "tie_word_embeddings"], + "T5Gemma2TextConfig": ["tie_word_embeddings"], + "T5Gemma2DecoderConfig": ["tie_word_embeddings"], "ShieldGemma2Config": [ "boi_token_index", "eoi_token_index",