diff --git a/docs/source/en/main_classes/tokenizer.md b/docs/source/en/main_classes/tokenizer.md index 2ad7e450404e..83d2ae5df6a7 100644 --- a/docs/source/en/main_classes/tokenizer.md +++ b/docs/source/en/main_classes/tokenizer.md @@ -51,6 +51,25 @@ token space (e.g., getting the index of the token comprising a given character o to a given token). +# Multimodal Tokenizer + +Apart from that each tokenizer can be a "multimodal" tokenizer which means that the tokenizer will hold all relevant special tokens +as part of tokenizer attributes for easier access. For example, if the tokenizer is loaded from a vision-language model like LLaVA, you will +be able to access `tokenizer.image_token_id` to obtain the special image token used as a placeholder. + +To enable extra special tokens for any type of tokenizer, you have to add the following lines and save the tokenizer. Extra special tokens do not +have to be modality related and can ne anything that the model often needs access to. In the below code, tokenizer at `output_dir` will have direct access +to three more special tokens. + +```python +vision_tokenizer = AutoTokenizer.from_pretrained( + "llava-hf/llava-1.5-7b-hf", + extra_special_tokens={"image_token": "", "boi_token": "", "eoi_token": ""} +) +print(vision_tokenizer.image_token, vision_tokenizer.image_token_id) +("", 32000) +``` + ## PreTrainedTokenizer [[autodoc]] PreTrainedTokenizer diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index cc80f6a19bfb..9e75e6fd3c38 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -443,7 +443,7 @@ def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] return torch.stack(examples, dim=0) # If yes, check if we have a `pad_token`. - if tokenizer._pad_token is None: + if tokenizer.pad_token is None: raise ValueError( "You are attempting to pad samples but the tokenizer you are using" f" ({tokenizer.__class__.__name__}) does not have a pad token." @@ -477,7 +477,7 @@ def _tf_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = N return tf.stack(examples, axis=0) # If yes, check if we have a `pad_token`. - if tokenizer._pad_token is None: + if tokenizer.pad_token is None: raise ValueError( "You are attempting to pad samples but the tokenizer you are using" f" ({tokenizer.__class__.__name__}) does not have a pad token." @@ -513,7 +513,7 @@ def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] return np.stack(examples, axis=0) # If yes, check if we have a `pad_token`. - if tokenizer._pad_token is None: + if tokenizer.pad_token is None: raise ValueError( "You are attempting to pad samples but the tokenizer you are using" f" ({tokenizer.__class__.__name__}) does not have a pad token." @@ -1090,7 +1090,7 @@ def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]: self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() ] probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) - if self.tokenizer._pad_token is not None: + if self.tokenizer.pad_token is not None: padding_mask = labels.eq(self.tokenizer.pad_token_id) probability_matrix.masked_fill_(padding_mask, value=0.0) @@ -1131,7 +1131,7 @@ def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]: self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels ] masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool) - if self.tokenizer._pad_token is not None: + if self.tokenizer.pad_token is not None: padding_mask = inputs == self.tokenizer.pad_token_id masked_indices = masked_indices & ~padding_mask @@ -1170,7 +1170,7 @@ def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]: self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() ] masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0 - if self.tokenizer._pad_token is not None: + if self.tokenizer.pad_token is not None: padding_mask = labels == self.tokenizer.pad_token_id masked_indices[padding_mask] = 0 @@ -1251,13 +1251,13 @@ def mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any]: self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() ] probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) - if self.tokenizer._pad_token is not None: + if self.tokenizer.pad_token is not None: padding_mask = labels.eq(self.tokenizer.pad_token_id) probability_matrix.masked_fill_(padding_mask, value=0.0) masked_indices = torch.bernoulli(probability_matrix).bool() # probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value attention_mask = (~masked_indices).float() - if self.tokenizer._pad_token is not None: + if self.tokenizer.pad_token is not None: attention_padding_mask = labels.eq(self.tokenizer.pad_token_id) attention_mask.masked_fill_(attention_padding_mask, value=1.0) labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute @@ -1367,7 +1367,7 @@ def torch_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]: dtype=torch.bool, ) masked_indices.masked_fill_(special_tokens_mask, value=0.0) - if self.tokenizer._pad_token is not None: + if self.tokenizer.pad_token is not None: padding_mask = labels.eq(self.tokenizer.pad_token_id) masked_indices.masked_fill_(padding_mask, value=0.0) @@ -1471,7 +1471,7 @@ def tf_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]: ) special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool) masked_indices = masked_indices & ~special_tokens_mask - if self.tokenizer._pad_token is not None: + if self.tokenizer.pad_token is not None: padding_mask = labels == self.tokenizer.pad_token_id masked_indices = masked_indices & ~padding_mask @@ -1571,7 +1571,7 @@ def numpy_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]: dtype=bool, ) masked_indices[special_tokens_mask] = 0 - if self.tokenizer._pad_token is not None: + if self.tokenizer.pad_token is not None: padding_mask = labels == self.tokenizer.pad_token_id masked_indices[padding_mask] = 0.0 diff --git a/src/transformers/models/blip_2/processing_blip_2.py b/src/transformers/models/blip_2/processing_blip_2.py index fa6a99f71a46..c68523784128 100644 --- a/src/transformers/models/blip_2/processing_blip_2.py +++ b/src/transformers/models/blip_2/processing_blip_2.py @@ -74,8 +74,11 @@ class Blip2Processor(ProcessorMixin): def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs): tokenizer.return_token_type_ids = False self.current_processor = image_processor - self.image_token = AddedToken("", normalized=False, special=True) - tokenizer.add_tokens([self.image_token], special_tokens=True) + if not hasattr(tokenizer, "image_token"): + self.image_token = AddedToken("", normalized=False, special=True) + tokenizer.add_tokens([self.image_token], special_tokens=True) + else: + self.image_token = tokenizer.image_token self.num_query_tokens = num_query_tokens super().__init__(image_processor, tokenizer) diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index 2d699c8f663a..e2a50d1af51b 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -66,9 +66,12 @@ class ChameleonProcessor(ProcessorMixin): def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""): self.image_seq_length = image_seq_length - self.image_token = image_token - self.image_start_token = "" # fixed tokens for start and end, so can hardcode - self.image_end_token = "" + self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token + self.image_start_token = ( + tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "" + ) # fixed tokens for start and end, so can hardcode + self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "" + super().__init__(image_processor, tokenizer) def __call__( diff --git a/src/transformers/models/gemma/tokenization_gemma.py b/src/transformers/models/gemma/tokenization_gemma.py index ff0d1d034c22..7138cafbd625 100644 --- a/src/transformers/models/gemma/tokenization_gemma.py +++ b/src/transformers/models/gemma/tokenization_gemma.py @@ -138,7 +138,7 @@ def __getstate__(self): return state def __setstate__(self, d): - self.__dict__ = d + self.__dict__.update(d) self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model.LoadFromSerializedProto(self.sp_model_proto) diff --git a/src/transformers/models/idefics/processing_idefics.py b/src/transformers/models/idefics/processing_idefics.py index 3406ab2226e0..ca6e4702d317 100644 --- a/src/transformers/models/idefics/processing_idefics.py +++ b/src/transformers/models/idefics/processing_idefics.py @@ -219,7 +219,11 @@ def __init__(self, image_processor, tokenizer=None, image_size=224, add_end_of_u super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor - self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + self.image_token_id = ( + tokenizer.image_token_id + if hasattr(tokenizer, "image_token") + else tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + ) self.default_image_dims = ( self.image_processor.image_num_channels, diff --git a/src/transformers/models/idefics2/processing_idefics2.py b/src/transformers/models/idefics2/processing_idefics2.py index 9a041257c36b..f99c1bda4745 100644 --- a/src/transformers/models/idefics2/processing_idefics2.py +++ b/src/transformers/models/idefics2/processing_idefics2.py @@ -95,16 +95,19 @@ def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, cha if tokenizer is None: raise ValueError("You need to specify a `tokenizer`.") - self.fake_image_token = AddedToken("", normalized=False, special=True) - self.image_token = AddedToken("", normalized=False, special=True) + if not hasattr(tokenizer, "image_token"): + self.fake_image_token = AddedToken("", normalized=False, special=True) + self.image_token = AddedToken("", normalized=False, special=True) + tokens_to_add = {"additional_special_tokens": [self.fake_image_token, self.image_token]} + tokenizer.add_special_tokens(tokens_to_add) + else: + self.fake_image_token = tokenizer.image_boundary_token + self.image_token = tokenizer.image_token + self.end_of_utterance_token = AddedToken("", normalized=False, special=True) + tokenizer.add_special_tokens({"additional_special_tokens": [self.end_of_utterance_token]}) self.image_seq_len = image_seq_len - tokens_to_add = { - "additional_special_tokens": [self.fake_image_token, self.image_token, self.end_of_utterance_token] - } - tokenizer.add_special_tokens(tokens_to_add) - super().__init__(image_processor, tokenizer, chat_template=chat_template) def _extract_images_from_prompts(self, prompts): diff --git a/src/transformers/models/instructblip/processing_instructblip.py b/src/transformers/models/instructblip/processing_instructblip.py index 05ff9871f4d7..3d48839d376c 100644 --- a/src/transformers/models/instructblip/processing_instructblip.py +++ b/src/transformers/models/instructblip/processing_instructblip.py @@ -78,8 +78,11 @@ class InstructBlipProcessor(ProcessorMixin): qformer_tokenizer_class = "AutoTokenizer" def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs): - self.image_token = AddedToken("", normalized=False, special=True) - tokenizer.add_tokens([self.image_token], special_tokens=True) + if not hasattr(tokenizer, "image_token"): + self.image_token = AddedToken("", normalized=False, special=True) + tokenizer.add_tokens([self.image_token], special_tokens=True) + else: + self.image_token = tokenizer.image_token self.num_query_tokens = num_query_tokens super().__init__(image_processor, tokenizer, qformer_tokenizer) diff --git a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py index 3e96d279a42f..1d4e59e26b46 100644 --- a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py @@ -63,8 +63,11 @@ class InstructBlipVideoProcessor(ProcessorMixin): qformer_tokenizer_class = "AutoTokenizer" def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs): - self.video_token = AddedToken("