1
1
from collections import OrderedDict
2
2
from functools import partial
3
- from typing import Any , Callable , List , Optional , Type , TypeVar , Union
3
+ from typing import Any , Callable , Optional , TypeVar , Union
4
4
5
5
import torch .nn as nn
6
6
from pydantic import BaseModel , root_validator
19
19
TModelCfg = TypeVar ("TModelCfg" , bound = "ModelCfg" )
20
20
21
21
22
- def init_cnn (module : nn .Module ):
22
+ def init_cnn (module : nn .Module ) -> None :
23
23
"Init module - kaiming_normal for Conv2d and 0 for biases."
24
24
if getattr (module , "bias" , None ) is not None :
25
25
nn .init .constant_ (module .bias , 0 ) # type: ignore
@@ -39,7 +39,7 @@ def __init__(
39
39
mid_channels : int ,
40
40
stride : int = 1 ,
41
41
conv_layer = ConvBnAct ,
42
- act_fn : Type [nn .Module ] = nn .ReLU ,
42
+ act_fn : type [nn .Module ] = nn .ReLU ,
43
43
zero_bn : bool = True ,
44
44
bn_1st : bool = True ,
45
45
groups : int = 1 ,
@@ -144,16 +144,17 @@ def __init__(
144
144
self .id_conv = nn .Sequential (OrderedDict (id_layers ))
145
145
else :
146
146
self .id_conv = None
147
- self .act_fn = get_act (act_fn ) # type: ignore
147
+ self .act_fn = get_act (act_fn )
148
148
149
149
def forward (self , x ):
150
150
identity = self .id_conv (x ) if self .id_conv is not None else x
151
151
return self .act_fn (self .convs (x ) + identity )
152
152
153
153
154
154
def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
155
+ """Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
155
156
len_stem = len (cfg .stem_sizes )
156
- stem : List [tuple [str , nn .Module ]] = [
157
+ stem : list [tuple [str , nn .Module ]] = [
157
158
(
158
159
f"conv_{ i } " ,
159
160
cfg .conv_layer (
@@ -175,7 +176,7 @@ def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
175
176
176
177
177
178
def make_layer (cfg : TModelCfg , layer_num : int ) -> nn .Sequential : # type: ignore
178
- # expansion, in_channels, out_channels, blocks, stride, sa):
179
+ """Create layer (stage)"""
179
180
# if no pool on stem - stride = 2 for first layer block in body
180
181
stride = 1 if cfg .stem_pool and layer_num == 0 else 2
181
182
num_blocks = cfg .layers [layer_num ]
@@ -213,6 +214,7 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
213
214
214
215
215
216
def make_body (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
217
+ """Create model body."""
216
218
return nn .Sequential (
217
219
OrderedDict (
218
220
[
@@ -224,6 +226,7 @@ def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
224
226
225
227
226
228
def make_head (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
229
+ """Create head."""
227
230
head = [
228
231
("pool" , nn .AdaptiveAvgPool2d (1 )),
229
232
("flat" , nn .Flatten ()),
@@ -238,27 +241,27 @@ class ModelCfg(BaseModel):
238
241
name : Optional [str ] = None
239
242
in_chans : int = 3
240
243
num_classes : int = 1000
241
- block : Type [nn .Module ] = ResBlock
242
- conv_layer : Type [nn .Module ] = ConvBnAct
243
- block_sizes : List [int ] = [64 , 128 , 256 , 512 ]
244
- layers : List [int ] = [2 , 2 , 2 , 2 ]
245
- norm : Type [nn .Module ] = nn .BatchNorm2d
246
- act_fn : Type [nn .Module ] = nn .ReLU
244
+ block : type [nn .Module ] = ResBlock
245
+ conv_layer : type [nn .Module ] = ConvBnAct
246
+ block_sizes : list [int ] = [64 , 128 , 256 , 512 ]
247
+ layers : list [int ] = [2 , 2 , 2 , 2 ]
248
+ norm : type [nn .Module ] = nn .BatchNorm2d
249
+ act_fn : type [nn .Module ] = nn .ReLU
247
250
pool : Callable [[Any ], nn .Module ] = partial (
248
251
nn .AvgPool2d , kernel_size = 2 , ceil_mode = True
249
252
)
250
253
expansion : int = 1
251
254
groups : int = 1
252
255
dw : bool = False
253
256
div_groups : Union [int , None ] = None
254
- sa : Union [bool , int , Type [nn .Module ]] = False
255
- se : Union [bool , int , Type [nn .Module ]] = False
257
+ sa : Union [bool , int , type [nn .Module ]] = False
258
+ se : Union [bool , int , type [nn .Module ]] = False
256
259
se_module : Union [bool , None ] = None
257
260
se_reduction : Union [int , None ] = None
258
261
bn_1st : bool = True
259
262
zero_bn : bool = True
260
263
stem_stride_on : int = 0
261
- stem_sizes : List [int ] = [32 , 32 , 64 ]
264
+ stem_sizes : list [int ] = [32 , 32 , 64 ]
262
265
stem_pool : Union [Callable [[], nn .Module ], None ] = partial (
263
266
nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1
264
267
)
@@ -286,7 +289,7 @@ def _get_str_value(self, field: str) -> str:
286
289
def __repr__ (self ) -> str :
287
290
return f"{ self .__repr_name__ ()} (\n { self .__repr_str__ (chr (10 ) + ' ' )} )"
288
291
289
- def __repr_args__ (self ):
292
+ def __repr_args__ (self ) -> list [ tuple [ str , str ]] :
290
293
return [
291
294
(field , str_value )
292
295
for field in self .__fields__
@@ -325,7 +328,8 @@ def body(self):
325
328
def from_cfg (cls , cfg : ModelCfg ):
326
329
return cls (** cfg .dict ())
327
330
328
- def __call__ (self ):
331
+ def __call__ (self ) -> nn .Sequential :
332
+ """Create model."""
329
333
model_name = self .name or self .__class__ .__name__
330
334
named_sequential = type (model_name , (nn .Sequential ,), {})
331
335
model = named_sequential (
@@ -338,13 +342,14 @@ def __call__(self):
338
342
return model
339
343
340
344
def _get_extra_repr (self ) -> str :
345
+ """Repr for changed fields"""
341
346
return " " .join (
342
347
f"{ field } : { self ._get_str_value (field )} ,"
343
348
for field in self .__fields_set__
344
349
if field != "name"
345
- )[:- 1 ]
350
+ )[:- 1 ] # strip last comma.
346
351
347
- def __repr__ (self ):
352
+ def __repr__ (self ) -> str :
348
353
se_repr = self .se .__name__ if self .se else "False" # type: ignore
349
354
model_name = self .name or self .__class__ .__name__
350
355
return (
0 commit comments