diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0198cdd33711..ff096422bf13 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1036,6 +1036,8 @@ title: ColQwen2 - local: model_doc/data2vec title: Data2Vec + - local: model_doc/deepseek_ocr + title: DeepSeekOCR - local: model_doc/deepseek_vl title: DeepseekVL - local: model_doc/deepseek_vl_hybrid diff --git a/docs/source/en/model_doc/deepseek_ocr.md b/docs/source/en/model_doc/deepseek_ocr.md new file mode 100644 index 000000000000..2a9cda22bf53 --- /dev/null +++ b/docs/source/en/model_doc/deepseek_ocr.md @@ -0,0 +1,123 @@ + + + +# DeepSeekOCR + +## Overview + +DeepSeekOCR is a vision-language model designed for optical character recognition (OCR) tasks. The model combines dual vision encoders (SAM and CLIP) with a language model to process both text and images for generating contextually relevant OCR outputs, including document understanding, grounding, and markdown conversion. +The model uses a modified [DeepSeek-V2](./deepseek_v2) as its text decoder. + +### Usage tips + +The example below demonstrates how to perform OCR with grounding on a document image using the [`AutoModel`] class. + +```py +import torch +from PIL import Image +from transformers import AutoModel, AutoProcessor + +processor = AutoProcessor.from_pretrained("deepseek-ai/deepseek-ocr") +model = AutoModel.from_pretrained("deepseek-ai/deepseek-ocr", torch_dtype=torch.bfloat16) + +image = Image.open("document.png").convert("RGB") + +conversation = [ + { + "role": "<|User|>", + "content": [ + {"type": "image", "path": "./document.png"}, + {"type": "text", "text": "<|grounding|>Convert the document to markdown."}, + ], + } +] + +inputs = processor.apply_chat_template( + conversation, + return_dict=True, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt" +) + +with torch.no_grad(): + generated = model.generate(**inputs, max_new_tokens=250) + +text = processor.batch_decode(generated, skip_special_tokens=False)[0] +print(text) +``` + +## DeepseekOcrConfig + +[[autodoc]] DeepseekOcrConfig + +## DeepseekOcrVisionConfig + +[[autodoc]] DeepseekOcrVisionConfig + +## DeepseekOcrSamConfig + +[[autodoc]] DeepseekOcrSamConfig + +## DeepseekOcrCLIPVisionConfig + +[[autodoc]] DeepseekOcrCLIPVisionConfig + +## DeepseekOcrProjectorConfig + +[[autodoc]] DeepseekOcrProjectorConfig + +## DeepseekOcrProcessor + +[[autodoc]] DeepseekOcrProcessor + +## DeepseekOcrImageProcessorFast + +[[autodoc]] DeepseekOcrImageProcessorFast + +## DeepseekOcrModelOutputWithPast + +[[autodoc]] DeepseekOcrModelOutputWithPast + +## DeepseekOcrCausalLMOutputWithPast + +[[autodoc]] DeepseekOcrCausalLMOutputWithPast + +## DeepseekOcrTextModel + +[[autodoc]] DeepseekOcrTextModel + - forward + +## DeepseekOcrCLIPVisionModel + +[[autodoc]] DeepseekOcrCLIPVisionModel + - forward + +## DeepseekOcrProjector + +[[autodoc]] DeepseekOcrProjector + - forward + +## DeepseekOcrModel + +[[autodoc]] DeepseekOcrModel + - forward + +## DeepseekOcrForConditionalGeneration + +[[autodoc]] DeepseekOcrForConditionalGeneration + - forward diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5630063f92ec..38850931e5a3 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -85,6 +85,7 @@ from .deberta import * from .deberta_v2 import * from .decision_transformer import * + from .deepseek_ocr import * from .deepseek_v2 import * from .deepseek_v3 import * from .deepseek_vl import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7e2e84a445ef..5bcb5192bb83 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -104,6 +104,7 @@ ("deberta", "DebertaConfig"), ("deberta-v2", "DebertaV2Config"), ("decision_transformer", "DecisionTransformerConfig"), + ("deepseek_ocr", "DeepseekOcrConfig"), ("deepseek_v2", "DeepseekV2Config"), ("deepseek_v3", "DeepseekV3Config"), ("deepseek_vl", "DeepseekVLConfig"), @@ -542,6 +543,7 @@ ("deberta", "DeBERTa"), ("deberta-v2", "DeBERTa-v2"), ("decision_transformer", "Decision Transformer"), + ("deepseek_ocr", "DeepSeek-OCR"), ("deepseek_v2", "DeepSeek-V2"), ("deepseek_v3", "DeepSeek-V3"), ("deepseek_vl", "DeepseekVL"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 197029464efd..eea0eaf15b0e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -110,6 +110,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("deberta", "DebertaModel"), ("deberta-v2", "DebertaV2Model"), ("decision_transformer", "DecisionTransformerModel"), + ("deepseek_ocr", "DeepseekOcrForConditionalGeneration"), ("deepseek_v2", "DeepseekV2Model"), ("deepseek_v3", "DeepseekV3Model"), ("deepseek_vl", "DeepseekVLModel"), @@ -1020,6 +1021,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("blip-2", "Blip2ForConditionalGeneration"), ("chameleon", "ChameleonForConditionalGeneration"), ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), + ("deepseek_ocr", "DeepseekOcrForConditionalGeneration"), ("deepseek_vl", "DeepseekVLForConditionalGeneration"), ("deepseek_vl_hybrid", "DeepseekVLHybridForConditionalGeneration"), ("emu3", "Emu3ForConditionalGeneration"), diff --git a/src/transformers/models/deepseek_ocr/__init__.py b/src/transformers/models/deepseek_ocr/__init__.py new file mode 100644 index 000000000000..560b6da667fd --- /dev/null +++ b/src/transformers/models/deepseek_ocr/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_deepseek_ocr import * + from .modeling_deepseek_ocr import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/deepseek_ocr/configuration_deepseek_ocr.py b/src/transformers/models/deepseek_ocr/configuration_deepseek_ocr.py new file mode 100644 index 000000000000..e9e8b926acd7 --- /dev/null +++ b/src/transformers/models/deepseek_ocr/configuration_deepseek_ocr.py @@ -0,0 +1,487 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deepseek_ocr/modular_deepseek_ocr.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deepseek_ocr.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 Deepseek-AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params +from ..auto import CONFIG_MAPPING, AutoConfig + + +class DeepseekOcrSamConfig(PreTrainedConfig): + model_type = "deepseek_ocr_sam_vision" + base_config_key = "sam_config" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + layer_norm_eps=1e-6, + attention_dropout=0.0, + initializer_range=1e-10, + qkv_bias=True, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + global_attn_indexes=None, + mlp_ratio=4.0, + output_channels=256, + downsample_channels=None, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.qkv_bias = qkv_bias + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes if global_attn_indexes is not None else [2, 5, 8, 11] + self.mlp_ratio = mlp_ratio + self.output_channels = output_channels + self.downsample_channels = downsample_channels if downsample_channels is not None else [512, 1024] + self.mlp_dim = int(hidden_size * mlp_ratio) + self.out_channels = output_channels + + +class DeepseekOcrCLIPVisionConfig(PreTrainedConfig): + model_type = "deepseek_ocr_clip_vision" + base_config_key = "clip_vision_config" + + def __init__( + self, + hidden_size=1024, + intermediate_size=4096, + projection_dim=768, + num_hidden_layers=24, + num_attention_heads=16, + num_channels=3, + image_size=224, + patch_size=14, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + + +class DeepseekOcrProjectorConfig(PreTrainedConfig): + model_type = "deepseek_ocr_projector" + base_config_key = "projector_config" + + def __init__( + self, + input_dim=2048, + n_embed=1280, + projector_type="linear", + depth=1, + **kwargs, + ): + super().__init__(**kwargs) + self.input_dim = input_dim + self.n_embed = n_embed + self.projector_type = projector_type + self.depth = depth + + +class DeepseekOcrVisionConfig(PreTrainedConfig): + model_type = "deepseek_ocr_vision" + base_config_key = "vision_config" + sub_configs = { + "sam_config": DeepseekOcrSamConfig, + "clip_config": DeepseekOcrCLIPVisionConfig, + } + + def __init__(self, sam_config=None, clip_config=None, **kwargs): + super().__init__(**kwargs) + + if sam_config is None: + self.sam_config = DeepseekOcrSamConfig() + elif isinstance(sam_config, dict): + self.sam_config = DeepseekOcrSamConfig(**sam_config) + else: + self.sam_config = sam_config + + if clip_config is None: + self.clip_config = DeepseekOcrCLIPVisionConfig() + elif isinstance(clip_config, dict): + self.clip_config = DeepseekOcrCLIPVisionConfig(**clip_config) + else: + self.clip_config = clip_config + + # Aggregate commonly accessed vision attributes. + self.image_size = self.sam_config.image_size + self.patch_size = self.sam_config.patch_size + + +class DeepseekOcrTextConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekOcrTextModel`]. It is used to instantiate a DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of DeepSeek-V2-Lite" [deepseek-ai/DeepSeek-V2-Lite"](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite"). + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the DeepSeek model. Defines the number of different tokens that can be represented by the + `input_ids` passed when calling [`DeepseekOcrTextModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + The number of key-value heads used to implement Grouped Query Attention (GQA). If + `num_key_value_heads=num_attention_heads`, the model will use Multi-Head Attention (MHA). If + `num_key_value_heads=1`, the model will use Multi-Query Attention (MQA). Otherwise, GQA is used. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated normal initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon value used by the RMS normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/value attentions (useful for inference optimization). + pad_token_id (`int`, *optional*): + Padding token ID. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning-of-sequence token ID. + eos_token_id (`int`, *optional*, defaults to 2): + End-of-sequence token ID. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie input and output embeddings. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value, and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability applied to attention weights. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias term in the MLP layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in the shallow layers before switching to MoE layers. + kv_lora_rank (`int`, *optional*, defaults to 512): + Rank of the LoRA decomposition for key-value projections. + q_lora_rank (`int`, *optional*, defaults to 1536): + Rank of the LoRA decomposition for query projections. + Specifically, it determines the dimensionality to which the query (q) vectors are compressed before being expanded back to their original size. + It reduces computational overhead while maintaining model performance. + n_group (`int`, *optional*): + Number of groups for routed experts. + n_routed_experts (`int`, *optional*, defaults to 64): + Number of routed experts (None indicates a dense model). + n_shared_experts (`int`, *optional*, defaults to 2): + Number of shared experts (None indicates a dense model). + qk_nope_head_dim (`int`, *optional*, defaults to 128): + The head dimension for the QK (query-key) projections when using NOPE (Neural Operator Position Encoding). + qk_rope_head_dim (`int`, *optional*, defaults to 64): + The head dimension for QK projections when using RoPE. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for routed experts in MoE models. + topk_group (`int`, *optional*): + Number of selected groups per token for expert selection. + topk_method (`str`, *optional*, defaults to `"greedy"`): + The method used for selecting top-k experts in the routed gate mechanism. + v_head_dim (`int`, *optional*, defaults to 128): + The dimension of value projections in the attention layers. + num_experts_per_tok (`int`, *optional*): + The number of experts selected per token. If `None`, the model behaves as a dense Transformer. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE (Mixture of Experts) representations. + + ```python + >>> from transformers import DeepseekOcrTextModel, DeepseekOcrTextConfig + >>> # Initializing a DeepSeek-V2 style configuration + >>> configuration = DeepseekOcrTextConfig() + >>> # Accessing the model configuration + >>> model = DeepseekOcrTextModel(configuration) + >>> print(model.config) + ``` + """ + + model_type = "deepseek_ocr_text" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.q_a_proj": "colwise", + "layers.*.self_attn.q_b_proj": "colwise", + "layers.*.self_attn.kv_b_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size: Optional[int] = 32000, + hidden_size: Optional[int] = 4096, + intermediate_size: Optional[int] = 11008, + num_hidden_layers: Optional[int] = 32, + num_attention_heads: Optional[int] = 32, + num_key_value_heads: Optional[int] = None, + hidden_act: Optional[str] = "silu", + max_position_embeddings: Optional[int] = 2048, + initializer_range: Optional[float] = 0.02, + rms_norm_eps: Optional[int] = 1e-6, + use_cache: Optional[bool] = True, + pad_token_id: Optional[int] = None, + bos_token_id: Optional[int] = 1, + eos_token_id: Optional[int] = 2, + tie_word_embeddings: Optional[bool] = False, + rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None, + attention_bias: Optional[bool] = False, + attention_dropout: Optional[float] = 0.0, + mlp_bias: Optional[bool] = False, + first_k_dense_replace: Optional[int] = 0, + kv_lora_rank: Optional[int] = 512, + q_lora_rank: Optional[int] = 1536, + n_group: Optional[int] = None, + n_routed_experts: Optional[int] = 64, + n_shared_experts: Optional[int] = 2, + qk_nope_head_dim: Optional[int] = 128, + qk_rope_head_dim: Optional[int] = 64, + routed_scaling_factor: Optional[float] = 1.0, + topk_group: Optional[int] = None, + topk_method: Optional[str] = "greedy", + v_head_dim: Optional[int] = 128, + num_experts_per_tok: Optional[int] = None, + moe_intermediate_size: Optional[int] = 1407, + **kwargs, + ): + self.first_k_dense_replace = first_k_dense_replace + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.n_group = n_group + self.n_routed_experts = n_routed_experts + self.n_shared_experts = n_shared_experts + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.routed_scaling_factor = routed_scaling_factor + self.topk_group = topk_group + self.topk_method = topk_method + self.v_head_dim = v_head_dim + self.num_experts_per_tok = num_experts_per_tok + self.moe_intermediate_size = moe_intermediate_size + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + self.head_dim = qk_rope_head_dim + # Try to set `rope_scaling` if available, otherwise use `rope_parameters` + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or rope_parameters + + # Validate the correctness of rotary position embeddings parameters + rope_theta = kwargs.get("rope_theta", 10000.0) + standardize_rope_params(self, rope_theta=rope_theta) + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class DeepseekOcrConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekOcrForConditionalGeneration`]. It is used to instantiate a + DeepseekOCR model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the DeepseekOCR + [deepseek-ai/deepseek-ocr](https://huggingface.co/deepseek-ai/deepseek-ocr) architecture. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `DeepseekV2Config`): + The config object or dictionary of the text backbone (DeepSeek-V2). + vision_config (`DeepseekOcrVisionConfig` or `dict`, *optional*): + The config object or dictionary of the vision encoders (SAM and CLIP). + projector_config (`DeepseekOcrProjectorConfig` or `dict`, *optional*): + The config object or dictionary of the projector that maps vision features to text embedding space. + candidate_resolutions (`list`, *optional*, defaults to `[[1024, 1024]]`): + List of candidate image resolutions for adaptive image processing. + global_view_pos (`str`, *optional*, defaults to `"head"`): + Position of the global view in the image sequence. + tile_tag (`str`, *optional*, defaults to `"2D"`): + Tag format for image tiles. + image_token_index (`int`, *optional*, defaults to 100015): + The index representing image tokens in the model's token vocabulary. + + Example: + + ```python + >>> from transformers import DeepseekOcrConfig, DeepseekOcrForConditionalGeneration + + >>> # Initializing a DeepseekOCR configuration + >>> configuration = DeepseekOcrConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = DeepseekOcrForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_ocr" + sub_configs = { + "text_config": AutoConfig, + "vision_config": DeepseekOcrVisionConfig, + "projector_config": DeepseekOcrProjectorConfig, + } + + def __init__( + self, + text_config=None, + vision_config=None, + projector_config=None, + candidate_resolutions=None, + global_view_pos="head", + tile_tag="2D", + image_token_index=100015, + image_grid_pinpoints=None, + vision_feature_layer=None, + vision_feature_select_strategy="default", + **kwargs, + ): + if candidate_resolutions is None: + candidate_resolutions = [[1024, 1024]] + + self.candidate_resolutions = candidate_resolutions + self.global_view_pos = global_view_pos + self.tile_tag = tile_tag + self.image_token_index = image_token_index + self.image_token_id = image_token_index + self.image_grid_pinpoints = image_grid_pinpoints if image_grid_pinpoints is not None else [[1024, 1024]] + self.vision_feature_layer = vision_feature_layer + self.vision_feature_select_strategy = vision_feature_select_strategy + + if text_config is None: + text_config = CONFIG_MAPPING["deepseek_v2"]( + hidden_size=1280, + intermediate_size=6848, + num_hidden_layers=12, + num_attention_heads=10, + num_key_value_heads=10, + moe_intermediate_size=896, + n_routed_experts=64, + n_shared_experts=2, + num_experts_per_tok=6, + first_k_dense_replace=1, + vocab_size=129280, + max_position_embeddings=8192, + use_mla=False, + ) + elif isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "deepseek_v2") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + + self.text_config = text_config + + if vision_config is None: + self.vision_config = DeepseekOcrVisionConfig() + elif isinstance(vision_config, dict): + self.vision_config = DeepseekOcrVisionConfig(**vision_config) + else: + self.vision_config = vision_config + + if projector_config is None: + self.projector_config = DeepseekOcrProjectorConfig() + elif isinstance(projector_config, dict): + self.projector_config = DeepseekOcrProjectorConfig(**projector_config) + else: + self.projector_config = projector_config + + self.hidden_size = self.text_config.hidden_size + self.vocab_size = self.text_config.vocab_size + + super().__init__(**kwargs) + + +__all__ = [ + "DeepseekOcrConfig", + "DeepseekOcrVisionConfig", + "DeepseekOcrSamConfig", + "DeepseekOcrCLIPVisionConfig", + "DeepseekOcrProjectorConfig", +] diff --git a/src/transformers/models/deepseek_ocr/convert_deepseek_ocr_weights_to_hf.py b/src/transformers/models/deepseek_ocr/convert_deepseek_ocr_weights_to_hf.py new file mode 100644 index 000000000000..b0a680b35665 --- /dev/null +++ b/src/transformers/models/deepseek_ocr/convert_deepseek_ocr_weights_to_hf.py @@ -0,0 +1,240 @@ +# Copyright 2025 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import re +from pathlib import Path +from textwrap import dedent + +import torch +from safetensors.torch import load_file + +from transformers import ( + AutoTokenizer, + DeepseekOcrConfig, + DeepseekOcrForConditionalGeneration, + DeepseekOcrImageProcessorFast, + DeepseekOcrProcessor, +) + + +CHAT_TEMPLATE = dedent( + """ + {%- for message in messages %} + {%- if message['content'] is string %} +{{ message['content'].rstrip() }} + {%- else %} + {%- set ns = namespace(previous_was_image=False) %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + + {%- set ns.previous_was_image = True %} + {%- elif content['type'] == 'text' %} +{{- ('\n' if ns.previous_was_image else '') + content['text'].rstrip() }} + {%- set ns.previous_was_image = False %} + {%- endif %} + {%- endfor %} + {%- endif %} + {%- if not loop.last %} + + {%- endif %} + {%- endfor %} + """ +).strip() + + +# fmt: off +STATE_DICT_MAPPING = { + r"^model\.sam_model\.patch_embed\.proj\.(weight|bias)": r"model.sam_model.patch_embed.projection.\1", + r"^model\.sam_model\.blocks\.(\d+)\.norm(\d+)\.(weight|bias)": r"model.sam_model.layers.\1.layer_norm\2.\3", + r"^model\.sam_model\.blocks\.(\d+)\.attn\.qkv\.(weight|bias)": r"model.sam_model.layers.\1.attn.qkv.\2", + r"^model\.sam_model\.blocks\.(\d+)\.attn\.proj\.(weight|bias)": r"model.sam_model.layers.\1.attn.proj.\2", + r"^model\.sam_model\.blocks\.(\d+)\.attn\.rel_pos_([hw])": r"model.sam_model.layers.\1.attn.rel_pos_\2", + r"^model\.sam_model\.blocks\.(\d+)\.mlp\.lin(\d+)\.(weight|bias)": r"model.sam_model.layers.\1.mlp.lin\2.\3", + r"^model\.sam_model\.neck\.0\.weight": r"model.sam_model.neck.conv1.weight", + r"^model\.sam_model\.neck\.1\.(weight|bias)": r"model.sam_model.neck.layer_norm1.\1", + r"^model\.sam_model\.neck\.2\.weight": r"model.sam_model.neck.conv2.weight", + r"^model\.sam_model\.neck\.3\.(weight|bias)": r"model.sam_model.neck.layer_norm2.\1", + r"^model\.sam_model\.net_2\.weight": r"model.sam_model.net_2.weight", + r"^model\.sam_model\.net_3\.weight": r"model.sam_model.net_3.weight", + r"^model\.sam_model\.pos_embed": r"model.sam_model.pos_embed", + + r"^model\.vision_model\.embeddings\.class_embedding": r"model.clip_model.vision_model.embeddings.class_embedding", + r"^model\.vision_model\.embeddings\.patch_embedding\.weight": r"model.clip_model.vision_model.embeddings.patch_embedding.weight", + r"^model\.vision_model\.embeddings\.position_embedding\.weight": r"model.clip_model.vision_model.embeddings.position_embedding.weight", + r"^model\.vision_model\.pre_layrnorm\.(weight|bias)": r"model.clip_model.vision_model.pre_layrnorm.\1", + r"^model\.vision_model\.transformer\.layers\.(\d+)\.layer_norm(\d+)\.(weight|bias)": r"model.clip_model.vision_model.encoder.layers.\1.layer_norm\2.\3", + r"^model\.vision_model\.transformer\.layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)": r"model.clip_model.vision_model.encoder.layers.\1.self_attn.qkv_proj.\2", + r"^model\.vision_model\.transformer\.layers\.(\d+)\.self_attn\.out_proj\.(weight|bias)": r"model.clip_model.vision_model.encoder.layers.\1.self_attn.out_proj.\2", + r"^model\.vision_model\.transformer\.layers\.(\d+)\.mlp\.fc(\d+)\.(weight|bias)": r"model.clip_model.vision_model.encoder.layers.\1.mlp.fc\2.\3", + r"^model\.vision_model\.post_layernorm\.(weight|bias)": r"model.clip_model.vision_model.post_layernorm.\1", + + r"^model\.projector\.layers\.(weight|bias)": r"model.multi_modal_projector.layers.\1", + + r"^model\.embed_tokens\.weight": r"model.language_model.embed_tokens.weight", + r"^model\.layers\.(\d+)\.input_layernorm\.weight": r"model.language_model.layers.\1.input_layernorm.weight", + r"^model\.layers\.(\d+)\.post_attention_layernorm\.weight": r"model.language_model.layers.\1.post_attention_layernorm.weight", + r"^model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj\.weight": r"model.language_model.layers.\1.self_attn.\2_proj.weight", + r"^model\.layers\.(\d+)\.mlp\.(gate|up|down)_proj\.weight": r"model.language_model.layers.\1.mlp.\2_proj.weight", + r"^model\.layers\.(\d+)\.mlp\.(gate|up|down)\.(weight|bias)": r"model.language_model.layers.\1.mlp.\2.\3", + r"^model\.norm\.weight": r"model.language_model.norm.weight", + r"^model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.(gate|up|down)_proj\.(weight|bias)": r"model.language_model.layers.\1.mlp.experts.\2.\3_proj.\4", + r"^model\.layers\.(\d+)\.mlp\.shared_experts\.(\d+)\.(gate|up|down)_proj\.(weight|bias)": r"model.language_model.layers.\1.mlp.shared_experts.\2.\3_proj.\4", + r"^model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.(gate|up|down)\.(weight|bias)": r"model.language_model.layers.\1.mlp.experts.\2.\3.\4", + r"^model\.layers\.(\d+)\.mlp\.shared_experts\.(\d+)\.(gate|up|down)\.(weight|bias)": r"model.language_model.layers.\1.mlp.shared_experts.\2.\3.\4", + r"^model\.layers\.(\d+)\.mlp\.shared_experts\.(gate|up|down)_proj\.(weight|bias)": r"model.language_model.layers.\1.mlp.shared_experts.\2_proj.\3", + r"^model\.layers\.(\d+)\.mlp\.shared_experts\.(gate|up|down)\.(weight|bias)": r"model.language_model.layers.\1.mlp.shared_experts.\2.\3", + + r"^model\.image_newline": r"model.image_newline", + r"^model\.view_seperator": r"model.view_seperator", + + r"^lm_head\.weight": r"lm_head.weight", +} +# fmt: on + + +def map_old_key_to_new(old_key): + for pattern, replacement in STATE_DICT_MAPPING.items(): + new_key, n_replace = re.subn(pattern, replacement, old_key) + if n_replace > 0: + return new_key + + raise ValueError(f"Key: {old_key} could not be mapped (check the mapping).") + + +def split_qkv_weights(key, tensor, num_heads, hidden_size): + if "qkv_proj.weight" in key: + q, k, v = torch.split(tensor, hidden_size, dim=0) + return { + key.replace("qkv_proj.weight", "q_proj.weight"): q, + key.replace("qkv_proj.weight", "k_proj.weight"): k, + key.replace("qkv_proj.weight", "v_proj.weight"): v, + } + elif "qkv_proj.bias" in key: + q, k, v = torch.split(tensor, hidden_size, dim=0) + return { + key.replace("qkv_proj.bias", "q_proj.bias"): q, + key.replace("qkv_proj.bias", "k_proj.bias"): k, + key.replace("qkv_proj.bias", "v_proj.bias"): v, + } + + return {key: tensor} + + +def convert_state_dict(original_state_dict, config): + new_state_dict = {} + + clip_hidden_size = config.vision_config.clip_config.hidden_size + clip_num_heads = config.vision_config.clip_config.num_attention_heads + + for old_key, tensor in original_state_dict.items(): + new_key = map_old_key_to_new(old_key) + + if "qkv_proj" in new_key and "clip_model" in new_key: + split_dict = split_qkv_weights(new_key, tensor, clip_num_heads, clip_hidden_size) + new_state_dict.update(split_dict) + else: + new_state_dict[new_key] = tensor + + return new_state_dict + + +def main(): + parser = argparse.ArgumentParser(description="Convert DeepSeek OCR weights to HuggingFace format") + parser.add_argument( + "--original_checkpoint_path", + type=str, + required=True, + help="Path to the original checkpoint file (.safetensors)", + ) + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Path where to save the converted model", + ) + + args = parser.parse_args() + + checkpoint_path = Path(args.original_checkpoint_path) + output_path = Path(args.output_path) + output_path.mkdir(parents=True, exist_ok=True) + + config_path = checkpoint_path.parent / "config.json" + + print(f"Loading original checkpoint from {checkpoint_path}") + original_state_dict = load_file(checkpoint_path) + + if config_path.exists(): + print(f"Loading config from {config_path}") + with open(config_path, "r") as f: + config_dict = json.load(f) + if "language_config" in config_dict: + config_dict["text_config"] = config_dict.pop("language_config") + + if "text_config" in config_dict and "head_dim" not in config_dict["text_config"]: + text_config = config_dict["text_config"] + if "hidden_size" in text_config and "num_attention_heads" in text_config: + text_config["head_dim"] = text_config["hidden_size"] // text_config["num_attention_heads"] + + config = DeepseekOcrConfig(**config_dict) + else: + print("Config not found, using default config") + config = DeepseekOcrConfig() + + tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent) + image_token_id = tokenizer.convert_tokens_to_ids("") + if image_token_id is None: + raise ValueError("Tokenizer does not contain the token required for DeepSeek OCR.") + config.image_token_index = image_token_id + config.image_token_id = image_token_id + text_config = getattr(config, "text_config", None) + if text_config is not None and hasattr(text_config, "image_token_id"): + text_config.image_token_id = image_token_id + + print("Converting state dict...") + converted_state_dict = convert_state_dict(original_state_dict, config) + reference_dtype = next(iter(original_state_dict.values())).dtype + + print("Creating model...") + model = DeepseekOcrForConditionalGeneration(config) + model.to(dtype=reference_dtype) + + print("Loading converted state dict into model...") + missing_keys, unexpected_keys = model.load_state_dict(converted_state_dict, strict=True) + + if missing_keys: + print(f"Missing keys: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys: {unexpected_keys}") + + print(f"Saving converted model to {output_path}") + model.save_pretrained(output_path) + config.save_pretrained(output_path) + + print("Creating and saving processor...") + image_processor = DeepseekOcrImageProcessorFast() + processor = DeepseekOcrProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + chat_template=CHAT_TEMPLATE, + ) + processor.save_pretrained(output_path) + + print("Conversion complete!") + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/deepseek_ocr/image_processing_deepseek_ocr_fast.py b/src/transformers/models/deepseek_ocr/image_processing_deepseek_ocr_fast.py new file mode 100644 index 000000000000..ef0df8b75fec --- /dev/null +++ b/src/transformers/models/deepseek_ocr/image_processing_deepseek_ocr_fast.py @@ -0,0 +1,352 @@ +# Copyright 2025 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union + +import numpy as np +import torch + +# TODO protect this import +from PIL import Image, ImageDraw, ImageFont, ImageOps +from torchvision import transforms +from torchvision.transforms.v2 import functional as F +from torchvision.transforms.v2.functional import to_pil_image + +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + BatchFeature, + Unpack, +) +from ...image_utils import ImageInput, PILImageResampling, SizeDict +from ...processing_utils import ImagesKwargs +from ...utils import TensorType, auto_docstring, logging + + +logger = logging.get_logger(__name__) + + +class DeepseekOcrImageProcessorKwargs(ImagesKwargs, total=False): + r""" + patch_size (`int`, *optional*): + The size of the patch. + base_size (`int`, *optional*): + The base size for the global image view. + dynamic_hd (`int`, *optional*): + The maximum number of crops per image. + """ + + patch_size: int + base_size: int + dynamic_hd: int + + +@auto_docstring +class DeepseekOcrImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + size = {"height": 1024, "width": 1024} + base_size = {"height": 1024, "width": 1024} + patch_size = 16 + dynamic_hd = 36 + image_mean = [0.5, 0.5, 0.5] + image_std = [0.5, 0.5, 0.5] + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + valid_kwargs = DeepseekOcrImageProcessorKwargs + model_input_names = ["pixel_values", "image_attention_mask", "image_spatial_crop"] + + def __init__(self, **kwargs: Unpack[DeepseekOcrImageProcessorKwargs]): + super().__init__(**kwargs) + + def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def dynamic_preprocess(self, image, patch_image_size, max_num=36, min_num=2): + """ + Dynamically preprocess images with aspect ratio handling. + + Returns: + processed_images: list of preprocessed image tensors + target_aspect_ratio: tuple (width_crops, height_crops) + """ + if not isinstance(image, Image.Image): + image = to_pil_image(image) + + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + } + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + target_aspect_ratio = self.find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, patch_image_size + ) + + target_width = patch_image_size * target_aspect_ratio[0] + target_height = patch_image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + resized_img = image.resize((target_width, target_height), resample=self.resample) + + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // patch_image_size)) * patch_image_size, + (i // (target_width // patch_image_size)) * patch_image_size, + ((i % (target_width // patch_image_size)) + 1) * patch_image_size, + ((i // (target_width // patch_image_size)) + 1) * patch_image_size, + ) + split_img = resized_img.crop(box) + processed_images.append(split_img) + + return processed_images, target_aspect_ratio + + def pad_to_max_num_crops(self, images, max_crops=5): + """Pad images tensor to max_crops.""" + B, _, H, W = images.shape + if B < max_crops: + pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device) + images = torch.cat([images, pad], dim=0) + return images + + @auto_docstring + def preprocess( + self, + images: ImageInput, + **kwargs: Unpack[DeepseekOcrImageProcessorKwargs], + ) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + size: SizeDict, + base_size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + patch_size: int, + dynamic_hd: int, + do_rescale: bool, + rescale_factor: Optional[float], + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ): + if not isinstance(size, SizeDict): + size = SizeDict(**size) + if not isinstance(base_size, SizeDict): + base_size = SizeDict(**base_size) + patch_image_size = size.height + base_image_size = base_size.height + downsample_ratio = 4 + + images_transformed = [] + images_spatial_crop = [] + images_tokens = [] + + to_tensor = transforms.ToTensor() + normalize = transforms.Normalize(mean=self.image_mean, std=self.image_std) + mean_fill = tuple(int(x * 255) for x in self.image_mean) + + for image in images: + if not isinstance(image, Image.Image): + image = to_pil_image(image) + if hasattr(ImageOps, "exif_transpose"): + image = ImageOps.exif_transpose(image) + if self.do_convert_rgb and image.mode != "RGB": + image = image.convert("RGB") + + orig_width, orig_height = image.size + # max_dim = max(orig_width, orig_height) + # min_dim = min(orig_width, orig_height) + + if orig_width <= patch_image_size and orig_height <= patch_image_size: + crop_ratio = [1, 1] + images_crop_raw = [] + else: + images_crop_raw, crop_ratio = self.dynamic_preprocess(image, patch_image_size, max_num=dynamic_hd) + + interp_mode = interpolation if interpolation is not None else self.resample + if isinstance(interp_mode, F.InterpolationMode): + interp_mode = getattr(PILImageResampling, interp_mode.name) + global_view = ImageOps.pad(image, (base_image_size, base_image_size), method=interp_mode, color=mean_fill) + if base_image_size != patch_image_size: + global_view = global_view.resize((patch_image_size, patch_image_size), interp_mode) + + global_view = normalize(to_tensor(global_view)).to(torch.bfloat16) + + width_crop_num, height_crop_num = crop_ratio + + if width_crop_num > 1 or height_crop_num > 1: + processed_crops = [] + for crop in images_crop_raw: + crop_tensor = normalize(to_tensor(crop)).to(torch.bfloat16) + processed_crops.append(crop_tensor) + + crops_tensor = torch.stack(processed_crops, dim=0) + else: + processed_crops = [] + crops_tensor = torch.empty(0, dtype=torch.bfloat16) + + num_queries_base = math.ceil((base_image_size // 16) / downsample_ratio) + num_queries = math.ceil((patch_image_size // 16) / downsample_ratio) + + tokenized_image_len = (num_queries_base + 1) * num_queries_base + 1 + if width_crop_num > 1 or height_crop_num > 1: + tokenized_image_len += (num_queries * width_crop_num + 1) * (num_queries * height_crop_num) + + if crops_tensor.numel() > 0: + hd_images = torch.cat([crops_tensor, global_view.unsqueeze(0)], dim=0) + else: + hd_images = global_view.unsqueeze(0) + + max_crops = hd_images.size(0) + hd_images = self.pad_to_max_num_crops(hd_images, max_crops) + + images_transformed.append(hd_images) + images_spatial_crop.append([width_crop_num, height_crop_num]) + images_tokens.append(tokenized_image_len) + + max_crops = max(img.size(0) for img in images_transformed) + images_transformed = [self.pad_to_max_num_crops(im, max_crops) for im in images_transformed] + images_transformed = torch.stack(images_transformed, dim=0) + images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) + + data = { + "pixel_values": images_transformed, + "image_spatial_crop": images_spatial_crop, + "num_img_tokens": images_tokens, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + def extract_coordinates_and_label(self, ref_text, image_width, image_height): + """Extract bounding box coordinates and label from model output.""" + try: + label_type = ref_text[1] + cor_list = eval(ref_text[2]) + except Exception as e: + logger.warning(f"Failed to extract coordinates: {e}") + return None + + return (label_type, cor_list) + + def visualize_results(self, image, ref_texts, output_path): + """ + Visualize results by drawing bounding boxes on the image. + + Args: + image: PIL Image + ref_texts: list of reference texts from model output + output_path: path to save the visualization + + Returns: + PIL Image with bounding boxes drawn + """ + image_width, image_height = image.size + + img_draw = image.copy() + draw = ImageDraw.Draw(img_draw) + + overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0)) + draw2 = ImageDraw.Draw(overlay) + + try: + font = ImageFont.load_default() + except Exception: + font = None + + img_idx = 0 + + for i, ref in enumerate(ref_texts): + try: + result = self.extract_coordinates_and_label(ref, image_width, image_height) + if result: + label_type, points_list = result + + color = ( + np.random.randint(0, 200), + np.random.randint(0, 200), + np.random.randint(0, 255), + ) + color_a = color + (20,) + + for points in points_list: + x1, y1, x2, y2 = points + + x1 = int(x1 / 999 * image_width) + y1 = int(y1 / 999 * image_height) + x2 = int(x2 / 999 * image_width) + y2 = int(y2 / 999 * image_height) + + if label_type == "image": + try: + cropped = image.crop((x1, y1, x2, y2)) + cropped.save(f"{output_path}/images/{img_idx}.jpg") + except Exception as e: + logger.warning(f"Failed to save cropped image: {e}") + img_idx += 1 + + try: + if label_type == "title": + draw.rectangle([x1, y1, x2, y2], outline=color, width=4) + draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) + else: + draw.rectangle([x1, y1, x2, y2], outline=color, width=2) + draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) + + text_x = x1 + text_y = max(0, y1 - 15) + + if font: + text_bbox = draw.textbbox((0, 0), label_type, font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + draw.rectangle( + [text_x, text_y, text_x + text_width, text_y + text_height], + fill=(255, 255, 255, 30), + ) + draw.text((text_x, text_y), label_type, font=font, fill=color) + except Exception as e: + logger.warning(f"Failed to draw bounding box: {e}") + except Exception as e: + logger.warning(f"Failed to process reference: {e}") + continue + + img_draw.paste(overlay, (0, 0), overlay) + return img_draw + + +__all__ = ["DeepseekOcrImageProcessorFast"] diff --git a/src/transformers/models/deepseek_ocr/modeling_deepseek_ocr.py b/src/transformers/models/deepseek_ocr/modeling_deepseek_ocr.py new file mode 100644 index 000000000000..690330ce99f6 --- /dev/null +++ b/src/transformers/models/deepseek_ocr/modeling_deepseek_ocr.py @@ -0,0 +1,2104 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deepseek_ocr/modular_deepseek_ocr.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deepseek_ocr.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 Deepseek-AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import math +from collections.abc import Callable +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int +from ...utils.generic import check_model_inputs +from .configuration_deepseek_ocr import ( + DeepseekOcrCLIPVisionConfig, + DeepseekOcrConfig, + DeepseekOcrTextConfig, + DeepseekOcrVisionConfig, +) + + +logger = logging.get_logger(__name__) + + +class DeepseekOcrPreTrainedModel(PreTrainedModel): + config_class = DeepseekOcrConfig + base_model_prefix = "model" + + +class DeepseekOcrProjector(PreTrainedModel): + """ + Projector that maps concatenated SAM + CLIP features to language model space. + """ + + def __init__(self, config): + super().__init__(config) + self.layers = nn.Linear(config.input_dim, config.n_embed) + + def forward(self, x): + return self.layers(x) + + +class DeepseekOcrLayerNorm(nn.LayerNorm): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(normalized_shape, eps=eps, **kwargs) + if data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {data_format}") + self.data_format = data_format + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """ + Args: + features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) + """ + if self.data_format == "channels_first": + features = features.permute(0, 2, 3, 1) + features = super().forward(features) + features = features.permute(0, 3, 1, 2) + else: + features = super().forward(features) + return features + + +class DeepseekOcrSamVisionNeck(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False) + self.layer_norm1 = DeepseekOcrLayerNorm(config.output_channels, data_format="channels_first") + self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False) + self.layer_norm2 = DeepseekOcrLayerNorm(config.output_channels, data_format="channels_first") + + def forward(self, hidden_states): + hidden_states = hidden_states.permute(0, 3, 1, 2) + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + return hidden_states + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Deepseek OCR model outputs with optional image hidden states. + """ +) +class DeepseekOcrModelOutputWithPast(BaseModelOutputWithPast): + r""" + image_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states extracted from the visual encoder and projected into the language embedding space. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Deepseek OCR causal language model outputs with image hidden states. + """ +) +class DeepseekOcrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modelling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modelling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states produced by the visual encoder after multimodal projection. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for deepseek_ocr vision model's outputs that also contains image embeddings obtained by applying the projection + layer to the pooler_output. + """ +) +class DeepseekOcrVisionEncoderOutput(ModelOutput): + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +class DeepseekOcrPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +class DeepseekOcrMLPBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim) + self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + +class DeepseekOcrVisionAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, window_size): + super().__init__() + input_size = ( + (config.image_size // config.patch_size, config.image_size // config.patch_size) + if window_size == 0 + else (window_size, window_size) + ) + + self.num_attention_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + self.scale = head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.use_rel_pos = config.use_rel_pos + if self.use_rel_pos: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`torch.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + def get_decomposed_rel_pos( + self, + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: tuple[int, int], + k_size: tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + query (`torch.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`torch.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`torch.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + decomposed_rel_pos (`torch.Tensor`): + decomposed relative position embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = query.shape + reshaped_query = query.reshape(batch_size, query_height, query_width, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + + decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + + return decomposed_rel_pos + + def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_rel_pos: + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights) + attn_weights = attn_weights + decomposed_rel_pos + + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + return attn_output, attn_weights + + +class DeepseekOcrVisionSdpaAttention(DeepseekOcrVisionAttention): + """ + Multi-head Attention block with relative position embeddings. + Using SDPA instead of the default attention. + """ + + def __init__(self, config, window_size): + super().__init__(config, window_size) + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + if output_attentions: + logger.warning_once( + "`DeepseekOcrVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_bias = None + if self.use_rel_pos: + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + decomposed_rel_pos = decomposed_rel_pos.reshape( + batch_size, self.num_attention_heads, height * width, height * width + ) + attn_bias = decomposed_rel_pos + + query = query.view(batch_size, self.num_attention_heads, height * width, -1) + key = key.view(batch_size, self.num_attention_heads, height * width, -1) + value = value.view(batch_size, self.num_attention_heads, height * width, -1) + + attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias) + + attn_output = ( + attn_output.view(batch_size, self.num_attention_heads, height, width, -1) + .permute(0, 2, 3, 1, 4) + .reshape(batch_size, height, width, -1) + ) + + attn_output = self.proj(attn_output) + return attn_output, None + + +DEEPSEEK_OCR_VISION_ATTENTION_CLASSES = { + "eager": DeepseekOcrVisionAttention, + "sdpa": DeepseekOcrVisionSdpaAttention, +} + + +class DeepseekOcrVisionLayer(GradientCheckpointingLayer): + def __init__(self, config, window_size): + super().__init__() + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn = DEEPSEEK_OCR_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = DeepseekOcrMLPBlock(config) + self.window_size = window_size + + def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]: + """ + Args: + Partition into non-overlapping windows with padding if needed. + hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window + size. + + Returns: + windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel]. + (pad_height, pad_width): padded height and width before partition + """ + batch_size, height, width, channel = hidden_states.shape + + pad_h = (window_size - height % window_size) % window_size + pad_w = (window_size - width % window_size) % window_size + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) + pad_height, pad_width = height + pad_h, width + pad_w + + hidden_states = hidden_states.reshape( + batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel + ) + windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel) + return windows, (pad_height, pad_width) + + def window_unpartition( + self, windows: torch.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int] + ) -> torch.Tensor: + """ + Args: + Window unpartition into original sequences and removing padding. + hidden_states (tensor): + input tokens with [batch_size * num_windows, window_size, window_size, channel]. + window_size (int): + window size. + padding_shape (Tuple): + padded height and width (pad_height, pad_width). + original_shape (Tuple): original height and width (height, width) before padding. + + Returns: + hidden_states: unpartitioned sequences with [batch_size, height, width, channel]. + """ + pad_height, pad_width = padding_shape + height, width = original_shape + batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size) + hidden_states = windows.reshape( + batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1 + ) + hidden_states = ( + hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1) + ) + + hidden_states = hidden_states[:, :height, :width, :].contiguous() + return hidden_states + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]: + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + # Window partition + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + ) + # Reverse window partition + if self.window_size > 0: + hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + return hidden_states + + +class DeepseekOcrVisionNeck(nn.Module): + def __init__(self, config: DeepseekOcrVisionConfig): + super().__init__() + self.config = config + + self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False) + self.layer_norm1 = DeepseekOcrLayerNorm(config.output_channels, data_format="channels_first") + self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False) + self.layer_norm2 = DeepseekOcrLayerNorm(config.output_channels, data_format="channels_first") + + def forward(self, hidden_states): + hidden_states = hidden_states.permute(0, 3, 1, 2) + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + return hidden_states + + +class DeepseekOcrSamVisionEncoder(DeepseekOcrPreTrainedModel): + """ + SAM ViT-B vision encoder with additional neck layers for Deepseek OCR. + Wraps the SAM vision encoder and adds downsampling convolutions. + """ + + _can_record_outputs = {"hidden_states": DeepseekOcrVisionLayer, "attentions": DeepseekOcrVisionAttention} + + def __init__(self, config): + super().__init__(config) + self.config = config + self.image_size = config.image_size + self.patch_embed = DeepseekOcrPatchEmbeddings(config) + + self.pos_embed = None + if config.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, + config.image_size // config.patch_size, + config.image_size // config.patch_size, + config.hidden_size, + ) + ) + + self.layers = nn.ModuleList() + for i in range(config.num_hidden_layers): + layer = DeepseekOcrVisionLayer( + config, + window_size=config.window_size if i not in config.global_attn_indexes else 0, + ) + self.layers.append(layer) + + self.neck = DeepseekOcrVisionNeck(config) + + self.gradient_checkpointing = False + out_channels = config.out_channels + downsample_channels = config.downsample_channels + + # TODO move hardcoded values to config + self.net_2 = nn.Conv2d(out_channels, downsample_channels[0], kernel_size=3, stride=2, padding=1, bias=False) + self.net_3 = nn.Conv2d( + downsample_channels[0], downsample_channels[1], kernel_size=3, stride=2, padding=1, bias=False + ) + + def get_input_embeddings(self): + return self.patch_embed + + @check_model_inputs(tie_last_hidden_states=False) + def forward(self, pixel_values) -> DeepseekOcrVisionEncoderOutput: + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + hidden_states = self.neck(hidden_states) + hidden_states = self.net_2(hidden_states) + hidden_states = self.net_3(hidden_states) + + return hidden_states + + +class DeepseekOcrVisionEmbeddings(nn.Module): + def __init__(self, config: DeepseekOcrVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values, patch_embeds=None, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + + if patch_embeds is None: + patch_embeds = self.patch_embedding(pixel_values) + if patch_embeds.dim() == 4: + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + else: + patch_embeds = patch_embeds + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embeddings = self.position_embedding(self.position_ids) + if position_embeddings.shape[1] != embeddings.shape[1]: + class_pos_embed = position_embeddings[:, :1] + patch_pos_embed = position_embeddings[:, 1:] + src_size = int(math.sqrt(patch_pos_embed.shape[1])) + patch_pos_embed = patch_pos_embed.reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2) + patch_pos_embed = patch_pos_embed.to(torch.float32) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(height, width), + mode="bicubic", + align_corners=False, + antialias=True, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, height * width, -1) + position_embeddings = torch.cat([class_pos_embed, patch_pos_embed.to(position_embeddings.dtype)], dim=1) + embeddings = embeddings + position_embeddings + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class DeepseekOcrAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Union[DeepseekOcrVisionConfig, DeepseekOcrTextConfig]): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + **kwargs, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class DeepseekOcrMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class DeepseekOcrEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = DeepseekOcrAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = DeepseekOcrMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class DeepseekOcrCLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`DeepseekOcrCLIPEncoderLayer`]. + + Args: + config: DeepseekOcrCLIPConfig + """ + + def __init__(self, config: DeepseekOcrCLIPVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([DeepseekOcrEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = False, # TODO get rid of this when we're done with the fwd pass + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + """ + hidden_states = inputs_embeds + + all_hidden_states = [] if output_hidden_states else None + + for layer_module in self.layers: + if output_hidden_states: + all_hidden_states.append(hidden_states) + + hidden_states = layer_module( + hidden_states, + attention_mask, + causal_attention_mask, + **kwargs, + ) + + if output_hidden_states: + all_hidden_states.append(hidden_states) + all_hidden_states = tuple(all_hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + +class DeepseekOcrCLIPVisionTransformer(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = DeepseekOcrVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = DeepseekOcrCLIPEncoder(config) + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + interpolate_pos_encoding: Optional[bool] = False, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + patch_embeds = kwargs.pop("patch_embeds", None) + hidden_states = self.embeddings( + pixel_values, + patch_embeds, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + pooled_output = last_hidden_state[:, 0, :] + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + The vision model from DEEPSEEK_OCR_C_L_I_P without any head or projection on top. + """ +) +class DeepseekOcrCLIPVisionModel(DeepseekOcrPreTrainedModel): + config: DeepseekOcrVisionConfig + main_input_name = "pixel_values" + input_modalities = "image" + _no_split_modules = ["DeepseekOcrCLIPEncoderLayer"] + config_class = DeepseekOcrCLIPVisionConfig + + def __init__(self, config): + super().__init__(config) + self.vision_model = DeepseekOcrCLIPVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @check_model_inputs(tie_last_hidden_states=False) + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + Args: + patch_embeds (`torch.FloatTensor`, *optional*): + Precomputed patch embeddings derived from the SAM vision encoder. When provided, the transformer will + reuse them instead of recomputing embeddings from `pixel_values`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, DeepseekOcrCLIPVisionModel + + >>> model = DeepseekOcrCLIPVisionModel.from_pretrained("openai/deepseek_ocr_c_l_i_p-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/deepseek_ocr_c_l_i_p-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + + patch_embeds = kwargs.pop("patch_embeds", None) + return self.vision_model( + pixel_values=pixel_values, + patch_embeds=patch_embeds, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + +class DeepseekOcrTextMLP(nn.Module): + def __init__(self, config: DeepseekOcrTextConfig, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class DeepseekOcrTextExperts(nn.ModuleList): + """ + ModuleList of experts. + """ + + def __init__(self, config): + super().__init__() + self.num_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + for _ in range(config.n_routed_experts): + self.append(DeepseekOcrTextMLP(config, intermediate_size=config.moe_intermediate_size)) + + def forward(self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: + tokens_per_expert = torch.bincount(topk_idx.view(-1), minlength=self.num_experts) + + flat_indices = topk_idx.view(-1) + sorted_positions = flat_indices.argsort() + original_token_indices = sorted_positions // self.top_k + + sorted_tokens = hidden_states[original_token_indices] + combined_results = torch.empty_like(sorted_tokens) + + boundaries = torch.cumsum(tokens_per_expert, dim=0) + start_indices = torch.cat((torch.tensor([0], device=boundaries.device), boundaries[:-1])) + + for i in range(self.num_experts): + count = tokens_per_expert[i].item() + if count == 0: + continue + + start = start_indices[i].item() + end = boundaries[i].item() + + combined_results[start:end] = self[i](sorted_tokens[start:end]) + + dispatch_buffer = torch.empty_like(combined_results) + dispatch_buffer.scatter_(0, sorted_positions.unsqueeze(-1).expand_as(combined_results), combined_results) + + dispatch_buffer = dispatch_buffer.view(topk_idx.shape[0], self.top_k, -1) + weighted = dispatch_buffer.to(topk_weight.dtype) * topk_weight.unsqueeze(-1) + + return weighted.sum(dim=1).to(hidden_states.dtype) + + +class DeepseekOcrTextMoe(nn.Module): + def __init__(self, config: DeepseekOcrTextConfig): + super().__init__() + self.config = config + self.experts = DeepseekOcrTextExperts(config) + self.gate = nn.Linear(config.hidden_size, config.n_routed_experts, bias=False) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekOcrTextMLP(config=config, intermediate_size=intermediate_size) + self.routed_scaling_factor = config.routed_scaling_factor + self.topk_method = config.topk_method + self.num_group = config.n_group + self.top_k = config.num_experts_per_tok + self.topk_group = config.topk_group + self.norm_topk_prob = getattr(config, "norm_topk_prob", False) + + def route_tokens_to_experts(self, scores): + if self.top_k is None or self.top_k <= 0: + raise ValueError("`num_experts_per_tok` must be a positive integer for MoE routing.") + + if self.topk_method == "greedy": + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + elif self.topk_method == "group_limited_greedy": + if self.num_group is None or self.topk_group is None: + raise ValueError("`n_group` and `topk_group` must be provided for group_limited_greedy routing.") + group_scores = scores.view(scores.shape[0], self.num_group, -1).max(dim=-1).values + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(scores.shape[0], self.num_group, scores.shape[-1] // self.num_group) + .reshape(scores.shape[0], -1) + ) + masked_scores = scores.masked_fill(~score_mask.bool(), 0.0) + topk_weight, topk_idx = torch.topk(masked_scores, k=self.top_k, dim=-1, sorted=False) + else: + raise ValueError(f"Unsupported topk routing method: {self.topk_method}") + + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True).clamp_min(1e-20) + topk_weight = topk_weight / denominator + + topk_weight = topk_weight * self.routed_scaling_factor + return topk_idx, topk_weight + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residuals = hidden_states + orig_shape = hidden_states.shape + router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32)) + router_scores = router_logits.softmax(dim=-1, dtype=torch.float32) + router_scores_flat = router_scores.view(-1, router_scores.shape[-1]) + topk_indices, topk_weights = self.route_tokens_to_experts(router_scores_flat) + hidden_states_flat = hidden_states.view(-1, hidden_states.shape[-1]) + expert_output = self.experts(hidden_states_flat, topk_indices, topk_weights) + hidden_states = expert_output.view(*orig_shape) + + if hasattr(self, "shared_experts"): + hidden_states = hidden_states + self.shared_experts(residuals) + + return hidden_states + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class DeepseekOcrTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekOcrTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class DeepseekOcrTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = DeepseekOcrTextAttention(config, layer_idx) + self.mlp = ( + DeepseekOcrTextMoe(config) if layer_idx >= config.first_k_dense_replace else DeepseekOcrTextMLP(config) + ) + + self.input_layernorm = DeepseekOcrTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekOcrTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class DeepseekOcrTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: DeepseekOcrTextConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[DeepseekOcrTextConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@use_kernel_forward_from_hub("RMSNorm") +class DeepseekOcrTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekOcrTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +@auto_docstring +class DeepseekOcrTextPreTrainedModel(PreTrainedModel): + config: DeepseekOcrTextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekOcrTextDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": DeepseekOcrTextDecoderLayer, + "attentions": DeepseekOcrTextAttention, + } + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, DeepseekOcrTextMoe): + module.gate.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + +@auto_docstring +class DeepseekOcrTextModel(DeepseekOcrTextPreTrainedModel): + config: DeepseekOcrTextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekOcrTextDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": DeepseekOcrTextDecoderLayer, + "attentions": DeepseekOcrTextAttention, + } + + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.ModuleList( + [DeepseekOcrTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DeepseekOcrTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekOcrTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + for module in self.layers: + if isinstance(module.mlp, DeepseekOcrTextMoe): + module.mlp.gate.weight.data.normal_(mean=0.0, std=config.initializer_range) + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_embeddings=position_embeddings, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring( + custom_intro=""" + The Llava-Next model which consists of a vision backbone and a language model without language modeling head. + """ +) +class DeepseekOcrModel(DeepseekOcrPreTrainedModel): + """ + Deepseek OCR model with dual vision encoders (SAM + CLIP) and a projector. + """ + + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + + def __init__(self, config: DeepseekOcrConfig): + super().__init__(config) + + embed_std = 1 / math.sqrt(config.hidden_size) + self.image_newline = nn.Parameter(torch.randn(config.hidden_size) * embed_std) + + self.vocab_size = config.text_config.vocab_size + self.language_model = DeepseekOcrTextModel._from_config(config.text_config) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + + self.sam_model = DeepseekOcrSamVisionEncoder._from_config(config.vision_config.sam_config) + self.clip_model = DeepseekOcrCLIPVisionModel._from_config(config.vision_config.clip_config) + + self.multi_modal_projector = DeepseekOcrProjector._from_config(config.projector_config) + self.view_seperator = nn.Parameter( + torch.randn(config.hidden_size) * embed_std + ) # TODO the typo is in the checkpoint + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def pack_image_features( + self, + image_features, + image_sizes, + vision_feature_select_strategy, + image_newline=None, + image_spatial_crops=None, + ): + """ + Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. + + Args: + image_features (`list[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) + List of image feature tensor, each contains all the visual feature of all patches. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_select_strategy (`str`) + The feature selection strategy used to select the vision feature from the vision backbone. + image_newline (`torch.Tensor` of shape `(embed_dim)`) + New line embedding vector. + Returns: + image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) + feature_lens (`list[int]`) + token length of each image in image_features + """ + new_image_features = [] + feature_lens = [] + + for image_idx, features in enumerate(image_features): + crop_shape = None + if image_spatial_crops is not None: + crop_shape = image_spatial_crops[image_idx] + if isinstance(crop_shape, torch.Tensor): + crop_shape = crop_shape.tolist() + width_crop_num = int(crop_shape[0]) if crop_shape is not None else 1 + height_crop_num = int(crop_shape[1]) if crop_shape is not None else 1 + has_local_crops = width_crop_num > 1 or height_crop_num > 1 + + if has_local_crops and features.shape[0] >= width_crop_num * height_crop_num + 1: + valid_patch_count = width_crop_num * height_crop_num + 1 + else: + valid_patch_count = 1 if features.shape[0] > 0 else 0 + has_local_crops = False + + features = features[:valid_patch_count] + if features.shape[0] == 0: + new_image_features.append(features) + feature_lens.append(0) + continue + + global_feature = features[-1] + local_features = features[:-1] if has_local_crops else features[:0] + + processed_parts = [] + + if local_features.numel() > 0: + local_tokens = local_features.shape[1] + local_grid = int(math.isqrt(local_tokens)) + + if local_grid * local_grid == local_tokens: + local_features = local_features.view( + height_crop_num, + width_crop_num, + local_grid, + local_grid, + -1, + ) + local_features = local_features.permute(0, 2, 1, 3, 4).contiguous() + local_features = local_features.view( + height_crop_num * local_grid, + width_crop_num * local_grid, + -1, + ) + if image_newline is not None: + newline = ( + image_newline.unsqueeze(0) + .unsqueeze(0) + .to(local_features.device, dtype=local_features.dtype) + .expand(local_features.shape[0], 1, -1) + ) + local_features = torch.cat((local_features, newline), dim=1) + local_features = local_features.view(-1, local_features.shape[-1]) + else: + local_features = local_features.view(-1, local_features.shape[-1]) + if image_newline is not None: + newline = image_newline.unsqueeze(0).to(local_features.device, dtype=local_features.dtype) + local_features = torch.cat((local_features, newline), dim=0) + + processed_parts.append(local_features) + + global_tokens = global_feature.shape[0] + global_grid = int(math.isqrt(global_tokens)) + + if global_grid * global_grid == global_tokens: + global_features = global_feature.view(global_grid, global_grid, -1) + if image_newline is not None: + newline = ( + image_newline.unsqueeze(0) + .unsqueeze(0) + .to(global_features.device, dtype=global_features.dtype) + .expand(global_grid, 1, -1) + ) + global_features = torch.cat((global_features, newline), dim=1) + global_features = global_features.view(-1, global_features.shape[-1]) + else: + global_features = global_feature + if image_newline is not None: + global_features = torch.cat( + ( + global_features, + image_newline.unsqueeze(0).to(global_features.device, dtype=global_features.dtype), + ), + dim=0, + ) + + processed_parts.append(global_features) + + combined = torch.cat(processed_parts, dim=0) + new_image_features.append(combined) + feature_lens.append(combined.size(0)) + + feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device) + return new_image_features, feature_lens + + def get_image_features( + self, + pixel_values: torch.FloatTensor, # (B, num_patches, 3, H, W) or (sum_patches, 3, H, W) + image_sizes: torch.Tensor, # (num_images, 2) actual (H, W) + image_spatial_crops: Optional[torch.Tensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) + The tensors corresponding to the input images. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_layer (`Union[int, list[int]]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (list[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches + and are of shape `(num_patches, image_length, embed_dim)`). + """ + if pixel_values.dim() == 5: + image_num_patches = [pv.shape[0] for pv in pixel_values] + pixel_values = pixel_values.view(-1, *pixel_values.shape[2:]) + elif pixel_values.dim() == 4: + image_num_patches = [pixel_values.shape[0]] + else: + raise ValueError(f"pixel_values has shape {pixel_values.shape}, expected 4D or 5D") + + sam_features = self.sam_model(pixel_values) + sam_seq = sam_features.flatten(2).permute(0, 2, 1) + + clip_out = self.clip_model( + pixel_values=pixel_values, + patch_embeds=sam_features, + output_hidden_states=True, + return_dict=True, + interpolate_pos_encoding=True, + ) + + clip_seq = clip_out.last_hidden_state + vision_feature_layer_index = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + if vision_feature_layer_index is not None: + if isinstance(vision_feature_layer_index, int): + clip_seq = clip_out.hidden_states[vision_feature_layer_index] + else: + pool = [clip_out.hidden_states[i] for i in vision_feature_layer_index] + clip_seq = torch.cat(pool, dim=-1) + + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + if vision_feature_select_strategy == "default": + clip_seq = clip_seq[:, 1:] + elif vision_feature_select_strategy != "full": + raise ValueError(f"Unexpected vision_feature_select_strategy={vision_feature_select_strategy}") + + fused = torch.cat([clip_seq, sam_seq], dim=-1) + proj = self.multi_modal_projector(fused) + + proj_list = torch.split(proj, image_num_patches, dim=0) + + new_image_features, feature_lens = self.pack_image_features( + image_features=proj_list, + image_sizes=image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, + image_newline=self.image_newline, + image_spatial_crops=image_spatial_crops, + ) + + new_image_features = [ + torch.cat([pf, self.view_seperator.unsqueeze(0).to(pf.dtype)], dim=0) for pf in new_image_features + ] + feature_lens = feature_lens + 1 # account for view separator + concatenated_features = torch.cat(new_image_features, dim=0) + return concatenated_features, feature_lens + + def get_placeholder_mask(self, input_ids, inputs_embeds, image_token_id): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + tok_embed = self.get_input_embeddings()(torch.tensor(image_token_id, device=inputs_embeds.device)) + mask = (inputs_embeds == tok_embed).all(dim=-1) + else: + mask = input_ids == self.config.image_token_id + return mask.unsqueeze(-1).expand_as(inputs_embeds) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. + If `"full"`, the full vision features are used. + """ + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_spatial_crop = kwargs.pop("image_spatial_crop", None) + image_sizes = kwargs.pop("image_sizes", None) + image_attention_mask = kwargs.pop("image_attention_mask", None) + # num_img_tokens = kwargs.pop("num_img_tokens", None) + if image_sizes is None and image_spatial_crop is not None: + image_sizes = image_spatial_crop + + image_hidden_states = None + if pixel_values is not None and pixel_values.abs().sum().item() != 0: + if image_sizes is None: + raise ValueError("image_sizes must be provided when pixel_values are passed to the model.") + image_hidden_states, feature_lens = self.get_image_features( + pixel_values, + image_sizes, + image_spatial_crops=image_spatial_crop, + ) + + if image_attention_mask is not None: + token_mask = image_attention_mask.to(inputs_embeds.device) + else: + token_mask = self.get_placeholder_mask( + input_ids, inputs_embeds, self.config.image_token_index + ).squeeze(-1) + + batch_size = token_mask.shape[0] + start_idx = 0 + for batch_idx in range(batch_size): + valid_len = feature_lens[batch_idx].item() + if valid_len == 0: + continue + mask_positions = token_mask[batch_idx].nonzero(as_tuple=True)[0] + if mask_positions.numel() == 0: + continue + if mask_positions.numel() > valid_len: + # deactivate surplus placeholders so they won't interfere with autoregressive decoding + extra_positions = mask_positions[valid_len:] + token_mask[batch_idx, extra_positions] = False + mask_positions = mask_positions[:valid_len] + scatter_mask = torch.zeros_like(token_mask[batch_idx], dtype=torch.bool) + scatter_mask[mask_positions] = True + scatter_mask_expanded = scatter_mask.unsqueeze(-1).expand(-1, inputs_embeds.shape[-1]) + slice_features = image_hidden_states[start_idx : start_idx + valid_len].to(inputs_embeds.dtype) + inputs_embeds[batch_idx] = inputs_embeds[batch_idx].masked_scatter( + scatter_mask_expanded, slice_features + ) + start_idx += valid_len + + outputs = self.language_model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + if not isinstance(outputs, BaseModelOutputWithPast): + last_hidden_state = outputs[0] + past = outputs[1] if len(outputs) > 1 else None + hidden = outputs[2] if len(outputs) > 2 else None + attn = outputs[3] if len(outputs) > 3 else None + else: + last_hidden_state = outputs.last_hidden_state + past = outputs.past_key_values + hidden = outputs.hidden_states + attn = outputs.attentions + + return DeepseekOcrModelOutputWithPast( + last_hidden_state=last_hidden_state, + past_key_values=past, + hidden_states=hidden, + attentions=attn, + image_hidden_states=image_hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + The Deepseek-OCR model which consists of two vision backbones and a deepseek language model. + """ +) +class DeepseekOcrForConditionalGeneration(DeepseekOcrPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^image_newline": "model.image_newline", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekOcrModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vocab_size = config.vocab_size + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): + return self.model.pack_image_features( + image_features=image_features, + image_sizes=image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, + image_newline=image_newline, + ) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + ): + return self.model.get_image_features( + pixel_values=pixel_values, + image_sizes=image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, DeepseekOcrCausalLMOutputWithPast]: + r""" + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. + If `"full"`, the full vision features are used. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, DeepseekOcrForConditionalGeneration + + >>> model = DeepseekOcrForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") + >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") + + >>> prompt = "[INST] \nWhat is shown in this image? [/INST]" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)" + ```""" + image_spatial_crop = kwargs.pop("image_spatial_crop", None) + image_sizes = kwargs.pop("image_sizes", None) + if image_sizes is None and image_spatial_crop is not None: + image_sizes = image_spatial_crop + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + image_spatial_crop=image_spatial_crop, + image_sizes=image_sizes, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return DeepseekOcrCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + image_attention_mask=None, + image_spatial_crop=None, + num_img_tokens=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + model_inputs["image_sizes"] = image_sizes + if image_attention_mask is not None: + model_inputs["image_attention_mask"] = image_attention_mask + if image_spatial_crop is not None: + model_inputs["image_spatial_crop"] = image_spatial_crop + if num_img_tokens is not None: + model_inputs["num_img_tokens"] = num_img_tokens + + return model_inputs + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = [ + "DeepseekOcrModelOutputWithPast", + "DeepseekOcrCausalLMOutputWithPast", + "DeepseekOcrTextModel", + "DeepseekOcrTextPreTrainedModel", + "DeepseekOcrModel", + "DeepseekOcrForConditionalGeneration", + "DeepseekOcrPreTrainedModel", + "DeepseekOcrProjector", + "DeepseekOcrSamVisionEncoder", + "DeepseekOcrCLIPVisionModel", +] diff --git a/src/transformers/models/deepseek_ocr/modular_deepseek_ocr.py b/src/transformers/models/deepseek_ocr/modular_deepseek_ocr.py new file mode 100644 index 000000000000..94c021c6bb6c --- /dev/null +++ b/src/transformers/models/deepseek_ocr/modular_deepseek_ocr.py @@ -0,0 +1,1190 @@ +# Copyright 2025 Deepseek-AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...configuration_utils import PreTrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import check_model_inputs +from ..auto import CONFIG_MAPPING, AutoConfig +from ..clip.modeling_clip import ( + CLIPEncoder, + CLIPEncoderLayer, + CLIPVisionEmbeddings, + CLIPVisionModel, + CLIPVisionTransformer, +) +from ..deepseek_v2.configuration_deepseek_v2 import DeepseekV2Config +from ..deepseek_v2.modeling_deepseek_v2 import ( + DeepseekV2DecoderLayer, + DeepseekV2Model, + DeepseekV2PreTrainedModel, + DeepseekV2RMSNorm, +) +from ..llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding +from ..llava_next.modeling_llava_next import LlavaNextForConditionalGeneration, LlavaNextModel +from ..sam.modeling_sam import SamVisionEncoder, SamVisionNeck + + +logger = logging.get_logger(__name__) + + +class DeepseekOcrSamConfig(PreTrainedConfig): + model_type = "deepseek_ocr_sam_vision" + base_config_key = "sam_config" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + layer_norm_eps=1e-6, + attention_dropout=0.0, + initializer_range=1e-10, + qkv_bias=True, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + global_attn_indexes=None, + mlp_ratio=4.0, + output_channels=256, + downsample_channels=None, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.qkv_bias = qkv_bias + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes if global_attn_indexes is not None else [2, 5, 8, 11] + self.mlp_ratio = mlp_ratio + self.output_channels = output_channels + self.downsample_channels = downsample_channels if downsample_channels is not None else [512, 1024] + self.mlp_dim = int(hidden_size * mlp_ratio) + self.out_channels = output_channels + + +class DeepseekOcrCLIPVisionConfig(PreTrainedConfig): + model_type = "deepseek_ocr_clip_vision" + base_config_key = "clip_vision_config" + + def __init__( + self, + hidden_size=1024, + intermediate_size=4096, + projection_dim=768, + num_hidden_layers=24, + num_attention_heads=16, + num_channels=3, + image_size=224, + patch_size=14, + hidden_act="quick_gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + + +class DeepseekOcrProjectorConfig(PreTrainedConfig): + model_type = "deepseek_ocr_projector" + base_config_key = "projector_config" + + def __init__( + self, + input_dim=2048, + n_embed=1280, + projector_type="linear", + depth=1, + **kwargs, + ): + super().__init__(**kwargs) + self.input_dim = input_dim + self.n_embed = n_embed + self.projector_type = projector_type + self.depth = depth + + +class DeepseekOcrVisionConfig(PreTrainedConfig): + model_type = "deepseek_ocr_vision" + base_config_key = "vision_config" + sub_configs = { + "sam_config": DeepseekOcrSamConfig, + "clip_config": DeepseekOcrCLIPVisionConfig, + } + + def __init__(self, sam_config=None, clip_config=None, **kwargs): + super().__init__(**kwargs) + + if sam_config is None: + self.sam_config = DeepseekOcrSamConfig() + elif isinstance(sam_config, dict): + self.sam_config = DeepseekOcrSamConfig(**sam_config) + else: + self.sam_config = sam_config + + if clip_config is None: + self.clip_config = DeepseekOcrCLIPVisionConfig() + elif isinstance(clip_config, dict): + self.clip_config = DeepseekOcrCLIPVisionConfig(**clip_config) + else: + self.clip_config = clip_config + + # Aggregate commonly accessed vision attributes. + self.image_size = self.sam_config.image_size + self.patch_size = self.sam_config.patch_size + + +class DeepseekOcrTextConfig(DeepseekV2Config): + pass + + +class DeepseekOcrConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekOcrForConditionalGeneration`]. It is used to instantiate a + DeepseekOCR model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the DeepseekOCR + [deepseek-ai/deepseek-ocr](https://huggingface.co/deepseek-ai/deepseek-ocr) architecture. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `DeepseekV2Config`): + The config object or dictionary of the text backbone (DeepSeek-V2). + vision_config (`DeepseekOcrVisionConfig` or `dict`, *optional*): + The config object or dictionary of the vision encoders (SAM and CLIP). + projector_config (`DeepseekOcrProjectorConfig` or `dict`, *optional*): + The config object or dictionary of the projector that maps vision features to text embedding space. + candidate_resolutions (`list`, *optional*, defaults to `[[1024, 1024]]`): + List of candidate image resolutions for adaptive image processing. + global_view_pos (`str`, *optional*, defaults to `"head"`): + Position of the global view in the image sequence. + tile_tag (`str`, *optional*, defaults to `"2D"`): + Tag format for image tiles. + image_token_index (`int`, *optional*, defaults to 100015): + The index representing image tokens in the model's token vocabulary. + + Example: + + ```python + >>> from transformers import DeepseekOcrConfig, DeepseekOcrForConditionalGeneration + + >>> # Initializing a DeepseekOCR configuration + >>> configuration = DeepseekOcrConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = DeepseekOcrForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_ocr" + sub_configs = { + "text_config": AutoConfig, + "vision_config": DeepseekOcrVisionConfig, + "projector_config": DeepseekOcrProjectorConfig, + } + + def __init__( + self, + text_config=None, + vision_config=None, + projector_config=None, + candidate_resolutions=None, + global_view_pos="head", + tile_tag="2D", + image_token_index=100015, + image_grid_pinpoints=None, + vision_feature_layer=None, + vision_feature_select_strategy="default", + **kwargs, + ): + if candidate_resolutions is None: + candidate_resolutions = [[1024, 1024]] + + self.candidate_resolutions = candidate_resolutions + self.global_view_pos = global_view_pos + self.tile_tag = tile_tag + self.image_token_index = image_token_index + self.image_token_id = image_token_index + self.image_grid_pinpoints = image_grid_pinpoints if image_grid_pinpoints is not None else [[1024, 1024]] + self.vision_feature_layer = vision_feature_layer + self.vision_feature_select_strategy = vision_feature_select_strategy + + if text_config is None: + text_config = CONFIG_MAPPING["deepseek_v2"]( + hidden_size=1280, + intermediate_size=6848, + num_hidden_layers=12, + num_attention_heads=10, + num_key_value_heads=10, + moe_intermediate_size=896, + n_routed_experts=64, + n_shared_experts=2, + num_experts_per_tok=6, + first_k_dense_replace=1, + vocab_size=129280, + max_position_embeddings=8192, + use_mla=False, + ) + elif isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "deepseek_v2") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + + self.text_config = text_config + + if vision_config is None: + self.vision_config = DeepseekOcrVisionConfig() + elif isinstance(vision_config, dict): + self.vision_config = DeepseekOcrVisionConfig(**vision_config) + else: + self.vision_config = vision_config + + if projector_config is None: + self.projector_config = DeepseekOcrProjectorConfig() + elif isinstance(projector_config, dict): + self.projector_config = DeepseekOcrProjectorConfig(**projector_config) + else: + self.projector_config = projector_config + + self.hidden_size = self.text_config.hidden_size + self.vocab_size = self.text_config.vocab_size + + super().__init__(**kwargs) + + +class DeepseekOcrPreTrainedModel(PreTrainedModel): + config_class = DeepseekOcrConfig + base_model_prefix = "model" + + +class DeepseekOcrProjector(PreTrainedModel): + """ + Projector that maps concatenated SAM + CLIP features to language model space. + """ + + def __init__(self, config): + super().__init__(config) + self.layers = nn.Linear(config.input_dim, config.n_embed) + + def forward(self, x): + return self.layers(x) + + +class DeepseekOcrSamVisionNeck(SamVisionNeck): + def __init__(self, config): + super().__init__(config) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Deepseek OCR model outputs with optional image hidden states. + """ +) +class DeepseekOcrModelOutputWithPast(BaseModelOutputWithPast): + r""" + image_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states extracted from the visual encoder and projected into the language embedding space. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Deepseek OCR causal language model outputs with image hidden states. + """ +) +class DeepseekOcrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modelling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modelling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states produced by the visual encoder after multimodal projection. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class DeepseekOcrSamVisionEncoder(SamVisionEncoder): + """ + SAM ViT-B vision encoder with additional neck layers for Deepseek OCR. + Wraps the SAM vision encoder and adds downsampling convolutions. + """ + + def __init__(self, config): + super().__init__(config) + out_channels = config.out_channels + downsample_channels = config.downsample_channels + + # TODO move hardcoded values to config + self.net_2 = nn.Conv2d(out_channels, downsample_channels[0], kernel_size=3, stride=2, padding=1, bias=False) + self.net_3 = nn.Conv2d( + downsample_channels[0], downsample_channels[1], kernel_size=3, stride=2, padding=1, bias=False + ) + + def forward(self, pixel_values): + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + hidden_states = self.neck(hidden_states) + hidden_states = self.net_2(hidden_states) + hidden_states = self.net_3(hidden_states) + + return hidden_states + + +class DeepseekOcrVisionEmbeddings(CLIPVisionEmbeddings): + def forward(self, pixel_values, patch_embeds=None, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + + if patch_embeds is None: + patch_embeds = self.patch_embedding(pixel_values) + if patch_embeds.dim() == 4: + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + else: + patch_embeds = patch_embeds + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embeddings = self.position_embedding(self.position_ids) + if position_embeddings.shape[1] != embeddings.shape[1]: + class_pos_embed = position_embeddings[:, :1] + patch_pos_embed = position_embeddings[:, 1:] + src_size = int(math.sqrt(patch_pos_embed.shape[1])) + patch_pos_embed = patch_pos_embed.reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2) + patch_pos_embed = patch_pos_embed.to(torch.float32) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(height, width), + mode="bicubic", + align_corners=False, + antialias=True, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, height * width, -1) + position_embeddings = torch.cat([class_pos_embed, patch_pos_embed.to(position_embeddings.dtype)], dim=1) + embeddings = embeddings + position_embeddings + return embeddings + + +class DeepseekOcrEncoderLayer(CLIPEncoderLayer): + def __init__(self, config): + super().__init__(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + **kwargs, + ) + + +class DeepseekOcrCLIPEncoder(CLIPEncoder): + def __init__(self, config: DeepseekOcrCLIPVisionConfig): + super().__init__(config) + self.layers = nn.ModuleList([DeepseekOcrEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = False, # TODO get rid of this when we're done with the fwd pass + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + hidden_states = inputs_embeds + + all_hidden_states = [] if output_hidden_states else None + + for layer_module in self.layers: + if output_hidden_states: + all_hidden_states.append(hidden_states) + + hidden_states = layer_module( + hidden_states, + attention_mask, + causal_attention_mask, + **kwargs, + ) + + if output_hidden_states: + all_hidden_states.append(hidden_states) + all_hidden_states = tuple(all_hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + +class DeepseekOcrCLIPVisionTransformer(CLIPVisionTransformer): + def __init__(self, config): + super().__init__(config) + embed_dim = config.hidden_size + self.embeddings = DeepseekOcrVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = DeepseekOcrCLIPEncoder(config) + del self.post_layernorm + + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + interpolate_pos_encoding: Optional[bool] = False, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + patch_embeds = kwargs.pop("patch_embeds", None) + hidden_states = self.embeddings( + pixel_values, + patch_embeds, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + **kwargs, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + pooled_output = last_hidden_state[:, 0, :] + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +class DeepseekOcrCLIPVisionModel(CLIPVisionModel): + config_class = DeepseekOcrCLIPVisionConfig + + def __init__(self, config): + super().__init__(config) + self.vision_model = DeepseekOcrCLIPVisionTransformer(config) + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @check_model_inputs(tie_last_hidden_states=False) + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + interpolate_pos_encoding: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + Args: + patch_embeds (`torch.FloatTensor`, *optional*): + Precomputed patch embeddings derived from the SAM vision encoder. When provided, the transformer will + reuse them instead of recomputing embeddings from `pixel_values`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, DeepseekOcrCLIPVisionModel + + >>> model = DeepseekOcrCLIPVisionModel.from_pretrained("openai/deepseek_ocr_c_l_i_p-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/deepseek_ocr_c_l_i_p-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + + patch_embeds = kwargs.pop("patch_embeds", None) + return self.vision_model( + pixel_values=pixel_values, + patch_embeds=patch_embeds, + interpolate_pos_encoding=interpolate_pos_encoding, + **kwargs, + ) + + +class DeepseekOcrTextMLP(nn.Module): + def __init__(self, config: DeepseekOcrTextConfig, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class DeepseekOcrTextExperts(nn.ModuleList): + """ + ModuleList of experts. + """ + + def __init__(self, config): + super().__init__() + self.num_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + for _ in range(config.n_routed_experts): + self.append(DeepseekOcrTextMLP(config, intermediate_size=config.moe_intermediate_size)) + + def forward(self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: + tokens_per_expert = torch.bincount(topk_idx.view(-1), minlength=self.num_experts) + + flat_indices = topk_idx.view(-1) + sorted_positions = flat_indices.argsort() + original_token_indices = sorted_positions // self.top_k + + sorted_tokens = hidden_states[original_token_indices] + combined_results = torch.empty_like(sorted_tokens) + + boundaries = torch.cumsum(tokens_per_expert, dim=0) + start_indices = torch.cat((torch.tensor([0], device=boundaries.device), boundaries[:-1])) + + for i in range(self.num_experts): + count = tokens_per_expert[i].item() + if count == 0: + continue + + start = start_indices[i].item() + end = boundaries[i].item() + + combined_results[start:end] = self[i](sorted_tokens[start:end]) + + dispatch_buffer = torch.empty_like(combined_results) + dispatch_buffer.scatter_(0, sorted_positions.unsqueeze(-1).expand_as(combined_results), combined_results) + + dispatch_buffer = dispatch_buffer.view(topk_idx.shape[0], self.top_k, -1) + weighted = dispatch_buffer.to(topk_weight.dtype) * topk_weight.unsqueeze(-1) + + return weighted.sum(dim=1).to(hidden_states.dtype) + + +class DeepseekOcrTextMoe(nn.Module): + def __init__(self, config: DeepseekOcrTextConfig): + super().__init__() + self.config = config + self.experts = DeepseekOcrTextExperts(config) + self.gate = nn.Linear(config.hidden_size, config.n_routed_experts, bias=False) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekOcrTextMLP(config=config, intermediate_size=intermediate_size) + self.routed_scaling_factor = config.routed_scaling_factor + self.topk_method = config.topk_method + self.num_group = config.n_group + self.top_k = config.num_experts_per_tok + self.topk_group = config.topk_group + self.norm_topk_prob = getattr(config, "norm_topk_prob", False) + + def route_tokens_to_experts(self, scores): + if self.top_k is None or self.top_k <= 0: + raise ValueError("`num_experts_per_tok` must be a positive integer for MoE routing.") + + if self.topk_method == "greedy": + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + elif self.topk_method == "group_limited_greedy": + if self.num_group is None or self.topk_group is None: + raise ValueError("`n_group` and `topk_group` must be provided for group_limited_greedy routing.") + group_scores = scores.view(scores.shape[0], self.num_group, -1).max(dim=-1).values + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(scores.shape[0], self.num_group, scores.shape[-1] // self.num_group) + .reshape(scores.shape[0], -1) + ) + masked_scores = scores.masked_fill(~score_mask.bool(), 0.0) + topk_weight, topk_idx = torch.topk(masked_scores, k=self.top_k, dim=-1, sorted=False) + else: + raise ValueError(f"Unsupported topk routing method: {self.topk_method}") + + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True).clamp_min(1e-20) + topk_weight = topk_weight / denominator + + topk_weight = topk_weight * self.routed_scaling_factor + return topk_idx, topk_weight + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residuals = hidden_states + orig_shape = hidden_states.shape + router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32)) + router_scores = router_logits.softmax(dim=-1, dtype=torch.float32) + router_scores_flat = router_scores.view(-1, router_scores.shape[-1]) + topk_indices, topk_weights = self.route_tokens_to_experts(router_scores_flat) + hidden_states_flat = hidden_states.view(-1, hidden_states.shape[-1]) + expert_output = self.experts(hidden_states_flat, topk_indices, topk_weights) + hidden_states = expert_output.view(*orig_shape) + + if hasattr(self, "shared_experts"): + hidden_states = hidden_states + self.shared_experts(residuals) + + return hidden_states + + +class DeepseekOcrTextAttention(LlamaAttention): + pass + + +class DeepseekOcrTextDecoderLayer(DeepseekV2DecoderLayer): + def __init__(self, config, layer_idx): + super().__init__(config, layer_idx) + self.self_attn = DeepseekOcrTextAttention(config, layer_idx) + self.mlp = ( + DeepseekOcrTextMoe(config) if layer_idx >= config.first_k_dense_replace else DeepseekOcrTextMLP(config) + ) + + +class DeepseekOcrTextRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class DeepseekOcrTextRMSNorm(DeepseekV2RMSNorm): + pass + + +class DeepseekOcrTextPreTrainedModel(DeepseekV2PreTrainedModel): + pass + + +class DeepseekOcrTextModel(DeepseekV2Model): + config: DeepseekOcrTextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekOcrTextDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": DeepseekOcrTextDecoderLayer, + "attentions": DeepseekOcrTextAttention, + } + + def __init__(self, config): + super().__init__(config) + + self.layers = nn.ModuleList( + [DeepseekOcrTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DeepseekOcrTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekOcrTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + for module in self.layers: + if isinstance(module.mlp, DeepseekOcrTextMoe): + module.mlp.gate.weight.data.normal_(mean=0.0, std=config.initializer_range) + + +class DeepseekOcrModel(LlavaNextModel): + """ + Deepseek OCR model with dual vision encoders (SAM + CLIP) and a projector. + """ + + def __init__(self, config: DeepseekOcrConfig): + super().__init__(config) + del self.vision_tower + del self.multi_modal_projector + + self.sam_model = DeepseekOcrSamVisionEncoder._from_config(config.vision_config.sam_config) + self.clip_model = DeepseekOcrCLIPVisionModel._from_config(config.vision_config.clip_config) + + self.multi_modal_projector = DeepseekOcrProjector._from_config(config.projector_config) + + self.vocab_size = config.text_config.vocab_size + self.language_model = DeepseekOcrTextModel._from_config(config.text_config) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + + embed_std = 1 / math.sqrt(config.hidden_size) + self.image_newline = nn.Parameter(torch.randn(config.hidden_size) * embed_std) + self.view_seperator = nn.Parameter( + torch.randn(config.hidden_size) * embed_std + ) # TODO the typo is in the checkpoint + + self.post_init() + + def get_placeholder_mask(self, input_ids, inputs_embeds, image_token_id): + if input_ids is None: + tok_embed = self.get_input_embeddings()(torch.tensor(image_token_id, device=inputs_embeds.device)) + mask = (inputs_embeds == tok_embed).all(dim=-1) + else: + mask = input_ids == self.config.image_token_id + return mask.unsqueeze(-1).expand_as(inputs_embeds) + + def pack_image_features( + self, + image_features, + image_sizes, + vision_feature_select_strategy, + image_newline=None, + image_spatial_crops=None, + ): + new_image_features = [] + feature_lens = [] + + for image_idx, features in enumerate(image_features): + crop_shape = None + if image_spatial_crops is not None: + crop_shape = image_spatial_crops[image_idx] + if isinstance(crop_shape, torch.Tensor): + crop_shape = crop_shape.tolist() + width_crop_num = int(crop_shape[0]) if crop_shape is not None else 1 + height_crop_num = int(crop_shape[1]) if crop_shape is not None else 1 + has_local_crops = width_crop_num > 1 or height_crop_num > 1 + + if has_local_crops and features.shape[0] >= width_crop_num * height_crop_num + 1: + valid_patch_count = width_crop_num * height_crop_num + 1 + else: + valid_patch_count = 1 if features.shape[0] > 0 else 0 + has_local_crops = False + + features = features[:valid_patch_count] + if features.shape[0] == 0: + new_image_features.append(features) + feature_lens.append(0) + continue + + global_feature = features[-1] + local_features = features[:-1] if has_local_crops else features[:0] + + processed_parts = [] + + if local_features.numel() > 0: + local_tokens = local_features.shape[1] + local_grid = int(math.isqrt(local_tokens)) + + if local_grid * local_grid == local_tokens: + local_features = local_features.view( + height_crop_num, + width_crop_num, + local_grid, + local_grid, + -1, + ) + local_features = local_features.permute(0, 2, 1, 3, 4).contiguous() + local_features = local_features.view( + height_crop_num * local_grid, + width_crop_num * local_grid, + -1, + ) + if image_newline is not None: + newline = ( + image_newline.unsqueeze(0) + .unsqueeze(0) + .to(local_features.device, dtype=local_features.dtype) + .expand(local_features.shape[0], 1, -1) + ) + local_features = torch.cat((local_features, newline), dim=1) + local_features = local_features.view(-1, local_features.shape[-1]) + else: + local_features = local_features.view(-1, local_features.shape[-1]) + if image_newline is not None: + newline = image_newline.unsqueeze(0).to(local_features.device, dtype=local_features.dtype) + local_features = torch.cat((local_features, newline), dim=0) + + processed_parts.append(local_features) + + global_tokens = global_feature.shape[0] + global_grid = int(math.isqrt(global_tokens)) + + if global_grid * global_grid == global_tokens: + global_features = global_feature.view(global_grid, global_grid, -1) + if image_newline is not None: + newline = ( + image_newline.unsqueeze(0) + .unsqueeze(0) + .to(global_features.device, dtype=global_features.dtype) + .expand(global_grid, 1, -1) + ) + global_features = torch.cat((global_features, newline), dim=1) + global_features = global_features.view(-1, global_features.shape[-1]) + else: + global_features = global_feature + if image_newline is not None: + global_features = torch.cat( + ( + global_features, + image_newline.unsqueeze(0).to(global_features.device, dtype=global_features.dtype), + ), + dim=0, + ) + + processed_parts.append(global_features) + + combined = torch.cat(processed_parts, dim=0) + new_image_features.append(combined) + feature_lens.append(combined.size(0)) + + feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device) + return new_image_features, feature_lens + + def get_image_features( + self, + pixel_values: torch.FloatTensor, # (B, num_patches, 3, H, W) or (sum_patches, 3, H, W) + image_sizes: torch.Tensor, # (num_images, 2) actual (H, W) + image_spatial_crops: Optional[torch.Tensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + ): + if pixel_values.dim() == 5: + image_num_patches = [pv.shape[0] for pv in pixel_values] + pixel_values = pixel_values.view(-1, *pixel_values.shape[2:]) + elif pixel_values.dim() == 4: + image_num_patches = [pixel_values.shape[0]] + else: + raise ValueError(f"pixel_values has shape {pixel_values.shape}, expected 4D or 5D") + + sam_features = self.sam_model(pixel_values) + sam_seq = sam_features.flatten(2).permute(0, 2, 1) + + clip_out = self.clip_model( + pixel_values=pixel_values, + patch_embeds=sam_features, + output_hidden_states=True, + return_dict=True, + interpolate_pos_encoding=True, + ) + + clip_seq = clip_out.last_hidden_state + vision_feature_layer_index = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + + if vision_feature_layer_index is not None: + if isinstance(vision_feature_layer_index, int): + clip_seq = clip_out.hidden_states[vision_feature_layer_index] + else: + pool = [clip_out.hidden_states[i] for i in vision_feature_layer_index] + clip_seq = torch.cat(pool, dim=-1) + + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + if vision_feature_select_strategy == "default": + clip_seq = clip_seq[:, 1:] + elif vision_feature_select_strategy != "full": + raise ValueError(f"Unexpected vision_feature_select_strategy={vision_feature_select_strategy}") + + fused = torch.cat([clip_seq, sam_seq], dim=-1) + proj = self.multi_modal_projector(fused) + + proj_list = torch.split(proj, image_num_patches, dim=0) + + new_image_features, feature_lens = self.pack_image_features( + image_features=proj_list, + image_sizes=image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, + image_newline=self.image_newline, + image_spatial_crops=image_spatial_crops, + ) + + new_image_features = [ + torch.cat([pf, self.view_seperator.unsqueeze(0).to(pf.dtype)], dim=0) for pf in new_image_features + ] + feature_lens = feature_lens + 1 # account for view separator + concatenated_features = torch.cat(new_image_features, dim=0) + return concatenated_features, feature_lens + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_spatial_crop = kwargs.pop("image_spatial_crop", None) + image_sizes = kwargs.pop("image_sizes", None) + image_attention_mask = kwargs.pop("image_attention_mask", None) + # num_img_tokens = kwargs.pop("num_img_tokens", None) + if image_sizes is None and image_spatial_crop is not None: + image_sizes = image_spatial_crop + + image_hidden_states = None + if pixel_values is not None and pixel_values.abs().sum().item() != 0: + if image_sizes is None: + raise ValueError("image_sizes must be provided when pixel_values are passed to the model.") + image_hidden_states, feature_lens = self.get_image_features( + pixel_values, + image_sizes, + image_spatial_crops=image_spatial_crop, + ) + + if image_attention_mask is not None: + token_mask = image_attention_mask.to(inputs_embeds.device) + else: + token_mask = self.get_placeholder_mask( + input_ids, inputs_embeds, self.config.image_token_index + ).squeeze(-1) + + batch_size = token_mask.shape[0] + start_idx = 0 + for batch_idx in range(batch_size): + valid_len = feature_lens[batch_idx].item() + if valid_len == 0: + continue + mask_positions = token_mask[batch_idx].nonzero(as_tuple=True)[0] + if mask_positions.numel() == 0: + continue + if mask_positions.numel() > valid_len: + # deactivate surplus placeholders so they won't interfere with autoregressive decoding + extra_positions = mask_positions[valid_len:] + token_mask[batch_idx, extra_positions] = False + mask_positions = mask_positions[:valid_len] + scatter_mask = torch.zeros_like(token_mask[batch_idx], dtype=torch.bool) + scatter_mask[mask_positions] = True + scatter_mask_expanded = scatter_mask.unsqueeze(-1).expand(-1, inputs_embeds.shape[-1]) + slice_features = image_hidden_states[start_idx : start_idx + valid_len].to(inputs_embeds.dtype) + inputs_embeds[batch_idx] = inputs_embeds[batch_idx].masked_scatter( + scatter_mask_expanded, slice_features + ) + start_idx += valid_len + + outputs = self.language_model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + if not isinstance(outputs, BaseModelOutputWithPast): + last_hidden_state = outputs[0] + past = outputs[1] if len(outputs) > 1 else None + hidden = outputs[2] if len(outputs) > 2 else None + attn = outputs[3] if len(outputs) > 3 else None + else: + last_hidden_state = outputs.last_hidden_state + past = outputs.past_key_values + hidden = outputs.hidden_states + attn = outputs.attentions + + return DeepseekOcrModelOutputWithPast( + last_hidden_state=last_hidden_state, + past_key_values=past, + hidden_states=hidden, + attentions=attn, + image_hidden_states=image_hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + The Deepseek-OCR model which consists of two vision backbones and a deepseek language model. + """ +) +class DeepseekOcrForConditionalGeneration(LlavaNextForConditionalGeneration): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekOcrModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, DeepseekOcrCausalLMOutputWithPast]: + image_spatial_crop = kwargs.pop("image_spatial_crop", None) + image_sizes = kwargs.pop("image_sizes", None) + if image_sizes is None and image_spatial_crop is not None: + image_sizes = image_spatial_crop + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + image_spatial_crop=image_spatial_crop, + image_sizes=image_sizes, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return DeepseekOcrCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + image_attention_mask=None, + image_spatial_crop=None, + num_img_tokens=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + model_inputs["image_sizes"] = image_sizes + if image_attention_mask is not None: + model_inputs["image_attention_mask"] = image_attention_mask + if image_spatial_crop is not None: + model_inputs["image_spatial_crop"] = image_spatial_crop + if num_img_tokens is not None: + model_inputs["num_img_tokens"] = num_img_tokens + + return model_inputs + + +__all__ = [ + "DeepseekOcrConfig", + "DeepseekOcrVisionConfig", + "DeepseekOcrSamConfig", + "DeepseekOcrCLIPVisionConfig", + "DeepseekOcrProjectorConfig", + "DeepseekOcrModelOutputWithPast", + "DeepseekOcrCausalLMOutputWithPast", + "DeepseekOcrTextModel", + "DeepseekOcrTextPreTrainedModel", + "DeepseekOcrModel", + "DeepseekOcrForConditionalGeneration", + "DeepseekOcrPreTrainedModel", + "DeepseekOcrProjector", + "DeepseekOcrSamVisionEncoder", + "DeepseekOcrCLIPVisionModel", +] diff --git a/src/transformers/models/deepseek_ocr/processing_deepseek_ocr.py b/src/transformers/models/deepseek_ocr/processing_deepseek_ocr.py new file mode 100644 index 000000000000..19b3653465fe --- /dev/null +++ b/src/transformers/models/deepseek_ocr/processing_deepseek_ocr.py @@ -0,0 +1,169 @@ +# Copyright 2025 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Optional, Union + +import torch + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import TextInput +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class DeepseekOcrProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} + + +class DeepseekOcrProcessor(ProcessorMixin): + r""" + Constructs a DeepSeek OCR processor which wraps an image processor and a tokenizer into a single processor. + + [`DeepseekOcrProcessor`] offers all the functionalities of [`DeepseekOcrImageProcessorFast`] and tokenizer. + See the [`~DeepseekOcrProcessor.__call__`] and [`~DeepseekOcrProcessor.decode`] for more information. + + Args: + image_processor (`DeepseekOcrImageProcessorFast`): + The image processor to use for images. + tokenizer (PreTrainedTokenizer): + The tokenizer to use for text. + image_token (`str`, *optional*, defaults to `""`): + The image token to use. + """ + + attributes = ["image_processor", "tokenizer"] + tokenizer_class = "AutoTokenizer" + image_processor_class = "DeepseekOcrImageProcessorFast" + + def __init__( + self, + image_processor, + tokenizer, + image_token="", + **kwargs, + ): + self.image_token = image_token + # TODO this should not be here and handled in conversion script instead + if "chat_template" not in kwargs and getattr(tokenizer, "chat_template", None) is not None: + kwargs["chat_template"] = tokenizer.chat_template + super().__init__(image_processor, tokenizer, **kwargs) + + def __call__( + self, + text: Union[TextInput, list[TextInput]], + images: Optional[ImageInput] = None, + **kwargs: Unpack[ProcessingKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). + + Args: + text (`str`, `list[str]`): + The sequence or batch of sequences to be encoded. + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, etc.): + The image or batch of images to be prepared. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **pixel_values** -- Pixel values to be fed to a model. + - **image_attention_mask** -- Mask for image tokens in the input sequence. + - **image_spatial_crop** -- Spatial crop information for images. + """ + + output_kwargs = self._merge_kwargs(DeepseekOcrProcessorKwargs, self.tokenizer.init_kwargs, **kwargs) + image_kwargs = output_kwargs["images_kwargs"] + + image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {} + + num_img_tokens = image_inputs.pop("num_img_tokens", []) + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise TypeError("Invalid input text. Please provide a string, or a list of strings") + + concatenated_prompt = "".join(text) + if concatenated_prompt.count(self.image_token) != len(num_img_tokens): + raise ValueError( + f"Number of image tokens ({concatenated_prompt.count(self.image_token)}) in text " + f"does not match number of images ({len(num_img_tokens)}). " + f"Please add {self.image_token} token for each image." + ) + + image_count_iter = iter(num_img_tokens) + processed_text = [ + re.sub(re.escape(self.image_token), lambda _: self.image_token * next(image_count_iter), t) for t in text + ] + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(processed_text, **output_kwargs["text_kwargs"]) + + image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) + + input_ids = text_inputs["input_ids"] + if isinstance(input_ids, list): + batch_size = len(input_ids) + else: + batch_size = input_ids.size(0) + + if isinstance(input_ids, list): + image_attention_mask: list[list[bool]] = [] + for ids in input_ids: + mask = [token == image_token_id for token in ids] + image_attention_mask.append(mask) + else: + image_attention_mask = torch.zeros_like(input_ids, dtype=torch.bool) + for batch_idx in range(batch_size): + image_positions = (input_ids[batch_idx] == image_token_id).nonzero(as_tuple=True)[0] + image_attention_mask[batch_idx, image_positions] = True + + data = { + **text_inputs, + **image_inputs, + "image_attention_mask": image_attention_mask, + "num_img_tokens": num_img_tokens, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to the tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to the tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + ["image_attention_mask"])) + + +__all__ = ["DeepseekOcrProcessor"] diff --git a/tests/models/deepseek_ocr/__init__.py b/tests/models/deepseek_ocr/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/deepseek_ocr/test_modeling_deepseek_ocr.py b/tests/models/deepseek_ocr/test_modeling_deepseek_ocr.py new file mode 100644 index 000000000000..6d71fcbaa9d2 --- /dev/null +++ b/tests/models/deepseek_ocr/test_modeling_deepseek_ocr.py @@ -0,0 +1,276 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import tempfile +import unittest + +from transformers import ( + AutoModel, + AutoProcessor, + DeepseekOcrConfig, + DeepseekOcrForConditionalGeneration, + DeepseekOcrModel, + is_torch_available, +) +from transformers.testing_utils import ( + require_torch, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask + + +if is_torch_available(): + import torch + + +class DeepseekOcrModelTester: + def __init__( + self, + parent, + batch_size=2, + seq_length=25, + num_channels=3, + initializer_range=0.02, + is_training=True, + use_cache=False, + text_config={ + "model_type": "deepseek_v2", + "num_hidden_layers": 2, + "vocab_size": 129280, + "hidden_size": 64, + "intermediate_size": 128, + "num_attention_heads": 4, + "num_key_value_heads": 4, + "max_position_embeddings": 512, + "pad_token_id": 1, + "use_mla": False, + }, + sam_config={ + "num_hidden_layers": 1, + "hidden_size": 32, + "num_attention_heads": 4, + "image_size": 64, + "patch_size": 16, + "hidden_act": "gelu", + "output_channels": 16, + }, + clip_config={ + "num_hidden_layers": 1, + "hidden_size": 32, + "intermediate_size": 64, + "num_attention_heads": 4, + "image_size": 64, + "patch_size": 16, + "hidden_act": "quick_gelu", + "projection_dim": 32, + }, + projector_config={ + "input_dim": 32, + "n_embed": 64, + }, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.num_channels = num_channels + self.initializer_range = initializer_range + self.is_training = is_training + self.use_cache = use_cache + + self.text_config = text_config + self.sam_config = sam_config + self.clip_config = clip_config + self.projector_config = projector_config + + self.num_hidden_layers = text_config["num_hidden_layers"] + self.vocab_size = text_config["vocab_size"] + self.hidden_size = text_config["hidden_size"] + self.num_attention_heads = text_config["num_attention_heads"] + self.image_size = sam_config["image_size"] + self.num_image_tokens = 16 + self.pad_token_id = text_config["pad_token_id"] + self.image_token_id = 100015 + + def get_config(self): + vision_config = { + "sam_config": self.sam_config, + "clip_config": self.clip_config, + } + return DeepseekOcrConfig( + text_config=self.text_config, + vision_config=vision_config, + projector_config=self.projector_config, + image_token_index=self.image_token_id, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 1 + attention_mask = random_attention_mask([self.batch_size, self.seq_length]) + pixel_values = floats_tensor( + [ + self.batch_size, + self.num_channels, + self.image_size, + self.image_size, + ] + ) + input_ids[input_ids == self.num_image_tokens] = config.text_config.pad_token_id + input_ids[:, : self.num_image_tokens] = self.image_token_id + + return config, input_ids, attention_mask, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask, pixel_values = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class DeepseekOcrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (DeepseekOcrModel, DeepseekOcrForConditionalGeneration) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": DeepseekOcrModel, + "image-text-to-text": DeepseekOcrForConditionalGeneration, + } + if is_torch_available() + else {} + ) + _is_composite = True + + def setUp(self): + self.model_tester = DeepseekOcrModelTester(self) + self.config_tester = ConfigTester(self, config_class=DeepseekOcrConfig, has_text_modality=False) + + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + wte = model.get_input_embeddings() + inputs["inputs_embeds"] = wte(input_ids) + + with torch.no_grad(): + model(**inputs) + + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + inputs_embeds = model.get_input_embeddings()(input_ids) + + with torch.no_grad(): + out_ids = model(input_ids=input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] + torch.testing.assert_close(out_embeds, out_ids) + + def test_sdpa_can_dispatch_composite_models(self): + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + model_eager = model_class.from_pretrained( + tmpdirname, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + vision_attn = language_attn = "sdpa" if model._supports_sdpa else "eager" + + if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "language_model"): + self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn) + self.assertTrue(model_sdpa.language_model.config._attn_implementation == language_attn) + self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if any(re.finditer(r"Attention(?!Pool)", class_name)): + self.assertTrue(submodule.config._attn_implementation == "eager") + + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if any(re.finditer(r"Attention(?!Pool)", class_name)): + self.assertTrue(submodule.config._attn_implementation == "sdpa") + + +@require_torch +@slow +class DeepseekOcrIntegrationTest(unittest.TestCase): + def setUp(self): + self.model_id = "deepseek_ocr_converted" + + def test_model_text_generation(self): + processor = AutoProcessor.from_pretrained(self.model_id) + model = AutoModel.from_pretrained(self.model_id, torch_dtype=torch.bfloat16) + + conversation = [ + { + "role": "<|User|>", + "content": [ + {"type": "image", "path": "./handwritten_letter_small.png"}, + {"type": "text", "text": "<|grounding|>Convert the document to markdown."}, + ], + } + ] + + inputs = processor.apply_chat_template( + conversation, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt" + ) + inputs = {k: v.to(torch_device) for k, v in inputs.items()} + + with torch.no_grad(): + generated = model.generate(**inputs, max_new_tokens=250) + + text = processor.batch_decode(generated, skip_special_tokens=False)[0] + + self.assertIn("<|grounding|>Convert the document to markdown.", text) + self.assertIn("text", text) + self.assertIn("[[52, 50, 940, 950]]", text) diff --git a/tests/models/deepseek_ocr/test_processing_deepseek_ocr.py b/tests/models/deepseek_ocr/test_processing_deepseek_ocr.py new file mode 100644 index 000000000000..6e63e238fbde --- /dev/null +++ b/tests/models/deepseek_ocr/test_processing_deepseek_ocr.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +from transformers import AutoTokenizer, DeepseekOcrProcessor +from transformers.testing_utils import get_tests_dir, require_torch, require_vision +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import DeepseekOcrImageProcessorFast + + +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + + +@require_vision +class DeepseekOcrProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = DeepseekOcrProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = DeepseekOcrImageProcessorFast() + tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V2-Lite") + processor_kwargs = self.prepare_processor_dict() + processor = self.processor_class( + image_processor=image_processor, + tokenizer=tokenizer, + **processor_kwargs, + ) + processor.save_pretrained(self.tmpdirname) + self.image_token = processor.image_token + + @staticmethod + def prepare_processor_dict(): + return { + "image_token": "", + } + + def get_tokenizer(self, **kwargs): + return AutoTokenizer.from_pretrained(self.tmpdirname, **kwargs) + + def get_image_processor(self, **kwargs): + processor = self.processor_class.from_pretrained(self.tmpdirname, **kwargs) + return processor.image_processor + + @require_torch + def test_image_token_expansion(self): + from PIL import Image + + processor = self.processor_class.from_pretrained(self.tmpdirname) + image = Image.new("RGB", (64, 64), color="red") + + text = f"{processor.image_token} Describe this image." + inputs = processor(text=text, images=image, return_tensors="pt") + + self.assertIn("input_ids", inputs) + self.assertIn("pixel_values", inputs) + self.assertIn("image_attention_mask", inputs) + self.assertIn("num_img_tokens", inputs) + + image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token) + num_image_tokens = (inputs["input_ids"][0] == image_token_id).sum().item() + + self.assertGreater(num_image_tokens, 1) + + @require_torch + def test_image_attention_mask_generation(self): + import torch + from PIL import Image + + processor = self.processor_class.from_pretrained(self.tmpdirname) + image = Image.new("RGB", (64, 64), color="blue") + + text = f"{processor.image_token} What is in this image?" + inputs = processor(text=text, images=image, return_tensors="pt") + + image_attention_mask = inputs["image_attention_mask"] + self.assertIsInstance(image_attention_mask, torch.Tensor) + self.assertEqual(image_attention_mask.dtype, torch.bool) + + image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token) + expected_mask = inputs["input_ids"] == image_token_id + torch.testing.assert_close(image_attention_mask, expected_mask) + + @require_torch + def test_num_img_tokens_handling(self): + from PIL import Image + + processor = self.processor_class.from_pretrained(self.tmpdirname) + image1 = Image.new("RGB", (64, 64), color="red") + image2 = Image.new("RGB", (128, 128), color="green") + + text = [ + f"{processor.image_token} First image.", + f"{processor.image_token} Second image.", + ] + inputs = processor(text=text, images=[image1, image2], return_tensors="pt") + + self.assertIn("num_img_tokens", inputs) + self.assertIsInstance(inputs["num_img_tokens"], list) + self.assertEqual(len(inputs["num_img_tokens"]), 2) + + @require_torch + def test_processor_with_multiple_images(self): + from PIL import Image + + processor = self.processor_class.from_pretrained(self.tmpdirname) + image1 = Image.new("RGB", (64, 64), color="red") + image2 = Image.new("RGB", (64, 64), color="blue") + + text = f"{processor.image_token}{processor.image_token} Two images here." + inputs = processor(text=text, images=[image1, image2], return_tensors="pt") + + self.assertIn("input_ids", inputs) + self.assertIn("pixel_values", inputs) + self.assertEqual(len(inputs["num_img_tokens"]), 2) + + @require_torch + def test_processor_error_on_token_mismatch(self): + from PIL import Image + + processor = self.processor_class.from_pretrained(self.tmpdirname) + image = Image.new("RGB", (64, 64), color="red") + + text = "No image token here." + + with self.assertRaises(ValueError) as context: + processor(text=text, images=image, return_tensors="pt") + + self.assertIn("does not match", str(context.exception)) diff --git a/utils/check_repo.py b/utils/check_repo.py index 4dbfb8e0bd8a..1989da4ed70e 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -191,6 +191,10 @@ "BltLocalDecoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. "BltGlobalTransformer", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. "Florence2VisionBackbone", # Building part of bigger (tested) model. Tested implicitly through Florence2ForConditionalGeneration. + "DeepseekOcrCLIPVisionModel", + "DeepseekOcrModel", + "DeepseekOcrProjector", + "DeepseekOcrTextModel", ] ) @@ -393,6 +397,10 @@ "Qwen3OmniMoeTalkerModel", # Building part of a bigger model "Qwen3OmniMoeThinkerForConditionalGeneration", # Building part of a bigger model "Qwen3OmniMoeThinkerTextModel", # Building part of a bigger model + "DeepseekOcrCLIPVisionModel", # Building part of bigger (tested) model + "DeepseekOcrModel", # Building part of bigger (tested) model + "DeepseekOcrProjector", # Building part of bigger (tested) model + "DeepseekOcrTextModel", # Building part of bigger (tested) model ]