Skip to content

Commit 5cadf13

Browse files
committed
More dd arg conversions. fasternet, gcvit, hgnet, nextvit, starnet, vision_transformer_hybrid/relpos/sam, vitamin. Temp fix for shvit reset_classifier (need better approach)
1 parent 21b1ae7 commit 5cadf13

File tree

10 files changed

+718
-431
lines changed

10 files changed

+718
-431
lines changed

timm/models/fasternet.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Licensed under the MIT License.
1717

1818
from functools import partial
19-
from typing import Any, Dict, List, Optional, Set, Tuple, Union
19+
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
2020

2121
import torch
2222
import torch.nn as nn
@@ -33,11 +33,12 @@
3333

3434

3535
class Partial_conv3(nn.Module):
36-
def __init__(self, dim: int, n_div: int, forward: str):
36+
def __init__(self, dim: int, n_div: int, forward: str, device=None, dtype=None):
37+
dd = {'device': device, 'dtype': dtype}
3738
super().__init__()
3839
self.dim_conv3 = dim // n_div
3940
self.dim_untouched = dim - self.dim_conv3
40-
self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
41+
self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False, **dd)
4142

4243
if forward == 'slicing':
4344
self.forward = self.forward_slicing
@@ -68,25 +69,28 @@ def __init__(
6869
mlp_ratio: float,
6970
drop_path: float,
7071
layer_scale_init_value: float,
71-
act_layer: LayerType = partial(nn.ReLU, inplace=True),
72-
norm_layer: LayerType = nn.BatchNorm2d,
72+
act_layer: Type[nn.Module] = partial(nn.ReLU, inplace=True),
73+
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
7374
pconv_fw_type: str = 'split_cat',
75+
device=None,
76+
dtype=None,
7477
):
78+
dd = {'device': device, 'dtype': dtype}
7579
super().__init__()
7680
mlp_hidden_dim = int(dim * mlp_ratio)
7781

7882
self.mlp = nn.Sequential(*[
79-
nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False),
80-
norm_layer(mlp_hidden_dim),
83+
nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False, **dd),
84+
norm_layer(mlp_hidden_dim, **dd),
8185
act_layer(),
82-
nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False),
86+
nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False, **dd),
8387
])
8488

85-
self.spatial_mixing = Partial_conv3(dim, n_div, pconv_fw_type)
89+
self.spatial_mixing = Partial_conv3(dim, n_div, pconv_fw_type, **dd)
8690

8791
if layer_scale_init_value > 0:
8892
self.layer_scale = nn.Parameter(
89-
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
93+
layer_scale_init_value * torch.ones((dim), **dd), requires_grad=True)
9094
else:
9195
self.layer_scale = None
9296

@@ -112,12 +116,15 @@ def __init__(
112116
mlp_ratio: float,
113117
drop_path: float,
114118
layer_scale_init_value: float,
115-
act_layer: LayerType = partial(nn.ReLU, inplace=True),
116-
norm_layer: LayerType = nn.BatchNorm2d,
119+
act_layer: Type[nn.Module] = partial(nn.ReLU, inplace=True),
120+
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
117121
pconv_fw_type: str = 'split_cat',
118122
use_merge: bool = True,
119123
merge_size: Union[int, Tuple[int, int]] = 2,
124+
device=None,
125+
dtype=None,
120126
):
127+
dd = {'device': device, 'dtype': dtype}
121128
super().__init__()
122129
self.grad_checkpointing = False
123130
self.blocks = nn.Sequential(*[
@@ -130,13 +137,15 @@ def __init__(
130137
norm_layer=norm_layer,
131138
act_layer=act_layer,
132139
pconv_fw_type=pconv_fw_type,
140+
**dd,
133141
)
134142
for i in range(depth)
135143
])
136144
self.downsample = PatchMerging(
137145
dim=dim // 2,
138146
patch_size=merge_size,
139147
norm_layer=norm_layer,
148+
**dd,
140149
) if use_merge else nn.Identity()
141150

142151
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -154,11 +163,14 @@ def __init__(
154163
in_chans: int,
155164
embed_dim: int,
156165
patch_size: Union[int, Tuple[int, int]] = 4,
157-
norm_layer: LayerType = nn.BatchNorm2d,
166+
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
167+
device=None,
168+
dtype=None,
158169
):
170+
dd = {'device': device, 'dtype': dtype}
159171
super().__init__()
160-
self.proj = nn.Conv2d(in_chans, embed_dim, patch_size, patch_size, bias=False)
161-
self.norm = norm_layer(embed_dim)
172+
self.proj = nn.Conv2d(in_chans, embed_dim, patch_size, patch_size, bias=False, **dd)
173+
self.norm = norm_layer(embed_dim, **dd)
162174

163175
def forward(self, x: torch.Tensor) -> torch.Tensor:
164176
return self.norm(self.proj(x))
@@ -169,11 +181,14 @@ def __init__(
169181
self,
170182
dim: int,
171183
patch_size: Union[int, Tuple[int, int]] = 2,
172-
norm_layer: LayerType = nn.BatchNorm2d,
184+
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
185+
device=None,
186+
dtype=None,
173187
):
188+
dd = {'device': device, 'dtype': dtype}
174189
super().__init__()
175-
self.reduction = nn.Conv2d(dim, 2 * dim, patch_size, patch_size, bias=False)
176-
self.norm = norm_layer(2 * dim)
190+
self.reduction = nn.Conv2d(dim, 2 * dim, patch_size, patch_size, bias=False, **dd)
191+
self.norm = norm_layer(2 * dim, **dd)
177192

178193
def forward(self, x: torch.Tensor) -> torch.Tensor:
179194
return self.norm(self.reduction(x))
@@ -196,11 +211,14 @@ def __init__(
196211
drop_rate: float = 0.,
197212
drop_path_rate: float = 0.1,
198213
layer_scale_init_value: float = 0.,
199-
act_layer: LayerType = partial(nn.ReLU, inplace=True),
200-
norm_layer: LayerType = nn.BatchNorm2d,
214+
act_layer: Type[nn.Module] = partial(nn.ReLU, inplace=True),
215+
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
201216
pconv_fw_type: str = 'split_cat',
217+
device=None,
218+
dtype=None,
202219
):
203220
super().__init__()
221+
dd = {'device': device, 'dtype': dtype}
204222
assert pconv_fw_type in ('split_cat', 'slicing',)
205223
self.num_classes = num_classes
206224
self.drop_rate = drop_rate
@@ -214,9 +232,10 @@ def __init__(
214232
embed_dim=embed_dim,
215233
patch_size=patch_size,
216234
norm_layer=norm_layer if patch_norm else nn.Identity,
235+
**dd,
217236
)
218237
# stochastic depth decay rule
219-
dpr = calculate_drop_path_rates(drop_path_rate, sum(depths))
238+
dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
220239

221240
# build layers
222241
stages_list = []
@@ -227,13 +246,14 @@ def __init__(
227246
depth=depths[i],
228247
n_div=n_div,
229248
mlp_ratio=mlp_ratio,
230-
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
249+
drop_path=dpr[i],
231250
layer_scale_init_value=layer_scale_init_value,
232251
norm_layer=norm_layer,
233252
act_layer=act_layer,
234253
pconv_fw_type=pconv_fw_type,
235254
use_merge=False if i == 0 else True,
236255
merge_size=merge_size,
256+
**dd,
237257
)
238258
stages_list.append(stage)
239259
self.feature_info += [dict(num_chs=dim, reduction=2**(i+2), module=f'stages.{i}')]
@@ -243,10 +263,10 @@ def __init__(
243263
self.num_features = prev_chs = int(embed_dim * 2 ** (self.num_stages - 1))
244264
self.head_hidden_size = out_chs = feature_dim # 1280
245265
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
246-
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=False)
266+
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=False, **dd)
247267
self.act = act_layer()
248268
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
249-
self.classifier = Linear(out_chs, num_classes, bias=True) if num_classes > 0 else nn.Identity()
269+
self.classifier = Linear(out_chs, num_classes, bias=True, **dd) if num_classes > 0 else nn.Identity()
250270
self._initialize_weights()
251271

252272
def _initialize_weights(self):
@@ -285,12 +305,13 @@ def set_grad_checkpointing(self, enable=True):
285305
def get_classifier(self) -> nn.Module:
286306
return self.classifier
287307

288-
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
308+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg', device=None, dtype=None):
309+
dd = {'device': device, 'dtype': dtype}
289310
self.num_classes = num_classes
290311
# cannot meaningfully change pooling of efficient head after creation
291312
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
292313
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
293-
self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
314+
self.classifier = Linear(self.head_hidden_size, num_classes, **dd) if num_classes > 0 else nn.Identity()
294315

295316
def forward_intermediates(
296317
self,

0 commit comments

Comments
 (0)