Skip to content

Commit 53caeb0

Browse files
committed
Add some more dd kwarg updates, crossvit, ghostnet, rdnet, repghost, repvit, selecsls, swiftformer
1 parent c7955eb commit 53caeb0

File tree

7 files changed

+699
-346
lines changed

7 files changed

+699
-346
lines changed

timm/models/crossvit.py

Lines changed: 97 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# SPDX-License-Identifier: Apache-2.0
2222

2323
from functools import partial
24-
from typing import List, Optional, Tuple
24+
from typing import List, Optional, Tuple, Type, Union
2525

2626
import torch
2727
import torch.nn as nn
@@ -40,7 +40,17 @@ class PatchEmbed(nn.Module):
4040
""" Image to Patch Embedding
4141
"""
4242

43-
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False):
43+
def __init__(
44+
self,
45+
img_size: Union[int, Tuple[int, int]] = 224,
46+
patch_size: int = 16,
47+
in_chans: int = 3,
48+
embed_dim: int = 768,
49+
multi_conv: bool = False,
50+
device=None,
51+
dtype=None,
52+
):
53+
dd = {'device': device, 'dtype': dtype}
4454
super().__init__()
4555
img_size = to_2tuple(img_size)
4656
patch_size = to_2tuple(patch_size)
@@ -51,22 +61,22 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi
5161
if multi_conv:
5262
if patch_size[0] == 12:
5363
self.proj = nn.Sequential(
54-
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
64+
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3, **dd),
5565
nn.ReLU(inplace=True),
56-
nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0),
66+
nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0, **dd),
5767
nn.ReLU(inplace=True),
58-
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1),
68+
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1, **dd),
5969
)
6070
elif patch_size[0] == 16:
6171
self.proj = nn.Sequential(
62-
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
72+
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3, **dd),
6373
nn.ReLU(inplace=True),
64-
nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1),
74+
nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1, **dd),
6575
nn.ReLU(inplace=True),
66-
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
76+
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1, **dd),
6777
)
6878
else:
69-
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
79+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, **dd)
7080

7181
def forward(self, x):
7282
B, C, H, W = x.shape
@@ -82,23 +92,26 @@ def forward(self, x):
8292
class CrossAttention(nn.Module):
8393
def __init__(
8494
self,
85-
dim,
86-
num_heads=8,
87-
qkv_bias=False,
88-
attn_drop=0.,
89-
proj_drop=0.,
95+
dim: int,
96+
num_heads: int = 8,
97+
qkv_bias: bool = False,
98+
attn_drop: float = 0.,
99+
proj_drop: float = 0.,
100+
device=None,
101+
dtype=None,
90102
):
103+
dd = {'device': device, 'dtype': dtype}
91104
super().__init__()
92105
self.num_heads = num_heads
93106
head_dim = dim // num_heads
94107
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
95108
self.scale = head_dim ** -0.5
96109

97-
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
98-
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
99-
self.wv = nn.Linear(dim, dim, bias=qkv_bias)
110+
self.wq = nn.Linear(dim, dim, bias=qkv_bias, **dd)
111+
self.wk = nn.Linear(dim, dim, bias=qkv_bias, **dd)
112+
self.wv = nn.Linear(dim, dim, bias=qkv_bias, **dd)
100113
self.attn_drop = nn.Dropout(attn_drop)
101-
self.proj = nn.Linear(dim, dim)
114+
self.proj = nn.Linear(dim, dim, **dd)
102115
self.proj_drop = nn.Dropout(proj_drop)
103116

104117
def forward(self, x):
@@ -124,24 +137,28 @@ class CrossAttentionBlock(nn.Module):
124137

125138
def __init__(
126139
self,
127-
dim,
128-
num_heads,
129-
mlp_ratio=4.,
130-
qkv_bias=False,
131-
proj_drop=0.,
132-
attn_drop=0.,
133-
drop_path=0.,
134-
act_layer=nn.GELU,
135-
norm_layer=nn.LayerNorm,
140+
dim: int,
141+
num_heads: int,
142+
mlp_ratio: float = 4.,
143+
qkv_bias: bool = False,
144+
proj_drop: float = 0.,
145+
attn_drop: float = 0.,
146+
drop_path: float = 0.,
147+
act_layer: Type[nn.Module] = nn.GELU,
148+
norm_layer: Type[nn.Module] = nn.LayerNorm,
149+
device=None,
150+
dtype=None,
136151
):
152+
dd = {'device': device, 'dtype': dtype}
137153
super().__init__()
138-
self.norm1 = norm_layer(dim)
154+
self.norm1 = norm_layer(dim, **dd)
139155
self.attn = CrossAttention(
140156
dim,
141157
num_heads=num_heads,
142158
qkv_bias=qkv_bias,
143159
attn_drop=attn_drop,
144160
proj_drop=proj_drop,
161+
**dd,
145162
)
146163
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
147164
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@@ -155,20 +172,22 @@ class MultiScaleBlock(nn.Module):
155172

156173
def __init__(
157174
self,
158-
dim,
159-
patches,
160-
depth,
161-
num_heads,
162-
mlp_ratio,
163-
qkv_bias=False,
164-
proj_drop=0.,
165-
attn_drop=0.,
166-
drop_path=0.,
167-
act_layer=nn.GELU,
168-
norm_layer=nn.LayerNorm,
175+
dim: Tuple[int, ...],
176+
patches: Tuple[int, ...],
177+
depth: Tuple[int, ...],
178+
num_heads: Tuple[int, ...],
179+
mlp_ratio: Tuple[float, ...],
180+
qkv_bias: bool = False,
181+
proj_drop: float = 0.,
182+
attn_drop: float = 0.,
183+
drop_path: Union[List[float], float] = 0.,
184+
act_layer: Type[nn.Module] = nn.GELU,
185+
norm_layer: Type[nn.Module] = nn.LayerNorm,
186+
device=None,
187+
dtype=None,
169188
):
189+
dd = {'device': device, 'dtype': dtype}
170190
super().__init__()
171-
172191
num_branches = len(dim)
173192
self.num_branches = num_branches
174193
# different branch could have different embedding size, the first one is the base
@@ -185,6 +204,7 @@ def __init__(
185204
attn_drop=attn_drop,
186205
drop_path=drop_path[i],
187206
norm_layer=norm_layer,
207+
**dd,
188208
))
189209
if len(tmp) != 0:
190210
self.blocks.append(nn.Sequential(*tmp))
@@ -197,7 +217,7 @@ def __init__(
197217
if dim[d] == dim[(d + 1) % num_branches] and False:
198218
tmp = [nn.Identity()]
199219
else:
200-
tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches])]
220+
tmp = [norm_layer(dim[d], **dd), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches], **dd)]
201221
self.projs.append(nn.Sequential(*tmp))
202222

203223
self.fusion = nn.ModuleList()
@@ -215,6 +235,7 @@ def __init__(
215235
attn_drop=attn_drop,
216236
drop_path=drop_path[-1],
217237
norm_layer=norm_layer,
238+
**dd,
218239
))
219240
else:
220241
tmp = []
@@ -228,6 +249,7 @@ def __init__(
228249
attn_drop=attn_drop,
229250
drop_path=drop_path[-1],
230251
norm_layer=norm_layer,
252+
**dd,
231253
))
232254
self.fusion.append(nn.Sequential(*tmp))
233255

@@ -236,8 +258,8 @@ def __init__(
236258
if dim[(d + 1) % num_branches] == dim[d] and False:
237259
tmp = [nn.Identity()]
238260
else:
239-
tmp = [norm_layer(dim[(d + 1) % num_branches]), act_layer(),
240-
nn.Linear(dim[(d + 1) % num_branches], dim[d])]
261+
tmp = [norm_layer(dim[(d + 1) % num_branches], **dd), act_layer(),
262+
nn.Linear(dim[(d + 1) % num_branches], dim[d], **dd)]
241263
self.revert_projs.append(nn.Sequential(*tmp))
242264

243265
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
@@ -293,27 +315,30 @@ class CrossVit(nn.Module):
293315

294316
def __init__(
295317
self,
296-
img_size=224,
297-
img_scale=(1.0, 1.0),
298-
patch_size=(8, 16),
299-
in_chans=3,
300-
num_classes=1000,
301-
embed_dim=(192, 384),
302-
depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)),
303-
num_heads=(6, 12),
304-
mlp_ratio=(2., 2., 4.),
305-
multi_conv=False,
306-
crop_scale=False,
307-
qkv_bias=True,
308-
drop_rate=0.,
309-
pos_drop_rate=0.,
310-
proj_drop_rate=0.,
311-
attn_drop_rate=0.,
312-
drop_path_rate=0.,
313-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
314-
global_pool='token',
318+
img_size: int = 224,
319+
img_scale: Tuple[float, ...] = (1.0, 1.0),
320+
patch_size: Tuple[int, ...] = (8, 16),
321+
in_chans: int = 3,
322+
num_classes: int = 1000,
323+
embed_dim: Tuple[int, ...] = (192, 384),
324+
depth: Tuple[Tuple[int, ...], ...] = ((1, 3, 1), (1, 3, 1), (1, 3, 1)),
325+
num_heads: Tuple[int, ...] = (6, 12),
326+
mlp_ratio: Tuple[float, ...] = (2., 2., 4.),
327+
multi_conv: bool = False,
328+
crop_scale: bool = False,
329+
qkv_bias: bool = True,
330+
drop_rate: float = 0.,
331+
pos_drop_rate: float = 0.,
332+
proj_drop_rate: float = 0.,
333+
attn_drop_rate: float = 0.,
334+
drop_path_rate: float = 0.,
335+
norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
336+
global_pool: str = 'token',
337+
device=None,
338+
dtype=None,
315339
):
316340
super().__init__()
341+
dd = {'device': device, 'dtype': dtype}
317342
assert global_pool in ('token', 'avg')
318343

319344
self.num_classes = num_classes
@@ -330,8 +355,8 @@ def __init__(
330355

331356
# hard-coded for torch jit script
332357
for i in range(self.num_branches):
333-
setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])))
334-
setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i])))
358+
setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i], **dd)))
359+
setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i], **dd)))
335360

336361
for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim):
337362
self.patch_embed.append(
@@ -341,6 +366,7 @@ def __init__(
341366
in_chans=in_chans,
342367
embed_dim=d,
343368
multi_conv=multi_conv,
369+
**dd,
344370
))
345371

346372
self.pos_drop = nn.Dropout(p=pos_drop_rate)
@@ -363,14 +389,15 @@ def __init__(
363389
attn_drop=attn_drop_rate,
364390
drop_path=dpr_,
365391
norm_layer=norm_layer,
392+
**dd,
366393
)
367394
dpr_ptr += curr_depth
368395
self.blocks.append(blk)
369396

370-
self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])
397+
self.norm = nn.ModuleList([norm_layer(embed_dim[i], **dd) for i in range(self.num_branches)])
371398
self.head_drop = nn.Dropout(drop_rate)
372399
self.head = nn.ModuleList([
373-
nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
400+
nn.Linear(embed_dim[i], num_classes, **dd) if num_classes > 0 else nn.Identity()
374401
for i in range(self.num_branches)])
375402

376403
for i in range(self.num_branches):
@@ -418,8 +445,11 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
418445
if global_pool is not None:
419446
assert global_pool in ('token', 'avg')
420447
self.global_pool = global_pool
448+
device = self.head[0].weight.device if hasattr(self.head[0], 'weight') else None
449+
dtype = self.head[0].weight.dtype if hasattr(self.head[0], 'weight') else None
450+
dd = {'device': device, 'dtype': dtype}
421451
self.head = nn.ModuleList([
422-
nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
452+
nn.Linear(self.embed_dim[i], num_classes, **dd) if num_classes > 0 else nn.Identity()
423453
for i in range(self.num_branches)
424454
])
425455

0 commit comments

Comments
 (0)