Skip to content

Commit 8392ca9

Browse files
Refactor CLIP and update SD3. (#2316)
* Update SD3 scheduler and the dtype of the text encoders. * Fix the test. * Fix torch float16 issues and jax take issue. Add numeric checks for SD3 scheduler and text encoders. * Fix CLIP test. * Refactor CLIP models. * Update CLIP conversion script. * Update `from_config`. * Fix tests.
1 parent 1b50fc0 commit 8392ca9

22 files changed

+678
-442
lines changed

keras_hub/src/models/clip/clip_backbone.py

Lines changed: 3 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,109 +1,10 @@
1-
import math
2-
31
from keras import layers
4-
from keras import ops
52

63
from keras_hub.src.api_export import keras_hub_export
74
from keras_hub.src.models.backbone import Backbone
8-
9-
10-
class CLIPVisionPooler(layers.Layer):
11-
"""The vision pooler layer of CLIP.
12-
13-
`CLIPVisionPooler` will extracts the first token (index `0`) from the
14-
sequence of the vision embeddings as the pooled outputs.
15-
16-
Call arguments:
17-
vision_embeddings: A tensor of shape
18-
`(batch_size, sequence_length, hidden_dim)`.
19-
"""
20-
21-
def call(self, vision_embeddings):
22-
return vision_embeddings[:, 0, :]
23-
24-
def compute_output_shape(self, input_shape):
25-
return (input_shape[0], input_shape[-1])
26-
27-
28-
class CLIPTextPooler(layers.Layer):
29-
"""The text pooler layer of CLIP.
30-
31-
`CLIPTextPooler` extracts the text embeddings at the positions of EOS tokens
32-
as the pooled outputs.
33-
34-
Call arguments:
35-
text_embeddings: A tensor of shape
36-
`(batch_size, sequence_length, hidden_dim)`.
37-
token_ids: A tensor of shape `(batch_size, max_tokens)`, used to
38-
identify the positions of EOS tokens.
39-
"""
40-
41-
def call(self, text_embeddings, token_ids):
42-
# `keepdims` is not supported in `keras<=3.1`.
43-
eos_index = ops.argmax(token_ids, axis=-1)
44-
eos_index = ops.expand_dims(eos_index, axis=-1)
45-
eos_index = ops.expand_dims(eos_index, axis=-1)
46-
pooled_outputs = ops.take_along_axis(text_embeddings, eos_index, axis=1)
47-
return ops.squeeze(pooled_outputs, axis=1)
48-
49-
def compute_output_shape(self, input_shape):
50-
return (input_shape[0], input_shape[-1])
51-
52-
53-
class CLIPHead(layers.Layer):
54-
"""The head layer of CLIP.
55-
56-
`CLIPHead` takes `vision_embedding` and `text_embedding` as inputs to
57-
compute the corresponding logits. Both embeddings are L2 normalized and used
58-
to compute pairwise cosine similarity. The resulting logits are then scaled
59-
by a learnable `logit_scale` parameter.
60-
61-
Call arguments:
62-
vision_embedding: A tensor of shape `(batch_size, hidden_dim)`.
63-
text_embedding: A tensor of shape `(batch_size, hidden_dim)`.
64-
"""
65-
66-
def build(self, input_shape):
67-
self.logit_scale = self.add_weight(
68-
shape=(),
69-
initializer=lambda *a, **kw: math.log(1 / 0.07),
70-
trainable=True,
71-
dtype=self.variable_dtype,
72-
name="logit_scale",
73-
)
74-
75-
def call(self, vision_embedding, text_embedding):
76-
normalized_vision_embedding = ops.sqrt(
77-
ops.sum(ops.power(vision_embedding, 2), axis=-1, keepdims=True)
78-
)
79-
normalized_text_embedding = ops.sqrt(
80-
ops.sum(ops.power(text_embedding, 2), axis=-1, keepdims=True)
81-
)
82-
vision_embedding = vision_embedding / normalized_vision_embedding
83-
text_embedding = text_embedding / normalized_text_embedding
84-
logit_scale = ops.exp(self.logit_scale)
85-
text_logits = (
86-
ops.matmul(
87-
text_embedding,
88-
ops.transpose(vision_embedding),
89-
)
90-
* logit_scale
91-
)
92-
vision_logits = ops.transpose(text_logits)
93-
return vision_logits, text_logits
94-
95-
def compute_output_shape(
96-
self, vision_embedding_shape, text_embedding_shape
97-
):
98-
vision_logits_shape = (
99-
vision_embedding_shape[0],
100-
text_embedding_shape[0],
101-
)
102-
text_logits_shape = (
103-
text_embedding_shape[0],
104-
vision_embedding_shape[0],
105-
)
106-
return vision_logits_shape, text_logits_shape
5+
from keras_hub.src.models.clip.clip_layers import CLIPHead
6+
from keras_hub.src.models.clip.clip_layers import CLIPTextPooler
7+
from keras_hub.src.models.clip.clip_layers import CLIPVisionPooler
1078

1089

10910
@keras_hub_export("keras_hub.models.CLIPBackbone")

keras_hub/src/models/clip/clip_encoder_block.py

Lines changed: 0 additions & 111 deletions
This file was deleted.

0 commit comments

Comments
 (0)