19
19
TModelCfg = TypeVar ("TModelCfg" , bound = "ModelCfg" )
20
20
21
21
22
- def init_cnn (module : nn .Module ):
22
+ def init_cnn (module : nn .Module ) -> None :
23
23
"Init module - kaiming_normal for Conv2d and 0 for biases."
24
24
if getattr (module , "bias" , None ) is not None :
25
25
nn .init .constant_ (module .bias , 0 ) # type: ignore
@@ -144,14 +144,15 @@ def __init__(
144
144
self .id_conv = nn .Sequential (OrderedDict (id_layers ))
145
145
else :
146
146
self .id_conv = None
147
- self .act_fn = get_act (act_fn ) # type: ignore
147
+ self .act_fn = get_act (act_fn )
148
148
149
149
def forward (self , x ):
150
150
identity = self .id_conv (x ) if self .id_conv is not None else x
151
151
return self .act_fn (self .convs (x ) + identity )
152
152
153
153
154
154
def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
155
+ """Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
155
156
len_stem = len (cfg .stem_sizes )
156
157
stem : list [tuple [str , nn .Module ]] = [
157
158
(
@@ -175,7 +176,7 @@ def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
175
176
176
177
177
178
def make_layer (cfg : TModelCfg , layer_num : int ) -> nn .Sequential : # type: ignore
178
- # expansion, in_channels, out_channels, blocks, stride, sa):
179
+ """Create layer (stage)"""
179
180
# if no pool on stem - stride = 2 for first layer block in body
180
181
stride = 1 if cfg .stem_pool and layer_num == 0 else 2
181
182
num_blocks = cfg .layers [layer_num ]
@@ -213,6 +214,7 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
213
214
214
215
215
216
def make_body (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
217
+ """Create model body."""
216
218
return nn .Sequential (
217
219
OrderedDict (
218
220
[
@@ -224,6 +226,7 @@ def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
224
226
225
227
226
228
def make_head (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
229
+ """Create head."""
227
230
head = [
228
231
("pool" , nn .AdaptiveAvgPool2d (1 )),
229
232
("flat" , nn .Flatten ()),
@@ -326,6 +329,7 @@ def from_cfg(cls, cfg: ModelCfg):
326
329
return cls (** cfg .dict ())
327
330
328
331
def __call__ (self ) -> nn .Sequential :
332
+ """Create model."""
329
333
model_name = self .name or self .__class__ .__name__
330
334
named_sequential = type (model_name , (nn .Sequential ,), {})
331
335
model = named_sequential (
0 commit comments