21
21
# SPDX-License-Identifier: Apache-2.0
22
22
23
23
from functools import partial
24
- from typing import List , Optional , Tuple
24
+ from typing import List , Optional , Tuple , Type , Union
25
25
26
26
import torch
27
27
import torch .nn as nn
@@ -40,7 +40,17 @@ class PatchEmbed(nn.Module):
40
40
""" Image to Patch Embedding
41
41
"""
42
42
43
- def __init__ (self , img_size = 224 , patch_size = 16 , in_chans = 3 , embed_dim = 768 , multi_conv = False ):
43
+ def __init__ (
44
+ self ,
45
+ img_size : Union [int , Tuple [int , int ]] = 224 ,
46
+ patch_size : int = 16 ,
47
+ in_chans : int = 3 ,
48
+ embed_dim : int = 768 ,
49
+ multi_conv : bool = False ,
50
+ device = None ,
51
+ dtype = None ,
52
+ ):
53
+ dd = {'device' : device , 'dtype' : dtype }
44
54
super ().__init__ ()
45
55
img_size = to_2tuple (img_size )
46
56
patch_size = to_2tuple (patch_size )
@@ -51,22 +61,22 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi
51
61
if multi_conv :
52
62
if patch_size [0 ] == 12 :
53
63
self .proj = nn .Sequential (
54
- nn .Conv2d (in_chans , embed_dim // 4 , kernel_size = 7 , stride = 4 , padding = 3 ),
64
+ nn .Conv2d (in_chans , embed_dim // 4 , kernel_size = 7 , stride = 4 , padding = 3 , ** dd ),
55
65
nn .ReLU (inplace = True ),
56
- nn .Conv2d (embed_dim // 4 , embed_dim // 2 , kernel_size = 3 , stride = 3 , padding = 0 ),
66
+ nn .Conv2d (embed_dim // 4 , embed_dim // 2 , kernel_size = 3 , stride = 3 , padding = 0 , ** dd ),
57
67
nn .ReLU (inplace = True ),
58
- nn .Conv2d (embed_dim // 2 , embed_dim , kernel_size = 3 , stride = 1 , padding = 1 ),
68
+ nn .Conv2d (embed_dim // 2 , embed_dim , kernel_size = 3 , stride = 1 , padding = 1 , ** dd ),
59
69
)
60
70
elif patch_size [0 ] == 16 :
61
71
self .proj = nn .Sequential (
62
- nn .Conv2d (in_chans , embed_dim // 4 , kernel_size = 7 , stride = 4 , padding = 3 ),
72
+ nn .Conv2d (in_chans , embed_dim // 4 , kernel_size = 7 , stride = 4 , padding = 3 , ** dd ),
63
73
nn .ReLU (inplace = True ),
64
- nn .Conv2d (embed_dim // 4 , embed_dim // 2 , kernel_size = 3 , stride = 2 , padding = 1 ),
74
+ nn .Conv2d (embed_dim // 4 , embed_dim // 2 , kernel_size = 3 , stride = 2 , padding = 1 , ** dd ),
65
75
nn .ReLU (inplace = True ),
66
- nn .Conv2d (embed_dim // 2 , embed_dim , kernel_size = 3 , stride = 2 , padding = 1 ),
76
+ nn .Conv2d (embed_dim // 2 , embed_dim , kernel_size = 3 , stride = 2 , padding = 1 , ** dd ),
67
77
)
68
78
else :
69
- self .proj = nn .Conv2d (in_chans , embed_dim , kernel_size = patch_size , stride = patch_size )
79
+ self .proj = nn .Conv2d (in_chans , embed_dim , kernel_size = patch_size , stride = patch_size , ** dd )
70
80
71
81
def forward (self , x ):
72
82
B , C , H , W = x .shape
@@ -82,23 +92,26 @@ def forward(self, x):
82
92
class CrossAttention (nn .Module ):
83
93
def __init__ (
84
94
self ,
85
- dim ,
86
- num_heads = 8 ,
87
- qkv_bias = False ,
88
- attn_drop = 0. ,
89
- proj_drop = 0. ,
95
+ dim : int ,
96
+ num_heads : int = 8 ,
97
+ qkv_bias : bool = False ,
98
+ attn_drop : float = 0. ,
99
+ proj_drop : float = 0. ,
100
+ device = None ,
101
+ dtype = None ,
90
102
):
103
+ dd = {'device' : device , 'dtype' : dtype }
91
104
super ().__init__ ()
92
105
self .num_heads = num_heads
93
106
head_dim = dim // num_heads
94
107
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
95
108
self .scale = head_dim ** - 0.5
96
109
97
- self .wq = nn .Linear (dim , dim , bias = qkv_bias )
98
- self .wk = nn .Linear (dim , dim , bias = qkv_bias )
99
- self .wv = nn .Linear (dim , dim , bias = qkv_bias )
110
+ self .wq = nn .Linear (dim , dim , bias = qkv_bias , ** dd )
111
+ self .wk = nn .Linear (dim , dim , bias = qkv_bias , ** dd )
112
+ self .wv = nn .Linear (dim , dim , bias = qkv_bias , ** dd )
100
113
self .attn_drop = nn .Dropout (attn_drop )
101
- self .proj = nn .Linear (dim , dim )
114
+ self .proj = nn .Linear (dim , dim , ** dd )
102
115
self .proj_drop = nn .Dropout (proj_drop )
103
116
104
117
def forward (self , x ):
@@ -124,24 +137,28 @@ class CrossAttentionBlock(nn.Module):
124
137
125
138
def __init__ (
126
139
self ,
127
- dim ,
128
- num_heads ,
129
- mlp_ratio = 4. ,
130
- qkv_bias = False ,
131
- proj_drop = 0. ,
132
- attn_drop = 0. ,
133
- drop_path = 0. ,
134
- act_layer = nn .GELU ,
135
- norm_layer = nn .LayerNorm ,
140
+ dim : int ,
141
+ num_heads : int ,
142
+ mlp_ratio : float = 4. ,
143
+ qkv_bias : bool = False ,
144
+ proj_drop : float = 0. ,
145
+ attn_drop : float = 0. ,
146
+ drop_path : float = 0. ,
147
+ act_layer : Type [nn .Module ] = nn .GELU ,
148
+ norm_layer : Type [nn .Module ] = nn .LayerNorm ,
149
+ device = None ,
150
+ dtype = None ,
136
151
):
152
+ dd = {'device' : device , 'dtype' : dtype }
137
153
super ().__init__ ()
138
- self .norm1 = norm_layer (dim )
154
+ self .norm1 = norm_layer (dim , ** dd )
139
155
self .attn = CrossAttention (
140
156
dim ,
141
157
num_heads = num_heads ,
142
158
qkv_bias = qkv_bias ,
143
159
attn_drop = attn_drop ,
144
160
proj_drop = proj_drop ,
161
+ ** dd ,
145
162
)
146
163
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
147
164
self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
@@ -155,20 +172,22 @@ class MultiScaleBlock(nn.Module):
155
172
156
173
def __init__ (
157
174
self ,
158
- dim ,
159
- patches ,
160
- depth ,
161
- num_heads ,
162
- mlp_ratio ,
163
- qkv_bias = False ,
164
- proj_drop = 0. ,
165
- attn_drop = 0. ,
166
- drop_path = 0. ,
167
- act_layer = nn .GELU ,
168
- norm_layer = nn .LayerNorm ,
175
+ dim : Tuple [int , ...],
176
+ patches : Tuple [int , ...],
177
+ depth : Tuple [int , ...],
178
+ num_heads : Tuple [int , ...],
179
+ mlp_ratio : Tuple [float , ...],
180
+ qkv_bias : bool = False ,
181
+ proj_drop : float = 0. ,
182
+ attn_drop : float = 0. ,
183
+ drop_path : Union [List [float ], float ] = 0. ,
184
+ act_layer : Type [nn .Module ] = nn .GELU ,
185
+ norm_layer : Type [nn .Module ] = nn .LayerNorm ,
186
+ device = None ,
187
+ dtype = None ,
169
188
):
189
+ dd = {'device' : device , 'dtype' : dtype }
170
190
super ().__init__ ()
171
-
172
191
num_branches = len (dim )
173
192
self .num_branches = num_branches
174
193
# different branch could have different embedding size, the first one is the base
@@ -185,6 +204,7 @@ def __init__(
185
204
attn_drop = attn_drop ,
186
205
drop_path = drop_path [i ],
187
206
norm_layer = norm_layer ,
207
+ ** dd ,
188
208
))
189
209
if len (tmp ) != 0 :
190
210
self .blocks .append (nn .Sequential (* tmp ))
@@ -197,7 +217,7 @@ def __init__(
197
217
if dim [d ] == dim [(d + 1 ) % num_branches ] and False :
198
218
tmp = [nn .Identity ()]
199
219
else :
200
- tmp = [norm_layer (dim [d ]), act_layer (), nn .Linear (dim [d ], dim [(d + 1 ) % num_branches ])]
220
+ tmp = [norm_layer (dim [d ], ** dd ), act_layer (), nn .Linear (dim [d ], dim [(d + 1 ) % num_branches ], ** dd )]
201
221
self .projs .append (nn .Sequential (* tmp ))
202
222
203
223
self .fusion = nn .ModuleList ()
@@ -215,6 +235,7 @@ def __init__(
215
235
attn_drop = attn_drop ,
216
236
drop_path = drop_path [- 1 ],
217
237
norm_layer = norm_layer ,
238
+ ** dd ,
218
239
))
219
240
else :
220
241
tmp = []
@@ -228,6 +249,7 @@ def __init__(
228
249
attn_drop = attn_drop ,
229
250
drop_path = drop_path [- 1 ],
230
251
norm_layer = norm_layer ,
252
+ ** dd ,
231
253
))
232
254
self .fusion .append (nn .Sequential (* tmp ))
233
255
@@ -236,8 +258,8 @@ def __init__(
236
258
if dim [(d + 1 ) % num_branches ] == dim [d ] and False :
237
259
tmp = [nn .Identity ()]
238
260
else :
239
- tmp = [norm_layer (dim [(d + 1 ) % num_branches ]), act_layer (),
240
- nn .Linear (dim [(d + 1 ) % num_branches ], dim [d ])]
261
+ tmp = [norm_layer (dim [(d + 1 ) % num_branches ], ** dd ), act_layer (),
262
+ nn .Linear (dim [(d + 1 ) % num_branches ], dim [d ], ** dd )]
241
263
self .revert_projs .append (nn .Sequential (* tmp ))
242
264
243
265
def forward (self , x : List [torch .Tensor ]) -> List [torch .Tensor ]:
@@ -293,27 +315,30 @@ class CrossVit(nn.Module):
293
315
294
316
def __init__ (
295
317
self ,
296
- img_size = 224 ,
297
- img_scale = (1.0 , 1.0 ),
298
- patch_size = (8 , 16 ),
299
- in_chans = 3 ,
300
- num_classes = 1000 ,
301
- embed_dim = (192 , 384 ),
302
- depth = ((1 , 3 , 1 ), (1 , 3 , 1 ), (1 , 3 , 1 )),
303
- num_heads = (6 , 12 ),
304
- mlp_ratio = (2. , 2. , 4. ),
305
- multi_conv = False ,
306
- crop_scale = False ,
307
- qkv_bias = True ,
308
- drop_rate = 0. ,
309
- pos_drop_rate = 0. ,
310
- proj_drop_rate = 0. ,
311
- attn_drop_rate = 0. ,
312
- drop_path_rate = 0. ,
313
- norm_layer = partial (nn .LayerNorm , eps = 1e-6 ),
314
- global_pool = 'token' ,
318
+ img_size : int = 224 ,
319
+ img_scale : Tuple [float , ...] = (1.0 , 1.0 ),
320
+ patch_size : Tuple [int , ...] = (8 , 16 ),
321
+ in_chans : int = 3 ,
322
+ num_classes : int = 1000 ,
323
+ embed_dim : Tuple [int , ...] = (192 , 384 ),
324
+ depth : Tuple [Tuple [int , ...], ...] = ((1 , 3 , 1 ), (1 , 3 , 1 ), (1 , 3 , 1 )),
325
+ num_heads : Tuple [int , ...] = (6 , 12 ),
326
+ mlp_ratio : Tuple [float , ...] = (2. , 2. , 4. ),
327
+ multi_conv : bool = False ,
328
+ crop_scale : bool = False ,
329
+ qkv_bias : bool = True ,
330
+ drop_rate : float = 0. ,
331
+ pos_drop_rate : float = 0. ,
332
+ proj_drop_rate : float = 0. ,
333
+ attn_drop_rate : float = 0. ,
334
+ drop_path_rate : float = 0. ,
335
+ norm_layer : Type [nn .Module ] = partial (nn .LayerNorm , eps = 1e-6 ),
336
+ global_pool : str = 'token' ,
337
+ device = None ,
338
+ dtype = None ,
315
339
):
316
340
super ().__init__ ()
341
+ dd = {'device' : device , 'dtype' : dtype }
317
342
assert global_pool in ('token' , 'avg' )
318
343
319
344
self .num_classes = num_classes
@@ -330,8 +355,8 @@ def __init__(
330
355
331
356
# hard-coded for torch jit script
332
357
for i in range (self .num_branches ):
333
- setattr (self , f'pos_embed_{ i } ' , nn .Parameter (torch .zeros (1 , 1 + num_patches [i ], embed_dim [i ])))
334
- setattr (self , f'cls_token_{ i } ' , nn .Parameter (torch .zeros (1 , 1 , embed_dim [i ])))
358
+ setattr (self , f'pos_embed_{ i } ' , nn .Parameter (torch .zeros (1 , 1 + num_patches [i ], embed_dim [i ], ** dd )))
359
+ setattr (self , f'cls_token_{ i } ' , nn .Parameter (torch .zeros (1 , 1 , embed_dim [i ], ** dd )))
335
360
336
361
for im_s , p , d in zip (self .img_size_scaled , patch_size , embed_dim ):
337
362
self .patch_embed .append (
@@ -341,6 +366,7 @@ def __init__(
341
366
in_chans = in_chans ,
342
367
embed_dim = d ,
343
368
multi_conv = multi_conv ,
369
+ ** dd ,
344
370
))
345
371
346
372
self .pos_drop = nn .Dropout (p = pos_drop_rate )
@@ -363,14 +389,15 @@ def __init__(
363
389
attn_drop = attn_drop_rate ,
364
390
drop_path = dpr_ ,
365
391
norm_layer = norm_layer ,
392
+ ** dd ,
366
393
)
367
394
dpr_ptr += curr_depth
368
395
self .blocks .append (blk )
369
396
370
- self .norm = nn .ModuleList ([norm_layer (embed_dim [i ]) for i in range (self .num_branches )])
397
+ self .norm = nn .ModuleList ([norm_layer (embed_dim [i ], ** dd ) for i in range (self .num_branches )])
371
398
self .head_drop = nn .Dropout (drop_rate )
372
399
self .head = nn .ModuleList ([
373
- nn .Linear (embed_dim [i ], num_classes ) if num_classes > 0 else nn .Identity ()
400
+ nn .Linear (embed_dim [i ], num_classes , ** dd ) if num_classes > 0 else nn .Identity ()
374
401
for i in range (self .num_branches )])
375
402
376
403
for i in range (self .num_branches ):
@@ -418,8 +445,11 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
418
445
if global_pool is not None :
419
446
assert global_pool in ('token' , 'avg' )
420
447
self .global_pool = global_pool
448
+ device = self .head [0 ].weight .device if hasattr (self .head [0 ], 'weight' ) else None
449
+ dtype = self .head [0 ].weight .dtype if hasattr (self .head [0 ], 'weight' ) else None
450
+ dd = {'device' : device , 'dtype' : dtype }
421
451
self .head = nn .ModuleList ([
422
- nn .Linear (self .embed_dim [i ], num_classes ) if num_classes > 0 else nn .Identity ()
452
+ nn .Linear (self .embed_dim [i ], num_classes , ** dd ) if num_classes > 0 else nn .Identity ()
423
453
for i in range (self .num_branches )
424
454
])
425
455
0 commit comments