Skip to content

Commit 57f9fc2

Browse files
authored
Merge pull request #86 from ayasyrev/typing
Typing
2 parents 92de2d9 + 41aef34 commit 57f9fc2

File tree

2 files changed

+31
-26
lines changed

2 files changed

+31
-26
lines changed

src/model_constructor/model_constructor.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import OrderedDict
22
from functools import partial
3-
from typing import Any, Callable, List, Optional, Type, TypeVar, Union
3+
from typing import Any, Callable, Optional, TypeVar, Union
44

55
import torch.nn as nn
66
from pydantic import BaseModel, root_validator
@@ -19,7 +19,7 @@
1919
TModelCfg = TypeVar("TModelCfg", bound="ModelCfg")
2020

2121

22-
def init_cnn(module: nn.Module):
22+
def init_cnn(module: nn.Module) -> None:
2323
"Init module - kaiming_normal for Conv2d and 0 for biases."
2424
if getattr(module, "bias", None) is not None:
2525
nn.init.constant_(module.bias, 0) # type: ignore
@@ -39,7 +39,7 @@ def __init__(
3939
mid_channels: int,
4040
stride: int = 1,
4141
conv_layer=ConvBnAct,
42-
act_fn: Type[nn.Module] = nn.ReLU,
42+
act_fn: type[nn.Module] = nn.ReLU,
4343
zero_bn: bool = True,
4444
bn_1st: bool = True,
4545
groups: int = 1,
@@ -144,16 +144,17 @@ def __init__(
144144
self.id_conv = nn.Sequential(OrderedDict(id_layers))
145145
else:
146146
self.id_conv = None
147-
self.act_fn = get_act(act_fn) # type: ignore
147+
self.act_fn = get_act(act_fn)
148148

149149
def forward(self, x):
150150
identity = self.id_conv(x) if self.id_conv is not None else x
151151
return self.act_fn(self.convs(x) + identity)
152152

153153

154154
def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
155+
"""Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
155156
len_stem = len(cfg.stem_sizes)
156-
stem: List[tuple[str, nn.Module]] = [
157+
stem: list[tuple[str, nn.Module]] = [
157158
(
158159
f"conv_{i}",
159160
cfg.conv_layer(
@@ -175,7 +176,7 @@ def make_stem(cfg: TModelCfg) -> nn.Sequential: # type: ignore
175176

176177

177178
def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
178-
# expansion, in_channels, out_channels, blocks, stride, sa):
179+
"""Create layer (stage)"""
179180
# if no pool on stem - stride = 2 for first layer block in body
180181
stride = 1 if cfg.stem_pool and layer_num == 0 else 2
181182
num_blocks = cfg.layers[layer_num]
@@ -213,6 +214,7 @@ def make_layer(cfg: TModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
213214

214215

215216
def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
217+
"""Create model body."""
216218
return nn.Sequential(
217219
OrderedDict(
218220
[
@@ -224,6 +226,7 @@ def make_body(cfg: TModelCfg) -> nn.Sequential: # type: ignore
224226

225227

226228
def make_head(cfg: TModelCfg) -> nn.Sequential: # type: ignore
229+
"""Create head."""
227230
head = [
228231
("pool", nn.AdaptiveAvgPool2d(1)),
229232
("flat", nn.Flatten()),
@@ -238,27 +241,27 @@ class ModelCfg(BaseModel):
238241
name: Optional[str] = None
239242
in_chans: int = 3
240243
num_classes: int = 1000
241-
block: Type[nn.Module] = ResBlock
242-
conv_layer: Type[nn.Module] = ConvBnAct
243-
block_sizes: List[int] = [64, 128, 256, 512]
244-
layers: List[int] = [2, 2, 2, 2]
245-
norm: Type[nn.Module] = nn.BatchNorm2d
246-
act_fn: Type[nn.Module] = nn.ReLU
244+
block: type[nn.Module] = ResBlock
245+
conv_layer: type[nn.Module] = ConvBnAct
246+
block_sizes: list[int] = [64, 128, 256, 512]
247+
layers: list[int] = [2, 2, 2, 2]
248+
norm: type[nn.Module] = nn.BatchNorm2d
249+
act_fn: type[nn.Module] = nn.ReLU
247250
pool: Callable[[Any], nn.Module] = partial(
248251
nn.AvgPool2d, kernel_size=2, ceil_mode=True
249252
)
250253
expansion: int = 1
251254
groups: int = 1
252255
dw: bool = False
253256
div_groups: Union[int, None] = None
254-
sa: Union[bool, int, Type[nn.Module]] = False
255-
se: Union[bool, int, Type[nn.Module]] = False
257+
sa: Union[bool, int, type[nn.Module]] = False
258+
se: Union[bool, int, type[nn.Module]] = False
256259
se_module: Union[bool, None] = None
257260
se_reduction: Union[int, None] = None
258261
bn_1st: bool = True
259262
zero_bn: bool = True
260263
stem_stride_on: int = 0
261-
stem_sizes: List[int] = [32, 32, 64]
264+
stem_sizes: list[int] = [32, 32, 64]
262265
stem_pool: Union[Callable[[], nn.Module], None] = partial(
263266
nn.MaxPool2d, kernel_size=3, stride=2, padding=1
264267
)
@@ -286,7 +289,7 @@ def _get_str_value(self, field: str) -> str:
286289
def __repr__(self) -> str:
287290
return f"{self.__repr_name__()}(\n {self.__repr_str__(chr(10) + ' ')})"
288291

289-
def __repr_args__(self):
292+
def __repr_args__(self) -> list[tuple[str, str]]:
290293
return [
291294
(field, str_value)
292295
for field in self.__fields__
@@ -325,7 +328,8 @@ def body(self):
325328
def from_cfg(cls, cfg: ModelCfg):
326329
return cls(**cfg.dict())
327330

328-
def __call__(self):
331+
def __call__(self) -> nn.Sequential:
332+
"""Create model."""
329333
model_name = self.name or self.__class__.__name__
330334
named_sequential = type(model_name, (nn.Sequential,), {})
331335
model = named_sequential(
@@ -338,13 +342,14 @@ def __call__(self):
338342
return model
339343

340344
def _get_extra_repr(self) -> str:
345+
"""Repr for changed fields"""
341346
return " ".join(
342347
f"{field}: {self._get_str_value(field)},"
343348
for field in self.__fields_set__
344349
if field != "name"
345-
)[:-1]
350+
)[:-1] # strip last comma.
346351

347-
def __repr__(self):
352+
def __repr__(self) -> str:
348353
se_repr = self.se.__name__ if self.se else "False" # type: ignore
349354
model_name = self.name or self.__class__.__name__
350355
return (

src/model_constructor/yaresnet.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Yet another ResNet.
33

44
from collections import OrderedDict
5-
from typing import Callable, List, Type, Union
5+
from typing import Callable, Union
66

77
import torch.nn as nn
88
from torch.nn import Mish
@@ -27,15 +27,15 @@ def __init__(
2727
mid_channels: int,
2828
stride: int = 1,
2929
conv_layer=ConvBnAct,
30-
act_fn: Type[nn.Module] = nn.ReLU,
30+
act_fn: type[nn.Module] = nn.ReLU,
3131
zero_bn: bool = True,
3232
bn_1st: bool = True,
3333
groups: int = 1,
3434
dw: bool = False,
3535
div_groups: Union[None, int] = None,
3636
pool: Union[Callable[[], nn.Module], None] = None,
37-
se: Union[Type[nn.Module], None] = None,
38-
sa: Union[Type[nn.Module], None] = None,
37+
se: Union[type[nn.Module], None] = None,
38+
sa: Union[type[nn.Module], None] = None,
3939
):
4040
super().__init__()
4141
# pool defined at ModelConstructor.
@@ -115,9 +115,9 @@ def __init__(
115115
), # noqa E501
116116
]
117117
if se:
118-
layers.append(("se", se(out_channels)))
118+
layers.append(("se", se(out_channels))) # type: ignore
119119
if sa:
120-
layers.append(("sa", sa(out_channels)))
120+
layers.append(("sa", sa(out_channels))) # type: ignore
121121
self.convs = nn.Sequential(OrderedDict(layers))
122122
if in_channels != out_channels:
123123
self.id_conv = conv_layer(
@@ -143,7 +143,7 @@ class YaResNet34(ModelConstructor):
143143
expansion: int = 1
144144
layers: list[int] = [3, 4, 6, 3]
145145
stem_sizes: list[int] = [3, 32, 64, 64]
146-
act_fn: Type[nn.Module] = Mish
146+
act_fn: type[nn.Module] = Mish
147147

148148

149149
class YaResNet50(YaResNet34):

0 commit comments

Comments
 (0)