vikhyatk commited on
Commit
3967d78
·
verified ·
1 Parent(s): b767709

Upload Moondream

Browse files
config.json CHANGED
@@ -11,5 +11,5 @@
11
  "model_type": "phi"
12
  },
13
  "torch_dtype": "float16",
14
- "transformers_version": "4.36.2"
15
  }
 
11
  "model_type": "phi"
12
  },
13
  "torch_dtype": "float16",
14
+ "transformers_version": "4.44.0"
15
  }
fourier_features.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/crowsonkb/k-diffusion/blob/transformer-model-v2/k_diffusion/layers.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+
7
+
8
+ class FourierFeatures(nn.Module):
9
+ def __init__(self, in_features, out_features, std=1.0):
10
+ super().__init__()
11
+ assert out_features % 2 == 0
12
+ self.register_buffer(
13
+ "weight", torch.randn([out_features // 2, in_features]) * std
14
+ )
15
+
16
+ def forward(self, input):
17
+ f = 2 * math.pi * input @ self.weight.T
18
+ return torch.cat([f.cos(), f.sin()], dim=-1)
generation_config.json CHANGED
@@ -2,5 +2,5 @@
2
  "_from_model_config": true,
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
- "transformers_version": "4.36.2"
6
  }
 
2
  "_from_model_config": true,
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
+ "transformers_version": "4.44.0"
6
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:927694193ed81f83b9b269c0d1ffa8dc823dec90bce4703a54b22ebd6c9632b6
3
- size 3733912224
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4bf7aed8ba4325d23fa7cd348d795a27f3b272682536f08aca4cdd62cde79293
3
+ size 3736040266
modeling_phi.py CHANGED
@@ -13,62 +13,113 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
- """ PyTorch Phi model."""
17
-
18
 
 
19
  from typing import List, Optional, Tuple, Union
20
 
21
  import torch
22
- import torch.nn.functional as F
23
  import torch.utils.checkpoint
 
24
  from torch import nn
25
  from torch.nn import CrossEntropyLoss
26
 
27
  from transformers.activations import ACT2FN
28
- from transformers.cache_utils import Cache, DynamicCache
29
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
30
  from transformers.modeling_outputs import (
31
  BaseModelOutputWithPast,
32
  CausalLMOutputWithPast,
33
  )
34
  from transformers.modeling_utils import PreTrainedModel
35
  from transformers.utils import (
 
 
 
36
  is_flash_attn_2_available,
37
  is_flash_attn_greater_or_equal_2_10,
 
38
  logging,
 
39
  )
40
  from .configuration_moondream import PhiConfig
41
 
42
 
43
- try: # noqa: SIM105
44
- if is_flash_attn_2_available():
45
- from flash_attn import flash_attn_func, flash_attn_varlen_func
46
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
47
- except ImportError:
48
- # Workaround for https://github.com/huggingface/transformers/issues/28459,
49
- # don't move to contextlib.suppress(ImportError)
50
- pass
51
 
52
 
53
  logger = logging.get_logger(__name__)
54
 
 
55
 
56
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
57
- def _get_unpad_data(attention_mask):
58
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
59
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
60
- max_seqlen_in_batch = seqlens_in_batch.max().item()
61
- cu_seqlens = F.pad(
62
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
63
- )
64
- return (
65
- indices,
66
- cu_seqlens,
67
- max_seqlen_in_batch,
68
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
- # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
72
  class PhiRotaryEmbedding(nn.Module):
73
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
74
  super().__init__()
@@ -77,7 +128,11 @@ class PhiRotaryEmbedding(nn.Module):
77
  self.max_position_embeddings = max_position_embeddings
78
  self.base = base
79
  inv_freq = 1.0 / (
80
- self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
 
 
 
 
81
  )
82
  self.register_buffer("inv_freq", inv_freq, persistent=False)
83
 
@@ -91,8 +146,8 @@ class PhiRotaryEmbedding(nn.Module):
91
  def _set_cos_sin_cache(self, seq_len, device, dtype):
92
  self.max_seq_len_cached = seq_len
93
  t = torch.arange(
94
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
95
- )
96
 
97
  freqs = torch.outer(t, self.inv_freq)
98
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
@@ -111,7 +166,7 @@ class PhiRotaryEmbedding(nn.Module):
111
  )
112
 
113
 
114
- # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
115
  class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
116
  """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
117
 
@@ -129,8 +184,8 @@ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
129
  def _set_cos_sin_cache(self, seq_len, device, dtype):
130
  self.max_seq_len_cached = seq_len
131
  t = torch.arange(
132
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
133
- )
134
  t = t / self.scaling_factor
135
 
136
  freqs = torch.outer(t, self.inv_freq)
@@ -140,7 +195,7 @@ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
140
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
141
 
142
 
143
- # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
144
  class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
145
  """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
146
 
@@ -164,13 +219,17 @@ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
164
  - (self.scaling_factor - 1)
165
  ) ** (self.dim / (self.dim - 2))
166
  inv_freq = 1.0 / (
167
- base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
 
 
 
 
168
  )
169
  self.register_buffer("inv_freq", inv_freq, persistent=False)
170
 
171
  t = torch.arange(
172
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
173
- )
174
 
175
  freqs = torch.outer(t, self.inv_freq)
176
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
@@ -187,7 +246,7 @@ def rotate_half(x):
187
  return torch.cat((-x2, x1), dim=-1)
188
 
189
 
190
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
191
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
192
  """Applies Rotary Position Embedding to the query and key tensors.
193
 
@@ -256,8 +315,8 @@ class PhiAttention(nn.Module):
256
  self.layer_idx = layer_idx
257
  if layer_idx is None:
258
  logger.warning_once(
259
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
260
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
261
  "when creating this class."
262
  )
263
 
@@ -322,6 +381,7 @@ class PhiAttention(nn.Module):
322
  past_key_value: Optional[Cache] = None,
323
  output_attentions: bool = False,
324
  use_cache: bool = False,
 
325
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
326
  bsz, q_len, _ = hidden_states.size()
327
 
@@ -373,6 +433,7 @@ class PhiAttention(nn.Module):
373
  "sin": sin,
374
  "cos": cos,
375
  "partial_rotation_size": self.rotary_emb.dim,
 
376
  }
377
  key_states, value_states = past_key_value.update(
378
  key_states, value_states, self.layer_idx, cache_kwargs
@@ -381,10 +442,37 @@ class PhiAttention(nn.Module):
381
  key_states = repeat_kv(key_states, self.num_key_value_groups)
382
  value_states = repeat_kv(value_states, self.num_key_value_groups)
383
 
384
- attn_output = torch.nn.functional.scaled_dot_product_attention(
385
- query_states, key_states, value_states, attn_mask=attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  )
387
 
 
 
 
 
 
 
 
 
388
  attn_output = attn_output.transpose(1, 2).contiguous()
389
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
390
 
@@ -420,6 +508,7 @@ class PhiFlashAttention2(PhiAttention):
420
  past_key_value: Optional[Cache] = None,
421
  output_attentions: bool = False,
422
  use_cache: bool = False,
 
423
  **kwargs,
424
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
425
  # PhiFlashAttention2 attention does not support output_attentions
@@ -473,6 +562,7 @@ class PhiFlashAttention2(PhiAttention):
473
  "sin": sin,
474
  "cos": cos,
475
  "partial_rotation_size": self.rotary_emb.dim,
 
476
  }
477
  key_states, value_states = past_key_value.update(
478
  key_states, value_states, self.layer_idx, cache_kwargs
@@ -511,14 +601,17 @@ class PhiFlashAttention2(PhiAttention):
511
  key_states = key_states.to(target_dtype)
512
  value_states = value_states.to(target_dtype)
513
 
514
- attn_output = self._flash_attention_forward(
515
  query_states,
516
  key_states,
517
  value_states,
518
  attention_mask,
519
  q_len,
 
520
  dropout=attn_dropout,
521
  softmax_scale=None,
 
 
522
  )
523
 
524
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -529,137 +622,148 @@ class PhiFlashAttention2(PhiAttention):
529
 
530
  return attn_output, attn_weights, past_key_value
531
 
532
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
533
- def _flash_attention_forward(
534
- self,
535
- query_states,
536
- key_states,
537
- value_states,
538
- attention_mask,
539
- query_length,
540
- dropout=0.0,
541
- softmax_scale=None,
542
- ):
543
- """
544
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
545
- first unpad the input, then computes the attention scores and pad the final attention scores.
546
-
547
- Args:
548
- query_states (`torch.Tensor`):
549
- Input query states to be passed to Flash Attention API
550
- key_states (`torch.Tensor`):
551
- Input key states to be passed to Flash Attention API
552
- value_states (`torch.Tensor`):
553
- Input value states to be passed to Flash Attention API
554
- attention_mask (`torch.Tensor`):
555
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
556
- position of padding tokens and 1 for the position of non-padding tokens.
557
- dropout (`int`, *optional*):
558
- Attention dropout
559
- softmax_scale (`float`, *optional*):
560
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
561
- """
562
- if not self._flash_attn_uses_top_left_mask:
563
- causal = self.is_causal
564
- else:
565
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
566
- causal = self.is_causal and query_length != 1
567
 
568
- # Contains at least one padding token in the sequence
569
- if attention_mask is not None:
570
- batch_size = query_states.shape[0]
571
- (
572
- query_states,
573
- key_states,
574
- value_states,
575
- indices_q,
576
- cu_seq_lens,
577
- max_seq_lens,
578
- ) = self._upad_input(
579
- query_states, key_states, value_states, attention_mask, query_length
580
- )
581
 
582
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
583
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
584
-
585
- attn_output_unpad = flash_attn_varlen_func(
586
- query_states,
587
- key_states,
588
- value_states,
589
- cu_seqlens_q=cu_seqlens_q,
590
- cu_seqlens_k=cu_seqlens_k,
591
- max_seqlen_q=max_seqlen_in_batch_q,
592
- max_seqlen_k=max_seqlen_in_batch_k,
593
- dropout_p=dropout,
594
- softmax_scale=softmax_scale,
595
- causal=causal,
596
- )
597
 
598
- attn_output = pad_input(
599
- attn_output_unpad, indices_q, batch_size, query_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
  )
601
- else:
602
- attn_output = flash_attn_func(
603
- query_states,
604
- key_states,
605
- value_states,
606
- dropout,
607
- softmax_scale=softmax_scale,
608
- causal=causal,
609
  )
610
 
611
- return attn_output
612
 
613
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
614
- def _upad_input(
615
- self, query_layer, key_layer, value_layer, attention_mask, query_length
616
- ):
617
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
618
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
619
 
620
- key_layer = index_first_axis(
621
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
622
- indices_k,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
  )
624
- value_layer = index_first_axis(
625
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
626
- indices_k,
627
  )
628
- if query_length == kv_seq_len:
629
- query_layer = index_first_axis(
630
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
631
- indices_k,
632
- )
633
- cu_seqlens_q = cu_seqlens_k
634
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
635
- indices_q = indices_k
636
- elif query_length == 1:
637
- max_seqlen_in_batch_q = 1
638
- cu_seqlens_q = torch.arange(
639
- batch_size + 1, dtype=torch.int32, device=query_layer.device
640
- ) # There is a memcpy here, that is very bad.
641
- indices_q = cu_seqlens_q[:-1]
642
- query_layer = query_layer.squeeze(1)
643
- else:
644
- # The -q_len: slice assumes left padding.
645
- attention_mask = attention_mask[:, -query_length:]
646
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
647
- query_layer, attention_mask
648
  )
649
 
650
- return (
651
- query_layer,
652
- key_layer,
653
- value_layer,
654
- indices_q,
655
- (cu_seqlens_q, cu_seqlens_k),
656
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  )
658
 
 
 
 
 
 
 
 
659
 
660
  PHI_ATTENTION_CLASSES = {
661
  "eager": PhiAttention,
662
  "flash_attention_2": PhiFlashAttention2,
 
663
  }
664
 
665
 
@@ -681,6 +785,8 @@ class PhiDecoderLayer(nn.Module):
681
  output_attentions: Optional[bool] = False,
682
  use_cache: Optional[bool] = False,
683
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
 
 
684
  ) -> Tuple[
685
  torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
686
  ]:
@@ -700,6 +806,11 @@ class PhiDecoderLayer(nn.Module):
700
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
701
  (see `past_key_values`).
702
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
 
 
 
 
 
703
  """
704
 
705
  residual = hidden_states
@@ -714,6 +825,7 @@ class PhiDecoderLayer(nn.Module):
714
  past_key_value=past_key_value,
715
  output_attentions=output_attentions,
716
  use_cache=use_cache,
 
717
  )
718
  attn_outputs = self.resid_dropout(attn_outputs)
719
 
@@ -730,6 +842,27 @@ class PhiDecoderLayer(nn.Module):
730
  return outputs
731
 
732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733
  class PhiPreTrainedModel(PreTrainedModel):
734
  config_class = PhiConfig
735
  base_model_prefix = "model"
@@ -737,6 +870,7 @@ class PhiPreTrainedModel(PreTrainedModel):
737
  _no_split_modules = ["PhiDecoderLayer"]
738
  _skip_keys_device_placement = "past_key_values"
739
  _supports_flash_attn_2 = True
 
740
  _supports_cache_class = True
741
 
742
  def _init_weights(self, module):
@@ -761,7 +895,84 @@ class Embedding(nn.Module):
761
  def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
762
  return self.wte(input_ids)
763
 
764
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
765
  class PhiModel(PhiPreTrainedModel):
766
  """
767
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
@@ -783,7 +994,9 @@ class PhiModel(PhiPreTrainedModel):
783
  for layer_idx in range(config.num_hidden_layers)
784
  ]
785
  )
 
786
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
 
787
 
788
  self.gradient_checkpointing = False
789
  # Initialize weights and apply final processing
@@ -795,6 +1008,7 @@ class PhiModel(PhiPreTrainedModel):
795
  def set_input_embeddings(self, value):
796
  self.embd.wte = value
797
 
 
798
  def forward(
799
  self,
800
  input_ids: torch.LongTensor = None,
@@ -806,6 +1020,7 @@ class PhiModel(PhiPreTrainedModel):
806
  output_attentions: Optional[bool] = None,
807
  output_hidden_states: Optional[bool] = None,
808
  return_dict: Optional[bool] = None,
 
809
  ) -> Union[Tuple, BaseModelOutputWithPast]:
810
  output_attentions = (
811
  output_attentions
@@ -823,19 +1038,10 @@ class PhiModel(PhiPreTrainedModel):
823
  return_dict if return_dict is not None else self.config.use_return_dict
824
  )
825
 
826
- # retrieve input_ids and inputs_embeds
827
- if input_ids is not None and inputs_embeds is not None:
828
  raise ValueError(
829
- "You cannot specify both input_ids and inputs_embeds at the same time"
830
  )
831
- elif input_ids is not None:
832
- batch_size, seq_length = input_ids.shape[:2]
833
- elif inputs_embeds is not None:
834
- batch_size, seq_length = inputs_embeds.shape[:2]
835
- else:
836
- raise ValueError("You have to specify either input_ids or inputs_embeds")
837
-
838
- past_key_values_length = 0
839
 
840
  if self.gradient_checkpointing and self.training:
841
  if use_cache:
@@ -844,43 +1050,37 @@ class PhiModel(PhiPreTrainedModel):
844
  )
845
  use_cache = False
846
 
847
- if use_cache:
848
- use_legacy_cache = not isinstance(past_key_values, Cache)
849
- if use_legacy_cache:
850
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
851
- past_key_values_length = past_key_values.get_usable_length(seq_length)
852
-
853
- if position_ids is None:
854
- device = input_ids.device if input_ids is not None else inputs_embeds.device
855
- position_ids = torch.arange(
856
- past_key_values_length,
857
- seq_length + past_key_values_length,
858
- dtype=torch.long,
859
- device=device,
860
  )
861
- position_ids = position_ids.unsqueeze(0)
862
 
863
  if inputs_embeds is None:
864
  inputs_embeds = self.embd(input_ids)
865
 
866
- inputs_embeds = self.embed_dropout(inputs_embeds)
867
-
868
- # Attention mask.
869
- if self._use_flash_attention_2:
870
- # 2d mask is passed through the layers
871
- attention_mask = (
872
- attention_mask
873
- if (attention_mask is not None and 0 in attention_mask)
874
- else None
875
  )
876
- else:
877
- # 4d mask is passed through the layers
878
- attention_mask = _prepare_4d_causal_attention_mask(
879
- attention_mask,
880
- (batch_size, seq_length),
881
- inputs_embeds,
882
- past_key_values_length,
883
  )
 
 
 
 
 
 
 
 
 
 
884
 
885
  hidden_states = inputs_embeds
886
 
@@ -897,19 +1097,22 @@ class PhiModel(PhiPreTrainedModel):
897
  layer_outputs = self._gradient_checkpointing_func(
898
  decoder_layer.__call__,
899
  hidden_states,
900
- attention_mask,
901
  position_ids,
902
- past_key_values,
903
  output_attentions,
 
 
 
904
  )
905
  else:
906
  layer_outputs = decoder_layer(
907
  hidden_states,
908
- attention_mask=attention_mask,
909
  position_ids=position_ids,
910
  past_key_value=past_key_values,
911
  output_attentions=output_attentions,
912
  use_cache=use_cache,
 
913
  )
914
 
915
  hidden_states = layer_outputs[0]
@@ -944,6 +1147,86 @@ class PhiModel(PhiPreTrainedModel):
944
  attentions=all_self_attns,
945
  )
946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
947
 
948
  class CausalLMHead(nn.Module):
949
  """Causal Language Modeling head. Simplified version."""
@@ -958,7 +1241,6 @@ class CausalLMHead(nn.Module):
958
 
959
 
960
  class PhiForCausalLM(PhiPreTrainedModel):
961
- _tied_weights_keys = ["lm_head.linear.weight"]
962
 
963
  # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
964
  def __init__(self, config):
@@ -976,7 +1258,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
976
 
977
  # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
978
  def set_input_embeddings(self, value):
979
- self.model.embd.wte = value
980
 
981
  # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
982
  def get_output_embeddings(self):
@@ -994,6 +1276,10 @@ class PhiForCausalLM(PhiPreTrainedModel):
994
  def get_decoder(self):
995
  return self.model
996
 
 
 
 
 
997
  def forward(
998
  self,
999
  input_ids: torch.LongTensor = None,
@@ -1006,6 +1292,8 @@ class PhiForCausalLM(PhiPreTrainedModel):
1006
  output_attentions: Optional[bool] = None,
1007
  output_hidden_states: Optional[bool] = None,
1008
  return_dict: Optional[bool] = None,
 
 
1009
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1010
  r"""
1011
  Args:
@@ -1014,6 +1302,11 @@ class PhiForCausalLM(PhiPreTrainedModel):
1014
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1015
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1016
 
 
 
 
 
 
1017
  Returns:
1018
 
1019
  Example:
@@ -1058,13 +1351,16 @@ class PhiForCausalLM(PhiPreTrainedModel):
1058
  output_attentions=output_attentions,
1059
  output_hidden_states=output_hidden_states,
1060
  return_dict=return_dict,
 
1061
  )
1062
 
1063
  hidden_states = outputs[0]
1064
- logits = self.lm_head(hidden_states)
1065
 
1066
  loss = None
1067
  if labels is not None:
 
 
1068
  # Shift so that tokens < n predict n
1069
  shift_logits = logits[..., :-1, :].contiguous()
1070
  shift_labels = labels[..., 1:].contiguous()
@@ -1095,41 +1391,23 @@ class PhiForCausalLM(PhiPreTrainedModel):
1095
  past_key_values=None,
1096
  attention_mask=None,
1097
  inputs_embeds=None,
 
 
 
 
1098
  **kwargs,
1099
  ):
 
 
 
1100
  if past_key_values is not None:
1101
- if isinstance(past_key_values, Cache):
1102
- cache_length = past_key_values.get_seq_length()
1103
- past_length = past_key_values.seen_tokens
1104
- max_cache_length = past_key_values.get_max_length()
1105
- else:
1106
- cache_length = past_length = past_key_values[0][0].shape[2]
1107
- max_cache_length = None
1108
-
1109
- # Keep only the unprocessed tokens:
1110
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1111
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1112
- # input)
1113
- if (
1114
- attention_mask is not None
1115
- and attention_mask.shape[1] > input_ids.shape[1]
1116
- ):
1117
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1118
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1119
- # input_ids based on the past_length.
1120
- elif past_length < input_ids.shape[1]:
1121
- input_ids = input_ids[:, past_length:]
1122
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1123
-
1124
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1125
- if (
1126
- max_cache_length is not None
1127
- and attention_mask is not None
1128
- and cache_length + input_ids.shape[1] > max_cache_length
1129
- ):
1130
- attention_mask = attention_mask[:, -max_cache_length:]
1131
 
1132
- position_ids = kwargs.get("position_ids", None)
1133
  if attention_mask is not None and position_ids is None:
1134
  # create position_ids on the fly for batch generation
1135
  position_ids = attention_mask.long().cumsum(-1) - 1
@@ -1137,31 +1415,49 @@ class PhiForCausalLM(PhiPreTrainedModel):
1137
  if past_key_values:
1138
  position_ids = position_ids[:, -input_ids.shape[1] :]
1139
 
 
 
 
1140
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1141
- if inputs_embeds is not None and (input_ids is None or input_ids.shape[1] == 0):
1142
- model_inputs = {"inputs_embeds": inputs_embeds}
1143
  else:
1144
- model_inputs = {"input_ids": input_ids}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1145
 
1146
  model_inputs.update(
1147
  {
1148
  "position_ids": position_ids,
 
1149
  "past_key_values": past_key_values,
1150
- "use_cache": kwargs.get("use_cache"),
1151
  "attention_mask": attention_mask,
 
1152
  }
1153
  )
1154
  return model_inputs
1155
-
1156
- @staticmethod
1157
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1158
- def _reorder_cache(past_key_values, beam_idx):
1159
- reordered_past = ()
1160
- for layer_past in past_key_values:
1161
- reordered_past += (
1162
- tuple(
1163
- past_state.index_select(0, beam_idx.to(past_state.device))
1164
- for past_state in layer_past
1165
- ),
1166
- )
1167
- return reordered_past
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
+ """PyTorch Phi model."""
 
17
 
18
+ import math
19
  from typing import List, Optional, Tuple, Union
20
 
21
  import torch
 
22
  import torch.utils.checkpoint
23
+ from packaging import version
24
  from torch import nn
25
  from torch.nn import CrossEntropyLoss
26
 
27
  from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
29
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
30
  from transformers.modeling_outputs import (
31
  BaseModelOutputWithPast,
32
  CausalLMOutputWithPast,
33
  )
34
  from transformers.modeling_utils import PreTrainedModel
35
  from transformers.utils import (
36
+ add_start_docstrings,
37
+ add_start_docstrings_to_model_forward,
38
+ get_torch_version,
39
  is_flash_attn_2_available,
40
  is_flash_attn_greater_or_equal_2_10,
41
+ is_torchdynamo_compiling,
42
  logging,
43
+ replace_return_docstrings,
44
  )
45
  from .configuration_moondream import PhiConfig
46
 
47
 
48
+ if is_flash_attn_2_available():
49
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
 
 
 
 
 
 
50
 
51
 
52
  logger = logging.get_logger(__name__)
53
 
54
+ _CONFIG_FOR_DOC = "PhiConfig"
55
 
56
+
57
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
58
+ def _prepare_4d_causal_attention_mask_with_cache_position(
59
+ attention_mask: torch.Tensor,
60
+ sequence_length: int,
61
+ target_length: int,
62
+ dtype: torch.dtype,
63
+ device: torch.device,
64
+ min_dtype: float,
65
+ cache_position: torch.Tensor,
66
+ batch_size: int,
67
+ ):
68
+ """
69
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
70
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
71
+
72
+ Args:
73
+ attention_mask (`torch.Tensor`):
74
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
75
+ sequence_length (`int`):
76
+ The sequence length being processed.
77
+ target_length (`int`):
78
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
79
+ dtype (`torch.dtype`):
80
+ The dtype to use for the 4D attention mask.
81
+ device (`torch.device`):
82
+ The device to plcae the 4D attention mask on.
83
+ min_dtype (`float`):
84
+ The minimum value representable with the dtype `dtype`.
85
+ cache_position (`torch.Tensor`):
86
+ Indices depicting the position of the input sequence tokens in the sequence.
87
+ batch_size (`torch.Tensor`):
88
+ Batch size.
89
+ """
90
+ if attention_mask is not None and attention_mask.dim() == 4:
91
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
92
+ causal_mask = attention_mask
93
+ else:
94
+ causal_mask = torch.full(
95
+ (sequence_length, target_length),
96
+ fill_value=min_dtype,
97
+ dtype=dtype,
98
+ device=device,
99
+ )
100
+ if sequence_length != 1:
101
+ causal_mask = torch.triu(causal_mask, diagonal=1)
102
+ causal_mask *= torch.arange(
103
+ target_length, device=device
104
+ ) > cache_position.reshape(-1, 1)
105
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
106
+ if attention_mask is not None:
107
+ causal_mask = (
108
+ causal_mask.clone()
109
+ ) # copy to contiguous memory for in-place edit
110
+ mask_length = attention_mask.shape[-1]
111
+ padding_mask = (
112
+ causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
113
+ )
114
+ padding_mask = padding_mask == 0
115
+ causal_mask[:, :, :, :mask_length] = causal_mask[
116
+ :, :, :, :mask_length
117
+ ].masked_fill(padding_mask, min_dtype)
118
+
119
+ return causal_mask
120
 
121
 
122
+ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
123
  class PhiRotaryEmbedding(nn.Module):
124
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
125
  super().__init__()
 
128
  self.max_position_embeddings = max_position_embeddings
129
  self.base = base
130
  inv_freq = 1.0 / (
131
+ self.base
132
+ ** (
133
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
134
+ / self.dim
135
+ )
136
  )
137
  self.register_buffer("inv_freq", inv_freq, persistent=False)
138
 
 
146
  def _set_cos_sin_cache(self, seq_len, device, dtype):
147
  self.max_seq_len_cached = seq_len
148
  t = torch.arange(
149
+ self.max_seq_len_cached, device=device, dtype=torch.int64
150
+ ).type_as(self.inv_freq)
151
 
152
  freqs = torch.outer(t, self.inv_freq)
153
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
 
166
  )
167
 
168
 
169
+ # Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
170
  class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
171
  """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
172
 
 
184
  def _set_cos_sin_cache(self, seq_len, device, dtype):
185
  self.max_seq_len_cached = seq_len
186
  t = torch.arange(
187
+ self.max_seq_len_cached, device=device, dtype=torch.int64
188
+ ).type_as(self.inv_freq)
189
  t = t / self.scaling_factor
190
 
191
  freqs = torch.outer(t, self.inv_freq)
 
195
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
196
 
197
 
198
+ # Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
199
  class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
200
  """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
201
 
 
219
  - (self.scaling_factor - 1)
220
  ) ** (self.dim / (self.dim - 2))
221
  inv_freq = 1.0 / (
222
+ base
223
+ ** (
224
+ torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
225
+ / self.dim
226
+ )
227
  )
228
  self.register_buffer("inv_freq", inv_freq, persistent=False)
229
 
230
  t = torch.arange(
231
+ self.max_seq_len_cached, device=device, dtype=torch.int64
232
+ ).type_as(self.inv_freq)
233
 
234
  freqs = torch.outer(t, self.inv_freq)
235
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
 
246
  return torch.cat((-x2, x1), dim=-1)
247
 
248
 
249
+ # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
250
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
251
  """Applies Rotary Position Embedding to the query and key tensors.
252
 
 
315
  self.layer_idx = layer_idx
316
  if layer_idx is None:
317
  logger.warning_once(
318
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
319
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
320
  "when creating this class."
321
  )
322
 
 
381
  past_key_value: Optional[Cache] = None,
382
  output_attentions: bool = False,
383
  use_cache: bool = False,
384
+ cache_position: Optional[torch.LongTensor] = None,
385
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
386
  bsz, q_len, _ = hidden_states.size()
387
 
 
433
  "sin": sin,
434
  "cos": cos,
435
  "partial_rotation_size": self.rotary_emb.dim,
436
+ "cache_position": cache_position,
437
  }
438
  key_states, value_states = past_key_value.update(
439
  key_states, value_states, self.layer_idx, cache_kwargs
 
442
  key_states = repeat_kv(key_states, self.num_key_value_groups)
443
  value_states = repeat_kv(value_states, self.num_key_value_groups)
444
 
445
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
446
+ attn_weights = torch.matmul(
447
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
448
+ ) / math.sqrt(self.head_dim)
449
+
450
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
451
+ raise ValueError(
452
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
453
+ f" {attn_weights.size()}"
454
+ )
455
+
456
+ if attention_mask is not None:
457
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
458
+ attn_weights += causal_mask
459
+
460
+ # upcast attention to fp32
461
+ attn_weights = nn.functional.softmax(
462
+ attn_weights, dim=-1, dtype=torch.float32
463
+ ).to(value_states.dtype)
464
+ attn_weights = nn.functional.dropout(
465
+ attn_weights, p=self.attention_dropout, training=self.training
466
  )
467
 
468
+ attn_output = torch.matmul(attn_weights, value_states)
469
+
470
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
471
+ raise ValueError(
472
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
473
+ f" {attn_output.size()}"
474
+ )
475
+
476
  attn_output = attn_output.transpose(1, 2).contiguous()
477
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
478
 
 
508
  past_key_value: Optional[Cache] = None,
509
  output_attentions: bool = False,
510
  use_cache: bool = False,
511
+ cache_position: Optional[torch.LongTensor] = None,
512
  **kwargs,
513
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
514
  # PhiFlashAttention2 attention does not support output_attentions
 
562
  "sin": sin,
563
  "cos": cos,
564
  "partial_rotation_size": self.rotary_emb.dim,
565
+ "cache_position": cache_position,
566
  }
567
  key_states, value_states = past_key_value.update(
568
  key_states, value_states, self.layer_idx, cache_kwargs
 
601
  key_states = key_states.to(target_dtype)
602
  value_states = value_states.to(target_dtype)
603
 
604
+ attn_output = _flash_attention_forward(
605
  query_states,
606
  key_states,
607
  value_states,
608
  attention_mask,
609
  q_len,
610
+ position_ids=position_ids,
611
  dropout=attn_dropout,
612
  softmax_scale=None,
613
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
614
+ is_causal=self.is_causal,
615
  )
616
 
617
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
 
622
 
623
  return attn_output, attn_weights, past_key_value
624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
 
626
+ class PhiSdpaAttention(PhiAttention):
627
+ def __init__(self, *args, **kwargs):
628
+ super().__init__(*args, **kwargs)
629
+ self.require_contiguous_qkv = version.parse(
630
+ get_torch_version()
631
+ ) < version.parse("2.2.0")
 
 
 
 
 
 
 
632
 
633
+ """
634
+ SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
635
+ `PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
636
+ SDPA API.
637
+ """
 
 
 
 
 
 
 
 
 
 
638
 
639
+ # Adapted from PhiAttention.forward
640
+ def forward(
641
+ self,
642
+ hidden_states: torch.Tensor,
643
+ attention_mask: Optional[torch.Tensor] = None,
644
+ position_ids: Optional[torch.LongTensor] = None,
645
+ past_key_value: Optional[Cache] = None,
646
+ output_attentions: bool = False,
647
+ use_cache: bool = False,
648
+ cache_position: Optional[torch.LongTensor] = None,
649
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
650
+ if output_attentions:
651
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
652
+ logger.warning_once(
653
+ "PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
654
+ "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
655
+ "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
656
+ 'be removed using the argument `attn_implementation="eager"` when loading the model.'
657
  )
658
+ return super().forward(
659
+ hidden_states=hidden_states,
660
+ attention_mask=attention_mask,
661
+ position_ids=position_ids,
662
+ past_key_value=past_key_value,
663
+ output_attentions=output_attentions,
664
+ use_cache=use_cache,
 
665
  )
666
 
667
+ bsz, q_len, _ = hidden_states.size()
668
 
669
+ query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
670
+ 3, dim=-1
671
+ )
 
 
 
672
 
673
+ query_states = query_states.view(
674
+ bsz, q_len, self.num_heads, self.head_dim
675
+ ).transpose(1, 2)
676
+ key_states = key_states.view(
677
+ bsz, q_len, self.num_key_value_heads, self.head_dim
678
+ ).transpose(1, 2)
679
+ value_states = value_states.view(
680
+ bsz, q_len, self.num_key_value_heads, self.head_dim
681
+ ).transpose(1, 2)
682
+
683
+ kv_seq_len = key_states.shape[-2]
684
+ if past_key_value is not None:
685
+ if self.layer_idx is None:
686
+ raise ValueError(
687
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
688
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
689
+ "with a layer index."
690
+ )
691
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
692
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
693
+
694
+ # Partial rotary embedding
695
+ query_rot, query_pass = (
696
+ query_states[..., : self.rotary_emb.dim],
697
+ query_states[..., self.rotary_emb.dim :],
698
  )
699
+ key_rot, key_pass = (
700
+ key_states[..., : self.rotary_emb.dim],
701
+ key_states[..., self.rotary_emb.dim :],
702
  )
703
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
704
+ query_rot, key_rot = apply_rotary_pos_emb(
705
+ query_rot, key_rot, cos, sin, position_ids
706
+ )
707
+
708
+ # [batch_size, seq_length, num_heads, head_dim]
709
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
710
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
711
+
712
+ if past_key_value is not None:
713
+ cache_kwargs = {
714
+ "sin": sin,
715
+ "cos": cos,
716
+ "partial_rotation_size": self.rotary_emb.dim,
717
+ "cache_position": cache_position,
718
+ }
719
+ key_states, value_states = past_key_value.update(
720
+ key_states, value_states, self.layer_idx, cache_kwargs
 
 
721
  )
722
 
723
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
724
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
725
+
726
+ causal_mask = attention_mask
727
+ if attention_mask is not None:
728
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
729
+
730
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
731
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
732
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
733
+ if (
734
+ self.require_contiguous_qkv
735
+ and query_states.device.type == "cuda"
736
+ and attention_mask is not None
737
+ ):
738
+ query_states = query_states.contiguous()
739
+ key_states = key_states.contiguous()
740
+ value_states = value_states.contiguous()
741
+
742
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
743
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
744
+ is_causal = True if causal_mask is None and q_len > 1 else False
745
+
746
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
747
+ query_states,
748
+ key_states,
749
+ value_states,
750
+ attn_mask=causal_mask,
751
+ dropout_p=self.attention_dropout if self.training else 0.0,
752
+ is_causal=is_causal,
753
  )
754
 
755
+ attn_output = attn_output.transpose(1, 2).contiguous()
756
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
757
+
758
+ attn_output = self.out_proj(attn_output)
759
+
760
+ return attn_output, None, past_key_value
761
+
762
 
763
  PHI_ATTENTION_CLASSES = {
764
  "eager": PhiAttention,
765
  "flash_attention_2": PhiFlashAttention2,
766
+ "sdpa": PhiSdpaAttention,
767
  }
768
 
769
 
 
785
  output_attentions: Optional[bool] = False,
786
  use_cache: Optional[bool] = False,
787
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
788
+ cache_position: Optional[torch.LongTensor] = None,
789
+ **kwargs,
790
  ) -> Tuple[
791
  torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
792
  ]:
 
806
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
807
  (see `past_key_values`).
808
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
809
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
810
+ Indices depicting the position of the input sequence tokens in the sequence
811
+ kwargs (`dict`, *optional*):
812
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
813
+ into the model
814
  """
815
 
816
  residual = hidden_states
 
825
  past_key_value=past_key_value,
826
  output_attentions=output_attentions,
827
  use_cache=use_cache,
828
+ cache_position=cache_position,
829
  )
830
  attn_outputs = self.resid_dropout(attn_outputs)
831
 
 
842
  return outputs
843
 
844
 
845
+ PHI_START_DOCSTRING = r"""
846
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
847
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
848
+ etc.)
849
+
850
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
851
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
852
+ and behavior.
853
+
854
+ Parameters:
855
+ config ([`PhiConfig`]):
856
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
857
+ load the weights associated with the model, only the configuration. Check out the
858
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
859
+ """
860
+
861
+
862
+ @add_start_docstrings(
863
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
864
+ PHI_START_DOCSTRING,
865
+ )
866
  class PhiPreTrainedModel(PreTrainedModel):
867
  config_class = PhiConfig
868
  base_model_prefix = "model"
 
870
  _no_split_modules = ["PhiDecoderLayer"]
871
  _skip_keys_device_placement = "past_key_values"
872
  _supports_flash_attn_2 = True
873
+ _supports_sdpa = True
874
  _supports_cache_class = True
875
 
876
  def _init_weights(self, module):
 
895
  def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
896
  return self.wte(input_ids)
897
 
898
+ PHI_INPUTS_DOCSTRING = r"""
899
+ Args:
900
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
901
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
902
+ it.
903
+
904
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
905
+ [`PreTrainedTokenizer.__call__`] for details.
906
+
907
+ [What are input IDs?](../glossary#input-ids)
908
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
909
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
910
+
911
+ - 1 for tokens that are **not masked**,
912
+ - 0 for tokens that are **masked**.
913
+
914
+ [What are attention masks?](../glossary#attention-mask)
915
+
916
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
917
+ [`PreTrainedTokenizer.__call__`] for details.
918
+
919
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
920
+ `past_key_values`).
921
+
922
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
923
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
924
+ information on the default strategy.
925
+
926
+ - 1 indicates the head is **not masked**,
927
+ - 0 indicates the head is **masked**.
928
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
929
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
930
+ config.n_positions - 1]`.
931
+
932
+ [What are position IDs?](../glossary#position-ids)
933
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
934
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
935
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
936
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
937
+
938
+ Two formats are allowed:
939
+ - a [`~cache_utils.Cache`] instance;
940
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
941
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
942
+ cache format.
943
+
944
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
945
+ legacy cache format will be returned.
946
+
947
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
948
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
949
+ of shape `(batch_size, sequence_length)`.
950
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
951
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
952
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
953
+ model's internal embedding lookup matrix.
954
+ use_cache (`bool`, *optional*):
955
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
956
+ `past_key_values`).
957
+ output_attentions (`bool`, *optional*):
958
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
959
+ tensors for more detail.
960
+ output_hidden_states (`bool`, *optional*):
961
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
962
+ more detail.
963
+ return_dict (`bool`, *optional*):
964
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
965
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
966
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
967
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
968
+ the complete sequence length.
969
+ """
970
+
971
+
972
+ @add_start_docstrings(
973
+ "The bare Phi Model outputting raw hidden-states without any specific head on top.",
974
+ PHI_START_DOCSTRING,
975
+ )
976
  class PhiModel(PhiPreTrainedModel):
977
  """
978
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
 
994
  for layer_idx in range(config.num_hidden_layers)
995
  ]
996
  )
997
+
998
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
999
+ self._use_sdpa = config._attn_implementation == "sdpa"
1000
 
1001
  self.gradient_checkpointing = False
1002
  # Initialize weights and apply final processing
 
1008
  def set_input_embeddings(self, value):
1009
  self.embd.wte = value
1010
 
1011
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1012
  def forward(
1013
  self,
1014
  input_ids: torch.LongTensor = None,
 
1020
  output_attentions: Optional[bool] = None,
1021
  output_hidden_states: Optional[bool] = None,
1022
  return_dict: Optional[bool] = None,
1023
+ cache_position: Optional[torch.LongTensor] = None,
1024
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1025
  output_attentions = (
1026
  output_attentions
 
1038
  return_dict if return_dict is not None else self.config.use_return_dict
1039
  )
1040
 
1041
+ if (input_ids is None) ^ (inputs_embeds is not None):
 
1042
  raise ValueError(
1043
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1044
  )
 
 
 
 
 
 
 
 
1045
 
1046
  if self.gradient_checkpointing and self.training:
1047
  if use_cache:
 
1050
  )
1051
  use_cache = False
1052
 
1053
+ use_legacy_cache = False
1054
+ if use_cache and not isinstance(past_key_values, Cache) and not self.training:
1055
+ use_legacy_cache = True
1056
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1057
+ logger.warning_once(
1058
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
1059
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
 
 
 
 
 
 
1060
  )
 
1061
 
1062
  if inputs_embeds is None:
1063
  inputs_embeds = self.embd(input_ids)
1064
 
1065
+ if cache_position is None:
1066
+ past_seen_tokens = (
1067
+ past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
 
 
 
 
1068
  )
1069
+ cache_position = torch.arange(
1070
+ past_seen_tokens,
1071
+ past_seen_tokens + inputs_embeds.shape[1],
1072
+ device=inputs_embeds.device,
 
 
 
1073
  )
1074
+ if position_ids is None:
1075
+ position_ids = cache_position.unsqueeze(0)
1076
+
1077
+ causal_mask = self._update_causal_mask(
1078
+ attention_mask,
1079
+ inputs_embeds,
1080
+ cache_position,
1081
+ past_key_values,
1082
+ output_attentions,
1083
+ )
1084
 
1085
  hidden_states = inputs_embeds
1086
 
 
1097
  layer_outputs = self._gradient_checkpointing_func(
1098
  decoder_layer.__call__,
1099
  hidden_states,
1100
+ causal_mask,
1101
  position_ids,
 
1102
  output_attentions,
1103
+ use_cache,
1104
+ past_key_values,
1105
+ cache_position,
1106
  )
1107
  else:
1108
  layer_outputs = decoder_layer(
1109
  hidden_states,
1110
+ attention_mask=causal_mask,
1111
  position_ids=position_ids,
1112
  past_key_value=past_key_values,
1113
  output_attentions=output_attentions,
1114
  use_cache=use_cache,
1115
+ cache_position=cache_position,
1116
  )
1117
 
1118
  hidden_states = layer_outputs[0]
 
1147
  attentions=all_self_attns,
1148
  )
1149
 
1150
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1151
+ def _update_causal_mask(
1152
+ self,
1153
+ attention_mask: torch.Tensor,
1154
+ input_tensor: torch.Tensor,
1155
+ cache_position: torch.Tensor,
1156
+ past_key_values: Cache,
1157
+ output_attentions: bool,
1158
+ ):
1159
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1160
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1161
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1162
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1163
+
1164
+ if self.config._attn_implementation == "flash_attention_2":
1165
+ if attention_mask is not None and 0.0 in attention_mask:
1166
+ return attention_mask
1167
+ return None
1168
+
1169
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1170
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1171
+ # to infer the attention mask.
1172
+ past_seen_tokens = (
1173
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1174
+ )
1175
+ using_static_cache = isinstance(past_key_values, StaticCache)
1176
+
1177
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1178
+ if (
1179
+ self.config._attn_implementation == "sdpa"
1180
+ and not using_static_cache
1181
+ and not output_attentions
1182
+ ):
1183
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1184
+ attention_mask,
1185
+ inputs_embeds=input_tensor,
1186
+ past_key_values_length=past_seen_tokens,
1187
+ is_training=self.training,
1188
+ ):
1189
+ return None
1190
+
1191
+ dtype, device = input_tensor.dtype, input_tensor.device
1192
+ min_dtype = torch.finfo(dtype).min
1193
+ sequence_length = input_tensor.shape[1]
1194
+ if using_static_cache:
1195
+ target_length = past_key_values.get_max_length()
1196
+ else:
1197
+ target_length = (
1198
+ attention_mask.shape[-1]
1199
+ if isinstance(attention_mask, torch.Tensor)
1200
+ else past_seen_tokens + sequence_length + 1
1201
+ )
1202
+
1203
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1204
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1205
+ attention_mask,
1206
+ sequence_length=sequence_length,
1207
+ target_length=target_length,
1208
+ dtype=dtype,
1209
+ device=device,
1210
+ min_dtype=min_dtype,
1211
+ cache_position=cache_position,
1212
+ batch_size=input_tensor.shape[0],
1213
+ )
1214
+
1215
+ if (
1216
+ self.config._attn_implementation == "sdpa"
1217
+ and attention_mask is not None
1218
+ and attention_mask.device.type == "cuda"
1219
+ and not output_attentions
1220
+ ):
1221
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1222
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1223
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1224
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1225
+ causal_mask, min_dtype
1226
+ )
1227
+
1228
+ return causal_mask
1229
+
1230
 
1231
  class CausalLMHead(nn.Module):
1232
  """Causal Language Modeling head. Simplified version."""
 
1241
 
1242
 
1243
  class PhiForCausalLM(PhiPreTrainedModel):
 
1244
 
1245
  # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
1246
  def __init__(self, config):
 
1258
 
1259
  # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1260
  def set_input_embeddings(self, value):
1261
+ self.transformer.embd.wte = value
1262
 
1263
  # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1264
  def get_output_embeddings(self):
 
1276
  def get_decoder(self):
1277
  return self.model
1278
 
1279
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1280
+ @replace_return_docstrings(
1281
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1282
+ )
1283
  def forward(
1284
  self,
1285
  input_ids: torch.LongTensor = None,
 
1292
  output_attentions: Optional[bool] = None,
1293
  output_hidden_states: Optional[bool] = None,
1294
  return_dict: Optional[bool] = None,
1295
+ cache_position: Optional[torch.LongTensor] = None,
1296
+ num_logits_to_keep: int = 0,
1297
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1298
  r"""
1299
  Args:
 
1302
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1303
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1304
 
1305
+ num_logits_to_keep (`int`, *optional*):
1306
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1307
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1308
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1309
+
1310
  Returns:
1311
 
1312
  Example:
 
1351
  output_attentions=output_attentions,
1352
  output_hidden_states=output_hidden_states,
1353
  return_dict=return_dict,
1354
+ cache_position=cache_position,
1355
  )
1356
 
1357
  hidden_states = outputs[0]
1358
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
1359
 
1360
  loss = None
1361
  if labels is not None:
1362
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
1363
+ logits = logits.float()
1364
  # Shift so that tokens < n predict n
1365
  shift_logits = logits[..., :-1, :].contiguous()
1366
  shift_labels = labels[..., 1:].contiguous()
 
1391
  past_key_values=None,
1392
  attention_mask=None,
1393
  inputs_embeds=None,
1394
+ cache_position=None,
1395
+ position_ids=None,
1396
+ use_cache=True,
1397
+ num_logits_to_keep=0,
1398
  **kwargs,
1399
  ):
1400
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1401
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1402
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1403
  if past_key_values is not None:
1404
+ if inputs_embeds is not None: # Exception 1
1405
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1406
+ elif (
1407
+ input_ids.shape[1] != cache_position.shape[0]
1408
+ ): # Default case (the "else", a no op, is Exception 2)
1409
+ input_ids = input_ids[:, cache_position]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1410
 
 
1411
  if attention_mask is not None and position_ids is None:
1412
  # create position_ids on the fly for batch generation
1413
  position_ids = attention_mask.long().cumsum(-1) - 1
 
1415
  if past_key_values:
1416
  position_ids = position_ids[:, -input_ids.shape[1] :]
1417
 
1418
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1419
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1420
+
1421
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1422
+ if inputs_embeds is not None and cache_position[0] == 0:
1423
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1424
  else:
1425
+ # The clone here is for the same reason as for `position_ids`.
1426
+ model_inputs = {
1427
+ "input_ids": input_ids.clone(memory_format=torch.contiguous_format),
1428
+ "inputs_embeds": None,
1429
+ }
1430
+
1431
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1432
+ if model_inputs["inputs_embeds"] is not None:
1433
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1434
+ device = model_inputs["inputs_embeds"].device
1435
+ else:
1436
+ batch_size, sequence_length = model_inputs["input_ids"].shape
1437
+ device = model_inputs["input_ids"].device
1438
+
1439
+ dtype = self.lm_head.weight.dtype
1440
+ min_dtype = torch.finfo(dtype).min
1441
+
1442
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1443
+ attention_mask,
1444
+ sequence_length=sequence_length,
1445
+ target_length=past_key_values.get_max_length(),
1446
+ dtype=dtype,
1447
+ device=device,
1448
+ min_dtype=min_dtype,
1449
+ cache_position=cache_position,
1450
+ batch_size=batch_size,
1451
+ )
1452
 
1453
  model_inputs.update(
1454
  {
1455
  "position_ids": position_ids,
1456
+ "cache_position": cache_position,
1457
  "past_key_values": past_key_values,
1458
+ "use_cache": use_cache,
1459
  "attention_mask": attention_mask,
1460
+ "num_logits_to_keep": num_logits_to_keep,
1461
  }
1462
  )
1463
  return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
moondream.py CHANGED
@@ -1,10 +1,14 @@
1
  import torch
2
- from .vision_encoder import VisionEncoder
3
- from .configuration_moondream import MoondreamConfig
4
  from transformers import PreTrainedModel
 
5
 
6
- from .modeling_phi import PhiForCausalLM
7
  from .configuration_moondream import PhiConfig
 
 
 
 
8
 
9
  class Moondream(PreTrainedModel):
10
  config_class = MoondreamConfig
@@ -15,6 +19,7 @@ class Moondream(PreTrainedModel):
15
  self.vision_encoder = VisionEncoder(
16
  use_flash_attn=config._attn_implementation == "flash_attention_2"
17
  )
 
18
 
19
  if type(config.text_config) == dict:
20
  phi_config = PhiConfig(
@@ -80,12 +85,55 @@ class Moondream(PreTrainedModel):
80
 
81
  with torch.no_grad():
82
  inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
 
83
  output_ids = self.text_model.generate(
84
- inputs_embeds=inputs_embeds, **generate_config
 
 
85
  )
86
 
87
  return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def answer_question(
90
  self,
91
  image_embeds,
@@ -93,6 +141,7 @@ class Moondream(PreTrainedModel):
93
  tokenizer,
94
  chat_history="",
95
  result_queue=None,
 
96
  **kwargs,
97
  ):
98
  prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
@@ -100,7 +149,7 @@ class Moondream(PreTrainedModel):
100
  image_embeds,
101
  prompt,
102
  tokenizer=tokenizer,
103
- max_new_tokens=512,
104
  **kwargs,
105
  )[0]
106
  cleaned_answer = answer.strip()
@@ -176,3 +225,6 @@ class Moondream(PreTrainedModel):
176
  x.strip()
177
  for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
178
  ]
 
 
 
 
1
  import torch
2
+
3
+ from typing import List, Union, Literal, Optional
4
  from transformers import PreTrainedModel
5
+ from PIL import Image
6
 
 
7
  from .configuration_moondream import PhiConfig
8
+ from .configuration_moondream import MoondreamConfig
9
+ from .vision_encoder import VisionEncoder
10
+ from .region_model import RegionModel
11
+ from .modeling_phi import PhiForCausalLM
12
 
13
  class Moondream(PreTrainedModel):
14
  config_class = MoondreamConfig
 
19
  self.vision_encoder = VisionEncoder(
20
  use_flash_attn=config._attn_implementation == "flash_attention_2"
21
  )
22
+ self.region_model = RegionModel()
23
 
24
  if type(config.text_config) == dict:
25
  phi_config = PhiConfig(
 
85
 
86
  with torch.no_grad():
87
  inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
88
+ attention_mask = torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device)
89
  output_ids = self.text_model.generate(
90
+ inputs_embeds=inputs_embeds,
91
+ attention_mask=attention_mask,
92
+ **generate_config,
93
  )
94
 
95
  return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
96
 
97
+ # Note: Not ready for use yet, intended for September release.
98
+ def caption(
99
+ self,
100
+ images: List[Image.Image],
101
+ tokenizer,
102
+ length: Optional[Literal["short"]] = None,
103
+ **kwargs,
104
+ ):
105
+ image_embeds = self.encode_image(images)
106
+
107
+ templated_prompts = [
108
+ f"<image>\n\n{'Short caption' if length == 'short' else 'Caption'}:" for _ in images
109
+ ]
110
+ inputs_embeds = torch.stack([
111
+ self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
112
+ for prompt, image_embed in zip(templated_prompts, image_embeds)
113
+ ])
114
+ attention_mask = torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device)
115
+
116
+ generate_config = {
117
+ "eos_token_id": tokenizer.eos_token_id,
118
+ "bos_token_id": tokenizer.bos_token_id,
119
+ "pad_token_id": tokenizer.bos_token_id,
120
+ "repetition_penalty": 1.2,
121
+ "max_new_tokens": 512,
122
+ **kwargs,
123
+ }
124
+
125
+ with torch.no_grad():
126
+ output_ids = self.text_model.generate(
127
+ inputs_embeds=inputs_embeds,
128
+ attention_mask=attention_mask,
129
+ **generate_config,
130
+ )
131
+
132
+ return [
133
+ x.strip()
134
+ for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
135
+ ]
136
+
137
  def answer_question(
138
  self,
139
  image_embeds,
 
141
  tokenizer,
142
  chat_history="",
143
  result_queue=None,
144
+ max_new_tokens=256,
145
  **kwargs,
146
  ):
147
  prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
 
149
  image_embeds,
150
  prompt,
151
  tokenizer=tokenizer,
152
+ max_new_tokens=max_new_tokens,
153
  **kwargs,
154
  )[0]
155
  cleaned_answer = answer.strip()
 
225
  x.strip()
226
  for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
227
  ]
228
+
229
+ def detect(self, image: Image.Image, query: str, tokenizer):
230
+ pass
region_model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .fourier_features import FourierFeatures
4
+
5
+ class RegionModel(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ self.position_features = FourierFeatures(2, 256)
10
+ self.position_encoder = nn.Linear(256, 2048)
11
+ self.size_features = FourierFeatures(2, 256)
12
+ self.size_encoder = nn.Linear(256, 2048)
13
+
14
+ self.position_decoder = nn.Linear(2048, 2)
15
+ self.size_decoder = nn.Linear(2048, 2)
16
+ self.confidence_decoder = nn.Linear(2048, 1)
17
+
18
+ def encode_position(self, position):
19
+ return self.position_encoder(self.position_features(position))
20
+
21
+ def encode_size(self, size):
22
+ return self.size_encoder(self.size_features(size))
23
+
24
+ def decode_position(self, x):
25
+ return self.position_decoder(x)
26
+
27
+ def decode_size(self, x):
28
+ return self.size_decoder(x)
29
+
30
+ def decode_confidence(self, x):
31
+ return self.confidence_decoder(x)
32
+
33
+ def encode(self, position, size):
34
+ return torch.stack(
35
+ [self.encode_position(position), self.encode_size(size)], dim=0
36
+ )
37
+
38
+ def decode(self, position_logits, size_logits):
39
+ return (
40
+ self.decode_position(position_logits),
41
+ self.decode_size(size_logits),
42
+ self.decode_confidence(size_logits),
43
+ )