Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
7560ca5
Adding dd factory_kwargs to modules in timm/layers, initial model WIP…
rwightman Sep 26, 2025
90a35c8
Add dd factory kwargs to eva, resnet
rwightman Sep 27, 2025
325a6cc
Add dd to other ResNet based models, Res2Net, ResNeSt, SKNet
rwightman Sep 27, 2025
b94c221
Add dd factory kwargs to maxxvit and regnet
rwightman Sep 27, 2025
ee751ef
Add dd factory kwargs to nfnet and resnetv2
rwightman Sep 28, 2025
10e7020
dd factory kwargs for fastvit, convnext, mambaout
rwightman Sep 28, 2025
60db539
Add dd factory kwargs to all EfficientNetBuilder models, MobileNet V1…
rwightman Sep 28, 2025
4d19b34
Fix typo for s2d norm
rwightman Sep 28, 2025
f15f7c9
Add dd factory kwargs to byobnet, cspnet, davit, edgenext
rwightman Sep 29, 2025
4c35b78
Add device/dtype factory kwargs to beit, efficientformer*, efficientv…
rwightman Sep 29, 2025
3a85ed4
avg pool should not have been passed dd
rwightman Sep 29, 2025
8cbbf39
Fix DarkStage device kwargs
rwightman Sep 29, 2025
1e172a0
dd kwargs for naflexvit, needs revisit for nn.Parameters
rwightman Sep 29, 2025
a7dc50f
A whack of classic convnets converted with dd factory kwargs. densene…
rwightman Sep 29, 2025
068e6d4
Remove **dd from two inception reset_classifier calls
rwightman Sep 29, 2025
6a3342c
dd factory kwargs added to a bunch of vit/vit-hybrids. cait, coat, co…
rwightman Sep 30, 2025
c7955eb
Add dd factory kwargs to all swin transformers and volo
rwightman Sep 30, 2025
53caeb0
Add some more dd kwarg updates, crossvit, ghostnet, rdnet, repghost, …
rwightman Oct 1, 2025
21b1ae7
More dd factory kwargs updates. hiera, hieradet_sam2, metaformer, mlp…
rwightman Oct 1, 2025
5cadf13
More dd arg conversions. fasternet, gcvit, hgnet, nextvit, starnet, v…
rwightman Oct 1, 2025
d3fdea8
Typing, super(), buffer dtype fixes for timm/layers and timm/models
rwightman Oct 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions timm/layers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def swish(x, inplace: bool = False):

class Swish(nn.Module):
def __init__(self, inplace: bool = False):
super(Swish, self).__init__()
super().__init__()
self.inplace = inplace

def forward(self, x):
Expand All @@ -37,7 +37,7 @@ class Mish(nn.Module):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
"""
def __init__(self, inplace: bool = False):
super(Mish, self).__init__()
super().__init__()

def forward(self, x):
return mish(x)
Expand All @@ -50,7 +50,7 @@ def sigmoid(x, inplace: bool = False):
# PyTorch has this, but not with a consistent inplace argument interface
class Sigmoid(nn.Module):
def __init__(self, inplace: bool = False):
super(Sigmoid, self).__init__()
super().__init__()
self.inplace = inplace

def forward(self, x):
Expand All @@ -64,7 +64,7 @@ def tanh(x, inplace: bool = False):
# PyTorch has this, but not with a consistent inplace argument interface
class Tanh(nn.Module):
def __init__(self, inplace: bool = False):
super(Tanh, self).__init__()
super().__init__()
self.inplace = inplace

def forward(self, x):
Expand All @@ -78,7 +78,7 @@ def hard_swish(x, inplace: bool = False):

class HardSwish(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwish, self).__init__()
super().__init__()
self.inplace = inplace

def forward(self, x):
Expand All @@ -94,7 +94,7 @@ def hard_sigmoid(x, inplace: bool = False):

class HardSigmoid(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoid, self).__init__()
super().__init__()
self.inplace = inplace

def forward(self, x):
Expand All @@ -114,7 +114,7 @@ def hard_mish(x, inplace: bool = False):

class HardMish(nn.Module):
def __init__(self, inplace: bool = False):
super(HardMish, self).__init__()
super().__init__()
self.inplace = inplace

def forward(self, x):
Expand All @@ -125,7 +125,7 @@ class PReLU(nn.PReLU):
"""Applies PReLU (w/ dummy inplace arg)
"""
def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
super().__init__(num_parameters=num_parameters, init=init)

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.prelu(input, self.weight)
Expand All @@ -139,7 +139,7 @@ class GELU(nn.Module):
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
"""
def __init__(self, inplace: bool = False):
super(GELU, self).__init__()
super().__init__()

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input)
Expand All @@ -153,7 +153,7 @@ class GELUTanh(nn.Module):
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
"""
def __init__(self, inplace: bool = False):
super(GELUTanh, self).__init__()
super().__init__()

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input, approximate='tanh')
Expand All @@ -167,7 +167,7 @@ class QuickGELU(nn.Module):
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
"""
def __init__(self, inplace: bool = False):
super(QuickGELU, self).__init__()
super().__init__()

def forward(self, input: torch.Tensor) -> torch.Tensor:
return quick_gelu(input)
10 changes: 5 additions & 5 deletions timm/layers/activations_me.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def swish_me(x, inplace=False):

class SwishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(SwishMe, self).__init__()
super().__init__()

def forward(self, x):
return SwishAutoFn.apply(x)
Expand Down Expand Up @@ -86,7 +86,7 @@ def mish_me(x, inplace=False):

class MishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(MishMe, self).__init__()
super().__init__()

def forward(self, x):
return MishAutoFn.apply(x)
Expand Down Expand Up @@ -119,7 +119,7 @@ def hard_sigmoid_me(x, inplace: bool = False):

class HardSigmoidMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoidMe, self).__init__()
super().__init__()

def forward(self, x):
return HardSigmoidAutoFn.apply(x)
Expand Down Expand Up @@ -161,7 +161,7 @@ def hard_swish_me(x, inplace=False):

class HardSwishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwishMe, self).__init__()
super().__init__()

def forward(self, x):
return HardSwishAutoFn.apply(x)
Expand Down Expand Up @@ -199,7 +199,7 @@ def hard_mish_me(x, inplace: bool = False):

class HardMishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardMishMe, self).__init__()
super().__init__()

def forward(self, x):
return HardMishAutoFn.apply(x)
Expand Down
14 changes: 7 additions & 7 deletions timm/layers/adaptive_avgmax_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size: _int_tuple_2_t = 1):

class FastAdaptiveAvgPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: F = 'NCHW'):
super(FastAdaptiveAvgPool, self).__init__()
super().__init__()
self.flatten = flatten
self.dim = get_spatial_dim(input_fmt)

Expand All @@ -67,7 +67,7 @@ def forward(self, x):

class FastAdaptiveMaxPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
super(FastAdaptiveMaxPool, self).__init__()
super().__init__()
self.flatten = flatten
self.dim = get_spatial_dim(input_fmt)

Expand All @@ -77,7 +77,7 @@ def forward(self, x):

class FastAdaptiveAvgMaxPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
super(FastAdaptiveAvgMaxPool, self).__init__()
super().__init__()
self.flatten = flatten
self.dim = get_spatial_dim(input_fmt)

Expand All @@ -89,7 +89,7 @@ def forward(self, x):

class FastAdaptiveCatAvgMaxPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
super(FastAdaptiveCatAvgMaxPool, self).__init__()
super().__init__()
self.flatten = flatten
self.dim_reduce = get_spatial_dim(input_fmt)
if flatten:
Expand All @@ -105,7 +105,7 @@ def forward(self, x):

class AdaptiveAvgMaxPool2d(nn.Module):
def __init__(self, output_size: _int_tuple_2_t = 1):
super(AdaptiveAvgMaxPool2d, self).__init__()
super().__init__()
self.output_size = output_size

def forward(self, x):
Expand All @@ -114,7 +114,7 @@ def forward(self, x):

class AdaptiveCatAvgMaxPool2d(nn.Module):
def __init__(self, output_size: _int_tuple_2_t = 1):
super(AdaptiveCatAvgMaxPool2d, self).__init__()
super().__init__()
self.output_size = output_size

def forward(self, x):
Expand All @@ -131,7 +131,7 @@ def __init__(
flatten: bool = False,
input_fmt: str = 'NCHW',
):
super(SelectAdaptivePool2d, self).__init__()
super().__init__()
assert input_fmt in ('NCHW', 'NHWC')
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
pool_type = pool_type.lower()
Expand Down
32 changes: 19 additions & 13 deletions timm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __init__(
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: Optional[Type[nn.Module]] = None,
device=None,
dtype=None
) -> None:
"""Initialize the Attention module.

Expand All @@ -50,6 +52,7 @@ def __init__(
norm_layer: Normalization layer constructor for QK normalization if enabled
"""
super().__init__()
dd = {'device': device, 'dtype': dtype}
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
if qk_norm or scale_norm:
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
Expand All @@ -58,12 +61,12 @@ def __init__(
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.norm = norm_layer(dim) if scale_norm else nn.Identity()
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.norm = norm_layer(dim, **dd) if scale_norm else nn.Identity()
self.proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
self.proj_drop = nn.Dropout(proj_drop)

def forward(
Expand Down Expand Up @@ -122,6 +125,8 @@ def __init__(
scale_norm: bool = False,
proj_bias: bool = True,
rotate_half: bool = False,
device=None,
dtype=None,
):
"""Initialize the Attention module.

Expand All @@ -140,6 +145,7 @@ def __init__(
rotate_half: Use 'half' ROPE layout instead of default 'interleaved'
"""
super().__init__()
dd = {'device': device, 'dtype': dtype}
if scale_norm or qk_norm:
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
self.num_heads = num_heads
Expand All @@ -153,19 +159,19 @@ def __init__(
self.rotate_half = rotate_half

if qkv_fused:
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias, **dd)
self.q_proj = self.k_proj = self.v_proj = None
else:
self.qkv = None
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)

self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity()
self.proj = nn.Linear(attn_dim, dim, bias=proj_bias)
self.norm = norm_layer(attn_dim, **dd) if scale_norm else nn.Identity()
self.proj = nn.Linear(attn_dim, dim, bias=proj_bias, **dd)
self.proj_drop = nn.Dropout(proj_drop)

def forward(
Expand Down
Loading