Skip to content

Commit 41aef34

Browse files
committed
typing
1 parent e7195aa commit 41aef34

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/model_constructor/model_constructor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
TModelCfg = TypeVar("TModelCfg", bound="ModelCfg")
2020

2121

22-
def init_cnn(module: nn.Module):
22+
def init_cnn(module: nn.Module) -> None:
2323
"Init module - kaiming_normal for Conv2d and 0 for biases."
2424
if getattr(module, "bias", None) is not None:
2525
nn.init.constant_(module.bias, 0) # type: ignore
@@ -144,14 +144,15 @@ def __init__(
144144
self.id_conv = nn.Sequential(OrderedDict(id_layers))
145145
else:
146146
self.id_conv = None
147-
self.act_fn = get_act(act_fn) # type: ignore
147+
self.act_fn = get_act(act_fn)
148148

149149
def forward(self, x):
150150
identity = self.id_conv(x) if self.id_conv is not None else x
151151
return self.act_fn(self.convs(x) + identity)
152152

153153

154154
def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
155+
"""Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
155156
len_stem = len(cfg.stem_sizes)
156157
stem: list[tuple[str, nn.Module]] = [
157158
(
@@ -175,7 +176,7 @@ def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
175176

176177

177178
def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
178-
# expansion, in_channels, out_channels, blocks, stride, sa):
179+
"""Create layer (stage)"""
179180
# if no pool on stem - stride = 2 for first layer block in body
180181
stride = 1 if cfg.stem_pool and layer_num == 0 else 2
181182
num_blocks = cfg.layers[layer_num]
@@ -213,6 +214,7 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
213214

214215

215216
def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
217+
"""Create model body."""
216218
return nn.Sequential(
217219
OrderedDict(
218220
[
@@ -224,6 +226,7 @@ def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
224226

225227

226228
def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
229+
"""Create head."""
227230
head = [
228231
("pool", nn.AdaptiveAvgPool2d(1)),
229232
("flat", nn.Flatten()),
@@ -326,6 +329,7 @@ def from_cfg(cls, cfg: ModelCfg):
326329
return cls(**cfg.dict())
327330

328331
def __call__(self) -> nn.Sequential:
332+
"""Create model."""
329333
model_name = self.name or self.__class__.__name__
330334
named_sequential = type(model_name, (nn.Sequential,), {})
331335
model = named_sequential(

0 commit comments

Comments
 (0)