Skip to content

Commit 21b1ae7

Browse files
committed
More dd factory kwargs updates. hiera, hieradet_sam2, metaformer, mlp_mixer, mobilevit, pnasnet, rexnet, sequencer, shvit model files. Fixed blur_pool dtype/device handling and update Mlp modules w/ annotations and a fix.
1 parent 53caeb0 commit 21b1ae7

File tree

12 files changed

+919
-533
lines changed

12 files changed

+919
-533
lines changed

timm/layers/blur_pool.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88
from functools import partial
99
from math import comb # Python 3.8
10-
from typing import Optional, Type
10+
from typing import Callable, Optional, Type, Union
1111

1212
import torch
1313
import torch.nn as nn
@@ -39,8 +39,7 @@ def __init__(
3939
device=None,
4040
dtype=None
4141
) -> None:
42-
dd = {'device': device, 'dtype': dtype}
43-
super(BlurPool2d, self).__init__()
42+
super().__init__()
4443
assert filt_size > 1
4544
self.channels = channels
4645
self.filt_size = filt_size
@@ -51,12 +50,18 @@ def __init__(
5150
# (0.5 + 0.5 x)^N => coefficients = C(N,k) / 2^N, k = 0..N
5251
coeffs = torch.tensor(
5352
[comb(filt_size - 1, k) for k in range(filt_size)],
54-
**dd,
53+
device='cpu',
54+
dtype=torch.float32,
5555
) / (2 ** (filt_size - 1)) # normalise so coefficients sum to 1
5656
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :]
5757
if channels is not None:
5858
blur_filter = blur_filter.repeat(self.channels, 1, 1, 1)
59-
self.register_buffer('filt', blur_filter, persistent=False)
59+
60+
self.register_buffer(
61+
'filt',
62+
blur_filter.to(device=device, dtype=dtype),
63+
persistent=False,
64+
)
6065

6166
def forward(self, x: torch.Tensor) -> torch.Tensor:
6267
x = F.pad(x, self.padding, mode=self.pad_mode)
@@ -69,6 +74,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6974
return F.conv2d(x, weight, stride=self.stride, groups=channels)
7075

7176

77+
def _normalize_aa_layer(aa_layer: LayerType) -> Callable[..., nn.Module]:
78+
"""Map string shorthands to callables (class or partial)."""
79+
if isinstance(aa_layer, str):
80+
key = aa_layer.lower().replace('_', '').replace('-', '')
81+
if key in ('avg', 'avgpool'):
82+
return nn.AvgPool2d
83+
if key in ('blur', 'blurpool'):
84+
return BlurPool2d
85+
if key == 'blurpc':
86+
# preconfigure a constant-pad BlurPool2d
87+
return partial(BlurPool2d, pad_mode='constant')
88+
raise AssertionError(f"Unknown anti-aliasing layer ({aa_layer}).")
89+
return aa_layer
90+
91+
92+
def _underlying_cls(layer_callable: Callable[..., nn.Module]):
93+
"""Return the class behind a callable (unwrap partial), else None."""
94+
if isinstance(layer_callable, partial):
95+
return layer_callable.func
96+
return layer_callable if isinstance(layer_callable, type) else None
97+
98+
99+
def _is_blurpool(layer_callable: Callable[..., nn.Module]) -> bool:
100+
"""True if callable is BlurPool2d or a partial of it."""
101+
cls = _underlying_cls(layer_callable)
102+
try:
103+
return issubclass(cls, BlurPool2d) # cls may be None, protect below
104+
except TypeError:
105+
return False
106+
except Exception:
107+
return False
108+
109+
72110
def create_aa(
73111
aa_layer: LayerType,
74112
channels: Optional[int] = None,
@@ -77,24 +115,29 @@ def create_aa(
77115
noop: Optional[Type[nn.Module]] = nn.Identity,
78116
device=None,
79117
dtype=None,
80-
) -> nn.Module:
81-
""" Anti-aliasing """
118+
) -> Optional[nn.Module]:
119+
""" Anti-aliasing factory that supports strings, classes, and partials. """
82120
if not aa_layer or not enable:
83121
return noop() if noop is not None else None
84122

85-
if isinstance(aa_layer, str):
86-
aa_layer = aa_layer.lower().replace('_', '').replace('-', '')
87-
if aa_layer == 'avg' or aa_layer == 'avgpool':
88-
aa_layer = nn.AvgPool2d
89-
elif aa_layer == 'blur' or aa_layer == 'blurpool':
90-
aa_layer = partial(BlurPool2d, device=device, dtype=dtype)
91-
elif aa_layer == 'blurpc':
92-
aa_layer = partial(BlurPool2d, pad_mode='constant', device=device, dtype=dtype)
123+
# Resolve strings to callables
124+
aa_layer = _normalize_aa_layer(aa_layer)
93125

94-
else:
95-
assert False, f"Unknown anti-aliasing layer ({aa_layer})."
126+
# Build kwargs we *intend* to pass
127+
call_kwargs = {"channels": channels, "stride": stride}
128+
129+
# Only add device/dtype for BlurPool2d (or partial of it) and don't override if already provided in the partial.
130+
if _is_blurpool(aa_layer):
131+
# Check if aa_layer is a partial and already has device/dtype set
132+
existing_kw = aa_layer.keywords if isinstance(aa_layer, partial) and aa_layer.keywords else {}
133+
if "device" not in existing_kw and device is not None:
134+
call_kwargs["device"] = device
135+
if "dtype" not in existing_kw and dtype is not None:
136+
call_kwargs["dtype"] = dtype
96137

138+
# Try (channels, stride, [device, dtype]) first; fall back to (stride) only
97139
try:
98-
return aa_layer(channels=channels, stride=stride)
99-
except TypeError as e:
140+
return aa_layer(**call_kwargs)
141+
except TypeError:
142+
# Some layers (e.g., AvgPool2d) may not accept 'channels' and need stride passed as kernel
100143
return aa_layer(stride)

timm/layers/mlp.py

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Hacked together by / Copyright 2020 Ross Wightman
44
"""
55
from functools import partial
6+
from typing import Optional, Type, Union, Tuple
67

78
from torch import nn as nn
89

@@ -17,14 +18,14 @@ class Mlp(nn.Module):
1718
"""
1819
def __init__(
1920
self,
20-
in_features,
21-
hidden_features=None,
22-
out_features=None,
23-
act_layer=nn.GELU,
24-
norm_layer=None,
25-
bias=True,
26-
drop=0.,
27-
use_conv=False,
21+
in_features: int,
22+
hidden_features: Optional[int] = None,
23+
out_features: Optional[int] = None,
24+
act_layer: Type[nn.Module] = nn.GELU,
25+
norm_layer: Optional[Type[nn.Module]] = None,
26+
bias: Union[bool, Tuple[bool, bool]] = True,
27+
drop: Union[float, Tuple[float, float]] = 0.,
28+
use_conv: bool = False,
2829
device=None,
2930
dtype=None,
3031
):
@@ -61,15 +62,15 @@ class GluMlp(nn.Module):
6162
"""
6263
def __init__(
6364
self,
64-
in_features,
65-
hidden_features=None,
66-
out_features=None,
67-
act_layer=nn.Sigmoid,
68-
norm_layer=None,
69-
bias=True,
70-
drop=0.,
71-
use_conv=False,
72-
gate_last=True,
65+
in_features: int,
66+
hidden_features: Optional[int] = None,
67+
out_features: Optional[int] = None,
68+
act_layer: Type[nn.Module] = nn.Sigmoid,
69+
norm_layer: Optional[Type[nn.Module]] = None,
70+
bias: Union[bool, Tuple[bool, bool]] = True,
71+
drop: Union[float, Tuple[float, float]] = 0.,
72+
use_conv: bool = False,
73+
gate_last: bool = True,
7374
device=None,
7475
dtype=None,
7576
):
@@ -118,14 +119,14 @@ class SwiGLU(nn.Module):
118119
"""
119120
def __init__(
120121
self,
121-
in_features,
122-
hidden_features=None,
123-
out_features=None,
124-
act_layer=nn.SiLU,
125-
norm_layer=None,
126-
bias=True,
127-
drop=0.,
128-
align_to=0,
122+
in_features: int,
123+
hidden_features: Optional[int] = None,
124+
out_features: Optional[int] = None,
125+
act_layer: Type[nn.Module] = nn.SiLU,
126+
norm_layer: Optional[Type[nn.Module]] = None,
127+
bias: Union[bool, Tuple[bool, bool]] = True,
128+
drop: Union[float, Tuple[float, float]] = 0.,
129+
align_to: int = 0,
129130
device=None,
130131
dtype=None,
131132
):
@@ -169,14 +170,14 @@ class GatedMlp(nn.Module):
169170
"""
170171
def __init__(
171172
self,
172-
in_features,
173-
hidden_features=None,
174-
out_features=None,
175-
act_layer=nn.GELU,
176-
norm_layer=None,
177-
gate_layer=None,
178-
bias=True,
179-
drop=0.,
173+
in_features: int,
174+
hidden_features: Optional[int] = None,
175+
out_features: Optional[int] = None,
176+
act_layer: Type[nn.Module] = nn.GELU,
177+
norm_layer: Optional[Type[nn.Module]] = None,
178+
gate_layer: Optional[Type[nn.Module]] = None,
179+
bias: Union[bool, Tuple[bool, bool]] = True,
180+
drop: Union[float, Tuple[float, float]] = 0.,
180181
device=None,
181182
dtype=None,
182183
):
@@ -216,13 +217,13 @@ class ConvMlp(nn.Module):
216217
"""
217218
def __init__(
218219
self,
219-
in_features,
220-
hidden_features=None,
221-
out_features=None,
222-
act_layer=nn.ReLU,
223-
norm_layer=None,
224-
bias=True,
225-
drop=0.,
220+
in_features: int,
221+
hidden_features: Optional[int] = None,
222+
out_features: Optional[int] = None,
223+
act_layer: Type[nn.Module] = nn.ReLU,
224+
norm_layer: Optional[Type[nn.Module]] = None,
225+
bias: Union[bool, Tuple[bool, bool]] = True,
226+
drop: float = 0.,
226227
device=None,
227228
dtype=None,
228229
):
@@ -254,13 +255,13 @@ class GlobalResponseNormMlp(nn.Module):
254255
"""
255256
def __init__(
256257
self,
257-
in_features,
258-
hidden_features=None,
259-
out_features=None,
260-
act_layer=nn.GELU,
261-
bias=True,
262-
drop=0.,
263-
use_conv=False,
258+
in_features: int,
259+
hidden_features: Optional[int] = None,
260+
out_features: Optional[int] = None,
261+
act_layer: Type[nn.Module] = nn.GELU,
262+
bias: Union[bool, Tuple[bool, bool]] = True,
263+
drop: Union[float, Tuple[float, float]] = 0.,
264+
use_conv: bool = False,
264265
device=None,
265266
dtype=None,
266267
):

timm/models/byobnet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import math
3232
from dataclasses import dataclass, field, replace
3333
from functools import partial
34-
from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
34+
from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence, Type
3535

3636
import torch
3737
import torch.nn as nn
@@ -245,11 +245,11 @@ def num_groups(group_size: Optional[int], channels: int) -> int:
245245
@dataclass
246246
class LayerFn:
247247
"""Container for layer factory functions."""
248-
conv_norm_act: Callable = ConvNormAct
249-
norm_act: Callable = BatchNormAct2d
250-
act: Callable = nn.ReLU
251-
attn: Optional[Callable] = None
252-
self_attn: Optional[Callable] = None
248+
conv_norm_act: Type[nn.Module] = ConvNormAct
249+
norm_act: Type[nn.Module] = BatchNormAct2d
250+
act: Type[nn.Module] = nn.ReLU
251+
attn: Optional[Type[nn.Module]] = None
252+
self_attn: Optional[Type[nn.Module]] = None
253253

254254

255255
class DownsampleAvg(nn.Module):

0 commit comments

Comments
 (0)