16
16
# Licensed under the MIT License.
17
17
18
18
from functools import partial
19
- from typing import Any , Dict , List , Optional , Set , Tuple , Union
19
+ from typing import Any , Dict , List , Optional , Set , Tuple , Type , Union
20
20
21
21
import torch
22
22
import torch .nn as nn
33
33
34
34
35
35
class Partial_conv3 (nn .Module ):
36
- def __init__ (self , dim : int , n_div : int , forward : str ):
36
+ def __init__ (self , dim : int , n_div : int , forward : str , device = None , dtype = None ):
37
+ dd = {'device' : device , 'dtype' : dtype }
37
38
super ().__init__ ()
38
39
self .dim_conv3 = dim // n_div
39
40
self .dim_untouched = dim - self .dim_conv3
40
- self .partial_conv3 = nn .Conv2d (self .dim_conv3 , self .dim_conv3 , 3 , 1 , 1 , bias = False )
41
+ self .partial_conv3 = nn .Conv2d (self .dim_conv3 , self .dim_conv3 , 3 , 1 , 1 , bias = False , ** dd )
41
42
42
43
if forward == 'slicing' :
43
44
self .forward = self .forward_slicing
@@ -68,25 +69,28 @@ def __init__(
68
69
mlp_ratio : float ,
69
70
drop_path : float ,
70
71
layer_scale_init_value : float ,
71
- act_layer : LayerType = partial (nn .ReLU , inplace = True ),
72
- norm_layer : LayerType = nn .BatchNorm2d ,
72
+ act_layer : Type [ nn . Module ] = partial (nn .ReLU , inplace = True ),
73
+ norm_layer : Type [ nn . Module ] = nn .BatchNorm2d ,
73
74
pconv_fw_type : str = 'split_cat' ,
75
+ device = None ,
76
+ dtype = None ,
74
77
):
78
+ dd = {'device' : device , 'dtype' : dtype }
75
79
super ().__init__ ()
76
80
mlp_hidden_dim = int (dim * mlp_ratio )
77
81
78
82
self .mlp = nn .Sequential (* [
79
- nn .Conv2d (dim , mlp_hidden_dim , 1 , bias = False ),
80
- norm_layer (mlp_hidden_dim ),
83
+ nn .Conv2d (dim , mlp_hidden_dim , 1 , bias = False , ** dd ),
84
+ norm_layer (mlp_hidden_dim , ** dd ),
81
85
act_layer (),
82
- nn .Conv2d (mlp_hidden_dim , dim , 1 , bias = False ),
86
+ nn .Conv2d (mlp_hidden_dim , dim , 1 , bias = False , ** dd ),
83
87
])
84
88
85
- self .spatial_mixing = Partial_conv3 (dim , n_div , pconv_fw_type )
89
+ self .spatial_mixing = Partial_conv3 (dim , n_div , pconv_fw_type , ** dd )
86
90
87
91
if layer_scale_init_value > 0 :
88
92
self .layer_scale = nn .Parameter (
89
- layer_scale_init_value * torch .ones ((dim )), requires_grad = True )
93
+ layer_scale_init_value * torch .ones ((dim ), ** dd ), requires_grad = True )
90
94
else :
91
95
self .layer_scale = None
92
96
@@ -112,12 +116,15 @@ def __init__(
112
116
mlp_ratio : float ,
113
117
drop_path : float ,
114
118
layer_scale_init_value : float ,
115
- act_layer : LayerType = partial (nn .ReLU , inplace = True ),
116
- norm_layer : LayerType = nn .BatchNorm2d ,
119
+ act_layer : Type [ nn . Module ] = partial (nn .ReLU , inplace = True ),
120
+ norm_layer : Type [ nn . Module ] = nn .BatchNorm2d ,
117
121
pconv_fw_type : str = 'split_cat' ,
118
122
use_merge : bool = True ,
119
123
merge_size : Union [int , Tuple [int , int ]] = 2 ,
124
+ device = None ,
125
+ dtype = None ,
120
126
):
127
+ dd = {'device' : device , 'dtype' : dtype }
121
128
super ().__init__ ()
122
129
self .grad_checkpointing = False
123
130
self .blocks = nn .Sequential (* [
@@ -130,13 +137,15 @@ def __init__(
130
137
norm_layer = norm_layer ,
131
138
act_layer = act_layer ,
132
139
pconv_fw_type = pconv_fw_type ,
140
+ ** dd ,
133
141
)
134
142
for i in range (depth )
135
143
])
136
144
self .downsample = PatchMerging (
137
145
dim = dim // 2 ,
138
146
patch_size = merge_size ,
139
147
norm_layer = norm_layer ,
148
+ ** dd ,
140
149
) if use_merge else nn .Identity ()
141
150
142
151
def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -154,11 +163,14 @@ def __init__(
154
163
in_chans : int ,
155
164
embed_dim : int ,
156
165
patch_size : Union [int , Tuple [int , int ]] = 4 ,
157
- norm_layer : LayerType = nn .BatchNorm2d ,
166
+ norm_layer : Type [nn .Module ] = nn .BatchNorm2d ,
167
+ device = None ,
168
+ dtype = None ,
158
169
):
170
+ dd = {'device' : device , 'dtype' : dtype }
159
171
super ().__init__ ()
160
- self .proj = nn .Conv2d (in_chans , embed_dim , patch_size , patch_size , bias = False )
161
- self .norm = norm_layer (embed_dim )
172
+ self .proj = nn .Conv2d (in_chans , embed_dim , patch_size , patch_size , bias = False , ** dd )
173
+ self .norm = norm_layer (embed_dim , ** dd )
162
174
163
175
def forward (self , x : torch .Tensor ) -> torch .Tensor :
164
176
return self .norm (self .proj (x ))
@@ -169,11 +181,14 @@ def __init__(
169
181
self ,
170
182
dim : int ,
171
183
patch_size : Union [int , Tuple [int , int ]] = 2 ,
172
- norm_layer : LayerType = nn .BatchNorm2d ,
184
+ norm_layer : Type [nn .Module ] = nn .BatchNorm2d ,
185
+ device = None ,
186
+ dtype = None ,
173
187
):
188
+ dd = {'device' : device , 'dtype' : dtype }
174
189
super ().__init__ ()
175
- self .reduction = nn .Conv2d (dim , 2 * dim , patch_size , patch_size , bias = False )
176
- self .norm = norm_layer (2 * dim )
190
+ self .reduction = nn .Conv2d (dim , 2 * dim , patch_size , patch_size , bias = False , ** dd )
191
+ self .norm = norm_layer (2 * dim , ** dd )
177
192
178
193
def forward (self , x : torch .Tensor ) -> torch .Tensor :
179
194
return self .norm (self .reduction (x ))
@@ -196,11 +211,14 @@ def __init__(
196
211
drop_rate : float = 0. ,
197
212
drop_path_rate : float = 0.1 ,
198
213
layer_scale_init_value : float = 0. ,
199
- act_layer : LayerType = partial (nn .ReLU , inplace = True ),
200
- norm_layer : LayerType = nn .BatchNorm2d ,
214
+ act_layer : Type [ nn . Module ] = partial (nn .ReLU , inplace = True ),
215
+ norm_layer : Type [ nn . Module ] = nn .BatchNorm2d ,
201
216
pconv_fw_type : str = 'split_cat' ,
217
+ device = None ,
218
+ dtype = None ,
202
219
):
203
220
super ().__init__ ()
221
+ dd = {'device' : device , 'dtype' : dtype }
204
222
assert pconv_fw_type in ('split_cat' , 'slicing' ,)
205
223
self .num_classes = num_classes
206
224
self .drop_rate = drop_rate
@@ -214,9 +232,10 @@ def __init__(
214
232
embed_dim = embed_dim ,
215
233
patch_size = patch_size ,
216
234
norm_layer = norm_layer if patch_norm else nn .Identity ,
235
+ ** dd ,
217
236
)
218
237
# stochastic depth decay rule
219
- dpr = calculate_drop_path_rates (drop_path_rate , sum ( depths ) )
238
+ dpr = calculate_drop_path_rates (drop_path_rate , depths , stagewise = True )
220
239
221
240
# build layers
222
241
stages_list = []
@@ -227,13 +246,14 @@ def __init__(
227
246
depth = depths [i ],
228
247
n_div = n_div ,
229
248
mlp_ratio = mlp_ratio ,
230
- drop_path = dpr [sum ( depths [: i ]): sum ( depths [: i + 1 ]) ],
249
+ drop_path = dpr [i ],
231
250
layer_scale_init_value = layer_scale_init_value ,
232
251
norm_layer = norm_layer ,
233
252
act_layer = act_layer ,
234
253
pconv_fw_type = pconv_fw_type ,
235
254
use_merge = False if i == 0 else True ,
236
255
merge_size = merge_size ,
256
+ ** dd ,
237
257
)
238
258
stages_list .append (stage )
239
259
self .feature_info += [dict (num_chs = dim , reduction = 2 ** (i + 2 ), module = f'stages.{ i } ' )]
@@ -243,10 +263,10 @@ def __init__(
243
263
self .num_features = prev_chs = int (embed_dim * 2 ** (self .num_stages - 1 ))
244
264
self .head_hidden_size = out_chs = feature_dim # 1280
245
265
self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
246
- self .conv_head = nn .Conv2d (prev_chs , out_chs , 1 , 1 , 0 , bias = False )
266
+ self .conv_head = nn .Conv2d (prev_chs , out_chs , 1 , 1 , 0 , bias = False , ** dd )
247
267
self .act = act_layer ()
248
268
self .flatten = nn .Flatten (1 ) if global_pool else nn .Identity () # don't flatten if pooling disabled
249
- self .classifier = Linear (out_chs , num_classes , bias = True ) if num_classes > 0 else nn .Identity ()
269
+ self .classifier = Linear (out_chs , num_classes , bias = True , ** dd ) if num_classes > 0 else nn .Identity ()
250
270
self ._initialize_weights ()
251
271
252
272
def _initialize_weights (self ):
@@ -285,12 +305,13 @@ def set_grad_checkpointing(self, enable=True):
285
305
def get_classifier (self ) -> nn .Module :
286
306
return self .classifier
287
307
288
- def reset_classifier (self , num_classes : int , global_pool : str = 'avg' ):
308
+ def reset_classifier (self , num_classes : int , global_pool : str = 'avg' , device = None , dtype = None ):
309
+ dd = {'device' : device , 'dtype' : dtype }
289
310
self .num_classes = num_classes
290
311
# cannot meaningfully change pooling of efficient head after creation
291
312
self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
292
313
self .flatten = nn .Flatten (1 ) if global_pool else nn .Identity () # don't flatten if pooling disabled
293
- self .classifier = Linear (self .head_hidden_size , num_classes ) if num_classes > 0 else nn .Identity ()
314
+ self .classifier = Linear (self .head_hidden_size , num_classes , ** dd ) if num_classes > 0 else nn .Identity ()
294
315
295
316
def forward_intermediates (
296
317
self ,
0 commit comments