|
1 | | -import math |
2 | | - |
3 | 1 | from keras import layers |
4 | | -from keras import ops |
5 | 2 |
|
6 | 3 | from keras_hub.src.api_export import keras_hub_export |
7 | 4 | from keras_hub.src.models.backbone import Backbone |
8 | | - |
9 | | - |
10 | | -class CLIPVisionPooler(layers.Layer): |
11 | | - """The vision pooler layer of CLIP. |
12 | | -
|
13 | | - `CLIPVisionPooler` will extracts the first token (index `0`) from the |
14 | | - sequence of the vision embeddings as the pooled outputs. |
15 | | -
|
16 | | - Call arguments: |
17 | | - vision_embeddings: A tensor of shape |
18 | | - `(batch_size, sequence_length, hidden_dim)`. |
19 | | - """ |
20 | | - |
21 | | - def call(self, vision_embeddings): |
22 | | - return vision_embeddings[:, 0, :] |
23 | | - |
24 | | - def compute_output_shape(self, input_shape): |
25 | | - return (input_shape[0], input_shape[-1]) |
26 | | - |
27 | | - |
28 | | -class CLIPTextPooler(layers.Layer): |
29 | | - """The text pooler layer of CLIP. |
30 | | -
|
31 | | - `CLIPTextPooler` extracts the text embeddings at the positions of EOS tokens |
32 | | - as the pooled outputs. |
33 | | -
|
34 | | - Call arguments: |
35 | | - text_embeddings: A tensor of shape |
36 | | - `(batch_size, sequence_length, hidden_dim)`. |
37 | | - token_ids: A tensor of shape `(batch_size, max_tokens)`, used to |
38 | | - identify the positions of EOS tokens. |
39 | | - """ |
40 | | - |
41 | | - def call(self, text_embeddings, token_ids): |
42 | | - # `keepdims` is not supported in `keras<=3.1`. |
43 | | - eos_index = ops.argmax(token_ids, axis=-1) |
44 | | - eos_index = ops.expand_dims(eos_index, axis=-1) |
45 | | - eos_index = ops.expand_dims(eos_index, axis=-1) |
46 | | - pooled_outputs = ops.take_along_axis(text_embeddings, eos_index, axis=1) |
47 | | - return ops.squeeze(pooled_outputs, axis=1) |
48 | | - |
49 | | - def compute_output_shape(self, input_shape): |
50 | | - return (input_shape[0], input_shape[-1]) |
51 | | - |
52 | | - |
53 | | -class CLIPHead(layers.Layer): |
54 | | - """The head layer of CLIP. |
55 | | -
|
56 | | - `CLIPHead` takes `vision_embedding` and `text_embedding` as inputs to |
57 | | - compute the corresponding logits. Both embeddings are L2 normalized and used |
58 | | - to compute pairwise cosine similarity. The resulting logits are then scaled |
59 | | - by a learnable `logit_scale` parameter. |
60 | | -
|
61 | | - Call arguments: |
62 | | - vision_embedding: A tensor of shape `(batch_size, hidden_dim)`. |
63 | | - text_embedding: A tensor of shape `(batch_size, hidden_dim)`. |
64 | | - """ |
65 | | - |
66 | | - def build(self, input_shape): |
67 | | - self.logit_scale = self.add_weight( |
68 | | - shape=(), |
69 | | - initializer=lambda *a, **kw: math.log(1 / 0.07), |
70 | | - trainable=True, |
71 | | - dtype=self.variable_dtype, |
72 | | - name="logit_scale", |
73 | | - ) |
74 | | - |
75 | | - def call(self, vision_embedding, text_embedding): |
76 | | - normalized_vision_embedding = ops.sqrt( |
77 | | - ops.sum(ops.power(vision_embedding, 2), axis=-1, keepdims=True) |
78 | | - ) |
79 | | - normalized_text_embedding = ops.sqrt( |
80 | | - ops.sum(ops.power(text_embedding, 2), axis=-1, keepdims=True) |
81 | | - ) |
82 | | - vision_embedding = vision_embedding / normalized_vision_embedding |
83 | | - text_embedding = text_embedding / normalized_text_embedding |
84 | | - logit_scale = ops.exp(self.logit_scale) |
85 | | - text_logits = ( |
86 | | - ops.matmul( |
87 | | - text_embedding, |
88 | | - ops.transpose(vision_embedding), |
89 | | - ) |
90 | | - * logit_scale |
91 | | - ) |
92 | | - vision_logits = ops.transpose(text_logits) |
93 | | - return vision_logits, text_logits |
94 | | - |
95 | | - def compute_output_shape( |
96 | | - self, vision_embedding_shape, text_embedding_shape |
97 | | - ): |
98 | | - vision_logits_shape = ( |
99 | | - vision_embedding_shape[0], |
100 | | - text_embedding_shape[0], |
101 | | - ) |
102 | | - text_logits_shape = ( |
103 | | - text_embedding_shape[0], |
104 | | - vision_embedding_shape[0], |
105 | | - ) |
106 | | - return vision_logits_shape, text_logits_shape |
| 5 | +from keras_hub.src.models.clip.clip_layers import CLIPHead |
| 6 | +from keras_hub.src.models.clip.clip_layers import CLIPTextPooler |
| 7 | +from keras_hub.src.models.clip.clip_layers import CLIPVisionPooler |
107 | 8 |
|
108 | 9 |
|
109 | 10 | @keras_hub_export("keras_hub.models.CLIPBackbone") |
|
0 commit comments