Skip to content

Commit d3fdea8

Browse files
committed
Typing, super(), buffer dtype fixes for timm/layers and timm/models
1 parent 5cadf13 commit d3fdea8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+643
-481
lines changed

timm/layers/activations.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def swish(x, inplace: bool = False):
1919

2020
class Swish(nn.Module):
2121
def __init__(self, inplace: bool = False):
22-
super(Swish, self).__init__()
22+
super().__init__()
2323
self.inplace = inplace
2424

2525
def forward(self, x):
@@ -37,7 +37,7 @@ class Mish(nn.Module):
3737
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
3838
"""
3939
def __init__(self, inplace: bool = False):
40-
super(Mish, self).__init__()
40+
super().__init__()
4141

4242
def forward(self, x):
4343
return mish(x)
@@ -50,7 +50,7 @@ def sigmoid(x, inplace: bool = False):
5050
# PyTorch has this, but not with a consistent inplace argument interface
5151
class Sigmoid(nn.Module):
5252
def __init__(self, inplace: bool = False):
53-
super(Sigmoid, self).__init__()
53+
super().__init__()
5454
self.inplace = inplace
5555

5656
def forward(self, x):
@@ -64,7 +64,7 @@ def tanh(x, inplace: bool = False):
6464
# PyTorch has this, but not with a consistent inplace argument interface
6565
class Tanh(nn.Module):
6666
def __init__(self, inplace: bool = False):
67-
super(Tanh, self).__init__()
67+
super().__init__()
6868
self.inplace = inplace
6969

7070
def forward(self, x):
@@ -78,7 +78,7 @@ def hard_swish(x, inplace: bool = False):
7878

7979
class HardSwish(nn.Module):
8080
def __init__(self, inplace: bool = False):
81-
super(HardSwish, self).__init__()
81+
super().__init__()
8282
self.inplace = inplace
8383

8484
def forward(self, x):
@@ -94,7 +94,7 @@ def hard_sigmoid(x, inplace: bool = False):
9494

9595
class HardSigmoid(nn.Module):
9696
def __init__(self, inplace: bool = False):
97-
super(HardSigmoid, self).__init__()
97+
super().__init__()
9898
self.inplace = inplace
9999

100100
def forward(self, x):
@@ -114,7 +114,7 @@ def hard_mish(x, inplace: bool = False):
114114

115115
class HardMish(nn.Module):
116116
def __init__(self, inplace: bool = False):
117-
super(HardMish, self).__init__()
117+
super().__init__()
118118
self.inplace = inplace
119119

120120
def forward(self, x):
@@ -125,7 +125,7 @@ class PReLU(nn.PReLU):
125125
"""Applies PReLU (w/ dummy inplace arg)
126126
"""
127127
def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
128-
super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
128+
super().__init__(num_parameters=num_parameters, init=init)
129129

130130
def forward(self, input: torch.Tensor) -> torch.Tensor:
131131
return F.prelu(input, self.weight)
@@ -139,7 +139,7 @@ class GELU(nn.Module):
139139
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
140140
"""
141141
def __init__(self, inplace: bool = False):
142-
super(GELU, self).__init__()
142+
super().__init__()
143143

144144
def forward(self, input: torch.Tensor) -> torch.Tensor:
145145
return F.gelu(input)
@@ -153,7 +153,7 @@ class GELUTanh(nn.Module):
153153
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
154154
"""
155155
def __init__(self, inplace: bool = False):
156-
super(GELUTanh, self).__init__()
156+
super().__init__()
157157

158158
def forward(self, input: torch.Tensor) -> torch.Tensor:
159159
return F.gelu(input, approximate='tanh')
@@ -167,7 +167,7 @@ class QuickGELU(nn.Module):
167167
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
168168
"""
169169
def __init__(self, inplace: bool = False):
170-
super(QuickGELU, self).__init__()
170+
super().__init__()
171171

172172
def forward(self, input: torch.Tensor) -> torch.Tensor:
173173
return quick_gelu(input)

timm/layers/activations_me.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def swish_me(x, inplace=False):
4949

5050
class SwishMe(nn.Module):
5151
def __init__(self, inplace: bool = False):
52-
super(SwishMe, self).__init__()
52+
super().__init__()
5353

5454
def forward(self, x):
5555
return SwishAutoFn.apply(x)
@@ -86,7 +86,7 @@ def mish_me(x, inplace=False):
8686

8787
class MishMe(nn.Module):
8888
def __init__(self, inplace: bool = False):
89-
super(MishMe, self).__init__()
89+
super().__init__()
9090

9191
def forward(self, x):
9292
return MishAutoFn.apply(x)
@@ -119,7 +119,7 @@ def hard_sigmoid_me(x, inplace: bool = False):
119119

120120
class HardSigmoidMe(nn.Module):
121121
def __init__(self, inplace: bool = False):
122-
super(HardSigmoidMe, self).__init__()
122+
super().__init__()
123123

124124
def forward(self, x):
125125
return HardSigmoidAutoFn.apply(x)
@@ -161,7 +161,7 @@ def hard_swish_me(x, inplace=False):
161161

162162
class HardSwishMe(nn.Module):
163163
def __init__(self, inplace: bool = False):
164-
super(HardSwishMe, self).__init__()
164+
super().__init__()
165165

166166
def forward(self, x):
167167
return HardSwishAutoFn.apply(x)
@@ -199,7 +199,7 @@ def hard_mish_me(x, inplace: bool = False):
199199

200200
class HardMishMe(nn.Module):
201201
def __init__(self, inplace: bool = False):
202-
super(HardMishMe, self).__init__()
202+
super().__init__()
203203

204204
def forward(self, x):
205205
return HardMishAutoFn.apply(x)

timm/layers/adaptive_avgmax_pool.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size: _int_tuple_2_t = 1):
5757

5858
class FastAdaptiveAvgPool(nn.Module):
5959
def __init__(self, flatten: bool = False, input_fmt: F = 'NCHW'):
60-
super(FastAdaptiveAvgPool, self).__init__()
60+
super().__init__()
6161
self.flatten = flatten
6262
self.dim = get_spatial_dim(input_fmt)
6363

@@ -67,7 +67,7 @@ def forward(self, x):
6767

6868
class FastAdaptiveMaxPool(nn.Module):
6969
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
70-
super(FastAdaptiveMaxPool, self).__init__()
70+
super().__init__()
7171
self.flatten = flatten
7272
self.dim = get_spatial_dim(input_fmt)
7373

@@ -77,7 +77,7 @@ def forward(self, x):
7777

7878
class FastAdaptiveAvgMaxPool(nn.Module):
7979
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
80-
super(FastAdaptiveAvgMaxPool, self).__init__()
80+
super().__init__()
8181
self.flatten = flatten
8282
self.dim = get_spatial_dim(input_fmt)
8383

@@ -89,7 +89,7 @@ def forward(self, x):
8989

9090
class FastAdaptiveCatAvgMaxPool(nn.Module):
9191
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
92-
super(FastAdaptiveCatAvgMaxPool, self).__init__()
92+
super().__init__()
9393
self.flatten = flatten
9494
self.dim_reduce = get_spatial_dim(input_fmt)
9595
if flatten:
@@ -105,7 +105,7 @@ def forward(self, x):
105105

106106
class AdaptiveAvgMaxPool2d(nn.Module):
107107
def __init__(self, output_size: _int_tuple_2_t = 1):
108-
super(AdaptiveAvgMaxPool2d, self).__init__()
108+
super().__init__()
109109
self.output_size = output_size
110110

111111
def forward(self, x):
@@ -114,7 +114,7 @@ def forward(self, x):
114114

115115
class AdaptiveCatAvgMaxPool2d(nn.Module):
116116
def __init__(self, output_size: _int_tuple_2_t = 1):
117-
super(AdaptiveCatAvgMaxPool2d, self).__init__()
117+
super().__init__()
118118
self.output_size = output_size
119119

120120
def forward(self, x):
@@ -131,7 +131,7 @@ def __init__(
131131
flatten: bool = False,
132132
input_fmt: str = 'NCHW',
133133
):
134-
super(SelectAdaptivePool2d, self).__init__()
134+
super().__init__()
135135
assert input_fmt in ('NCHW', 'NHWC')
136136
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
137137
pool_type = pool_type.lower()

timm/layers/cbam.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
dtype=None,
3535
):
3636
dd = {'device': device, 'dtype': dtype}
37-
super(ChannelAttn, self).__init__()
37+
super().__init__()
3838
if not rd_channels:
3939
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
4040
self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias, **dd)
@@ -63,7 +63,7 @@ def __init__(
6363
device=None,
6464
dtype=None
6565
):
66-
super(LightChannelAttn, self).__init__(
66+
super().__init__(
6767
channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias, device=device, dtype=dtype)
6868

6969
def forward(self, x):
@@ -82,8 +82,8 @@ def __init__(
8282
device=None,
8383
dtype=None,
8484
):
85-
super(SpatialAttn, self).__init__()
86-
self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False)
85+
super().__init__()
86+
self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False, device=device, dtype=dtype)
8787
self.gate = create_act_layer(gate_layer)
8888

8989
def forward(self, x):
@@ -102,8 +102,8 @@ def __init__(
102102
device=None,
103103
dtype=None,
104104
):
105-
super(LightSpatialAttn, self).__init__()
106-
self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False)
105+
super().__init__()
106+
self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False, device=device, dtype=dtype)
107107
self.gate = create_act_layer(gate_layer)
108108

109109
def forward(self, x):
@@ -127,7 +127,7 @@ def __init__(
127127
dtype=None,
128128
):
129129
dd = {'device': device, 'dtype': dtype}
130-
super(CbamModule, self).__init__()
130+
super().__init__()
131131
self.channel = ChannelAttn(
132132
channels,
133133
rd_ratio=rd_ratio,
@@ -161,7 +161,7 @@ def __init__(
161161
dtype=None,
162162
):
163163
dd = {'device': device, 'dtype': dtype}
164-
super(LightCbamModule, self).__init__()
164+
super().__init__()
165165
self.channel = LightChannelAttn(
166166
channels,
167167
rd_ratio=rd_ratio,

timm/layers/classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(
9595
pool_type: Global pooling type, pooling disabled if empty string ('').
9696
drop_rate: Pre-classifier dropout rate.
9797
"""
98-
super(ClassifierHead, self).__init__()
98+
super().__init__()
9999
self.in_features = in_features
100100
self.use_conv = use_conv
101101
self.input_fmt = input_fmt
@@ -258,7 +258,7 @@ def __init__(
258258
norm_layer = get_norm_layer(norm_layer)
259259
act_layer = get_act_layer(act_layer)
260260

261-
self.norm = norm_layer(in_features)
261+
self.norm = norm_layer(in_features, **dd)
262262
if hidden_size:
263263
self.pre_logits = nn.Sequential(OrderedDict([
264264
('fc', nn.Linear(in_features, hidden_size, **dd)),

timm/layers/cond_conv2d.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import math
1010
from functools import partial
11+
from typing import Union, Tuple
12+
1113
import torch
1214
from torch import nn as nn
1315
from torch.nn import functional as F
@@ -43,20 +45,20 @@ class CondConv2d(nn.Module):
4345

4446
def __init__(
4547
self,
46-
in_channels,
47-
out_channels,
48-
kernel_size=3,
49-
stride=1,
50-
padding='',
51-
dilation=1,
52-
groups=1,
53-
bias=False,
54-
num_experts=4,
48+
in_channels: int,
49+
out_channels: int,
50+
kernel_size: Union[int, Tuple[int, int]] = 3,
51+
stride: Union[int, Tuple[int, int]] = 1,
52+
padding: Union[int, Tuple[int, int], str] = '',
53+
dilation: Union[int, Tuple[int, int]] = 1,
54+
groups: int = 1,
55+
bias: bool = False,
56+
num_experts: int = 4,
5557
device=None,
5658
dtype=None,
5759
):
5860
dd = {'device': device, 'dtype': dtype}
59-
super(CondConv2d, self).__init__()
61+
super().__init__()
6062

6163
self.in_channels = in_channels
6264
self.out_channels = out_channels

0 commit comments

Comments
 (0)