Skip to content

Conversation

@AdonaiVera
Copy link

Add Swin Transformer Backbone

This PR adds SwinTransformerBackbone, based on the paper Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. It is the first step to support Grounding DINO in keras-hub, where Swin is the main image encoder.

Related Work

This idea has been discussed in previous threads:

  • keras-cv issue #2114 — community request to support Grounding DINO, which requires Swin.
  • keras-hub issue #2117 — A request to support Swin-UNETR, a model originally designed for 3D medical image segmentation using Swin Transformers for effective feature extraction.

Included in this PR

  • SwinTransformerBackbone model
  • Core layers (PatchEmbedding, WindowAttention, etc.)
  • Unit tests
  • (Presets will be added later)

🧪 Current Status

This PR is still a draft. I’m finishing:

  • Fixing some issues with float16 and float32
  • Preparing a Colab notebook to:
    • Convert pretrained weights from the original repo
    • Compare the outputs with the original model

Let me know if this direction makes sense, or if you have any comments or suggestions. Thanks!

@abheesht17
Copy link
Collaborator

@AdonaiVera - are you still working on this?

@AdonaiVera
Copy link
Author

Hi @abheesht17
Yes, I’m creating the Colab notebook to check it with the original one and making some changes in the infrastructure. Any feedback is very welcome! I plan to work on this during the weekend.

@abheesht17
Copy link
Collaborator

Hi @abheesht17 Yes, I’m creating the Colab notebook to check it with the original one and making some changes in the infrastructure. Any feedback is very welcome! I plan to work on this during the weekend.

Awesome, thank you! :)

@divyashreepathihalli divyashreepathihalli moved this from Todo to In Progress in KerasHub Jul 10, 2025
@divyashreepathihalli
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The code introduces a Swin Transformer backbone network, including core layers and unit tests. The changes look good overall, but there are some opportunities to improve efficiency and readability, particularly in the SwinTransformerBackbone and SwinTransformerBlock classes. Addressing these points will enhance the code's maintainability and performance.

Comment on lines 150 to 163
def reshape_and_norm(tensor, norm_layer=self.norm_layers[i]):
shape = ops.shape(tensor)
B = shape[0]
L = shape[1]
C = shape[2]
H_float = ops.sqrt(ops.cast(L, x.dtype))
H = ops.cast(H_float, "int32")
W = H
tensor = ops.reshape(tensor, (B, H, W, C))
return norm_layer(tensor)

x_reshaped = keras.layers.Lambda(reshape_and_norm)(x)
features.append(x_reshaped)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The reshape_and_norm function is defined inside the loop, which means it will be re-defined for each stage. This is inefficient. It should be defined outside the loop to avoid re-definition. Also, the shape information is available, so it's better to use that instead of hardcoding the reshape operation with sqrt and casts, which can be error-prone and less readable. Finally, it's better to return the reshaped tensor directly instead of using a Lambda layer.

Suggested change
def reshape_and_norm(tensor, norm_layer=self.norm_layers[i]):
shape = ops.shape(tensor)
B = shape[0]
L = shape[1]
C = shape[2]
H_float = ops.sqrt(ops.cast(L, x.dtype))
H = ops.cast(H_float, "int32")
W = H
tensor = ops.reshape(tensor, (B, H, W, C))
return norm_layer(tensor)
x_reshaped = keras.layers.Lambda(reshape_and_norm)(x)
features.append(x_reshaped)
def reshape_and_norm(tensor, norm_layer):
shape = ops.shape(tensor)
B = shape[0]
L = shape[1]
C = shape[2]
H = image_shape[0] // self.patch_size
W = image_shape[1] // self.patch_size
tensor = ops.reshape(tensor, (B, H, W, C))
return norm_layer(tensor)
# Forward pass
features = []
for i, stage in enumerate(self.stages):
x = stage(x)
x_reshaped = reshape_and_norm(x, self.norm_layers[i])
features.append(x_reshaped)

Comment on lines 439 to 455
img_mask = np.zeros((1, H, W, 1), dtype=np.int32)
cnt = 0
h_slices = [
(0, H // 2),
(H // 2, H - self.shift_size),
(H - self.shift_size, H),
]
w_slices = [
(0, W // 2),
(W // 2, W - self.shift_size),
(W - self.shift_size, W),
]
for h in h_slices:
for w in w_slices:
img_mask[:, h[0] : h[1], w[0] : w[1], :] = cnt
cnt += 1
img_mask = ops.convert_to_tensor(img_mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The img_mask is initialized with np.zeros and then converted to a tensor. It's better to initialize it directly as a Keras tensor to avoid unnecessary conversions and potential compatibility issues with different backends. Also, the division by 2 (e.g., H // 2) might lead to incorrect mask creation if H or W is not even. It's better to use ops.cast(H / 2, 'int32') to ensure correct behavior for both even and odd dimensions.

Suggested change
img_mask = np.zeros((1, H, W, 1), dtype=np.int32)
cnt = 0
h_slices = [
(0, H // 2),
(H // 2, H - self.shift_size),
(H - self.shift_size, H),
]
w_slices = [
(0, W // 2),
(W // 2, W - self.shift_size),
(W - self.shift_size, W),
]
for h in h_slices:
for w in w_slices:
img_mask[:, h[0] : h[1], w[0] : w[1], :] = cnt
cnt += 1
img_mask = ops.convert_to_tensor(img_mask)
img_mask = ops.zeros((1, H, W, 1), dtype="int32")
cnt = 0
h_slices = [
(0, ops.cast(H / 2, 'int32')),
(ops.cast(H / 2, 'int32'), H - self.shift_size),
(H - self.shift_size, H),
]
w_slices = [
(0, ops.cast(W / 2, 'int32')),
(ops.cast(W / 2, 'int32'), W - self.shift_size),
(W - self.shift_size, W),
]
for h in h_slices:
for w in w_slices:
img_mask = ops.slice_update(img_mask, [0, h[0], w[0], 0], ops.ones((1, h[1] - h[0], w[1] - w[0], 1), dtype='int32') * cnt)
cnt += 1
# img_mask = ops.convert_to_tensor(img_mask)

Comment on lines 461 to 464
attn_mask = ops.expand_dims(mask_windows, 1) - ops.expand_dims(
mask_windows, 2
)
attn_mask = ops.where(attn_mask != 0, -100.0, 0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The attention mask is created using ops.expand_dims and then compared to 0. This can be simplified by directly comparing the original mask_windows to each other and using the result to create the attention mask. This avoids the need for expanding dimensions and improves readability.

            attn_mask = ops.cast(ops.expand_dims(mask_windows, 1) != ops.expand_dims(mask_windows, 2), dtype='float32') * -100.0

Comment on lines +543 to +544
pad_values = ((0, 0), (0, H % 2), (0, W % 2), (0, 0))
x = ops.pad(x, pad_values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Padding is applied using ops.pad. It's important to ensure that the padding values are correctly specified for all backends. Consider using a more explicit padding mode like 'CONSTANT' with a value of 0 to avoid potential issues with different padding behaviors across backends.

        x = ops.pad(x, pad_values, mode='CONSTANT', constant_values=0)

@innat
Copy link

innat commented Jul 26, 2025

@AdonaiVera
You can take a look into this implementation (2d-swin), it reproduced original implementaiton (sort of - though there are some limitaiton, which you might care to fix).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

4 participants