diff --git a/timm/layers/activations.py b/timm/layers/activations.py index a863e6964b..af4f76a8fd 100644 --- a/timm/layers/activations.py +++ b/timm/layers/activations.py @@ -19,7 +19,7 @@ def swish(x, inplace: bool = False): class Swish(nn.Module): def __init__(self, inplace: bool = False): - super(Swish, self).__init__() + super().__init__() self.inplace = inplace def forward(self, x): @@ -37,7 +37,7 @@ class Mish(nn.Module): """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 """ def __init__(self, inplace: bool = False): - super(Mish, self).__init__() + super().__init__() def forward(self, x): return mish(x) @@ -50,7 +50,7 @@ def sigmoid(x, inplace: bool = False): # PyTorch has this, but not with a consistent inplace argument interface class Sigmoid(nn.Module): def __init__(self, inplace: bool = False): - super(Sigmoid, self).__init__() + super().__init__() self.inplace = inplace def forward(self, x): @@ -64,7 +64,7 @@ def tanh(x, inplace: bool = False): # PyTorch has this, but not with a consistent inplace argument interface class Tanh(nn.Module): def __init__(self, inplace: bool = False): - super(Tanh, self).__init__() + super().__init__() self.inplace = inplace def forward(self, x): @@ -78,7 +78,7 @@ def hard_swish(x, inplace: bool = False): class HardSwish(nn.Module): def __init__(self, inplace: bool = False): - super(HardSwish, self).__init__() + super().__init__() self.inplace = inplace def forward(self, x): @@ -94,7 +94,7 @@ def hard_sigmoid(x, inplace: bool = False): class HardSigmoid(nn.Module): def __init__(self, inplace: bool = False): - super(HardSigmoid, self).__init__() + super().__init__() self.inplace = inplace def forward(self, x): @@ -114,7 +114,7 @@ def hard_mish(x, inplace: bool = False): class HardMish(nn.Module): def __init__(self, inplace: bool = False): - super(HardMish, self).__init__() + super().__init__() self.inplace = inplace def forward(self, x): @@ -125,7 +125,7 @@ class PReLU(nn.PReLU): """Applies PReLU (w/ dummy inplace arg) """ def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: - super(PReLU, self).__init__(num_parameters=num_parameters, init=init) + super().__init__(num_parameters=num_parameters, init=init) def forward(self, input: torch.Tensor) -> torch.Tensor: return F.prelu(input, self.weight) @@ -139,7 +139,7 @@ class GELU(nn.Module): """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) """ def __init__(self, inplace: bool = False): - super(GELU, self).__init__() + super().__init__() def forward(self, input: torch.Tensor) -> torch.Tensor: return F.gelu(input) @@ -153,7 +153,7 @@ class GELUTanh(nn.Module): """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) """ def __init__(self, inplace: bool = False): - super(GELUTanh, self).__init__() + super().__init__() def forward(self, input: torch.Tensor) -> torch.Tensor: return F.gelu(input, approximate='tanh') @@ -167,7 +167,7 @@ class QuickGELU(nn.Module): """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) """ def __init__(self, inplace: bool = False): - super(QuickGELU, self).__init__() + super().__init__() def forward(self, input: torch.Tensor) -> torch.Tensor: return quick_gelu(input) diff --git a/timm/layers/activations_me.py b/timm/layers/activations_me.py index b0ddd5cb0d..6050511813 100644 --- a/timm/layers/activations_me.py +++ b/timm/layers/activations_me.py @@ -49,7 +49,7 @@ def swish_me(x, inplace=False): class SwishMe(nn.Module): def __init__(self, inplace: bool = False): - super(SwishMe, self).__init__() + super().__init__() def forward(self, x): return SwishAutoFn.apply(x) @@ -86,7 +86,7 @@ def mish_me(x, inplace=False): class MishMe(nn.Module): def __init__(self, inplace: bool = False): - super(MishMe, self).__init__() + super().__init__() def forward(self, x): return MishAutoFn.apply(x) @@ -119,7 +119,7 @@ def hard_sigmoid_me(x, inplace: bool = False): class HardSigmoidMe(nn.Module): def __init__(self, inplace: bool = False): - super(HardSigmoidMe, self).__init__() + super().__init__() def forward(self, x): return HardSigmoidAutoFn.apply(x) @@ -161,7 +161,7 @@ def hard_swish_me(x, inplace=False): class HardSwishMe(nn.Module): def __init__(self, inplace: bool = False): - super(HardSwishMe, self).__init__() + super().__init__() def forward(self, x): return HardSwishAutoFn.apply(x) @@ -199,7 +199,7 @@ def hard_mish_me(x, inplace: bool = False): class HardMishMe(nn.Module): def __init__(self, inplace: bool = False): - super(HardMishMe, self).__init__() + super().__init__() def forward(self, x): return HardMishAutoFn.apply(x) diff --git a/timm/layers/adaptive_avgmax_pool.py b/timm/layers/adaptive_avgmax_pool.py index d0dd58d986..8b03963b7b 100644 --- a/timm/layers/adaptive_avgmax_pool.py +++ b/timm/layers/adaptive_avgmax_pool.py @@ -57,7 +57,7 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size: _int_tuple_2_t = 1): class FastAdaptiveAvgPool(nn.Module): def __init__(self, flatten: bool = False, input_fmt: F = 'NCHW'): - super(FastAdaptiveAvgPool, self).__init__() + super().__init__() self.flatten = flatten self.dim = get_spatial_dim(input_fmt) @@ -67,7 +67,7 @@ def forward(self, x): class FastAdaptiveMaxPool(nn.Module): def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'): - super(FastAdaptiveMaxPool, self).__init__() + super().__init__() self.flatten = flatten self.dim = get_spatial_dim(input_fmt) @@ -77,7 +77,7 @@ def forward(self, x): class FastAdaptiveAvgMaxPool(nn.Module): def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'): - super(FastAdaptiveAvgMaxPool, self).__init__() + super().__init__() self.flatten = flatten self.dim = get_spatial_dim(input_fmt) @@ -89,7 +89,7 @@ def forward(self, x): class FastAdaptiveCatAvgMaxPool(nn.Module): def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'): - super(FastAdaptiveCatAvgMaxPool, self).__init__() + super().__init__() self.flatten = flatten self.dim_reduce = get_spatial_dim(input_fmt) if flatten: @@ -105,7 +105,7 @@ def forward(self, x): class AdaptiveAvgMaxPool2d(nn.Module): def __init__(self, output_size: _int_tuple_2_t = 1): - super(AdaptiveAvgMaxPool2d, self).__init__() + super().__init__() self.output_size = output_size def forward(self, x): @@ -114,7 +114,7 @@ def forward(self, x): class AdaptiveCatAvgMaxPool2d(nn.Module): def __init__(self, output_size: _int_tuple_2_t = 1): - super(AdaptiveCatAvgMaxPool2d, self).__init__() + super().__init__() self.output_size = output_size def forward(self, x): @@ -131,7 +131,7 @@ def __init__( flatten: bool = False, input_fmt: str = 'NCHW', ): - super(SelectAdaptivePool2d, self).__init__() + super().__init__() assert input_fmt in ('NCHW', 'NHWC') self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing pool_type = pool_type.lower() diff --git a/timm/layers/attention.py b/timm/layers/attention.py index 3fbbec342a..21329107fb 100644 --- a/timm/layers/attention.py +++ b/timm/layers/attention.py @@ -36,6 +36,8 @@ def __init__( attn_drop: float = 0., proj_drop: float = 0., norm_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None ) -> None: """Initialize the Attention module. @@ -50,6 +52,7 @@ def __init__( norm_layer: Normalization layer constructor for QK normalization if enabled """ super().__init__() + dd = {'device': device, 'dtype': dtype} assert dim % num_heads == 0, 'dim should be divisible by num_heads' if qk_norm or scale_norm: assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' @@ -58,12 +61,12 @@ def __init__( self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) + self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.norm = norm_layer(dim) if scale_norm else nn.Identity() - self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.norm = norm_layer(dim, **dd) if scale_norm else nn.Identity() + self.proj = nn.Linear(dim, dim, bias=proj_bias, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward( @@ -122,6 +125,8 @@ def __init__( scale_norm: bool = False, proj_bias: bool = True, rotate_half: bool = False, + device=None, + dtype=None, ): """Initialize the Attention module. @@ -140,6 +145,7 @@ def __init__( rotate_half: Use 'half' ROPE layout instead of default 'interleaved' """ super().__init__() + dd = {'device': device, 'dtype': dtype} if scale_norm or qk_norm: assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' self.num_heads = num_heads @@ -153,19 +159,19 @@ def __init__( self.rotate_half = rotate_half if qkv_fused: - self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias, **dd) self.q_proj = self.k_proj = self.v_proj = None else: self.qkv = None - self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) - self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) - self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) + self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd) + self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd) + self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd) - self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity() - self.proj = nn.Linear(attn_dim, dim, bias=proj_bias) + self.norm = norm_layer(attn_dim, **dd) if scale_norm else nn.Identity() + self.proj = nn.Linear(attn_dim, dim, bias=proj_bias, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward( diff --git a/timm/layers/attention2d.py b/timm/layers/attention2d.py index 6a542828bc..d454374a68 100644 --- a/timm/layers/attention2d.py +++ b/timm/layers/attention2d.py @@ -33,8 +33,11 @@ def __init__( value_dim: int = 64, attn_drop: float = 0., proj_drop: float = 0., + device=None, + dtype=None, ): """Initializer.""" + dd = {'device': device, 'dtype': dtype} super().__init__() dim_out = dim_out or dim self.num_heads = num_heads @@ -42,13 +45,22 @@ def __init__( self.value_dim = value_dim self.scale = key_dim ** -0.5 - self.query_proj = nn.Parameter(torch.randn([self.num_heads, self.key_dim, dim])) - self.key_proj = nn.Parameter(torch.randn([dim, self.key_dim])) - self.value_proj = nn.Parameter(torch.randn([dim, self.value_dim])) + self.query_proj = nn.Parameter(torch.empty((self.num_heads, self.key_dim, dim), **dd)) + self.key_proj = nn.Parameter(torch.empty((dim, self.key_dim), **dd)) + self.value_proj = nn.Parameter(torch.empty((dim, self.value_dim), **dd)) self.attn_drop = nn.Dropout(attn_drop) - self.out_proj = nn.Parameter(torch.randn([dim_out, self.num_heads, self.value_dim])) + self.out_proj = nn.Parameter(torch.empty((dim_out, self.num_heads, self.value_dim), **dd)) self.proj_drop = nn.Dropout(proj_drop) + self.reset_parameters() + + def reset_parameters(self): + scale = self.key_proj.shape[0] ** -0.5 + nn.init.normal_(self.query_proj, std=scale) + nn.init.normal_(self.key_proj, std=scale) + nn.init.normal_(self.value_proj, std=scale) + nn.init.normal_(self.out_proj, std=self.out_proj.shape[0] ** -0.5) + def _reshape_input(self, t): """Reshapes a tensor to three dimensions, keeping the first and last.""" s = t.shape @@ -108,6 +120,8 @@ def __init__( proj_drop: float = 0., norm_layer: Type[nn.Module] = nn.BatchNorm2d, use_bias: bool = False, + device=None, + dtype=None, ): """Initializer. @@ -119,6 +133,7 @@ def __init__( kv_stride: Key and value stride size. dw_kernel_size: Spatial dimension of the depthwise kernel. """ + dd = {'device': device, 'dtype': dtype} super().__init__() dim_out = dim_out or dim self.num_heads = num_heads @@ -136,19 +151,20 @@ def __init__( # FIXME dilation if padding == 'same': self.query.add_module('down_pool', create_pool2d( - 'avg', - kernel_size=self.query_strides, - padding='same', + 'avg', + kernel_size=self.query_strides, + padding='same', )) else: # no pad if not 'same' as kern=stride=even self.query.add_module('down_pool', nn.AvgPool2d(kernel_size=query_strides)) - self.query.add_module('norm', norm_layer(dim)) + self.query.add_module('norm', norm_layer(dim, **dd)) self.query.add_module('proj', create_conv2d( dim, self.num_heads * self.key_dim, kernel_size=1, bias=use_bias, + **dd, )) self.key = nn.Sequential() @@ -161,14 +177,16 @@ def __init__( dilation=dilation, padding=padding, depthwise=True, + **dd, )) - self.key.add_module('norm', norm_layer(dim)) + self.key.add_module('norm', norm_layer(dim, **dd)) self.key.add_module('proj', create_conv2d( dim, self.key_dim, kernel_size=1, padding=padding, bias=use_bias, + **dd, )) self.value = nn.Sequential() @@ -181,29 +199,37 @@ def __init__( dilation=dilation, padding=padding, depthwise=True, + **dd, )) - self.value.add_module('norm', norm_layer(dim)) + self.value.add_module('norm', norm_layer(dim, **dd)) self.value.add_module('proj', create_conv2d( dim, self.value_dim, kernel_size=1, bias=use_bias, + **dd, )) self.attn_drop = nn.Dropout(attn_drop) self.output = nn.Sequential() if self.has_query_strides: - self.output.add_module('upsample', nn.Upsample(scale_factor=self.query_strides, mode='bilinear', align_corners=False)) + self.output.add_module('upsample', nn.Upsample( + scale_factor=self.query_strides, + mode='bilinear', + align_corners=False + )) self.output.add_module('proj', create_conv2d( self.value_dim * self.num_heads, dim_out, kernel_size=1, bias=use_bias, + **dd, )) - self.output.add_module('drop', nn.Dropout(proj_drop)) + self.output.add_module('drop', nn.Dropout(proj_drop)) self.einsum = False + self.init_weights() def init_weights(self): # using xavier appeared to improve stability for mobilenetv4 hybrid w/ this layer @@ -304,8 +330,11 @@ def __init__( expand_first: bool = False, head_first: bool = False, attn_drop: float = 0., - proj_drop: float = 0. + proj_drop: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() dim_out = dim_out or dim dim_attn = dim_out if expand_first else dim @@ -314,9 +343,9 @@ def __init__( self.head_first = head_first self.fused_attn = use_fused_attn() - self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) + self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) + self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, attn_mask: Optional[torch.Tensor] = None): diff --git a/timm/layers/attention_pool.py b/timm/layers/attention_pool.py index f464c8a3d1..75ba0e6532 100644 --- a/timm/layers/attention_pool.py +++ b/timm/layers/attention_pool.py @@ -32,7 +32,10 @@ def __init__( norm_layer: Optional[Type[nn.Module]] = None, act_layer: Optional[Type[nn.Module]] = nn.GELU, drop: float = 0.0, + device = None, + dtype = None ): + dd = {'device': device, 'dtype': dtype} super().__init__() embed_dim = embed_dim or in_features out_features = out_features or in_features @@ -46,28 +49,28 @@ def __init__( if pos_embed == 'abs': assert feat_size is not None - self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features)) + self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features, **dd)) else: self.pos_embed = None self.latent_dim = latent_dim or embed_dim self.latent_len = latent_len - self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) + self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim, **dd)) - self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) - self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) + self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias, **dd) + self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias, **dd) if qk_norm: qk_norm_layer = norm_layer or nn.LayerNorm - self.q_norm = qk_norm_layer(self.head_dim) - self.k_norm = qk_norm_layer(self.head_dim) + self.q_norm = qk_norm_layer(self.head_dim, **dd) + self.k_norm = qk_norm_layer(self.head_dim, **dd) else: self.q_norm = nn.Identity() self.k_norm = nn.Identity() - self.proj = nn.Linear(embed_dim, embed_dim) + self.proj = nn.Linear(embed_dim, embed_dim, **dd) self.proj_drop = nn.Dropout(drop) - self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() - self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio), act_layer=act_layer) + self.norm = norm_layer(out_features, **dd) if norm_layer is not None else nn.Identity() + self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio), act_layer=act_layer, **dd) self.init_weights() diff --git a/timm/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py index 7fc3b962eb..cc26aecdf4 100644 --- a/timm/layers/attention_pool2d.py +++ b/timm/layers/attention_pool2d.py @@ -44,7 +44,10 @@ def __init__( pool_type: str = 'token', class_token: bool = False, drop_rate: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() assert pool_type in ('', 'token') self.embed_dim = embed_dim = embed_dim or in_features @@ -64,20 +67,20 @@ def __init__( self.fused_attn = use_fused_attn() if class_token: - self.cls_token = nn.Parameter(torch.zeros(1, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, embed_dim, **dd)) else: self.cls_token = None if qkv_separate: - self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias) - self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias) - self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd) + self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd) + self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd) self.qkv = None else: - self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias, **dd) self.drop = nn.Dropout(drop_rate) - self.proj = nn.Linear(embed_dim, self.out_features) - self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size) + self.proj = nn.Linear(embed_dim, self.out_features, **dd) + self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size, **dd) def init_weights(self, zero_init_last: bool = False): if self.qkv is None: @@ -171,7 +174,10 @@ def __init__( pool_type: str = 'token', class_token: bool = False, drop_rate: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() assert pool_type in ('', 'token') self.embed_dim = embed_dim = embed_dim or in_features @@ -192,21 +198,21 @@ def __init__( self.fused_attn = use_fused_attn() if class_token: - self.cls_token = nn.Parameter(torch.zeros(1, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, embed_dim, **dd)) else: self.cls_token = None if qkv_separate: - self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias) - self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias) - self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd) + self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd) + self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd) self.qkv = None else: self.q = self.k = self.v = None - self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias, **dd) self.drop = nn.Dropout(drop_rate) - self.proj = nn.Linear(embed_dim, self.out_features) - self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features)) + self.proj = nn.Linear(embed_dim, self.out_features, **dd) + self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features, **dd)) self.init_weights() diff --git a/timm/layers/blur_pool.py b/timm/layers/blur_pool.py index b7302d1acb..ec8dcf52c8 100644 --- a/timm/layers/blur_pool.py +++ b/timm/layers/blur_pool.py @@ -7,7 +7,7 @@ """ from functools import partial from math import comb # Python 3.8 -from typing import Optional, Type +from typing import Callable, Optional, Type, Union import torch import torch.nn as nn @@ -36,8 +36,10 @@ def __init__( filt_size: int = 3, stride: int = 2, pad_mode: str = 'reflect', + device=None, + dtype=None ) -> None: - super(BlurPool2d, self).__init__() + super().__init__() assert filt_size > 1 self.channels = channels self.filt_size = filt_size @@ -48,12 +50,18 @@ def __init__( # (0.5 + 0.5 x)^N => coefficients = C(N,k) / 2^N, k = 0..N coeffs = torch.tensor( [comb(filt_size - 1, k) for k in range(filt_size)], + device='cpu', dtype=torch.float32, ) / (2 ** (filt_size - 1)) # normalise so coefficients sum to 1 blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :] if channels is not None: blur_filter = blur_filter.repeat(self.channels, 1, 1, 1) - self.register_buffer('filt', blur_filter, persistent=False) + + self.register_buffer( + 'filt', + blur_filter.to(device=device, dtype=dtype), + persistent=False, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.pad(x, self.padding, mode=self.pad_mode) @@ -66,30 +74,70 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return F.conv2d(x, weight, stride=self.stride, groups=channels) +def _normalize_aa_layer(aa_layer: LayerType) -> Callable[..., nn.Module]: + """Map string shorthands to callables (class or partial).""" + if isinstance(aa_layer, str): + key = aa_layer.lower().replace('_', '').replace('-', '') + if key in ('avg', 'avgpool'): + return nn.AvgPool2d + if key in ('blur', 'blurpool'): + return BlurPool2d + if key == 'blurpc': + # preconfigure a constant-pad BlurPool2d + return partial(BlurPool2d, pad_mode='constant') + raise AssertionError(f"Unknown anti-aliasing layer ({aa_layer}).") + return aa_layer + + +def _underlying_cls(layer_callable: Callable[..., nn.Module]): + """Return the class behind a callable (unwrap partial), else None.""" + if isinstance(layer_callable, partial): + return layer_callable.func + return layer_callable if isinstance(layer_callable, type) else None + + +def _is_blurpool(layer_callable: Callable[..., nn.Module]) -> bool: + """True if callable is BlurPool2d or a partial of it.""" + cls = _underlying_cls(layer_callable) + try: + return issubclass(cls, BlurPool2d) # cls may be None, protect below + except TypeError: + return False + except Exception: + return False + + def create_aa( aa_layer: LayerType, channels: Optional[int] = None, stride: int = 2, enable: bool = True, - noop: Optional[Type[nn.Module]] = nn.Identity -) -> nn.Module: - """ Anti-aliasing """ + noop: Optional[Type[nn.Module]] = nn.Identity, + device=None, + dtype=None, +) -> Optional[nn.Module]: + """ Anti-aliasing factory that supports strings, classes, and partials. """ if not aa_layer or not enable: return noop() if noop is not None else None - if isinstance(aa_layer, str): - aa_layer = aa_layer.lower().replace('_', '').replace('-', '') - if aa_layer == 'avg' or aa_layer == 'avgpool': - aa_layer = nn.AvgPool2d - elif aa_layer == 'blur' or aa_layer == 'blurpool': - aa_layer = BlurPool2d - elif aa_layer == 'blurpc': - aa_layer = partial(BlurPool2d, pad_mode='constant') + # Resolve strings to callables + aa_layer = _normalize_aa_layer(aa_layer) - else: - assert False, f"Unknown anti-aliasing layer ({aa_layer})." + # Build kwargs we *intend* to pass + call_kwargs = {"channels": channels, "stride": stride} + + # Only add device/dtype for BlurPool2d (or partial of it) and don't override if already provided in the partial. + if _is_blurpool(aa_layer): + # Check if aa_layer is a partial and already has device/dtype set + existing_kw = aa_layer.keywords if isinstance(aa_layer, partial) and aa_layer.keywords else {} + if "device" not in existing_kw and device is not None: + call_kwargs["device"] = device + if "dtype" not in existing_kw and dtype is not None: + call_kwargs["dtype"] = dtype + # Try (channels, stride, [device, dtype]) first; fall back to (stride) only try: - return aa_layer(channels=channels, stride=stride) - except TypeError as e: + return aa_layer(**call_kwargs) + except TypeError: + # Some layers (e.g., AvgPool2d) may not accept 'channels' and need stride passed as kernel return aa_layer(stride) diff --git a/timm/layers/bottleneck_attn.py b/timm/layers/bottleneck_attn.py index c3db464e5a..ad2227d450 100644 --- a/timm/layers/bottleneck_attn.py +++ b/timm/layers/bottleneck_attn.py @@ -14,7 +14,7 @@ Hacked together by / Copyright 2021 Ross Wightman """ -from typing import List +from typing import List, Optional, Tuple import torch import torch.nn as nn @@ -58,12 +58,28 @@ class PosEmbedRel(nn.Module): As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 """ - def __init__(self, feat_size, dim_head, scale): + def __init__( + self, + feat_size: Tuple[int, int], + dim_head: int, + scale: float, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.height, self.width = to_2tuple(feat_size) self.dim_head = dim_head - self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * scale) - self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale) + self.scale = scale + + self.height_rel = nn.Parameter(torch.empty(self.height * 2 - 1, dim_head, **dd)) + self.width_rel = nn.Parameter(torch.empty(self.width * 2 - 1, dim_head, **dd)) + + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.normal_(self.height_rel, std=self.scale) + torch.nn.init.normal_(self.width_rel, std=self.scale) def forward(self, q): B, HW, _ = q.shape @@ -104,8 +120,20 @@ class BottleneckAttn(nn.Module): scale_pos_embed (bool): scale the position embedding as well as Q @ K """ def __init__( - self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None, - qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False): + self, + dim: int, + dim_out: Optional[int] = None, + feat_size: Optional[Tuple[int, int]] = None, + stride: int = 1, + num_heads: int = 4, + dim_head: Optional[int] = None, + qk_ratio: float = 1.0, + qkv_bias: bool = False, + scale_pos_embed: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' dim_out = dim_out or dim @@ -118,10 +146,10 @@ def __init__( self.scale = self.dim_head_qk ** -0.5 self.scale_pos_embed = scale_pos_embed - self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) + self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias, **dd) # NOTE I'm only supporting relative pos embedding for now - self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale) + self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale, **dd) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() diff --git a/timm/layers/cbam.py b/timm/layers/cbam.py index 576a8306d9..2bbaf5beae 100644 --- a/timm/layers/cbam.py +++ b/timm/layers/cbam.py @@ -7,6 +7,8 @@ Hacked together by / Copyright 2020 Ross Wightman """ +from typing import Optional, Tuple, Type, Union + import torch from torch import nn as nn import torch.nn.functional as F @@ -20,14 +22,24 @@ class ChannelAttn(nn.Module): """ Original CBAM channel attention module, currently avg + max pool variant only. """ def __init__( - self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, - act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): - super(ChannelAttn, self).__init__() + self, + channels: int, + rd_ratio: float = 1. / 16, + rd_channels: Optional[int] = None, + rd_divisor: int = 1, + act_layer: Type[nn.Module] = nn.ReLU, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + mlp_bias=False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() if not rd_channels: rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) - self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) + self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias, **dd) self.act = act_layer(inplace=True) - self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) + self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias, **dd) self.gate = create_act_layer(gate_layer) def forward(self, x): @@ -40,10 +52,19 @@ class LightChannelAttn(ChannelAttn): """An experimental 'lightweight' that sums avg + max pool first """ def __init__( - self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, - act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): - super(LightChannelAttn, self).__init__( - channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias) + self, + channels: int, + rd_ratio: float = 1./16, + rd_channels: Optional[int] = None, + rd_divisor: int = 1, + act_layer: Type[nn.Module] = nn.ReLU, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + mlp_bias: bool = False, + device=None, + dtype=None + ): + super().__init__( + channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias, device=device, dtype=dtype) def forward(self, x): x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) @@ -54,9 +75,15 @@ def forward(self, x): class SpatialAttn(nn.Module): """ Original CBAM spatial attention module """ - def __init__(self, kernel_size=7, gate_layer='sigmoid'): - super(SpatialAttn, self).__init__() - self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False) + def __init__( + self, + kernel_size: int = 7, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + device=None, + dtype=None, + ): + super().__init__() + self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False, device=device, dtype=dtype) self.gate = create_act_layer(gate_layer) def forward(self, x): @@ -68,9 +95,15 @@ def forward(self, x): class LightSpatialAttn(nn.Module): """An experimental 'lightweight' variant that sums avg_pool and max_pool results. """ - def __init__(self, kernel_size=7, gate_layer='sigmoid'): - super(LightSpatialAttn, self).__init__() - self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False) + def __init__( + self, + kernel_size: int = 7, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + device=None, + dtype=None, + ): + super().__init__() + self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False, device=device, dtype=dtype) self.gate = create_act_layer(gate_layer) def forward(self, x): @@ -81,13 +114,31 @@ def forward(self, x): class CbamModule(nn.Module): def __init__( - self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, - spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): - super(CbamModule, self).__init__() + self, + channels: int, + rd_ratio: float = 1./16, + rd_channels: Optional[int] = None, + rd_divisor: int = 1, + spatial_kernel_size: int = 7, + act_layer: Type[nn.Module] = nn.ReLU, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + mlp_bias: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.channel = ChannelAttn( - channels, rd_ratio=rd_ratio, rd_channels=rd_channels, - rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) - self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) + channels, + rd_ratio=rd_ratio, + rd_channels=rd_channels, + rd_divisor=rd_divisor, + act_layer=act_layer, + gate_layer=gate_layer, + mlp_bias=mlp_bias, + **dd, + ) + self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer, **dd) def forward(self, x): x = self.channel(x) @@ -97,13 +148,31 @@ def forward(self, x): class LightCbamModule(nn.Module): def __init__( - self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, - spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): - super(LightCbamModule, self).__init__() + self, + channels: int, + rd_ratio: float = 1./16, + rd_channels: Optional[int] = None, + rd_divisor: int = 1, + spatial_kernel_size: int = 7, + act_layer: Type[nn.Module] = nn.ReLU, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + mlp_bias: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.channel = LightChannelAttn( - channels, rd_ratio=rd_ratio, rd_channels=rd_channels, - rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) - self.spatial = LightSpatialAttn(spatial_kernel_size) + channels, + rd_ratio=rd_ratio, + rd_channels=rd_channels, + rd_divisor=rd_divisor, + act_layer=act_layer, + gate_layer=gate_layer, + mlp_bias=mlp_bias, + **dd, + ) + self.spatial = LightSpatialAttn(spatial_kernel_size, **dd) def forward(self, x): x = self.channel(x) diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 5e425fe6c8..b68a0bb3af 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -34,13 +34,13 @@ def _create_pool( return global_pool, num_pooled_features -def _create_fc(num_features, num_classes, use_conv=False): +def _create_fc(num_features, num_classes, use_conv=False, device=None, dtype=None): if num_classes <= 0: fc = nn.Identity() # pass-through (no classifier) elif use_conv: - fc = nn.Conv2d(num_features, num_classes, 1, bias=True) + fc = nn.Conv2d(num_features, num_classes, 1, bias=True, device=device, dtype=dtype) else: - fc = nn.Linear(num_features, num_classes, bias=True) + fc = nn.Linear(num_features, num_classes, bias=True, device=device, dtype=dtype) return fc @@ -51,6 +51,8 @@ def create_classifier( use_conv: bool = False, input_fmt: str = 'NCHW', drop_rate: Optional[float] = None, + device=None, + dtype=None, ): global_pool, num_pooled_features = _create_pool( num_features, @@ -63,6 +65,8 @@ def create_classifier( num_pooled_features, num_classes, use_conv=use_conv, + device=device, + dtype=dtype, ) if drop_rate is not None: dropout = nn.Dropout(drop_rate) @@ -81,6 +85,8 @@ def __init__( drop_rate: float = 0., use_conv: bool = False, input_fmt: str = 'NCHW', + device=None, + dtype=None, ): """ Args: @@ -89,7 +95,7 @@ def __init__( pool_type: Global pooling type, pooling disabled if empty string (''). drop_rate: Pre-classifier dropout rate. """ - super(ClassifierHead, self).__init__() + super().__init__() self.in_features = in_features self.use_conv = use_conv self.input_fmt = input_fmt @@ -100,6 +106,8 @@ def __init__( pool_type, use_conv=use_conv, input_fmt=input_fmt, + device=device, + dtype=dtype, ) self.global_pool = global_pool self.drop = nn.Dropout(drop_rate) @@ -107,6 +115,7 @@ def __init__( self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() def reset(self, num_classes: int, pool_type: Optional[str] = None): + # FIXME get current device/dtype for reset? if pool_type is not None and pool_type != self.global_pool.pool_type: self.global_pool, self.fc = create_classifier( self.in_features, @@ -145,6 +154,8 @@ def __init__( drop_rate: float = 0., norm_layer: Union[str, Callable] = 'layernorm2d', act_layer: Union[str, Callable] = 'tanh', + device=None, + dtype=None ): """ Args: @@ -156,6 +167,7 @@ def __init__( norm_layer: Normalization layer type. act_layer: MLP activation layer type (only used if hidden_size is not None). """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.in_features = in_features self.hidden_size = hidden_size @@ -166,20 +178,21 @@ def __init__( linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) - self.norm = norm_layer(in_features) + self.norm = norm_layer(in_features, **dd) self.flatten = nn.Flatten(1) if pool_type else nn.Identity() if hidden_size: self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', linear_layer(in_features, hidden_size)), + ('fc', linear_layer(in_features, hidden_size, **dd)), ('act', act_layer()), ])) self.num_features = hidden_size else: self.pre_logits = nn.Identity() self.drop = nn.Dropout(drop_rate) - self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.fc = linear_layer(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity() def reset(self, num_classes: int, pool_type: Optional[str] = None): + # FIXME handle device/dtype on reset if pool_type is not None: self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) self.flatten = nn.Flatten(1) if pool_type else nn.Identity() @@ -220,6 +233,8 @@ def __init__( norm_layer: Union[str, Callable] = 'layernorm', act_layer: Union[str, Callable] = 'gelu', input_fmt: str = 'NHWC', + device=None, + dtype=None, ): """ Args: @@ -231,6 +246,7 @@ def __init__( norm_layer: Normalization layer type. act_layer: MLP activation layer type (only used if hidden_size is not None). """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.in_features = in_features self.hidden_size = hidden_size @@ -242,19 +258,20 @@ def __init__( norm_layer = get_norm_layer(norm_layer) act_layer = get_act_layer(act_layer) - self.norm = norm_layer(in_features) + self.norm = norm_layer(in_features, **dd) if hidden_size: self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(in_features, hidden_size)), + ('fc', nn.Linear(in_features, hidden_size, **dd)), ('act', act_layer()), ])) self.num_features = hidden_size else: self.pre_logits = nn.Identity() self.drop = nn.Dropout(drop_rate) - self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.fc = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity() def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False): + # FIXME extract dd on reset if pool_type is not None: self.pool_type = pool_type if reset_other: diff --git a/timm/layers/cond_conv2d.py b/timm/layers/cond_conv2d.py index bdeb8666f0..8540674bc5 100644 --- a/timm/layers/cond_conv2d.py +++ b/timm/layers/cond_conv2d.py @@ -8,6 +8,8 @@ import math from functools import partial +from typing import Union, Tuple + import torch from torch import nn as nn from torch.nn import functional as F @@ -41,9 +43,22 @@ class CondConv2d(nn.Module): """ __constants__ = ['in_channels', 'out_channels', 'dynamic_padding'] - def __init__(self, in_channels, out_channels, kernel_size=3, - stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): - super(CondConv2d, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]] = 3, + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int], str] = '', + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = False, + num_experts: int = 4, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -61,11 +76,11 @@ def __init__(self, in_channels, out_channels, kernel_size=3, weight_num_param = 1 for wd in self.weight_shape: weight_num_param *= wd - self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + self.weight = torch.nn.Parameter(torch.empty(self.num_experts, weight_num_param, **dd)) if bias: self.bias_shape = (self.out_channels,) - self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + self.bias = torch.nn.Parameter(torch.empty(self.num_experts, self.out_channels, **dd)) else: self.register_parameter('bias', None) diff --git a/timm/layers/conv2d_same.py b/timm/layers/conv2d_same.py index f3d18495cc..0647004aaa 100644 --- a/timm/layers/conv2d_same.py +++ b/timm/layers/conv2d_same.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from typing import Tuple, Optional +from typing import Tuple, Optional, Union from ._fx import register_notrace_module from .config import is_exportable, is_scriptable @@ -35,24 +35,39 @@ class Conv2dSame(nn.Conv2d): def __init__( self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int], str] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + device=None, + dtype=None, + ): + super().__init__( in_channels, out_channels, kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - super(Conv2dSame, self).__init__( - in_channels, out_channels, kernel_size, - stride, 0, dilation, groups, bias, + stride, + 0, # padding + dilation, + groups, + bias, + device=device, + dtype=dtype, ) def forward(self, x): return conv2d_same( - x, self.weight, self.bias, - self.stride, self.padding, self.dilation, self.groups, + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, ) @@ -65,18 +80,28 @@ class Conv2dSameExport(nn.Conv2d): # pylint: disable=unused-argument def __init__( self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int], str] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + device=None, + dtype=None, + ): + super().__init__( in_channels, out_channels, kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): - super(Conv2dSameExport, self).__init__( - in_channels, out_channels, kernel_size, - stride, 0, dilation, groups, bias, + stride, + 0, # padding + dilation, + groups, + bias, + device=device, + dtype=dtype, ) self.pad = None self.pad_input_size = (0, 0) @@ -90,8 +115,13 @@ def forward(self, x): x = self.pad(x) return F.conv2d( - x, self.weight, self.bias, - self.stride, self.padding, self.dilation, self.groups, + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, ) diff --git a/timm/layers/conv_bn_act.py b/timm/layers/conv_bn_act.py index 64edf54a3c..503e1df9c2 100644 --- a/timm/layers/conv_bn_act.py +++ b/timm/layers/conv_bn_act.py @@ -32,10 +32,13 @@ def __init__( conv_kwargs: Optional[Dict[str, Any]] = None, norm_kwargs: Optional[Dict[str, Any]] = None, act_kwargs: Optional[Dict[str, Any]] = None, + device=None, + dtype=None, ): - super(ConvNormAct, self).__init__() - conv_kwargs = conv_kwargs or {} - norm_kwargs = norm_kwargs or {} + dd = {'device': device, 'dtype': dtype} + super().__init__() + conv_kwargs = {**dd, **(conv_kwargs or {})} + norm_kwargs = {**dd, **(norm_kwargs or {})} act_kwargs = act_kwargs or {} use_aa = aa_layer is not None and stride > 1 @@ -69,7 +72,14 @@ def __init__( norm_kwargs['drop_layer'] = drop_layer self.bn.add_module('drop', drop_layer()) - self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa, noop=None) + self.aa = create_aa( + aa_layer, + out_channels, + stride=stride, + enable=use_aa, + noop=None, + **dd, + ) @property def in_channels(self): diff --git a/timm/layers/drop.py b/timm/layers/drop.py index a2e59dcfa0..73a6e1dea9 100644 --- a/timm/layers/drop.py +++ b/timm/layers/drop.py @@ -129,7 +129,7 @@ def __init__( inplace: bool = False, batchwise: bool = False, fast: bool = True): - super(DropBlock2d, self).__init__() + super().__init__() self.drop_prob = drop_prob self.gamma_scale = gamma_scale self.block_size = block_size @@ -173,7 +173,7 @@ class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): - super(DropPath, self).__init__() + super().__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep diff --git a/timm/layers/eca.py b/timm/layers/eca.py index e29be6ac3c..cf42fa3360 100644 --- a/timm/layers/eca.py +++ b/timm/layers/eca.py @@ -33,11 +33,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from typing import Optional, Tuple, Type, Union import math + from torch import nn import torch.nn.functional as F - from .create_act import create_act_layer from .helpers import make_divisible @@ -58,9 +59,22 @@ class EcaModule(nn.Module): gate_layer: gating non-linearity to use """ def __init__( - self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid', - rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False): - super(EcaModule, self).__init__() + self, + channels: Optional[int] = None, + kernel_size: int = 3, + gamma: float = 2, + beta: float = 1, + act_layer: Optional[Type[nn.Module]] = None, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + rd_ratio: float = 1/8, + rd_channels: Optional[int] = None, + rd_divisor: int = 8, + use_mlp: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() if channels is not None: t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) @@ -72,11 +86,11 @@ def __init__( if rd_channels is None: rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor) act_layer = act_layer or nn.ReLU - self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True) + self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True, **dd) self.act = create_act_layer(act_layer) - self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True) + self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True, **dd) else: - self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False, **dd) self.act = None self.conv2 = None self.gate = create_act_layer(gate_layer) @@ -118,8 +132,19 @@ class CecaModule(nn.Module): gate_layer: gating non-linearity to use """ - def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'): - super(CecaModule, self).__init__() + def __init__( + self, + channels: Optional[int] = None, + kernel_size: int = 3, + gamma: float = 2, + beta: float = 1, + act_layer: Optional[nn.Module] = None, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() if channels is not None: t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) @@ -130,7 +155,7 @@ def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None # see https://github.com/pytorch/pytorch/pull/17240 # implement manual circular padding self.padding = (kernel_size - 1) // 2 - self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act) + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act, **dd) self.gate = create_act_layer(gate_layer) def forward(self, x): diff --git a/timm/layers/evo_norm.py b/timm/layers/evo_norm.py index ea77620712..0fca95f054 100644 --- a/timm/layers/evo_norm.py +++ b/timm/layers/evo_norm.py @@ -23,7 +23,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ -from typing import Sequence, Union +from typing import Optional, Sequence, Type, Union import torch import torch.nn as nn @@ -97,15 +97,26 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5): class EvoNorm2dB0(nn.Module): - def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-3, **_): + def __init__( + self, + num_features: int, + apply_act: bool = True, + momentum: float = 0.1, + eps: float = 1e-3, + device=None, + dtype=None, + **_ + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum self.eps = eps - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None - self.register_buffer('running_var', torch.ones(num_features)) + self.weight = nn.Parameter(torch.empty(num_features, **dd)) + self.bias = nn.Parameter(torch.empty(num_features, **dd)) + self.v = nn.Parameter(torch.empty(num_features, **dd)) if apply_act else None + self.register_buffer('running_var', torch.ones(num_features, **dd)) + self.reset_parameters() def reset_parameters(self): @@ -136,14 +147,25 @@ def forward(self, x): class EvoNorm2dB1(nn.Module): - def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): + def __init__( + self, + num_features: int, + apply_act: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, + device=None, + dtype=None, + **_ + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum self.eps = eps - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - self.register_buffer('running_var', torch.ones(num_features)) + self.weight = nn.Parameter(torch.empty(num_features, **dd)) + self.bias = nn.Parameter(torch.empty(num_features, **dd)) + self.register_buffer('running_var', torch.ones(num_features, **dd)) + self.reset_parameters() def reset_parameters(self): @@ -171,14 +193,25 @@ def forward(self, x): class EvoNorm2dB2(nn.Module): - def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): + def __init__( + self, + num_features: int, + apply_act: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, + device=None, + dtype=None, + **_ + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum self.eps = eps - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - self.register_buffer('running_var', torch.ones(num_features)) + self.weight = nn.Parameter(torch.empty(num_features, **dd)) + self.bias = nn.Parameter(torch.empty(num_features, **dd)) + self.register_buffer('running_var', torch.ones(num_features, **dd)) + self.reset_parameters() def reset_parameters(self): @@ -206,7 +239,18 @@ def forward(self, x): class EvoNorm2dS0(nn.Module): - def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_): + def __init__( + self, + num_features: int, + groups: int = 32, + group_size: Optional[int] = None, + apply_act: bool = True, + eps: float = 1e-5, + device=None, + dtype=None, + **_ + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.apply_act = apply_act # apply activation (non-linearity) if group_size: @@ -215,9 +259,10 @@ def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps else: self.groups = groups self.eps = eps - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None + self.weight = nn.Parameter(torch.empty(num_features, **dd)) + self.bias = nn.Parameter(torch.empty(num_features, **dd)) + self.v = nn.Parameter(torch.empty(num_features, **dd)) if apply_act else None + self.reset_parameters() def reset_parameters(self): @@ -237,9 +282,26 @@ def forward(self, x): class EvoNorm2dS0a(EvoNorm2dS0): - def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-3, **_): + def __init__( + self, + num_features: int, + groups: int = 32, + group_size: Optional[int] = None, + apply_act: bool = True, + eps: float = 1e-3, + device=None, + dtype=None, + **_ + ): super().__init__( - num_features, groups=groups, group_size=group_size, apply_act=apply_act, eps=eps) + num_features, + groups=groups, + group_size=group_size, + apply_act=apply_act, + eps=eps, + device=device, + dtype=dtype, + ) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') @@ -255,8 +317,18 @@ def forward(self, x): class EvoNorm2dS1(nn.Module): def __init__( - self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=None, eps=1e-5, **_): + self, + num_features: int, + groups: int = 32, + group_size: Optional[int] = None, + apply_act: bool = True, + act_layer: Optional[Type[nn.Module]] = None, + eps: float = 1e-5, + device=None, + dtype=None, + **_ + ): + dd = {'device': device, 'dtype': dtype} super().__init__() act_layer = act_layer or nn.SiLU self.apply_act = apply_act # apply activation (non-linearity) @@ -271,8 +343,9 @@ def __init__( self.groups = groups self.eps = eps self.pre_act_norm = False - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) + self.weight = nn.Parameter(torch.empty(num_features, **dd)) + self.bias = nn.Parameter(torch.empty(num_features, **dd)) + self.reset_parameters() def reset_parameters(self): @@ -290,10 +363,27 @@ def forward(self, x): class EvoNorm2dS1a(EvoNorm2dS1): def __init__( - self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=None, eps=1e-3, **_): + self, + num_features: int, + groups: int = 32, + group_size: Optional[int] = None, + apply_act: bool = True, + act_layer: Optional[Type[nn.Module]] = None, + eps: float = 1e-3, + device=None, + dtype=None, + **_ + ): super().__init__( - num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) + num_features, + groups=groups, + group_size=group_size, + apply_act=apply_act, + act_layer=act_layer, + eps=eps, + device=device, + dtype=dtype, + ) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') @@ -305,8 +395,18 @@ def forward(self, x): class EvoNorm2dS2(nn.Module): def __init__( - self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=None, eps=1e-5, **_): + self, + num_features: int, + groups: int = 32, + group_size: Optional[int] = None, + apply_act: bool = True, + act_layer: Optional[Type[nn.Module]] = None, + eps: float = 1e-5, + device=None, + dtype=None, + **_ + ): + dd = {'device': device, 'dtype': dtype} super().__init__() act_layer = act_layer or nn.SiLU self.apply_act = apply_act # apply activation (non-linearity) @@ -320,8 +420,9 @@ def __init__( else: self.groups = groups self.eps = eps - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) + self.weight = nn.Parameter(torch.empty(num_features, **dd)) + self.bias = nn.Parameter(torch.empty(num_features, **dd)) + self.reset_parameters() def reset_parameters(self): @@ -339,10 +440,27 @@ def forward(self, x): class EvoNorm2dS2a(EvoNorm2dS2): def __init__( - self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=None, eps=1e-3, **_): + self, + num_features: int, + groups: int = 32, + group_size: Optional[int] = None, + apply_act: bool = True, + act_layer: Optional[Type[nn.Module]] = None, + eps: float = 1e-3, + device=None, + dtype=None, + **_ + ): super().__init__( - num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) + num_features, + groups=groups, + group_size=group_size, + apply_act=apply_act, + act_layer=act_layer, + eps=eps, + device=device, + dtype=dtype, + ) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') diff --git a/timm/layers/filter_response_norm.py b/timm/layers/filter_response_norm.py index a66a1cd493..1d4898d890 100644 --- a/timm/layers/filter_response_norm.py +++ b/timm/layers/filter_response_norm.py @@ -4,6 +4,8 @@ Hacked together by / Copyright 2021 Ross Wightman """ +from typing import Optional, Type + import torch import torch.nn as nn @@ -17,14 +19,25 @@ def inv_instance_rms(x, eps: float = 1e-5): class FilterResponseNormTlu2d(nn.Module): - def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): - super(FilterResponseNormTlu2d, self).__init__() + def __init__( + self, + num_features: int, + apply_act: bool = True, + eps: float = 1e-5, + rms: bool = True, + device=None, + dtype=None, + **_, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.apply_act = apply_act # apply activation (non-linearity) self.rms = rms self.eps = eps - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None + self.weight = nn.Parameter(torch.empty(num_features, **dd)) + self.bias = nn.Parameter(torch.empty(num_features, **dd)) + self.tau = nn.Parameter(torch.empty(num_features, **dd)) if apply_act else None + self.reset_parameters() def reset_parameters(self): @@ -43,16 +56,29 @@ def forward(self, x): class FilterResponseNormAct2d(nn.Module): - def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_): - super(FilterResponseNormAct2d, self).__init__() + def __init__( + self, + num_features: int, + apply_act: bool = True, + act_layer: Type[nn.Module] = nn.ReLU, + inplace: Optional[bool] = None, + rms: bool = True, + eps: float = 1e-5, + device=None, + dtype=None, + **_, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() if act_layer is not None and apply_act: self.act = create_act_layer(act_layer, inplace=inplace) else: self.act = nn.Identity() self.rms = rms self.eps = eps - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) + self.weight = nn.Parameter(torch.empty(num_features, **dd)) + self.bias = nn.Parameter(torch.empty(num_features, **dd)) + self.reset_parameters() def reset_parameters(self): diff --git a/timm/layers/gather_excite.py b/timm/layers/gather_excite.py index 2d60dc961e..36105c9ca4 100644 --- a/timm/layers/gather_excite.py +++ b/timm/layers/gather_excite.py @@ -11,6 +11,7 @@ Hacked together by / Copyright 2021 Ross Wightman """ +from typing import Optional, Tuple, Type, Union import math from torch import nn as nn @@ -26,10 +27,24 @@ class GatherExcite(nn.Module): """ Gather-Excite Attention Module """ def __init__( - self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, - rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): - super(GatherExcite, self).__init__() + self, + channels: int, + feat_size: Optional[Tuple[int, int]] = None, + extra_params: bool = False, + extent: int = 0, + use_mlp: bool = True, + rd_ratio: float = 1./16, + rd_channels: Optional[int] = None, + rd_divisor: int = 1, + add_maxpool: bool = False, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.add_maxpool = add_maxpool act_layer = get_act_layer(act_layer) self.extent = extent @@ -38,18 +53,18 @@ def __init__( if extent == 0: assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' self.gather.add_module( - 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) + 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True, *dd)) if norm_layer: - self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) + self.gather.add_module(f'norm1', nn.BatchNorm2d(channels, *dd)) else: assert extent % 2 == 0 num_conv = int(math.log2(extent)) for i in range(num_conv): self.gather.add_module( f'conv{i + 1}', - create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) + create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True, *dd)) if norm_layer: - self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) + self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels, *dd)) if i != num_conv - 1: self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) else: @@ -64,7 +79,7 @@ def __init__( if not rd_channels: rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) - self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() + self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer, *dd) if use_mlp else nn.Identity() self.gate = create_act_layer(gate_layer) def forward(self, x): diff --git a/timm/layers/global_context.py b/timm/layers/global_context.py index de7fb5c15f..16d4ee184b 100644 --- a/timm/layers/global_context.py +++ b/timm/layers/global_context.py @@ -7,6 +7,8 @@ Hacked together by / Copyright 2021 Ross Wightman """ +from typing import Optional, Tuple, Type, Union + from torch import nn as nn import torch.nn.functional as F @@ -18,26 +20,41 @@ class GlobalContext(nn.Module): - def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, - rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): - super(GlobalContext, self).__init__() + def __init__( + self, + channels: int, + use_attn: bool = True, + fuse_add: bool = False, + fuse_scale: bool = True, + init_last_zero: bool = False, + rd_ratio: float = 1./8, + rd_channels: Optional[int] = None, + rd_divisor: int = 1, + act_layer: Type[nn.Module] = nn.ReLU, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + device=None, + dtype=None + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() act_layer = get_act_layer(act_layer) - self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None + self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True, **dd) if use_attn else None if rd_channels is None: rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) if fuse_add: - self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d, **dd) else: self.mlp_add = None if fuse_scale: - self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d, **dd) else: self.mlp_scale = None self.gate = create_act_layer(gate_layer) self.init_last_zero = init_last_zero + self.reset_parameters() def reset_parameters(self): diff --git a/timm/layers/grn.py b/timm/layers/grn.py index ae71e013fc..90503c94e5 100644 --- a/timm/layers/grn.py +++ b/timm/layers/grn.py @@ -18,7 +18,15 @@ class GlobalResponseNorm(nn.Module): """ Global Response Normalization layer """ - def __init__(self, dim, eps=1e-6, channels_last=True): + def __init__( + self, + dim: int, + eps: float = 1e-6, + channels_last: bool = True, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.eps = eps if channels_last: @@ -30,8 +38,8 @@ def __init__(self, dim, eps=1e-6, channels_last=True): self.channel_dim = 1 self.wb_shape = (1, -1, 1, 1) - self.weight = nn.Parameter(torch.zeros(dim)) - self.bias = nn.Parameter(torch.zeros(dim)) + self.weight = nn.Parameter(torch.zeros(dim, **dd)) + self.bias = nn.Parameter(torch.zeros(dim, **dd)) def forward(self, x): x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True) diff --git a/timm/layers/halo_attn.py b/timm/layers/halo_attn.py index f2ac64f85e..3f75401bde 100644 --- a/timm/layers/halo_attn.py +++ b/timm/layers/halo_attn.py @@ -16,7 +16,7 @@ Hacked together by / Copyright 2021 Ross Wightman """ -from typing import List +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -64,19 +64,36 @@ class PosEmbedRel(nn.Module): Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 """ - def __init__(self, block_size, win_size, dim_head, scale): + def __init__( + self, + block_size: int, + win_size: int, + dim_head: int, + scale: float, + device=None, + dtype=None, + ): """ Args: - block_size (int): block size - win_size (int): neighbourhood window size - dim_head (int): attention head dim - scale (float): scale factor (for init) + block_size: block size + win_size: neighbourhood window size + dim_head: attention head dim + scale: scale factor (for init) """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.block_size = block_size self.dim_head = dim_head - self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) - self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) + self.scale = scale + + self.height_rel = nn.Parameter(torch.empty(win_size * 2 - 1, dim_head, **dd)) + self.width_rel = nn.Parameter(torch.empty(win_size * 2 - 1, dim_head, **dd)) + + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.normal_(self.height_rel, std=self.scale) + torch.nn.init.normal_(self.width_rel, std=self.scale) def forward(self, q): B, BB, HW, _ = q.shape @@ -123,8 +140,23 @@ class HaloAttn(nn.Module): scale_pos_embed (bool): scale the position embedding as well as Q @ K """ def __init__( - self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, - qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False): + self, + dim: int, + dim_out: Optional[int] = None, + feat_size: Optional[Tuple[int, int]] = None, + stride: int = 1, + num_heads: int = 8, + dim_head: Optional[int] = None, + block_size: int = 8, + halo_size: int = 3, + qk_ratio: float = 1.0, + qkv_bias: bool = False, + avg_down: bool = False, + scale_pos_embed: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() dim_out = dim_out or dim assert dim_out % num_heads == 0 @@ -149,11 +181,16 @@ def __init__( # FIXME not clear if this stride behaviour is what the paper intended # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving # data in unfolded block form. I haven't wrapped my head around how that'd look. - self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias) - self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias) + self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias, **dd) + self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias, **dd) self.pos_embed = PosEmbedRel( - block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale) + block_size=self.block_size_ds, + win_size=self.win_size, + dim_head=self.dim_head_qk, + scale=self.scale, + **dd, + ) self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity() diff --git a/timm/layers/hybrid_embed.py b/timm/layers/hybrid_embed.py index de57a2e9da..e914f883b8 100644 --- a/timm/layers/hybrid_embed.py +++ b/timm/layers/hybrid_embed.py @@ -40,7 +40,10 @@ def __init__( output_fmt: Optional[str] = None, strict_img_size: bool = True, dynamic_img_pad: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() assert isinstance(backbone, nn.Module) self.backbone = backbone @@ -58,6 +61,7 @@ def __init__( patch_size=patch_size, feature_size=feature_size, feature_ratio=feature_ratio, + **dd, ) if output_fmt is not None: @@ -79,6 +83,7 @@ def __init__( kernel_size=patch_size, stride=patch_size, bias=bias, + **dd, ) else: assert self.feature_dim == embed_dim, \ @@ -92,6 +97,8 @@ def _init_backbone( feature_size: Optional[Union[int, Tuple[int, int]]] = None, feature_ratio: Optional[Union[int, Tuple[int, int]]] = None, feature_dim: Optional[int] = None, + device=None, + dtype=None, ): img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -101,7 +108,8 @@ def _init_backbone( training = self.backbone.training if training: self.backbone.eval() - o = self.backbone(torch.zeros(1, self.in_chans, img_size[0], img_size[1])) + # FIXME whatif meta device? + o = self.backbone(torch.zeros(1, self.in_chans, img_size[0], img_size[1], device=device, dtype=dtype)) if isinstance(o, (list, tuple)): o = o[-1] # last feature if backbone outputs list/tuple of features feature_size = o.shape[-2:] @@ -142,6 +150,8 @@ def set_input_size( kernel_size=new_patch_size, stride=new_patch_size, bias=self.proj.bias is not None, + device=self.proj.device, + dtype=self.proj.dtype, ) new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True)) if self.proj.bias is not None: @@ -165,6 +175,7 @@ def set_input_size( feature_size=feature_size, feature_ratio=feature_ratio, feature_dim=feature_dim, + # FIXME device/dtype? ) def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: @@ -225,6 +236,8 @@ def __init__( embed_dim: int = 768, bias=True, proj=True, + device=None, + dtype=None, ): super().__init__( backbone=backbone, @@ -236,6 +249,8 @@ def __init__( embed_dim=embed_dim, bias=bias, proj=proj, + device=device, + dtype=dtype, ) @torch.jit.ignore diff --git a/timm/layers/inplace_abn.py b/timm/layers/inplace_abn.py index 74fefef88a..1705eea9b6 100644 --- a/timm/layers/inplace_abn.py +++ b/timm/layers/inplace_abn.py @@ -40,9 +40,18 @@ class InplaceAbn(nn.Module): Negative slope for the `leaky_relu` activation. """ - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, - act_layer="leaky_relu", act_param=0.01, drop_layer=None): - super(InplaceAbn, self).__init__() + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + apply_act=True, + act_layer="leaky_relu", + act_param=0.01, + drop_layer=None, + ): + super().__init__() self.num_features = num_features self.affine = affine self.eps = eps diff --git a/timm/layers/lambda_layer.py b/timm/layers/lambda_layer.py index 9192e266e6..e9ef45f602 100644 --- a/timm/layers/lambda_layer.py +++ b/timm/layers/lambda_layer.py @@ -20,6 +20,8 @@ Hacked together by / Copyright 2021 Ross Wightman """ +from typing import Optional, Tuple + import torch from torch import nn import torch.nn.functional as F @@ -29,9 +31,12 @@ from .weight_init import trunc_normal_ -def rel_pos_indices(size): +def rel_pos_indices(size, device=None): size = to_2tuple(size) - pos = torch.stack(ndgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1) + pos = torch.stack(ndgrid( + torch.arange(size[0], device=device, dtype=torch.long), + torch.arange(size[1], device=device, dtype=torch.long), + )).flatten(1) rel_pos = pos[:, None, :] - pos[:, :, None] rel_pos[0] += size[0] - 1 rel_pos[1] += size[1] - 1 @@ -55,19 +60,31 @@ class LambdaLayer(nn.Module): * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set Args: - dim (int): input dimension to the module - dim_out (int): output dimension of the module, same as dim if not set - feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W - stride (int): output stride of the module, avg pool used if stride == 2 - num_heads (int): parallel attention heads. - dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set - r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9) - qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) - qkv_bias (bool): add bias to q, k, and v projections + dim: input dimension to the module + dim_out: output dimension of the module, same as dim if not set + feat_size: size of input feature_map for relative pos variant H, W + stride: output stride of the module, avg pool used if stride == 2 + num_heads: parallel attention heads. + dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set + r: local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9) + qk_ratio: ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) + qkv_bias: add bias to q, k, and v projections """ def __init__( - self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9, - qk_ratio=1.0, qkv_bias=False): + self, + dim: int, + dim_out: Optional[int] = None, + feat_size: Optional[Tuple[int, int]] = None, + stride: int = 1, + num_heads: int = 4, + dim_head: int = 16, + r: int = 9, + qk_ratio: float = 1.0, + qkv_bias: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() dim_out = dim_out or dim assert dim_out % num_heads == 0, ' should be divided by num_heads' @@ -78,13 +95,16 @@ def __init__( self.qkv = nn.Conv2d( dim, num_heads * self.dim_qk + self.dim_qk + self.dim_v, - kernel_size=1, bias=qkv_bias) - self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk) - self.norm_v = nn.BatchNorm2d(self.dim_v) + kernel_size=1, + bias=qkv_bias, + **dd, + ) + self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk, **dd) + self.norm_v = nn.BatchNorm2d(self.dim_v, **dd) if r is not None: # local lambda convolution for pos - self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0)) + self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0), **dd) self.pos_emb = None self.rel_pos_indices = None else: @@ -93,8 +113,12 @@ def __init__( feat_size = to_2tuple(feat_size) rel_size = [2 * s - 1 for s in feat_size] self.conv_lambda = None - self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk)) - self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) + self.pos_emb = nn.Parameter(torch.empty(rel_size[0], rel_size[1], self.dim_qk, **dd)) + self.register_buffer( + 'rel_pos_indices', + rel_pos_indices(feat_size, device=device), + persistent=False, + ) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() diff --git a/timm/layers/layer_scale.py b/timm/layers/layer_scale.py index 08566b2bd1..123073bcd1 100644 --- a/timm/layers/layer_scale.py +++ b/timm/layers/layer_scale.py @@ -10,10 +10,17 @@ def __init__( dim: int, init_values: float = 1e-5, inplace: bool = False, + device=None, + dtype=None, ) -> None: super().__init__() self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) + self.gamma = nn.Parameter(init_values * torch.empty(dim, device=device, dtype=dtype)) + + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.gamma) def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma @@ -27,10 +34,17 @@ def __init__( dim: int, init_values: float = 1e-5, inplace: bool = False, + device=None, + dtype=None, ): super().__init__() self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) + self.gamma = nn.Parameter(init_values * torch.empty(dim, device=device, dtype=dtype)) + + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.gamma) def forward(self, x): gamma = self.gamma.view(1, -1, 1, 1) diff --git a/timm/layers/median_pool.py b/timm/layers/median_pool.py index 40bd71a7a3..3b83589295 100644 --- a/timm/layers/median_pool.py +++ b/timm/layers/median_pool.py @@ -16,7 +16,7 @@ class MedianPool2d(nn.Module): same: override padding and enforce same padding, boolean """ def __init__(self, kernel_size=3, stride=1, padding=0, same=False): - super(MedianPool2d, self).__init__() + super().__init__() self.k = to_2tuple(kernel_size) self.stride = to_2tuple(stride) self.padding = to_4tuple(padding) # convert to l, r, t, b diff --git a/timm/layers/mixed_conv2d.py b/timm/layers/mixed_conv2d.py index fa0ce565c0..6f17510a3b 100644 --- a/timm/layers/mixed_conv2d.py +++ b/timm/layers/mixed_conv2d.py @@ -4,6 +4,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ +from typing import List, Union import torch from torch import nn as nn @@ -23,9 +24,18 @@ class MixedConv2d(nn.ModuleDict): Based on MDConv and GroupedConv in MixNet impl: https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py """ - def __init__(self, in_channels, out_channels, kernel_size=3, - stride=1, padding='', dilation=1, depthwise=False, **kwargs): - super(MixedConv2d, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, List[int]] = 3, + stride: int = 1, + padding: str = '', + dilation: int = 1, + depthwise: bool = False, + **kwargs + ): + super().__init__() kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] num_groups = len(kernel_size) @@ -39,8 +49,15 @@ def __init__(self, in_channels, out_channels, kernel_size=3, self.add_module( str(idx), create_conv2d_pad( - in_ch, out_ch, k, stride=stride, - padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + in_ch, + out_ch, + k, + stride=stride, + padding=padding, + dilation=dilation, + groups=conv_groups, + **kwargs, + ) ) self.splits = in_splits diff --git a/timm/layers/ml_decoder.py b/timm/layers/ml_decoder.py index cd7d506207..9f2f9902b3 100644 --- a/timm/layers/ml_decoder.py +++ b/timm/layers/ml_decoder.py @@ -35,7 +35,7 @@ def add_ml_decoder_head(model): class TransformerDecoderLayerOptimal(nn.Module): def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu", layer_norm_eps=1e-5) -> None: - super(TransformerDecoderLayerOptimal, self).__init__() + super().__init__() self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.dropout = nn.Dropout(dropout) self.dropout1 = nn.Dropout(dropout) @@ -89,7 +89,7 @@ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None class MLDecoder(nn.Module): def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048): - super(MLDecoder, self).__init__() + super().__init__() embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups if embed_len_decoder > num_classes: embed_len_decoder = num_classes diff --git a/timm/layers/mlp.py b/timm/layers/mlp.py index a8b1cc0d2e..9523ee1631 100644 --- a/timm/layers/mlp.py +++ b/timm/layers/mlp.py @@ -3,6 +3,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ from functools import partial +from typing import Optional, Type, Union, Tuple from torch import nn as nn @@ -17,15 +18,18 @@ class Mlp(nn.Module): """ def __init__( self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - norm_layer=None, - bias=True, - drop=0., - use_conv=False, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Optional[Type[nn.Module]] = None, + bias: Union[bool, Tuple[bool, bool]] = True, + drop: Union[float, Tuple[float, float]] = 0., + use_conv: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -33,11 +37,11 @@ def __init__( drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear - self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0], **dd) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) - self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() - self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.norm = norm_layer(hidden_features, **dd) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1], **dd) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): @@ -58,16 +62,19 @@ class GluMlp(nn.Module): """ def __init__( self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.Sigmoid, - norm_layer=None, - bias=True, - drop=0., - use_conv=False, - gate_last=True, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Type[nn.Module] = nn.Sigmoid, + norm_layer: Optional[Type[nn.Module]] = None, + bias: Union[bool, Tuple[bool, bool]] = True, + drop: Union[float, Tuple[float, float]] = 0., + use_conv: bool = False, + gate_last: bool = True, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -78,11 +85,11 @@ def __init__( self.chunk_dim = 1 if use_conv else -1 self.gate_last = gate_last # use second half of width for gate - self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0], **dd) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) - self.norm = norm_layer(hidden_features // 2) if norm_layer is not None else nn.Identity() - self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1]) + self.norm = norm_layer(hidden_features // 2, **dd) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1], **dd) self.drop2 = nn.Dropout(drop_probs[1]) def init_weights(self): @@ -112,15 +119,18 @@ class SwiGLU(nn.Module): """ def __init__( self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.SiLU, - norm_layer=None, - bias=True, - drop=0., - align_to=0, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Type[nn.Module] = nn.SiLU, + norm_layer: Optional[Type[nn.Module]] = None, + bias: Union[bool, Tuple[bool, bool]] = True, + drop: Union[float, Tuple[float, float]] = 0., + align_to: int = 0, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -130,12 +140,12 @@ def __init__( if align_to: hidden_features = hidden_features + (-hidden_features % align_to) - self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0]) - self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0], **dd) + self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0], **dd) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) - self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.norm = norm_layer(hidden_features, **dd) if norm_layer is not None else nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1], **dd) self.drop2 = nn.Dropout(drop_probs[1]) def init_weights(self): @@ -160,32 +170,35 @@ class GatedMlp(nn.Module): """ def __init__( self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - norm_layer=None, - gate_layer=None, - bias=True, - drop=0., + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Optional[Type[nn.Module]] = None, + gate_layer: Optional[Type[nn.Module]] = None, + bias: Union[bool, Tuple[bool, bool]] = True, + drop: Union[float, Tuple[float, float]] = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0], **dd) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) if gate_layer is not None: assert hidden_features % 2 == 0 - self.gate = gate_layer(hidden_features) + self.gate = gate_layer(hidden_features, **dd) hidden_features = hidden_features // 2 # FIXME base reduction on gate property? else: self.gate = nn.Identity() - self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.norm = norm_layer(hidden_features, **dd) if norm_layer is not None else nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1], **dd) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): @@ -204,24 +217,27 @@ class ConvMlp(nn.Module): """ def __init__( self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.ReLU, - norm_layer=None, - bias=True, - drop=0., + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Optional[Type[nn.Module]] = None, + bias: Union[bool, Tuple[bool, bool]] = True, + drop: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) - self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) - self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0], **dd) + self.norm = norm_layer(hidden_features, **dd) if norm_layer else nn.Identity() self.act = act_layer() self.drop = nn.Dropout(drop) - self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1], **dd) def forward(self, x): x = self.fc1(x) @@ -239,14 +255,17 @@ class GlobalResponseNormMlp(nn.Module): """ def __init__( self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - bias=True, - drop=0., - use_conv=False, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Type[nn.Module] = nn.GELU, + bias: Union[bool, Tuple[bool, bool]] = True, + drop: Union[float, Tuple[float, float]] = 0., + use_conv: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -254,11 +273,11 @@ def __init__( drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear - self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0], **dd) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) - self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv) - self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv, **dd) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1], **dd) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): diff --git a/timm/layers/non_local_attn.py b/timm/layers/non_local_attn.py index 71fe208290..000eac61e7 100644 --- a/timm/layers/non_local_attn.py +++ b/timm/layers/non_local_attn.py @@ -4,6 +4,8 @@ - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification """ +from typing import Optional, Type + import torch from torch import nn from torch.nn import functional as F @@ -21,16 +23,27 @@ class NonLocalAttn(nn.Module): Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net. """ - def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs): - super(NonLocalAttn, self).__init__() + def __init__( + self, + in_channels, + use_scale=True, + rd_ratio=1/8, + rd_channels=None, + rd_divisor=8, + device=None, + dtype=None, + **_, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() if rd_channels is None: rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) self.scale = in_channels ** -0.5 if use_scale else 1.0 - self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) - self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) - self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) - self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True) - self.norm = nn.BatchNorm2d(in_channels) + self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True, **dd) + self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True, **dd) + self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True, **dd) + self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True, **dd) + self.norm = nn.BatchNorm2d(in_channels, **dd) self.reset_parameters() def forward(self, x): @@ -73,13 +86,22 @@ def reset_parameters(self): @register_notrace_module class BilinearAttnTransform(nn.Module): - def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(BilinearAttnTransform, self).__init__() - - self.conv1 = ConvNormAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer) - self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1)) - self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size)) - self.conv2 = ConvNormAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + def __init__( + self, + in_channels: int, + block_size: int, + groups: int, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.conv1 = ConvNormAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer, **dd) + self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1), **dd) + self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size), **dd) + self.conv2 = ConvNormAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer, **dd) self.block_size = block_size self.groups = groups self.in_channels = in_channels @@ -129,14 +151,34 @@ class BatNonLocalAttn(nn.Module): """ def __init__( - self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, - drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_): + self, + in_channels: int, + block_size: int = 7, + groups: int = 2, + rd_ratio: float = 0.25, + rd_channels: Optional[int] = None, + rd_divisor: int = 8, + drop_rate: float = 0.2, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + device=None, + dtype=None, + **_, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() if rd_channels is None: rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) - self.conv1 = ConvNormAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer) - self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer) - self.conv2 = ConvNormAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.conv1 = ConvNormAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer, **dd) + self.ba = BilinearAttnTransform( + rd_channels, + block_size, + groups, + act_layer=act_layer, + norm_layer=norm_layer, + **dd, + ) + self.conv2 = ConvNormAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer, **dd) self.dropout = nn.Dropout2d(p=drop_rate) def forward(self, x): diff --git a/timm/layers/norm.py b/timm/layers/norm.py index ec082da2ef..cca8eecfe4 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -216,7 +216,7 @@ def __init__( device=None, dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + dd = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels if isinstance(normalized_shape, numbers.Integral): @@ -228,7 +228,7 @@ def __init__( self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) if self.elementwise_affine: - self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd)) else: self.register_parameter('weight', None) @@ -264,7 +264,7 @@ def __init__( device=None, dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + dd = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels if isinstance(normalized_shape, numbers.Integral): @@ -275,7 +275,7 @@ def __init__( self.elementwise_affine = affine if self.elementwise_affine: - self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd)) else: self.register_parameter('weight', None) @@ -312,7 +312,7 @@ def __init__( device=None, dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + dd = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels if isinstance(normalized_shape, numbers.Integral): @@ -324,7 +324,7 @@ def __init__( self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) if self.elementwise_affine: - self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd)) else: self.register_parameter('weight', None) @@ -364,7 +364,7 @@ def __init__( device=None, dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + dd = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels if isinstance(normalized_shape, numbers.Integral): @@ -375,7 +375,7 @@ def __init__( self.elementwise_affine = affine if self.elementwise_affine: - self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd)) else: self.register_parameter('weight', None) @@ -408,7 +408,7 @@ def __init__( device=None, dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + dd = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels if isinstance(normalized_shape, numbers.Integral): @@ -420,7 +420,7 @@ def __init__( self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) if self.elementwise_affine: - self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd)) else: self.register_parameter('weight', None) @@ -454,7 +454,7 @@ def __init__( device=None, dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + dd = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels if isinstance(normalized_shape, numbers.Integral): @@ -465,7 +465,7 @@ def __init__( self.elementwise_affine = affine if self.elementwise_affine: - self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd)) else: self.register_parameter('weight', None) @@ -498,7 +498,7 @@ def __init__( device=None, dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + dd = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels if isinstance(normalized_shape, numbers.Integral): @@ -510,7 +510,7 @@ def __init__( self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) if self.elementwise_affine: - self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd)) else: self.register_parameter('weight', None) @@ -546,7 +546,7 @@ def __init__( device=None, dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + dd = {'device': device, 'dtype': dtype} super().__init__() normalized_shape = channels if isinstance(normalized_shape, numbers.Integral): @@ -557,7 +557,7 @@ def __init__( self.elementwise_affine = affine if self.elementwise_affine: - self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **dd)) else: self.register_parameter('weight', None) diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index d362a95079..cd993bebb7 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -78,7 +78,7 @@ def __init__( ): try: factory_kwargs = {'device': device, 'dtype': dtype} - super(BatchNormAct2d, self).__init__( + super().__init__( num_features, eps=eps, momentum=momentum, @@ -88,7 +88,7 @@ def __init__( ) except TypeError: # NOTE for backwards compat with old PyTorch w/o factory device/dtype support - super(BatchNormAct2d, self).__init__( + super().__init__( num_features, eps=eps, momentum=momentum, @@ -218,21 +218,24 @@ class FrozenBatchNormAct2d(torch.nn.Module): """ def __init__( - self, - num_features: int, - eps: float = 1e-5, - apply_act: bool = True, - act_layer: LayerType = nn.ReLU, - act_kwargs: Dict[str, Any] = None, - inplace: bool = True, - drop_layer: Optional[Type[nn.Module]] = None, + self, + num_features: int, + eps: float = 1e-5, + apply_act: bool = True, + act_layer: LayerType = nn.ReLU, + act_kwargs: Dict[str, Any] = None, + inplace: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.eps = eps - self.register_buffer("weight", torch.ones(num_features)) - self.register_buffer("bias", torch.zeros(num_features)) - self.register_buffer("running_mean", torch.zeros(num_features)) - self.register_buffer("running_var", torch.ones(num_features)) + self.register_buffer("weight", torch.ones(num_features, **dd)) + self.register_buffer("bias", torch.zeros(num_features, **dd)) + self.register_buffer("running_mean", torch.zeros(num_features, **dd)) + self.register_buffer("running_var", torch.ones(num_features, **dd)) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) @@ -380,12 +383,16 @@ def __init__( act_kwargs: Dict[str, Any] = None, inplace: bool = True, drop_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ): - super(GroupNormAct, self).__init__( + super().__init__( _num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine, + device=device, + dtype=dtype, ) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) @@ -415,8 +422,10 @@ def __init__( act_kwargs: Dict[str, Any] = None, inplace: bool = True, drop_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ): - super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine) + super().__init__(1, num_channels, eps=eps, affine=affine, device=device, dtype=dtype) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) @@ -447,7 +456,7 @@ def __init__( drop_layer: Optional[Type[nn.Module]] = None, **kwargs, ): - super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine, **kwargs) + super().__init__(normalization_shape, eps=eps, elementwise_affine=affine, **kwargs) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) @@ -633,8 +642,9 @@ def __init__( act_kwargs: Dict[str, Any] = None, inplace: bool = True, drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, ): - super().__init__(channels=num_channels, eps=eps, affine=affine) + super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() @@ -666,8 +676,9 @@ def __init__( act_kwargs: Dict[str, Any] = None, inplace: bool = True, drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, ): - super().__init__(channels=num_channels, eps=eps, affine=affine) + super().__init__(channels=num_channels, eps=eps, affine=affine, **kwargs) self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index 28d5067b82..0d60e694e3 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -31,7 +31,7 @@ class PatchEmbed(nn.Module): def __init__( self, - img_size: Union[int, Tuple[int, int]] = 224, + img_size: Optional[Union[int, Tuple[int, int]]] = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, @@ -41,7 +41,10 @@ def __init__( bias: bool = True, strict_img_size: bool = True, dynamic_img_pad: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.patch_size = to_2tuple(patch_size) self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size) @@ -56,8 +59,8 @@ def __init__( self.strict_img_size = strict_img_size self.dynamic_img_pad = dynamic_img_pad - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **dd) + self.norm = norm_layer(embed_dim, **dd) if norm_layer else nn.Identity() def _init_img_size(self, img_size: Union[int, Tuple[int, int]]): assert self.patch_size @@ -84,6 +87,8 @@ def set_input_size( kernel_size=new_patch_size, stride=new_patch_size, bias=self.proj.bias is not None, + device=self.proj.weight.device, + dtype=self.proj.weight.dtype, ) new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True)) if self.proj.bias is not None: @@ -144,7 +149,7 @@ class PatchEmbedWithSize(PatchEmbed): def __init__( self, - img_size: Optional[int] = 224, + img_size: Optional[Union[int, Tuple[int, int]]] = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, @@ -152,6 +157,8 @@ def __init__( flatten: bool = True, output_fmt: Optional[str] = None, bias: bool = True, + device=None, + dtype=None, ): super().__init__( img_size=img_size, @@ -162,6 +169,8 @@ def __init__( flatten=flatten, output_fmt=output_fmt, bias=bias, + device=device, + dtype=dtype, ) def forward(self, x) -> Tuple[torch.Tensor, List[int]]: @@ -255,12 +264,12 @@ def resample_kernel(kernel): def _compute_resize_matrix( - old_size: Tuple[int, int], - new_size: Tuple[int, int], - interpolation: str, - antialias: bool, - device: torch.device, - dtype: torch.dtype = DTYPE_INTERMEDIATE + old_size: Tuple[int, int], + new_size: Tuple[int, int], + interpolation: str, + antialias: bool, + device: torch.device, + dtype: torch.dtype = DTYPE_INTERMEDIATE ) -> torch.Tensor: """Computes the resize matrix basis vectors and interpolates them to new_size.""" old_h, old_w = old_size @@ -282,11 +291,11 @@ def _compute_resize_matrix( def _apply_resampling( - patch_embed: torch.Tensor, - pinv_matrix: torch.Tensor, - new_size_tuple: Tuple[int, int], - orig_dtype: torch.dtype, - intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE + patch_embed: torch.Tensor, + pinv_matrix: torch.Tensor, + new_size_tuple: Tuple[int, int], + orig_dtype: torch.dtype, + intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE ) -> torch.Tensor: """ Simplified resampling w/o vmap use. As proposed by https://github.com/stas-sl @@ -335,10 +344,10 @@ class PatchEmbedResamplerFixedOrigSize(nn.Module): caching the pseudoinverse matrix based on the target size. """ def __init__( - self, - orig_size: Tuple[int, int], - interpolation: str = 'bicubic', - antialias: bool = True + self, + orig_size: Tuple[int, int], + interpolation: str = 'bicubic', + antialias: bool = True ): """ Args: @@ -356,10 +365,10 @@ def __init__( self._pinv_cache_map: Dict[Tuple[int, int], str] = {} def _get_or_create_pinv_matrix( - self, - new_size: Tuple[int, int], - device: torch.device, - dtype: torch.dtype = DTYPE_INTERMEDIATE + self, + new_size: Tuple[int, int], + device: torch.device, + dtype: torch.dtype = DTYPE_INTERMEDIATE ) -> torch.Tensor: """Retrieves the cached pinv matrix or computes and caches it for the given new_size.""" cache_key = new_size @@ -438,12 +447,12 @@ class PatchEmbedInterpolator(nn.Module): """ def __init__( - self, - base_patch_size: Tuple[int, int], - in_chans: int = 3, - embed_dim: int = 768, - interpolation: str = 'bicubic', - antialias: bool = True, + self, + base_patch_size: Tuple[int, int], + in_chans: int = 3, + embed_dim: int = 768, + interpolation: str = 'bicubic', + antialias: bool = True, ): super().__init__() self.base_patch_size = base_patch_size @@ -453,9 +462,9 @@ def __init__( self.antialias = antialias def resample_linear_weight( - self, - weight: torch.Tensor, - target_patch_size: Tuple[int, int], + self, + weight: torch.Tensor, + target_patch_size: Tuple[int, int], ) -> torch.Tensor: """Resample linear patch embedding weights for a new patch size. @@ -495,9 +504,9 @@ def resample_linear_weight( return weight_resampled def resample_conv_weight( - self, - weight: torch.Tensor, - target_patch_size: Tuple[int, int], + self, + weight: torch.Tensor, + target_patch_size: Tuple[int, int], ) -> torch.Tensor: """Resample conv2d patch embedding weights for a new patch size. @@ -523,12 +532,12 @@ def resample_conv_weight( return weight_resampled def forward( - self, - patches: torch.Tensor, - proj_weight: torch.Tensor, - proj_bias: Optional[torch.Tensor] = None, - patch_size: Optional[Tuple[int, int]] = None, - is_linear: bool = True, + self, + patches: torch.Tensor, + proj_weight: torch.Tensor, + proj_bias: Optional[torch.Tensor] = None, + patch_size: Optional[Tuple[int, int]] = None, + is_linear: bool = True, ) -> torch.Tensor: """Apply patch embedding with dynamic weight resampling. diff --git a/timm/layers/pool2d_same.py b/timm/layers/pool2d_same.py index 9e9f3046ba..d7513c9427 100644 --- a/timm/layers/pool2d_same.py +++ b/timm/layers/pool2d_same.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Union from ._fx import register_notrace_module from .helpers import to_2tuple @@ -29,10 +29,17 @@ def avg_pool2d_same( class AvgPool2dSame(nn.AvgPool2d): """ Tensorflow like 'SAME' wrapper for 2D average pooling """ - def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Union[int, Tuple[int, int], str] = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + ): kernel_size = to_2tuple(kernel_size) stride = to_2tuple(stride) - super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + super().__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) def forward(self, x): x = pad_same(x, self.kernel_size, self.stride) @@ -56,11 +63,18 @@ def max_pool2d_same( class MaxPool2dSame(nn.MaxPool2d): """ Tensorflow like 'SAME' wrapper for 2D max pooling """ - def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): + def __init__( + self, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Union[int, Tuple[int, int], str] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + ceil_mode: bool = False, + ): kernel_size = to_2tuple(kernel_size) stride = to_2tuple(stride) dilation = to_2tuple(dilation) - super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) + super().__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) def forward(self, x): x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) diff --git a/timm/layers/pos_embed_rel.py b/timm/layers/pos_embed_rel.py index 4fcb111e99..c18969b9e3 100644 --- a/timm/layers/pos_embed_rel.py +++ b/timm/layers/pos_embed_rel.py @@ -22,12 +22,16 @@ def gen_relative_position_index( q_size: Tuple[int, int], k_size: Optional[Tuple[int, int]] = None, class_token: bool = False, + device=None, ) -> torch.Tensor: # Adapted with significant modifications from Swin / BeiT codebases # get pair-wise relative position index for each token inside the window assert k_size is None, 'Different q & k sizes not currently supported' # FIXME - coords = torch.stack(ndgrid(torch.arange(q_size[0]), torch.arange(q_size[1]))).flatten(1) # 2, Wh, Ww + coords = torch.stack(ndgrid( + torch.arange(q_size[0], device=device), + torch.arange(q_size[1], device=device), + )).flatten(1) # 2, Wh, Ww relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0 @@ -140,9 +144,9 @@ def resize_rel_pos_bias_table_levit( position_bias_table.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode=interpolation, - antialias=antialias) - relative_position_bias_table_resized = \ - relative_position_bias_table_resized.view(nH2, L2).permute(1, 0) + antialias=antialias, + ) + relative_position_bias_table_resized = relative_position_bias_table_resized.view(nH2, L2).permute(1, 0) relative_position_bias_table_resized.to(orig_dtype) return relative_position_bias_table_resized else: @@ -270,7 +274,15 @@ class RelPosBias(nn.Module): Adapted from Swin-V1 relative position bias impl, modularized. """ - def __init__(self, window_size, num_heads, prefix_tokens=0): + def __init__( + self, + window_size: Tuple[int, int], + num_heads: int, + prefix_tokens: int = 0, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() assert prefix_tokens <= 1 self.window_size = window_size @@ -278,10 +290,10 @@ def __init__(self, window_size, num_heads, prefix_tokens=0): self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,) num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens - self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) + self.relative_position_bias_table = nn.Parameter(torch.empty(num_relative_distance, num_heads, **dd)) self.register_buffer( "relative_position_index", - gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0).view(-1), + gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0, device=device).view(-1), persistent=False, ) @@ -304,11 +316,13 @@ def gen_relative_log_coords( win_size: Tuple[int, int], pretrained_win_size: Tuple[int, int] = (0, 0), mode='swin', + device=None, + dtype=None, ): assert mode in ('swin', 'cr') # as per official swin-v2 impl, supporting timm specific 'cr' log coords as well - relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0]).to(torch.float32) - relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1]).to(torch.float32) + relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], device=device).to(torch.float32) + relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], device=device).to(torch.float32) relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w)) relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2 if mode == 'swin': @@ -326,7 +340,7 @@ def gen_relative_log_coords( relative_coords_table = torch.sign(relative_coords_table) * torch.log( 1.0 + relative_coords_table.abs()) - return relative_coords_table + return relative_coords_table.to(dtype) class RelPosMlp(nn.Module): @@ -337,13 +351,16 @@ class RelPosMlp(nn.Module): """ def __init__( self, - window_size, - num_heads=8, - hidden_dim=128, - prefix_tokens=0, - mode='cr', - pretrained_window_size=(0, 0) + window_size: Tuple[int, int], + num_heads: int = 8, + hidden_dim: int = 128, + prefix_tokens: int = 0, + mode: str = 'cr', + pretrained_window_size: Tuple[int, int] = (0, 0), + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.window_size = window_size self.window_area = self.window_size[0] * self.window_size[1] @@ -365,19 +382,22 @@ def __init__( out_features=num_heads, act_layer=nn.ReLU, bias=mlp_bias, - drop=(0.125, 0.) + drop=(0.125, 0.), + **dd, ) self.register_buffer( "relative_position_index", - gen_relative_position_index(window_size).view(-1), - persistent=False) + gen_relative_position_index(window_size, device=device).view(-1), + persistent=False, + ) # get relative_coords_table self.register_buffer( "rel_coords_log", - gen_relative_log_coords(window_size, pretrained_window_size, mode=mode), - persistent=False) + gen_relative_log_coords(window_size, pretrained_window_size, mode=mode, **dd), + persistent=False, + ) def get_bias(self) -> torch.Tensor: relative_position_bias = self.mlp(self.rel_coords_log) @@ -399,6 +419,8 @@ def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): def generate_lookup_tensor( length: int, max_relative_position: Optional[int] = None, + device=None, + dtype=None, ): """Generate a one_hot lookup tensor to reindex embeddings along one dimension. @@ -415,7 +437,7 @@ def generate_lookup_tensor( max_relative_position = length - 1 # Return the cached lookup tensor, otherwise compute it and cache it. vocab_size = 2 * max_relative_position + 1 - ret = torch.zeros(length, length, vocab_size) + ret = torch.zeros(length, length, vocab_size, device=device, dtype=dtype) for i in range(length): for x in range(length): v = x - i + max_relative_position @@ -459,7 +481,15 @@ class RelPosBiasTf(nn.Module): Adapted from: https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py """ - def __init__(self, window_size, num_heads, prefix_tokens=0): + def __init__( + self, + window_size: Tuple[int, int], + num_heads: int, + prefix_tokens: int = 0, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() assert prefix_tokens <= 1 self.window_size = window_size @@ -469,9 +499,9 @@ def __init__(self, window_size, num_heads, prefix_tokens=0): vocab_height = 2 * window_size[0] - 1 vocab_width = 2 * window_size[1] - 1 self.bias_shape = (self.num_heads, vocab_height, vocab_width) - self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape)) - self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False) - self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False) + self.relative_position_bias_table = nn.Parameter(torch.empty(self.bias_shape, **dd)) + self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0], **dd), persistent=False) + self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1], **dd), persistent=False) self.init_weights() def init_weights(self): diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index fd8bb1416e..f01d0bca24 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -43,8 +43,8 @@ def build_sincos2d_pos_embed( temperature: float = 10000., reverse_coord: bool = False, interleave_sin_cos: bool = False, + device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None ) -> torch.Tensor: """ @@ -96,8 +96,8 @@ def build_fourier_pos_embed( ref_feat_shape: Optional[List[int]] = None, grid_offset: float = 0., grid_indexing: str = 'ij', - dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, ) -> List[torch.Tensor]: """ @@ -164,7 +164,7 @@ def build_fourier_pos_embed( grid = grid.unsqueeze(-1) pos = grid * bands - pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype) + pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype=dtype) out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos] return out @@ -177,6 +177,8 @@ def __init__( num_bands: int = 64, concat_grid=True, keep_spatial=False, + device=None, + dtype=None, ): super().__init__() self.max_res = max_res @@ -185,7 +187,7 @@ def __init__( self.keep_spatial = keep_spatial self.register_buffer( 'bands', - pixel_freq_bands(max_res, num_bands), + pixel_freq_bands(max_res, num_bands).to(device=device, dtype=dtype), persistent=False, ) @@ -228,7 +230,8 @@ def rope_rotate_half(x: torch.Tensor) -> torch.Tensor: def apply_rot_embed( x: torch.Tensor, - emb: torch.Tensor, + sin_emb: torch.Tensor, + cos_emb: torch.Tensor, half: bool = False, ) -> torch.Tensor: # x: [..., D], eg [x0, x1, x2, x3, x4, x5] @@ -246,7 +249,8 @@ def apply_rot_embed( def apply_rot_embed_list( x: List[torch.Tensor], - emb: torch.Tensor, + sin_emb: torch.Tensor, + cos_emb: torch.Tensor, half: bool = False ) -> List[torch.Tensor]: if isinstance(x, torch.Tensor): @@ -331,8 +335,8 @@ def build_rotary_pos_embed( ref_feat_shape: Optional[List[int]] = None, grid_offset: float = 0., grid_indexing: str = 'ij', - dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, ): """ @@ -347,8 +351,8 @@ def build_rotary_pos_embed( ref_feat_shape: Reference feature shape for resize / fine-tune. grid_offset: Constant offset to add to grid for non-pixel freq. grid_indexing: Indexing mode for meshgrid ('ij' or 'xy') - dtype: Output dtype. device: Output device. + dtype: Output dtype. Returns: @@ -398,6 +402,8 @@ def __init__( ref_feat_shape: Optional[List[int]] = None, grid_offset: float = 0., grid_indexing: str = 'ij', + device=None, + dtype=None, ): super().__init__() self.dim = dim @@ -426,14 +432,14 @@ def __init__( ) self.register_buffer( 'bands', - bands, + bands.to(device=device, dtype=dtype), persistent=False, ) self.pos_embed_sin = None self.pos_embed_cos = None else: # cache full sin/cos embeddings if shape provided up front - emb_sin, emb_cos = self._get_pos_embed_values(feat_shape) + emb_sin, emb_cos = self._get_pos_embed_values(feat_shape, device=device, dtype=dtype) self.bands = None self.register_buffer( 'pos_embed_sin', @@ -446,7 +452,7 @@ def __init__( persistent=False, ) - def _get_pos_embed_values(self, feat_shape: List[int]): + def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32): emb_sin, emb_cos = build_rotary_pos_embed( feat_shape=feat_shape, dim=self.dim, @@ -457,6 +463,8 @@ def _get_pos_embed_values(self, feat_shape: List[int]): ref_feat_shape=self.ref_feat_shape, grid_offset=self.grid_offset, grid_indexing=self.grid_indexing, + device=device, + dtype=dtype, ) return emb_sin, emb_cos @@ -465,9 +473,11 @@ def update_feat_shape(self, feat_shape: List[int]): # only update if feat_shape was set and different from previous value assert self.pos_embed_sin is not None assert self.pos_embed_cos is not None - emb_sin, emb_cos = self._get_pos_embed_values(feat_shape) - self.pos_embed_sin = emb_sin.to(self.pos_embed_sin.device, self.pos_embed_sin.dtype) - self.pos_embed_cos = emb_cos.to(self.pos_embed_cos.device, self.pos_embed_cos.dtype) + self.pos_embed_sin, self.pos_embed_cos = self._get_pos_embed_values( + feat_shape, + device=self.pos_embed_sin.device, + dtype=self.pos_embed_sin.dtype, + ) self.feat_shape = feat_shape def get_embed(self, shape: Optional[List[int]] = None): @@ -502,15 +512,17 @@ class RotaryEmbeddingCat(nn.Module): def __init__( self, - dim, - max_res=224, - temperature=10000, - in_pixels=True, + dim: int, + max_res: int = 224, + temperature: float = 10000, + in_pixels: bool = True, linear_bands: bool = False, feat_shape: Optional[List[int]] = None, ref_feat_shape: Optional[List[int]] = None, grid_offset: float = 0., grid_indexing: str = 'ij', + device=None, + dtype=None, ): super().__init__() self.dim = dim @@ -539,7 +551,7 @@ def __init__( ) self.register_buffer( 'bands', - bands, + bands.to(device=device, dtype=dtype), persistent=False, ) self.pos_embed = None @@ -548,11 +560,11 @@ def __init__( self.bands = None self.register_buffer( 'pos_embed', - self._get_pos_embed_values(feat_shape=feat_shape), + self._get_pos_embed_values(feat_shape=feat_shape, device=device, dtype=dtype), persistent=False, ) - def _get_pos_embed_values(self, feat_shape: List[int]): + def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32): embeds = build_rotary_pos_embed( feat_shape=feat_shape, dim=self.dim, @@ -563,6 +575,8 @@ def _get_pos_embed_values(self, feat_shape: List[int]): ref_feat_shape=self.ref_feat_shape, grid_offset=self.grid_offset, grid_indexing=self.grid_indexing, + device=device, + dtype=dtype, ) return torch.cat(embeds, -1) @@ -570,7 +584,8 @@ def update_feat_shape(self, feat_shape: List[int]): if self.feat_shape is not None and feat_shape != self.feat_shape: # only update if feat_shape was set and different from previous value assert self.pos_embed is not None - self.pos_embed = self._get_pos_embed_values(feat_shape).to( + self.pos_embed = self._get_pos_embed_values( + feat_shape, device=self.pos_embed.device, dtype=self.pos_embed.dtype, ) @@ -697,12 +712,12 @@ def get_mixed_grid( if grid_indexing == 'xy': shape = swap_shape_xy(shape) x_pos, y_pos = torch.meshgrid( - torch.arange(shape[0], dtype=dtype, device=device), - torch.arange(shape[1], dtype=dtype, device=device), + torch.arange(shape[0], device=device, dtype=torch.float32), + torch.arange(shape[1], device=device, dtype=torch.float32), indexing=grid_indexing, ) - t_x = x_pos.flatten() - t_y = y_pos.flatten() + t_x = x_pos.to(dtype).flatten() + t_y = y_pos.to(dtype).flatten() return t_x, t_y @@ -741,6 +756,8 @@ def __init__( temperature: float = 10.0, feat_shape: Optional[List[int]] = None, grid_indexing: str = 'xy', + device=None, + dtype=None, ): """Initialize rotary embeddings. @@ -769,6 +786,8 @@ def __init__( num_heads, temperature=temperature, rotate=True, + device=device, + dtype=dtype, ) # (2, depth, num_heads, head_dim//2) self.freqs = nn.Parameter(freqs) @@ -784,7 +803,7 @@ def _get_grid_values(self, feat_shape: Optional[List[int]]): t_x, t_y = get_mixed_grid( feat_shape, grid_indexing=self.grid_indexing, - device=self.freqs.device + device=self.freqs.device, ) return t_x, t_y @@ -900,8 +919,8 @@ def make_coords_dinov3( Returns: coords with shape (HW, 2) in [-1, 1]. """ # 0.5-centered indices with optional offset - coords_h = torch.arange(0.5, height, device=device, dtype=dtype) + grid_offset - coords_w = torch.arange(0.5, width, device=device, dtype=dtype) + grid_offset + coords_h = torch.arange(0.5, height, device=device, dtype=torch.float32) + grid_offset + coords_w = torch.arange(0.5, width, device=device, dtype=torch.float32) + grid_offset # Normalization denominators if normalize_coords == "max": @@ -921,6 +940,8 @@ def make_coords_dinov3( # Normalize to [0, 1] coords_h = coords_h / h_denom coords_w = coords_w / w_denom + coords_h = coords_h.to(dtype) + coords_w = coords_w.to(dtype) # Create grid then map to [-1, 1] if grid_indexing == "xy": @@ -956,6 +977,8 @@ def __init__( shift_coords: Optional[float] = None, jitter_coords: Optional[float] = None, # interpreted as factor J >= 1 rescale_coords: Optional[float] = None, # interpreted as factor R >= 1 + device=None, + dtype=None, ): super().__init__() @@ -981,7 +1004,7 @@ def __init__( self.grid_indexing = grid_indexing # Precompute periods - periods = self._compute_periods() + periods = self._compute_periods(device=device, dtype=dtype) self.register_buffer("periods", periods, persistent=False) if feat_shape is not None: @@ -995,18 +1018,17 @@ def _compute_periods(self, device: torch.device = 'cpu', dtype: torch.dtype = to dim = self.dim // 4 if self.min_period is not None and self.max_period is not None: - exponents = torch.linspace(0, 1, dim, dtype=torch.float32) + exponents = torch.linspace(0, 1, dim, device='cpu', dtype=torch.float32) periods = self.min_period * ((self.max_period / self.min_period) ** exponents) else: if self.temperature is None: raise ValueError("Provide either min/max periods or `temperature`.") - exponents = 2.0 * torch.arange(dim, device=device, dtype=dtype) / (self.dim // 2) + exponents = 2.0 * torch.arange(dim, device='cpu', dtype=torch.float32) / (self.dim // 2) periods = self.temperature ** exponents # NOTE: The original dinv3 model weights have periods downcast to bfloat16 in persistent buffers, # loaded models will differ a bit vs timm as periods is not persistent and generated in float32 by default - - return periods + return periods.to(device=device, dtype=dtype) def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor: """Apply shift/jitter/rescale train time augmentations.""" @@ -1042,7 +1064,7 @@ def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor: return coords - def _get_pos_embed_from_coords(self, coords: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def _get_pos_embed_from_coords(self, coords: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Return sin/cos embeddings with either 'half' or 'interleaved' layout.""" # coords: (HW, 2); periods: (dim) dim = self.dim // 4 @@ -1066,13 +1088,17 @@ def _get_pos_embed_from_coords(self, coords: torch.Tensor) -> Tuple[torch.Tenso cos = torch.cos(angles) return sin, cos - def _create_embed(self, feat_shape: List[int], no_aug: bool = False) -> torch.Tensor: + def _create_embed( + self, + feat_shape: List[int], + no_aug: bool = False, + ) -> torch.Tensor: H, W = feat_shape coords = make_coords_dinov3( H, W, normalize_coords=self.normalize_coords, grid_indexing=self.grid_indexing, - grid_offset=self.grid_offset + grid_offset=self.grid_offset, ) # (HW, 2) if not no_aug: coords = self._apply_coord_augs(coords) @@ -1081,7 +1107,8 @@ def _create_embed(self, feat_shape: List[int], no_aug: bool = False) -> torch.Te return rope_embed def _cache_embed(self, feat_shape: List[int]): - rope_embed = self._create_embed(feat_shape, no_aug=True) # create non-augmented embeds for cache + # create non-augmented embeds for cache + rope_embed = self._create_embed(feat_shape, no_aug=True) self.register_buffer("pos_embed_cached", rope_embed, persistent=False) self.feat_shape = feat_shape diff --git a/timm/layers/selective_kernel.py b/timm/layers/selective_kernel.py index ec8ee6ce27..d09c9fa286 100644 --- a/timm/layers/selective_kernel.py +++ b/timm/layers/selective_kernel.py @@ -4,6 +4,8 @@ Hacked together by / Copyright 2020 Ross Wightman """ +from typing import List, Optional, Tuple, Type, Union + import torch from torch import nn as nn @@ -20,18 +22,28 @@ def _kernel_valid(k): class SelectiveKernelAttn(nn.Module): - def __init__(self, channels, num_paths=2, attn_channels=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__( + self, + channels: int, + num_paths: int = 2, + attn_channels: int = 32, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + device=None, + dtype=None, + ): """ Selective Kernel Attention Module Selective Kernel attention mechanism factored out into its own module. """ - super(SelectiveKernelAttn, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.num_paths = num_paths - self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) - self.bn = norm_layer(attn_channels) + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False, **dd) + self.bn = norm_layer(attn_channels, **dd) self.act = act_layer(inplace=True) - self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False, **dd) def forward(self, x): _assert(x.shape[1] == self.num_paths, '') @@ -48,9 +60,26 @@ def forward(self, x): class SelectiveKernel(nn.Module): - def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, - rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_layer=None): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + kernel_size: Optional[Union[int, List[int]]] = None, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + rd_ratio: float = 1./16, + rd_channels: Optional[int] = None, + rd_divisor: int = 8, + keep_3x3: bool = True, + split_input: bool = True, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module]= nn.BatchNorm2d, + aa_layer: Optional[Type[nn.Module]] = None, + drop_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): """ Selective Kernel Convolution Module As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. @@ -61,22 +90,23 @@ def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, d a noteworthy increase in performance over similar param count models without this attention layer. -Ross W Args: - in_channels (int): module input (feature) channel count - out_channels (int): module output (feature) channel count - kernel_size (int, list): kernel size for each convolution branch - stride (int): stride for convolutions - dilation (int): dilation for module as a whole, impacts dilation of each branch - groups (int): number of groups for each branch - rd_ratio (int, float): reduction factor for attention features - keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations - split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, + in_channels: module input (feature) channel count + out_channels: module output (feature) channel count + kernel_size: kernel size for each convolution branch + stride: stride for convolutions + dilation: dilation for module as a whole, impacts dilation of each branch + groups: number of groups for each branch + rd_ratio: reduction factor for attention features + keep_3x3: keep all branch convolution kernels as 3x3, changing larger kernels for dilations + split_input: split input channels evenly across each convolution branch, keeps param count lower, can be viewed as grouping by path, output expands to module out_channels count - act_layer (nn.Module): activation layer to use - norm_layer (nn.Module): batchnorm/norm layer to use - aa_layer (nn.Module): anti-aliasing module - drop_layer (nn.Module): spatial drop module in convs (drop block, etc) + act_layer: activation layer to use + norm_layer: batchnorm/norm layer to use + aa_layer: anti-aliasing module + drop_layer: spatial drop module in convs (drop block, etc) """ - super(SelectiveKernel, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() out_channels = out_channels or in_channels kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation _kernel_valid(kernel_size) @@ -98,13 +128,13 @@ def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, d conv_kwargs = dict( stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer, - aa_layer=aa_layer, drop_layer=drop_layer) + aa_layer=aa_layer, drop_layer=drop_layer, **dd) self.paths = nn.ModuleList([ ConvNormAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) for k, d in zip(kernel_size, dilation)]) attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) - self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels, **dd) def forward(self, x): if self.split_input: diff --git a/timm/layers/separable_conv.py b/timm/layers/separable_conv.py index c081e02bc4..bc771a9821 100644 --- a/timm/layers/separable_conv.py +++ b/timm/layers/separable_conv.py @@ -5,6 +5,8 @@ Hacked together by / Copyright 2020 Ross Wightman """ +from typing import Optional, Type, Union + from torch import nn as nn from .create_conv2d import create_conv2d @@ -14,21 +16,50 @@ class SeparableConvNormAct(nn.Module): """ Separable Conv w/ trailing Norm and Activation """ - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, - channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, - apply_act=True, drop_layer=None): - super(SeparableConvNormAct, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + padding: str = '', + bias: bool = False, + channel_multiplier: float = 1.0, + pw_kernel_size: int = 1, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.ReLU, + apply_act: bool = True, + drop_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.conv_dw = create_conv2d( - in_channels, int(in_channels * channel_multiplier), kernel_size, - stride=stride, dilation=dilation, padding=padding, depthwise=True) + in_channels, + int(in_channels * channel_multiplier), + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + depthwise=True, + **dd, + ) self.conv_pw = create_conv2d( - int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + int(in_channels * channel_multiplier), + out_channels, + pw_kernel_size, + padding=padding, + bias=bias, + **dd, + ) norm_act_layer = get_norm_act_layer(norm_layer, act_layer) norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} - self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) + self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs, **dd) @property def in_channels(self): @@ -51,16 +82,42 @@ def forward(self, x): class SeparableConv2d(nn.Module): """ Separable Conv """ - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, - channel_multiplier=1.0, pw_kernel_size=1): - super(SeparableConv2d, self).__init__() + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + dilation=1, + padding='', + bias=False, + channel_multiplier=1.0, + pw_kernel_size=1, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.conv_dw = create_conv2d( - in_channels, int(in_channels * channel_multiplier), kernel_size, - stride=stride, dilation=dilation, padding=padding, depthwise=True) + in_channels, + int(in_channels * channel_multiplier), + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + depthwise=True, + **dd, + ) self.conv_pw = create_conv2d( - int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + int(in_channels * channel_multiplier), + out_channels, + pw_kernel_size, + padding=padding, + bias=bias, + **dd, + ) @property def in_channels(self): diff --git a/timm/layers/space_to_depth.py b/timm/layers/space_to_depth.py index 452681544f..6289c58977 100644 --- a/timm/layers/space_to_depth.py +++ b/timm/layers/space_to_depth.py @@ -5,7 +5,7 @@ class SpaceToDepth(nn.Module): bs: torch.jit.Final[int] - def __init__(self, block_size=4): + def __init__(self, block_size: int = 4): super().__init__() assert block_size == 4 self.bs = block_size diff --git a/timm/layers/split_attn.py b/timm/layers/split_attn.py index ac54f8988a..d702492e19 100644 --- a/timm/layers/split_attn.py +++ b/timm/layers/split_attn.py @@ -6,6 +6,8 @@ Modified for torchscript compat, performance, and consistency with timm by Ross Wightman """ +from typing import Optional, Type, Union + import torch import torch.nn.functional as F from torch import nn @@ -14,8 +16,8 @@ class RadixSoftmax(nn.Module): - def __init__(self, radix, cardinality): - super(RadixSoftmax, self).__init__() + def __init__(self, radix: int, cardinality: int): + super().__init__() self.radix = radix self.cardinality = cardinality @@ -33,10 +35,27 @@ def forward(self, x): class SplitAttn(nn.Module): """Split-Attention (aka Splat) """ - def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, - dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, - act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs): - super(SplitAttn, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + dilation: int = 1, + groups: int = 1, + bias: bool = False, + radix: int = 2, + rd_ratio: float = 0.25, + rd_channels: Optional[int] = None, + rd_divisor: int = 8, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Optional[Type[nn.Module]] = None, + drop_layer: Optional[Type[nn.Module]] = None, + **kwargs, + ): + dd = {'device': kwargs.pop('device', None), 'dtype': kwargs.pop('dtype', None)} + super().__init__() out_channels = out_channels or in_channels self.radix = radix mid_chs = out_channels * radix @@ -47,15 +66,24 @@ def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padd padding = kernel_size // 2 if padding is None else padding self.conv = nn.Conv2d( - in_channels, mid_chs, kernel_size, stride, padding, dilation, - groups=groups * radix, bias=bias, **kwargs) - self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() + in_channels, + mid_chs, + kernel_size, + stride, + padding, + dilation, + groups=groups * radix, + bias=bias, + **kwargs, + **dd, + ) + self.bn0 = norm_layer(mid_chs, **dd) if norm_layer else nn.Identity() self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act0 = act_layer(inplace=True) - self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) - self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() + self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups, **dd) + self.bn1 = norm_layer(attn_chs, **dd) if norm_layer else nn.Identity() self.act1 = act_layer(inplace=True) - self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) + self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups, **dd) self.rsoftmax = RadixSoftmax(radix, groups) def forward(self, x): diff --git a/timm/layers/split_batchnorm.py b/timm/layers/split_batchnorm.py index 830781b335..31e6d356f0 100644 --- a/timm/layers/split_batchnorm.py +++ b/timm/layers/split_batchnorm.py @@ -17,13 +17,25 @@ class SplitBatchNorm2d(torch.nn.BatchNorm2d): - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, - track_running_stats=True, num_splits=2): + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + num_splits=2, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__(num_features, eps, momentum, affine, track_running_stats) assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' self.num_splits = num_splits self.aux_bn = nn.ModuleList([ - nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) + nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats, **dd) + for _ in range(num_splits - 1) + ]) def forward(self, input: torch.Tensor): if self.training: # aux BN only relevant while training diff --git a/timm/layers/squeeze_excite.py b/timm/layers/squeeze_excite.py index 4fe568fe8f..ed584dacfd 100644 --- a/timm/layers/squeeze_excite.py +++ b/timm/layers/squeeze_excite.py @@ -10,6 +10,8 @@ Hacked together by / Copyright 2021 Ross Wightman """ +from typing import Optional, Tuple, Type, Union + from torch import nn as nn from .create_act import create_act_layer @@ -26,16 +28,28 @@ class SEModule(nn.Module): * customizable activation, normalization, and gate layer """ def __init__( - self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, - bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): - super(SEModule, self).__init__() + self, + channels: int, + rd_ratio: float = 1. / 16, + rd_channels: Optional[int] = None, + rd_divisor: int = 8, + add_maxpool: bool = False, + bias: bool = True, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Optional[Type[nn.Module]] = None, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.add_maxpool = add_maxpool if not rd_channels: rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) - self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias) - self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() + self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias, **dd) + self.bn = norm_layer(rd_channels, **dd) if norm_layer else nn.Identity() self.act = create_act_layer(act_layer, inplace=True) - self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias) + self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias, **dd) self.gate = create_act_layer(gate_layer) def forward(self, x): @@ -56,10 +70,19 @@ class EffectiveSEModule(nn.Module): """ 'Effective Squeeze-Excitation From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 """ - def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): - super(EffectiveSEModule, self).__init__() + def __init__( + self, + channels: int, + add_maxpool: bool = False, + gate_layer: Union[str, Type[nn.Module]] = 'hard_sigmoid', + device=None, + dtype=None, + **_, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.add_maxpool = add_maxpool - self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) + self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0, device=device, dtype=dtype) self.gate = create_act_layer(gate_layer) def forward(self, x): @@ -84,14 +107,24 @@ class SqueezeExciteCl(nn.Module): * customizable activation, normalization, and gate layer """ def __init__( - self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, - bias=True, act_layer=nn.ReLU, gate_layer='sigmoid'): + self, + channels: int, + rd_ratio: float = 1. / 16, + rd_channels: Optional[int] = None, + rd_divisor: int = 8, + bias: bool = True, + act_layer: Type[nn.Module] = nn.ReLU, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() if not rd_channels: rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) - self.fc1 = nn.Linear(channels, rd_channels, bias=bias) + self.fc1 = nn.Linear(channels, rd_channels, bias=bias, **dd) self.act = create_act_layer(act_layer, inplace=True) - self.fc2 = nn.Linear(rd_channels, channels, bias=bias) + self.fc2 = nn.Linear(rd_channels, channels, bias=bias, **dd) self.gate = create_act_layer(gate_layer) def forward(self, x): diff --git a/timm/layers/std_conv.py b/timm/layers/std_conv.py index 287c60d516..c8d5e6506d 100644 --- a/timm/layers/std_conv.py +++ b/timm/layers/std_conv.py @@ -16,6 +16,8 @@ Hacked together by / copyright Ross Wightman, 2021. """ +from typing import Optional, Tuple, Union + import torch import torch.nn as nn import torch.nn.functional as F @@ -31,19 +33,35 @@ class StdConv2d(nn.Conv2d): https://arxiv.org/abs/1903.10520v2 """ def __init__( - self, in_channel, out_channels, kernel_size, stride=1, padding=None, - dilation=1, groups=1, bias=False, eps=1e-6): + self, + in_channel: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Optional[Union[int, Tuple[int, int]]] = None, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = False, + eps: float = 1e-6, + device=None, + dtype=None, + ): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( in_channel, out_channels, kernel_size, stride=stride, - padding=padding, dilation=dilation, groups=groups, bias=bias) + padding=padding, dilation=dilation, groups=groups, bias=bias, device=device, dtype=dtype) self.eps = eps def forward(self, x): weight = F.batch_norm( - self.weight.reshape(1, self.out_channels, -1), None, None, - training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + self.weight.reshape(1, self.out_channels, -1), + None, # running_mean + None, # running_var + training=True, + momentum=0., + eps=self.eps, + ).reshape_as(self.weight) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -56,12 +74,23 @@ class StdConv2dSame(nn.Conv2d): https://arxiv.org/abs/1903.10520v2 """ def __init__( - self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', - dilation=1, groups=1, bias=False, eps=1e-6): + self, + in_channel: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: str = 'SAME', + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = False, + eps: float = 1e-6, + device=None, + dtype=None, + ): padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, - groups=groups, bias=bias) + groups=groups, bias=bias, device=device, dtype=dtype) self.same_pad = is_dynamic self.eps = eps @@ -69,8 +98,13 @@ def forward(self, x): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) weight = F.batch_norm( - self.weight.reshape(1, self.out_channels, -1), None, None, - training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + self.weight.reshape(1, self.out_channels, -1), + None, # running_mean + None, # running_var + training=True, + momentum=0., + eps=self.eps, + ).reshape_as(self.weight) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -85,22 +119,53 @@ class ScaledStdConv2d(nn.Conv2d): """ def __init__( - self, in_channels, out_channels, kernel_size, stride=1, padding=None, - dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Optional[Union[int, Tuple[int, int], str]] = None, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + gamma: float = 1.0, + eps: float = 1e-6, + gain_init: float = 1.0, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, - groups=groups, bias=bias) - self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) + groups=groups, bias=bias, **dd) self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) self.eps = eps + self.gain_init = gain_init + + self.gain = nn.Parameter(torch.empty((self.out_channels, 1, 1, 1), **dd)) + + self.reset_parameters() + + def reset_parameters(self) -> None: + # Only initialize gain if it exists (for the second call) + if hasattr(self, 'gain'): + torch.nn.init.constant_(self.gain, self.gain_init) + # Also reset parent parameters if needed + super().reset_parameters() + # On first call (from super().__init__), do nothing def forward(self, x): weight = F.batch_norm( - self.weight.reshape(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), + None, # running_mean + None, # running_var weight=(self.gain * self.scale).view(-1), - training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + training=True, + momentum=0., + eps=self.eps, + ).reshape_as(self.weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) @@ -115,22 +180,53 @@ class ScaledStdConv2dSame(nn.Conv2d): """ def __init__( - self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', - dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: str = 'SAME', + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + gamma: float = 1.0, + eps: float = 1e-6, + gain_init: float = 1.0, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, - groups=groups, bias=bias) - self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) + groups=groups, bias=bias, **dd) self.scale = gamma * self.weight[0].numel() ** -0.5 self.same_pad = is_dynamic self.eps = eps + self.gain_init = gain_init + + self.gain = nn.Parameter(torch.empty((self.out_channels, 1, 1, 1), **dd)) + + self.reset_parameters() + + def reset_parameters(self) -> None: + # Only initialize gain if it exists (for the second call) + if hasattr(self, 'gain'): + torch.nn.init.constant_(self.gain, self.gain_init) + # Also reset parent parameters if needed + super().reset_parameters() + # On first call (from super().__init__), do nothing def forward(self, x): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) weight = F.batch_norm( - self.weight.reshape(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), + None, # running_mean + None, # running_var weight=(self.gain * self.scale).view(-1), - training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + training=True, + momentum=0., + eps=self.eps, + ).reshape_as(self.weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) diff --git a/timm/layers/test_time_pool.py b/timm/layers/test_time_pool.py index 5826d8c966..dc7200623a 100644 --- a/timm/layers/test_time_pool.py +++ b/timm/layers/test_time_pool.py @@ -15,7 +15,7 @@ class TestTimePoolHead(nn.Module): def __init__(self, base, original_pool=7): - super(TestTimePoolHead, self).__init__() + super().__init__() self.base = base self.original_pool = original_pool base_fc = self.base.get_classifier() diff --git a/timm/models/_efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py index 0903b9e595..6edb90d50f 100644 --- a/timm/models/_efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -2,14 +2,26 @@ Hacked together by / Copyright 2019, Ross Wightman """ -from typing import Callable, Dict, Optional, Type +from typing import Callable, Dict, Optional, Type, Union import torch import torch.nn as nn from torch.nn import functional as F -from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\ - ConvNormAct, get_norm_act_layer, MultiQueryAttention2d, Attention2d +from timm.layers import ( + create_conv2d, + DropPath, + make_divisible, + create_act_layer, + create_aa, + to_2tuple, + LayerType, + ConvNormAct, + get_norm_act_layer, + MultiQueryAttention2d, + Attention2d, + LayerScale2d, +) __all__ = [ 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', @@ -49,15 +61,18 @@ def __init__( gate_layer: LayerType = nn.Sigmoid, force_act_layer: Optional[LayerType] = None, rd_round_fn: Optional[Callable] = None, + device=None, + dtype=None, ): - super(SqueezeExcite, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() if rd_channels is None: rd_round_fn = rd_round_fn or round rd_channels = rd_round_fn(in_chs * rd_ratio) act_layer = force_act_layer or act_layer - self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True) + self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True, **dd) self.act1 = create_act_layer(act_layer, inplace=True) - self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True) + self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True, **dd) self.gate = create_act_layer(gate_layer) def forward(self, x): @@ -79,25 +94,34 @@ def __init__( stride: int = 1, dilation: int = 1, group_size: int = 0, - pad_type: str = '', + pad_type: Union[int, str] = '', skip: bool = False, - act_layer: LayerType = nn.ReLU, + act_layer: Optional[LayerType] = nn.ReLU, norm_layer: LayerType = nn.BatchNorm2d, aa_layer: Optional[LayerType] = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): - super(ConvBnAct, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) groups = num_groups(group_size, in_chs) self.has_skip = skip and stride == 1 and in_chs == out_chs use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation self.conv = create_conv2d( - in_chs, out_chs, kernel_size, + in_chs, + out_chs, + kernel_size, stride=1 if use_aa else stride, - dilation=dilation, groups=groups, padding=pad_type) - self.bn1 = norm_act_layer(out_chs, inplace=True) - self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa) + dilation=dilation, + groups=groups, + padding=pad_type, + **dd, + ) + self.bn1 = norm_act_layer(out_chs, inplace=True, **dd) + self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa, **dd) self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() def feature_info(self, location): @@ -139,8 +163,11 @@ def __init__( aa_layer: Optional[LayerType] = None, se_layer: Optional[ModuleType] = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): - super(DepthwiseSeparableConv, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv @@ -149,8 +176,8 @@ def __init__( # Space to depth if s2d == 1: sd_chs = int(in_chs * 4) - self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same') - self.bn_s2d = norm_act_layer(sd_chs) + self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same', **dd) + self.bn_s2d = norm_act_layer(sd_chs, **dd) dw_kernel_size = (dw_kernel_size + 1) // 2 dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type in_chs = sd_chs @@ -163,17 +190,23 @@ def __init__( groups = num_groups(group_size, in_chs) self.conv_dw = create_conv2d( - in_chs, in_chs, dw_kernel_size, + in_chs, + in_chs, + dw_kernel_size, stride=1 if use_aa else stride, - dilation=dilation, padding=dw_pad_type, groups=groups) - self.bn1 = norm_act_layer(in_chs, inplace=True) - self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa) + dilation=dilation, + padding=dw_pad_type, + groups=groups, + **dd, + ) + self.bn1 = norm_act_layer(in_chs, inplace=True, **dd) + self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa, **dd) # Squeeze-and-excitation - self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity() + self.se = se_layer(in_chs, act_layer=act_layer, **dd) if se_layer else nn.Identity() - self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) - self.bn2 = norm_act_layer(out_chs, inplace=True, apply_act=self.has_pw_act) + self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type, **dd) + self.bn2 = norm_act_layer(out_chs, inplace=True, apply_act=self.has_pw_act, **dd) self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() def feature_info(self, location): @@ -228,8 +261,11 @@ def __init__( se_layer: Optional[ModuleType] = None, conv_kwargs: Optional[Dict] = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): - super(InvertedResidual, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) conv_kwargs = conv_kwargs or {} self.has_skip = (in_chs == out_chs and stride == 1) and not noskip @@ -238,8 +274,8 @@ def __init__( # Space to depth if s2d == 1: sd_chs = int(in_chs * 4) - self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same') - self.bn_s2d = norm_act_layer(sd_chs, sd_chs) + self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same', **dd) + self.bn_s2d = norm_act_layer(sd_chs, **dd) dw_kernel_size = (dw_kernel_size + 1) // 2 dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type in_chs = sd_chs @@ -253,23 +289,30 @@ def __init__( groups = num_groups(group_size, mid_chs) # Point-wise expansion - self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) - self.bn1 = norm_act_layer(mid_chs, inplace=True) + self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs, **dd) + self.bn1 = norm_act_layer(mid_chs, inplace=True, **dd) # Depth-wise convolution self.conv_dw = create_conv2d( - mid_chs, mid_chs, dw_kernel_size, + mid_chs, + mid_chs, + dw_kernel_size, stride=1 if use_aa else stride, - dilation=dilation, groups=groups, padding=dw_pad_type, **conv_kwargs) - self.bn2 = norm_act_layer(mid_chs, inplace=True) - self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa) + dilation=dilation, + groups=groups, + padding=dw_pad_type, + **conv_kwargs, + **dd, + ) + self.bn2 = norm_act_layer(mid_chs, inplace=True, **dd) + self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa, **dd) # Squeeze-and-excitation - self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() + self.se = se_layer(mid_chs, act_layer=act_layer, **dd) if se_layer else nn.Identity() # Point-wise linear projection - self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) - self.bn3 = norm_act_layer(out_chs, apply_act=False) + self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs, **dd) + self.bn3 = norm_act_layer(out_chs, apply_act=False, **dd) self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() def feature_info(self, location): @@ -296,17 +339,6 @@ def forward(self, x): return x -class LayerScale2d(nn.Module): - def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - gamma = self.gamma.view(1, -1, 1, 1) - return x.mul_(gamma) if self.inplace else x * gamma - - class UniversalInvertedResidual(nn.Module): """ Universal Inverted Residual Block (aka Universal Inverted Bottleneck, UIB) @@ -334,8 +366,11 @@ def __init__( conv_kwargs: Optional[Dict] = None, drop_path_rate: float = 0., layer_scale_init_value: Optional[float] = 1e-5, + device=None, + dtype=None, ): - super(UniversalInvertedResidual, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() conv_kwargs = conv_kwargs or {} self.has_skip = (in_chs == out_chs and stride == 1) and not noskip if stride > 1: @@ -356,6 +391,7 @@ def __init__( norm_layer=norm_layer, aa_layer=aa_layer, **conv_kwargs, + **dd, ) else: self.dw_start = nn.Identity() @@ -368,6 +404,7 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, **conv_kwargs, + **dd, ) # Middle depth-wise convolution @@ -383,13 +420,14 @@ def __init__( norm_layer=norm_layer, aa_layer=aa_layer, **conv_kwargs, + **dd, ) else: # keeping mid as identity so it can be hooked more easily for features self.dw_mid = nn.Identity() # Squeeze-and-excitation - self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() + self.se = se_layer(mid_chs, act_layer=act_layer, **dd) if se_layer else nn.Identity() # Point-wise linear projection self.pw_proj = ConvNormAct( @@ -399,6 +437,7 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, **conv_kwargs, + **dd, ) if dw_kernel_size_end: @@ -416,12 +455,13 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, **conv_kwargs, + **dd, ) else: self.dw_end = nn.Identity() if layer_scale_init_value is not None: - self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value) + self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value, **dd) else: self.layer_scale = nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() @@ -478,8 +518,11 @@ def __init__( layer_scale_init_value: Optional[float] = 1e-5, use_bias: bool = False, use_cpe: bool = False, + device=None, + dtype=None, ): - super(MobileAttention, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip self.query_strides = to_2tuple(query_strides) @@ -498,11 +541,12 @@ def __init__( dilation=dilation, depthwise=True, bias=True, + **dd, ) else: self.conv_cpe_dw = None - self.norm = norm_act_layer(in_chs, apply_act=False) + self.norm = norm_act_layer(in_chs, apply_act=False, **dd) if num_heads is None: assert in_chs % key_dim == 0 @@ -524,6 +568,7 @@ def __init__( proj_drop=proj_drop, norm_layer=norm_layer, # use_bias=use_bias, # why not here if used w/ mhsa? + **dd, ) else: self.attn = Attention2d( @@ -533,10 +578,11 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, bias=use_bias, + **dd, ) if layer_scale_init_value is not None: - self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value) + self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value, **dd) else: self.layer_scale = nn.Identity() @@ -585,11 +631,13 @@ def __init__( se_layer: Optional[ModuleType] = None, num_experts: int = 0, drop_path_rate: float = 0., + device=None, + dtype=None, ): - + dd = {'device': device, 'dtype': dtype} self.num_experts = num_experts conv_kwargs = dict(num_experts=self.num_experts) - super(CondConvResidual, self).__init__( + super().__init__( in_chs, out_chs, dw_kernel_size=dw_kernel_size, @@ -607,8 +655,9 @@ def __init__( se_layer=se_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate, + **dd, ) - self.routing_fn = nn.Linear(in_chs, self.num_experts) + self.routing_fn = nn.Linear(in_chs, self.num_experts, **dd) def forward(self, x): shortcut = x @@ -656,8 +705,11 @@ def __init__( aa_layer: Optional[LayerType] = None, se_layer: Optional[ModuleType] = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): - super(EdgeResidual, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) if force_in_chs > 0: mid_chs = make_divisible(force_in_chs * exp_ratio) @@ -669,19 +721,25 @@ def __init__( # Expansion convolution self.conv_exp = create_conv2d( - in_chs, mid_chs, exp_kernel_size, + in_chs, + mid_chs, + exp_kernel_size, stride=1 if use_aa else stride, - dilation=dilation, groups=groups, padding=pad_type) - self.bn1 = norm_act_layer(mid_chs, inplace=True) + dilation=dilation, + groups=groups, + padding=pad_type, + **dd, + ) + self.bn1 = norm_act_layer(mid_chs, inplace=True, **dd) - self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa) + self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa, **dd) # Squeeze-and-excitation - self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() + self.se = se_layer(mid_chs, act_layer=act_layer, **dd) if se_layer else nn.Identity() # Point-wise linear projection - self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) - self.bn2 = norm_act_layer(out_chs, apply_act=False) + self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **dd) + self.bn2 = norm_act_layer(out_chs, apply_act=False, **dd) self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() def feature_info(self, location): diff --git a/timm/models/_efficientnet_builder.py b/timm/models/_efficientnet_builder.py index 24e47e8477..c0a1ea42ea 100644 --- a/timm/models/_efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -335,6 +335,8 @@ def __init__( drop_path_rate: float = 0., layer_scale_init_value: Optional[float] = None, feature_location: str = '', + device=None, + dtype=None, ): self.output_stride = output_stride self.pad_type = pad_type @@ -357,6 +359,7 @@ def __init__( feature_location = 'expansion' self.feature_location = feature_location assert feature_location in ('bottleneck', 'expansion', '') + self.dd = {'device': device, 'dtype': dtype} # device/dtype factory kwargs self.verbose = _DEBUG_BUILDER # state updated during build, consumed by model @@ -398,6 +401,8 @@ def _make_block(self, ba, block_idx, block_count): else: ba['se_layer'] = self.se_layer + ba.update(self.dd) # device/type factory kwargs + if bt == 'ir': _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba) diff --git a/timm/models/_features.py b/timm/models/_features.py index fec947af74..3814869175 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -260,7 +260,7 @@ def __init__( first element e.g. `x[0]` flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules) """ - super(FeatureDictNet, self).__init__() + super().__init__() self.feature_info = _get_feature_info(model, out_indices) self.output_fmt = Format(output_fmt) self.concat = feature_concat diff --git a/timm/models/_prune.py b/timm/models/_prune.py index 370b911f46..80ccc03a4d 100644 --- a/timm/models/_prune.py +++ b/timm/models/_prune.py @@ -62,6 +62,11 @@ def adapt_model_from_string(parent_module, model_string): if shape[0] != '': state_dict[key] = [int(i) for i in shape] + # Extract device and dtype from the parent module + device = next(parent_module.parameters()).device + dtype = next(parent_module.parameters()).dtype + dd = {'device': device, 'dtype': dtype} + new_module = deepcopy(parent_module) for n, m in parent_module.named_modules(): old_module = extract_layer(parent_module, n) @@ -78,27 +83,48 @@ def adapt_model_from_string(parent_module, model_string): in_channels = out_channels g = in_channels new_conv = conv( - in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, - bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, - groups=g, stride=old_module.stride) + in_channels=in_channels, + out_channels=out_channels, + kernel_size=old_module.kernel_size, + bias=old_module.bias is not None, + padding=old_module.padding, + dilation=old_module.dilation, + groups=g, + stride=old_module.stride, + **dd, + ) set_layer(new_module, n, new_conv) elif isinstance(old_module, BatchNormAct2d): new_bn = BatchNormAct2d( - state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, - affine=old_module.affine, track_running_stats=True) + state_dict[n + '.weight'][0], + eps=old_module.eps, + momentum=old_module.momentum, + affine=old_module.affine, + track_running_stats=True, + **dd, + ) new_bn.drop = old_module.drop new_bn.act = old_module.act set_layer(new_module, n, new_bn) elif isinstance(old_module, nn.BatchNorm2d): new_bn = nn.BatchNorm2d( - num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, - affine=old_module.affine, track_running_stats=True) + num_features=state_dict[n + '.weight'][0], + eps=old_module.eps, + momentum=old_module.momentum, + affine=old_module.affine, + track_running_stats=True, + **dd, + ) set_layer(new_module, n, new_bn) elif isinstance(old_module, nn.Linear): # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? num_features = state_dict[n + '.weight'][1] new_fc = Linear( - in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) + in_features=num_features, + out_features=old_module.out_features, + bias=old_module.bias is not None, + **dd, + ) set_layer(new_module, n, new_fc) if hasattr(new_module, 'num_features'): if getattr(new_module, 'head_hidden_size', 0) == new_module.num_features: diff --git a/timm/models/beit.py b/timm/models/beit.py index 04ef1ffa75..6c241edf35 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -39,15 +39,27 @@ # --------------------------------------------------------' import math -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, calculate_drop_path_rates, trunc_normal_, use_fused_attn -from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid +from timm.layers import ( + PatchEmbed, + Mlp, + SwiGLU, + LayerNorm, + DropPath, + calculate_drop_path_rates, + trunc_normal_, + use_fused_attn, + resample_patch_embed, + resample_abs_pos_embed, + resize_rel_pos_bias_table, + ndgrid, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -57,7 +69,7 @@ __all__ = ['Beit'] -def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: +def gen_relative_position_index(window_size: Tuple[int, int], device=None) -> torch.Tensor: """Generate relative position index for window-based attention. Creates a lookup table for relative position indices between all pairs of positions @@ -74,14 +86,17 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window window_area = window_size[0] * window_size[1] - coords = torch.stack(ndgrid(torch.arange(window_size[0]), torch.arange(window_size[1]))) # 2, Wh, Ww + coords = torch.stack(ndgrid( + torch.arange(window_size[0], device=device, dtype=torch.long), + torch.arange(window_size[1], device=device, dtype=torch.long), + )) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 - relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index = torch.zeros(size=(window_area + 1,) * 2, device=device, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0:, 0] = num_relative_distance - 2 @@ -107,6 +122,8 @@ def __init__( proj_drop: float = 0., window_size: Optional[Tuple[int, int]] = None, attn_head_dim: Optional[int] = None, + device=None, + dtype=None, ): """Initialize attention module. @@ -120,6 +137,7 @@ def __init__( window_size: Window size for relative position bias. If None, no relative position bias. attn_head_dim: Dimension per attention head. If None, uses dim // num_heads. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads head_dim = dim // num_heads @@ -130,11 +148,11 @@ def __init__( self.fused_attn = use_fused_attn() self.qkv_bias_separate = qkv_bias_separate - self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False, **dd) if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) - self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False) - self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.q_bias = nn.Parameter(torch.zeros(all_head_dim, **dd)) + self.register_buffer('k_bias', torch.zeros(all_head_dim, **dd), persistent=False) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim, **dd)) else: self.q_bias = None self.k_bias = None @@ -144,15 +162,19 @@ def __init__( self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( - torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH - self.register_buffer("relative_position_index", gen_relative_position_index(window_size), persistent=False) + torch.zeros(self.num_relative_distance, num_heads, **dd)) # 2*Wh-1 * 2*Ww-1, nH + self.register_buffer( + "relative_position_index", + gen_relative_position_index(window_size, device=device), + persistent=False, + ) else: self.window_size = None self.relative_position_bias_table = None self.relative_position_index = None self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(all_head_dim, dim) + self.proj = nn.Linear(all_head_dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) def _get_rel_pos_bias(self) -> torch.Tensor: @@ -245,10 +267,12 @@ def __init__( attn_drop: float = 0., drop_path: float = 0., init_values: Optional[float] = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = LayerNorm, window_size: Optional[Tuple[int, int]] = None, attn_head_dim: Optional[int] = None, + device=None, + dtype=None, ): """Initialize transformer block. @@ -268,8 +292,9 @@ def __init__( window_size: Window size for relative position bias in attention. attn_head_dim: Dimension per attention head. """ + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = Attention( dim, num_heads=num_heads, @@ -278,17 +303,19 @@ def __init__( proj_drop=proj_drop, window_size=window_size, attn_head_dim=attn_head_dim, + **dd, ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) if swiglu_mlp: self.mlp = SwiGLU( in_features=dim, hidden_features=int(dim * mlp_ratio), norm_layer=norm_layer if scale_mlp else None, drop=proj_drop, + **dd, ) else: self.mlp = Mlp( @@ -297,12 +324,13 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer if scale_mlp else None, drop=proj_drop, + **dd, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() if init_values: - self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) - self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) + self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd)) + self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd)) else: self.gamma_1, self.gamma_2 = None, None @@ -332,18 +360,19 @@ class RelativePositionBias(nn.Module): within a window, including special handling for cls token. """ - def __init__(self, window_size: Tuple[int, int], num_heads: int): + def __init__(self, window_size: Tuple[int, int], num_heads: int, device=None, dtype=None): """Initialize relative position bias module. Args: window_size: Height and width of the attention window. num_heads: Number of attention heads. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.window_size = window_size self.window_area = window_size[0] * window_size[1] num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 - self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) + self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads, **dd)) # trunc_normal_(self.relative_position_bias_table, std=.02) self.register_buffer("relative_position_index", gen_relative_position_index(window_size)) @@ -385,12 +414,14 @@ def __init__( proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., - norm_layer: Callable = LayerNorm, + norm_layer: Type[nn.Module] = LayerNorm, init_values: Optional[float] = None, use_abs_pos_emb: bool = True, use_rel_pos_bias: bool = False, use_shared_rel_pos_bias: bool = False, head_init_scale: float = 0.001, + device=None, + dtype=None, ): """Initialize BEiT model. @@ -419,6 +450,7 @@ def __init__( use_shared_rel_pos_bias: If True, share relative position bias across layers. head_init_scale: Scale factor for head initialization. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_classes = num_classes self.global_pool = global_pool @@ -431,19 +463,21 @@ def __init__( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + **dd, ) num_patches = self.patch_embed.num_patches r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd)) # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim, **dd)) if use_abs_pos_emb else None self.pos_drop = nn.Dropout(p=pos_drop_rate) if use_shared_rel_pos_bias: self.rel_pos_bias = RelativePositionBias( window_size=self.patch_embed.grid_size, num_heads=num_heads, + **dd, ) else: self.rel_pos_bias = None @@ -463,16 +497,17 @@ def __init__( norm_layer=norm_layer, init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None, + **dd, ) for i in range(depth)]) self.feature_info = [ dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] use_fc_norm = self.global_pool == 'avg' - self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) - self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim, **dd) + self.fc_norm = norm_layer(embed_dim, **dd) if use_fc_norm else nn.Identity() self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) if self.pos_embed is not None: diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index e90f2cc9b6..60a443d3d8 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -31,16 +31,29 @@ import math from dataclasses import dataclass, field, replace from functools import partial -from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence +from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence, Type import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import ( - ClassifierHead, NormMlpClassifierHead, ConvNormAct, BatchNormAct2d, EvoNorm2dS0a, - AttentionPool2d, RotAttentionPool2d, DropPath, calculate_drop_path_rates, AvgPool2dSame, - create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, + ClassifierHead, + NormMlpClassifierHead, + ConvNormAct, + BatchNormAct2d, + EvoNorm2dS0a, + AttentionPool2d, + RotAttentionPool2d, + DropPath, + calculate_drop_path_rates, + AvgPool2dSame, + create_conv2d, + get_act_layer, + get_norm_act_layer, + get_attn, + make_divisible, + to_2tuple, ) from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -232,11 +245,11 @@ def num_groups(group_size: Optional[int], channels: int) -> int: @dataclass class LayerFn: """Container for layer factory functions.""" - conv_norm_act: Callable = ConvNormAct - norm_act: Callable = BatchNormAct2d - act: Callable = nn.ReLU - attn: Optional[Callable] = None - self_attn: Optional[Callable] = None + conv_norm_act: Type[nn.Module] = ConvNormAct + norm_act: Type[nn.Module] = BatchNormAct2d + act: Type[nn.Module] = nn.ReLU + attn: Optional[Type[nn.Module]] = None + self_attn: Optional[Type[nn.Module]] = None class DownsampleAvg(nn.Module): @@ -253,6 +266,8 @@ def __init__( dilation: int = 1, apply_act: bool = False, layers: Optional[LayerFn] = None, + device=None, + dtype=None, ): """Initialize DownsampleAvg. @@ -264,7 +279,8 @@ def __init__( apply_act: Whether to apply activation. layers: Layer factory functions. """ - super(DownsampleAvg, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() layers = layers or LayerFn() avg_stride = stride if dilation == 1 else 1 if stride > 1 or dilation > 1: @@ -272,7 +288,7 @@ def __init__( self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) else: self.pool = nn.Identity() - self.conv = layers.conv_norm_act(in_chs, out_chs, 1, apply_act=apply_act) + self.conv = layers.conv_norm_act(in_chs, out_chs, 1, apply_act=apply_act, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. @@ -340,24 +356,39 @@ def __init__( layers: LayerFn = None, drop_block: Callable = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): - super(BasicBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) groups = num_groups(group_size, mid_chs) self.shortcut = create_shortcut( - downsample, in_chs, out_chs, - stride=stride, dilation=dilation, apply_act=False, layers=layers, + downsample, + in_chs, + out_chs, + stride=stride, + dilation=dilation, + apply_act=False, + layers=layers, + **dd, ) - self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0]) + self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], **dd) self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.conv2_kxk = layers.conv_norm_act( - mid_chs, out_chs, kernel_size, - dilation=dilation[1], groups=groups, drop_layer=drop_block, apply_act=False, + mid_chs, + out_chs, + kernel_size, + dilation=dilation[1], + groups=groups, + drop_layer=drop_block, + apply_act=False, + **dd, ) - self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs, **dd) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) @@ -401,30 +432,51 @@ def __init__( layers: LayerFn = None, drop_block: Callable = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): - super(BottleneckBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() layers = layers or LayerFn() mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) groups = num_groups(group_size, mid_chs) self.shortcut = create_shortcut( - downsample, in_chs, out_chs, - stride=stride, dilation=dilation, apply_act=False, layers=layers, + downsample, + in_chs, + out_chs, + stride=stride, + dilation=dilation, + apply_act=False, + layers=layers, + **dd, ) - self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) + self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1, **dd) self.conv2_kxk = layers.conv_norm_act( - mid_chs, mid_chs, kernel_size, - stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, + mid_chs, + mid_chs, + kernel_size, + stride=stride, + dilation=dilation[0], + groups=groups, + drop_layer=drop_block, + **dd, ) if extra_conv: self.conv2b_kxk = layers.conv_norm_act( - mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups) + mid_chs, + mid_chs, + kernel_size, + dilation=dilation[1], + groups=groups, + **dd, + ) else: self.conv2b_kxk = nn.Identity() - self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) - self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) - self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs, **dd) + self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False, **dd) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs, **dd) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) @@ -475,24 +527,40 @@ def __init__( layers: LayerFn = None, drop_block: Callable = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): - super(DarkBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) groups = num_groups(group_size, mid_chs) self.shortcut = create_shortcut( - downsample, in_chs, out_chs, - stride=stride, dilation=dilation, apply_act=False, layers=layers, + downsample, + in_chs, + out_chs, + stride=stride, + dilation=dilation, + apply_act=False, + layers=layers, + **dd, ) - self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) - self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) + self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1, **dd) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs, **dd) self.conv2_kxk = layers.conv_norm_act( - mid_chs, out_chs, kernel_size, - stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False, + mid_chs, + out_chs, + kernel_size, + stride=stride, + dilation=dilation[0], + groups=groups, + drop_layer=drop_block, + apply_act=False, + **dd, ) - self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs, **dd) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) @@ -540,23 +608,38 @@ def __init__( layers: LayerFn = None, drop_block: Callable = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): - super(EdgeBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) groups = num_groups(group_size, mid_chs) self.shortcut = create_shortcut( - downsample, in_chs, out_chs, - stride=stride, dilation=dilation, apply_act=False, layers=layers, + downsample, + in_chs, + out_chs, + stride=stride, + dilation=dilation, + apply_act=False, + layers=layers, + **dd, ) self.conv1_kxk = layers.conv_norm_act( - in_chs, mid_chs, kernel_size, - stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, + in_chs, + mid_chs, + kernel_size, + stride=stride, + dilation=dilation[0], + groups=groups, + drop_layer=drop_block, + **dd, ) - self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) - self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) - self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs, **dd) + self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False, **dd) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs, **dd) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) @@ -598,9 +681,12 @@ def __init__( layers: LayerFn = None, drop_block: Callable = None, drop_path_rate: float = 0., - inference_mode: bool = False + inference_mode: bool = False, + device=None, + dtype=None, ): - super(RepVggBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.groups = groups = num_groups(group_size, in_chs) layers = layers or LayerFn() @@ -613,19 +699,35 @@ def __init__( dilation=dilation, groups=groups, bias=True, + **dd, ) else: self.reparam_conv = None use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1] - self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None + self.identity = layers.norm_act(out_chs, apply_act=False, **dd) if use_ident else None self.conv_kxk = layers.conv_norm_act( - in_chs, out_chs, kernel_size, - stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False, + in_chs, + out_chs, + kernel_size, + stride=stride, + dilation=dilation[0], + groups=groups, + drop_layer=drop_block, + apply_act=False, + **dd, + ) + self.conv_1x1 = layers.conv_norm_act( + in_chs, + out_chs, + 1, + stride=stride, + groups=groups, + apply_act=False, + **dd, ) - self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() - self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) + self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs, **dd) self.act = layers.act(inplace=True) def init_weights(self, zero_init_last: bool = False): @@ -767,10 +869,13 @@ def __init__( layers: LayerFn = None, drop_block: Callable = None, drop_path_rate: float = 0., + device=None, + dtype=None, ) -> None: """ Construct a MobileOneBlock module. """ - super(MobileOneBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.num_conv_branches = num_conv_branches self.groups = groups = num_groups(group_size, in_chs) layers = layers or LayerFn() @@ -783,31 +888,45 @@ def __init__( stride=stride, dilation=dilation, groups=groups, - bias=True) + bias=True, + **dd, + ) else: self.reparam_conv = None # Re-parameterizable skip connection use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1] - self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None + self.identity = layers.norm_act(out_chs, apply_act=False, **dd) if use_ident else None # Re-parameterizable conv branches convs = [] for _ in range(self.num_conv_branches): convs.append(layers.conv_norm_act( - in_chs, out_chs, kernel_size=kernel_size, - stride=stride, groups=groups, apply_act=False)) + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + groups=groups, + apply_act=False, + **dd, + )) self.conv_kxk = nn.ModuleList(convs) # Re-parameterizable scale branch self.conv_scale = None if kernel_size > 1: self.conv_scale = layers.conv_norm_act( - in_chs, out_chs, kernel_size=1, - stride=stride, groups=groups, apply_act=False) + in_chs, + out_chs, + kernel_size=1, + stride=stride, + groups=groups, + apply_act=False, + **dd, + ) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() - self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) + self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs, **dd) self.act = layers.act(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -953,31 +1072,46 @@ def __init__( layers: LayerFn = None, drop_block: Callable = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): - super(SelfAttnBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() assert layers is not None mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) groups = num_groups(group_size, mid_chs) self.shortcut = create_shortcut( - downsample, in_chs, out_chs, - stride=stride, dilation=dilation, apply_act=False, layers=layers, + downsample, + in_chs, + out_chs, + stride=stride, + dilation=dilation, + apply_act=False, + layers=layers, + **dd, ) - self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) + self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1, **dd) if extra_conv: self.conv2_kxk = layers.conv_norm_act( - mid_chs, mid_chs, kernel_size, - stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, + mid_chs, + mid_chs, + kernel_size, + stride=stride, + dilation=dilation[0], + groups=groups, + drop_layer=drop_block, + **dd, ) stride = 1 # striding done via conv if enabled else: self.conv2_kxk = nn.Identity() opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size) # FIXME need to dilate self attn to have dilated network support, moop moop - self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs) - self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity() - self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) + self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs, **dd) + self.post_attn = layers.norm_act(mid_chs, **dd) if post_attn_na else nn.Identity() + self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False, **dd) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) @@ -1035,7 +1169,10 @@ def __init__( num_act: Optional[int] = None, chs_decay: float = 0.5, layers: LayerFn = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() assert stride in (2, 4) layers = layers or LayerFn() @@ -1066,7 +1203,7 @@ def __init__( if i > 0 and s > 1: last_feat_idx = i - 1 self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0)) - self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s)) + self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s, **dd)) prev_chs = ch curr_stride *= s prev_feat = conv_name @@ -1107,38 +1244,41 @@ def create_byob_stem( pool_type: str = '', feat_prefix: str = 'stem', layers: LayerFn = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} layers = layers or LayerFn() assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', 'one', '7x7', '3x3') if 'quad' in stem_type: # based on NFNet stem, stack of 4 3x3 convs num_act = 2 if 'quad2' in stem_type else None - stem = Stem(in_chs, out_chs, num_rep=4, num_act=num_act, pool=pool_type, layers=layers) + stem = Stem(in_chs, out_chs, num_rep=4, num_act=num_act, pool=pool_type, layers=layers, **dd) elif 'tiered' in stem_type: # 3x3 stack of 3 convs as in my ResNet-T - stem = Stem(in_chs, (3 * out_chs // 8, out_chs // 2, out_chs), pool=pool_type, layers=layers) + stem = Stem(in_chs, (3 * out_chs // 8, out_chs // 2, out_chs), pool=pool_type, layers=layers, **dd) elif 'deep' in stem_type: # 3x3 stack of 3 convs as in ResNet-D - stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers) + stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers, **dd) elif 'rep' in stem_type: - stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers) + stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers, **dd) elif 'one' in stem_type: - stem = MobileOneBlock(in_chs, out_chs, kernel_size=3, stride=2, layers=layers) + stem = MobileOneBlock(in_chs, out_chs, kernel_size=3, stride=2, layers=layers, **dd) elif '7x7' in stem_type: # 7x7 stem conv as in ResNet if pool_type: - stem = Stem(in_chs, out_chs, 7, num_rep=1, pool=pool_type, layers=layers) + stem = Stem(in_chs, out_chs, 7, num_rep=1, pool=pool_type, layers=layers, **dd) else: - stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2) + stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2, **dd) else: if isinstance(out_chs, (tuple, list)): - stem = Stem(in_chs, out_chs, 3, pool=pool_type, layers=layers) + stem = Stem(in_chs, out_chs, 3, pool=pool_type, layers=layers, **dd) else: # 3x3 stem conv as in RegNet is the default if pool_type: - stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers) + stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers, **dd) else: - stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2) + stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2, **dd) if isinstance(stem, Stem): feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info] @@ -1207,6 +1347,8 @@ def create_byob_stages( feat_size: Optional[int] = None, layers: Optional[LayerFn] = None, block_kwargs_fn: Optional[Callable] = update_block_kwargs, + device=None, + dtype=None, ): layers = layers or LayerFn() feature_info = [] @@ -1244,6 +1386,8 @@ def create_byob_stages( downsample=cfg.downsample, drop_path_rate=dpr[stage_idx][block_idx], layers=layers, + device=device, + dtype=dtype, ) if block_cfg.type in ('self_attn',): # add feat_size arg for blocks that support/need it @@ -1295,6 +1439,8 @@ def __init__( drop_rate: float = 0., drop_path_rate: float = 0., zero_init_last: bool = True, + device=None, + dtype=None, **kwargs, ): """ @@ -1311,6 +1457,7 @@ def __init__( **kwargs: Extra kwargs overlayed onto cfg. """ super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.drop_rate = drop_rate self.grad_checkpointing = False @@ -1333,6 +1480,7 @@ def __init__( stem_type=cfg.stem_type, pool_type=cfg.stem_pool, layers=stem_layers, + **dd, ) self.feature_info.extend(stem_feat[:-1]) feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction']) @@ -1344,6 +1492,7 @@ def __init__( stem_feat[-1], layers=stage_layers, feat_size=feat_size, + **dd, ) self.feature_info.extend(stage_feat[:-1]) reduction = stage_feat[-1]['reduction'] @@ -1351,7 +1500,7 @@ def __init__( prev_chs = stage_feat[-1]['num_chs'] if cfg.num_features: self.num_features = int(round(cfg.width_factor * cfg.num_features)) - self.final_conv = stage_layers.conv_norm_act(prev_chs, self.num_features, 1) + self.final_conv = stage_layers.conv_norm_act(prev_chs, self.num_features, 1, **dd) else: self.num_features = prev_chs self.final_conv = nn.Identity() @@ -1372,6 +1521,7 @@ def __init__( norm_layer=cfg.norm_layer, act_layer=cfg.act_layer, drop_rate=self.drop_rate, + **dd, ) self.head_hidden_size = self.head.hidden_size elif cfg.head_type == 'attn_abs': @@ -1386,6 +1536,7 @@ def __init__( pool_type=global_pool, drop_rate=self.drop_rate, qkv_separate=True, + **dd, ) self.head_hidden_size = self.head.embed_dim elif cfg.head_type == 'attn_rot': @@ -1400,6 +1551,7 @@ def __init__( pool_type=global_pool, drop_rate=self.drop_rate, qkv_separate=True, + **dd, ) self.head_hidden_size = self.head.embed_dim else: @@ -1411,6 +1563,7 @@ def __init__( num_classes, pool_type=global_pool, drop_rate=self.drop_rate, + **dd, ) self.global_pool = global_pool diff --git a/timm/models/cait.py b/timm/models/cait.py index 2c500ec3dd..b484ef4d8c 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -9,7 +9,7 @@ # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type, Any import torch import torch.nn as nn @@ -29,18 +29,28 @@ class ClassAttn(nn.Module): # with slight modifications to do CA fused_attn: torch.jit.Final[bool] - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + device=None, + dtype=None, + ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.q = nn.Linear(dim, dim, bias=qkv_bias) - self.k = nn.Linear(dim, dim, bias=qkv_bias) - self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd) + self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd) + self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): @@ -73,39 +83,44 @@ class LayerScaleBlockClassAttn(nn.Module): # with slight modifications to add CA and LayerScale def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - attn_block=ClassAttn, - mlp_block=Mlp, - init_values=1e-4, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + attn_block: Type[nn.Module] = ClassAttn, + mlp_block: Type[nn.Module] = Mlp, + init_values: float = 1e-4, + device=None, + dtype=None, ): super().__init__() - self.norm1 = norm_layer(dim) + dd = {'device': device, 'dtype': dtype} + self.norm1 = norm_layer(dim, **dd) self.attn = attn_block( dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, + **dd, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = mlp_block( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop, + **dd, ) - self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) - self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) + self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd)) + self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd)) def forward(self, x, x_cls): u = torch.cat((x_cls, x), dim=1) @@ -117,8 +132,18 @@ def forward(self, x, x_cls): class TalkingHeadAttn(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf) - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + device=None, + dtype=None, + ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_heads = num_heads @@ -126,13 +151,13 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.) self.scale = head_dim ** -0.5 - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) - self.proj_l = nn.Linear(num_heads, num_heads) - self.proj_w = nn.Linear(num_heads, num_heads) + self.proj_l = nn.Linear(num_heads, num_heads, **dd) + self.proj_w = nn.Linear(num_heads, num_heads, **dd) self.proj_drop = nn.Dropout(proj_drop) @@ -161,39 +186,44 @@ class LayerScaleBlock(nn.Module): # with slight modifications to add layerScale def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - attn_block=TalkingHeadAttn, - mlp_block=Mlp, - init_values=1e-4, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + attn_block: Type[nn.Module] = TalkingHeadAttn, + mlp_block: Type[nn.Module] = Mlp, + init_values: float = 1e-4, + device=None, + dtype=None, ): super().__init__() - self.norm1 = norm_layer(dim) + dd = {'device': device, 'dtype': dtype} + self.norm1 = norm_layer(dim, **dd) self.attn = attn_block( dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, + **dd, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = mlp_block( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop, + **dd, ) - self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) - self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) + self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd)) + self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd)) def forward(self, x): x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) @@ -206,35 +236,38 @@ class Cait(nn.Module): # with slight modifications to adapt to our cait models def __init__( self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - global_pool='token', - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4., - qkv_bias=True, - drop_rate=0., - pos_drop_rate=0., - proj_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - block_layers=LayerScaleBlock, - block_layers_token=LayerScaleBlockClassAttn, - patch_layer=PatchEmbed, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - act_layer=nn.GELU, - attn_block=TalkingHeadAttn, - mlp_block=Mlp, - init_values=1e-4, - attn_block_token_only=ClassAttn, - mlp_block_token_only=Mlp, - depth_token_only=2, - mlp_ratio_token_only=4.0 + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'token', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_rate: float = 0., + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + block_layers: Type[nn.Module] = LayerScaleBlock, + block_layers_token: Type[nn.Module] = LayerScaleBlockClassAttn, + patch_layer: Type[nn.Module] = PatchEmbed, + norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), + act_layer: Type[nn.Module] = nn.GELU, + attn_block: Type[nn.Module] = TalkingHeadAttn, + mlp_block: Type[nn.Module] = Mlp, + init_values: float = 1e-4, + attn_block_token_only: Type[nn.Module] = ClassAttn, + mlp_block_token_only: Type[nn.Module] = Mlp, + depth_token_only: int = 2, + mlp_ratio_token_only: float = 4.0, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} assert global_pool in ('', 'token', 'avg') self.num_classes = num_classes @@ -247,12 +280,13 @@ def __init__( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + **dd, ) num_patches = self.patch_embed.num_patches r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, **dd)) self.pos_drop = nn.Dropout(p=pos_drop_rate) dpr = [drop_path_rate for i in range(depth)] @@ -269,6 +303,7 @@ def __init__( attn_block=attn_block, mlp_block=mlp_block, init_values=init_values, + **dd, ) for i in range(depth)]) self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)] @@ -282,12 +317,13 @@ def __init__( attn_block=attn_block_token_only, mlp_block=mlp_block_token_only, init_values=init_values, + **dd, ) for _ in range(depth_token_only)]) - self.norm = norm_layer(embed_dim) + self.norm = norm_layer(embed_dim, **dd) self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) diff --git a/timm/models/coat.py b/timm/models/coat.py index 3fa4f69666..770cf18fb7 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -7,7 +7,7 @@ Modified from timm/models/vision_transformer.py """ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type, Any import torch import torch.nn as nn @@ -23,7 +23,14 @@ class ConvRelPosEnc(nn.Module): """ Convolutional relative position encoding. """ - def __init__(self, head_chs, num_heads, window): + def __init__( + self, + head_chs: int, + num_heads: int, + window: Union[int, dict], + device=None, + dtype=None, + ): """ Initialization. Ch: Channels per head. @@ -35,6 +42,7 @@ def __init__(self, head_chs, num_heads, window): e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2}) It will apply different window size to the attention head splits. """ + dd = {'device': device, 'dtype': dtype} super().__init__() if isinstance(window, int): @@ -60,6 +68,7 @@ def __init__(self, head_chs, num_heads, window): padding=(padding_size, padding_size), dilation=(dilation, dilation), groups=cur_head_split * head_chs, + **dd, ) self.conv_list.append(cur_conv) self.head_splits.append(cur_head_split) @@ -91,21 +100,24 @@ class FactorAttnConvRelPosEnc(nn.Module): """ Factorized attention with convolutional relative position encoding class. """ def __init__( self, - dim, - num_heads=8, - qkv_bias=False, - attn_drop=0., - proj_drop=0., - shared_crpe=None, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + shared_crpe: Optional[Any] = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) # Shared convolutional relative position encoding. @@ -141,9 +153,16 @@ class ConvPosEnc(nn.Module): """ Convolutional Position Encoding. Note: This module is similar to the conditional position encoding in CPVT. """ - def __init__(self, dim, k=3): - super(ConvPosEnc, self).__init__() - self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim) + def __init__( + self, + dim: int, + k: int = 3, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim, **dd) def forward(self, x, size: Tuple[int, int]): B, N, C = x.shape @@ -169,24 +188,27 @@ class SerialBlock(nn.Module): Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """ def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - shared_cpe=None, - shared_crpe=None, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + shared_cpe: Optional[Any] = None, + shared_crpe: Optional[Any] = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() # Conv-Attention. self.cpe = shared_cpe - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.factoratt_crpe = FactorAttnConvRelPosEnc( dim, num_heads=num_heads, @@ -194,17 +216,19 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, shared_crpe=shared_crpe, + **dd, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # MLP. - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop, + **dd, ) def forward(self, x, size: Tuple[int, int]): @@ -226,23 +250,28 @@ class ParallelBlock(nn.Module): """ Parallel block class. """ def __init__( self, - dims, - num_heads, - mlp_ratios=[], - qkv_bias=False, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - shared_crpes=None, + dims: List[int], + num_heads: int, + mlp_ratios: List[float] = None, + qkv_bias: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + shared_crpes: Optional[List[Any]] = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() + if mlp_ratios is None: + mlp_ratios = [] # Conv-Attention. - self.norm12 = norm_layer(dims[1]) - self.norm13 = norm_layer(dims[2]) - self.norm14 = norm_layer(dims[3]) + self.norm12 = norm_layer(dims[1], **dd) + self.norm13 = norm_layer(dims[2], **dd) + self.norm14 = norm_layer(dims[3], **dd) self.factoratt_crpe2 = FactorAttnConvRelPosEnc( dims[1], num_heads=num_heads, @@ -250,6 +279,7 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, shared_crpe=shared_crpes[1], + **dd, ) self.factoratt_crpe3 = FactorAttnConvRelPosEnc( dims[2], @@ -258,6 +288,7 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, shared_crpe=shared_crpes[2], + **dd, ) self.factoratt_crpe4 = FactorAttnConvRelPosEnc( dims[3], @@ -266,13 +297,14 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, shared_crpe=shared_crpes[3], + **dd, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # MLP. - self.norm22 = norm_layer(dims[1]) - self.norm23 = norm_layer(dims[2]) - self.norm24 = norm_layer(dims[3]) + self.norm22 = norm_layer(dims[1], **dd) + self.norm23 = norm_layer(dims[2], **dd) + self.norm24 = norm_layer(dims[3], **dd) # In parallel block, we assume dimensions are the same and share the linear transformation. assert dims[1] == dims[2] == dims[3] assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3] @@ -282,6 +314,7 @@ def __init__( hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop, + **dd, ) def upsample(self, x, factor: float, size: Tuple[int, int]): @@ -354,27 +387,30 @@ class CoaT(nn.Module): """ CoaT class. """ def __init__( self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - embed_dims=(64, 128, 320, 512), - serial_depths=(3, 4, 6, 3), - parallel_depth=0, - num_heads=8, - mlp_ratios=(4, 4, 4, 4), - qkv_bias=True, - drop_rate=0., - proj_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - norm_layer=LayerNorm, - return_interm_layers=False, - out_features=None, - crpe_window=None, - global_pool='token', + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + num_classes: int = 1000, + embed_dims: Tuple[int, int, int, int] = (64, 128, 320, 512), + serial_depths: Tuple[int, int, int, int] = (3, 4, 6, 3), + parallel_depth: int = 0, + num_heads: int = 8, + mlp_ratios: Tuple[float, float, float, float] = (4, 4, 4, 4), + qkv_bias: bool = True, + drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_layer: Type[nn.Module] = LayerNorm, + return_interm_layers: bool = False, + out_features: Optional[List[str]] = None, + crpe_window: Optional[dict] = None, + global_pool: str = 'token', + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} assert global_pool in ('token', 'avg') crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} self.return_interm_layers = return_interm_layers @@ -388,34 +424,34 @@ def __init__( img_size = to_2tuple(img_size) self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, - embed_dim=embed_dims[0], norm_layer=nn.LayerNorm) + embed_dim=embed_dims[0], norm_layer=nn.LayerNorm, **dd) self.patch_embed2 = PatchEmbed( img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0], - embed_dim=embed_dims[1], norm_layer=nn.LayerNorm) + embed_dim=embed_dims[1], norm_layer=nn.LayerNorm, **dd) self.patch_embed3 = PatchEmbed( img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1], - embed_dim=embed_dims[2], norm_layer=nn.LayerNorm) + embed_dim=embed_dims[2], norm_layer=nn.LayerNorm, **dd) self.patch_embed4 = PatchEmbed( img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2], - embed_dim=embed_dims[3], norm_layer=nn.LayerNorm) + embed_dim=embed_dims[3], norm_layer=nn.LayerNorm, **dd) # Class tokens. - self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0])) - self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1])) - self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2])) - self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3])) + self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0], **dd)) + self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1], **dd)) + self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2], **dd)) + self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3], **dd)) # Convolutional position encodings. - self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3) - self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3) - self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3) - self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3) + self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3, **dd) + self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3, **dd) + self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3, **dd) + self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3, **dd) # Convolutional relative position encodings. - self.crpe1 = ConvRelPosEnc(head_chs=embed_dims[0] // num_heads, num_heads=num_heads, window=crpe_window) - self.crpe2 = ConvRelPosEnc(head_chs=embed_dims[1] // num_heads, num_heads=num_heads, window=crpe_window) - self.crpe3 = ConvRelPosEnc(head_chs=embed_dims[2] // num_heads, num_heads=num_heads, window=crpe_window) - self.crpe4 = ConvRelPosEnc(head_chs=embed_dims[3] // num_heads, num_heads=num_heads, window=crpe_window) + self.crpe1 = ConvRelPosEnc(head_chs=embed_dims[0] // num_heads, num_heads=num_heads, window=crpe_window, **dd) + self.crpe2 = ConvRelPosEnc(head_chs=embed_dims[1] // num_heads, num_heads=num_heads, window=crpe_window, **dd) + self.crpe3 = ConvRelPosEnc(head_chs=embed_dims[2] // num_heads, num_heads=num_heads, window=crpe_window, **dd) + self.crpe4 = ConvRelPosEnc(head_chs=embed_dims[3] // num_heads, num_heads=num_heads, window=crpe_window, **dd) dpr = drop_path_rate skwargs = dict( @@ -435,6 +471,7 @@ def __init__( shared_cpe=self.cpe1, shared_crpe=self.crpe1, **skwargs, + **dd, ) for _ in range(serial_depths[0])] ) @@ -447,6 +484,7 @@ def __init__( shared_cpe=self.cpe2, shared_crpe=self.crpe2, **skwargs, + **dd, ) for _ in range(serial_depths[1])] ) @@ -459,6 +497,7 @@ def __init__( shared_cpe=self.cpe3, shared_crpe=self.crpe3, **skwargs, + **dd, ) for _ in range(serial_depths[2])] ) @@ -471,6 +510,7 @@ def __init__( shared_cpe=self.cpe4, shared_crpe=self.crpe4, **skwargs, + **dd, ) for _ in range(serial_depths[3])] ) @@ -484,6 +524,7 @@ def __init__( mlp_ratios=mlp_ratios, shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4), **skwargs, + **dd, ) for _ in range(parallel_depth)] ) @@ -493,23 +534,23 @@ def __init__( # Classification head(s). if not self.return_interm_layers: if self.parallel_blocks is not None: - self.norm2 = norm_layer(embed_dims[1]) - self.norm3 = norm_layer(embed_dims[2]) + self.norm2 = norm_layer(embed_dims[1], **dd) + self.norm3 = norm_layer(embed_dims[2], **dd) else: self.norm2 = self.norm3 = None - self.norm4 = norm_layer(embed_dims[3]) + self.norm4 = norm_layer(embed_dims[3], **dd) if self.parallel_depth > 0: # CoaT series: Aggregate features of last three scales for classification. assert embed_dims[1] == embed_dims[2] == embed_dims[3] - self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1) + self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1, **dd) self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity() else: # CoaT-Lite series: Use feature of last scale for classification. self.aggregate = None self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity() # Initialize weights. trunc_normal_(self.cls_token1, std=.02) diff --git a/timm/models/convit.py b/timm/models/convit.py index 3dd8adfd23..ece9ad0895 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -21,7 +21,7 @@ '''These modules are adapted from those of timm, see https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ''' -from typing import Optional +from typing import Optional, Union, Type, Any import torch import torch.nn as nn @@ -40,13 +40,16 @@ class GPSA(nn.Module): def __init__( self, - dim, - num_heads=8, - qkv_bias=False, - attn_drop=0., - proj_drop=0., - locality_strength=1., + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + locality_strength: float = 1., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads self.dim = dim @@ -54,15 +57,15 @@ def __init__( self.scale = head_dim ** -0.5 self.locality_strength = locality_strength - self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) - self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias, **dd) + self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.pos_proj = nn.Linear(3, num_heads) + self.proj = nn.Linear(dim, dim, **dd) + self.pos_proj = nn.Linear(3, num_heads, **dd) self.proj_drop = nn.Dropout(proj_drop) - self.gating_param = nn.Parameter(torch.ones(self.num_heads)) - self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None + self.gating_param = nn.Parameter(torch.ones(self.num_heads, **dd)) + self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3, **dd) # silly torchscript hack, won't work with None def forward(self, x): B, N, C = x.shape @@ -117,7 +120,10 @@ def local_init(self): def get_rel_indices(self, num_patches: int) -> torch.Tensor: img_size = int(num_patches ** .5) rel_indices = torch.zeros(1, num_patches, num_patches, 3) - ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1) + ind = ( + torch.arange(img_size, dtype=torch.float32).view(1, -1) + - torch.arange(img_size, dtype=torch.float32).view(-1, 1) + ) indx = ind.repeat(img_size, img_size) indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1) indd = indx ** 2 + indy ** 2 @@ -125,26 +131,30 @@ def get_rel_indices(self, num_patches: int) -> torch.Tensor: rel_indices[:, :, :, 1] = indy.unsqueeze(0) rel_indices[:, :, :, 0] = indx.unsqueeze(0) device = self.qk.weight.device - return rel_indices.to(device) + dtype = self.qk.weight.dtype + return rel_indices.to(device=device, dtype=dtype) class MHSA(nn.Module): def __init__( self, - dim, - num_heads=8, - qkv_bias=False, - attn_drop=0., - proj_drop=0., + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) def get_attention_map(self, x, return_map=False): @@ -155,12 +165,15 @@ def get_attention_map(self, x, return_map=False): attn_map = attn_map.softmax(dim=-1).mean(0) img_size = int(N ** .5) - ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1) + ind = ( + torch.arange(img_size, dtype=torch.float32).view(1, -1) + - torch.arange(img_size, dtype=torch.float32).view(-1, 1) + ) indx = ind.repeat(img_size, img_size) indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1) indd = indx ** 2 + indy ** 2 distances = indd ** .5 - distances = distances.to(x.device) + distances = distances.to(attn_map.device, attn_map.dtype) dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N if return_map: @@ -187,20 +200,23 @@ class Block(nn.Module): def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=LayerNorm, - use_gpsa=True, - locality_strength=1., + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = LayerNorm, + use_gpsa: bool = True, + locality_strength: float = 1., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.use_gpsa = use_gpsa if self.use_gpsa: self.attn = GPSA( @@ -210,6 +226,7 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, locality_strength=locality_strength, + **dd, ) else: self.attn = MHSA( @@ -218,15 +235,17 @@ def __init__( qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, + **dd, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop, + **dd, ) def forward(self, x): @@ -241,28 +260,31 @@ class ConVit(nn.Module): def __init__( self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - global_pool='token', - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4., - qkv_bias=False, - drop_rate=0., - pos_drop_rate=0., - proj_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - hybrid_backbone=None, - norm_layer=LayerNorm, - local_up_to_layer=3, - locality_strength=1., - use_pos_embed=True, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'token', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = False, + drop_rate: float = 0., + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + hybrid_backbone: Optional[Any] = None, + norm_layer: Type[nn.Module] = LayerNorm, + local_up_to_layer: int = 3, + locality_strength: float = 1., + use_pos_embed: bool = True, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} assert global_pool in ('', 'avg', 'token') embed_dim *= num_heads self.num_classes = num_classes @@ -274,22 +296,28 @@ def __init__( if hybrid_backbone is not None: self.patch_embed = HybridEmbed( - hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + hybrid_backbone, + img_size=img_size, + in_chans=in_chans, + embed_dim=embed_dim, + **dd, + ) else: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + **dd, ) num_patches = self.patch_embed.num_patches self.num_patches = num_patches - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd)) self.pos_drop = nn.Dropout(p=pos_drop_rate) if self.use_pos_embed: - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, **dd)) trunc_normal_(self.pos_embed, std=.02) dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule @@ -305,13 +333,14 @@ def __init__( norm_layer=norm_layer, use_gpsa=i < local_up_to_layer, locality_strength=locality_strength, + **dd, ) for i in range(depth)]) - self.norm = norm_layer(embed_dim) + self.norm = norm_layer(embed_dim, **dd) # Classifier head self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')] self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity() trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights) diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py index de986e3baf..55978f67f8 100644 --- a/timm/models/convmixer.py +++ b/timm/models/convmixer.py @@ -1,7 +1,7 @@ """ ConvMixer """ -from typing import Optional +from typing import Optional, Type import torch import torch.nn as nn @@ -16,7 +16,7 @@ class Residual(nn.Module): - def __init__(self, fn): + def __init__(self, fn: nn.Module): super().__init__() self.fn = fn @@ -27,42 +27,45 @@ def forward(self, x): class ConvMixer(nn.Module): def __init__( self, - dim, - depth, - kernel_size=9, - patch_size=7, - in_chans=3, - num_classes=1000, - global_pool='avg', - drop_rate=0., - act_layer=nn.GELU, + dim: int, + depth: int, + kernel_size: int = 9, + patch_size: int = 7, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + drop_rate: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + device=None, + dtype=None, **kwargs, ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.num_features = self.head_hidden_size = dim self.grad_checkpointing = False self.stem = nn.Sequential( - nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size), + nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size, **dd), act_layer(), - nn.BatchNorm2d(dim) + nn.BatchNorm2d(dim, **dd) ) self.blocks = nn.Sequential( *[nn.Sequential( Residual(nn.Sequential( - nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), + nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same", **dd), act_layer(), - nn.BatchNorm2d(dim) + nn.BatchNorm2d(dim, **dd) )), - nn.Conv2d(dim, dim, kernel_size=1), + nn.Conv2d(dim, dim, kernel_size=1, **dd), act_layer(), - nn.BatchNorm2d(dim) + nn.BatchNorm2d(dim, **dd) ) for i in range(depth)] ) self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(dim, num_classes, **dd) if num_classes > 0 else nn.Identity() @torch.jit.ignore def group_matcher(self, coarse=False): diff --git a/timm/models/convnext.py b/timm/models/convnext.py index c236571058..c801f7c887 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -44,10 +44,27 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, calculate_drop_path_rates, Mlp, GlobalResponseNormMlp, \ - LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple -from timm.layers import SimpleNorm2d, SimpleNorm -from timm.layers import NormMlpClassifierHead, ClassifierHead +from timm.layers import ( + trunc_normal_, + AvgPool2dSame, + DropPath, + calculate_drop_path_rates, + Mlp, + GlobalResponseNormMlp, + LayerNorm2d, + LayerNorm, + RmsNorm2d, + RmsNorm, + SimpleNorm2d, + SimpleNorm, + create_conv2d, + get_act_layer, + get_norm_layer, + make_divisible, + to_ntuple, + NormMlpClassifierHead, + ClassifierHead, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq @@ -59,7 +76,15 @@ class Downsample(nn.Module): """Downsample module for ConvNeXt.""" - def __init__(self, in_chs: int, out_chs: int, stride: int = 1, dilation: int = 1) -> None: + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: int = 1, + device=None, + dtype=None, + ) -> None: """Initialize Downsample module. Args: @@ -68,6 +93,7 @@ def __init__(self, in_chs: int, out_chs: int, stride: int = 1, dilation: int = 1 stride: Stride for downsampling. dilation: Dilation rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() avg_stride = stride if dilation == 1 else 1 if stride > 1 or dilation > 1: @@ -77,7 +103,7 @@ def __init__(self, in_chs: int, out_chs: int, stride: int = 1, dilation: int = 1 self.pool = nn.Identity() if in_chs != out_chs: - self.conv = create_conv2d(in_chs, out_chs, 1, stride=1) + self.conv = create_conv2d(in_chs, out_chs, 1, stride=1, **dd) else: self.conv = nn.Identity() @@ -115,6 +141,8 @@ def __init__( act_layer: Union[str, Callable] = 'gelu', norm_layer: Optional[Callable] = None, drop_path: float = 0., + device=None, + dtype=None, ): """ @@ -133,6 +161,7 @@ def __init__( norm_layer: Normalization layer (defaults to LN if not specified). drop_path: Stochastic depth probability. """ + dd = {'device': device, 'dtype': dtype} super().__init__() out_chs = out_chs or in_chs dilation = to_ntuple(2)(dilation) @@ -149,12 +178,18 @@ def __init__( dilation=dilation[0], depthwise=True, bias=conv_bias, + **dd, + ) + self.norm = norm_layer(out_chs, **dd) + self.mlp = mlp_layer( + out_chs, + int(mlp_ratio * out_chs), + act_layer=act_layer, + **dd, ) - self.norm = norm_layer(out_chs) - self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer) - self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None + self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs, **dd)) if ls_init_value is not None else None if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0]) + self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0], **dd) else: self.shortcut = nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -196,7 +231,9 @@ def __init__( use_grn: bool = False, act_layer: Union[str, Callable] = 'gelu', norm_layer: Optional[Callable] = None, - norm_layer_cl: Optional[Callable] = None + norm_layer_cl: Optional[Callable] = None, + device=None, + dtype=None, ) -> None: """Initialize ConvNeXt stage. @@ -216,6 +253,7 @@ def __init__( norm_layer: Normalization layer. norm_layer_cl: Normalization layer for channels last. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.grad_checkpointing = False @@ -223,7 +261,7 @@ def __init__( ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1 pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used self.downsample = nn.Sequential( - norm_layer(in_chs), + norm_layer(in_chs, **dd), create_conv2d( in_chs, out_chs, @@ -232,6 +270,7 @@ def __init__( dilation=dilation[0], padding=pad, bias=conv_bias, + **dd, ), ) in_chs = out_chs @@ -253,6 +292,7 @@ def __init__( use_grn=use_grn, act_layer=act_layer, norm_layer=norm_layer if conv_mlp else norm_layer_cl, + **dd, )) in_chs = out_chs self.blocks = nn.Sequential(*stage_blocks) @@ -324,6 +364,8 @@ def __init__( norm_eps: Optional[float] = None, drop_rate: float = 0., drop_path_rate: float = 0., + device=None, + dtype=None, ): """ Args: @@ -349,6 +391,7 @@ def __init__( drop_path_rate: Stochastic depth drop rate. """ super().__init__() + dd = {'device': device, 'dtype': dtype} assert output_stride in (8, 16, 32) kernel_sizes = to_ntuple(4)(kernel_sizes) norm_layer, norm_layer_cl = _get_norm_layers(norm_layer, conv_mlp, norm_eps) @@ -362,17 +405,17 @@ def __init__( if stem_type == 'patch': # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 self.stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias), - norm_layer(dims[0]), + nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias, **dd), + norm_layer(dims[0], **dd), ) stem_stride = patch_size else: mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0] self.stem = nn.Sequential(*filter(None, [ - nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias), + nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd), act_layer() if 'act' in stem_type else None, - nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias), - norm_layer(dims[0]), + nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd), + norm_layer(dims[0], **dd), ])) stem_stride = 4 @@ -406,6 +449,7 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, norm_layer_cl=norm_layer_cl, + **dd, )) prev_chs = out_chs # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 @@ -417,12 +461,13 @@ def __init__( # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) if head_norm_first: assert not head_hidden_size - self.norm_pre = norm_layer(self.num_features) + self.norm_pre = norm_layer(self.num_features, **dd) self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, + **dd, ) else: self.norm_pre = nn.Identity() @@ -434,6 +479,7 @@ def __init__( drop_rate=self.drop_rate, norm_layer=norm_layer, act_layer='gelu', + **dd, ) self.head_hidden_size = self.head.num_features named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index c2f32b754a..a19489b358 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -21,7 +21,7 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -40,7 +40,17 @@ class PatchEmbed(nn.Module): """ Image to Patch Embedding """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False): + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + multi_conv: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -51,22 +61,22 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi if multi_conv: if patch_size[0] == 12: self.proj = nn.Sequential( - nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3), + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3, **dd), nn.ReLU(inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0), + nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0, **dd), nn.ReLU(inplace=True), - nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1), + nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1, **dd), ) elif patch_size[0] == 16: self.proj = nn.Sequential( - nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3), + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3, **dd), nn.ReLU(inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1), + nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1, **dd), nn.ReLU(inplace=True), - nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), + nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1, **dd), ) else: - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, **dd) def forward(self, x): B, C, H, W = x.shape @@ -82,23 +92,26 @@ def forward(self, x): class CrossAttention(nn.Module): def __init__( self, - dim, - num_heads=8, - qkv_bias=False, - attn_drop=0., - proj_drop=0., + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = head_dim ** -0.5 - self.wq = nn.Linear(dim, dim, bias=qkv_bias) - self.wk = nn.Linear(dim, dim, bias=qkv_bias) - self.wv = nn.Linear(dim, dim, bias=qkv_bias) + self.wq = nn.Linear(dim, dim, bias=qkv_bias, **dd) + self.wk = nn.Linear(dim, dim, bias=qkv_bias, **dd) + self.wv = nn.Linear(dim, dim, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): @@ -124,24 +137,28 @@ class CrossAttentionBlock(nn.Module): def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = CrossAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, + **dd, ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -155,20 +172,22 @@ class MultiScaleBlock(nn.Module): def __init__( self, - dim, - patches, - depth, - num_heads, - mlp_ratio, - qkv_bias=False, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, + dim: Tuple[int, ...], + patches: Tuple[int, ...], + depth: Tuple[int, ...], + num_heads: Tuple[int, ...], + mlp_ratio: Tuple[float, ...], + qkv_bias: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: Union[List[float], float] = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - num_branches = len(dim) self.num_branches = num_branches # different branch could have different embedding size, the first one is the base @@ -185,6 +204,7 @@ def __init__( attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer, + **dd, )) if len(tmp) != 0: self.blocks.append(nn.Sequential(*tmp)) @@ -197,7 +217,7 @@ def __init__( if dim[d] == dim[(d + 1) % num_branches] and False: tmp = [nn.Identity()] else: - tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches])] + tmp = [norm_layer(dim[d], **dd), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches], **dd)] self.projs.append(nn.Sequential(*tmp)) self.fusion = nn.ModuleList() @@ -215,6 +235,7 @@ def __init__( attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer, + **dd, )) else: tmp = [] @@ -228,6 +249,7 @@ def __init__( attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer, + **dd, )) self.fusion.append(nn.Sequential(*tmp)) @@ -236,8 +258,8 @@ def __init__( if dim[(d + 1) % num_branches] == dim[d] and False: tmp = [nn.Identity()] else: - tmp = [norm_layer(dim[(d + 1) % num_branches]), act_layer(), - nn.Linear(dim[(d + 1) % num_branches], dim[d])] + tmp = [norm_layer(dim[(d + 1) % num_branches], **dd), act_layer(), + nn.Linear(dim[(d + 1) % num_branches], dim[d], **dd)] self.revert_projs.append(nn.Sequential(*tmp)) def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: @@ -293,27 +315,30 @@ class CrossVit(nn.Module): def __init__( self, - img_size=224, - img_scale=(1.0, 1.0), - patch_size=(8, 16), - in_chans=3, - num_classes=1000, - embed_dim=(192, 384), - depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), - num_heads=(6, 12), - mlp_ratio=(2., 2., 4.), - multi_conv=False, - crop_scale=False, - qkv_bias=True, - drop_rate=0., - pos_drop_rate=0., - proj_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - norm_layer=partial(nn.LayerNorm, eps=1e-6), - global_pool='token', + img_size: int = 224, + img_scale: Tuple[float, ...] = (1.0, 1.0), + patch_size: Tuple[int, ...] = (8, 16), + in_chans: int = 3, + num_classes: int = 1000, + embed_dim: Tuple[int, ...] = (192, 384), + depth: Tuple[Tuple[int, ...], ...] = ((1, 3, 1), (1, 3, 1), (1, 3, 1)), + num_heads: Tuple[int, ...] = (6, 12), + mlp_ratio: Tuple[float, ...] = (2., 2., 4.), + multi_conv: bool = False, + crop_scale: bool = False, + qkv_bias: bool = True, + drop_rate: float = 0., + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), + global_pool: str = 'token', + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} assert global_pool in ('token', 'avg') self.num_classes = num_classes @@ -330,8 +355,8 @@ def __init__( # hard-coded for torch jit script for i in range(self.num_branches): - setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i]))) - setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i]))) + setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i], **dd))) + setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i], **dd))) for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim): self.patch_embed.append( @@ -341,6 +366,7 @@ def __init__( in_chans=in_chans, embed_dim=d, multi_conv=multi_conv, + **dd, )) self.pos_drop = nn.Dropout(p=pos_drop_rate) @@ -363,14 +389,15 @@ def __init__( attn_drop=attn_drop_rate, drop_path=dpr_, norm_layer=norm_layer, + **dd, ) dpr_ptr += curr_depth self.blocks.append(blk) - self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)]) + self.norm = nn.ModuleList([norm_layer(embed_dim[i], **dd) for i in range(self.num_branches)]) self.head_drop = nn.Dropout(drop_rate) self.head = nn.ModuleList([ - nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() + nn.Linear(embed_dim[i], num_classes, **dd) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) for i in range(self.num_branches): @@ -418,8 +445,11 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): if global_pool is not None: assert global_pool in ('token', 'avg') self.global_pool = global_pool + device = self.head[0].weight.device if hasattr(self.head[0], 'weight') else None + dtype = self.head[0].weight.dtype if hasattr(self.head[0], 'weight') else None + dd = {'device': device, 'dtype': dtype} self.head = nn.ModuleList([ - nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() + nn.Linear(self.embed_dim[i], num_classes, **dd) if num_classes > 0 else nn.Identity() for i in range(self.num_branches) ]) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 5dfebc56a9..6de3b354ee 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -14,7 +14,7 @@ """ from dataclasses import dataclass, asdict, replace from functools import partial -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -138,31 +138,41 @@ class BottleneckBlock(nn.Module): def __init__( self, - in_chs, - out_chs, - dilation=1, - bottle_ratio=0.25, - groups=1, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - attn_last=False, - attn_layer=None, - drop_block=None, - drop_path=0. + in_chs: int, + out_chs: int, + dilation: int = 1, + bottle_ratio: float = 0.25, + groups: int = 1, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + attn_last: bool = False, + attn_layer: Optional[Type[nn.Module]] = None, + drop_block: Optional[Type[nn.Module]] = None, + drop_path: float = 0., + device=None, + dtype=None, ): - super(BottleneckBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) attn_last = attn_layer is not None and attn_last attn_first = attn_layer is not None and not attn_last - self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs) + self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs, **dd) self.conv2 = ConvNormAct( - mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups, - drop_layer=drop_block, **ckwargs) - self.attn2 = attn_layer(mid_chs, act_layer=act_layer) if attn_first else nn.Identity() - self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) - self.attn3 = attn_layer(out_chs, act_layer=act_layer) if attn_last else nn.Identity() + mid_chs, + mid_chs, + kernel_size=3, + dilation=dilation, + groups=groups, + drop_layer=drop_block, + **ckwargs, + **dd, + ) + self.attn2 = attn_layer(mid_chs, act_layer=act_layer, **dd) if attn_first else nn.Identity() + self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs, **dd) + self.attn3 = attn_layer(out_chs, act_layer=act_layer, **dd) if attn_last else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() self.act3 = create_act_layer(act_layer) @@ -189,26 +199,36 @@ class DarkBlock(nn.Module): def __init__( self, - in_chs, - out_chs, - dilation=1, - bottle_ratio=0.5, - groups=1, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - attn_layer=None, - drop_block=None, - drop_path=0. + in_chs: int, + out_chs: int, + dilation: int = 1, + bottle_ratio: float = 0.5, + groups: int = 1, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + attn_layer: Optional[Type[nn.Module]] = None, + drop_block: Optional[Type[nn.Module]] = None, + drop_path: float = 0., + device=None, + dtype=None, ): - super(DarkBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) - self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs) - self.attn = attn_layer(mid_chs, act_layer=act_layer) if attn_layer is not None else nn.Identity() + self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs, **dd) + self.attn = attn_layer(mid_chs, act_layer=act_layer, **dd) if attn_layer is not None else nn.Identity() self.conv2 = ConvNormAct( - mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, - drop_layer=drop_block, **ckwargs) + mid_chs, + out_chs, + kernel_size=3, + dilation=dilation, + groups=groups, + drop_layer=drop_block, + **ckwargs, + **dd, + ) self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() def zero_init_last(self): @@ -229,26 +249,36 @@ class EdgeBlock(nn.Module): def __init__( self, - in_chs, - out_chs, - dilation=1, - bottle_ratio=0.5, - groups=1, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - attn_layer=None, - drop_block=None, - drop_path=0. + in_chs: int, + out_chs: int, + dilation: int = 1, + bottle_ratio: float = 0.5, + groups: int = 1, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + attn_layer: Optional[Type[nn.Module]] = None, + drop_block: Optional[Type[nn.Module]] = None, + drop_path: float = 0., + device=None, + dtype=None, ): - super(EdgeBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() mid_chs = int(round(out_chs * bottle_ratio)) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) self.conv1 = ConvNormAct( - in_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups, - drop_layer=drop_block, **ckwargs) - self.attn = attn_layer(mid_chs, act_layer=act_layer) if attn_layer is not None else nn.Identity() - self.conv2 = ConvNormAct(mid_chs, out_chs, kernel_size=1, **ckwargs) + in_chs, + mid_chs, + kernel_size=3, + dilation=dilation, + groups=groups, + drop_layer=drop_block, + **ckwargs, + **dd, + ) + self.attn = attn_layer(mid_chs, act_layer=act_layer, **dd) if attn_layer is not None else nn.Identity() + self.conv2 = ConvNormAct(mid_chs, out_chs, kernel_size=1, **ckwargs, **dd) self.drop_path = DropPath(drop_path) if drop_path else nn.Identity() def zero_init_last(self): @@ -267,24 +297,27 @@ class CrossStage(nn.Module): """Cross Stage.""" def __init__( self, - in_chs, - out_chs, - stride, - dilation, - depth, - block_ratio=1., - bottle_ratio=1., - expand_ratio=1., - groups=1, - first_dilation=None, - avg_down=False, - down_growth=False, - cross_linear=False, - block_dpr=None, - block_fn=BottleneckBlock, + in_chs: int, + out_chs: int, + stride: int, + dilation: int, + depth: int, + block_ratio: float = 1., + bottle_ratio: float = 1., + expand_ratio: float = 1., + groups: int = 1, + first_dilation: Optional[int] = None, + avg_down: bool = False, + down_growth: bool = False, + cross_linear: bool = False, + block_dpr: Optional[List[float]] = None, + block_fn: Type[nn.Module] = BottleneckBlock, + device=None, + dtype=None, **block_kwargs, ): - super(CrossStage, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels self.expand_chs = exp_chs = int(round(out_chs * expand_ratio)) @@ -296,12 +329,20 @@ def __init__( if avg_down: self.conv_down = nn.Sequential( nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling - ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs, **dd) ) else: self.conv_down = ConvNormAct( - in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, - aa_layer=aa_layer, **conv_kwargs) + in_chs, + down_chs, + kernel_size=3, + stride=stride, + dilation=first_dilation, + groups=groups, + aa_layer=aa_layer, + **conv_kwargs, + **dd, + ) prev_chs = down_chs else: self.conv_down = nn.Identity() @@ -310,7 +351,14 @@ def __init__( # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also, # there is also special case for the first stage for some of the model that results in uneven split # across the two paths. I did it this way for simplicity for now. - self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs) + self.conv_exp = ConvNormAct( + prev_chs, + exp_chs, + kernel_size=1, + apply_act=not cross_linear, + **conv_kwargs, + **dd, + ) prev_chs = exp_chs // 2 # output of conv_exp is always split in two self.blocks = nn.Sequential() @@ -323,12 +371,13 @@ def __init__( groups=groups, drop_path=block_dpr[i] if block_dpr is not None else 0., **block_kwargs, + **dd, )) prev_chs = block_out_chs # transition convs - self.conv_transition_b = ConvNormAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs) - self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) + self.conv_transition_b = ConvNormAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs, **dd) + self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs, **dd) def forward(self, x): x = self.conv_down(x) @@ -346,24 +395,27 @@ class CrossStage3(nn.Module): """ def __init__( self, - in_chs, - out_chs, - stride, - dilation, - depth, - block_ratio=1., - bottle_ratio=1., - expand_ratio=1., - groups=1, - first_dilation=None, - avg_down=False, - down_growth=False, - cross_linear=False, - block_dpr=None, - block_fn=BottleneckBlock, + in_chs: int, + out_chs: int, + stride: int, + dilation: int, + depth: int, + block_ratio: float = 1., + bottle_ratio: float = 1., + expand_ratio: float = 1., + groups: int = 1, + first_dilation: Optional[int] = None, + avg_down: bool = False, + down_growth: bool = False, + cross_linear: bool = False, + block_dpr: Optional[List[float]] = None, + block_fn: Type[nn.Module] = BottleneckBlock, + device=None, + dtype=None, **block_kwargs, ): - super(CrossStage3, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() first_dilation = first_dilation or dilation down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels self.expand_chs = exp_chs = int(round(out_chs * expand_ratio)) @@ -375,19 +427,34 @@ def __init__( if avg_down: self.conv_down = nn.Sequential( nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling - ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs, **dd) ) else: self.conv_down = ConvNormAct( - in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, - aa_layer=aa_layer, **conv_kwargs) + in_chs, + down_chs, + kernel_size=3, + stride=stride, + dilation=first_dilation, + groups=groups, + aa_layer=aa_layer, + **conv_kwargs, + **dd, + ) prev_chs = down_chs else: self.conv_down = None prev_chs = in_chs # expansion conv - self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs) + self.conv_exp = ConvNormAct( + prev_chs, + exp_chs, + kernel_size=1, + apply_act=not cross_linear, + **conv_kwargs, + **dd, + ) prev_chs = exp_chs // 2 # expanded output is split in 2 for blocks and cross stage self.blocks = nn.Sequential() @@ -400,11 +467,12 @@ def __init__( groups=groups, drop_path=block_dpr[i] if block_dpr is not None else 0., **block_kwargs, + **dd, )) prev_chs = block_out_chs # transition convs - self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) + self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs, **dd) def forward(self, x): x = self.conv_down(x) @@ -420,21 +488,24 @@ class DarkStage(nn.Module): def __init__( self, - in_chs, - out_chs, - stride, - dilation, - depth, - block_ratio=1., - bottle_ratio=1., - groups=1, - first_dilation=None, - avg_down=False, - block_fn=BottleneckBlock, - block_dpr=None, + in_chs: int, + out_chs: int, + stride: int, + dilation: int, + depth: int, + block_ratio: float = 1., + bottle_ratio: float = 1., + groups: int = 1, + first_dilation: Optional[int] = None, + avg_down: bool = False, + block_fn: Type[nn.Module] = BottleneckBlock, + block_dpr: Optional[List[float]] = None, + device=None, + dtype=None, **block_kwargs, ): - super(DarkStage, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() first_dilation = first_dilation or dilation conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) aa_layer = block_kwargs.pop('aa_layer', None) @@ -442,12 +513,20 @@ def __init__( if avg_down: self.conv_down = nn.Sequential( nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling - ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs, **dd) ) else: self.conv_down = ConvNormAct( - in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, - aa_layer=aa_layer, **conv_kwargs) + in_chs, + out_chs, + kernel_size=3, + stride=stride, + dilation=first_dilation, + groups=groups, + aa_layer=aa_layer, + **conv_kwargs, + **dd, + ) prev_chs = out_chs block_out_chs = int(round(out_chs * block_ratio)) @@ -460,7 +539,8 @@ def __init__( bottle_ratio=bottle_ratio, groups=groups, drop_path=block_dpr[i] if block_dpr is not None else 0., - **block_kwargs + **block_kwargs, + **dd, )) prev_chs = block_out_chs @@ -471,16 +551,19 @@ def forward(self, x): def create_csp_stem( - in_chans=3, - out_chs=32, - kernel_size=3, - stride=2, - pool='', - padding='', - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - aa_layer=None, + in_chans: int = 3, + out_chs: int = 32, + kernel_size: int = 3, + stride: int = 2, + pool: str = '', + padding: str = '', + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + aa_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} stem = nn.Sequential() feature_info = [] if not isinstance(out_chs, (tuple, list)): @@ -503,6 +586,7 @@ def create_csp_stem( padding=padding if i == 0 else '', act_layer=act_layer, norm_layer=norm_layer, + **dd, )) stem_stride *= conv_stride prev_chs = chs @@ -513,7 +597,7 @@ def create_csp_stem( feature_info.append(prev_feat) if aa_layer is not None: stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)) - stem.add_module('aa', aa_layer(channels=prev_chs, stride=2)) + stem.add_module('aa', aa_layer(channels=prev_chs, stride=2, **dd)) pool_name = 'aa' else: stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) @@ -565,7 +649,10 @@ def create_csp_stages( drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any], + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} cfg_dict = asdict(cfg.stages) num_stages = len(cfg.stages.depth) cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \ @@ -605,6 +692,7 @@ def create_csp_stages( aa_layer=cfg.aa_layer, attn_layer=attn_fn, # will be passed through stage as block_kwargs **block_kwargs, + **dd, )] prev_chs = stage_args['out_chs'] prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}') @@ -626,13 +714,15 @@ class CspNet(nn.Module): def __init__( self, cfg: CspModelCfg, - in_chans=3, - num_classes=1000, - output_stride=32, - global_pool='avg', - drop_rate=0., - drop_path_rate=0., - zero_init_last=True, + in_chans: int = 3, + num_classes: int = 1000, + output_stride: int = 32, + global_pool: str = 'avg', + drop_rate: float = 0., + drop_path_rate: float = 0., + zero_init_last: bool = True, + device=None, + dtype=None, **kwargs, ): """ @@ -648,6 +738,7 @@ def __init__( kwargs (dict): Extra kwargs overlayed onto cfg """ super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.drop_rate = drop_rate assert output_stride in (8, 16, 32) @@ -661,7 +752,7 @@ def __init__( self.feature_info = [] # Construct the stem - self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args) + self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args, **dd) self.feature_info.extend(stem_feat_info[:-1]) # Construct the stages @@ -670,6 +761,7 @@ def __init__( drop_path_rate=drop_path_rate, output_stride=output_stride, stem_feat=stem_feat_info[-1], + **dd, ) prev_chs = stage_feat_info[-1]['num_chs'] self.feature_info.extend(stage_feat_info) @@ -681,6 +773,7 @@ def __init__( num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate, + **dd, ) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) diff --git a/timm/models/davit.py b/timm/models/davit.py index 87d90b062f..f030a43a07 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -12,7 +12,7 @@ # All rights reserved. # This source code is licensed under the MIT license from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -32,8 +32,16 @@ class ConvPosEnc(nn.Module): - def __init__(self, dim: int, k: int = 3, act: bool = False): - super(ConvPosEnc, self).__init__() + def __init__( + self, + dim: int, + k: int = 3, + act: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.proj = nn.Conv2d( dim, @@ -42,6 +50,7 @@ def __init__(self, dim: int, k: int = 3, act: bool = False): stride=1, padding=k // 2, groups=dim, + **dd, ) self.act = nn.GELU() if act else nn.Identity() @@ -58,11 +67,14 @@ class Stem(nn.Module): def __init__( self, - in_chs=3, - out_chs=96, - stride=4, - norm_layer=LayerNorm2d, + in_chs: int = 3, + out_chs: int = 96, + stride: int = 4, + norm_layer: Type[nn.Module] = LayerNorm2d, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() stride = to_2tuple(stride) self.stride = stride @@ -75,8 +87,9 @@ def __init__( kernel_size=7, stride=stride, padding=3, + **dd, ) - self.norm = norm_layer(out_chs) + self.norm = norm_layer(out_chs, **dd) def forward(self, x: Tensor): B, C, H, W = x.shape @@ -91,16 +104,19 @@ def forward(self, x: Tensor): class Downsample(nn.Module): def __init__( self, - in_chs, - out_chs, - kernel_size=3, - norm_layer=LayerNorm2d, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + norm_layer: Type[nn.Module] = LayerNorm2d, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.in_chs = in_chs self.out_chs = out_chs - self.norm = norm_layer(in_chs) + self.norm = norm_layer(in_chs, **dd) self.even_k = kernel_size % 2 == 0 self.conv = nn.Conv2d( in_chs, @@ -108,6 +124,7 @@ def __init__( kernel_size=kernel_size, stride=2, padding=0 if self.even_k else kernel_size // 2, + **dd, ) def forward(self, x: Tensor): @@ -124,14 +141,23 @@ def forward(self, x: Tensor): class ChannelAttentionV2(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=True, dynamic_scale=True): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + dynamic_scale: bool = True, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.groups = num_heads self.head_dim = dim // num_heads self.dynamic_scale = dynamic_scale - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) + self.proj = nn.Linear(dim, dim, **dd) def forward(self, x): B, N, C = x.shape @@ -155,14 +181,22 @@ def forward(self, x): class ChannelAttention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) + self.proj = nn.Linear(dim, dim, **dd) def forward(self, x: Tensor): B, N, C = x.shape @@ -183,37 +217,42 @@ class ChannelBlock(nn.Module): def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - ffn=True, - cpe_act=False, - v2=False, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + ffn: bool = True, + cpe_act: bool = False, + v2: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd) self.ffn = ffn - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) attn_layer = ChannelAttentionV2 if v2 else ChannelAttention self.attn = attn_layer( dim, num_heads=num_heads, qkv_bias=qkv_bias, + **dd, ) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd) if self.ffn: - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + **dd, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() else: @@ -282,7 +321,16 @@ class WindowAttention(nn.Module): """ fused_attn: torch.jit.Final[bool] - def __init__(self, dim, window_size, num_heads, qkv_bias=True): + def __init__( + self, + dim: int, + window_size: Tuple[int, int], + num_heads: int, + qkv_bias: bool = True, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.window_size = window_size @@ -291,8 +339,8 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True): self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) + self.proj = nn.Linear(dim, dim, **dd) self.softmax = nn.Softmax(dim=-1) @@ -330,17 +378,20 @@ class SpatialBlock(nn.Module): def __init__( self, - dim, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - ffn=True, - cpe_act=False, + dim: int, + num_heads: int, + window_size: int = 7, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + ffn: bool = True, + cpe_act: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.ffn = ffn @@ -348,24 +399,26 @@ def __init__( self.window_size = to_2tuple(window_size) self.mlp_ratio = mlp_ratio - self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) - self.norm1 = norm_layer(dim) + self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd) + self.norm1 = norm_layer(dim, **dd) self.attn = WindowAttention( dim, self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, + **dd, ) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd) if self.ffn: - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, + **dd, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() else: @@ -416,31 +469,34 @@ def forward(self, x: Tensor): class DaVitStage(nn.Module): def __init__( self, - in_chs, - out_chs, - depth=1, - downsample=True, - attn_types=('spatial', 'channel'), - num_heads=3, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - drop_path_rates=(0, 0), - norm_layer=LayerNorm2d, - norm_layer_cl=nn.LayerNorm, - ffn=True, - cpe_act=False, - down_kernel_size=2, - named_blocks=False, - channel_attn_v2=False, + in_chs: int, + out_chs: int, + depth:int = 1, + downsample: bool = True, + attn_types: Tuple[str, ...] = ('spatial', 'channel'), + num_heads: int = 3, + window_size: int = 7, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_path_rates: Tuple[float, ...] = (0, 0), + norm_layer: Type[nn.Module] = LayerNorm2d, + norm_layer_cl: Type[nn.Module] = nn.LayerNorm, + ffn: bool = True, + cpe_act: bool = False, + down_kernel_size: int = 2, + named_blocks: bool = False, + channel_attn_v2: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.grad_checkpointing = False # downsample embedding layer at the beginning of each stage if downsample: - self.downsample = Downsample(in_chs, out_chs, kernel_size=down_kernel_size, norm_layer=norm_layer) + self.downsample = Downsample(in_chs, out_chs, kernel_size=down_kernel_size, norm_layer=norm_layer, **dd) else: self.downsample = nn.Identity() @@ -467,6 +523,7 @@ def __init__( ffn=ffn, cpe_act=cpe_act, window_size=window_size, + **dd, ))) elif attn_type == 'channel': dual_attention_block.append(('channel_block', ChannelBlock( @@ -479,6 +536,7 @@ def __init__( ffn=ffn, cpe_act=cpe_act, v2=channel_attn_v2, + **dd, ))) if named_blocks: stage_blocks.append(nn.Sequential(OrderedDict(dual_attention_block))) @@ -519,29 +577,32 @@ class DaVit(nn.Module): def __init__( self, - in_chans=3, - depths=(1, 1, 3, 1), - embed_dims=(96, 192, 384, 768), - num_heads=(3, 6, 12, 24), - window_size=7, - mlp_ratio=4, - qkv_bias=True, - norm_layer='layernorm2d', - norm_layer_cl='layernorm', - norm_eps=1e-5, - attn_types=('spatial', 'channel'), - ffn=True, - cpe_act=False, - down_kernel_size=2, - channel_attn_v2=False, - named_blocks=False, - drop_rate=0., - drop_path_rate=0., - num_classes=1000, - global_pool='avg', - head_norm_first=False, + in_chans: int = 3, + depths: Tuple[int, ...] = (1, 1, 3, 1), + embed_dims: Tuple[int, ...] = (96, 192, 384, 768), + num_heads: Tuple[int, ...] = (3, 6, 12, 24), + window_size: int = 7, + mlp_ratio: float = 4, + qkv_bias: bool = True, + norm_layer: str = 'layernorm2d', + norm_layer_cl: str = 'layernorm', + norm_eps: float = 1e-5, + attn_types: Tuple[str, ...] = ('spatial', 'channel'), + ffn: bool = True, + cpe_act: bool = False, + down_kernel_size: int = 2, + channel_attn_v2: bool = False, + named_blocks: bool = False, + drop_rate: float = 0., + drop_path_rate: float = 0., + num_classes: int = 1000, + global_pool: str = 'avg', + head_norm_first: bool = False, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} num_stages = len(embed_dims) assert num_stages == len(num_heads) == len(depths) norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) @@ -552,7 +613,7 @@ def __init__( self.grad_checkpointing = False self.feature_info = [] - self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer) + self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer, **dd) in_chs = embed_dims[0] dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) @@ -577,6 +638,7 @@ def __init__( down_kernel_size=down_kernel_size, channel_attn_v2=channel_attn_v2, named_blocks=named_blocks, + **dd, ) in_chs = out_chs stages.append(stage) @@ -588,12 +650,13 @@ def __init__( # otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt # FIXME generalize this structure to ClassifierHead if head_norm_first: - self.norm_pre = norm_layer(self.num_features) + self.norm_pre = norm_layer(self.num_features, **dd) self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, + **dd, ) else: self.norm_pre = nn.Identity() @@ -603,6 +666,7 @@ def __init__( pool_type=global_pool, drop_rate=self.drop_rate, norm_layer=norm_layer, + **dd, ) self.apply(self._init_weights) diff --git a/timm/models/deit.py b/timm/models/deit.py index 0072013bf6..ed8a6dec1d 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -11,7 +11,7 @@ # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. from functools import partial -from typing import Optional +from typing import Optional, Type import torch from torch import nn as nn @@ -36,12 +36,13 @@ def __init__(self, *args, **kwargs): weight_init = kwargs.pop('weight_init', '') super().__init__(*args, **kwargs, weight_init='skip') assert self.global_pool in ('token',) + dd = {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)} self.num_prefix_tokens = 2 - self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim, **dd)) self.pos_embed = nn.Parameter( - torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim)) - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() + torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim, **dd)) + self.head_dist = nn.Linear(self.embed_dim, self.num_classes, **dd) if self.num_classes > 0 else nn.Identity() self.distilled_training = False # must set this True to train w/ distillation token self.init_weights(weight_init) diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 1a9f9887a5..a09da46d23 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -4,7 +4,7 @@ """ import re from collections import OrderedDict -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -31,9 +31,11 @@ def __init__( num_input_features: int, growth_rate: int, bn_size: int, - norm_layer: type = BatchNormAct2d, + norm_layer: Type[nn.Module] = BatchNormAct2d, drop_rate: float = 0., grad_checkpointing: bool = False, + device=None, + dtype=None, ) -> None: """Initialize DenseLayer. @@ -45,13 +47,14 @@ def __init__( drop_rate: Dropout rate. grad_checkpointing: Use gradient checkpointing. """ - super(DenseLayer, self).__init__() - self.add_module('norm1', norm_layer(num_input_features)), + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.add_module('norm1', norm_layer(num_input_features, **dd)), self.add_module('conv1', nn.Conv2d( - num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), - self.add_module('norm2', norm_layer(bn_size * growth_rate)), + num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False, **dd)), + self.add_module('norm2', norm_layer(bn_size * growth_rate, **dd)), self.add_module('conv2', nn.Conv2d( - bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), + bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False, **dd)), self.drop_rate = float(drop_rate) self.grad_checkpointing = grad_checkpointing @@ -129,9 +132,11 @@ def __init__( num_input_features: int, bn_size: int, growth_rate: int, - norm_layer: type = BatchNormAct2d, + norm_layer: Type[nn.Module] = BatchNormAct2d, drop_rate: float = 0., grad_checkpointing: bool = False, + device=None, + dtype=None, ) -> None: """Initialize DenseBlock. @@ -144,7 +149,8 @@ def __init__( drop_rate: Dropout rate. grad_checkpointing: Use gradient checkpointing. """ - super(DenseBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() for i in range(num_layers): layer = DenseLayer( num_input_features + i * growth_rate, @@ -153,6 +159,7 @@ def __init__( norm_layer=norm_layer, drop_rate=drop_rate, grad_checkpointing=grad_checkpointing, + **dd, ) self.add_module('denselayer%d' % (i + 1), layer) @@ -182,8 +189,10 @@ def __init__( self, num_input_features: int, num_output_features: int, - norm_layer: type = BatchNormAct2d, - aa_layer: Optional[type] = None, + norm_layer: Type[nn.Module] = BatchNormAct2d, + aa_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ) -> None: """Initialize DenseTransition. @@ -193,12 +202,13 @@ def __init__( norm_layer: Normalization layer class. aa_layer: Anti-aliasing layer class. """ - super(DenseTransition, self).__init__() - self.add_module('norm', norm_layer(num_input_features)) + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.add_module('norm', norm_layer(num_input_features, **dd)) self.add_module('conv', nn.Conv2d( - num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) + num_input_features, num_output_features, kernel_size=1, stride=1, bias=False, **dd)) if aa_layer is not None: - self.add_module('pool', aa_layer(num_output_features, stride=2)) + self.add_module('pool', aa_layer(num_output_features, stride=2, **dd)) else: self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) @@ -231,11 +241,13 @@ def __init__( stem_type: str = '', act_layer: str = 'relu', norm_layer: str = 'batchnorm2d', - aa_layer: Optional[type] = None, + aa_layer: Optional[Type[nn.Module]] = None, drop_rate: float = 0., proj_drop_rate: float = 0., memory_efficient: bool = False, aa_stem_only: bool = True, + device=None, + dtype=None, ) -> None: """Initialize DenseNet. @@ -255,8 +267,9 @@ def __init__( memory_efficient: If True, uses checkpointing for memory efficiency. aa_stem_only: Apply anti-aliasing only to stem. """ + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes - super(DenseNet, self).__init__() + super().__init__() norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer) # Stem @@ -267,25 +280,25 @@ def __init__( else: stem_pool = nn.Sequential(*[ nn.MaxPool2d(kernel_size=3, stride=1, padding=1), - aa_layer(channels=num_init_features, stride=2)]) + aa_layer(channels=num_init_features, stride=2, **dd)]) if deep_stem: stem_chs_1 = stem_chs_2 = growth_rate if 'tiered' in stem_type: stem_chs_1 = 3 * (growth_rate // 4) stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4) self.features = nn.Sequential(OrderedDict([ - ('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)), - ('norm0', norm_layer(stem_chs_1)), - ('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)), - ('norm1', norm_layer(stem_chs_2)), - ('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)), - ('norm2', norm_layer(num_init_features)), + ('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False, **dd)), + ('norm0', norm_layer(stem_chs_1, **dd)), + ('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False, **dd)), + ('norm1', norm_layer(stem_chs_2, **dd)), + ('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False, **dd)), + ('norm2', norm_layer(num_init_features, **dd)), ('pool0', stem_pool), ])) else: self.features = nn.Sequential(OrderedDict([ - ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), - ('norm0', norm_layer(num_init_features)), + ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False, **dd)), + ('norm0', norm_layer(num_init_features, **dd)), ('pool0', stem_pool), ])) self.feature_info = [ @@ -303,6 +316,7 @@ def __init__( norm_layer=norm_layer, drop_rate=proj_drop_rate, grad_checkpointing=memory_efficient, + **dd, ) module_name = f'denseblock{(i + 1)}' self.features.add_module(module_name, block) @@ -317,12 +331,13 @@ def __init__( num_output_features=num_features // 2, norm_layer=norm_layer, aa_layer=transition_aa_layer, + **dd, ) self.features.add_module(f'transition{i + 1}', trans) num_features = num_features // 2 # Final batch norm - self.features.add_module('norm5', norm_layer(num_features)) + self.features.add_module('norm5', norm_layer(num_features, **dd)) self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')] self.num_features = self.head_hidden_size = num_features @@ -332,6 +347,7 @@ def __init__( self.num_features, self.num_classes, pool_type=global_pool, + **dd, ) self.global_pool = global_pool self.head_drop = nn.Dropout(drop_rate) diff --git a/timm/models/dla.py b/timm/models/dla.py index 197060e4e6..2763686dbf 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -6,7 +6,7 @@ Res2Net Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169 """ import math -from typing import List, Optional +from typing import List, Optional, Tuple, Type import torch import torch.nn as nn @@ -22,17 +22,41 @@ class DlaBasic(nn.Module): """DLA Basic""" - def __init__(self, inplanes, planes, stride=1, dilation=1, **_): - super(DlaBasic, self).__init__() + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + dilation: int = 1, + device=None, + dtype=None, + **_, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.conv1 = nn.Conv2d( - inplanes, planes, kernel_size=3, - stride=stride, padding=dilation, bias=False, dilation=dilation) - self.bn1 = nn.BatchNorm2d(planes) + inplanes, + planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + **dd, + ) + self.bn1 = nn.BatchNorm2d(planes, **dd) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d( - planes, planes, kernel_size=3, - stride=1, padding=dilation, bias=False, dilation=dilation) - self.bn2 = nn.BatchNorm2d(planes) + planes, + planes, + kernel_size=3, + stride=1, + padding=dilation, + bias=False, + dilation=dilation, + **dd, + ) + self.bn2 = nn.BatchNorm2d(planes, **dd) self.stride = stride def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None): @@ -56,20 +80,39 @@ class DlaBottleneck(nn.Module): """DLA/DLA-X Bottleneck""" expansion = 2 - def __init__(self, inplanes, outplanes, stride=1, dilation=1, cardinality=1, base_width=64): - super(DlaBottleneck, self).__init__() + def __init__( + self, + inplanes: int, + outplanes: int, + stride: int = 1, + dilation: int = 1, + cardinality: int = 1, + base_width: int = 64, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.stride = stride mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality) mid_planes = mid_planes // self.expansion - self.conv1 = nn.Conv2d(inplanes, mid_planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(mid_planes) + self.conv1 = nn.Conv2d(inplanes, mid_planes, kernel_size=1, bias=False, **dd) + self.bn1 = nn.BatchNorm2d(mid_planes, **dd) self.conv2 = nn.Conv2d( - mid_planes, mid_planes, kernel_size=3, - stride=stride, padding=dilation, bias=False, dilation=dilation, groups=cardinality) - self.bn2 = nn.BatchNorm2d(mid_planes) - self.conv3 = nn.Conv2d(mid_planes, outplanes, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(outplanes) + mid_planes, + mid_planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + groups=cardinality, + **dd, + ) + self.bn2 = nn.BatchNorm2d(mid_planes, **dd) + self.conv3 = nn.Conv2d(mid_planes, outplanes, kernel_size=1, bias=False, **dd) + self.bn3 = nn.BatchNorm2d(outplanes, **dd) self.relu = nn.ReLU(inplace=True) def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None): @@ -99,31 +142,51 @@ class DlaBottle2neck(nn.Module): """ expansion = 2 - def __init__(self, inplanes, outplanes, stride=1, dilation=1, scale=4, cardinality=8, base_width=4): - super(DlaBottle2neck, self).__init__() + def __init__( + self, + inplanes: int, + outplanes: int, + stride: int = 1, + dilation: int = 1, + scale: int = 4, + cardinality: int = 8, + base_width: int = 4, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.is_first = stride > 1 self.scale = scale mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality) mid_planes = mid_planes // self.expansion self.width = mid_planes - self.conv1 = nn.Conv2d(inplanes, mid_planes * scale, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(mid_planes * scale) + self.conv1 = nn.Conv2d(inplanes, mid_planes * scale, kernel_size=1, bias=False, **dd) + self.bn1 = nn.BatchNorm2d(mid_planes * scale, **dd) num_scale_convs = max(1, scale - 1) convs = [] bns = [] for _ in range(num_scale_convs): convs.append(nn.Conv2d( - mid_planes, mid_planes, kernel_size=3, - stride=stride, padding=dilation, dilation=dilation, groups=cardinality, bias=False)) - bns.append(nn.BatchNorm2d(mid_planes)) + mid_planes, + mid_planes, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=cardinality, + bias=False, + **dd, + )) + bns.append(nn.BatchNorm2d(mid_planes, **dd)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) if self.is_first else None - self.conv3 = nn.Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(outplanes) + self.conv3 = nn.Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False, **dd) + self.bn3 = nn.BatchNorm2d(outplanes, **dd) self.relu = nn.ReLU(inplace=True) def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None): @@ -163,11 +226,27 @@ def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional class DlaRoot(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, shortcut): - super(DlaRoot, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + shortcut: bool, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.conv = nn.Conv2d( - in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2) - self.bn = nn.BatchNorm2d(out_channels) + in_channels, + out_channels, + 1, + stride=1, + bias=False, + padding=(kernel_size - 1) // 2, + **dd, + ) + self.bn = nn.BatchNorm2d(out_channels, **dd) self.relu = nn.ReLU(inplace=True) self.shortcut = shortcut @@ -184,27 +263,30 @@ def forward(self, x_children: List[torch.Tensor]): class DlaTree(nn.Module): def __init__( self, - levels, - block, - in_channels, - out_channels, - stride=1, - dilation=1, - cardinality=1, - base_width=64, - level_root=False, - root_dim=0, - root_kernel_size=1, - root_shortcut=False, + levels: int, + block: Type[nn.Module], + in_channels: int, + out_channels: int, + stride: int = 1, + dilation: int = 1, + cardinality: int = 1, + base_width: int = 64, + level_root: bool = False, + root_dim: int = 0, + root_kernel_size: int = 1, + root_shortcut: bool = False, + device=None, + dtype=None, ): - super(DlaTree, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() if root_dim == 0: root_dim = 2 * out_channels if level_root: root_dim += in_channels self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else nn.Identity() self.project = nn.Identity() - cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width) + cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width, **dd) if levels == 1: self.tree1 = block(in_channels, out_channels, stride, **cargs) self.tree2 = block(out_channels, out_channels, 1, **cargs) @@ -213,9 +295,9 @@ def __init__( # used, I've moved the project layer here to avoid wasted params but old checkpoints will # need strict=False while loading. self.project = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), - nn.BatchNorm2d(out_channels)) - self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut) + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False, **dd), + nn.BatchNorm2d(out_channels, **dd)) + self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut, **dd) else: cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut)) self.tree1 = DlaTree( @@ -260,19 +342,22 @@ def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional class DLA(nn.Module): def __init__( self, - levels, - channels, - output_stride=32, - num_classes=1000, - in_chans=3, - global_pool='avg', - cardinality=1, - base_width=64, - block=DlaBottle2neck, - shortcut_root=False, - drop_rate=0.0, + levels: Tuple[int, ...], + channels: Tuple[int, ...], + output_stride: int = 32, + num_classes: int = 1000, + in_chans: int = 3, + global_pool: str = 'avg', + cardinality: int = 1, + base_width: int = 64, + block: Type[nn.Module] = DlaBottle2neck, + shortcut_root: bool = False, + drop_rate: float = 0.0, + device=None, + dtype=None, ): - super(DLA, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.channels = channels self.num_classes = num_classes self.cardinality = cardinality @@ -280,13 +365,13 @@ def __init__( assert output_stride == 32 # FIXME support dilation self.base_layer = nn.Sequential( - nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False), - nn.BatchNorm2d(channels[0]), + nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False, **dd), + nn.BatchNorm2d(channels[0], **dd), nn.ReLU(inplace=True), ) - self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) - self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2) - cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root) + self.level0 = self._make_conv_level(channels[0], channels[0], levels[0], **dd) + self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2, **dd) + cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root, **dd) self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs) self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs) self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs) @@ -307,6 +392,7 @@ def __init__( pool_type=global_pool, use_conv=True, drop_rate=drop_rate, + **dd, ) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() @@ -318,15 +404,22 @@ def __init__( m.weight.data.fill_(1) m.bias.data.zero_() - def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): + def _make_conv_level(self, inplanes: int, planes: int, convs: int, stride: int = 1, dilation: int = 1, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} modules = [] for i in range(convs): modules.extend([ nn.Conv2d( - inplanes, planes, kernel_size=3, + inplanes, + planes, + kernel_size=3, stride=stride if i == 0 else 1, - padding=dilation, bias=False, dilation=dilation), - nn.BatchNorm2d(planes), + padding=dilation, + bias=False, + dilation=dilation, + **dd, + ), + nn.BatchNorm2d(planes, **dd), nn.ReLU(inplace=True)]) inplanes = planes return nn.Sequential(*modules) diff --git a/timm/models/dpn.py b/timm/models/dpn.py index c03e5fe1a1..759658b75d 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -8,7 +8,7 @@ """ from collections import OrderedDict from functools import partial -from typing import Tuple +from typing import Tuple, Type, Optional import torch import torch.nn as nn @@ -23,9 +23,16 @@ class CatBnAct(nn.Module): - def __init__(self, in_chs, norm_layer=BatchNormAct2d): - super(CatBnAct, self).__init__() - self.bn = norm_layer(in_chs, eps=0.001) + def __init__( + self, + in_chs: int, + norm_layer: Type[nn.Module] = BatchNormAct2d, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.bn = norm_layer(in_chs, eps=0.001, **dd) @torch.jit._overload_method # noqa: F811 def forward(self, x): @@ -44,10 +51,21 @@ def forward(self, x): class BnActConv2d(nn.Module): - def __init__(self, in_chs, out_chs, kernel_size, stride, groups=1, norm_layer=BatchNormAct2d): - super(BnActConv2d, self).__init__() - self.bn = norm_layer(in_chs, eps=0.001) - self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, groups=groups) + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int, + groups: int = 1, + norm_layer: Type[nn.Module] = BatchNormAct2d, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.bn = norm_layer(in_chs, eps=0.001, **dd) + self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, groups=groups, **dd) def forward(self, x): return self.conv(self.bn(x)) @@ -56,16 +74,19 @@ def forward(self, x): class DualPathBlock(nn.Module): def __init__( self, - in_chs, - num_1x1_a, - num_3x3_b, - num_1x1_c, - inc, - groups, - block_type='normal', - b=False, + in_chs: int, + num_1x1_a: int, + num_3x3_b: int, + num_1x1_c: int, + inc: int, + groups: int, + block_type: str = 'normal', + b: bool = False, + device=None, + dtype=None, ): - super(DualPathBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.num_1x1_c = num_1x1_c self.inc = inc self.b = b @@ -86,20 +107,20 @@ def __init__( # Using different member names here to allow easier parameter key matching for conversion if self.key_stride == 2: self.c1x1_w_s2 = BnActConv2d( - in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2) + in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2, **dd) else: self.c1x1_w_s1 = BnActConv2d( - in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1) + in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1, **dd) - self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1) + self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1, **dd) self.c3x3_b = BnActConv2d( - in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3, stride=self.key_stride, groups=groups) + in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3, stride=self.key_stride, groups=groups, **dd) if b: - self.c1x1_c = CatBnAct(in_chs=num_3x3_b) - self.c1x1_c1 = create_conv2d(num_3x3_b, num_1x1_c, kernel_size=1) - self.c1x1_c2 = create_conv2d(num_3x3_b, inc, kernel_size=1) + self.c1x1_c = CatBnAct(in_chs=num_3x3_b, **dd) + self.c1x1_c1 = create_conv2d(num_3x3_b, num_1x1_c, kernel_size=1, **dd) + self.c1x1_c2 = create_conv2d(num_3x3_b, inc, kernel_size=1, **dd) else: - self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1) + self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1, **dd) self.c1x1_c1 = None self.c1x1_c2 = None @@ -150,23 +171,26 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: class DPN(nn.Module): def __init__( self, - k_sec=(3, 4, 20, 3), - inc_sec=(16, 32, 24, 128), - k_r=96, - groups=32, - num_classes=1000, - in_chans=3, - output_stride=32, - global_pool='avg', - small=False, - num_init_features=64, - b=False, - drop_rate=0., - norm_layer='batchnorm2d', - act_layer='relu', - fc_act_layer='elu', + k_sec: Tuple[int, ...] = (3, 4, 20, 3), + inc_sec: Tuple[int, ...] = (16, 32, 24, 128), + k_r: int = 96, + groups: int = 32, + num_classes: int = 1000, + in_chans: int = 3, + output_stride: int = 32, + global_pool: str = 'avg', + small: bool = False, + num_init_features: int = 64, + b: bool = False, + drop_rate: float = 0., + norm_layer: str = 'batchnorm2d', + act_layer: str = 'relu', + fc_act_layer: str = 'elu', + device=None, + dtype=None, ): - super(DPN, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.drop_rate = drop_rate self.b = b @@ -179,7 +203,13 @@ def __init__( # conv1 blocks['conv1_1'] = ConvNormAct( - in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer) + in_chans, + num_init_features, + kernel_size=3 if small else 7, + stride=2, + norm_layer=norm_layer, + **dd, + ) blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')] @@ -187,10 +217,10 @@ def __init__( bw = 64 * bw_factor inc = inc_sec[0] r = (k_r * bw) // (64 * bw_factor) - blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b) + blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b, **dd) in_chs = bw + 3 * inc for i in range(2, k_sec[0] + 1): - blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b, **dd) in_chs += inc self.feature_info += [dict(num_chs=in_chs, reduction=4, module=f'features.conv2_{k_sec[0]}')] @@ -198,10 +228,10 @@ def __init__( bw = 128 * bw_factor inc = inc_sec[1] r = (k_r * bw) // (64 * bw_factor) - blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b, **dd) in_chs = bw + 3 * inc for i in range(2, k_sec[1] + 1): - blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b, **dd) in_chs += inc self.feature_info += [dict(num_chs=in_chs, reduction=8, module=f'features.conv3_{k_sec[1]}')] @@ -209,10 +239,10 @@ def __init__( bw = 256 * bw_factor inc = inc_sec[2] r = (k_r * bw) // (64 * bw_factor) - blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b, **dd) in_chs = bw + 3 * inc for i in range(2, k_sec[2] + 1): - blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b, **dd) in_chs += inc self.feature_info += [dict(num_chs=in_chs, reduction=16, module=f'features.conv4_{k_sec[2]}')] @@ -220,21 +250,26 @@ def __init__( bw = 512 * bw_factor inc = inc_sec[3] r = (k_r * bw) // (64 * bw_factor) - blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b, **dd) in_chs = bw + 3 * inc for i in range(2, k_sec[3] + 1): - blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b, **dd) in_chs += inc self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')] - blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer) + blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer, **dd) self.num_features = self.head_hidden_size = in_chs self.features = nn.Sequential(blocks) # Using 1x1 conv for the FC layer to allow the extra pooling scheme self.global_pool, self.classifier = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.num_features, + self.num_classes, + pool_type=global_pool, + use_conv=True, + **dd, + ) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() @torch.jit.ignore diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index c63d485d77..a83ea6ddf4 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -9,15 +9,24 @@ """ import math from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import torch import torch.nn.functional as F from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_tf_, DropPath, calculate_drop_path_rates, LayerNorm2d, Mlp, create_conv2d, \ - NormMlpClassifierHead, ClassifierHead +from timm.layers import ( + DropPath, + calculate_drop_path_rates, + LayerNorm2d, + Mlp, + create_conv2d, + NormMlpClassifierHead, + ClassifierHead, + trunc_normal_tf_, +) + from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_module @@ -29,9 +38,17 @@ @register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method class PositionalEncodingFourier(nn.Module): - def __init__(self, hidden_dim=32, dim=768, temperature=10000): + def __init__( + self, + hidden_dim: int = 32, + dim: int = 768, + temperature: float = 10000., + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) + self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1, **dd) self.scale = 2 * math.pi self.temperature = temperature self.hidden_dim = hidden_dim @@ -67,25 +84,41 @@ def forward(self, shape: Tuple[int, int, int]): class ConvBlock(nn.Module): def __init__( self, - dim, - dim_out=None, - kernel_size=7, - stride=1, - conv_bias=True, - expand_ratio=4, - ls_init_value=1e-6, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - act_layer=nn.GELU, drop_path=0., + dim: int, + dim_out: Optional[int] = None, + kernel_size: int = 7, + stride: int = 1, + conv_bias: bool = True, + expand_ratio: float = 4, + ls_init_value: float = 1e-6, + norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), + act_layer: Type[nn.Module] = nn.GELU, + drop_path: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() dim_out = dim_out or dim self.shortcut_after_dw = stride > 1 or dim != dim_out self.conv_dw = create_conv2d( - dim, dim_out, kernel_size=kernel_size, stride=stride, depthwise=True, bias=conv_bias) - self.norm = norm_layer(dim_out) - self.mlp = Mlp(dim_out, int(expand_ratio * dim_out), act_layer=act_layer) - self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None + dim, + dim_out, + kernel_size=kernel_size, + stride=stride, + depthwise=True, + bias=conv_bias, + **dd, + ) + self.norm = norm_layer(dim_out, **dd) + self.mlp = Mlp( + dim_out, + int(expand_ratio * dim_out), + act_layer=act_layer, + **dd, + ) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out, **dd)) if ls_init_value > 0 else None self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): @@ -108,19 +141,22 @@ def forward(self, x): class CrossCovarianceAttn(nn.Module): def __init__( self, - dim, - num_heads=8, - qkv_bias=False, - attn_drop=0., - proj_drop=0. + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads - self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1, **dd)) - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): @@ -147,20 +183,23 @@ def no_weight_decay(self): class SplitTransposeBlock(nn.Module): def __init__( self, - dim, - num_scales=1, - num_heads=8, - expand_ratio=4, - use_pos_emb=True, - conv_bias=True, - qkv_bias=True, - ls_init_value=1e-6, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - act_layer=nn.GELU, - drop_path=0., - attn_drop=0., - proj_drop=0. + dim: int, + num_scales: int = 1, + num_heads: int = 8, + expand_ratio: float = 4, + use_pos_emb: bool = True, + conv_bias: bool = True, + qkv_bias: bool = True, + ls_init_value: float = 1e-6, + norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), + act_layer: Type[nn.Module] = nn.GELU, + drop_path: float = 0., + attn_drop: float = 0., + proj_drop: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() width = max(int(math.ceil(dim / num_scales)), int(math.floor(dim // num_scales))) self.width = width @@ -168,20 +207,31 @@ def __init__( convs = [] for i in range(self.num_scales): - convs.append(create_conv2d(width, width, kernel_size=3, depthwise=True, bias=conv_bias)) + convs.append(create_conv2d(width, width, kernel_size=3, depthwise=True, bias=conv_bias, **dd)) self.convs = nn.ModuleList(convs) self.pos_embd = None if use_pos_emb: - self.pos_embd = PositionalEncodingFourier(dim=dim) - self.norm_xca = norm_layer(dim) - self.gamma_xca = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.pos_embd = PositionalEncodingFourier(dim=dim, **dd) + self.norm_xca = norm_layer(dim, **dd) + self.gamma_xca = nn.Parameter(ls_init_value * torch.ones(dim, **dd)) if ls_init_value > 0 else None self.xca = CrossCovarianceAttn( - dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + **dd, + ) - self.norm = norm_layer(dim, eps=1e-6) - self.mlp = Mlp(dim, int(expand_ratio * dim), act_layer=act_layer) - self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.norm = norm_layer(dim, eps=1e-6, **dd) + self.mlp = Mlp( + dim, + int(expand_ratio * dim), + act_layer=act_layer, + **dd, + ) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim, **dd)) if ls_init_value > 0 else None self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): @@ -223,24 +273,27 @@ def forward(self, x): class EdgeNeXtStage(nn.Module): def __init__( self, - in_chs, - out_chs, - stride=2, - depth=2, - num_global_blocks=1, - num_heads=4, - scales=2, - kernel_size=7, - expand_ratio=4, - use_pos_emb=False, - downsample_block=False, - conv_bias=True, - ls_init_value=1.0, - drop_path_rates=None, - norm_layer=LayerNorm2d, - norm_layer_cl=partial(nn.LayerNorm, eps=1e-6), - act_layer=nn.GELU + in_chs: int, + out_chs: int, + stride: int = 2, + depth: int = 2, + num_global_blocks: int = 1, + num_heads: int = 4, + scales: int = 2, + kernel_size: int = 7, + expand_ratio: float = 4, + use_pos_emb: bool = False, + downsample_block: bool = False, + conv_bias: float = True, + ls_init_value: float = 1.0, + drop_path_rates: Optional[List[float]] = None, + norm_layer: Type[nn.Module] = LayerNorm2d, + norm_layer_cl: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), + act_layer: Type[nn.Module] = nn.GELU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.grad_checkpointing = False @@ -248,8 +301,8 @@ def __init__( self.downsample = nn.Identity() else: self.downsample = nn.Sequential( - norm_layer(in_chs), - nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2, bias=conv_bias) + norm_layer(in_chs, **dd), + nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2, bias=conv_bias, **dd) ) in_chs = out_chs @@ -268,6 +321,7 @@ def __init__( drop_path=drop_path_rates[i], norm_layer=norm_layer_cl, act_layer=act_layer, + **dd, ) ) else: @@ -283,6 +337,7 @@ def __init__( drop_path=drop_path_rates[i], norm_layer=norm_layer_cl, act_layer=act_layer, + **dd, ) ) in_chs = out_chs @@ -300,28 +355,31 @@ def forward(self, x): class EdgeNeXt(nn.Module): def __init__( self, - in_chans=3, - num_classes=1000, - global_pool='avg', - dims=(24, 48, 88, 168), - depths=(3, 3, 9, 3), - global_block_counts=(0, 1, 1, 1), - kernel_sizes=(3, 5, 7, 9), - heads=(8, 8, 8, 8), - d2_scales=(2, 2, 3, 4), - use_pos_emb=(False, True, False, False), - ls_init_value=1e-6, - head_init_scale=1., - expand_ratio=4, - downsample_block=False, - conv_bias=True, - stem_type='patch', - head_norm_first=False, - act_layer=nn.GELU, - drop_path_rate=0., - drop_rate=0., + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + dims: Tuple[int, ...] = (24, 48, 88, 168), + depths: Tuple[int, ...] = (3, 3, 9, 3), + global_block_counts: Tuple[int, ...] = (0, 1, 1, 1), + kernel_sizes: Tuple[int, ...] = (3, 5, 7, 9), + heads: Tuple[int, ...] = (8, 8, 8, 8), + d2_scales: Tuple[int, ...] = (2, 2, 3, 4), + use_pos_emb: Tuple[bool, ...] = (False, True, False, False), + ls_init_value: float = 1e-6, + head_init_scale: float = 1., + expand_ratio: float = 4, + downsample_block: bool = False, + conv_bias: bool = True, + stem_type: str = 'patch', + head_norm_first: bool = False, + act_layer: Type[nn.Module] = nn.GELU, + drop_path_rate: float = 0., + drop_rate: float = 0., + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.global_pool = global_pool self.drop_rate = drop_rate @@ -332,13 +390,13 @@ def __init__( assert stem_type in ('patch', 'overlap') if stem_type == 'patch': self.stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=conv_bias), - norm_layer(dims[0]), + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=conv_bias, **dd,), + norm_layer(dims[0], **dd), ) else: self.stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=9, stride=4, padding=9 // 2, bias=conv_bias), - norm_layer(dims[0]), + nn.Conv2d(in_chans, dims[0], kernel_size=9, stride=4, padding=9 // 2, bias=conv_bias, **dd), + norm_layer(dims[0], **dd), ) curr_stride = 4 @@ -367,6 +425,7 @@ def __init__( norm_layer=norm_layer, norm_layer_cl=norm_layer_cl, act_layer=act_layer, + **dd, )) # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 in_chs = dims[i] @@ -376,12 +435,13 @@ def __init__( self.num_features = self.head_hidden_size = dims[-1] if head_norm_first: - self.norm_pre = norm_layer(self.num_features) + self.norm_pre = norm_layer(self.num_features, **dd) self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, + **dd, ) else: self.norm_pre = nn.Identity() @@ -391,6 +451,7 @@ def __init__( pool_type=global_pool, drop_rate=self.drop_rate, norm_layer=norm_layer, + **dd, ) named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 75e5615e00..7e140da581 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -12,13 +12,22 @@ Modifications and timm support by / Copyright 2022, Ross Wightman """ -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, calculate_drop_path_rates, trunc_normal_, to_2tuple, Mlp, ndgrid +from timm.layers import ( + DropPath, + LayerScale, + LayerScale2d, + Mlp, + calculate_drop_path_rates, + trunc_normal_, + to_2tuple, + ndgrid, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq @@ -45,12 +54,15 @@ class Attention(torch.nn.Module): def __init__( self, - dim=384, - key_dim=32, - num_heads=8, - attn_ratio=4, - resolution=7 + dim: int = 384, + key_dim: int = 32, + num_heads: int = 8, + attn_ratio: float = 4, + resolution: int = 7, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads self.scale = key_dim ** -0.5 @@ -60,14 +72,17 @@ def __init__( self.val_attn_dim = self.val_dim * num_heads self.attn_ratio = attn_ratio - self.qkv = nn.Linear(dim, self.key_attn_dim * 2 + self.val_attn_dim) - self.proj = nn.Linear(self.val_attn_dim, dim) + self.qkv = nn.Linear(dim, self.key_attn_dim * 2 + self.val_attn_dim, **dd) + self.proj = nn.Linear(self.val_attn_dim, dim, **dd) resolution = to_2tuple(resolution) - pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) + pos = torch.stack(ndgrid( + torch.arange(resolution[0], device=device, dtype=torch.long), + torch.arange(resolution[1], device=device, dtype=torch.long) + )).flatten(1) rel_pos = (pos[..., :, None] - pos[..., None, :]).abs() rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1] - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1])) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1], **dd)) self.register_buffer('attention_bias_idxs', rel_pos) self.attention_bias_cache = {} # per-device attention_biases cache (data-parallel compat) @@ -102,15 +117,24 @@ def forward(self, x): # x (B,N,C) class Stem4(nn.Sequential): - def __init__(self, in_chs, out_chs, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__( + self, + in_chs: int, + out_chs: int, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.stride = 4 - self.add_module('conv1', nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1)) - self.add_module('norm1', norm_layer(out_chs // 2)) + self.add_module('conv1', nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1, **dd)) + self.add_module('norm1', norm_layer(out_chs // 2, **dd)) self.add_module('act1', act_layer()) - self.add_module('conv2', nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1)) - self.add_module('norm2', norm_layer(out_chs)) + self.add_module('conv2', nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1, **dd)) + self.add_module('norm2', norm_layer(out_chs, **dd)) self.add_module('act2', act_layer()) @@ -121,12 +145,23 @@ class Downsample(nn.Module): Output: tensor in shape [B, C, H/stride, W/stride] """ - def __init__(self, in_chs, out_chs, kernel_size=3, stride=2, padding=None, norm_layer=nn.BatchNorm2d): + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + stride: int = 2, + padding: Optional[int] = None, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() if padding is None: padding = kernel_size // 2 - self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding) - self.norm = norm_layer(out_chs) + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding, **dd) + self.norm = norm_layer(out_chs, **dd) def forward(self, x): x = self.conv(x) @@ -150,7 +185,7 @@ class Pooling(nn.Module): --pool_size: pooling size """ - def __init__(self, pool_size=3): + def __init__(self, pool_size: int = 3): super().__init__() self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) @@ -166,21 +201,24 @@ class ConvMlpWithNorm(nn.Module): def __init__( self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - norm_layer=nn.BatchNorm2d, - drop=0. + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + drop: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Conv2d(in_features, hidden_features, 1) - self.norm1 = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc1 = nn.Conv2d(in_features, hidden_features, 1, **dd) + self.norm1 = norm_layer(hidden_features, **dd) if norm_layer is not None else nn.Identity() self.act = act_layer() - self.fc2 = nn.Conv2d(hidden_features, out_features, 1) - self.norm2 = norm_layer(out_features) if norm_layer is not None else nn.Identity() + self.fc2 = nn.Conv2d(hidden_features, out_features, 1, **dd) + self.norm2 = norm_layer(out_features, **dd) if norm_layer is not None else nn.Identity() self.drop = nn.Dropout(drop) def forward(self, x): @@ -194,76 +232,63 @@ def forward(self, x): return x -class LayerScale(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - class MetaBlock1d(nn.Module): def __init__( self, - dim, - mlp_ratio=4., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - proj_drop=0., - drop_path=0., - layer_scale_init_value=1e-5 + dim: int, + mlp_ratio: float = 4., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + proj_drop: float = 0., + drop_path: float = 0., + layer_scale_init_value: float = 1e-5, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm1 = norm_layer(dim) - self.token_mixer = Attention(dim) - self.norm2 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) + self.token_mixer = Attention(dim, **dd) + self.ls1 = LayerScale(dim, layer_scale_init_value, **dd) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim, **dd) self.mlp = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, + **dd, ) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.ls1 = LayerScale(dim, layer_scale_init_value) - self.ls2 = LayerScale(dim, layer_scale_init_value) + self.ls2 = LayerScale(dim, layer_scale_init_value, **dd) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): - x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x)))) - x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x)))) + x = x + self.drop_path1(self.ls1(self.token_mixer(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x -class LayerScale2d(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - gamma = self.gamma.view(1, -1, 1, 1) - return x.mul_(gamma) if self.inplace else x * gamma - - class MetaBlock2d(nn.Module): def __init__( self, - dim, - pool_size=3, - mlp_ratio=4., - act_layer=nn.GELU, - norm_layer=nn.BatchNorm2d, - proj_drop=0., - drop_path=0., - layer_scale_init_value=1e-5 + dim: int, + pool_size: int = 3, + mlp_ratio: float = 4., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + proj_drop: float = 0., + drop_path: float = 0., + layer_scale_init_value: float = 1e-5, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.token_mixer = Pooling(pool_size=pool_size) - self.ls1 = LayerScale2d(dim, layer_scale_init_value) + self.ls1 = LayerScale2d(dim, layer_scale_init_value, **dd) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.mlp = ConvMlpWithNorm( @@ -272,8 +297,9 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, drop=proj_drop, + **dd, ) - self.ls2 = LayerScale2d(dim, layer_scale_init_value) + self.ls2 = LayerScale2d(dim, layer_scale_init_value, **dd) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): @@ -286,25 +312,28 @@ class EfficientFormerStage(nn.Module): def __init__( self, - dim, - dim_out, - depth, - downsample=True, - num_vit=1, - pool_size=3, - mlp_ratio=4., - act_layer=nn.GELU, - norm_layer=nn.BatchNorm2d, - norm_layer_cl=nn.LayerNorm, - proj_drop=.0, - drop_path=0., - layer_scale_init_value=1e-5, + dim: int, + dim_out: int, + depth: int , + downsample: bool = True, + num_vit: int = 1, + pool_size: int = 3, + mlp_ratio: float = 4., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + norm_layer_cl: Type[nn.Module] = nn.LayerNorm, + proj_drop: float = .0, + drop_path: float = 0., + layer_scale_init_value: float = 1e-5, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.grad_checkpointing = False if downsample: - self.downsample = Downsample(in_chs=dim, out_chs=dim_out, norm_layer=norm_layer) + self.downsample = Downsample(in_chs=dim, out_chs=dim_out, norm_layer=norm_layer, **dd) dim = dim_out else: assert dim == dim_out @@ -326,6 +355,7 @@ def __init__( proj_drop=proj_drop, drop_path=drop_path[block_idx], layer_scale_init_value=layer_scale_init_value, + **dd, )) else: blocks.append( @@ -338,6 +368,7 @@ def __init__( proj_drop=proj_drop, drop_path=drop_path[block_idx], layer_scale_init_value=layer_scale_init_value, + **dd, )) if num_vit and num_vit == remain_idx: blocks.append(Flat()) @@ -357,29 +388,32 @@ class EfficientFormer(nn.Module): def __init__( self, - depths, - embed_dims=None, - in_chans=3, - num_classes=1000, - global_pool='avg', - downsamples=None, - num_vit=0, - mlp_ratios=4, - pool_size=3, - layer_scale_init_value=1e-5, - act_layer=nn.GELU, - norm_layer=nn.BatchNorm2d, - norm_layer_cl=nn.LayerNorm, - drop_rate=0., - proj_drop_rate=0., - drop_path_rate=0., + depths: Tuple[int, ...] = (3, 2, 6, 4), + embed_dims: Tuple[int, ...] = (48, 96, 224, 448), + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + downsamples: Optional[Tuple[bool, ...]] = None, + num_vit: int = 0, + mlp_ratios: float = 4, + pool_size: int = 3, + layer_scale_init_value: float = 1e-5, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + norm_layer_cl: Type[nn.Module] = nn.LayerNorm, + drop_rate: float = 0., + proj_drop_rate: float = 0., + drop_path_rate: float = 0., + device=None, + dtype=None, **kwargs ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.global_pool = global_pool - self.stem = Stem4(in_chans, embed_dims[0], norm_layer=norm_layer) + self.stem = Stem4(in_chans, embed_dims[0], norm_layer=norm_layer, **dd) prev_dim = embed_dims[0] # stochastic depth decay rule @@ -404,6 +438,7 @@ def __init__( proj_drop=proj_drop_rate, drop_path=dpr[i], layer_scale_init_value=layer_scale_init_value, + **dd, ) prev_dim = embed_dims[i] stages.append(stage) @@ -412,11 +447,11 @@ def __init__( # Classifier head self.num_features = self.head_hidden_size = embed_dims[-1] - self.norm = norm_layer_cl(self.num_features) + self.norm = norm_layer_cl(self.num_features, **dd) self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity() # assuming model is always distilled (valid for current checkpoints, will split def if that changes) - self.head_dist = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(embed_dims[-1], num_classes, **dd) if num_classes > 0 else nn.Identity() self.distilled_training = False # must set this True to train w/ distillation token self.apply(self._init_weights) diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py index 5bc2112f20..92fbdaf14a 100644 --- a/timm/models/efficientformer_v2.py +++ b/timm/models/efficientformer_v2.py @@ -16,14 +16,26 @@ """ import math from functools import partial -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct -from timm.layers import DropPath, calculate_drop_path_rates, trunc_normal_, to_2tuple, to_ntuple, ndgrid +from timm.layers import ( + create_conv2d, + create_norm_layer, + get_act_layer, + get_norm_layer, + ConvNormAct, + LayerScale2d, + DropPath, + calculate_drop_path_rates, + trunc_normal_, + to_2tuple, + to_ntuple, + ndgrid, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq @@ -57,19 +69,22 @@ class ConvNorm(nn.Module): def __init__( self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding='', - dilation=1, - groups=1, - bias=True, - norm_layer='batchnorm2d', - norm_kwargs=None, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: Union[int, str] = '', + dilation: int = 1, + groups: int = 1, + bias: bool = True, + norm_layer: str = 'batchnorm2d', + norm_kwargs: Optional[Dict] = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} norm_kwargs = norm_kwargs or {} - super(ConvNorm, self).__init__() + super().__init__() self.conv = create_conv2d( in_channels, out_channels, @@ -79,8 +94,9 @@ def __init__( dilation=dilation, groups=groups, bias=bias, + **dd, ) - self.bn = create_norm_layer(norm_layer, out_channels, **norm_kwargs) + self.bn = create_norm_layer(norm_layer, out_channels, **norm_kwargs, **dd) def forward(self, x): x = self.conv(x) @@ -93,14 +109,17 @@ class Attention2d(torch.nn.Module): def __init__( self, - dim=384, - key_dim=32, - num_heads=8, - attn_ratio=4, - resolution=7, - act_layer=nn.GELU, - stride=None, + dim: int = 384, + key_dim: int = 32, + num_heads: int = 8, + attn_ratio: int = 4, + resolution: Union[int, Tuple[int, int]] = 7, + act_layer: Type[nn.Module] = nn.GELU, + stride: Optional[int] = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads self.scale = key_dim ** -0.5 @@ -109,7 +128,7 @@ def __init__( resolution = to_2tuple(resolution) if stride is not None: resolution = tuple([math.ceil(r / stride) for r in resolution]) - self.stride_conv = ConvNorm(dim, dim, kernel_size=3, stride=stride, groups=dim) + self.stride_conv = ConvNorm(dim, dim, kernel_size=3, stride=stride, groups=dim, **dd) self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear') else: self.stride_conv = None @@ -122,21 +141,24 @@ def __init__( self.attn_ratio = attn_ratio kh = self.key_dim * self.num_heads - self.q = ConvNorm(dim, kh) - self.k = ConvNorm(dim, kh) - self.v = ConvNorm(dim, self.dh) - self.v_local = ConvNorm(self.dh, self.dh, kernel_size=3, groups=self.dh) - self.talking_head1 = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1) - self.talking_head2 = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1) + self.q = ConvNorm(dim, kh, **dd) + self.k = ConvNorm(dim, kh, **dd) + self.v = ConvNorm(dim, self.dh, **dd) + self.v_local = ConvNorm(self.dh, self.dh, kernel_size=3, groups=self.dh, **dd) + self.talking_head1 = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1, **dd) + self.talking_head2 = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1, **dd) self.act = act_layer() - self.proj = ConvNorm(self.dh, dim, 1) + self.proj = ConvNorm(self.dh, dim, 1, **dd) - pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1) + pos = torch.stack(ndgrid( + torch.arange(self.resolution[0], device=device, dtype=torch.long), + torch.arange(self.resolution[1], device=device, dtype=torch.long), + )).flatten(1) rel_pos = (pos[..., :, None] - pos[..., None, :]).abs() rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1] - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, self.N)) - self.register_buffer('attention_bias_idxs', torch.LongTensor(rel_pos), persistent=False) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, self.N, **dd)) + self.register_buffer('attention_bias_idxs', rel_pos, persistent=False) self.attention_bias_cache = {} # per-device attention_biases cache (data-parallel compat) @torch.no_grad() @@ -182,11 +204,18 @@ def forward(self, x): class LocalGlobalQuery(torch.nn.Module): - def __init__(self, in_dim, out_dim): + def __init__( + self, + in_dim: int, + out_dim: int, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.pool = nn.AvgPool2d(1, 2, 0) - self.local = nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=2, padding=1, groups=in_dim) - self.proj = ConvNorm(in_dim, out_dim, 1) + self.local = nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=2, padding=1, groups=in_dim, **dd) + self.proj = ConvNorm(in_dim, out_dim, 1, **dd) def forward(self, x): local_q = self.local(x) @@ -201,14 +230,17 @@ class Attention2dDownsample(torch.nn.Module): def __init__( self, - dim=384, - key_dim=16, - num_heads=8, - attn_ratio=4, - resolution=7, - out_dim=None, - act_layer=nn.GELU, + dim: int = 384, + key_dim: int = 16, + num_heads: int = 8, + attn_ratio: int = 4, + resolution: Union[int, Tuple[int, int]] = 7, + out_dim: Optional[int] = None, + act_layer: Type[nn.Module] = nn.GELU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads @@ -225,19 +257,22 @@ def __init__( self.out_dim = out_dim or dim kh = self.key_dim * self.num_heads - self.q = LocalGlobalQuery(dim, kh) - self.k = ConvNorm(dim, kh, 1) - self.v = ConvNorm(dim, self.dh, 1) - self.v_local = ConvNorm(self.dh, self.dh, kernel_size=3, stride=2, groups=self.dh) + self.q = LocalGlobalQuery(dim, kh, **dd) + self.k = ConvNorm(dim, kh, 1, **dd) + self.v = ConvNorm(dim, self.dh, 1, **dd) + self.v_local = ConvNorm(self.dh, self.dh, kernel_size=3, stride=2, groups=self.dh, **dd) self.act = act_layer() - self.proj = ConvNorm(self.dh, self.out_dim, 1) + self.proj = ConvNorm(self.dh, self.out_dim, 1, **dd) - self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.N)) - k_pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.N, **dd)) + k_pos = torch.stack(ndgrid( + torch.arange(self.resolution[0], device=device, dtype=torch.long), + torch.arange(self.resolution[1], device=device, dtype=torch.long), + )).flatten(1) q_pos = torch.stack(ndgrid( - torch.arange(0, self.resolution[0], step=2), - torch.arange(0, self.resolution[1], step=2) + torch.arange(0, self.resolution[0], step=2, device=device, dtype=torch.long), + torch.arange(0, self.resolution[1], step=2, device=device, dtype=torch.long), )).flatten(1) rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs() rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1] @@ -282,16 +317,19 @@ def forward(self, x): class Downsample(nn.Module): def __init__( self, - in_chs, - out_chs, - kernel_size=3, - stride=2, - padding=1, - resolution=7, - use_attn=False, - act_layer=nn.GELU, - norm_layer=nn.BatchNorm2d, + in_chs: int, + out_chs: int, + kernel_size: Union[int, Tuple[int, int]] = 3, + stride: Union[int, Tuple[int, int]] = 2, + padding: Union[int, Tuple[int, int]] = 1, + resolution: Union[int, Tuple[int, int]] = 7, + use_attn: bool = False, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm2d, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() kernel_size = to_2tuple(kernel_size) @@ -305,6 +343,7 @@ def __init__( stride=stride, padding=padding, norm_layer=norm_layer, + **dd, ) if use_attn: @@ -313,6 +352,7 @@ def __init__( out_dim=out_chs, resolution=resolution, act_layer=act_layer, + **dd, ) else: self.attn = None @@ -332,28 +372,44 @@ class ConvMlpWithNorm(nn.Module): def __init__( self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - norm_layer=nn.BatchNorm2d, - drop=0., - mid_conv=False, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + drop: float = 0., + mid_conv: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = ConvNormAct( - in_features, hidden_features, 1, - bias=True, norm_layer=norm_layer, act_layer=act_layer) + in_features, + hidden_features, + 1, + bias=True, + norm_layer=norm_layer, + act_layer=act_layer, + **dd, + ) if mid_conv: self.mid = ConvNormAct( - hidden_features, hidden_features, 3, - groups=hidden_features, bias=True, norm_layer=norm_layer, act_layer=act_layer) + hidden_features, + hidden_features, + 3, + groups=hidden_features, + bias=True, + norm_layer=norm_layer, + act_layer=act_layer, + **dd, + ) else: self.mid = nn.Identity() self.drop1 = nn.Dropout(drop) - self.fc2 = ConvNorm(hidden_features, out_features, 1, norm_layer=norm_layer) + self.fc2 = ConvNorm(hidden_features, out_features, 1, norm_layer=norm_layer, **dd) self.drop2 = nn.Dropout(drop) def forward(self, x): @@ -365,31 +421,23 @@ def forward(self, x): return x -class LayerScale2d(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - gamma = self.gamma.view(1, -1, 1, 1) - return x.mul_(gamma) if self.inplace else x * gamma - - class EfficientFormerV2Block(nn.Module): def __init__( self, - dim, - mlp_ratio=4., - act_layer=nn.GELU, - norm_layer=nn.BatchNorm2d, - proj_drop=0., - drop_path=0., - layer_scale_init_value=1e-5, - resolution=7, - stride=None, - use_attn=True, + dim: int, + mlp_ratio: float = 4., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + proj_drop: float = 0., + drop_path: float = 0., + layer_scale_init_value: Optional[float] = 1e-5, + resolution: Union[int, Tuple[int, int]] = 7, + stride: Optional[int] = None, + use_attn: bool = True, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() if use_attn: @@ -398,9 +446,10 @@ def __init__( resolution=resolution, act_layer=act_layer, stride=stride, + **dd, ) self.ls1 = LayerScale2d( - dim, layer_scale_init_value) if layer_scale_init_value is not None else nn.Identity() + dim, layer_scale_init_value, **dd) if layer_scale_init_value is not None else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() else: self.token_mixer = None @@ -414,9 +463,10 @@ def __init__( norm_layer=norm_layer, drop=proj_drop, mid_conv=True, + **dd, ) self.ls2 = LayerScale2d( - dim, layer_scale_init_value) if layer_scale_init_value is not None else nn.Identity() + dim, layer_scale_init_value, **dd) if layer_scale_init_value is not None else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): @@ -427,16 +477,38 @@ def forward(self, x): class Stem4(nn.Sequential): - def __init__(self, in_chs, out_chs, act_layer=nn.GELU, norm_layer=nn.BatchNorm2d): + def __init__( + self, + in_chs: int, + out_chs: int, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.stride = 4 self.conv1 = ConvNormAct( - in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1, bias=True, - norm_layer=norm_layer, act_layer=act_layer + in_chs, + out_chs // 2, + kernel_size=3, + stride=2, padding=1, + bias=True, + norm_layer=norm_layer, + act_layer=act_layer, + **dd, ) self.conv2 = ConvNormAct( - out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1, bias=True, - norm_layer=norm_layer, act_layer=act_layer + out_chs // 2, + out_chs, + kernel_size=3, + stride=2, + padding=1, + bias=True, + norm_layer=norm_layer, + act_layer=act_layer, + **dd, ) @@ -444,23 +516,25 @@ class EfficientFormerV2Stage(nn.Module): def __init__( self, - dim, - dim_out, - depth, - resolution=7, - downsample=True, - block_stride=None, - downsample_use_attn=False, - block_use_attn=False, - num_vit=1, - mlp_ratio=4., - proj_drop=.0, - drop_path=0., - layer_scale_init_value=1e-5, - act_layer=nn.GELU, - norm_layer=nn.BatchNorm2d, - + dim: int, + dim_out: int, + depth: int, + resolution: Union[int, Tuple[int, int]] = 7, + downsample: bool = True, + block_stride: Optional[int] = None, + downsample_use_attn: bool = False, + block_use_attn: bool = False, + num_vit: int = 1, + mlp_ratio: Union[float, Tuple[float, ...]] = 4., + proj_drop: float = .0, + drop_path: Union[float, List[float]] = 0., + layer_scale_init_value: Optional[float] = 1e-5, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.grad_checkpointing = False mlp_ratio = to_ntuple(depth)(mlp_ratio) @@ -474,6 +548,7 @@ def __init__( resolution=resolution, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) dim = dim_out resolution = tuple([math.ceil(r / 2) for r in resolution]) @@ -495,6 +570,7 @@ def __init__( layer_scale_init_value=layer_scale_init_value, act_layer=act_layer, norm_layer=norm_layer, + **dd, ) blocks += [b] self.blocks = nn.Sequential(*blocks) @@ -511,25 +587,28 @@ def forward(self, x): class EfficientFormerV2(nn.Module): def __init__( self, - depths, - in_chans=3, - img_size=224, - global_pool='avg', - embed_dims=None, - downsamples=None, - mlp_ratios=4, - norm_layer='batchnorm2d', - norm_eps=1e-5, - act_layer='gelu', - num_classes=1000, - drop_rate=0., - proj_drop_rate=0., - drop_path_rate=0., - layer_scale_init_value=1e-5, - num_vit=0, - distillation=True, + depths: Tuple[int, ...], + in_chans: int = 3, + img_size: Union[int, Tuple[int, int]] = 224, + global_pool: str = 'avg', + embed_dims: Optional[Tuple[int, ...]] = None, + downsamples: Optional[Tuple[bool, ...]] = None, + mlp_ratios: Union[float, Tuple[float, ...], Tuple[Tuple[float, ...], ...]] = 4, + norm_layer: str = 'batchnorm2d', + norm_eps: float = 1e-5, + act_layer: str = 'gelu', + num_classes: int = 1000, + drop_rate: float = 0., + proj_drop_rate: float = 0., + drop_path_rate: float = 0., + layer_scale_init_value: Optional[float] = 1e-5, + num_vit: int = 0, + distillation: bool = True, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} assert global_pool in ('avg', '') self.num_classes = num_classes self.global_pool = global_pool @@ -538,7 +617,7 @@ def __init__( norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) act_layer = get_act_layer(act_layer) - self.stem = Stem4(in_chans, embed_dims[0], act_layer=act_layer, norm_layer=norm_layer) + self.stem = Stem4(in_chans, embed_dims[0], act_layer=act_layer, norm_layer=norm_layer, **dd) prev_dim = embed_dims[0] stride = 4 @@ -565,6 +644,7 @@ def __init__( layer_scale_init_value=layer_scale_init_value, act_layer=act_layer, norm_layer=norm_layer, + **dd, ) if downsamples[i]: stride *= 2 @@ -575,12 +655,12 @@ def __init__( # Classifier head self.num_features = self.head_hidden_size = embed_dims[-1] - self.norm = norm_layer(embed_dims[-1]) + self.norm = norm_layer(embed_dims[-1], **dd) self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(embed_dims[-1], num_classes, **dd) if num_classes > 0 else nn.Identity() self.dist = distillation if self.dist: - self.head_dist = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(embed_dims[-1], num_classes, **dd) if num_classes > 0 else nn.Identity() else: self.head_dist = None diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index d6140de9f3..edc568335e 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -96,7 +96,9 @@ def __init__( round_chs_fn: Callable = round_channels, drop_rate: float = 0., drop_path_rate: float = 0., - global_pool: str = 'avg' + global_pool: str = 'avg', + device=None, + dtype=None, ) -> None: """Initialize EfficientNet model. @@ -119,7 +121,8 @@ def __init__( drop_path_rate: Drop path rate for stochastic depth. global_pool: Global pooling type. """ - super(EfficientNet, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d norm_act_layer = get_norm_act_layer(norm_layer, act_layer) @@ -131,8 +134,8 @@ def __init__( # Stem if not fix_stem: stem_size = round_chs_fn(stem_size) - self.conv_stem = create_conv2d(in_chans, stem_size, stem_kernel_size, stride=2, padding=pad_type) - self.bn1 = norm_act_layer(stem_size, inplace=True) + self.conv_stem = create_conv2d(in_chans, stem_size, stem_kernel_size, stride=2, padding=pad_type, **dd) + self.bn1 = norm_act_layer(stem_size, inplace=True, **dd) # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( @@ -144,6 +147,7 @@ def __init__( aa_layer=aa_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, + **dd, ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features @@ -152,8 +156,8 @@ def __init__( # Head + Pooling if num_features > 0: - self.conv_head = create_conv2d(head_chs, num_features, 1, padding=pad_type) - self.bn2 = norm_act_layer(num_features, inplace=True) + self.conv_head = create_conv2d(head_chs, num_features, 1, padding=pad_type, **dd) + self.bn2 = norm_act_layer(num_features, inplace=True, **dd) self.num_features = self.head_hidden_size = num_features else: self.conv_head = nn.Identity() @@ -161,7 +165,11 @@ def __init__( self.num_features = self.head_hidden_size = head_chs self.global_pool, self.classifier = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool) + self.num_features, + self.num_classes, + pool_type=global_pool, + **dd, + ) efficientnet_init_weights(self) @@ -366,8 +374,11 @@ def __init__( round_chs_fn: Callable = round_channels, drop_rate: float = 0., drop_path_rate: float = 0., + device=None, + dtype=None, ): - super(EfficientNetFeatures, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d norm_act_layer = get_norm_act_layer(norm_layer, act_layer) @@ -378,8 +389,8 @@ def __init__( # Stem if not fix_stem: stem_size = round_chs_fn(stem_size) - self.conv_stem = create_conv2d(in_chans, stem_size, stem_kernel_size, stride=2, padding=pad_type) - self.bn1 = norm_act_layer(stem_size, inplace=True) + self.conv_stem = create_conv2d(in_chans, stem_size, stem_kernel_size, stride=2, padding=pad_type, **dd) + self.bn1 = norm_act_layer(stem_size, inplace=True, **dd) # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( @@ -392,6 +403,7 @@ def __init__( se_layer=se_layer, drop_path_rate=drop_path_rate, feature_location=feature_location, + **dd, ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = FeatureInfo(builder.features, out_indices) diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 8b35a04c87..3cae78d897 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -7,7 +7,7 @@ """ __all__ = ['EfficientVit', 'EfficientVitLarge'] -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union from functools import partial import torch @@ -48,19 +48,22 @@ def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, . class ConvNormAct(nn.Module): def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size=3, - stride=1, - dilation=1, - groups=1, - bias=False, - dropout=0., - norm_layer=nn.BatchNorm2d, - act_layer=nn.ReLU, + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]] = 3, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = False, + dropout: float = 0., + norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm2d, + act_layer: Optional[Type[nn.Module]] = nn.ReLU, + device=None, + dtype=None, ): - super(ConvNormAct, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.dropout = nn.Dropout(dropout, inplace=False) self.conv = create_conv2d( in_channels, @@ -70,8 +73,9 @@ def __init__( dilation=dilation, groups=groups, bias=bias, + **dd, ) - self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity() + self.norm = norm_layer(num_features=out_channels, **dd) if norm_layer else nn.Identity() self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity() def forward(self, x): @@ -84,16 +88,19 @@ def forward(self, x): class DSConv(nn.Module): def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size=3, - stride=1, - use_bias=False, - norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), - act_layer=(nn.ReLU6, None), + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + use_bias: Union[bool, Tuple[bool, bool]] = False, + norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d, + act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, None), + device=None, + dtype=None, ): - super(DSConv, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() use_bias = val2tuple(use_bias, 2) norm_layer = val2tuple(norm_layer, 2) act_layer = val2tuple(act_layer, 2) @@ -107,6 +114,7 @@ def __init__( norm_layer=norm_layer[0], act_layer=act_layer[0], bias=use_bias[0], + **dd, ) self.point_conv = ConvNormAct( in_channels, @@ -115,6 +123,7 @@ def __init__( norm_layer=norm_layer[1], act_layer=act_layer[1], bias=use_bias[1], + **dd, ) def forward(self, x): @@ -126,17 +135,20 @@ def forward(self, x): class ConvBlock(nn.Module): def __init__( self, - in_channels: int, - out_channels: int, - kernel_size=3, - stride=1, - mid_channels=None, - expand_ratio=1, - use_bias=False, - norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), - act_layer=(nn.ReLU6, None), + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + mid_channels: Optional[int] = None, + expand_ratio: float = 1, + use_bias: Union[bool, Tuple[bool, bool]] = False, + norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d, + act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, None), + device=None, + dtype=None, ): - super(ConvBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() use_bias = val2tuple(use_bias, 2) norm_layer = val2tuple(norm_layer, 2) act_layer = val2tuple(act_layer, 2) @@ -150,6 +162,7 @@ def __init__( norm_layer=norm_layer[0], act_layer=act_layer[0], bias=use_bias[0], + **dd, ) self.conv2 = ConvNormAct( mid_channels, @@ -159,6 +172,7 @@ def __init__( norm_layer=norm_layer[1], act_layer=act_layer[1], bias=use_bias[1], + **dd, ) def forward(self, x): @@ -169,18 +183,21 @@ def forward(self, x): class MBConv(nn.Module): def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size=3, - stride=1, - mid_channels=None, - expand_ratio=6, - use_bias=False, - norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d), - act_layer=(nn.ReLU6, nn.ReLU6, None), + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + mid_channels: Optional[int] = None, + expand_ratio: float = 6, + use_bias: Union[bool, Tuple[bool, ...]] = False, + norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d, + act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, nn.ReLU6, None), + device=None, + dtype=None, ): - super(MBConv, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() use_bias = val2tuple(use_bias, 3) norm_layer = val2tuple(norm_layer, 3) act_layer = val2tuple(act_layer, 3) @@ -194,6 +211,7 @@ def __init__( norm_layer=norm_layer[0], act_layer=act_layer[0], bias=use_bias[0], + **dd, ) self.depth_conv = ConvNormAct( mid_channels, @@ -204,6 +222,7 @@ def __init__( norm_layer=norm_layer[1], act_layer=act_layer[1], bias=use_bias[1], + **dd, ) self.point_conv = ConvNormAct( mid_channels, @@ -212,6 +231,7 @@ def __init__( norm_layer=norm_layer[2], act_layer=act_layer[2], bias=use_bias[2], + **dd, ) def forward(self, x): @@ -223,19 +243,22 @@ def forward(self, x): class FusedMBConv(nn.Module): def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size=3, - stride=1, - mid_channels=None, - expand_ratio=6, - groups=1, - use_bias=False, - norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), - act_layer=(nn.ReLU6, None), + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + mid_channels: Optional[int] = None, + expand_ratio: float = 6, + groups: int = 1, + use_bias: Union[bool, Tuple[bool, ...]] = False, + norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d, + act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, None), + device=None, + dtype=None, ): - super(FusedMBConv, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() use_bias = val2tuple(use_bias, 2) norm_layer = val2tuple(norm_layer, 2) act_layer = val2tuple(act_layer, 2) @@ -250,6 +273,7 @@ def __init__( norm_layer=norm_layer[0], act_layer=act_layer[0], bias=use_bias[0], + **dd, ) self.point_conv = ConvNormAct( mid_channels, @@ -258,6 +282,7 @@ def __init__( norm_layer=norm_layer[1], act_layer=act_layer[1], bias=use_bias[1], + **dd, ) def forward(self, x): @@ -270,20 +295,23 @@ class LiteMLA(nn.Module): """Lightweight multi-scale linear attention""" def __init__( - self, - in_channels: int, - out_channels: int, - heads: int or None = None, - heads_ratio: float = 1.0, - dim=8, - use_bias=False, - norm_layer=(None, nn.BatchNorm2d), - act_layer=(None, None), - kernel_func=nn.ReLU, - scales=(5,), - eps=1e-5, + self, + in_channels: int, + out_channels: int, + heads: Optional[int] = None, + heads_ratio: float = 1.0, + dim: int = 8, + use_bias: Union[bool, Tuple[bool, ...]] = False, + norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (None, nn.BatchNorm2d), + act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (None, None), + kernel_func: Type[nn.Module] = nn.ReLU, + scales: Tuple[int, ...] = (5,), + eps: float = 1e-5, + device=None, + dtype=None, ): - super(LiteMLA, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.eps = eps heads = heads or int(in_channels // dim * heads_ratio) total_dim = heads * dim @@ -299,6 +327,7 @@ def __init__( bias=use_bias[0], norm_layer=norm_layer[0], act_layer=act_layer[0], + **dd, ) self.aggreg = nn.ModuleList([ nn.Sequential( @@ -309,8 +338,9 @@ def __init__( padding=get_same_padding(scale), groups=3 * total_dim, bias=use_bias[0], + **dd, ), - nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]), + nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0], **dd), ) for scale in scales ]) @@ -323,6 +353,7 @@ def __init__( bias=use_bias[1], norm_layer=norm_layer[1], act_layer=act_layer[1], + **dd, ) def _attn(self, q, k, v): @@ -367,15 +398,18 @@ def forward(self, x): class EfficientVitBlock(nn.Module): def __init__( - self, - in_channels, - heads_ratio=1.0, - head_dim=32, - expand_ratio=4, - norm_layer=nn.BatchNorm2d, - act_layer=nn.Hardswish, + self, + in_channels: int, + heads_ratio: float = 1.0, + head_dim: int = 32, + expand_ratio: float = 4, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.Hardswish, + device=None, + dtype=None, ): - super(EfficientVitBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.context_module = ResidualBlock( LiteMLA( in_channels=in_channels, @@ -383,6 +417,7 @@ def __init__( heads_ratio=heads_ratio, dim=head_dim, norm_layer=(None, norm_layer), + **dd, ), nn.Identity(), ) @@ -394,6 +429,7 @@ def __init__( use_bias=(True, True, False), norm_layer=(None, None, norm_layer), act_layer=(act_layer, act_layer, None), + **dd, ), nn.Identity(), ) @@ -406,12 +442,12 @@ def forward(self, x): class ResidualBlock(nn.Module): def __init__( - self, - main: Optional[nn.Module], - shortcut: Optional[nn.Module] = None, - pre_norm: Optional[nn.Module] = None, + self, + main: Optional[nn.Module], + shortcut: Optional[nn.Module] = None, + pre_norm: Optional[nn.Module] = None, ): - super(ResidualBlock, self).__init__() + super().__init__() self.pre_norm = pre_norm if pre_norm is not None else nn.Identity() self.main = main self.shortcut = shortcut @@ -432,7 +468,10 @@ def build_local_block( act_layer: str, fewer_norm: bool = False, block_type: str = "default", + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} assert block_type in ["default", "large", "fused"] if expand_ratio == 1: if block_type == "default": @@ -443,6 +482,7 @@ def build_local_block( use_bias=(True, False) if fewer_norm else False, norm_layer=(None, norm_layer) if fewer_norm else norm_layer, act_layer=(act_layer, None), + **dd, ) else: block = ConvBlock( @@ -452,6 +492,7 @@ def build_local_block( use_bias=(True, False) if fewer_norm else False, norm_layer=(None, norm_layer) if fewer_norm else norm_layer, act_layer=(act_layer, None), + **dd, ) else: if block_type == "default": @@ -463,6 +504,7 @@ def build_local_block( use_bias=(True, True, False) if fewer_norm else False, norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer, act_layer=(act_layer, act_layer, None), + **dd, ) else: block = FusedMBConv( @@ -473,20 +515,37 @@ def build_local_block( use_bias=(True, False) if fewer_norm else False, norm_layer=(None, norm_layer) if fewer_norm else norm_layer, act_layer=(act_layer, None), + **dd, ) return block class Stem(nn.Sequential): - def __init__(self, in_chs, out_chs, depth, norm_layer, act_layer, block_type='default'): + def __init__( + self, + in_chs: int, + out_chs: int, + depth: int, + norm_layer: Type[nn.Module], + act_layer: Type[nn.Module], + block_type: str = 'default', + device=None, + dtype=None, + ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.stride = 2 self.add_module( 'in_conv', ConvNormAct( - in_chs, out_chs, - kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer, + in_chs, + out_chs, + kernel_size=3, + stride=2, + norm_layer=norm_layer, + act_layer=act_layer, + **dd, ) ) stem_block = 0 @@ -500,6 +559,7 @@ def __init__(self, in_chs, out_chs, depth, norm_layer, act_layer, block_type='de norm_layer=norm_layer, act_layer=act_layer, block_type=block_type, + **dd, ), nn.Identity(), )) @@ -509,16 +569,19 @@ def __init__(self, in_chs, out_chs, depth, norm_layer, act_layer, block_type='de class EfficientVitStage(nn.Module): def __init__( self, - in_chs, - out_chs, - depth, - norm_layer, - act_layer, - expand_ratio, - head_dim, - vit_stage=False, + in_chs: int, + out_chs: int, + depth: int, + norm_layer: Type[nn.Module], + act_layer: Type[nn.Module], + expand_ratio: float, + head_dim: int, + vit_stage: bool = False, + device=None, + dtype=None, ): - super(EfficientVitStage, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() blocks = [ResidualBlock( build_local_block( in_channels=in_chs, @@ -528,6 +591,7 @@ def __init__( norm_layer=norm_layer, act_layer=act_layer, fewer_norm=vit_stage, + **dd, ), None, )] @@ -543,6 +607,7 @@ def __init__( expand_ratio=expand_ratio, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) ) else: @@ -555,7 +620,8 @@ def __init__( stride=1, expand_ratio=expand_ratio, norm_layer=norm_layer, - act_layer=act_layer + act_layer=act_layer, + **dd, ), nn.Identity(), )) @@ -569,16 +635,19 @@ def forward(self, x): class EfficientVitLargeStage(nn.Module): def __init__( self, - in_chs, - out_chs, - depth, - norm_layer, - act_layer, - head_dim, - vit_stage=False, - fewer_norm=False, + in_chs: int, + out_chs: int, + depth: int, + norm_layer: Type[nn.Module], + act_layer: Type[nn.Module], + head_dim: int, + vit_stage: bool = False, + fewer_norm: bool = False, + device=None, + dtype=None, ): - super(EfficientVitLargeStage, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() blocks = [ResidualBlock( build_local_block( in_channels=in_chs, @@ -589,6 +658,7 @@ def __init__( act_layer=act_layer, fewer_norm=vit_stage or fewer_norm, block_type='default' if fewer_norm else 'fused', + **dd, ), None, )] @@ -604,6 +674,7 @@ def __init__( expand_ratio=6, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) ) else: @@ -619,6 +690,7 @@ def __init__( act_layer=act_layer, fewer_norm=fewer_norm, block_type='default' if fewer_norm else 'fused', + **dd, ), nn.Identity(), )) @@ -631,29 +703,32 @@ def forward(self, x): class ClassifierHead(nn.Module): def __init__( - self, - in_channels: int, - widths: List[int], - num_classes: int = 1000, - dropout: float = 0., - norm_layer=nn.BatchNorm2d, - act_layer=nn.Hardswish, - pool_type: str = 'avg', - norm_eps: float = 1e-5, + self, + in_channels: int, + widths: List[int], + num_classes: int = 1000, + dropout: float = 0., + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Optional[Type[nn.Module]] = nn.Hardswish, + pool_type: str = 'avg', + norm_eps: float = 1e-5, + device=None, + dtype=None, ): - super(ClassifierHead, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.widths = widths self.num_features = widths[-1] assert pool_type, 'Cannot disable pooling' - self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer) + self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer, **dd) self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True) self.classifier = nn.Sequential( - nn.Linear(widths[0], widths[1], bias=False), - nn.LayerNorm(widths[1], eps=norm_eps), + nn.Linear(widths[0], widths[1], bias=False, **dd), + nn.LayerNorm(widths[1], eps=norm_eps, **dd), act_layer(inplace=True) if act_layer is not None else nn.Identity(), nn.Dropout(dropout, inplace=False), - nn.Linear(widths[1], num_classes, bias=True) if num_classes > 0 else nn.Identity(), + nn.Linear(widths[1], num_classes, bias=True, **dd) if num_classes > 0 else nn.Identity(), ) def reset(self, num_classes: int, pool_type: Optional[str] = None): @@ -681,26 +756,29 @@ def forward(self, x, pre_logits: bool = False): class EfficientVit(nn.Module): def __init__( - self, - in_chans=3, - widths=(), - depths=(), - head_dim=32, - expand_ratio=4, - norm_layer=nn.BatchNorm2d, - act_layer=nn.Hardswish, - global_pool='avg', - head_widths=(), - drop_rate=0.0, - num_classes=1000, + self, + in_chans: int = 3, + widths: Tuple[int, ...] = (), + depths: Tuple[int, ...] = (), + head_dim: int = 32, + expand_ratio: float = 4, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.Hardswish, + global_pool: str = 'avg', + head_widths: Tuple[int, ...] = (), + drop_rate: float = 0.0, + num_classes: int = 1000, + device=None, + dtype=None, ): - super(EfficientVit, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.grad_checkpointing = False self.global_pool = global_pool self.num_classes = num_classes # input stem - self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer) + self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, **dd) stride = self.stem.stride # stages @@ -717,6 +795,7 @@ def __init__( expand_ratio=expand_ratio, head_dim=head_dim, vit_stage=i >= 2, + **dd, )) stride *= 2 in_channels = w @@ -729,6 +808,7 @@ def __init__( num_classes=num_classes, dropout=drop_rate, pool_type=self.global_pool, + **dd, ) self.head_hidden_size = self.head.num_features @@ -835,19 +915,22 @@ def forward(self, x): class EfficientVitLarge(nn.Module): def __init__( self, - in_chans=3, - widths=(), - depths=(), - head_dim=32, - norm_layer=nn.BatchNorm2d, - act_layer=GELUTanh, - global_pool='avg', - head_widths=(), - drop_rate=0.0, - num_classes=1000, - norm_eps=1e-7, + in_chans: int = 3, + widths: Tuple[int, ...] = (), + depths: Tuple[int, ...] = (), + head_dim: int = 32, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Type[nn.Module] = GELUTanh, + global_pool: str = 'avg', + head_widths: Tuple[int, ...] = (), + drop_rate: float = 0.0, + num_classes: int = 1000, + norm_eps: float = 1e-7, + device=None, + dtype=None, ): - super(EfficientVitLarge, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.grad_checkpointing = False self.global_pool = global_pool self.num_classes = num_classes @@ -855,7 +938,7 @@ def __init__( norm_layer = partial(norm_layer, eps=self.norm_eps) # input stem - self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, block_type='large') + self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, block_type='large', **dd) stride = self.stem.stride # stages @@ -872,6 +955,7 @@ def __init__( head_dim=head_dim, vit_stage=i >= 3, fewer_norm=i >= 2, + **dd, )) stride *= 2 in_channels = w @@ -886,6 +970,7 @@ def __init__( pool_type=self.global_pool, act_layer=act_layer, norm_eps=self.norm_eps, + **dd, ) self.head_hidden_size = self.head.num_features diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index 80c5d99995..497e984726 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -9,7 +9,7 @@ __all__ = ['EfficientVitMsra'] import itertools from collections import OrderedDict -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -23,12 +23,24 @@ class ConvNorm(torch.nn.Sequential): - def __init__(self, in_chs, out_chs, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + def __init__( + self, + in_chs: int, + out_chs: int, + ks: int = 1, + stride: int = 1, + pad: int = 0, + dilation: int = 1, + groups: int = 1, + bn_weight_init: float = 1, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False) - self.bn = nn.BatchNorm2d(out_chs) + self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False, **dd) + self.bn = nn.BatchNorm2d(out_chs, **dd) torch.nn.init.constant_(self.bn.weight, bn_weight_init) - torch.nn.init.constant_(self.bn.bias, 0) @torch.no_grad() def fuse(self): @@ -46,11 +58,21 @@ def fuse(self): class NormLinear(torch.nn.Sequential): - def __init__(self, in_features, out_features, bias=True, std=0.02, drop=0.): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + std: float = 0.02, + drop: float = 0., + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.bn = nn.BatchNorm1d(in_features) + self.bn = nn.BatchNorm1d(in_features, **dd) self.drop = nn.Dropout(drop) - self.linear = nn.Linear(in_features, out_features, bias=bias) + self.linear = nn.Linear(in_features, out_features, bias=bias, **dd) trunc_normal_(self.linear.weight, std=std) if self.linear.bias is not None: @@ -74,14 +96,21 @@ def fuse(self): class PatchMerging(torch.nn.Module): - def __init__(self, dim, out_dim): + def __init__( + self, + dim: int, + out_dim: int, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() hid_dim = int(dim * 4) - self.conv1 = ConvNorm(dim, hid_dim, 1, 1, 0) + self.conv1 = ConvNorm(dim, hid_dim, 1, 1, 0, **dd) self.act = torch.nn.ReLU() - self.conv2 = ConvNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim) - self.se = SqueezeExcite(hid_dim, .25) - self.conv3 = ConvNorm(hid_dim, out_dim, 1, 1, 0) + self.conv2 = ConvNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim, **dd) + self.se = SqueezeExcite(hid_dim, .25, **dd) + self.conv3 = ConvNorm(hid_dim, out_dim, 1, 1, 0, **dd) def forward(self, x): x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x)))))) @@ -89,7 +118,7 @@ def forward(self, x): class ResidualDrop(torch.nn.Module): - def __init__(self, m, drop=0.): + def __init__(self, m: nn.Module, drop: float = 0.): super().__init__() self.m = m self.drop = drop @@ -103,11 +132,18 @@ def forward(self, x): class ConvMlp(torch.nn.Module): - def __init__(self, ed, h): + def __init__( + self, + ed: int, + h: int, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.pw1 = ConvNorm(ed, h) + self.pw1 = ConvNorm(ed, h, **dd) self.act = torch.nn.ReLU() - self.pw2 = ConvNorm(h, ed, bn_weight_init=0) + self.pw2 = ConvNorm(h, ed, bn_weight_init=0, **dd) def forward(self, x): x = self.pw2(self.act(self.pw1(x))) @@ -129,13 +165,16 @@ class CascadedGroupAttention(torch.nn.Module): """ def __init__( self, - dim, - key_dim, - num_heads=8, - attn_ratio=4, - resolution=14, - kernels=(5, 5, 5, 5), + dim: int, + key_dim: int, + num_heads: int = 8, + attn_ratio: int = 4, + resolution: int = 14, + kernels: Tuple[int, ...] = (5, 5, 5, 5), + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads self.scale = key_dim ** -0.5 @@ -146,13 +185,13 @@ def __init__( qkvs = [] dws = [] for i in range(num_heads): - qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.val_dim)) - dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim)) + qkvs.append(ConvNorm(dim // num_heads, self.key_dim * 2 + self.val_dim, **dd)) + dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim, **dd)) self.qkvs = torch.nn.ModuleList(qkvs) self.dws = torch.nn.ModuleList(dws) self.proj = torch.nn.Sequential( torch.nn.ReLU(), - ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0) + ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0, **dd) ) points = list(itertools.product(range(resolution), range(resolution))) @@ -165,8 +204,12 @@ def __init__( if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) - self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets), **dd)) + self.register_buffer( + 'attention_bias_idxs', + torch.tensor(idxs, device=device, dtype=torch.long).view(N, N), + persistent=False, + ) self.attention_bias_cache = {} @torch.no_grad() @@ -222,14 +265,17 @@ class LocalWindowAttention(torch.nn.Module): """ def __init__( self, - dim, - key_dim, - num_heads=8, - attn_ratio=4, - resolution=14, - window_resolution=7, - kernels=(5, 5, 5, 5), + dim: int, + key_dim: int, + num_heads: int = 8, + attn_ratio: int = 4, + resolution: int = 14, + window_resolution: int = 7, + kernels: Tuple[int, ...] = (5, 5, 5, 5), + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.num_heads = num_heads @@ -242,6 +288,7 @@ def __init__( attn_ratio=attn_ratio, resolution=window_resolution, kernels=kernels, + **dd, ) def forward(self, x): @@ -287,18 +334,21 @@ class EfficientVitBlock(torch.nn.Module): """ def __init__( self, - dim, - key_dim, - num_heads=8, - attn_ratio=4, - resolution=14, - window_resolution=7, - kernels=[5, 5, 5, 5], + dim: int, + key_dim: int, + num_heads: int = 8, + attn_ratio: int = 4, + resolution: int = 14, + window_resolution: int = 7, + kernels: List[int] = [5, 5, 5, 5], + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.dw0 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.)) - self.ffn0 = ResidualDrop(ConvMlp(dim, int(dim * 2))) + self.dw0 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0., **dd)) + self.ffn0 = ResidualDrop(ConvMlp(dim, int(dim * 2), **dd)) self.mixer = ResidualDrop( LocalWindowAttention( @@ -307,11 +357,12 @@ def __init__( resolution=resolution, window_resolution=window_resolution, kernels=kernels, - ) + **dd, + ), ) - self.dw1 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.)) - self.ffn1 = ResidualDrop(ConvMlp(dim, int(dim * 2))) + self.dw1 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0., **dd)) + self.ffn1 = ResidualDrop(ConvMlp(dim, int(dim * 2), **dd)) def forward(self, x): return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x))))) @@ -320,17 +371,20 @@ def forward(self, x): class EfficientVitStage(torch.nn.Module): def __init__( self, - in_dim, - out_dim, - key_dim, - downsample=('', 1), - num_heads=8, - attn_ratio=4, - resolution=14, - window_resolution=7, - kernels=[5, 5, 5, 5], - depth=1, + in_dim: int, + out_dim: int, + key_dim: int, + downsample: Tuple[str, int] = ('', 1), + num_heads: int = 8, + attn_ratio: int = 4, + resolution: int = 14, + window_resolution: int = 7, + kernels: List[int] = [5, 5, 5, 5], + depth: int = 1, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() if downsample[0] == 'subsample': self.resolution = (resolution - 1) // downsample[1] + 1 @@ -338,16 +392,16 @@ def __init__( down_blocks.append(( 'res1', torch.nn.Sequential( - ResidualDrop(ConvNorm(in_dim, in_dim, 3, 1, 1, groups=in_dim)), - ResidualDrop(ConvMlp(in_dim, int(in_dim * 2))), + ResidualDrop(ConvNorm(in_dim, in_dim, 3, 1, 1, groups=in_dim, **dd)), + ResidualDrop(ConvMlp(in_dim, int(in_dim * 2), **dd)), ) )) - down_blocks.append(('patchmerge', PatchMerging(in_dim, out_dim))) + down_blocks.append(('patchmerge', PatchMerging(in_dim, out_dim, **dd))) down_blocks.append(( 'res2', torch.nn.Sequential( - ResidualDrop(ConvNorm(out_dim, out_dim, 3, 1, 1, groups=out_dim)), - ResidualDrop(ConvMlp(out_dim, int(out_dim * 2))), + ResidualDrop(ConvNorm(out_dim, out_dim, 3, 1, 1, groups=out_dim, **dd)), + ResidualDrop(ConvMlp(out_dim, int(out_dim * 2), **dd)), ) )) self.downsample = nn.Sequential(OrderedDict(down_blocks)) @@ -358,7 +412,16 @@ def __init__( blocks = [] for d in range(depth): - blocks.append(EfficientVitBlock(out_dim, key_dim, num_heads, attn_ratio, self.resolution, window_resolution, kernels)) + blocks.append(EfficientVitBlock( + out_dim, + key_dim, + num_heads, + attn_ratio, + self.resolution, + window_resolution, + kernels, + **dd, + )) self.blocks = nn.Sequential(*blocks) def forward(self, x): @@ -368,41 +431,51 @@ def forward(self, x): class PatchEmbedding(torch.nn.Sequential): - def __init__(self, in_chans, dim): + def __init__( + self, + in_chans: int, + dim: int, + device=None, + dtype=None, + ): super().__init__() - self.add_module('conv1', ConvNorm(in_chans, dim // 8, 3, 2, 1)) + dd = {'device': device, 'dtype': dtype} + self.add_module('conv1', ConvNorm(in_chans, dim // 8, 3, 2, 1, **dd)) self.add_module('relu1', torch.nn.ReLU()) - self.add_module('conv2', ConvNorm(dim // 8, dim // 4, 3, 2, 1)) + self.add_module('conv2', ConvNorm(dim // 8, dim // 4, 3, 2, 1, **dd)) self.add_module('relu2', torch.nn.ReLU()) - self.add_module('conv3', ConvNorm(dim // 4, dim // 2, 3, 2, 1)) + self.add_module('conv3', ConvNorm(dim // 4, dim // 2, 3, 2, 1, **dd)) self.add_module('relu3', torch.nn.ReLU()) - self.add_module('conv4', ConvNorm(dim // 2, dim, 3, 2, 1)) + self.add_module('conv4', ConvNorm(dim // 2, dim, 3, 2, 1, **dd)) self.patch_size = 16 class EfficientVitMsra(nn.Module): def __init__( self, - img_size=224, - in_chans=3, - num_classes=1000, - embed_dim=(64, 128, 192), - key_dim=(16, 16, 16), - depth=(1, 2, 3), - num_heads=(4, 4, 4), - window_size=(7, 7, 7), - kernels=(5, 5, 5, 5), - down_ops=(('', 1), ('subsample', 2), ('subsample', 2)), - global_pool='avg', - drop_rate=0., + img_size: int = 224, + in_chans: int = 3, + num_classes: int = 1000, + embed_dim: Tuple[int, ...] = (64, 128, 192), + key_dim: Tuple[int, ...] = (16, 16, 16), + depth: Tuple[int, ...] = (1, 2, 3), + num_heads: Tuple[int, ...] = (4, 4, 4), + window_size: Tuple[int, ...] = (7, 7, 7), + kernels: Tuple[int, ...] = (5, 5, 5, 5), + down_ops: Tuple[Tuple[str, int], ...] = (('', 1), ('subsample', 2), ('subsample', 2)), + global_pool: str = 'avg', + drop_rate: float = 0., + device=None, + dtype=None, ): - super(EfficientVitMsra, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} self.grad_checkpointing = False self.num_classes = num_classes self.drop_rate = drop_rate # Patch embedding - self.patch_embed = PatchEmbedding(in_chans, embed_dim[0]) + self.patch_embed = PatchEmbedding(in_chans, embed_dim[0], **dd) stride = self.patch_embed.patch_size resolution = img_size // self.patch_embed.patch_size attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))] @@ -424,6 +497,7 @@ def __init__( window_resolution=wd, kernels=kernels, depth=dpth, + **dd, ) pre_ed = ed if do[0] == 'subsample' and i != 0: @@ -440,7 +514,7 @@ def __init__( self.global_pool = nn.Identity() self.num_features = self.head_hidden_size = embed_dim[-1] self.head = NormLinear( - self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() + self.num_features, num_classes, drop=self.drop_rate, **dd) if num_classes > 0 else torch.nn.Identity() @torch.jit.ignore def no_weight_decay(self): diff --git a/timm/models/eva.py b/timm/models/eva.py index f2da2ba55e..f63058af59 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -121,6 +121,8 @@ def __init__( qk_norm: bool = False, scale_norm: bool = True, rotate_half: bool = False, + device=None, + dtype=None, ): """ Args: @@ -139,6 +141,7 @@ def __init__( scale_norm: Enable normalization (scaling) of attention output with norm_layer rotate_half: Use half rotation layout instead of interleaved """ + dd = {'device': device, 'dtype': dtype} super().__init__() if scale_norm or qk_norm: assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' @@ -154,25 +157,25 @@ def __init__( self.rotate_half = rotate_half if qkv_fused: - self.qkv = nn.Linear(dim, attn_dim * 3, bias=False) + self.qkv = nn.Linear(dim, attn_dim * 3, bias=False, **dd) self.q_proj = self.k_proj = self.v_proj = None if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(attn_dim)) - self.register_buffer('k_bias', torch.zeros(attn_dim), persistent=False) - self.v_bias = nn.Parameter(torch.zeros(attn_dim)) + self.q_bias = nn.Parameter(torch.zeros(attn_dim, **dd)) + self.register_buffer('k_bias', torch.zeros(attn_dim, **dd), persistent=False) + self.v_bias = nn.Parameter(torch.zeros(attn_dim, **dd)) else: self.q_bias = self.k_bias = self.v_bias = None else: - self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) - self.k_proj = nn.Linear(dim, attn_dim, bias=False) - self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) + self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd) + self.k_proj = nn.Linear(dim, attn_dim, bias=False, **dd) + self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd) self.qkv = None self.q_bias = self.k_bias = self.v_bias = None - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity() - self.proj = nn.Linear(attn_dim, dim) + self.norm = norm_layer(attn_dim, **dd) if scale_norm else nn.Identity() + self.proj = nn.Linear(attn_dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward( @@ -263,6 +266,8 @@ def __init__( act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, attn_head_dim: Optional[int] = None, + device=None, + dtype=None, **kwargs, ): """ Initialize the EVA transformer block. @@ -286,8 +291,10 @@ def __init__( norm_layer: Normalization layer constructor attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads) """ + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm1 = norm_layer(dim) + + self.norm1 = norm_layer(dim, **dd) attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention self.attn = attn_cls( dim, @@ -301,11 +308,12 @@ def __init__( norm_layer=norm_layer, scale_norm=scale_attn_inner, rotate_half=rotate_half, + **dd, ) - self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None + self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd)) if init_values is not None else None self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) hidden_features = int(dim * mlp_ratio) if swiglu_mlp: if scale_mlp or swiglu_align_to: @@ -316,6 +324,7 @@ def __init__( norm_layer=norm_layer if scale_mlp else None, drop=proj_drop, align_to=swiglu_align_to, + **dd, ) else: # w/o any extra norm, an impl with packed weights is used @@ -326,6 +335,7 @@ def __init__( act_layer=nn.SiLU, gate_last=False, drop=proj_drop, + **dd, ) else: self.mlp = Mlp( @@ -334,8 +344,9 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer if scale_mlp else None, drop=proj_drop, + **dd, ) - self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None + self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd)) if init_values is not None else None self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward( @@ -376,6 +387,8 @@ def __init__( act_layer: Callable = nn.GELU, norm_layer: Callable = nn.LayerNorm, attn_head_dim: Optional[int] = None, + device=None, + dtype=None, ): """ Initialize the post-norm EVA transformer block. @@ -398,7 +411,9 @@ def __init__( norm_layer: Normalization layer constructor attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads) """ + dd = {'device': device, 'dtype': dtype} super().__init__() + attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention self.attn = attn_cls( dim, @@ -412,8 +427,9 @@ def __init__( norm_layer=norm_layer, scale_norm=scale_attn_inner, rotate_half=rotate_half, + **dd, ) - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() hidden_features = int(dim * mlp_ratio) @@ -426,6 +442,7 @@ def __init__( norm_layer=norm_layer if scale_mlp else None, drop=proj_drop, align_to=swiglu_align_to, + **dd, ) else: # w/o any extra norm, an impl with packed fc1 weights is used, matches existing GluMLP @@ -436,6 +453,7 @@ def __init__( act_layer=nn.SiLU, gate_last=False, drop=proj_drop, + **dd, ) else: self.mlp = Mlp( @@ -444,8 +462,9 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer if scale_mlp else None, drop=proj_drop, + **dd, ) - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward( @@ -513,6 +532,8 @@ def __init__( dynamic_img_pad: bool = False, ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None, head_init_scale: float = 0.001, + device=None, + dtype=None, ): """Initialize the EVA Vision Transformer model. @@ -562,6 +583,7 @@ def __init__( head_init_scale: Initialization scale for classification head weights """ super().__init__() + dd = {'device': device, 'dtype': dtype} assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') self.num_classes = num_classes self.global_pool = global_pool @@ -594,16 +616,17 @@ def __init__( dynamic_img_pad=dynamic_img_pad, bias=not use_pre_transformer_norm, **embed_args, + **dd, ) num_patches = self.patch_embed.num_patches r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None - self.reg_token = nn.Parameter(torch.zeros(1, num_reg_tokens, embed_dim)) if num_reg_tokens else None + self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, **dd)) if class_token else None + self.reg_token = nn.Parameter(torch.empty(1, num_reg_tokens, embed_dim, **dd)) if num_reg_tokens else None self.cls_embed = class_token and self.reg_token is None num_pos_tokens = num_patches if no_embed_class else num_patches + self.num_prefix_tokens - self.pos_embed = nn.Parameter(torch.zeros(1, num_pos_tokens, embed_dim)) if use_abs_pos_emb else None + self.pos_embed = nn.Parameter(torch.empty(1, num_pos_tokens, embed_dim, **dd)) if use_abs_pos_emb else None self.pos_drop = nn.Dropout(p=pos_drop_rate) if patch_drop_rate > 0: self.patch_drop = PatchDropoutWithIndices(patch_drop_rate, num_prefix_tokens=self.num_prefix_tokens) @@ -621,6 +644,7 @@ def __init__( feat_shape=None if dynamic_img_size else self.patch_embed.grid_size, temperature=rope_temperature, grid_indexing=rope_grid_indexing, + **dd, ) if rope_type == 'mixed': rope_kwargs.update(dict(depth=depth)) @@ -636,7 +660,7 @@ def __init__( else: self.rope = None - self.norm_pre = norm_layer(embed_dim) if activate_pre_norm else nn.Identity() + self.norm_pre = norm_layer(embed_dim, **dd) if activate_pre_norm else nn.Identity() dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock @@ -659,12 +683,13 @@ def __init__( drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, + **dd, ) for i in range(depth)]) self.feature_info = [ dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] - self.norm = norm_layer(embed_dim) if activate_post_norm else nn.Identity() + self.norm = norm_layer(embed_dim, **dd) if activate_post_norm else nn.Identity() if global_pool == 'map': self.attn_pool = AttentionPoolLatent( @@ -673,13 +698,17 @@ def __init__( mlp_ratio=attn_pool_mlp_ratio or mlp_ratio, norm_layer=norm_layer, act_layer=nn.GELU, + **dd, ) else: self.attn_pool = None - self.fc_norm = norm_layer(embed_dim) if activate_fc_norm else nn.Identity() + self.fc_norm = norm_layer(embed_dim, **dd) if activate_fc_norm else nn.Identity() self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity() + self.init_weights(head_init_scale=head_init_scale) + + def init_weights(self, head_init_scale=None): self.apply(self._init_weights) if self.pos_embed is not None: trunc_normal_(self.pos_embed, std=.02) @@ -687,9 +716,8 @@ def __init__( trunc_normal_(self.cls_token, std=.02) if self.reg_token is not None: trunc_normal_(self.reg_token, std=.02) - self.fix_init_weight() - if isinstance(self.head, nn.Linear): + if head_init_scale and isinstance(self.head, nn.Linear): trunc_normal_(self.head.weight, std=.02) self.head.weight.data.mul_(head_init_scale) self.head.bias.data.mul_(head_init_scale) diff --git a/timm/models/fasternet.py b/timm/models/fasternet.py index 747cdf6b39..f956f70ac4 100644 --- a/timm/models/fasternet.py +++ b/timm/models/fasternet.py @@ -16,7 +16,7 @@ # Licensed under the MIT License. from functools import partial -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import torch import torch.nn as nn @@ -33,11 +33,12 @@ class Partial_conv3(nn.Module): - def __init__(self, dim: int, n_div: int, forward: str): + def __init__(self, dim: int, n_div: int, forward: str, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim_conv3 = dim // n_div self.dim_untouched = dim - self.dim_conv3 - self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False) + self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False, **dd) if forward == 'slicing': self.forward = self.forward_slicing @@ -68,25 +69,28 @@ def __init__( mlp_ratio: float, drop_path: float, layer_scale_init_value: float, - act_layer: LayerType = partial(nn.ReLU, inplace=True), - norm_layer: LayerType = nn.BatchNorm2d, + act_layer: Type[nn.Module] = partial(nn.ReLU, inplace=True), + norm_layer: Type[nn.Module] = nn.BatchNorm2d, pconv_fw_type: str = 'split_cat', + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = nn.Sequential(*[ - nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False), - norm_layer(mlp_hidden_dim), + nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False, **dd), + norm_layer(mlp_hidden_dim, **dd), act_layer(), - nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False), + nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False, **dd), ]) - self.spatial_mixing = Partial_conv3(dim, n_div, pconv_fw_type) + self.spatial_mixing = Partial_conv3(dim, n_div, pconv_fw_type, **dd) if layer_scale_init_value > 0: self.layer_scale = nn.Parameter( - layer_scale_init_value * torch.ones((dim)), requires_grad=True) + layer_scale_init_value * torch.ones((dim), **dd), requires_grad=True) else: self.layer_scale = None @@ -112,12 +116,15 @@ def __init__( mlp_ratio: float, drop_path: float, layer_scale_init_value: float, - act_layer: LayerType = partial(nn.ReLU, inplace=True), - norm_layer: LayerType = nn.BatchNorm2d, + act_layer: Type[nn.Module] = partial(nn.ReLU, inplace=True), + norm_layer: Type[nn.Module] = nn.BatchNorm2d, pconv_fw_type: str = 'split_cat', use_merge: bool = True, merge_size: Union[int, Tuple[int, int]] = 2, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.grad_checkpointing = False self.blocks = nn.Sequential(*[ @@ -130,6 +137,7 @@ def __init__( norm_layer=norm_layer, act_layer=act_layer, pconv_fw_type=pconv_fw_type, + **dd, ) for i in range(depth) ]) @@ -137,6 +145,7 @@ def __init__( dim=dim // 2, patch_size=merge_size, norm_layer=norm_layer, + **dd, ) if use_merge else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -154,11 +163,14 @@ def __init__( in_chans: int, embed_dim: int, patch_size: Union[int, Tuple[int, int]] = 4, - norm_layer: LayerType = nn.BatchNorm2d, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.proj = nn.Conv2d(in_chans, embed_dim, patch_size, patch_size, bias=False) - self.norm = norm_layer(embed_dim) + self.proj = nn.Conv2d(in_chans, embed_dim, patch_size, patch_size, bias=False, **dd) + self.norm = norm_layer(embed_dim, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.norm(self.proj(x)) @@ -169,11 +181,14 @@ def __init__( self, dim: int, patch_size: Union[int, Tuple[int, int]] = 2, - norm_layer: LayerType = nn.BatchNorm2d, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.reduction = nn.Conv2d(dim, 2 * dim, patch_size, patch_size, bias=False) - self.norm = norm_layer(2 * dim) + self.reduction = nn.Conv2d(dim, 2 * dim, patch_size, patch_size, bias=False, **dd) + self.norm = norm_layer(2 * dim, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.norm(self.reduction(x)) @@ -196,11 +211,14 @@ def __init__( drop_rate: float = 0., drop_path_rate: float = 0.1, layer_scale_init_value: float = 0., - act_layer: LayerType = partial(nn.ReLU, inplace=True), - norm_layer: LayerType = nn.BatchNorm2d, + act_layer: Type[nn.Module] = partial(nn.ReLU, inplace=True), + norm_layer: Type[nn.Module] = nn.BatchNorm2d, pconv_fw_type: str = 'split_cat', + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} assert pconv_fw_type in ('split_cat', 'slicing',) self.num_classes = num_classes self.drop_rate = drop_rate @@ -214,9 +232,10 @@ def __init__( embed_dim=embed_dim, patch_size=patch_size, norm_layer=norm_layer if patch_norm else nn.Identity, + **dd, ) # stochastic depth decay rule - dpr = calculate_drop_path_rates(drop_path_rate, sum(depths)) + dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) # build layers stages_list = [] @@ -227,13 +246,14 @@ def __init__( depth=depths[i], n_div=n_div, mlp_ratio=mlp_ratio, - drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], + drop_path=dpr[i], layer_scale_init_value=layer_scale_init_value, norm_layer=norm_layer, act_layer=act_layer, pconv_fw_type=pconv_fw_type, use_merge=False if i == 0 else True, merge_size=merge_size, + **dd, ) stages_list.append(stage) self.feature_info += [dict(num_chs=dim, reduction=2**(i+2), module=f'stages.{i}')] @@ -243,10 +263,10 @@ def __init__( self.num_features = prev_chs = int(embed_dim * 2 ** (self.num_stages - 1)) self.head_hidden_size = out_chs = feature_dim # 1280 self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=False) + self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=False, **dd) self.act = act_layer() self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(out_chs, num_classes, bias=True) if num_classes > 0 else nn.Identity() + self.classifier = Linear(out_chs, num_classes, bias=True, **dd) if num_classes > 0 else nn.Identity() self._initialize_weights() def _initialize_weights(self): @@ -285,12 +305,13 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.classifier - def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg', device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(self.head_hidden_size, num_classes, **dd) if num_classes > 0 else nn.Identity() def forward_intermediates( self, diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index be47d378af..bef5ec8b61 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -14,7 +14,8 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import ( - DropPath, calculate_drop_path_rates, + DropPath, + calculate_drop_path_rates, trunc_normal_, create_conv2d, ConvNormAct, @@ -63,6 +64,8 @@ def __init__( use_scale_branch: bool = True, num_conv_branches: int = 1, act_layer: Type[nn.Module] = nn.GELU, + device=None, + dtype=None, ) -> None: """Construct a MobileOneBlock module. @@ -79,7 +82,8 @@ def __init__( use_scale_branch: Whether to use scale branch. Default: ``True`` num_conv_branches: Number of linear conv branches. """ - super(MobileOneBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.inference_mode = inference_mode self.groups = num_groups(group_size, in_chs) self.stride = stride @@ -90,7 +94,7 @@ def __init__( self.num_conv_branches = num_conv_branches # Check if SE-ReLU is requested - self.se = SqueezeExcite(out_chs, rd_divisor=1) if use_se else nn.Identity() + self.se = SqueezeExcite(out_chs, rd_divisor=1, **dd) if use_se else nn.Identity() if inference_mode: self.reparam_conv = create_conv2d( @@ -101,13 +105,14 @@ def __init__( dilation=dilation, groups=self.groups, bias=True, + **dd, ) else: # Re-parameterizable skip connection self.reparam_conv = None self.identity = ( - nn.BatchNorm2d(num_features=in_chs) + nn.BatchNorm2d(num_features=in_chs, **dd) if out_chs == in_chs and stride == 1 else None ) @@ -122,6 +127,7 @@ def __init__( stride=self.stride, groups=self.groups, apply_act=False, + **dd, ) for _ in range(self.num_conv_branches) ]) else: @@ -136,7 +142,8 @@ def __init__( kernel_size=1, stride=self.stride, groups=self.groups, - apply_act=False + apply_act=False, + **dd, ) self.act = act_layer() if use_act else nn.Identity() @@ -237,7 +244,8 @@ def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: return kernel_final, bias_final def _fuse_bn_tensor( - self, branch: Union[nn.Sequential, nn.BatchNorm2d] + self, + branch: Union[nn.Sequential, nn.BatchNorm2d] ) -> Tuple[torch.Tensor, torch.Tensor]: """Method to fuse batchnorm layer with preceding conv layer. Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 @@ -300,6 +308,8 @@ def __init__( use_se: bool = False, act_layer: Optional[nn.Module] = None, inference_mode: bool = False, + device=None, + dtype=None, ) -> None: """Construct a ReparamLargeKernelConv module. @@ -313,7 +323,8 @@ def __init__( act_layer: Activation module. Default: ``nn.GELU`` inference_mode: If True, instantiates model in inference mode. Default: ``False`` """ - super(ReparamLargeKernelConv, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.stride = stride self.groups = num_groups(group_size, in_chs) self.in_chs = in_chs @@ -330,6 +341,7 @@ def __init__( dilation=1, groups=self.groups, bias=True, + **dd, ) else: self.reparam_conv = None @@ -340,6 +352,7 @@ def __init__( stride=self.stride, groups=self.groups, apply_act=False, + **dd, ) if small_kernel is not None: assert ( @@ -352,8 +365,9 @@ def __init__( stride=self.stride, groups=self.groups, apply_act=False, + **dd, ) - self.se = SqueezeExcite(out_chs, rd_ratio=0.25) if use_se else nn.Identity() + self.se = SqueezeExcite(out_chs, rd_ratio=0.25, **dd) if use_se else nn.Identity() # FIXME output of this act was not used in original impl, likely due to bug self.act = act_layer() if act_layer is not None else nn.Identity() @@ -409,7 +423,8 @@ def reparameterize(self) -> None: @staticmethod def _fuse_bn( - conv: nn.Conv2d, bn: nn.BatchNorm2d + conv: nn.Conv2d, + bn: nn.BatchNorm2d ) -> Tuple[torch.Tensor, torch.Tensor]: """Method to fuse batchnorm layer with conv layer. @@ -437,6 +452,8 @@ def convolutional_stem( act_layer: Type[nn.Module] = nn.GELU, inference_mode: bool = False, use_scale_branch: bool = True, + device=None, + dtype=None, ) -> nn.Sequential: """Build convolutional stem with MobileOne blocks. @@ -448,6 +465,7 @@ def convolutional_stem( Returns: nn.Sequential object with stem elements. """ + dd = {'device': device, 'dtype': dtype} return nn.Sequential( MobileOneBlock( in_chs=in_chs, @@ -457,6 +475,7 @@ def convolutional_stem( act_layer=act_layer, inference_mode=inference_mode, use_scale_branch=use_scale_branch, + **dd, ), MobileOneBlock( in_chs=out_chs, @@ -467,6 +486,7 @@ def convolutional_stem( act_layer=act_layer, inference_mode=inference_mode, use_scale_branch=use_scale_branch, + **dd, ), MobileOneBlock( in_chs=out_chs, @@ -476,6 +496,7 @@ def convolutional_stem( act_layer=act_layer, inference_mode=inference_mode, use_scale_branch=use_scale_branch, + **dd, ), ) @@ -495,6 +516,8 @@ def __init__( qkv_bias: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, + device=None, + dtype=None, ) -> None: """Build MHSA module that can handle 3D or 4D input tensors. @@ -505,6 +528,7 @@ def __init__( attn_drop: Dropout rate for attention tensor. proj_drop: Dropout rate for projection tensor. """ + dd = {'device': device, 'dtype': dtype} super().__init__() assert dim % head_dim == 0, "dim should be divisible by head_dim" self.head_dim = head_dim @@ -512,9 +536,9 @@ def __init__( self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -561,6 +585,8 @@ def __init__( lkc_use_act: bool = False, use_se: bool = False, inference_mode: bool = False, + device=None, + dtype=None, ) -> None: """Build patch embedding layer. @@ -571,6 +597,7 @@ def __init__( embed_dim: Number of embedding dimensions. inference_mode: Flag to instantiate model in inference mode. Default: ``False`` """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.proj = nn.Sequential( ReparamLargeKernelConv( @@ -583,6 +610,7 @@ def __init__( use_se=use_se, act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act inference_mode=inference_mode, + **dd, ), MobileOneBlock( in_chs=embed_dim, @@ -592,6 +620,7 @@ def __init__( use_se=False, act_layer=act_layer, inference_mode=inference_mode, + **dd, ) ) @@ -601,10 +630,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LayerScale2d(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + device=None, + dtype=None, + ): super().__init__() self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim, 1, 1)) + self.gamma = nn.Parameter(init_values * torch.ones(dim, 1, 1, device=device, dtype=dtype)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma @@ -619,10 +655,12 @@ class RepMixer(nn.Module): def __init__( self, - dim, - kernel_size=3, - layer_scale_init_value=1e-5, + dim: int, + kernel_size: int = 3, + layer_scale_init_value: Optional[float] = 1e-5, inference_mode: bool = False, + device=None, + dtype=None, ): """Build RepMixer Module. @@ -632,6 +670,7 @@ def __init__( layer_scale_init_value: Initial value for layer scale. Default: 1e-5 inference_mode: If True, instantiates model in inference mode. Default: ``False`` """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.kernel_size = kernel_size @@ -646,6 +685,7 @@ def __init__( padding=self.kernel_size // 2, groups=self.dim, bias=True, + **dd, ) else: self.reparam_conv = None @@ -657,6 +697,7 @@ def __init__( use_act=False, use_scale_branch=False, num_conv_branches=0, + **dd, ) self.mixer = MobileOneBlock( dim, @@ -664,9 +705,10 @@ def __init__( kernel_size, group_size=1, use_act=False, + **dd, ) if layer_scale_init_value is not None: - self.layer_scale = LayerScale2d(dim, layer_scale_init_value) + self.layer_scale = LayerScale2d(dim, layer_scale_init_value, **dd) else: self.layer_scale = nn.Identity() @@ -732,6 +774,8 @@ def __init__( out_chs: Optional[int] = None, act_layer: Type[nn.Module] = nn.GELU, drop: float = 0.0, + device=None, + dtype=None, ) -> None: """Build convolutional FFN module. @@ -742,6 +786,7 @@ def __init__( act_layer: Activation layer. Default: ``GELU`` drop: Dropout rate. Default: ``0.0``. """ + dd = {'device': device, 'dtype': dtype} super().__init__() out_chs = out_chs or in_chs hidden_channels = hidden_channels or in_chs @@ -751,10 +796,11 @@ def __init__( kernel_size=7, groups=in_chs, apply_act=False, + **dd, ) - self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1) + self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1, **dd) self.act = act_layer() - self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1) + self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1, **dd) self.drop = nn.Dropout(drop) self.apply(self._init_weights) @@ -788,7 +834,9 @@ def __init__( dim: int, dim_out: Optional[int] = None, spatial_shape: Union[int, Tuple[int, int]] = (7, 7), - inference_mode=False, + inference_mode: bool = False, + device=None, + dtype=None, ) -> None: """Build reparameterizable conditional positional encoding @@ -798,7 +846,8 @@ def __init__( spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ - super(RepConditionalPosEnc, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() if isinstance(spatial_shape, int): spatial_shape = tuple([spatial_shape] * 2) assert isinstance(spatial_shape, Tuple), ( @@ -824,6 +873,7 @@ def __init__( padding=spatial_shape[0] // 2, groups=self.groups, bias=True, + **dd, ) else: self.reparam_conv = None @@ -835,6 +885,7 @@ def __init__( int(spatial_shape[0] // 2), groups=self.groups, bias=True, + **dd, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -907,6 +958,8 @@ def __init__( drop_path: float = 0.0, layer_scale_init_value: float = 1e-5, inference_mode: bool = False, + device=None, + dtype=None, ): """Build RepMixer Block. @@ -920,7 +973,7 @@ def __init__( layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ - + dd = {'device': device, 'dtype': dtype} super().__init__() self.token_mixer = RepMixer( @@ -928,6 +981,7 @@ def __init__( kernel_size=kernel_size, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, + **dd, ) self.mlp = ConvMlp( @@ -935,9 +989,10 @@ def __init__( hidden_channels=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, + **dd, ) if layer_scale_init_value is not None: - self.layer_scale = LayerScale2d(dim, layer_scale_init_value) + self.layer_scale = LayerScale2d(dim, layer_scale_init_value, **dd) else: self.layer_scale = nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() @@ -964,6 +1019,8 @@ def __init__( proj_drop: float = 0.0, drop_path: float = 0.0, layer_scale_init_value: float = 1e-5, + device=None, + dtype=None, ): """Build Attention Block. @@ -976,13 +1033,13 @@ def __init__( drop_path: Drop path rate. Default: 0.0 layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 """ - + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm = norm_layer(dim) - self.token_mixer = Attention(dim=dim) + self.norm = norm_layer(dim, **dd) + self.token_mixer = Attention(dim=dim, **dd) if layer_scale_init_value is not None: - self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value) + self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value, **dd) else: self.layer_scale_1 = nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() @@ -992,9 +1049,10 @@ def __init__( hidden_channels=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, + **dd, ) if layer_scale_init_value is not None: - self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value) + self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value, **dd) else: self.layer_scale_2 = nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() @@ -1022,10 +1080,12 @@ def __init__( act_layer: Type[nn.Module] = nn.GELU, norm_layer: Type[nn.Module] = nn.BatchNorm2d, proj_drop_rate: float = 0.0, - drop_path_rate: float = 0.0, + drop_path_rate: Union[List[float], float] = 0.0, layer_scale_init_value: Optional[float] = 1e-5, - lkc_use_act=False, - inference_mode=False, + lkc_use_act: bool = False, + inference_mode: bool = False, + device=None, + dtype=None, ): """FastViT stage. @@ -1043,6 +1103,7 @@ def __init__( inference_mode: Flag to instantiate block in inference mode. """ super().__init__() + dd = {'device': device, 'dtype': dtype} self.grad_checkpointing = False if downsample: @@ -1055,13 +1116,14 @@ def __init__( act_layer=act_layer, lkc_use_act=lkc_use_act, inference_mode=inference_mode, + **dd, ) else: assert dim == dim_out self.downsample = nn.Identity() if pos_emb_layer is not None: - self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode) + self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode, **dd) else: self.pos_emb = nn.Identity() @@ -1077,6 +1139,7 @@ def __init__( drop_path=drop_path_rate[block_idx], layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, + **dd, )) elif token_mixer_type == "attention": blocks.append(AttentionBlock( @@ -1087,6 +1150,7 @@ def __init__( proj_drop=proj_drop_rate, drop_path=drop_path_rate[block_idx], layer_scale_init_value=layer_scale_init_value, + **dd, )) else: raise ValueError( @@ -1137,8 +1201,11 @@ def __init__( norm_layer: Type[nn.Module] = nn.BatchNorm2d, act_layer: Type[nn.Module] = nn.GELU, inference_mode: bool = False, + device=None, + dtype=None, ) -> None: super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = 0 if fork_feat else num_classes self.fork_feat = fork_feat self.global_pool = global_pool @@ -1151,6 +1218,7 @@ def __init__( act_layer, inference_mode, use_scale_branch=stem_use_scale_branch, + **dd, ) # Build the main stages of the network architecture @@ -1179,6 +1247,7 @@ def __init__( layer_scale_init_value=layer_scale_init_value, lkc_use_act=lkc_use_act, inference_mode=inference_mode, + **dd, ) stages.append(stage) prev_dim = embed_dims[i] @@ -1202,7 +1271,7 @@ def __init__( """ layer = nn.Identity() else: - layer = norm_layer(embed_dims[i_emb]) + layer = norm_layer(embed_dims[i_emb], **dd) layer_name = f"norm{i_layer}" self.add_module(layer_name, layer) else: @@ -1218,12 +1287,14 @@ def __init__( use_se=True, act_layer=act_layer, num_conv_branches=1, + **dd, ) self.head = ClassifierHead( final_features, num_classes, pool_type=global_pool, drop_rate=drop_rate, + **dd, ) self.apply(self._init_weights) diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 775b45dc2f..46cb9d3789 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -18,13 +18,22 @@ # Written by Jianwei Yang (jianwyan@microsoft.com) # -------------------------------------------------------- from functools import partial -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Type, Union import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead, calculate_drop_path_rates +from timm.layers import ( + Mlp, + DropPath, + LayerNorm2d, + LayerScale2d, + trunc_normal_, + ClassifierHead, + NormMlpClassifierHead, + calculate_drop_path_rates, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint @@ -37,15 +46,18 @@ class FocalModulation(nn.Module): def __init__( self, dim: int, - focal_window, + focal_window: int, focal_level: int, focal_factor: int = 2, bias: bool = True, use_post_norm: bool = False, normalize_modulator: bool = False, proj_drop: float = 0., - norm_layer: Callable = LayerNorm2d, + norm_layer: Type[nn.Module] = LayerNorm2d, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim @@ -56,11 +68,11 @@ def __init__( self.normalize_modulator = normalize_modulator self.input_split = [dim, dim, self.focal_level + 1] - self.f = nn.Conv2d(dim, 2 * dim + (self.focal_level + 1), kernel_size=1, bias=bias) - self.h = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self.f = nn.Conv2d(dim, 2 * dim + (self.focal_level + 1), kernel_size=1, bias=bias, **dd) + self.h = nn.Conv2d(dim, dim, kernel_size=1, bias=bias, **dd) self.act = nn.GELU() - self.proj = nn.Conv2d(dim, dim, kernel_size=1) + self.proj = nn.Conv2d(dim, dim, kernel_size=1, **dd) self.proj_drop = nn.Dropout(proj_drop) self.focal_layers = nn.ModuleList() @@ -68,11 +80,11 @@ def __init__( for k in range(self.focal_level): kernel_size = self.focal_factor * k + self.focal_window self.focal_layers.append(nn.Sequential( - nn.Conv2d(dim, dim, kernel_size=kernel_size, groups=dim, padding=kernel_size // 2, bias=False), + nn.Conv2d(dim, dim, kernel_size=kernel_size, groups=dim, padding=kernel_size // 2, bias=False, **dd), nn.GELU(), )) self.kernel_sizes.append(kernel_size) - self.norm = norm_layer(dim) if self.use_post_norm else nn.Identity() + self.norm = norm_layer(dim, **dd) if self.use_post_norm else nn.Identity() def forward(self, x): # pre linear projection @@ -101,17 +113,6 @@ def forward(self, x): return x_out -class LayerScale2d(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - gamma = self.gamma.view(1, -1, 1, 1) - return x.mul_(gamma) if self.inplace else x * gamma - - class FocalNetBlock(nn.Module): """ Focal Modulation Network Block. """ @@ -125,11 +126,13 @@ def __init__( use_post_norm: bool = False, use_post_norm_in_modulation: bool = False, normalize_modulator: bool = False, - layerscale_value: float = 1e-4, + layerscale_value: Optional[float] = 1e-4, proj_drop: float = 0., drop_path: float = 0., - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm2d, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = LayerNorm2d, + device=None, + dtype=None, ): """ Args: @@ -145,6 +148,7 @@ def __init__( act_layer: Activation layer. norm_layer: Normalization layer. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.mlp_ratio = mlp_ratio @@ -153,7 +157,7 @@ def __init__( self.focal_level = focal_level self.use_post_norm = use_post_norm - self.norm1 = norm_layer(dim) if not use_post_norm else nn.Identity() + self.norm1 = norm_layer(dim, **dd) if not use_post_norm else nn.Identity() self.modulation = FocalModulation( dim, focal_window=focal_window, @@ -162,21 +166,23 @@ def __init__( normalize_modulator=normalize_modulator, proj_drop=proj_drop, norm_layer=norm_layer, + **dd, ) - self.norm1_post = norm_layer(dim) if use_post_norm else nn.Identity() - self.ls1 = LayerScale2d(dim, layerscale_value) if layerscale_value is not None else nn.Identity() + self.norm1_post = norm_layer(dim, **dd) if use_post_norm else nn.Identity() + self.ls1 = LayerScale2d(dim, layerscale_value, **dd) if layerscale_value is not None else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) if not use_post_norm else nn.Identity() + self.norm2 = norm_layer(dim, **dd) if not use_post_norm else nn.Identity() self.mlp = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, use_conv=True, + **dd, ) - self.norm2_post = norm_layer(dim) if use_post_norm else nn.Identity() - self.ls2 = LayerScale2d(dim, layerscale_value) if layerscale_value is not None else nn.Identity() + self.norm2_post = norm_layer(dim, **dd) if use_post_norm else nn.Identity() + self.ls2 = LayerScale2d(dim, layerscale_value, **dd) if layerscale_value is not None else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): @@ -211,10 +217,12 @@ def __init__( use_post_norm: bool = False, use_post_norm_in_modulation: bool = False, normalize_modulator: bool = False, - layerscale_value: float = 1e-4, + layerscale_value: Optional[float] = 1e-4, proj_drop: float = 0., - drop_path: float = 0., - norm_layer: Callable = LayerNorm2d, + drop_path: Union[float, List[float]] = 0., + norm_layer: Type[nn.Module] = LayerNorm2d, + device=None, + dtype=None, ): """ Args: @@ -233,6 +241,7 @@ def __init__( drop_path: Stochastic depth rate. norm_layer: Normalization layer. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.depth = depth @@ -245,6 +254,7 @@ def __init__( stride=2, overlap=use_overlap_down, norm_layer=norm_layer, + **dd, ) else: self.downsample = nn.Identity() @@ -263,6 +273,7 @@ def __init__( proj_drop=proj_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, + **dd, ) for i in range(depth)]) @@ -288,7 +299,9 @@ def __init__( out_chs: int, stride: int = 4, overlap: bool = False, - norm_layer: Optional[Callable] = None, + norm_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ): """ @@ -299,6 +312,7 @@ def __init__( overlap: Use overlapping convolutions if True. norm_layer: Normalization layer. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.stride = stride padding = 0 @@ -309,8 +323,8 @@ def __init__( kernel_size, padding = 7, 2 elif stride == 2: kernel_size, padding = 3, 1 - self.proj = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding) - self.norm = norm_layer(out_chs) if norm_layer is not None else nn.Identity() + self.proj = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding, **dd) + self.norm = norm_layer(out_chs, **dd) if norm_layer is not None else nn.Identity() def forward(self, x): x = self.proj(x) @@ -339,10 +353,12 @@ def __init__( head_hidden_size: Optional[int] = None, head_init_scale: float = 1.0, layerscale_value: Optional[float] = None, - drop_rate: bool = 0., - proj_drop_rate: bool = 0., - drop_path_rate: bool = 0.1, - norm_layer: Callable = partial(LayerNorm2d, eps=1e-5), + drop_rate: float = 0., + proj_drop_rate: float = 0., + drop_path_rate: float = 0.1, + norm_layer: Type[nn.Module] = partial(LayerNorm2d, eps=1e-5), + device=None, + dtype=None, ): """ Args: @@ -361,7 +377,7 @@ def __init__( norm_layer: Normalization layer. """ super().__init__() - + dd = {'device': device, 'dtype': dtype} self.num_layers = len(depths) embed_dim = [embed_dim * (2 ** i) for i in range(self.num_layers)] @@ -375,6 +391,7 @@ def __init__( out_chs=embed_dim[0], overlap=use_overlap_down, norm_layer=norm_layer, + **dd, ) in_dim = embed_dim[0] @@ -398,6 +415,7 @@ def __init__( proj_drop=proj_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, + **dd, ) in_dim = out_dim layers += [layer] @@ -415,14 +433,16 @@ def __init__( pool_type=global_pool, drop_rate=drop_rate, norm_layer=norm_layer, + **dd, ) else: - self.norm = norm_layer(self.num_features) + self.norm = norm_layer(self.num_features, **dd) self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, - drop_rate=drop_rate + drop_rate=drop_rate, + **dd, ) named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index b7569a63b5..2533ed949c 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -21,14 +21,27 @@ """ import math from functools import partial -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Type, Union import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \ - get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert +from timm.layers import ( + DropPath, + calculate_drop_path_rates, + to_2tuple, + to_ntuple, + Mlp, + ClassifierHead, + LayerNorm2d, + LayerScale, + get_attn, + get_act_layer, + get_norm_layer, + RelPosBias, + _assert, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function @@ -43,15 +56,18 @@ class MbConvBlock(nn.Module): """ def __init__( self, - in_chs, - out_chs=None, - expand_ratio=1.0, - attn_layer='se', - bias=False, - act_layer=nn.GELU, + in_chs: int, + out_chs: Optional[int] = None, + expand_ratio: float = 1.0, + attn_layer: str = 'se', + bias: bool = False, + act_layer: Type[nn.Module] = nn.GELU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - attn_kwargs = dict(act_layer=act_layer) + attn_kwargs = dict(act_layer=act_layer, **dd) if isinstance(attn_layer, str) and attn_layer == 'se' or attn_layer == 'eca': attn_kwargs['rd_ratio'] = 0.25 attn_kwargs['bias'] = False @@ -59,10 +75,10 @@ def __init__( out_chs = out_chs or in_chs mid_chs = int(expand_ratio * in_chs) - self.conv_dw = nn.Conv2d(in_chs, mid_chs, 3, 1, 1, groups=in_chs, bias=bias) + self.conv_dw = nn.Conv2d(in_chs, mid_chs, 3, 1, 1, groups=in_chs, bias=bias, **dd) self.act = act_layer() self.se = attn_layer(mid_chs, **attn_kwargs) - self.conv_pw = nn.Conv2d(mid_chs, out_chs, 1, 1, 0, bias=bias) + self.conv_pw = nn.Conv2d(mid_chs, out_chs, 1, 1, 0, bias=bias, **dd) def forward(self, x): shortcut = x @@ -77,27 +93,30 @@ def forward(self, x): class Downsample2d(nn.Module): def __init__( self, - dim, - dim_out=None, - reduction='conv', - act_layer=nn.GELU, - norm_layer=LayerNorm2d, # NOTE in NCHW + dim: int, + dim_out: Optional[int] = None, + reduction: str = 'conv', + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = LayerNorm2d, # NOTE in NCHW + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() dim_out = dim_out or dim - self.norm1 = norm_layer(dim) if norm_layer is not None else nn.Identity() - self.conv_block = MbConvBlock(dim, act_layer=act_layer) + self.norm1 = norm_layer(dim, **dd) if norm_layer is not None else nn.Identity() + self.conv_block = MbConvBlock(dim, act_layer=act_layer, **dd) assert reduction in ('conv', 'max', 'avg') if reduction == 'conv': - self.reduction = nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False) + self.reduction = nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False, **dd) elif reduction == 'max': assert dim == dim_out self.reduction = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: assert dim == dim_out self.reduction = nn.AvgPool2d(kernel_size=2) - self.norm2 = norm_layer(dim_out) if norm_layer is not None else nn.Identity() + self.norm2 = norm_layer(dim_out, **dd) if norm_layer is not None else nn.Identity() def forward(self, x): x = self.norm1(x) @@ -110,11 +129,14 @@ def forward(self, x): class FeatureBlock(nn.Module): def __init__( self, - dim, - levels=0, - reduction='max', - act_layer=nn.GELU, + dim: int, + levels: int = 0, + reduction: str = 'max', + act_layer: Type[nn.Module] = nn.GELU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() reductions = levels levels = max(1, levels) @@ -124,7 +146,7 @@ def __init__( pool_fn = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1) self.blocks = nn.Sequential() for i in range(levels): - self.blocks.add_module(f'conv{i+1}', MbConvBlock(dim, act_layer=act_layer)) + self.blocks.add_module(f'conv{i+1}', MbConvBlock(dim, act_layer=act_layer, **dd)) if reductions: self.blocks.add_module(f'pool{i+1}', pool_fn()) reductions -= 1 @@ -138,12 +160,15 @@ def __init__( self, in_chs: int = 3, out_chs: int = 96, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = LayerNorm2d, # NOTE stem in NCHW + device=None, + dtype=None, ): super().__init__() - self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1) - self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer) + dd = {'device': device, 'dtype': dtype} + self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1, **dd) + self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer, **dd) def forward(self, x): x = self.conv1(x) @@ -162,7 +187,10 @@ def __init__( qkv_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() window_size = to_2tuple(window_size) self.window_size = window_size @@ -171,13 +199,13 @@ def __init__( self.scale = self.head_dim ** -0.5 self.use_global = use_global - self.rel_pos = RelPosBias(window_size=window_size, num_heads=num_heads) + self.rel_pos = RelPosBias(window_size=window_size, num_heads=num_heads, **dd) if self.use_global: - self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias, **dd) else: - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, q_global: Optional[torch.Tensor] = None): @@ -223,16 +251,6 @@ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, i return x -class LayerScale(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - class GlobalContextVitBlock(nn.Module): def __init__( self, @@ -248,16 +266,19 @@ def __init__( attn_drop: float = 0., drop_path: float = 0., attn_layer: Callable = WindowAttentionGlobal, - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() feat_size = to_2tuple(feat_size) window_size = to_2tuple(window_size) self.window_size = window_size self.num_windows = int((feat_size[0] // window_size[0]) * (feat_size[1] // window_size[1])) - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = attn_layer( dim, num_heads=num_heads, @@ -266,13 +287,14 @@ def __init__( qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, + **dd, ) - self.ls1 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity() + self.ls1 = LayerScale(dim, layer_scale, **dd) if layer_scale is not None else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop) - self.ls2 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity() + self.norm2 = norm_layer(dim, **dd) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, **dd) + self.ls2 = LayerScale(dim, layer_scale, **dd) if layer_scale is not None else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def _window_attn(self, x, q_global: Optional[torch.Tensor] = None): @@ -292,7 +314,7 @@ def forward(self, x, q_global: Optional[torch.Tensor] = None): class GlobalContextVitStage(nn.Module): def __init__( self, - dim, + dim: int, depth: int, num_heads: int, feat_size: Tuple[int, int], @@ -306,16 +328,20 @@ def __init__( proj_drop: float = 0., attn_drop: float = 0., drop_path: Union[List[float], float] = 0.0, - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, - norm_layer_cl: Callable = LayerNorm2d, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + norm_layer_cl: Type[nn.Module] = LayerNorm2d, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() if downsample: self.downsample = Downsample2d( dim=dim, dim_out=dim * 2, norm_layer=norm_layer, + **dd, ) dim = dim * 2 feat_size = (feat_size[0] // 2, feat_size[1] // 2) @@ -325,8 +351,8 @@ def __init__( window_size = to_2tuple(window_size) feat_levels = int(math.log2(min(feat_size) / min(window_size))) - self.global_block = FeatureBlock(dim, feat_levels) - self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity() + self.global_block = FeatureBlock(dim, feat_levels, **dd) + self.global_norm = norm_layer_cl(dim, **dd) if global_norm else nn.Identity() self.blocks = nn.ModuleList([ GlobalContextVitBlock( @@ -343,10 +369,11 @@ def __init__( drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, act_layer=act_layer, norm_layer=norm_layer_cl, + **dd, ) for i in range(depth) ]) - self.norm = norm_layer_cl(dim) if stage_norm else nn.Identity() + self.norm = norm_layer_cl(dim, **dd) if stage_norm else nn.Identity() self.dim = dim self.feat_size = feat_size self.grad_checkpointing = False @@ -375,9 +402,9 @@ def __init__( in_chans: int = 3, num_classes: int = 1000, global_pool: str = 'avg', - img_size: Tuple[int, int] = 224, + img_size: Union[int, Tuple[int, int]] = 224, window_ratio: Tuple[int, ...] = (32, 32, 16, 32), - window_size: Tuple[int, ...] = None, + window_size: Optional[Union[int, Tuple[int, ...]]] = None, embed_dim: int = 64, depths: Tuple[int, ...] = (3, 4, 19, 5), num_heads: Tuple[int, ...] = (2, 4, 8, 16), @@ -388,13 +415,16 @@ def __init__( proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., - weight_init='', + weight_init: str = '', act_layer: str = 'gelu', norm_layer: str = 'layernorm2d', norm_layer_cl: str = 'layernorm', norm_eps: float = 1e-5, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} act_layer = get_act_layer(act_layer) norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps) @@ -416,7 +446,8 @@ def __init__( in_chs=in_chans, out_chs=embed_dim, act_layer=act_layer, - norm_layer=norm_layer + norm_layer=norm_layer, + **dd, ) dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) @@ -441,12 +472,13 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, norm_layer_cl=norm_layer_cl, + **dd, )) self.feature_info += [dict(num_chs=stages[-1].dim, reduction=2**(i+2), module=f'stages.{i}')] self.stages = nn.Sequential(*stages) # Classifier head - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate, **dd) if weight_init: named_apply(partial(self._init_weights, scheme=weight_init), self) @@ -490,11 +522,12 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes if global_pool is None: global_pool = self.head.global_pool.pool_type - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, **dd) def forward_intermediates( self, diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 126d638f43..53d988ac96 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -11,15 +11,16 @@ """ import math from functools import partial -from typing import Any, Callable, Dict, List, Set, Optional, Tuple, Union +from typing import Any, Dict, List, Set, Optional, Tuple, Union, Type import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import SelectAdaptivePool2d, Linear, make_divisible, LayerType +from timm.layers import SelectAdaptivePool2d, Linear, make_divisible from timm.utils.model import reparameterize_model + from ._builder import build_model_with_cfg from ._efficientnet_blocks import SqueezeExcite, ConvBnAct from ._features import feature_take_indices @@ -41,22 +42,25 @@ def __init__( ratio: int = 2, dw_size: int = 3, stride: int = 1, - act_layer: LayerType = nn.ReLU, + act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, ): - super(GhostModule, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.out_chs = out_chs init_chs = math.ceil(out_chs / ratio) new_chs = init_chs * (ratio - 1) self.primary_conv = nn.Sequential( - nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False), - nn.BatchNorm2d(init_chs), + nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False, **dd), + nn.BatchNorm2d(init_chs, **dd), act_layer(inplace=True), ) self.cheap_operation = nn.Sequential( - nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False), - nn.BatchNorm2d(new_chs), + nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False, **dd), + nn.BatchNorm2d(new_chs, **dd), act_layer(inplace=True), ) @@ -76,30 +80,33 @@ def __init__( ratio: int = 2, dw_size: int = 3, stride: int = 1, - act_layer: LayerType = nn.ReLU, + act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.gate_fn = nn.Sigmoid() self.out_chs = out_chs init_chs = math.ceil(out_chs / ratio) new_chs = init_chs * (ratio - 1) self.primary_conv = nn.Sequential( - nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False), - nn.BatchNorm2d(init_chs), + nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False, **dd), + nn.BatchNorm2d(init_chs, **dd), act_layer(inplace=True), ) self.cheap_operation = nn.Sequential( - nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size // 2, groups=init_chs, bias=False), - nn.BatchNorm2d(new_chs), + nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size // 2, groups=init_chs, bias=False, **dd), + nn.BatchNorm2d(new_chs, **dd), act_layer(inplace=True), ) self.short_conv = nn.Sequential( - nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False), - nn.BatchNorm2d(out_chs), - nn.Conv2d(out_chs, out_chs, kernel_size=(1, 5), stride=1, padding=(0, 2), groups=out_chs, bias=False), - nn.BatchNorm2d(out_chs), - nn.Conv2d(out_chs, out_chs, kernel_size=(5, 1), stride=1, padding=(2, 0), groups=out_chs, bias=False), - nn.BatchNorm2d(out_chs), + nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False, **dd), + nn.BatchNorm2d(out_chs, **dd), + nn.Conv2d(out_chs, out_chs, kernel_size=(1, 5), stride=1, padding=(0, 2), groups=out_chs, bias=False, **dd), + nn.BatchNorm2d(out_chs, **dd), + nn.Conv2d(out_chs, out_chs, kernel_size=(5, 1), stride=1, padding=(2, 0), groups=out_chs, bias=False, **dd), + nn.BatchNorm2d(out_chs, **dd), ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -120,10 +127,13 @@ def __init__( ratio: int = 2, dw_size: int = 3, stride: int = 1, - act_layer: LayerType = nn.ReLU, + act_layer: Type[nn.Module] = nn.ReLU, mode: str = 'original', + device=None, + dtype=None, ): - super(GhostModuleV3, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.gate_fn = nn.Sigmoid() self.out_chs = out_chs init_chs = math.ceil(out_chs / ratio) @@ -137,28 +147,43 @@ def __init__( self.primary_rpr_skip = None self.primary_rpr_scale = None - self.primary_rpr_conv = nn.ModuleList( - [ConvBnAct(in_chs, init_chs, kernel_size, stride, pad_type=kernel_size // 2, \ - act_layer=None) for _ in range(self.num_conv_branches)] - ) + self.primary_rpr_conv = nn.ModuleList([ + ConvBnAct( + in_chs, + init_chs, + kernel_size, + stride, + pad_type=kernel_size // 2, + act_layer=None, + **dd, + ) for _ in range(self.num_conv_branches) + ]) # Re-parameterizable scale branch self.primary_activation = act_layer(inplace=True) - self.cheap_rpr_skip = nn.BatchNorm2d(init_chs) - self.cheap_rpr_conv = nn.ModuleList( - [ConvBnAct(init_chs, new_chs, dw_size, 1, pad_type=dw_size // 2, group_size=1, \ - act_layer=None) for _ in range(self.num_conv_branches)] - ) + self.cheap_rpr_skip = nn.BatchNorm2d(init_chs, **dd) + self.cheap_rpr_conv = nn.ModuleList([ + ConvBnAct( + init_chs, + new_chs, + dw_size, + 1, + pad_type=dw_size // 2, + group_size=1, + act_layer=None, + **dd, + ) for _ in range(self.num_conv_branches) + ]) # Re-parameterizable scale branch - self.cheap_rpr_scale = ConvBnAct(init_chs, new_chs, 1, 1, pad_type=0, group_size=1, act_layer=None) + self.cheap_rpr_scale = ConvBnAct(init_chs, new_chs, 1, 1, pad_type=0, group_size=1, act_layer=None, **dd) self.cheap_activation = act_layer(inplace=True) self.short_conv = nn.Sequential( - nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False), - nn.BatchNorm2d(out_chs), - nn.Conv2d(out_chs, out_chs, kernel_size=(1,5), stride=1, padding=(0,2), groups=out_chs, bias=False), - nn.BatchNorm2d(out_chs), - nn.Conv2d(out_chs, out_chs, kernel_size=(5,1), stride=1, padding=(2,0), groups=out_chs, bias=False), - nn.BatchNorm2d(out_chs), + nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False, **dd), + nn.BatchNorm2d(out_chs, **dd), + nn.Conv2d(out_chs, out_chs, kernel_size=(1,5), stride=1, padding=(0,2), groups=out_chs, bias=False, **dd), + nn.BatchNorm2d(out_chs, **dd), + nn.Conv2d(out_chs, out_chs, kernel_size=(5,1), stride=1, padding=(2,0), groups=out_chs, bias=False, **dd), + nn.BatchNorm2d(out_chs, **dd), ) if self.mode in ['shortcut'] else nn.Identity() self.in_channels = init_chs @@ -254,9 +279,7 @@ def _fuse_bn_tensor(self, branch): device=branch.weight.device ) for i in range(self.in_channels): - kernel_value[i, i % input_dim, - self.kernel_size // 2, - self.kernel_size // 2] = 1 + kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1 self.id_tensor = kernel_value kernel = self.id_tensor running_mean = branch.running_mean @@ -341,35 +364,45 @@ def __init__( out_chs: int, dw_kernel_size: int = 3, stride: int = 1, - act_layer: Callable = nn.ReLU, + act_layer: Type[nn.Module] = nn.ReLU, se_ratio: float = 0., mode: str = 'original', + device=None, + dtype=None, ): - super(GhostBottleneck, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() has_se = se_ratio is not None and se_ratio > 0. self.stride = stride # Point-wise expansion if mode == 'original': - self.ghost1 = GhostModule(in_chs, mid_chs, act_layer=act_layer) + self.ghost1 = GhostModule(in_chs, mid_chs, act_layer=act_layer, **dd) else: - self.ghost1 = GhostModuleV2(in_chs, mid_chs, act_layer=act_layer) + self.ghost1 = GhostModuleV2(in_chs, mid_chs, act_layer=act_layer, **dd) # Depth-wise convolution if self.stride > 1: self.conv_dw = nn.Conv2d( - mid_chs, mid_chs, dw_kernel_size, stride=stride, - padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False) - self.bn_dw = nn.BatchNorm2d(mid_chs) + mid_chs, + mid_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size-1)//2, + groups=mid_chs, + bias=False, + **dd, + ) + self.bn_dw = nn.BatchNorm2d(mid_chs, **dd) else: self.conv_dw = None self.bn_dw = None # Squeeze-and-excitation - self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None + self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio, **dd) if has_se else None # Point-wise linear projection - self.ghost2 = GhostModule(mid_chs, out_chs, act_layer=nn.Identity) + self.ghost2 = GhostModule(mid_chs, out_chs, act_layer=nn.Identity, **dd) # shortcut if in_chs == out_chs and self.stride == 1: @@ -377,11 +410,18 @@ def __init__( else: self.shortcut = nn.Sequential( nn.Conv2d( - in_chs, in_chs, dw_kernel_size, stride=stride, - padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), - nn.BatchNorm2d(in_chs), - nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), - nn.BatchNorm2d(out_chs), + in_chs, + in_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size-1)//2, + groups=in_chs, + bias=False, + **dd, + ), + nn.BatchNorm2d(in_chs, **dd), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False, **dd), + nn.BatchNorm2d(out_chs, **dd), ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -416,11 +456,14 @@ def __init__( out_chs: int, dw_kernel_size: int = 3, stride: int = 1, - act_layer: LayerType = nn.ReLU, + act_layer: Type[nn.Module] = nn.ReLU, se_ratio: float = 0., mode: str = 'original', + device=None, + dtype=None, ): - super(GhostBottleneckV3, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() has_se = se_ratio is not None and se_ratio > 0. self.stride = stride @@ -431,16 +474,23 @@ def __init__( self.bn_dw = nn.Identity() # Point-wise expansion - self.ghost1 = GhostModuleV3(in_chs, mid_chs, act_layer=act_layer, mode=mode) + self.ghost1 = GhostModuleV3(in_chs, mid_chs, act_layer=act_layer, mode=mode, **dd) # Depth-wise convolution if self.stride > 1: - self.dw_rpr_conv = nn.ModuleList( - [ConvBnAct(mid_chs, mid_chs, dw_kernel_size, stride, pad_type=(dw_kernel_size - 1) // 2, - group_size=1, act_layer=None) for _ in range(self.num_conv_branches)] - ) + self.dw_rpr_conv = nn.ModuleList([ConvBnAct( + mid_chs, + mid_chs, + dw_kernel_size, + stride, + pad_type=(dw_kernel_size - 1) // 2, + group_size=1, + act_layer=None, + **dd, + ) for _ in range(self.num_conv_branches) + ]) # Re-parameterizable scale branch - self.dw_rpr_scale = ConvBnAct(mid_chs, mid_chs, 1, 2, pad_type=0, group_size=1, act_layer=None) + self.dw_rpr_scale = ConvBnAct(mid_chs, mid_chs, 1, 2, pad_type=0, group_size=1, act_layer=None, **dd) self.kernel_size = dw_kernel_size self.in_channels = mid_chs else: @@ -449,10 +499,10 @@ def __init__( self.dw_rpr_skip = None # Squeeze-and-excitation - self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else nn.Identity() + self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio, **dd) if has_se else nn.Identity() # Point-wise linear projection - self.ghost2 = GhostModuleV3(mid_chs, out_chs, act_layer=nn.Identity, mode='original') + self.ghost2 = GhostModuleV3(mid_chs, out_chs, act_layer=nn.Identity, mode='original', **dd) # shortcut if in_chs == out_chs and self.stride == 1: @@ -460,11 +510,18 @@ def __init__( else: self.shortcut = nn.Sequential( nn.Conv2d( - in_chs, in_chs, dw_kernel_size, stride=stride, - padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), - nn.BatchNorm2d(in_chs), - nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), - nn.BatchNorm2d(out_chs), + in_chs, + in_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size-1)//2, + groups=in_chs, + bias=False, + **dd, + ), + nn.BatchNorm2d(in_chs, **dd), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False, **dd), + nn.BatchNorm2d(out_chs, **dd), ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -535,9 +592,7 @@ def _fuse_bn_tensor(self, branch): device=branch.weight.device ) for i in range(self.in_channels): - kernel_value[i, i % input_dim, - self.kernel_size // 2, - self.kernel_size // 2] = 1 + kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1 self.id_tensor = kernel_value kernel = self.id_tensor running_mean = branch.running_mean @@ -586,7 +641,7 @@ def reparameterize(self): class GhostNet(nn.Module): def __init__( self, - cfgs, + cfgs: List[List[List[Union[int, float]]]], num_classes: int = 1000, width: float = 1.0, in_chans: int = 3, @@ -594,8 +649,11 @@ def __init__( global_pool: str = 'avg', drop_rate: float = 0.2, version: str = 'v1', + device=None, + dtype=None, ): - super(GhostNet, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} # setting of inverted residual blocks assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported' self.cfgs = cfgs @@ -607,9 +665,9 @@ def __init__( # building first layer stem_chs = make_divisible(16 * width, 4) - self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False) + self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False, **dd) self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem')) - self.bn1 = nn.BatchNorm2d(stem_chs) + self.bn1 = nn.BatchNorm2d(stem_chs, **dd) self.act1 = nn.ReLU(inplace=True) prev_chs = stem_chs @@ -624,7 +682,7 @@ def __init__( for k, exp_size, c, se_ratio, s in cfg: out_chs = make_divisible(c * width, 4) mid_chs = make_divisible(exp_size * width, 4) - layer_kwargs = {} + layer_kwargs = dict(**dd) if version == 'v2' and layer_idx > 1: layer_kwargs['mode'] = 'attn' if version == 'v3' and layer_idx > 1: @@ -640,7 +698,7 @@ def __init__( stage_idx += 1 out_chs = make_divisible(exp_size * width, 4) - stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1))) + stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1, **dd))) self.pool_dim = prev_chs = out_chs self.blocks = nn.Sequential(*stages) @@ -649,10 +707,10 @@ def __init__( self.num_features = prev_chs self.head_hidden_size = out_chs = 1280 self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) + self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True, **dd) self.act2 = nn.ReLU(inplace=True) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(out_chs, num_classes, **dd) if num_classes > 0 else nn.Identity() # FIXME init @@ -684,7 +742,10 @@ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear( + self.head_hidden_size, num_classes, + device=self.conv_head.weight.device, dtype=self.conv_head.weight.dtype + ) if num_classes > 0 else nn.Identity() def forward_intermediates( self, diff --git a/timm/models/hgnet.py b/timm/models/hgnet.py index d6eee3f76b..927b2c4dc6 100644 --- a/timm/models/hgnet.py +++ b/timm/models/hgnet.py @@ -6,7 +6,7 @@ PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py """ -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -25,12 +25,15 @@ class LearnableAffineBlock(nn.Module): def __init__( self, - scale_value=1.0, - bias_value=0.0 + scale_value: float = 1.0, + bias_value: float = 0.0, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True) - self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True) + self.scale = nn.Parameter(torch.tensor([scale_value], **dd), requires_grad=True) + self.bias = nn.Parameter(torch.tensor([bias_value], **dd), requires_grad=True) def forward(self, x): return self.scale * x + self.bias @@ -39,15 +42,18 @@ def forward(self, x): class ConvBNAct(nn.Module): def __init__( self, - in_chs, - out_chs, - kernel_size, - stride=1, - groups=1, - padding='', - use_act=True, - use_lab=False + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + padding: str = '', + use_act: bool = True, + use_lab: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.use_act = use_act self.use_lab = use_lab @@ -58,14 +64,15 @@ def __init__( stride=stride, padding=padding, groups=groups, + **dd, ) - self.bn = nn.BatchNorm2d(out_chs) + self.bn = nn.BatchNorm2d(out_chs, **dd) if self.use_act: self.act = nn.ReLU() else: self.act = nn.Identity() if self.use_act and self.use_lab: - self.lab = LearnableAffineBlock() + self.lab = LearnableAffineBlock(**dd) else: self.lab = nn.Identity() @@ -80,12 +87,15 @@ def forward(self, x): class LightConvBNAct(nn.Module): def __init__( self, - in_chs, - out_chs, - kernel_size, - groups=1, - use_lab=False + in_chs: int, + out_chs: int, + kernel_size: int, + groups: int = 1, + use_lab: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.conv1 = ConvBNAct( in_chs, @@ -93,6 +103,7 @@ def __init__( kernel_size=1, use_act=False, use_lab=use_lab, + **dd, ) self.conv2 = ConvBNAct( out_chs, @@ -101,6 +112,7 @@ def __init__( groups=out_chs, use_act=True, use_lab=use_lab, + **dd, ) def forward(self, x): @@ -110,7 +122,8 @@ def forward(self, x): class EseModule(nn.Module): - def __init__(self, chs): + def __init__(self, chs: int, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} super().__init__() self.conv = nn.Conv2d( chs, @@ -118,6 +131,7 @@ def __init__(self, chs): kernel_size=1, stride=1, padding=0, + **dd, ) self.sigmoid = nn.Sigmoid() @@ -131,14 +145,16 @@ def forward(self, x): class StemV1(nn.Module): # for PP-HGNet - def __init__(self, stem_chs): + def __init__(self, stem_chs: List[int], device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} super().__init__() self.stem = nn.Sequential(*[ ConvBNAct( stem_chs[i], stem_chs[i + 1], kernel_size=3, - stride=2 if i == 0 else 1) for i in range( + stride=2 if i == 0 else 1, + **dd) for i in range( len(stem_chs) - 1) ]) self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -151,7 +167,16 @@ def forward(self, x): class StemV2(nn.Module): # for PP-HGNetv2 - def __init__(self, in_chs, mid_chs, out_chs, use_lab=False): + def __init__( + self, + in_chs: int, + mid_chs: int, + out_chs: int, + use_lab: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.stem1 = ConvBNAct( in_chs, @@ -159,6 +184,7 @@ def __init__(self, in_chs, mid_chs, out_chs, use_lab=False): kernel_size=3, stride=2, use_lab=use_lab, + **dd, ) self.stem2a = ConvBNAct( mid_chs, @@ -166,6 +192,7 @@ def __init__(self, in_chs, mid_chs, out_chs, use_lab=False): kernel_size=2, stride=1, use_lab=use_lab, + **dd, ) self.stem2b = ConvBNAct( mid_chs // 2, @@ -173,6 +200,7 @@ def __init__(self, in_chs, mid_chs, out_chs, use_lab=False): kernel_size=2, stride=1, use_lab=use_lab, + **dd, ) self.stem3 = ConvBNAct( mid_chs * 2, @@ -180,6 +208,7 @@ def __init__(self, in_chs, mid_chs, out_chs, use_lab=False): kernel_size=3, stride=2, use_lab=use_lab, + **dd, ) self.stem4 = ConvBNAct( mid_chs, @@ -187,6 +216,7 @@ def __init__(self, in_chs, mid_chs, out_chs, use_lab=False): kernel_size=1, stride=1, use_lab=use_lab, + **dd, ) self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True) @@ -206,17 +236,20 @@ def forward(self, x): class HighPerfGpuBlock(nn.Module): def __init__( self, - in_chs, - mid_chs, - out_chs, - layer_num, - kernel_size=3, - residual=False, - light_block=False, - use_lab=False, - agg='ese', - drop_path=0., + in_chs: int, + mid_chs: int, + out_chs: int, + layer_num: int, + kernel_size: int = 3, + residual: bool = False, + light_block: bool = False, + use_lab: bool = False, + agg: str = 'ese', + drop_path: Union[List[float], float] = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.residual = residual @@ -229,6 +262,7 @@ def __init__( mid_chs, kernel_size=kernel_size, use_lab=use_lab, + **dd, ) ) else: @@ -239,6 +273,7 @@ def __init__( kernel_size=kernel_size, stride=1, use_lab=use_lab, + **dd, ) ) @@ -251,6 +286,7 @@ def __init__( kernel_size=1, stride=1, use_lab=use_lab, + **dd, ) aggregation_excitation_conv = ConvBNAct( out_chs // 2, @@ -258,6 +294,7 @@ def __init__( kernel_size=1, stride=1, use_lab=use_lab, + **dd, ) self.aggregation = nn.Sequential( aggregation_squeeze_conv, @@ -270,8 +307,9 @@ def __init__( kernel_size=1, stride=1, use_lab=use_lab, + **dd, ) - att = EseModule(out_chs) + att = EseModule(out_chs, **dd) self.aggregation = nn.Sequential( aggregation_conv, att, @@ -295,19 +333,22 @@ def forward(self, x): class HighPerfGpuStage(nn.Module): def __init__( self, - in_chs, - mid_chs, - out_chs, - block_num, - layer_num, - downsample=True, - stride=2, - light_block=False, - kernel_size=3, - use_lab=False, - agg='ese', - drop_path=0., + in_chs: int, + mid_chs: int, + out_chs: int, + block_num: int, + layer_num: int, + downsample: bool = True, + stride: int = 2, + light_block: bool = False, + kernel_size: int = 3, + use_lab: bool = False, + agg: str = 'ese', + drop_path: Union[List[float], float] = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.downsample = downsample if downsample: @@ -319,6 +360,7 @@ def __init__( groups=in_chs, use_act=False, use_lab=use_lab, + **dd, ) else: self.downsample = nn.Identity() @@ -337,6 +379,7 @@ def __init__( use_lab=use_lab, agg=agg, drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path, + **dd, ) ) self.blocks = nn.Sequential(*blocks_list) @@ -359,9 +402,12 @@ def __init__( pool_type: str = 'avg', drop_rate: float = 0., hidden_size: Optional[int] = 2048, - use_lab: bool = False + use_lab: bool = False, + device=None, + dtype=None, ): - super(ClassifierHead, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.num_features = in_features if pool_type is not None: if not pool_type: @@ -377,10 +423,11 @@ def __init__( stride=1, padding=0, bias=False, + **dd, ) act = nn.ReLU() if use_lab: - lab = LearnableAffineBlock() + lab = LearnableAffineBlock(**dd) self.last_conv = nn.Sequential(last_conv, act, lab) else: self.last_conv = nn.Sequential(last_conv, act) @@ -389,16 +436,17 @@ def __init__( self.dropout = nn.Dropout(drop_rate) self.flatten = nn.Flatten(1) if pool_type else nn.Identity() # don't flatten if pooling disabled - self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.fc = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity() - def reset(self, num_classes: int, pool_type: Optional[str] = None): + def reset(self, num_classes: int, pool_type: Optional[str] = None, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} if pool_type is not None: if not pool_type: assert num_classes == 0, 'Classifier head must be removed if pooling is disabled' self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) self.flatten = nn.Flatten(1) if pool_type else nn.Identity() # don't flatten if pooling disabled - self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.fc = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity() def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) @@ -423,9 +471,12 @@ def __init__( drop_rate: float = 0., drop_path_rate: float = 0., use_lab: bool = False, + device=None, + dtype=None, **kwargs, ): - super(HighPerfGpuNet, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} stem_type = cfg["stem_type"] stem_chs = cfg["stem_chs"] stages_cfg = [cfg["stage1"], cfg["stage2"], cfg["stage3"], cfg["stage4"]] @@ -439,9 +490,11 @@ def __init__( in_chs=in_chans, mid_chs=stem_chs[0], out_chs=stem_chs[1], - use_lab=use_lab) + use_lab=use_lab, + **dd, + ) else: - self.stem = StemV1([in_chans] + stem_chs) + self.stem = StemV1([in_chans] + stem_chs, **dd) current_stride = 4 @@ -463,6 +516,7 @@ def __init__( use_lab=use_lab, agg='ese' if stem_type == 'v1' else 'se', drop_path=dpr[i], + **dd, )] self.num_features = out_chs if downsample: @@ -476,7 +530,8 @@ def __init__( pool_type=global_pool, drop_rate=drop_rate, hidden_size=head_hidden_size, - use_lab=use_lab + use_lab=use_lab, + **dd, ) self.head_hidden_size = self.head.num_features @@ -505,9 +560,9 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, device=None, dtype=None): self.num_classes = num_classes - self.head.reset(num_classes, global_pool) + self.head.reset(num_classes, global_pool, device=device, dtype=dtype) def forward_intermediates( self, diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 0e9e86441a..81342a2fef 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -31,8 +31,19 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, calculate_drop_path_rates, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \ - _assert, get_norm_layer, to_2tuple, init_weight_vit, init_weight_jax +from timm.layers import ( + DropPath, + calculate_drop_path_rates, + Mlp, + LayerScale, + ClNormMlpClassifierHead, + use_fused_attn, + _assert, + get_norm_layer, + to_2tuple, + init_weight_vit, + init_weight_jax, +) from ._registry import generate_default_cfgs, register_model from ._builder import build_model_with_cfg @@ -258,6 +269,8 @@ def __init__( q_stride: int = 1, window_size: int = 0, use_mask_unit_attn: bool = False, + device=None, + dtype=None, ): """ Args: @@ -267,8 +280,8 @@ def __init__( - window_size: The current (flattened) size of a mask unit *after* pooling (if any). - use_mask_unit_attn: Use Mask Unit or Global Attention. """ + dd = {'device': device, 'dtype': dtype} super().__init__() - self.dim = dim self.dim_out = dim_out self.heads = heads @@ -277,8 +290,8 @@ def __init__( self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.qkv = nn.Linear(dim, 3 * dim_out) - self.proj = nn.Linear(dim_out, dim_out) + self.qkv = nn.Linear(dim, 3 * dim_out, **dd) + self.proj = nn.Linear(dim_out, dim_out, **dd) self.window_size = window_size self.use_mask_unit_attn = use_mask_unit_attn @@ -322,16 +335,19 @@ def __init__( window_size: int = 0, use_expand_proj: bool = True, use_mask_unit_attn: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.dim_out = dim_out - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) if dim != dim_out: self.do_expand = True if use_expand_proj: - self.proj = nn.Linear(dim, dim_out) + self.proj = nn.Linear(dim, dim_out, **dd) else: assert dim_out == dim * 2 self.proj = None @@ -344,14 +360,15 @@ def __init__( heads, q_stride, window_size, - use_mask_unit_attn + use_mask_unit_attn, + **dd ) - self.ls1 = LayerScale(dim_out, init_values=init_values) if init_values is not None else nn.Identity() + self.ls1 = LayerScale(dim_out, init_values=init_values, **dd) if init_values is not None else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity() - self.norm2 = norm_layer(dim_out) - self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer) - self.ls2 = LayerScale(dim_out, init_values=init_values) if init_values is not None else nn.Identity() + self.norm2 = norm_layer(dim_out, **dd) + self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer, **dd) + self.ls2 = LayerScale(dim_out, init_values=init_values, **dd) if init_values is not None else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -386,9 +403,11 @@ def __init__( stride: Tuple[int, ...], padding: Tuple[int, ...], reshape: bool = True, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - # Support any number of spatial dimensions self.spatial_dims = len(kernel) self.reshape = reshape @@ -398,6 +417,7 @@ def __init__( kernel_size=kernel, stride=stride, padding=padding, + **dd, ) def forward( @@ -442,15 +462,18 @@ def __init__( init_values: Optional[float] = None, fix_init: bool = True, weight_init: str = '', - norm_layer: Union[str, nn.Module] = "LayerNorm", + norm_layer: Union[str, Type[nn.Module]] = "LayerNorm", drop_rate: float = 0.0, patch_drop_rate: float = 0.0, head_init_scale: float = 0.001, sep_pos_embed: bool = False, abs_win_pos_embed: bool = False, global_pos_size: Tuple[int, int] = (14, 14), + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.grad_checkpointing = False norm_layer = get_norm_layer(norm_layer) @@ -475,6 +498,7 @@ def __init__( patch_kernel, patch_stride, patch_padding, + **dd, ) self.pos_embed: Optional[nn.Parameter] = None @@ -483,18 +507,18 @@ def __init__( self.pos_embed_temporal: Optional[nn.Parameter] = None if sep_pos_embed: self.pos_embed_spatial = nn.Parameter( - torch.zeros(1, self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], embed_dim) + torch.zeros(1, self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], embed_dim, **dd) ) self.pos_embed_temporal = nn.Parameter( - torch.zeros(1, self.tokens_spatial_shape[0], embed_dim) + torch.zeros(1, self.tokens_spatial_shape[0], embed_dim, **dd) ) else: if abs_win_pos_embed: # absolute win, params NCHW to make tile & interpolate more natural before add & reshape - self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *global_pos_size)) - self.pos_embed_win = nn.Parameter(torch.zeros(1, embed_dim, *mask_unit_size)) + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *global_pos_size, **dd)) + self.pos_embed_win = nn.Parameter(torch.zeros(1, embed_dim, *mask_unit_size, **dd)) else: - self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim, **dd)) # Setup roll and reroll modules self.unroll = Unroll( @@ -544,6 +568,7 @@ def __init__( window_size=flat_mu_size, use_expand_proj=use_expand_proj, use_mask_unit_attn=use_mask_unit_attn, + **dd, ) embed_dim = dim_out if i in self.stage_ends: @@ -559,6 +584,7 @@ def __init__( drop_rate=drop_rate, norm_layer=norm_layer, input_fmt='NLC', + **dd, ) # Initialize everything diff --git a/timm/models/hieradet_sam2.py b/timm/models/hieradet_sam2.py index 9e7b4634ef..33bdcbf994 100644 --- a/timm/models/hieradet_sam2.py +++ b/timm/models/hieradet_sam2.py @@ -1,15 +1,27 @@ import math from copy import deepcopy from functools import partial -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, DropPath, calculate_drop_path_rates, ClNormMlpClassifierHead, LayerScale, \ - get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn +from timm.layers import ( + PatchEmbed, + Mlp, + DropPath, + calculate_drop_path_rates, + ClNormMlpClassifierHead, + LayerScale, + get_norm_layer, + get_act_layer, + init_weight_jax, + init_weight_vit, + to_2tuple, + use_fused_attn, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -61,12 +73,15 @@ class MultiScaleAttention(nn.Module): fused_attn: torch.jit.Final[bool] def __init__( - self, - dim: int, - dim_out: int, - num_heads: int, - q_pool: nn.Module = None, + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.dim_out = dim_out @@ -76,8 +91,8 @@ def __init__( self.fused_attn = use_fused_attn() self.q_pool = q_pool - self.qkv = nn.Linear(dim, dim_out * 3) - self.proj = nn.Linear(dim_out, dim_out) + self.qkv = nn.Linear(dim, dim_out * 3, **dd) + self.proj = nn.Linear(dim_out, dim_out, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, _ = x.shape @@ -116,18 +131,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MultiScaleBlock(nn.Module): def __init__( - self, - dim: int, - dim_out: int, - num_heads: int, - mlp_ratio: float = 4.0, - q_stride: Optional[Tuple[int, int]] = None, - norm_layer: Union[nn.Module, str] = "LayerNorm", - act_layer: Union[nn.Module, str] = "GELU", - window_size: int = 0, - init_values: Optional[float] = None, - drop_path: float = 0.0, + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + q_stride: Optional[Tuple[int, int]] = None, + norm_layer: Union[Type[nn.Module], str] = "LayerNorm", + act_layer: Union[Type[nn.Module], str] = "GELU", + window_size: int = 0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() norm_layer = get_norm_layer(norm_layer) act_layer = get_act_layer(act_layer) @@ -138,7 +156,7 @@ def __init__( self.q_stride = q_stride if dim != dim_out: - self.proj = nn.Linear(dim, dim_out) + self.proj = nn.Linear(dim, dim_out, **dd) else: self.proj = nn.Identity() self.pool = None @@ -150,23 +168,25 @@ def __init__( ceil_mode=False, ) - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = MultiScaleAttention( dim, dim_out, num_heads=num_heads, q_pool=deepcopy(self.pool), + **dd, ) - self.ls1 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity() + self.ls1 = LayerScale(dim_out, init_values, **dd) if init_values is not None else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - self.norm2 = norm_layer(dim_out) + self.norm2 = norm_layer(dim_out, **dd) self.mlp = Mlp( dim_out, int(dim_out * mlp_ratio), act_layer=act_layer, + **dd, ) - self.ls2 = LayerScale(dim_out, init_values) if init_values is not None else nn.Identity() + self.ls2 = LayerScale(dim_out, init_values, **dd) if init_values is not None else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -214,23 +234,31 @@ class HieraPatchEmbed(nn.Module): def __init__( self, - kernel_size: Tuple[int, ...] = (7, 7), - stride: Tuple[int, ...] = (4, 4), - padding: Tuple[int, ...] = (3, 3), + kernel_size: Union[int, Tuple[int, int]] = (7, 7), + stride: Union[int, Tuple[int, int]] = (4, 4), + padding: Union[str, int, Tuple[int, int]] = (3, 3), in_chans: int = 3, embed_dim: int = 768, + device=None, + dtype=None, ): """ Args: - kernel_size (Tuple): kernel size of the projection layer. - stride (Tuple): stride of the projection layer. - padding (Tuple): padding size of the projection layer. - in_chans (int): Number of input image channels. - embed_dim (int): embed_dim (int): Patch embedding dimension. + kernel_size: kernel size of the projection layer. + stride: stride of the projection layer. + padding: padding size of the projection layer. + in_chans: Number of input image channels. + embed_dim: Patch embedding dimension. """ super().__init__() + dd = {'device': device, 'dtype': dtype} self.proj = nn.Conv2d( - in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + in_chans, + embed_dim, + kernel_size=kernel_size, + stride=stride, + padding=padding, + **dd, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -252,10 +280,10 @@ def __init__( global_pool: str = 'avg', embed_dim: int = 96, # initial embed dim num_heads: int = 1, # initial number of heads - patch_kernel: Tuple[int, ...] = (7, 7), - patch_stride: Tuple[int, ...] = (4, 4), - patch_padding: Tuple[int, ...] = (3, 3), - patch_size: Optional[Tuple[int, ...]] = None, + patch_kernel: Tuple[int, int] = (7, 7), + patch_stride: Tuple[int, int] = (4, 4), + patch_padding: Tuple[int, int] = (3, 3), + patch_size: Optional[Tuple[int, int]] = None, q_pool: int = 3, # number of q_pool stages q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage @@ -281,10 +309,13 @@ def __init__( head_init_scale: float = 0.001, drop_rate: float = 0.0, drop_path_rate: float = 0.0, # stochastic depth - norm_layer: Union[nn.Module, str] = "LayerNorm", - act_layer: Union[nn.Module, str] = "GELU", + norm_layer: Union[Type[nn.Module], str] = "LayerNorm", + act_layer: Union[Type[nn.Module], str] = "GELU", + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} norm_layer = get_norm_layer(norm_layer) act_layer = get_act_layer(act_layer) assert len(stages) == len(window_spec) @@ -308,6 +339,7 @@ def __init__( embed_dim=embed_dim, output_fmt='NHWC', dynamic_img_pad=True, + **dd, ) else: self.patch_embed = HieraPatchEmbed( @@ -316,14 +348,15 @@ def __init__( padding=patch_padding, in_chans=in_chans, embed_dim=embed_dim, + **dd, ) # Which blocks have global att? self.global_att_blocks = global_att_blocks # Windowed positional embedding (https://arxiv.org/abs/2311.05613) self.global_pos_size = global_pos_size - self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.global_pos_size)) - self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])) + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.global_pos_size, **dd)) + self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0], **dd)) dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule cur_stage = 0 @@ -354,6 +387,7 @@ def __init__( norm_layer=norm_layer, act_layer=act_layer, init_values=init_values, + **dd, ) embed_dim = dim_out @@ -369,6 +403,7 @@ def __init__( pool_type=global_pool, drop_rate=drop_rate, norm_layer=norm_layer, + **dd, ) # Initialize everything diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 92ee3511cf..0a7734cd94 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -9,7 +9,7 @@ Modified by Ke Sun (sunk@mail.ustc.edu.cn) """ import logging -from typing import List +from typing import Dict, List, Type, Optional, Tuple import torch import torch.nn as nn @@ -357,15 +357,18 @@ class HighResolutionModule(nn.Module): def __init__( self, - num_branches, - block_types, - num_blocks, - num_in_chs, - num_channels, - fuse_method, - multi_scale_output=True, + num_branches: int, + block_types: Type[nn.Module], + num_blocks: Tuple[int, ...], + num_in_chs: List[int], + num_channels: Tuple[int, ...], + fuse_method: str, + multi_scale_output: bool = True, + device=None, + dtype=None, ): - super(HighResolutionModule, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self._check_branches( num_branches, block_types, @@ -385,8 +388,9 @@ def __init__( block_types, num_blocks, num_channels, + **dd, ) - self.fuse_layers = self._make_fuse_layers() + self.fuse_layers = self._make_fuse_layers(**dd) self.fuse_act = nn.ReLU(False) def _check_branches(self, num_branches, block_types, num_blocks, num_in_chs, num_channels): @@ -401,31 +405,39 @@ def _check_branches(self, num_branches, block_types, num_blocks, num_in_chs, num _logger.error(error_msg) raise ValueError(error_msg) - def _make_one_branch(self, branch_index, block_type, num_blocks, num_channels, stride=1): + def _make_one_branch(self, branch_index, block_type, num_blocks, num_channels, stride=1, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} downsample = None if stride != 1 or self.num_in_chs[branch_index] != num_channels[branch_index] * block_type.expansion: downsample = nn.Sequential( nn.Conv2d( - self.num_in_chs[branch_index], num_channels[branch_index] * block_type.expansion, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(num_channels[branch_index] * block_type.expansion, momentum=_BN_MOMENTUM), + self.num_in_chs[branch_index], + num_channels[branch_index] * block_type.expansion, + kernel_size=1, + stride=stride, + bias=False, + **dd, + ), + nn.BatchNorm2d(num_channels[branch_index] * block_type.expansion, momentum=_BN_MOMENTUM, **dd), ) - layers = [block_type(self.num_in_chs[branch_index], num_channels[branch_index], stride, downsample)] + layers = [block_type(self.num_in_chs[branch_index], num_channels[branch_index], stride, downsample, **dd)] self.num_in_chs[branch_index] = num_channels[branch_index] * block_type.expansion for i in range(1, num_blocks[branch_index]): - layers.append(block_type(self.num_in_chs[branch_index], num_channels[branch_index])) + layers.append(block_type(self.num_in_chs[branch_index], num_channels[branch_index], **dd)) return nn.Sequential(*layers) - def _make_branches(self, num_branches, block_type, num_blocks, num_channels): + def _make_branches(self, num_branches, block_type, num_blocks, num_channels, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} branches = [] for i in range(num_branches): - branches.append(self._make_one_branch(i, block_type, num_blocks, num_channels)) + branches.append(self._make_one_branch(i, block_type, num_blocks, num_channels, **dd)) return nn.ModuleList(branches) - def _make_fuse_layers(self): + def _make_fuse_layers(self, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} if self.num_branches == 1: return nn.Identity() @@ -437,8 +449,8 @@ def _make_fuse_layers(self): for j in range(num_branches): if j > i: fuse_layer.append(nn.Sequential( - nn.Conv2d(num_in_chs[j], num_in_chs[i], 1, 1, 0, bias=False), - nn.BatchNorm2d(num_in_chs[i], momentum=_BN_MOMENTUM), + nn.Conv2d(num_in_chs[j], num_in_chs[i], 1, 1, 0, bias=False, **dd), + nn.BatchNorm2d(num_in_chs[i], momentum=_BN_MOMENTUM, **dd), nn.Upsample(scale_factor=2 ** (j - i), mode='nearest'))) elif j == i: fuse_layer.append(nn.Identity()) @@ -448,14 +460,14 @@ def _make_fuse_layers(self): if k == i - j - 1: num_out_chs_conv3x3 = num_in_chs[i] conv3x3s.append(nn.Sequential( - nn.Conv2d(num_in_chs[j], num_out_chs_conv3x3, 3, 2, 1, bias=False), - nn.BatchNorm2d(num_out_chs_conv3x3, momentum=_BN_MOMENTUM) + nn.Conv2d(num_in_chs[j], num_out_chs_conv3x3, 3, 2, 1, bias=False, **dd), + nn.BatchNorm2d(num_out_chs_conv3x3, momentum=_BN_MOMENTUM, **dd) )) else: num_out_chs_conv3x3 = num_in_chs[j] conv3x3s.append(nn.Sequential( - nn.Conv2d(num_in_chs[j], num_out_chs_conv3x3, 3, 2, 1, bias=False), - nn.BatchNorm2d(num_out_chs_conv3x3, momentum=_BN_MOMENTUM), + nn.Conv2d(num_in_chs[j], num_out_chs_conv3x3, 3, 2, 1, bias=False, **dd), + nn.BatchNorm2d(num_out_chs_conv3x3, momentum=_BN_MOMENTUM, **dd), nn.ReLU(False) )) fuse_layer.append(nn.Sequential(*conv3x3s)) @@ -488,7 +500,7 @@ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: class SequentialList(nn.Sequential): def __init__(self, *args): - super(SequentialList, self).__init__(*args) + super().__init__(*args) @torch.jit._overload_method # noqa: F811 def forward(self, x): @@ -522,55 +534,58 @@ class HighResolutionNet(nn.Module): def __init__( self, - cfg, - in_chans=3, - num_classes=1000, - output_stride=32, - global_pool='avg', - drop_rate=0.0, - head='classification', + cfg: Dict, + in_chans: int = 3, + num_classes: int = 1000, + output_stride: int = 32, + global_pool: str = 'avg', + drop_rate: float = 0.0, + head: str = 'classification', + device=None, + dtype=None, **kwargs, ): - super(HighResolutionNet, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.num_classes = num_classes assert output_stride == 32 # FIXME support dilation cfg.update(**kwargs) stem_width = cfg['stem_width'] - self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=3, stride=2, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(stem_width, momentum=_BN_MOMENTUM) + self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **dd) + self.bn1 = nn.BatchNorm2d(stem_width, momentum=_BN_MOMENTUM, **dd) self.act1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(stem_width, 64, kernel_size=3, stride=2, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(64, momentum=_BN_MOMENTUM) + self.conv2 = nn.Conv2d(stem_width, 64, kernel_size=3, stride=2, padding=1, bias=False, **dd) + self.bn2 = nn.BatchNorm2d(64, momentum=_BN_MOMENTUM, **dd) self.act2 = nn.ReLU(inplace=True) self.stage1_cfg = cfg['stage1'] num_channels = self.stage1_cfg['num_channels'][0] block_type = block_types_dict[self.stage1_cfg['block_type']] num_blocks = self.stage1_cfg['num_blocks'][0] - self.layer1 = self._make_layer(block_type, 64, num_channels, num_blocks) + self.layer1 = self._make_layer(block_type, 64, num_channels, num_blocks, **dd) stage1_out_channel = block_type.expansion * num_channels self.stage2_cfg = cfg['stage2'] num_channels = self.stage2_cfg['num_channels'] block_type = block_types_dict[self.stage2_cfg['block_type']] num_channels = [num_channels[i] * block_type.expansion for i in range(len(num_channels))] - self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels) - self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) + self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels, **dd) + self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels, **dd) self.stage3_cfg = cfg['stage3'] num_channels = self.stage3_cfg['num_channels'] block_type = block_types_dict[self.stage3_cfg['block_type']] num_channels = [num_channels[i] * block_type.expansion for i in range(len(num_channels))] - self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) - self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels, **dd) + self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels, **dd) self.stage4_cfg = cfg['stage4'] num_channels = self.stage4_cfg['num_channels'] block_type = block_types_dict[self.stage4_cfg['block_type']] num_channels = [num_channels[i] * block_type.expansion for i in range(len(num_channels))] - self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) - self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True) + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels, **dd) + self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True, **dd) self.head = head self.head_channels = None # set if _make_head called @@ -581,17 +596,19 @@ def __init__( self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head( pre_stage_channels, conv_bias=head_conv_bias, + **dd, ) self.global_pool, self.head_drop, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate, + **dd, ) else: if head == 'incre': self.num_features = self.head_hidden_size = 2048 - self.incre_modules, _, _ = self._make_head(pre_stage_channels, incre_only=True) + self.incre_modules, _, _ = self._make_head(pre_stage_channels, incre_only=True, **dd) else: self.num_features = self.head_hidden_size = 256 self.incre_modules = None @@ -609,7 +626,8 @@ def __init__( self.init_weights() - def _make_head(self, pre_stage_channels, incre_only=False, conv_bias=True): + def _make_head(self, pre_stage_channels, incre_only=False, conv_bias=True, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} head_block_type = Bottleneck self.head_channels = [32, 64, 128, 256] @@ -617,7 +635,7 @@ def _make_head(self, pre_stage_channels, incre_only=False, conv_bias=True): # from C, 2C, 4C, 8C to 128, 256, 512, 1024 incre_modules = [] for i, channels in enumerate(pre_stage_channels): - incre_modules.append(self._make_layer(head_block_type, channels, self.head_channels[i], 1, stride=1)) + incre_modules.append(self._make_layer(head_block_type, channels, self.head_channels[i], 1, stride=1, **dd)) incre_modules = nn.ModuleList(incre_modules) if incre_only: return incre_modules, None, None @@ -629,9 +647,15 @@ def _make_head(self, pre_stage_channels, incre_only=False, conv_bias=True): out_channels = self.head_channels[i + 1] * head_block_type.expansion downsamp_module = nn.Sequential( nn.Conv2d( - in_channels=in_channels, out_channels=out_channels, - kernel_size=3, stride=2, padding=1, bias=conv_bias), - nn.BatchNorm2d(out_channels, momentum=_BN_MOMENTUM), + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=conv_bias, + **dd, + ), + nn.BatchNorm2d(out_channels, momentum=_BN_MOMENTUM, **dd), nn.ReLU(inplace=True) ) downsamp_modules.append(downsamp_module) @@ -639,15 +663,22 @@ def _make_head(self, pre_stage_channels, incre_only=False, conv_bias=True): final_layer = nn.Sequential( nn.Conv2d( - in_channels=self.head_channels[3] * head_block_type.expansion, out_channels=self.num_features, - kernel_size=1, stride=1, padding=0, bias=conv_bias), - nn.BatchNorm2d(self.num_features, momentum=_BN_MOMENTUM), + in_channels=self.head_channels[3] * head_block_type.expansion, + out_channels=self.num_features, + kernel_size=1, + stride=1, + padding=0, + bias=conv_bias, + **dd, + ), + nn.BatchNorm2d(self.num_features, momentum=_BN_MOMENTUM, **dd), nn.ReLU(inplace=True) ) return incre_modules, downsamp_modules, final_layer - def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) @@ -656,8 +687,8 @@ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer) if i < num_branches_pre: if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append(nn.Sequential( - nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), - nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM), + nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False, **dd), + nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM, **dd), nn.ReLU(inplace=True))) else: transition_layers.append(nn.Identity()) @@ -667,29 +698,30 @@ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer) _in_chs = num_channels_pre_layer[-1] _out_chs = num_channels_cur_layer[i] if j == i - num_branches_pre else _in_chs conv3x3s.append(nn.Sequential( - nn.Conv2d(_in_chs, _out_chs, 3, 2, 1, bias=False), - nn.BatchNorm2d(_out_chs, momentum=_BN_MOMENTUM), + nn.Conv2d(_in_chs, _out_chs, 3, 2, 1, bias=False, **dd), + nn.BatchNorm2d(_out_chs, momentum=_BN_MOMENTUM, **dd), nn.ReLU(inplace=True))) transition_layers.append(nn.Sequential(*conv3x3s)) return nn.ModuleList(transition_layers) - def _make_layer(self, block_type, inplanes, planes, block_types, stride=1): + def _make_layer(self, block_type, inplanes, planes, block_types, stride=1, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} downsample = None if stride != 1 or inplanes != planes * block_type.expansion: downsample = nn.Sequential( - nn.Conv2d(inplanes, planes * block_type.expansion, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block_type.expansion, momentum=_BN_MOMENTUM), + nn.Conv2d(inplanes, planes * block_type.expansion, kernel_size=1, stride=stride, bias=False, **dd), + nn.BatchNorm2d(planes * block_type.expansion, momentum=_BN_MOMENTUM, **dd), ) - layers = [block_type(inplanes, planes, stride, downsample)] + layers = [block_type(inplanes, planes, stride, downsample, **dd)] inplanes = planes * block_type.expansion for i in range(1, block_types): - layers.append(block_type(inplanes, planes)) + layers.append(block_type(inplanes, planes, **dd)) return nn.Sequential(*layers) - def _make_stage(self, layer_config, num_in_chs, multi_scale_output=True): + def _make_stage(self, layer_config, num_in_chs, multi_scale_output=True, device=None, dtype=None): num_modules = layer_config['num_modules'] num_branches = layer_config['num_branches'] num_blocks = layer_config['num_blocks'] @@ -702,8 +734,16 @@ def _make_stage(self, layer_config, num_in_chs, multi_scale_output=True): # multi_scale_output is only used last module reset_multi_scale_output = multi_scale_output or i < num_modules - 1 modules.append(HighResolutionModule( - num_branches, block_type, num_blocks, num_in_chs, num_channels, fuse_method, reset_multi_scale_output) - ) + num_branches, + block_type, + num_blocks, + num_in_chs, + num_channels, + fuse_method, + reset_multi_scale_output, + device=device, + dtype=dtype, + )) num_in_chs = modules[-1].get_num_in_chs() return SequentialList(*modules), num_in_chs @@ -817,7 +857,7 @@ def __init__( **kwargs, ): assert feature_location in ('incre', '') - super(HighResolutionNetFeatures, self).__init__( + super().__init__( cfg, in_chans=in_chans, num_classes=num_classes, diff --git a/timm/models/inception_next.py b/timm/models/inception_next.py index 326e081732..a2a2097458 100644 --- a/timm/models/inception_next.py +++ b/timm/models/inception_next.py @@ -4,7 +4,7 @@ """ from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type import torch import torch.nn as nn @@ -25,26 +25,28 @@ class InceptionDWConv2d(nn.Module): def __init__( self, - in_chs, - square_kernel_size=3, - band_kernel_size=11, - branch_ratio=0.125, - dilation=1, + in_chs: int, + square_kernel_size: int = 3, + band_kernel_size: int = 11, + branch_ratio: float = 0.125, + dilation: int = 1, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - gc = int(in_chs * branch_ratio) # channel numbers of a convolution branch square_padding = get_padding(square_kernel_size, dilation=dilation) band_padding = get_padding(band_kernel_size, dilation=dilation) self.dwconv_hw = nn.Conv2d( gc, gc, square_kernel_size, - padding=square_padding, dilation=dilation, groups=gc) + padding=square_padding, dilation=dilation, groups=gc, **dd) self.dwconv_w = nn.Conv2d( gc, gc, (1, band_kernel_size), - padding=(0, band_padding), dilation=(1, dilation), groups=gc) + padding=(0, band_padding), dilation=(1, dilation), groups=gc, **dd) self.dwconv_h = nn.Conv2d( gc, gc, (band_kernel_size, 1), - padding=(band_padding, 0), dilation=(dilation, 1), groups=gc) + padding=(band_padding, 0), dilation=(dilation, 1), groups=gc, **dd) self.split_indexes = (in_chs - 3 * gc, gc, gc, gc) def forward(self, x): @@ -65,24 +67,27 @@ class ConvMlp(nn.Module): def __init__( self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.ReLU, - norm_layer=None, - bias=True, - drop=0., + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Optional[Type[nn.Module]] = None, + bias: bool = True, + drop: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) - self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) - self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0], **dd) + self.norm = norm_layer(hidden_features, **dd) if norm_layer else nn.Identity() self.act = act_layer() self.drop = nn.Dropout(drop) - self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1], **dd) def forward(self, x): x = self.fc1(x) @@ -99,15 +104,18 @@ class MlpClassifierHead(nn.Module): def __init__( self, - in_features, - num_classes=1000, - pool_type='avg', - mlp_ratio=3, - act_layer=nn.GELU, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - drop=0., - bias=True + in_features: int, + num_classes: int = 1000, + pool_type: str = 'avg', + mlp_ratio: float = 3, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), + drop: float = 0., + bias: bool = True, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.use_conv = False self.in_features = in_features @@ -116,10 +124,10 @@ def __init__( assert pool_type, 'Cannot disable pooling' self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True) - self.fc1 = nn.Linear(in_features * self.global_pool.feat_mult(), hidden_features, bias=bias) + self.fc1 = nn.Linear(in_features * self.global_pool.feat_mult(), hidden_features, bias=bias, **dd) self.act = act_layer() - self.norm = norm_layer(hidden_features) - self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) + self.norm = norm_layer(hidden_features, **dd) + self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias, **dd) self.drop = nn.Dropout(drop) def reset(self, num_classes: int, pool_type: Optional[str] = None): @@ -148,22 +156,24 @@ class MetaNeXtBlock(nn.Module): def __init__( self, - dim, - dilation=1, - token_mixer=InceptionDWConv2d, - norm_layer=nn.BatchNorm2d, - mlp_layer=ConvMlp, - mlp_ratio=4, - act_layer=nn.GELU, - ls_init_value=1e-6, - drop_path=0., - + dim: int, + dilation: int = 1, + token_mixer: Type[nn.Module] = InceptionDWConv2d, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + mlp_layer: Type[nn.Module] = ConvMlp, + mlp_ratio: float = 4, + act_layer: Type[nn.Module] = nn.GELU, + ls_init_value: float = 1e-6, + drop_path: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.token_mixer = token_mixer(dim, dilation=dilation) - self.norm = norm_layer(dim) - self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=act_layer) - self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value else None + self.token_mixer = token_mixer(dim, dilation=dilation, **dd) + self.norm = norm_layer(dim, **dd) + self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=act_layer, **dd) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim, **dd)) if ls_init_value else None self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): @@ -180,29 +190,33 @@ def forward(self, x): class MetaNeXtStage(nn.Module): def __init__( self, - in_chs, - out_chs, - stride=2, - depth=2, - dilation=(1, 1), - drop_path_rates=None, - ls_init_value=1.0, - token_mixer=InceptionDWConv2d, - act_layer=nn.GELU, - norm_layer=None, - mlp_ratio=4, + in_chs: int, + out_chs: int, + stride: int = 2, + depth: int = 2, + dilation: Tuple[int, int] = (1, 1), + drop_path_rates: Optional[List[float]] = None, + ls_init_value: float = 1.0, + token_mixer: Type[nn.Module] = InceptionDWConv2d, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Optional[Type[nn.Module]] = None, + mlp_ratio: float = 4, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.grad_checkpointing = False if stride > 1 or dilation[0] != dilation[1]: self.downsample = nn.Sequential( - norm_layer(in_chs), + norm_layer(in_chs, **dd), nn.Conv2d( in_chs, out_chs, kernel_size=2, stride=stride, dilation=dilation[0], + **dd, ), ) else: @@ -220,6 +234,7 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, mlp_ratio=mlp_ratio, + **dd, )) self.blocks = nn.Sequential(*stage_blocks) @@ -252,22 +267,24 @@ class MetaNeXt(nn.Module): def __init__( self, - in_chans=3, - num_classes=1000, - global_pool='avg', - output_stride=32, - depths=(3, 3, 9, 3), - dims=(96, 192, 384, 768), - token_mixers=InceptionDWConv2d, - norm_layer=nn.BatchNorm2d, - act_layer=nn.GELU, - mlp_ratios=(4, 4, 4, 3), - drop_rate=0., - drop_path_rate=0., - ls_init_value=1e-6, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + output_stride: int = 32, + depths: Tuple[int, ...] = (3, 3, 9, 3), + dims: Tuple[int, ...] = (96, 192, 384, 768), + token_mixers: Union[Type[nn.Module], List[Type[nn.Module]]] = InceptionDWConv2d, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.GELU, + mlp_ratios: Union[int, Tuple[int, ...]] = (4, 4, 4, 3), + drop_rate: float = 0., + drop_path_rate: float = 0., + ls_init_value: float = 1e-6, + device=None, + dtype=None, ): super().__init__() - + dd = {'device': device, 'dtype': dtype} num_stage = len(depths) if not isinstance(token_mixers, (list, tuple)): token_mixers = [token_mixers] * num_stage @@ -279,8 +296,8 @@ def __init__( self.feature_info = [] self.stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), - norm_layer(dims[0]) + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, **dd), + norm_layer(dims[0], **dd) ) dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) @@ -309,11 +326,12 @@ def __init__( token_mixer=token_mixers[i], norm_layer=norm_layer, mlp_ratio=mlp_ratios[i], + **dd, )) prev_chs = out_chs self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] self.num_features = prev_chs - self.head = MlpClassifierHead(self.num_features, num_classes, pool_type=self.global_pool, drop=drop_rate) + self.head = MlpClassifierHead(self.num_features, num_classes, pool_type=self.global_pool, drop=drop_rate, **dd) self.head_hidden_size = self.head.num_features self.apply(self._init_weights) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index d691be7a8f..0e27efebeb 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -3,6 +3,7 @@ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) """ from functools import partial +from typing import Type, Optional import torch import torch.nn as nn @@ -16,26 +17,32 @@ class Mixed_5b(nn.Module): - def __init__(self, conv_block=None): - super(Mixed_5b, self).__init__() + def __init__( + self, + conv_block: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() conv_block = conv_block or ConvNormAct - self.branch0 = conv_block(192, 96, kernel_size=1, stride=1) + self.branch0 = conv_block(192, 96, kernel_size=1, stride=1, **dd) self.branch1 = nn.Sequential( - conv_block(192, 48, kernel_size=1, stride=1), - conv_block(48, 64, kernel_size=5, stride=1, padding=2) + conv_block(192, 48, kernel_size=1, stride=1, **dd), + conv_block(48, 64, kernel_size=5, stride=1, padding=2, **dd) ) self.branch2 = nn.Sequential( - conv_block(192, 64, kernel_size=1, stride=1), - conv_block(64, 96, kernel_size=3, stride=1, padding=1), - conv_block(96, 96, kernel_size=3, stride=1, padding=1) + conv_block(192, 64, kernel_size=1, stride=1, **dd), + conv_block(64, 96, kernel_size=3, stride=1, padding=1, **dd), + conv_block(96, 96, kernel_size=3, stride=1, padding=1, **dd) ) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - conv_block(192, 64, kernel_size=1, stride=1) + conv_block(192, 64, kernel_size=1, stride=1, **dd) ) def forward(self, x): @@ -48,25 +55,32 @@ def forward(self, x): class Block35(nn.Module): - def __init__(self, scale=1.0, conv_block=None): - super(Block35, self).__init__() + def __init__( + self, + scale: float = 1.0, + conv_block: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.scale = scale conv_block = conv_block or ConvNormAct - self.branch0 = conv_block(320, 32, kernel_size=1, stride=1) + self.branch0 = conv_block(320, 32, kernel_size=1, stride=1, **dd) self.branch1 = nn.Sequential( - conv_block(320, 32, kernel_size=1, stride=1), - conv_block(32, 32, kernel_size=3, stride=1, padding=1) + conv_block(320, 32, kernel_size=1, stride=1, **dd), + conv_block(32, 32, kernel_size=3, stride=1, padding=1, **dd) ) self.branch2 = nn.Sequential( - conv_block(320, 32, kernel_size=1, stride=1), - conv_block(32, 48, kernel_size=3, stride=1, padding=1), - conv_block(48, 64, kernel_size=3, stride=1, padding=1) + conv_block(320, 32, kernel_size=1, stride=1, **dd), + conv_block(32, 48, kernel_size=3, stride=1, padding=1, **dd), + conv_block(48, 64, kernel_size=3, stride=1, padding=1, **dd) ) - self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) + self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1, **dd) self.act = nn.ReLU() def forward(self, x): @@ -81,16 +95,22 @@ def forward(self, x): class Mixed_6a(nn.Module): - def __init__(self, conv_block=None): - super(Mixed_6a, self).__init__() + def __init__( + self, + conv_block: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() conv_block = conv_block or ConvNormAct - self.branch0 = conv_block(320, 384, kernel_size=3, stride=2) + self.branch0 = conv_block(320, 384, kernel_size=3, stride=2, **dd) self.branch1 = nn.Sequential( - conv_block(320, 256, kernel_size=1, stride=1), - conv_block(256, 256, kernel_size=3, stride=1, padding=1), - conv_block(256, 384, kernel_size=3, stride=2) + conv_block(320, 256, kernel_size=1, stride=1, **dd), + conv_block(256, 256, kernel_size=3, stride=1, padding=1, **dd), + conv_block(256, 384, kernel_size=3, stride=2, **dd) ) self.branch2 = nn.MaxPool2d(3, stride=2) @@ -104,20 +124,27 @@ def forward(self, x): class Block17(nn.Module): - def __init__(self, scale=1.0, conv_block=None): - super(Block17, self).__init__() + def __init__( + self, + scale: float = 1.0, + conv_block: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.scale = scale conv_block = conv_block or ConvNormAct - self.branch0 = conv_block(1088, 192, kernel_size=1, stride=1) + self.branch0 = conv_block(1088, 192, kernel_size=1, stride=1, **dd) self.branch1 = nn.Sequential( - conv_block(1088, 128, kernel_size=1, stride=1), - conv_block(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)), - conv_block(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)) + conv_block(1088, 128, kernel_size=1, stride=1, **dd), + conv_block(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd), + conv_block(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd) ) - self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) + self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1, **dd) self.act = nn.ReLU() def forward(self, x): @@ -131,24 +158,30 @@ def forward(self, x): class Mixed_7a(nn.Module): - def __init__(self, conv_block=None): - super(Mixed_7a, self).__init__() + def __init__( + self, + conv_block: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() conv_block = conv_block or ConvNormAct self.branch0 = nn.Sequential( - conv_block(1088, 256, kernel_size=1, stride=1), - conv_block(256, 384, kernel_size=3, stride=2) + conv_block(1088, 256, kernel_size=1, stride=1, **dd), + conv_block(256, 384, kernel_size=3, stride=2, **dd) ) self.branch1 = nn.Sequential( - conv_block(1088, 256, kernel_size=1, stride=1), - conv_block(256, 288, kernel_size=3, stride=2) + conv_block(1088, 256, kernel_size=1, stride=1, **dd), + conv_block(256, 288, kernel_size=3, stride=2, **dd) ) self.branch2 = nn.Sequential( - conv_block(1088, 256, kernel_size=1, stride=1), - conv_block(256, 288, kernel_size=3, stride=1, padding=1), - conv_block(288, 320, kernel_size=3, stride=2) + conv_block(1088, 256, kernel_size=1, stride=1, **dd), + conv_block(256, 288, kernel_size=3, stride=1, padding=1, **dd), + conv_block(288, 320, kernel_size=3, stride=2, **dd) ) self.branch3 = nn.MaxPool2d(3, stride=2) @@ -164,20 +197,28 @@ def forward(self, x): class Block8(nn.Module): - def __init__(self, scale=1.0, no_relu=False, conv_block=None): - super(Block8, self).__init__() + def __init__( + self, + scale: float = 1.0, + no_relu: bool = False, + conv_block: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.scale = scale conv_block = conv_block or ConvNormAct - self.branch0 = conv_block(2080, 192, kernel_size=1, stride=1) + self.branch0 = conv_block(2080, 192, kernel_size=1, stride=1, **dd) self.branch1 = nn.Sequential( - conv_block(2080, 192, kernel_size=1, stride=1), - conv_block(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)), - conv_block(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + conv_block(2080, 192, kernel_size=1, stride=1, **dd), + conv_block(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1), **dd), + conv_block(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0), **dd) ) - self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) + self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1, **dd) self.relu = None if no_relu else nn.ReLU() def forward(self, x): @@ -194,16 +235,19 @@ def forward(self, x): class InceptionResnetV2(nn.Module): def __init__( self, - num_classes=1000, - in_chans=3, - drop_rate=0., - output_stride=32, - global_pool='avg', - norm_layer='batchnorm2d', - norm_eps=1e-3, - act_layer='relu', - ): - super(InceptionResnetV2, self).__init__() + num_classes: int = 1000, + in_chans: int = 3, + drop_rate: float = 0., + output_stride: int = 32, + global_pool: str = 'avg', + norm_layer: str = 'batchnorm2d', + norm_eps: float = 1e-3, + act_layer: str = 'relu', + device=None, + dtype=None, + ) -> None: + super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.num_features = self.head_hidden_size = 1536 assert output_stride == 32 @@ -216,34 +260,39 @@ def __init__( act_kwargs=dict(inplace=True), ) - self.conv2d_1a = conv_block(in_chans, 32, kernel_size=3, stride=2) - self.conv2d_2a = conv_block(32, 32, kernel_size=3, stride=1) - self.conv2d_2b = conv_block(32, 64, kernel_size=3, stride=1, padding=1) + self.conv2d_1a = conv_block(in_chans, 32, kernel_size=3, stride=2, **dd) + self.conv2d_2a = conv_block(32, 32, kernel_size=3, stride=1, **dd) + self.conv2d_2b = conv_block(32, 64, kernel_size=3, stride=1, padding=1, **dd) self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')] self.maxpool_3a = nn.MaxPool2d(3, stride=2) - self.conv2d_3b = conv_block(64, 80, kernel_size=1, stride=1) - self.conv2d_4a = conv_block(80, 192, kernel_size=3, stride=1) + self.conv2d_3b = conv_block(64, 80, kernel_size=1, stride=1, **dd) + self.conv2d_4a = conv_block(80, 192, kernel_size=3, stride=1, **dd) self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')] self.maxpool_5a = nn.MaxPool2d(3, stride=2) - self.mixed_5b = Mixed_5b(conv_block=conv_block) - self.repeat = nn.Sequential(*[Block35(scale=0.17, conv_block=conv_block) for _ in range(10)]) + self.mixed_5b = Mixed_5b(conv_block=conv_block, **dd) + self.repeat = nn.Sequential(*[Block35(scale=0.17, conv_block=conv_block, **dd) for _ in range(10)]) self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')] - self.mixed_6a = Mixed_6a(conv_block=conv_block) - self.repeat_1 = nn.Sequential(*[Block17(scale=0.10, conv_block=conv_block) for _ in range(20)]) + self.mixed_6a = Mixed_6a(conv_block=conv_block, **dd) + self.repeat_1 = nn.Sequential(*[Block17(scale=0.10, conv_block=conv_block, **dd) for _ in range(20)]) self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')] - self.mixed_7a = Mixed_7a(conv_block=conv_block) - self.repeat_2 = nn.Sequential(*[Block8(scale=0.20, conv_block=conv_block) for _ in range(9)]) + self.mixed_7a = Mixed_7a(conv_block=conv_block, **dd) + self.repeat_2 = nn.Sequential(*[Block8(scale=0.20, conv_block=conv_block, **dd) for _ in range(9)]) - self.block8 = Block8(no_relu=True, conv_block=conv_block) - self.conv2d_7b = conv_block(2080, self.num_features, kernel_size=1, stride=1) + self.block8 = Block8(no_relu=True, conv_block=conv_block, **dd) + self.conv2d_7b = conv_block(2080, self.num_features, kernel_size=1, stride=1, **dd) self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')] self.global_pool, self.head_drop, self.classif = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate) + self.num_features, + self.num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + **dd, + ) @torch.jit.ignore def group_matcher(self, coarse=False): diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index a55521c3de..fd073c72fd 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -4,6 +4,7 @@ Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE """ from functools import partial +from typing import Optional, Type import torch import torch.nn as nn @@ -21,19 +22,27 @@ class InceptionA(nn.Module): - def __init__(self, in_channels, pool_features, conv_block=None): - super(InceptionA, self).__init__() + def __init__( + self, + in_channels: int, + pool_features: int, + conv_block: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() conv_block = conv_block or ConvNormAct - self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) + self.branch1x1 = conv_block(in_channels, 64, kernel_size=1, **dd) - self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) - self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) + self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1, **dd) + self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2, **dd) - self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) - self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) - self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1, **dd) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1, **dd) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1, **dd) - self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) + self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1, **dd) def _forward(self, x): branch1x1 = self.branch1x1(x) @@ -58,14 +67,21 @@ def forward(self, x): class InceptionB(nn.Module): - def __init__(self, in_channels, conv_block=None): - super(InceptionB, self).__init__() + def __init__( + self, + in_channels: int, + conv_block: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() conv_block = conv_block or ConvNormAct - self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) + self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2, **dd) - self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) - self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) - self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1, **dd) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1, **dd) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2, **dd) def _forward(self, x): branch3x3 = self.branch3x3(x) @@ -86,23 +102,31 @@ def forward(self, x): class InceptionC(nn.Module): - def __init__(self, in_channels, channels_7x7, conv_block=None): - super(InceptionC, self).__init__() + def __init__( + self, + in_channels: int, + channels_7x7: int, + conv_block: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() conv_block = conv_block or ConvNormAct - self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) + self.branch1x1 = conv_block(in_channels, 192, kernel_size=1, **dd) c7 = channels_7x7 - self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) - self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) - self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1, **dd) + self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3), **dd) + self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0), **dd) - self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) - self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) - self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) - self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) - self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1, **dd) + self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0), **dd) + self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3), **dd) + self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0), **dd) + self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3), **dd) - self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + self.branch_pool = conv_block(in_channels, 192, kernel_size=1, **dd) def _forward(self, x): branch1x1 = self.branch1x1(x) @@ -130,16 +154,23 @@ def forward(self, x): class InceptionD(nn.Module): - def __init__(self, in_channels, conv_block=None): - super(InceptionD, self).__init__() + def __init__( + self, + in_channels: int, + conv_block: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() conv_block = conv_block or ConvNormAct - self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) - self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) + self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1, **dd) + self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2, **dd) - self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) - self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) - self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) - self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) + self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1, **dd) + self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3), **dd) + self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0), **dd) + self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2, **dd) def _forward(self, x): branch3x3 = self.branch3x3_1(x) @@ -161,21 +192,28 @@ def forward(self, x): class InceptionE(nn.Module): - def __init__(self, in_channels, conv_block=None): - super(InceptionE, self).__init__() + def __init__( + self, + in_channels: int, + conv_block: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() conv_block = conv_block or ConvNormAct - self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) + self.branch1x1 = conv_block(in_channels, 320, kernel_size=1, **dd) - self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) - self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) - self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1, **dd) + self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1), **dd) + self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0), **dd) - self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) - self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) - self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) - self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1, **dd) + self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1, **dd) + self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1), **dd) + self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0), **dd) - self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + self.branch_pool = conv_block(in_channels, 192, kernel_size=1, **dd) def _forward(self, x): branch1x1 = self.branch1x1(x) @@ -208,13 +246,21 @@ def forward(self, x): class InceptionAux(nn.Module): - def __init__(self, in_channels, num_classes, conv_block=None): - super(InceptionAux, self).__init__() + def __init__( + self, + in_channels: int, + num_classes: int, + conv_block: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() conv_block = conv_block or ConvNormAct - self.conv0 = conv_block(in_channels, 128, kernel_size=1) - self.conv1 = conv_block(128, 768, kernel_size=5) + self.conv0 = conv_block(in_channels, 128, kernel_size=1, **dd) + self.conv1 = conv_block(128, 768, kernel_size=5, **dd) self.conv1.stddev = 0.01 - self.fc = Linear(768, num_classes) + self.fc = Linear(768, num_classes, **dd) self.fc.stddev = 0.001 def forward(self, x): @@ -242,16 +288,19 @@ class InceptionV3(nn.Module): def __init__( self, - num_classes=1000, - in_chans=3, - drop_rate=0., - global_pool='avg', - aux_logits=False, - norm_layer='batchnorm2d', - norm_eps=1e-3, - act_layer='relu', + num_classes: int = 1000, + in_chans: int = 3, + drop_rate: float = 0., + global_pool: str = 'avg', + aux_logits: bool = False, + norm_layer: str = 'batchnorm2d', + norm_eps: float = 1e-3, + act_layer: str = 'relu', + device=None, + dtype=None, ): - super(InceptionV3, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.aux_logits = aux_logits conv_block = partial( @@ -263,28 +312,28 @@ def __init__( act_kwargs=dict(inplace=True), ) - self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2) - self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) - self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) + self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2, **dd) + self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3, **dd) + self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1, **dd) self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2) - self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) - self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) + self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1, **dd) + self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3, **dd) self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2) - self.Mixed_5b = InceptionA(192, pool_features=32, conv_block=conv_block) - self.Mixed_5c = InceptionA(256, pool_features=64, conv_block=conv_block) - self.Mixed_5d = InceptionA(288, pool_features=64, conv_block=conv_block) - self.Mixed_6a = InceptionB(288, conv_block=conv_block) - self.Mixed_6b = InceptionC(768, channels_7x7=128, conv_block=conv_block) - self.Mixed_6c = InceptionC(768, channels_7x7=160, conv_block=conv_block) - self.Mixed_6d = InceptionC(768, channels_7x7=160, conv_block=conv_block) - self.Mixed_6e = InceptionC(768, channels_7x7=192, conv_block=conv_block) + self.Mixed_5b = InceptionA(192, pool_features=32, conv_block=conv_block, **dd) + self.Mixed_5c = InceptionA(256, pool_features=64, conv_block=conv_block, **dd) + self.Mixed_5d = InceptionA(288, pool_features=64, conv_block=conv_block, **dd) + self.Mixed_6a = InceptionB(288, conv_block=conv_block, **dd) + self.Mixed_6b = InceptionC(768, channels_7x7=128, conv_block=conv_block, **dd) + self.Mixed_6c = InceptionC(768, channels_7x7=160, conv_block=conv_block, **dd) + self.Mixed_6d = InceptionC(768, channels_7x7=160, conv_block=conv_block, **dd) + self.Mixed_6e = InceptionC(768, channels_7x7=192, conv_block=conv_block, **dd) if aux_logits: - self.AuxLogits = InceptionAux(768, num_classes, conv_block=conv_block) + self.AuxLogits = InceptionAux(768, num_classes, conv_block=conv_block, **dd) else: self.AuxLogits = None - self.Mixed_7a = InceptionD(768, conv_block=conv_block) - self.Mixed_7b = InceptionE(1280, conv_block=conv_block) - self.Mixed_7c = InceptionE(2048, conv_block=conv_block) + self.Mixed_7a = InceptionD(768, conv_block=conv_block, **dd) + self.Mixed_7b = InceptionE(1280, conv_block=conv_block, **dd) + self.Mixed_7c = InceptionE(2048, conv_block=conv_block, **dd) self.feature_info = [ dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'), dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'), @@ -299,6 +348,7 @@ def __init__( self.num_classes, pool_type=global_pool, drop_rate=drop_rate, + **dd, ) for m in self.modules(): diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index cadbf95284..9c06e1620a 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -3,7 +3,7 @@ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) """ from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type import torch import torch.nn as nn @@ -18,10 +18,16 @@ class Mixed3a(nn.Module): - def __init__(self, conv_block=ConvNormAct): - super(Mixed3a, self).__init__() + def __init__( + self, + conv_block: Type[nn.Module] = ConvNormAct, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.maxpool = nn.MaxPool2d(3, stride=2) - self.conv = conv_block(64, 96, kernel_size=3, stride=2) + self.conv = conv_block(64, 96, kernel_size=3, stride=2, **dd) def forward(self, x): x0 = self.maxpool(x) @@ -31,19 +37,25 @@ def forward(self, x): class Mixed4a(nn.Module): - def __init__(self, conv_block=ConvNormAct): - super(Mixed4a, self).__init__() + def __init__( + self, + conv_block: Type[nn.Module] = ConvNormAct, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.branch0 = nn.Sequential( - conv_block(160, 64, kernel_size=1, stride=1), - conv_block(64, 96, kernel_size=3, stride=1) + conv_block(160, 64, kernel_size=1, stride=1, **dd), + conv_block(64, 96, kernel_size=3, stride=1, **dd) ) self.branch1 = nn.Sequential( - conv_block(160, 64, kernel_size=1, stride=1), - conv_block(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), - conv_block(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), - conv_block(64, 96, kernel_size=(3, 3), stride=1) + conv_block(160, 64, kernel_size=1, stride=1, **dd), + conv_block(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd), + conv_block(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd), + conv_block(64, 96, kernel_size=(3, 3), stride=1, **dd) ) def forward(self, x): @@ -54,9 +66,15 @@ def forward(self, x): class Mixed5a(nn.Module): - def __init__(self, conv_block=ConvNormAct): - super(Mixed5a, self).__init__() - self.conv = conv_block(192, 192, kernel_size=3, stride=2) + def __init__( + self, + conv_block: Type[nn.Module] = ConvNormAct, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.conv = conv_block(192, 192, kernel_size=3, stride=2, **dd) self.maxpool = nn.MaxPool2d(3, stride=2) def forward(self, x): @@ -67,24 +85,30 @@ def forward(self, x): class InceptionA(nn.Module): - def __init__(self, conv_block=ConvNormAct): - super(InceptionA, self).__init__() - self.branch0 = conv_block(384, 96, kernel_size=1, stride=1) + def __init__( + self, + conv_block: Type[nn.Module] = ConvNormAct, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.branch0 = conv_block(384, 96, kernel_size=1, stride=1, **dd) self.branch1 = nn.Sequential( - conv_block(384, 64, kernel_size=1, stride=1), - conv_block(64, 96, kernel_size=3, stride=1, padding=1) + conv_block(384, 64, kernel_size=1, stride=1, **dd), + conv_block(64, 96, kernel_size=3, stride=1, padding=1, **dd) ) self.branch2 = nn.Sequential( - conv_block(384, 64, kernel_size=1, stride=1), - conv_block(64, 96, kernel_size=3, stride=1, padding=1), - conv_block(96, 96, kernel_size=3, stride=1, padding=1) + conv_block(384, 64, kernel_size=1, stride=1, **dd), + conv_block(64, 96, kernel_size=3, stride=1, padding=1, **dd), + conv_block(96, 96, kernel_size=3, stride=1, padding=1, **dd) ) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - conv_block(384, 96, kernel_size=1, stride=1) + conv_block(384, 96, kernel_size=1, stride=1, **dd) ) def forward(self, x): @@ -97,14 +121,20 @@ def forward(self, x): class ReductionA(nn.Module): - def __init__(self, conv_block=ConvNormAct): - super(ReductionA, self).__init__() - self.branch0 = conv_block(384, 384, kernel_size=3, stride=2) + def __init__( + self, + conv_block: Type[nn.Module] = ConvNormAct, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.branch0 = conv_block(384, 384, kernel_size=3, stride=2, **dd) self.branch1 = nn.Sequential( - conv_block(384, 192, kernel_size=1, stride=1), - conv_block(192, 224, kernel_size=3, stride=1, padding=1), - conv_block(224, 256, kernel_size=3, stride=2) + conv_block(384, 192, kernel_size=1, stride=1, **dd), + conv_block(192, 224, kernel_size=3, stride=1, padding=1, **dd), + conv_block(224, 256, kernel_size=3, stride=2, **dd) ) self.branch2 = nn.MaxPool2d(3, stride=2) @@ -118,27 +148,33 @@ def forward(self, x): class InceptionB(nn.Module): - def __init__(self, conv_block=ConvNormAct): - super(InceptionB, self).__init__() - self.branch0 = conv_block(1024, 384, kernel_size=1, stride=1) + def __init__( + self, + conv_block: Type[nn.Module] = ConvNormAct, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.branch0 = conv_block(1024, 384, kernel_size=1, stride=1, **dd) self.branch1 = nn.Sequential( - conv_block(1024, 192, kernel_size=1, stride=1), - conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), - conv_block(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)) + conv_block(1024, 192, kernel_size=1, stride=1, **dd), + conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd), + conv_block(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd) ) self.branch2 = nn.Sequential( - conv_block(1024, 192, kernel_size=1, stride=1), - conv_block(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), - conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), - conv_block(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), - conv_block(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)) + conv_block(1024, 192, kernel_size=1, stride=1, **dd), + conv_block(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd), + conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd), + conv_block(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd), + conv_block(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd) ) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - conv_block(1024, 128, kernel_size=1, stride=1) + conv_block(1024, 128, kernel_size=1, stride=1, **dd) ) def forward(self, x): @@ -151,19 +187,25 @@ def forward(self, x): class ReductionB(nn.Module): - def __init__(self, conv_block=ConvNormAct): - super(ReductionB, self).__init__() + def __init__( + self, + conv_block: Type[nn.Module] = ConvNormAct, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.branch0 = nn.Sequential( - conv_block(1024, 192, kernel_size=1, stride=1), - conv_block(192, 192, kernel_size=3, stride=2) + conv_block(1024, 192, kernel_size=1, stride=1, **dd), + conv_block(192, 192, kernel_size=3, stride=2, **dd) ) self.branch1 = nn.Sequential( - conv_block(1024, 256, kernel_size=1, stride=1), - conv_block(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), - conv_block(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), - conv_block(320, 320, kernel_size=3, stride=2) + conv_block(1024, 256, kernel_size=1, stride=1, **dd), + conv_block(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd), + conv_block(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd), + conv_block(320, 320, kernel_size=3, stride=2, **dd) ) self.branch2 = nn.MaxPool2d(3, stride=2) @@ -177,24 +219,30 @@ def forward(self, x): class InceptionC(nn.Module): - def __init__(self, conv_block=ConvNormAct): - super(InceptionC, self).__init__() + def __init__( + self, + conv_block: Type[nn.Module] = ConvNormAct, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() - self.branch0 = conv_block(1536, 256, kernel_size=1, stride=1) + self.branch0 = conv_block(1536, 256, kernel_size=1, stride=1, **dd) - self.branch1_0 = conv_block(1536, 384, kernel_size=1, stride=1) - self.branch1_1a = conv_block(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) - self.branch1_1b = conv_block(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + self.branch1_0 = conv_block(1536, 384, kernel_size=1, stride=1, **dd) + self.branch1_1a = conv_block(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1), **dd) + self.branch1_1b = conv_block(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0), **dd) - self.branch2_0 = conv_block(1536, 384, kernel_size=1, stride=1) - self.branch2_1 = conv_block(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)) - self.branch2_2 = conv_block(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)) - self.branch2_3a = conv_block(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) - self.branch2_3b = conv_block(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + self.branch2_0 = conv_block(1536, 384, kernel_size=1, stride=1, **dd) + self.branch2_1 = conv_block(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0), **dd) + self.branch2_2 = conv_block(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1), **dd) + self.branch2_3a = conv_block(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1), **dd) + self.branch2_3b = conv_block(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0), **dd) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - conv_block(1536, 256, kernel_size=1, stride=1) + conv_block(1536, 256, kernel_size=1, stride=1, **dd) ) def forward(self, x): @@ -221,16 +269,19 @@ def forward(self, x): class InceptionV4(nn.Module): def __init__( self, - num_classes=1000, - in_chans=3, - output_stride=32, - drop_rate=0., - global_pool='avg', - norm_layer='batchnorm2d', - norm_eps=1e-3, - act_layer='relu', - ): - super(InceptionV4, self).__init__() + num_classes: int = 1000, + in_chans: int = 3, + output_stride: int = 32, + drop_rate: float = 0., + global_pool: str = 'avg', + norm_layer: str = 'batchnorm2d', + norm_eps: float = 1e-3, + act_layer: str = 'relu', + device=None, + dtype=None, + ) -> None: + dd = {'device': device, 'dtype': dtype} + super().__init__() assert output_stride == 32 self.num_classes = num_classes self.num_features = self.head_hidden_size = 1536 @@ -244,18 +295,18 @@ def __init__( ) features = [ - conv_block(in_chans, 32, kernel_size=3, stride=2), - conv_block(32, 32, kernel_size=3, stride=1), - conv_block(32, 64, kernel_size=3, stride=1, padding=1), - Mixed3a(conv_block), - Mixed4a(conv_block), - Mixed5a(conv_block), + conv_block(in_chans, 32, kernel_size=3, stride=2, **dd), + conv_block(32, 32, kernel_size=3, stride=1, **dd), + conv_block(32, 64, kernel_size=3, stride=1, padding=1, **dd), + Mixed3a(conv_block, **dd), + Mixed4a(conv_block, **dd), + Mixed5a(conv_block, **dd), ] - features += [InceptionA(conv_block) for _ in range(4)] - features += [ReductionA(conv_block)] # Mixed6a - features += [InceptionB(conv_block) for _ in range(7)] - features += [ReductionB(conv_block)] # Mixed7a - features += [InceptionC(conv_block) for _ in range(3)] + features += [InceptionA(conv_block, **dd) for _ in range(4)] + features += [ReductionA(conv_block, **dd)] # Mixed6a + features += [InceptionB(conv_block, **dd) for _ in range(7)] + features += [ReductionB(conv_block, **dd)] # Mixed7a + features += [InceptionC(conv_block, **dd) for _ in range(3)] self.features = nn.Sequential(*features) self.feature_info = [ dict(num_chs=64, reduction=2, module='features.2'), @@ -265,7 +316,12 @@ def __init__( dict(num_chs=1536, reduction=32, module='features.21'), ] self.global_pool, self.head_drop, self.last_linear = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate) + self.num_features, + self.num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + **dd, + ) @torch.jit.ignore def group_matcher(self, coarse=False): diff --git a/timm/models/levit.py b/timm/models/levit.py index a4c9ce628a..affb48db8d 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -25,7 +25,7 @@ # Copyright 2020 Ross Wightman, Apache-2.0 License from collections import OrderedDict from functools import partial -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -42,10 +42,22 @@ class ConvNorm(nn.Module): def __init__( - self, in_chs, out_chs, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1): + self, + in_chs: int, + out_chs: int, + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bn_weight_init: float = 1, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.linear = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, dilation, groups, bias=False) - self.bn = nn.BatchNorm2d(out_chs) + self.linear = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, dilation, groups, bias=False, **dd) + self.bn = nn.BatchNorm2d(out_chs, **dd) nn.init.constant_(self.bn.weight, bn_weight_init) @@ -67,10 +79,18 @@ def forward(self, x): class LinearNorm(nn.Module): - def __init__(self, in_features, out_features, bn_weight_init=1): + def __init__( + self, + in_features: int, + out_features: int, + bn_weight_init: float = 1, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.linear = nn.Linear(in_features, out_features, bias=False) - self.bn = nn.BatchNorm1d(out_features) + self.linear = nn.Linear(in_features, out_features, bias=False, **dd) + self.bn = nn.BatchNorm1d(out_features, **dd) nn.init.constant_(self.bn.weight, bn_weight_init) @@ -91,11 +111,21 @@ def forward(self, x): class NormLinear(nn.Module): - def __init__(self, in_features, out_features, bias=True, std=0.02, drop=0.): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + std: float = 0.02, + drop: float = 0., + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.bn = nn.BatchNorm1d(in_features) + self.bn = nn.BatchNorm1d(in_features, **dd) self.drop = nn.Dropout(drop) - self.linear = nn.Linear(in_features, out_features, bias=bias) + self.linear = nn.Linear(in_features, out_features, bias=bias, **dd) trunc_normal_(self.linear.weight, std=std) if self.linear.bias is not None: @@ -121,33 +151,56 @@ def forward(self, x): class Stem8(nn.Sequential): - def __init__(self, in_chs, out_chs, act_layer): + def __init__( + self, + in_chs: int, + out_chs: int, + act_layer: Type[nn.Module], + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.stride = 8 - self.add_module('conv1', ConvNorm(in_chs, out_chs // 4, 3, stride=2, padding=1)) + self.add_module('conv1', ConvNorm(in_chs, out_chs // 4, 3, stride=2, padding=1, **dd)) self.add_module('act1', act_layer()) - self.add_module('conv2', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1)) + self.add_module('conv2', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1, **dd)) self.add_module('act2', act_layer()) - self.add_module('conv3', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1)) + self.add_module('conv3', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1, **dd)) class Stem16(nn.Sequential): - def __init__(self, in_chs, out_chs, act_layer): + def __init__( + self, + in_chs: int, + out_chs: int, + act_layer: Type[nn.Module], + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.stride = 16 - self.add_module('conv1', ConvNorm(in_chs, out_chs // 8, 3, stride=2, padding=1)) + self.add_module('conv1', ConvNorm(in_chs, out_chs // 8, 3, stride=2, padding=1, **dd)) self.add_module('act1', act_layer()) - self.add_module('conv2', ConvNorm(out_chs // 8, out_chs // 4, 3, stride=2, padding=1)) + self.add_module('conv2', ConvNorm(out_chs // 8, out_chs // 4, 3, stride=2, padding=1, **dd)) self.add_module('act2', act_layer()) - self.add_module('conv3', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1)) + self.add_module('conv3', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1, **dd)) self.add_module('act3', act_layer()) - self.add_module('conv4', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1)) + self.add_module('conv4', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1, **dd)) class Downsample(nn.Module): - def __init__(self, stride, resolution, use_pool=False): + def __init__( + self, + stride: int, + resolution: Union[int, Tuple[int, int]], + use_pool: bool = False, + device=None, + dtype=None, + ): super().__init__() self.stride = stride self.resolution = to_2tuple(resolution) @@ -168,14 +221,17 @@ class Attention(nn.Module): def __init__( self, - dim, - key_dim, - num_heads=8, - attn_ratio=4., - resolution=14, - use_conv=False, - act_layer=nn.SiLU, + dim: int, + key_dim: int, + num_heads: int = 8, + attn_ratio: float = 4., + resolution: Union[int, Tuple[int, int]] = 14, + use_conv: bool = False, + act_layer: Type[nn.Module] = nn.SiLU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() ln_layer = ConvNorm if use_conv else LinearNorm resolution = to_2tuple(resolution) @@ -188,14 +244,17 @@ def __init__( self.val_dim = int(attn_ratio * key_dim) self.val_attn_dim = int(attn_ratio * key_dim) * num_heads - self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2) + self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2, **dd) self.proj = nn.Sequential(OrderedDict([ ('act', act_layer()), - ('ln', ln_layer(self.val_attn_dim, dim, bn_weight_init=0)) + ('ln', ln_layer(self.val_attn_dim, dim, bn_weight_init=0, **dd)) ])) - self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1])) - pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1], **dd)) + pos = torch.stack(ndgrid( + torch.arange(resolution[0], device=device, dtype=torch.long), + torch.arange(resolution[1], device=device, dtype=torch.long), + )).flatten(1) rel_pos = (pos[..., :, None] - pos[..., None, :]).abs() rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1] self.register_buffer('attention_bias_idxs', rel_pos, persistent=False) @@ -247,17 +306,20 @@ class AttentionDownsample(nn.Module): def __init__( self, - in_dim, - out_dim, - key_dim, - num_heads=8, - attn_ratio=2.0, - stride=2, - resolution=14, - use_conv=False, - use_pool=False, - act_layer=nn.SiLU, + in_dim: int, + out_dim: int, + key_dim: int, + num_heads: int = 8, + attn_ratio: float = 2.0, + stride: int = 2, + resolution: Union[int, Tuple[int, int]] = 14, + use_conv: bool = False, + use_pool: bool = False, + act_layer: Type[nn.Module] = nn.SiLU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() resolution = to_2tuple(resolution) @@ -278,23 +340,26 @@ def __init__( kernel_size=3 if use_pool else 1, padding=1 if use_pool else 0, count_include_pad=False) else: ln_layer = LinearNorm - sub_layer = partial(Downsample, resolution=resolution, use_pool=use_pool) + sub_layer = partial(Downsample, resolution=resolution, use_pool=use_pool, **dd) - self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim) + self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim, **dd) self.q = nn.Sequential(OrderedDict([ ('down', sub_layer(stride=stride)), - ('ln', ln_layer(in_dim, self.key_attn_dim)) + ('ln', ln_layer(in_dim, self.key_attn_dim, **dd)) ])) self.proj = nn.Sequential(OrderedDict([ ('act', act_layer()), - ('ln', ln_layer(self.val_attn_dim, out_dim)) + ('ln', ln_layer(self.val_attn_dim, out_dim, **dd)) ])) - self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1])) - k_pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1], **dd)) + k_pos = torch.stack(ndgrid( + torch.arange(resolution[0], device=device, dtype=torch.long), + torch.arange(resolution[1], device=device, dtype=torch.long), + )).flatten(1) q_pos = torch.stack(ndgrid( - torch.arange(0, resolution[0], step=stride), - torch.arange(0, resolution[1], step=stride) + torch.arange(0, resolution[0], step=stride, device=device, dtype=torch.long), + torch.arange(0, resolution[1], step=stride, device=device, dtype=torch.long), )).flatten(1) rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs() rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1] @@ -348,22 +413,25 @@ class LevitMlp(nn.Module): """ def __init__( self, - in_features, - hidden_features=None, - out_features=None, - use_conv=False, - act_layer=nn.SiLU, - drop=0. + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + use_conv: bool = False, + act_layer: Type[nn.Module] = nn.SiLU, + drop: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features ln_layer = ConvNorm if use_conv else LinearNorm - self.ln1 = ln_layer(in_features, hidden_features) + self.ln1 = ln_layer(in_features, hidden_features, **dd) self.act = act_layer() self.drop = nn.Dropout(drop) - self.ln2 = ln_layer(hidden_features, out_features, bn_weight_init=0) + self.ln2 = ln_layer(hidden_features, out_features, bn_weight_init=0, **dd) def forward(self, x): x = self.ln1(x) @@ -376,19 +444,22 @@ def forward(self, x): class LevitDownsample(nn.Module): def __init__( self, - in_dim, - out_dim, - key_dim, - num_heads=8, - attn_ratio=4., - mlp_ratio=2., - act_layer=nn.SiLU, - attn_act_layer=None, - resolution=14, - use_conv=False, - use_pool=False, - drop_path=0., + in_dim: int, + out_dim: int, + key_dim: int, + num_heads: int = 8, + attn_ratio: float = 4., + mlp_ratio: float = 2., + act_layer: Type[nn.Module] = nn.SiLU, + attn_act_layer: Optional[Type[nn.Module]] = None, + resolution: Union[int, Tuple[int, int]] = 14, + use_conv: bool = False, + use_pool: bool = False, + drop_path: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() attn_act_layer = attn_act_layer or act_layer @@ -402,13 +473,15 @@ def __init__( resolution=resolution, use_conv=use_conv, use_pool=use_pool, + **dd, ) self.mlp = LevitMlp( out_dim, int(out_dim * mlp_ratio), use_conv=use_conv, - act_layer=act_layer + act_layer=act_layer, + **dd, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -421,17 +494,20 @@ def forward(self, x): class LevitBlock(nn.Module): def __init__( self, - dim, - key_dim, - num_heads=8, - attn_ratio=4., - mlp_ratio=2., - resolution=14, - use_conv=False, - act_layer=nn.SiLU, - attn_act_layer=None, - drop_path=0., + dim: int, + key_dim: int, + num_heads: int = 8, + attn_ratio: float = 4., + mlp_ratio: float = 2., + resolution: Union[int, Tuple[int, int]] = 14, + use_conv: bool = False, + act_layer: Type[nn.Module] = nn.SiLU, + attn_act_layer: Optional[Type[nn.Module]] = None, + drop_path: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() attn_act_layer = attn_act_layer or act_layer @@ -443,6 +519,7 @@ def __init__( resolution=resolution, use_conv=use_conv, act_layer=attn_act_layer, + **dd, ) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -450,7 +527,8 @@ def __init__( dim, int(dim * mlp_ratio), use_conv=use_conv, - act_layer=act_layer + act_layer=act_layer, + **dd, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -463,20 +541,23 @@ def forward(self, x): class LevitStage(nn.Module): def __init__( self, - in_dim, - out_dim, - key_dim, - depth=4, - num_heads=8, - attn_ratio=4.0, - mlp_ratio=4.0, - act_layer=nn.SiLU, - attn_act_layer=None, - resolution=14, - downsample='', - use_conv=False, - drop_path=0., + in_dim: int, + out_dim: int, + key_dim: int, + depth: int = 4, + num_heads: int = 8, + attn_ratio: float = 4.0, + mlp_ratio: float = 4.0, + act_layer: Type[nn.Module] = nn.SiLU, + attn_act_layer: Optional[Type[nn.Module]] = None, + resolution: Union[int, Tuple[int, int]] = 14, + downsample: str = '', + use_conv: bool = False, + drop_path: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() resolution = to_2tuple(resolution) @@ -493,6 +574,7 @@ def __init__( resolution=resolution, use_conv=use_conv, drop_path=drop_path, + **dd, ) resolution = [(r - 1) // 2 + 1 for r in resolution] else: @@ -512,6 +594,7 @@ def __init__( resolution=resolution, use_conv=use_conv, drop_path=drop_path, + **dd, )] self.blocks = nn.Sequential(*blocks) @@ -530,26 +613,30 @@ class Levit(nn.Module): def __init__( self, - img_size=224, - in_chans=3, - num_classes=1000, - embed_dim=(192,), - key_dim=64, - depth=(12,), - num_heads=(3,), - attn_ratio=2., - mlp_ratio=2., - stem_backbone=None, - stem_stride=None, - stem_type='s16', - down_op='subsample', - act_layer='hard_swish', - attn_act_layer=None, - use_conv=False, - global_pool='avg', - drop_rate=0., - drop_path_rate=0.): + img_size: Union[int, Tuple[int, int]] = 224, + in_chans: int = 3, + num_classes: int = 1000, + embed_dim: Tuple[int, ...] = (192,), + key_dim: int = 64, + depth: Tuple[int, ...] = (12,), + num_heads: Union[int, Tuple[int, ...]] = (3,), + attn_ratio: Union[float, Tuple[float, ...]] = 2., + mlp_ratio: Union[float, Tuple[float, ...]] = 2., + stem_backbone: Optional[nn.Module] = None, + stem_stride: Optional[int] = None, + stem_type: str = 's16', + down_op: str = 'subsample', + act_layer: str = 'hard_swish', + attn_act_layer: Optional[str] = None, + use_conv: bool = False, + global_pool: str = 'avg', + drop_rate: float = 0., + drop_path_rate: float = 0., + device=None, + dtype=None, + ): super().__init__() + dd = {'device': device, 'dtype': dtype} act_layer = get_act_layer(act_layer) attn_act_layer = get_act_layer(attn_act_layer or act_layer) self.use_conv = use_conv @@ -574,9 +661,9 @@ def __init__( else: assert stem_type in ('s16', 's8') if stem_type == 's16': - self.stem = Stem16(in_chans, embed_dim[0], act_layer=act_layer) + self.stem = Stem16(in_chans, embed_dim[0], act_layer=act_layer, **dd) else: - self.stem = Stem8(in_chans, embed_dim[0], act_layer=act_layer) + self.stem = Stem8(in_chans, embed_dim[0], act_layer=act_layer, **dd) stride = self.stem.stride resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))]) @@ -597,7 +684,8 @@ def __init__( resolution=resolution, use_conv=use_conv, downsample=down_op if stage_stride == 2 else '', - drop_path=drop_path_rate + drop_path=drop_path_rate, + **dd, )] stride *= stage_stride resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution]) @@ -606,7 +694,7 @@ def __init__( self.stages = nn.Sequential(*stages) # Classifier head - self.head = NormLinear(embed_dim[-1], num_classes, drop=drop_rate) if num_classes > 0 else nn.Identity() + self.head = NormLinear(embed_dim[-1], num_classes, drop=drop_rate, **dd) if num_classes > 0 else nn.Identity() @torch.jit.ignore def no_weight_decay(self): @@ -726,7 +814,8 @@ def forward(self, x): class LevitDistilled(Levit): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity() + dd = {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)} + self.head_dist = NormLinear(self.num_features, self.num_classes, **dd) if self.num_classes > 0 else nn.Identity() self.distilled_training = False # must set this True to train w/ distillation token @torch.jit.ignore diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index c33c8d75e8..991781ad4b 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -6,7 +6,7 @@ InceptionNeXt (https://github.com/sail-sg/inceptionnext) """ from collections import OrderedDict -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import torch from torch import nn @@ -26,30 +26,35 @@ class Stem(nn.Module): def __init__( self, - in_chs=3, - out_chs=96, + in_chs: int = 3, + out_chs: int = 96, mid_norm: bool = True, - act_layer=nn.GELU, - norm_layer=LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = LayerNorm, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.conv1 = nn.Conv2d( in_chs, out_chs // 2, kernel_size=3, stride=2, - padding=1 + padding=1, + **dd, ) - self.norm1 = norm_layer(out_chs // 2) if mid_norm else None + self.norm1 = norm_layer(out_chs // 2, **dd) if mid_norm else None self.act = act_layer() self.conv2 = nn.Conv2d( out_chs // 2, out_chs, kernel_size=3, stride=2, - padding=1 + padding=1, + **dd, ) - self.norm2 = norm_layer(out_chs) + self.norm2 = norm_layer(out_chs, **dd) def forward(self, x): x = self.conv1(x) @@ -68,18 +73,22 @@ class DownsampleNormFirst(nn.Module): def __init__( self, - in_chs=96, - out_chs=198, - norm_layer=LayerNorm, + in_chs: int = 96, + out_chs: int = 198, + norm_layer: Type[nn.Module] = LayerNorm, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm = norm_layer(in_chs) + self.norm = norm_layer(in_chs, **dd) self.conv = nn.Conv2d( in_chs, out_chs, kernel_size=3, stride=2, - padding=1 + padding=1, + **dd, ) def forward(self, x): @@ -94,19 +103,23 @@ class Downsample(nn.Module): def __init__( self, - in_chs=96, - out_chs=198, - norm_layer=LayerNorm, + in_chs: int = 96, + out_chs: int = 198, + norm_layer: Type[nn.Module] = LayerNorm, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.conv = nn.Conv2d( in_chs, out_chs, kernel_size=3, stride=2, - padding=1 + padding=1, + **dd, ) - self.norm = norm_layer(out_chs) + self.norm = norm_layer(out_chs, **dd) def forward(self, x): x = x.permute(0, 3, 1, 2) @@ -122,15 +135,18 @@ class MlpHead(nn.Module): def __init__( self, - in_features, - num_classes=1000, - pool_type='avg', - act_layer=nn.GELU, - mlp_ratio=4, - norm_layer=LayerNorm, - drop_rate=0., - bias=True, + in_features: int, + num_classes: int = 1000, + pool_type: str = 'avg', + act_layer: Type[nn.Module] = nn.GELU, + mlp_ratio: Optional[int] = 4, + norm_layer: Type[nn.Module] = LayerNorm, + drop_rate: float = 0., + bias: bool = True, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() if mlp_ratio is not None: hidden_size = int(mlp_ratio * in_features) @@ -140,19 +156,19 @@ def __init__( self.in_features = in_features self.hidden_size = hidden_size or in_features - self.norm = norm_layer(in_features) + self.norm = norm_layer(in_features, **dd) if hidden_size: self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(in_features, hidden_size)), + ('fc', nn.Linear(in_features, hidden_size, **dd)), ('act', act_layer()), - ('norm', norm_layer(hidden_size)) + ('norm', norm_layer(hidden_size, **dd)) ])) self.num_features = hidden_size else: self.num_features = in_features self.pre_logits = nn.Identity() - self.fc = nn.Linear(self.num_features, num_classes, bias=bias) if num_classes > 0 else nn.Identity() + self.fc = nn.Linear(self.num_features, num_classes, bias=bias, **dd) if num_classes > 0 else nn.Identity() self.head_dropout = nn.Dropout(drop_rate) def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False): @@ -187,20 +203,23 @@ class GatedConvBlock(nn.Module): def __init__( self, - dim, - expansion_ratio=8 / 3, - kernel_size=7, - conv_ratio=1.0, - ls_init_value=None, - norm_layer=LayerNorm, - act_layer=nn.GELU, - drop_path=0., + dim: int, + expansion_ratio: float = 8 / 3, + kernel_size: int = 7, + conv_ratio: float = 1.0, + ls_init_value: Optional[float] = None, + norm_layer: Type[nn.Module] = LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + drop_path: float = 0., + device=None, + dtype=None, **kwargs ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm = norm_layer(dim) + self.norm = norm_layer(dim, **dd) hidden = int(expansion_ratio * dim) - self.fc1 = nn.Linear(dim, hidden * 2) + self.fc1 = nn.Linear(dim, hidden * 2, **dd) self.act = act_layer() conv_channels = int(conv_ratio * dim) self.split_indices = (hidden, hidden - conv_channels, conv_channels) @@ -209,10 +228,11 @@ def __init__( conv_channels, kernel_size=kernel_size, padding=kernel_size // 2, - groups=conv_channels + groups=conv_channels, + **dd, ) - self.fc2 = nn.Linear(hidden, dim) - self.ls = LayerScale(dim) if ls_init_value is not None else nn.Identity() + self.fc2 = nn.Linear(hidden, dim, **dd) + self.ls = LayerScale(dim, **dd) if ls_init_value is not None else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): @@ -233,26 +253,29 @@ class MambaOutStage(nn.Module): def __init__( self, - dim, + dim: int, dim_out: Optional[int] = None, depth: int = 4, - expansion_ratio=8 / 3, - kernel_size=7, - conv_ratio=1.0, + expansion_ratio: float = 8 / 3, + kernel_size: int = 7, + conv_ratio: float = 1.0, downsample: str = '', ls_init_value: Optional[float] = None, - norm_layer=LayerNorm, - act_layer=nn.GELU, - drop_path=0., + norm_layer: Type[nn.Module] = LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + drop_path: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() dim_out = dim_out or dim self.grad_checkpointing = False if downsample == 'conv': - self.downsample = Downsample(dim, dim_out, norm_layer=norm_layer) + self.downsample = Downsample(dim, dim_out, norm_layer=norm_layer, **dd) elif downsample == 'conv_nf': - self.downsample = DownsampleNormFirst(dim, dim_out, norm_layer=norm_layer) + self.downsample = DownsampleNormFirst(dim, dim_out, norm_layer=norm_layer, **dd) else: assert dim == dim_out self.downsample = nn.Identity() @@ -267,6 +290,7 @@ def __init__( norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path[j] if isinstance(drop_path, (list, tuple)) else drop_path, + **dd, ) for j in range(depth) ]) @@ -299,24 +323,27 @@ class MambaOut(nn.Module): def __init__( self, - in_chans=3, - num_classes=1000, - global_pool='avg', - depths=(3, 3, 9, 3), - dims=(96, 192, 384, 576), - norm_layer=LayerNorm, - act_layer=nn.GELU, - conv_ratio=1.0, - expansion_ratio=8/3, - kernel_size=7, - stem_mid_norm=True, - ls_init_value=None, - downsample='conv', - drop_path_rate=0., - drop_rate=0., - head_fn='default', + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + depths: Tuple[int, ...] = (3, 3, 9, 3), + dims: Tuple[int, ...] = (96, 192, 384, 576), + norm_layer: Type[nn.Module] = LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + conv_ratio: float = 1.0, + expansion_ratio: float = 8/3, + kernel_size: int = 7, + stem_mid_norm: bool = True, + ls_init_value: Optional[float] = None, + downsample: str = 'conv', + drop_path_rate: float = 0., + drop_rate: float = 0., + head_fn: str = 'default', + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.drop_rate = drop_rate self.output_fmt = 'NHWC' @@ -336,6 +363,7 @@ def __init__( mid_norm=stem_mid_norm, act_layer=act_layer, norm_layer=norm_layer, + **dd, ) prev_dim = dims[0] dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) @@ -358,6 +386,7 @@ def __init__( norm_layer=norm_layer, act_layer=act_layer, drop_path=dp_rates[i], + **dd, ) self.stages.append(stage) prev_dim = dim @@ -373,6 +402,7 @@ def __init__( pool_type=global_pool, drop_rate=drop_rate, norm_layer=norm_layer, + **dd, ) else: # more typical norm -> pool -> fc -> act -> fc @@ -383,6 +413,7 @@ def __init__( pool_type=global_pool, norm_layer=norm_layer, drop_rate=drop_rate, + **dd, ) self.num_features = prev_dim self.head_hidden_size = self.head.num_features diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 3be91f3f94..7cd332f925 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -45,10 +45,33 @@ from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import Mlp, ConvMlp, DropPath, calculate_drop_path_rates, LayerNorm, ClassifierHead, NormMlpClassifierHead -from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d -from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert -from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn, resize_rel_pos_bias_table +from timm.layers import ( + Mlp, + ConvMlp, + DropPath, + calculate_drop_path_rates, + LayerNorm, + LayerScale, + LayerScale2d, + ClassifierHead, + NormMlpClassifierHead, + create_attn, + get_act_layer, + get_norm_layer, + get_norm_act_layer, + create_conv2d, + create_pool2d, + trunc_normal_tf_, + to_2tuple, + extend_tuple, + make_divisible, + _assert, + RelPosMlp, + RelPosBias, + RelPosBiasTf, + use_fused_attn, + resize_rel_pos_bias_table, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function @@ -157,7 +180,9 @@ def __init__( head_first: bool = True, rel_pos_cls: Optional[Callable] = None, attn_drop: float = 0., - proj_drop: float = 0. + proj_drop: float = 0., + device=None, + dtype=None, ): """ Args: @@ -171,6 +196,7 @@ def __init__( attn_drop: Attention dropout rate. proj_drop: Projection dropout rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() dim_out = dim_out or dim dim_attn = dim_out if expand_first else dim @@ -180,10 +206,10 @@ def __init__( self.scale = dim_head ** -0.5 self.fused_attn = use_fused_attn() - self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) - self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None + self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias, **dd) + self.rel_pos = rel_pos_cls(num_heads=self.num_heads, **dd) if rel_pos_cls else None self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) + self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -238,7 +264,9 @@ def __init__( head_first: bool = True, rel_pos_cls: Optional[Callable] = None, attn_drop: float = 0., - proj_drop: float = 0. + proj_drop: float = 0., + device=None, + dtype=None, ): """ Args: @@ -252,6 +280,7 @@ def __init__( attn_drop: Attention dropout rate. proj_drop: Projection dropout rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() dim_out = dim_out or dim dim_attn = dim_out if expand_first and dim_out > dim else dim @@ -262,10 +291,10 @@ def __init__( self.scale = dim_head ** -0.5 self.fused_attn = use_fused_attn() - self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias) - self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None + self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias, **dd) + self.rel_pos = rel_pos_cls(num_heads=self.num_heads, **dd) if rel_pos_cls else None self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim_attn, dim_out, bias=bias) + self.proj = nn.Linear(dim_attn, dim_out, bias=bias, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -306,44 +335,6 @@ def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None return x -class LayerScale(nn.Module): - """Per-channel scaling layer.""" - - def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False): - """ - Args: - dim: Number of channels. - init_values: Initial scaling value. - inplace: Whether to perform inplace operations. - """ - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - gamma = self.gamma - return x.mul_(gamma) if self.inplace else x * gamma - - -class LayerScale2d(nn.Module): - """Per-channel scaling layer for 2D tensors.""" - - def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False): - """ - Args: - dim: Number of channels. - init_values: Initial scaling value. - inplace: Whether to perform inplace operations. - """ - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - gamma = self.gamma.view(1, -1, 1, 1) - return x.mul_(gamma) if self.inplace else x * gamma - - class Downsample2d(nn.Module): """A downsample pooling module supporting several maxpool and avgpool modes. @@ -360,6 +351,8 @@ def __init__( pool_type: str = 'avg2', padding: str = '', bias: bool = True, + device=None, + dtype=None, ): """ Args: @@ -382,7 +375,7 @@ def __init__( self.pool = create_pool2d('avg', 2, padding=padding or 0) if dim != dim_out: - self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) + self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias, device=device, dtype=dtype) else: self.expand = nn.Identity() @@ -436,6 +429,8 @@ def __init__( rel_pos_cls: Optional[Callable] = None, cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), drop_path: float = 0., + device=None, + dtype=None, ): """ Args: @@ -446,20 +441,21 @@ def __init__( cfg: Transformer block configuration. drop_path: Drop path rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) act_layer = get_act_layer(cfg.act_layer) if stride == 2: - self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias) + self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias, **dd) self.norm1 = nn.Sequential(OrderedDict([ - ('norm', norm_layer(dim)), - ('down', Downsample2d(dim, dim, pool_type=cfg.pool_type)), + ('norm', norm_layer(dim, **dd)), + ('down', Downsample2d(dim, dim, pool_type=cfg.pool_type, **dd)), ])) else: assert dim == dim_out self.shortcut = nn.Identity() - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = Attention2d( dim, @@ -469,18 +465,21 @@ def __init__( bias=cfg.attn_bias, rel_pos_cls=rel_pos_cls, attn_drop=cfg.attn_drop, - proj_drop=cfg.proj_drop + proj_drop=cfg.proj_drop, + **dd, ) - self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim_out) + self.norm2 = norm_layer(dim_out, **dd) self.mlp = ConvMlp( in_features=dim_out, hidden_features=int(dim_out * cfg.expand_ratio), act_layer=act_layer, - drop=cfg.proj_drop) - self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + drop=cfg.proj_drop, + **dd, + ) + self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def init_weights(self, scheme: str = '') -> None: @@ -536,7 +535,9 @@ def __init__( stride: int = 1, dilation: Tuple[int, int] = (1, 1), cfg: MaxxVitConvCfg = MaxxVitConvCfg(), - drop_path: float = 0. + drop_path: float = 0., + device=None, + dtype=None, ): """ Args: @@ -547,14 +548,15 @@ def __init__( cfg: Convolution block configuration. drop_path: Drop path rate. """ - super(MbConvBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps) mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio) groups = num_groups(cfg.group_size, mid_chs) if stride == 2: self.shortcut = Downsample2d( - in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias, padding=cfg.padding) + in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias, padding=cfg.padding, **dd) else: self.shortcut = nn.Identity() @@ -570,17 +572,24 @@ def __init__( else: stride_2, dilation_2 = stride, dilation[0] - self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act) + self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act, **dd) if stride_pool > 1: - self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type, padding=cfg.padding) + self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type, padding=cfg.padding, **dd) else: self.down = nn.Identity() - self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1) - self.norm1 = norm_act_layer(mid_chs) + self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1, **dd) + self.norm1 = norm_act_layer(mid_chs, **dd) self.conv2_kxk = create_conv2d( - mid_chs, mid_chs, cfg.kernel_size, - stride=stride_2, dilation=dilation_2, groups=groups, padding=cfg.padding) + mid_chs, + mid_chs, + cfg.kernel_size, + stride=stride_2, + dilation=dilation_2, + groups=groups, + padding=cfg.padding, + **dd, + ) attn_kwargs = {} if isinstance(cfg.attn_layer, str): @@ -590,15 +599,15 @@ def __init__( # two different orderings for SE and norm2 (due to some weights and trials using SE before norm2) if cfg.attn_early: - self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) - self.norm2 = norm_act_layer(mid_chs) + self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs, **dd) + self.norm2 = norm_act_layer(mid_chs, **dd) self.se = None else: self.se_early = None - self.norm2 = norm_act_layer(mid_chs) - self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) + self.norm2 = norm_act_layer(mid_chs, **dd) + self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs, **dd) - self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias) + self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias, **dd) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def init_weights(self, scheme: str = '') -> None: @@ -639,7 +648,9 @@ def __init__( dilation: Tuple[int, int] = (1, 1), cfg: MaxxVitConvCfg = MaxxVitConvCfg(), conv_mlp: bool = True, - drop_path: float = 0. + drop_path: float = 0., + device=None, + dtype=None, ): """ Args: @@ -652,6 +663,7 @@ def __init__( conv_mlp: Whether to use convolutional MLP. drop_path: Drop path rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() out_chs = out_chs or in_chs act_layer = get_act_layer(cfg.act_layer) @@ -665,9 +677,9 @@ def __init__( self.use_conv_mlp = conv_mlp if stride == 2: - self.shortcut = Downsample2d(in_chs, out_chs) + self.shortcut = Downsample2d(in_chs, out_chs, **dd) elif in_chs != out_chs: - self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias) + self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias, **dd) else: self.shortcut = nn.Identity() @@ -680,19 +692,32 @@ def __init__( stride_dw = stride if stride_pool == 2: - self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) + self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type, **dd) else: self.down = nn.Identity() self.conv_dw = create_conv2d( - in_chs, out_chs, kernel_size=kernel_size, stride=stride_dw, dilation=dilation[1], - depthwise=True, bias=cfg.output_bias) - self.norm = norm_layer(out_chs) - self.mlp = mlp_layer(out_chs, int(cfg.expand_ratio * out_chs), bias=cfg.output_bias, act_layer=act_layer) + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride_dw, + dilation=dilation[1], + depthwise=True, + bias=cfg.output_bias, + **dd, + ) + self.norm = norm_layer(out_chs, **dd) + self.mlp = mlp_layer( + out_chs, + int(cfg.expand_ratio * out_chs), + bias=cfg.output_bias, + act_layer=act_layer, + **dd, + ) if conv_mlp: - self.ls = LayerScale2d(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() + self.ls = LayerScale2d(out_chs, cfg.init_values, **dd) if cfg.init_values else nn.Identity() else: - self.ls = LayerScale(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() + self.ls = LayerScale(out_chs, cfg.init_values, **dd) if cfg.init_values else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -778,7 +803,10 @@ def __init__( partition_type: str = 'block', cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), drop_path: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last act_layer = get_act_layer(cfg.act_layer) @@ -787,7 +815,7 @@ def __init__( self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = AttentionCl( dim, dim, @@ -797,17 +825,20 @@ def __init__( rel_pos_cls=rel_pos_cls, attn_drop=cfg.attn_drop, proj_drop=cfg.proj_drop, + **dd, ) - self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.ls1 = LayerScale(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = Mlp( in_features=dim, hidden_features=int(dim * cfg.expand_ratio), act_layer=act_layer, - drop=cfg.proj_drop) - self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + drop=cfg.proj_drop, + **dd, + ) + self.ls2 = LayerScale(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def _partition_attn(self, x): @@ -842,6 +873,8 @@ def __init__( dim: int, cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), drop_path: float = 0., + device=None, + dtype=None, ): """ Args: @@ -849,6 +882,7 @@ def __init__( cfg: Transformer block configuration. drop_path: Drop path rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() assert dim % 2 == 0 norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last @@ -858,7 +892,7 @@ def __init__( self.partition_size = to_2tuple(cfg.window_size) rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn_block = AttentionCl( dim, dim // 2, @@ -868,6 +902,7 @@ def __init__( rel_pos_cls=rel_pos_cls, attn_drop=cfg.attn_drop, proj_drop=cfg.proj_drop, + **dd, ) self.attn_grid = AttentionCl( dim, @@ -878,18 +913,21 @@ def __init__( rel_pos_cls=rel_pos_cls, attn_drop=cfg.attn_drop, proj_drop=cfg.proj_drop, + **dd, ) - self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.ls1 = LayerScale(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = Mlp( in_features=dim, hidden_features=int(dim * cfg.expand_ratio), out_features=dim, act_layer=act_layer, - drop=cfg.proj_drop) - self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + drop=cfg.proj_drop, + **dd, + ) + self.ls2 = LayerScale(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def _partition_attn(self, x: torch.Tensor) -> torch.Tensor: @@ -963,6 +1001,8 @@ def __init__( partition_type: str = 'block', cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), drop_path: float = 0., + device=None, + dtype=None, ): """ Args: @@ -971,6 +1011,7 @@ def __init__( cfg: Transformer block configuration. drop_path: Drop path rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) # NOTE this block is channels-last act_layer = get_act_layer(cfg.act_layer) @@ -979,7 +1020,7 @@ def __init__( self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = Attention2d( dim, dim, @@ -989,17 +1030,20 @@ def __init__( rel_pos_cls=rel_pos_cls, attn_drop=cfg.attn_drop, proj_drop=cfg.proj_drop, + **dd, ) - self.ls1 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.ls1 = LayerScale2d(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = ConvMlp( in_features=dim, hidden_features=int(dim * cfg.expand_ratio), act_layer=act_layer, - drop=cfg.proj_drop) - self.ls2 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + drop=cfg.proj_drop, + **dd, + ) + self.ls2 = LayerScale2d(dim, init_values=cfg.init_values, **dd) if cfg.init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def _partition_attn(self, x: torch.Tensor) -> torch.Tensor: @@ -1034,6 +1078,8 @@ def __init__( conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), drop_path: float = 0., + device=None, + dtype=None, ): """Initialize MaxxVitBlock. @@ -1045,13 +1091,14 @@ def __init__( transformer_cfg: Configuration for transformer blocks. drop_path: Drop path rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.nchw_attn = transformer_cfg.use_nchw_attn conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock - self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) + self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path, **dd) - attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) + attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path, **dd) partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs) self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) @@ -1091,6 +1138,8 @@ def __init__( conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), drop_path: float = 0., + device=None, + dtype=None, ): """ Args: @@ -1102,16 +1151,17 @@ def __init__( transformer_cfg: Transformer block configuration. drop_path: Drop path rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock if num_conv > 1: - convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)] - convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path)] * (num_conv - 1) + convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path, **dd)] + convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path, **dd)] * (num_conv - 1) self.conv = nn.Sequential(*convs) else: - self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) - self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) + self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path, **dd) + self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path, **dd) def init_weights(self, scheme: str = '') -> None: named_apply(partial(_init_transformer, scheme=scheme), self.attn) @@ -1139,6 +1189,8 @@ def __init__( transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), drop_path: Union[float, List[float]] = 0., + device=None, + dtype=None, ): """ Args: @@ -1152,6 +1204,7 @@ def __init__( conv_cfg: Convolution block configuration. drop_path: Drop path rate(s). """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.grad_checkpointing = False @@ -1168,6 +1221,7 @@ def __init__( stride=block_stride, cfg=conv_cfg, drop_path=drop_path[i], + **dd, )] elif t == 'T': rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size) @@ -1178,6 +1232,7 @@ def __init__( rel_pos_cls=rel_pos_cls, cfg=transformer_cfg, drop_path=drop_path[i], + **dd, )] elif t == 'M': blocks += [MaxxVitBlock( @@ -1187,6 +1242,7 @@ def __init__( conv_cfg=conv_cfg, transformer_cfg=transformer_cfg, drop_path=drop_path[i], + **dd, )] elif t == 'PM': blocks += [ParallelMaxxVitBlock( @@ -1196,6 +1252,7 @@ def __init__( conv_cfg=conv_cfg, transformer_cfg=transformer_cfg, drop_path=drop_path[i], + **dd, )] in_chs = out_chs self.blocks = nn.Sequential(*blocks) @@ -1221,6 +1278,8 @@ def __init__( act_layer: str = 'gelu', norm_layer: str = 'batchnorm2d', norm_eps: float = 1e-5, + device=None, + dtype=None, ): """ Args: @@ -1233,6 +1292,7 @@ def __init__( norm_layer: Normalization layer. norm_eps: Normalization epsilon. """ + dd = {'device': device, 'dtype': dtype} super().__init__() if not isinstance(out_chs, (list, tuple)): out_chs = to_2tuple(out_chs) @@ -1241,9 +1301,9 @@ def __init__( self.out_chs = out_chs[-1] self.stride = 2 - self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2, padding=padding, bias=bias) - self.norm1 = norm_act_layer(out_chs[0]) - self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1, padding=padding, bias=bias) + self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2, padding=padding, bias=bias, **dd) + self.norm1 = norm_act_layer(out_chs[0], **dd) + self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1, padding=padding, bias=bias, **dd) def init_weights(self, scheme: str = '') -> None: named_apply(partial(_init_conv, scheme=scheme), self) @@ -1301,6 +1361,8 @@ def __init__( global_pool: str = 'avg', drop_rate: float = 0., drop_path_rate: float = 0., + device=None, + dtype=None, **kwargs: Any, ): """ @@ -1315,6 +1377,7 @@ def __init__( **kwargs: Additional keyword arguments to overlay on config. """ super().__init__() + dd = {'device': device, 'dtype': dtype} img_size = to_2tuple(img_size) if kwargs: cfg = _overlay_kwargs(cfg, **kwargs) @@ -1334,6 +1397,7 @@ def __init__( act_layer=cfg.conv_cfg.act_layer, norm_layer=cfg.conv_cfg.norm_layer, norm_eps=cfg.conv_cfg.norm_eps, + **dd, ) stride = self.stem.stride self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')] @@ -1357,6 +1421,7 @@ def __init__( transformer_cfg=transformer_cfg, feat_size=feat_size, drop_path=dpr[i], + **dd, )] stride *= stage_stride in_chs = out_chs @@ -1374,12 +1439,19 @@ def __init__( pool_type=global_pool, drop_rate=drop_rate, norm_layer=final_norm_layer, + **dd, ) else: # standard classifier head w/ norm, pooling, fc classifier self.head_hidden_size = self.num_features - self.norm = final_norm_layer(self.num_features) - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + self.norm = final_norm_layer(self.num_features, **dd) + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + **dd, + ) # Weight init (default PyTorch init works well for AdamW if scheme not set) assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff') diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 0a46624976..a157124694 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -28,7 +28,7 @@ from collections import OrderedDict from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type import torch import torch.nn as nn @@ -37,8 +37,17 @@ from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, DropPath, calculate_drop_path_rates, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \ - use_fused_attn +from timm.layers import ( + trunc_normal_, + DropPath, + calculate_drop_path_rates, + SelectAdaptivePool2d, + GroupNorm1, + LayerNorm, + LayerNorm2d, + Mlp, + use_fused_attn, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint, checkpoint_seq @@ -55,19 +64,23 @@ class Stem(nn.Module): def __init__( self, - in_channels, - out_channels, - norm_layer=None, + in_channels: int, + out_channels: int, + norm_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=7, stride=4, - padding=2 + padding=2, + **dd, ) - self.norm = norm_layer(out_channels) if norm_layer else nn.Identity() + self.norm = norm_layer(out_channels, **dd) if norm_layer else nn.Identity() def forward(self, x): x = self.conv(x) @@ -82,21 +95,25 @@ class Downsampling(nn.Module): def __init__( self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - norm_layer=None, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + norm_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm = norm_layer(in_channels) if norm_layer else nn.Identity() + self.norm = norm_layer(in_channels, **dd) if norm_layer else nn.Identity() self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, - padding=padding + padding=padding, + **dd ) def forward(self, x): @@ -110,10 +127,19 @@ class Scale(nn.Module): Scale vector by element multiplications. """ - def __init__(self, dim, init_value=1.0, trainable=True, use_nchw=True): + def __init__( + self, + dim: int, + init_value: float = 1.0, + trainable: bool = True, + use_nchw: bool = True, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.shape = (dim, 1, 1) if use_nchw else (dim,) - self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable) + self.scale = nn.Parameter(init_value * torch.ones(dim, **dd), requires_grad=trainable) def forward(self, x): return x * self.scale.view(self.shape) @@ -124,7 +150,7 @@ class SquaredReLU(nn.Module): Squared ReLU: https://arxiv.org/abs/2109.08668 """ - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super().__init__() self.relu = nn.ReLU(inplace=inplace) @@ -139,18 +165,21 @@ class StarReLU(nn.Module): def __init__( self, - scale_value=1.0, - bias_value=0.0, - scale_learnable=True, - bias_learnable=True, - mode=None, - inplace=False + scale_value: float = 1.0, + bias_value: float = 0.0, + scale_learnable: bool = True, + bias_learnable: bool = True, + mode: Optional[str] = None, + inplace: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.inplace = inplace self.relu = nn.ReLU(inplace=inplace) - self.scale = nn.Parameter(scale_value * torch.ones(1), requires_grad=scale_learnable) - self.bias = nn.Parameter(bias_value * torch.ones(1), requires_grad=bias_learnable) + self.scale = nn.Parameter(scale_value * torch.ones(1, **dd), requires_grad=scale_learnable) + self.bias = nn.Parameter(bias_value * torch.ones(1, **dd), requires_grad=bias_learnable) def forward(self, x): return self.scale * self.relu(x) ** 2 + self.bias @@ -165,15 +194,18 @@ class Attention(nn.Module): def __init__( self, - dim, - head_dim=32, - num_heads=None, - qkv_bias=False, - attn_drop=0., - proj_drop=0., - proj_bias=False, + dim: int, + head_dim: int = 32, + num_heads: Optional[int] = None, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + proj_bias: bool = False, + device=None, + dtype=None, **kwargs ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.head_dim = head_dim @@ -186,9 +218,9 @@ def __init__( self.attention_dim = self.num_heads * self.head_dim - self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias) + self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): @@ -217,21 +249,21 @@ def forward(self, x): # used a custom norm with a weight term but no bias term. class GroupNorm1NoBias(GroupNorm1): - def __init__(self, num_channels, **kwargs): + def __init__(self, num_channels: int, **kwargs): super().__init__(num_channels, **kwargs) self.eps = kwargs.get('eps', 1e-6) self.bias = None class LayerNorm2dNoBias(LayerNorm2d): - def __init__(self, num_channels, **kwargs): + def __init__(self, num_channels: int, **kwargs): super().__init__(num_channels, **kwargs) self.eps = kwargs.get('eps', 1e-6) self.bias = None class LayerNormNoBias(nn.LayerNorm): - def __init__(self, num_channels, **kwargs): + def __init__(self, num_channels: int, **kwargs): super().__init__(num_channels, **kwargs) self.eps = kwargs.get('eps', 1e-6) self.bias = None @@ -244,24 +276,33 @@ class SepConv(nn.Module): def __init__( self, - dim, - expansion_ratio=2, - act1_layer=StarReLU, - act2_layer=nn.Identity, - bias=False, - kernel_size=7, - padding=3, + dim: int, + expansion_ratio: float = 2, + act1_layer: Type[nn.Module] = StarReLU, + act2_layer: Type[nn.Module] = nn.Identity, + bias: bool = False, + kernel_size: int = 7, + padding: int = 3, + device=None, + dtype=None, **kwargs ): + dd = {'device': device, 'dtype': dtype} super().__init__() mid_channels = int(expansion_ratio * dim) - self.pwconv1 = nn.Conv2d(dim, mid_channels, kernel_size=1, bias=bias) - self.act1 = act1_layer() + self.pwconv1 = nn.Conv2d(dim, mid_channels, kernel_size=1, bias=bias, **dd) + self.act1 = act1_layer(**dd) if issubclass(act1_layer, StarReLU) else act1_layer() self.dwconv = nn.Conv2d( - mid_channels, mid_channels, kernel_size=kernel_size, - padding=padding, groups=mid_channels, bias=bias) # depthwise conv - self.act2 = act2_layer() - self.pwconv2 = nn.Conv2d(mid_channels, dim, kernel_size=1, bias=bias) + mid_channels, + mid_channels, + kernel_size=kernel_size, + padding=padding, + groups=mid_channels, + bias=bias, + **dd, + ) # depthwise conv + self.act2 = act2_layer(**dd) if issubclass(act2_layer, StarReLU) else act2_layer() + self.pwconv2 = nn.Conv2d(mid_channels, dim, kernel_size=1, bias=bias, **dd) def forward(self, x): x = self.pwconv1(x) @@ -277,10 +318,9 @@ class Pooling(nn.Module): Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418 """ - def __init__(self, pool_size=3, **kwargs): + def __init__(self, pool_size: int = 3, **kwargs): super().__init__() - self.pool = nn.AvgPool2d( - pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) + self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) def forward(self, x): y = self.pool(x) @@ -293,20 +333,23 @@ class MlpHead(nn.Module): def __init__( self, - dim, - num_classes=1000, - mlp_ratio=4, - act_layer=SquaredReLU, - norm_layer=LayerNorm, - drop_rate=0., - bias=True + dim: int, + num_classes: int = 1000, + mlp_ratio: float = 4, + act_layer: Type[nn.Module] = SquaredReLU, + norm_layer: Type[nn.Module] = LayerNorm, + drop_rate: float = 0., + bias: bool = True, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() hidden_features = int(mlp_ratio * dim) - self.fc1 = nn.Linear(dim, hidden_features, bias=bias) + self.fc1 = nn.Linear(dim, hidden_features, bias=bias, **dd) self.act = act_layer() - self.norm = norm_layer(hidden_features) - self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) + self.norm = norm_layer(hidden_features, **dd) + self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias, **dd) self.head_drop = nn.Dropout(drop_rate) def forward(self, x): @@ -325,29 +368,32 @@ class MetaFormerBlock(nn.Module): def __init__( self, - dim, - token_mixer=Pooling, - mlp_act=StarReLU, - mlp_bias=False, - norm_layer=LayerNorm2d, - proj_drop=0., - drop_path=0., - use_nchw=True, - layer_scale_init_value=None, - res_scale_init_value=None, + dim: int, + token_mixer: Type[nn.Module] = Pooling, + mlp_act: Type[nn.Module] = StarReLU, + mlp_bias: bool = False, + norm_layer: Type[nn.Module] = LayerNorm2d, + proj_drop: float = 0., + drop_path: float = 0., + use_nchw: bool = True, + layer_scale_init_value: Optional[float] = None, + res_scale_init_value: Optional[float] = None, + device=None, + dtype=None, **kwargs ): + dd = {'device': device, 'dtype': dtype} super().__init__() - ls_layer = partial(Scale, dim=dim, init_value=layer_scale_init_value, use_nchw=use_nchw) - rs_layer = partial(Scale, dim=dim, init_value=res_scale_init_value, use_nchw=use_nchw) + ls_layer = partial(Scale, dim=dim, init_value=layer_scale_init_value, use_nchw=use_nchw, **dd) + rs_layer = partial(Scale, dim=dim, init_value=res_scale_init_value, use_nchw=use_nchw, **dd) - self.norm1 = norm_layer(dim) - self.token_mixer = token_mixer(dim=dim, proj_drop=proj_drop, **kwargs) + self.norm1 = norm_layer(dim, **dd) + self.token_mixer = token_mixer(dim=dim, proj_drop=proj_drop, **dd, **kwargs) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.layer_scale1 = ls_layer() if layer_scale_init_value is not None else nn.Identity() self.res_scale1 = rs_layer() if res_scale_init_value is not None else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = Mlp( dim, int(4 * dim), @@ -355,6 +401,7 @@ def __init__( bias=mlp_bias, drop=proj_drop, use_conv=use_nchw, + **dd ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.layer_scale2 = ls_layer() if layer_scale_init_value is not None else nn.Identity() @@ -380,22 +427,24 @@ class MetaFormerStage(nn.Module): def __init__( self, - in_chs, - out_chs, - depth=2, - token_mixer=nn.Identity, - mlp_act=StarReLU, - mlp_bias=False, - downsample_norm=LayerNorm2d, - norm_layer=LayerNorm2d, - proj_drop=0., - dp_rates=[0.] * 2, - layer_scale_init_value=None, - res_scale_init_value=None, + in_chs: int, + out_chs: int, + depth: int = 2, + token_mixer: Type[nn.Module] = nn.Identity, + mlp_act: Type[nn.Module] = StarReLU, + mlp_bias: bool = False, + downsample_norm: Optional[Type[nn.Module]] = LayerNorm2d, + norm_layer: Type[nn.Module] = LayerNorm2d, + proj_drop: float = 0., + dp_rates: List[float] = [0.] * 2, + layer_scale_init_value: Optional[float] = None, + res_scale_init_value: Optional[float] = None, + device=None, + dtype=None, **kwargs, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.grad_checkpointing = False self.use_nchw = not issubclass(token_mixer, Attention) @@ -407,6 +456,7 @@ def __init__( stride=2, padding=1, norm_layer=downsample_norm, + **dd, ) self.blocks = nn.Sequential(*[MetaFormerBlock( @@ -420,6 +470,7 @@ def __init__( layer_scale_init_value=layer_scale_init_value, res_scale_init_value=res_scale_init_value, use_nchw=self.use_nchw, + **dd, **kwargs, ) for i in range(depth)]) @@ -473,26 +524,33 @@ class MetaFormer(nn.Module): def __init__( self, - in_chans=3, - num_classes=1000, - global_pool='avg', - depths=(2, 2, 6, 2), - dims=(64, 128, 320, 512), - token_mixers=Pooling, - mlp_act=StarReLU, - mlp_bias=False, - drop_path_rate=0., - proj_drop_rate=0., - drop_rate=0.0, - layer_scale_init_values=None, - res_scale_init_values=(None, None, 1.0, 1.0), - downsample_norm=LayerNorm2dNoBias, - norm_layers=LayerNorm2dNoBias, - output_norm=LayerNorm2d, - use_mlp_head=True, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + depths: Tuple[int, ...] = (2, 2, 6, 2), + dims: Tuple[int, ...] = (64, 128, 320, 512), + token_mixers: Union[Type[nn.Module], List[Type[nn.Module]]] = Pooling, + mlp_act: Type[nn.Module] = StarReLU, + mlp_bias: bool = False, + drop_path_rate: float = 0., + proj_drop_rate: float = 0., + drop_rate: float = 0.0, + layer_scale_init_values: Optional[Union[float, List[float]]] = None, + res_scale_init_values: Union[Tuple[Optional[float], ...], List[Optional[float]]] = (None, None, 1.0, 1.0), + downsample_norm: Optional[Type[nn.Module]] = LayerNorm2dNoBias, + norm_layers: Union[Type[nn.Module], List[Type[nn.Module]]] = LayerNorm2dNoBias, + output_norm: Type[nn.Module] = LayerNorm2d, + use_mlp_head: bool = True, + device=None, + dtype=None, **kwargs, ): super().__init__() + dd = {'device': device, 'dtype': dtype} + # Bind dd kwargs to activation layers that need them + if mlp_act in (StarReLU,): + mlp_act = partial(mlp_act, **dd) + self.num_classes = num_classes self.num_features = dims[-1] self.drop_rate = drop_rate @@ -519,7 +577,8 @@ def __init__( self.stem = Stem( in_chans, dims[0], - norm_layer=downsample_norm + norm_layer=downsample_norm, + **dd, ) stages = [] @@ -539,6 +598,7 @@ def __init__( res_scale_init_value=res_scale_init_values[i], downsample_norm=downsample_norm, norm_layer=norm_layers[i], + **dd, **kwargs, )] prev_dim = dims[i] @@ -550,17 +610,17 @@ def __init__( if num_classes > 0: if self.use_mlp_head: # FIXME not actually returning mlp hidden state right now as pre-logits. - final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate) + final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate, **dd) self.head_hidden_size = self.num_features else: - final = nn.Linear(self.num_features, num_classes) + final = nn.Linear(self.num_features, num_classes, **dd) self.head_hidden_size = self.num_features else: final = nn.Identity() self.head = nn.Sequential(OrderedDict([ ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), - ('norm', output_norm(self.num_features)), + ('norm', output_norm(self.num_features, **dd)), ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), ('drop', nn.Dropout(drop_rate) if self.use_mlp_head else nn.Identity()), ('fc', final) @@ -584,16 +644,17 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes if global_pool is not None: self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() if num_classes > 0: if self.use_mlp_head: - final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate) + final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate, **dd) else: - final = nn.Linear(self.num_features, num_classes) + final = nn.Linear(self.num_features, num_classes, **dd) else: final = nn.Identity() self.head.fc = final diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 838e0e0117..a6db3dd7bc 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -40,13 +40,14 @@ """ import math from functools import partial -from typing import Any, Dict, List, Optional, Union, Tuple +from typing import Any, Dict, List, Optional, Type, Union, Tuple import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple + from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint, checkpoint_seq @@ -65,11 +66,13 @@ def __init__( dim: int, seq_len: int, mlp_ratio: Union[float, Tuple[float, float]] = (0.5, 4.0), - mlp_layer: type = Mlp, - norm_layer: type = partial(nn.LayerNorm, eps=1e-6), - act_layer: type = nn.GELU, + mlp_layer: Type[nn.Module] = Mlp, + norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), + act_layer: Type[nn.Module] = nn.GELU, drop: float = 0., drop_path: float = 0., + device=None, + dtype=None, ) -> None: """Initialize MixerBlock. @@ -83,13 +86,14 @@ def __init__( drop: Dropout rate. drop_path: Drop path rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)] - self.norm1 = norm_layer(dim) - self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop) + self.norm1 = norm_layer(dim, **dd) + self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop, **dd) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) + self.norm2 = norm_layer(dim, **dd) + self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass.""" @@ -101,15 +105,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Affine(nn.Module): """Affine transformation layer.""" - def __init__(self, dim: int) -> None: + def __init__(self, dim: int, device=None, dtype=None) -> None: """Initialize Affine layer. Args: dim: Dimension of features. """ + dd = {'device': device, 'dtype': dtype} super().__init__() - self.alpha = nn.Parameter(torch.ones((1, 1, dim))) - self.beta = nn.Parameter(torch.zeros((1, 1, dim))) + self.alpha = nn.Parameter(torch.ones((1, 1, dim), **dd)) + self.beta = nn.Parameter(torch.zeros((1, 1, dim), **dd)) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply affine transformation.""" @@ -126,12 +131,14 @@ def __init__( dim: int, seq_len: int, mlp_ratio: float = 4, - mlp_layer: type = Mlp, - norm_layer: type = Affine, - act_layer: type = nn.GELU, + mlp_layer: Type[nn.Module] = Mlp, + norm_layer: Type[nn.Module] = Affine, + act_layer: Type[nn.Module] = nn.GELU, init_values: float = 1e-4, drop: float = 0., drop_path: float = 0., + device=None, + dtype=None, ) -> None: """Initialize ResBlock. @@ -146,15 +153,16 @@ def __init__( drop: Dropout rate. drop_path: Drop path rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() channel_dim = int(dim * mlp_ratio) - self.norm1 = norm_layer(dim) - self.linear_tokens = nn.Linear(seq_len, seq_len) + self.norm1 = norm_layer(dim, **dd) + self.linear_tokens = nn.Linear(seq_len, seq_len, **dd) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop) - self.ls1 = nn.Parameter(init_values * torch.ones(dim)) - self.ls2 = nn.Parameter(init_values * torch.ones(dim)) + self.norm2 = norm_layer(dim, **dd) + self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop, **dd) + self.ls1 = nn.Parameter(init_values * torch.ones(dim, **dd)) + self.ls2 = nn.Parameter(init_values * torch.ones(dim, **dd)) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass.""" @@ -168,7 +176,14 @@ class SpatialGatingUnit(nn.Module): Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 """ - def __init__(self, dim: int, seq_len: int, norm_layer: type = nn.LayerNorm) -> None: + def __init__( + self, + dim: int, + seq_len: int, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, + ) -> None: """Initialize Spatial Gating Unit. Args: @@ -176,10 +191,11 @@ def __init__(self, dim: int, seq_len: int, norm_layer: type = nn.LayerNorm) -> N seq_len: Sequence length. norm_layer: Normalization layer. """ + dd = {'device': device, 'dtype': dtype} super().__init__() gate_dim = dim // 2 - self.norm = norm_layer(gate_dim) - self.proj = nn.Linear(seq_len, seq_len) + self.norm = norm_layer(gate_dim, **dd) + self.proj = nn.Linear(seq_len, seq_len, **dd) def init_weights(self) -> None: """Initialize weights for projection gate.""" @@ -205,11 +221,13 @@ def __init__( dim: int, seq_len: int, mlp_ratio: float = 4, - mlp_layer: type = GatedMlp, - norm_layer: type = partial(nn.LayerNorm, eps=1e-6), - act_layer: type = nn.GELU, + mlp_layer: Type[nn.Module] = GatedMlp, + norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), + act_layer: Type[nn.Module] = nn.GELU, drop: float = 0., drop_path: float = 0., + device=None, + dtype=None, ) -> None: """Initialize SpatialGatingBlock. @@ -223,11 +241,19 @@ def __init__( drop: Dropout rate. drop_path: Drop path rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() channel_dim = int(dim * mlp_ratio) - self.norm = norm_layer(dim) - sgu = partial(SpatialGatingUnit, seq_len=seq_len) - self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop) + self.norm = norm_layer(dim, **dd) + sgu = partial(SpatialGatingUnit, seq_len=seq_len, **dd) + self.mlp_channels = mlp_layer( + dim, + channel_dim, + act_layer=act_layer, + gate_layer=sgu, + drop=drop, + **dd, + ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -251,16 +277,18 @@ def __init__( num_blocks: int = 8, embed_dim: int = 512, mlp_ratio: Union[float, Tuple[float, float]] = (0.5, 4.0), - block_layer: type = MixerBlock, - mlp_layer: type = Mlp, - norm_layer: type = partial(nn.LayerNorm, eps=1e-6), - act_layer: type = nn.GELU, + block_layer: Type[nn.Module] = MixerBlock, + mlp_layer: Type[nn.Module] = Mlp, + norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), + act_layer: Type[nn.Module] = nn.GELU, drop_rate: float = 0., proj_drop_rate: float = 0., drop_path_rate: float = 0., nlhb: bool = False, stem_norm: bool = False, global_pool: str = 'avg', + device=None, + dtype=None, ) -> None: """Initialize MLP-Mixer. @@ -284,6 +312,7 @@ def __init__( global_pool: Global pooling type. """ super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models @@ -295,6 +324,7 @@ def __init__( in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None, + **dd, ) reduction = self.stem.feat_ratio() if hasattr(self.stem, 'feat_ratio') else patch_size # FIXME drop_path (stochastic depth scaling rule or all the same?) @@ -308,13 +338,14 @@ def __init__( act_layer=act_layer, drop=proj_drop_rate, drop_path=drop_path_rate, + **dd, ) for _ in range(num_blocks)]) self.feature_info = [ dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(num_blocks)] - self.norm = norm_layer(embed_dim) + self.norm = norm_layer(embed_dim, **dd) self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(embed_dim, self.num_classes, **dd) if num_classes > 0 else nn.Identity() self.init_weights(nlhb=nlhb) @@ -368,7 +399,8 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) if global_pool is not None: assert global_pool in ('', 'avg') self.global_pool = global_pool - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + device, dtype = self.head.weight.device, self.head.weight.dtype if hasattr(self.head, 'weight') else (None, None) + self.head = nn.Linear(self.embed_dim, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity() def forward_intermediates( self, diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index eb87bb38d8..031cdbcf31 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -63,6 +63,8 @@ def __init__( drop_path_rate: float = 0., layer_scale_init_value: Optional[float] = None, global_pool: str = 'avg', + device=None, + dtype=None, ): """Initialize MobileNetV3. @@ -87,7 +89,8 @@ def __init__( layer_scale_init_value: Enable layer scale on compatible blocks if not None. global_pool: Type of pooling to use for global pooling features of the FC head. """ - super(MobileNetV3, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d norm_act_layer = get_norm_act_layer(norm_layer, act_layer) @@ -99,8 +102,8 @@ def __init__( # Stem if not fix_stem: stem_size = round_chs_fn(stem_size) - self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) - self.bn1 = norm_act_layer(stem_size, inplace=True) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type, **dd) + self.bn1 = norm_act_layer(stem_size, inplace=True, **dd) # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( @@ -114,6 +117,7 @@ def __init__( se_layer=se_layer, drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value, + **dd, ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features @@ -126,16 +130,30 @@ def __init__( num_pooled_chs = self.num_features * self.global_pool.feat_mult() if head_norm: # mobilenet-v4 post-pooling PW conv is followed by a norm+act layer - self.conv_head = create_conv2d(num_pooled_chs, self.head_hidden_size, 1, padding=pad_type) # never bias - self.norm_head = norm_act_layer(self.head_hidden_size) + self.conv_head = create_conv2d( + num_pooled_chs, + self.head_hidden_size, + 1, + padding=pad_type, + bias=False, # never a bias + **dd, + ) + self.norm_head = norm_act_layer(self.head_hidden_size, **dd) self.act2 = nn.Identity() else: # mobilenet-v3 and others only have an activation after final PW conv - self.conv_head = create_conv2d(num_pooled_chs, self.head_hidden_size, 1, padding=pad_type, bias=head_bias) + self.conv_head = create_conv2d( + num_pooled_chs, + self.head_hidden_size, + 1, + padding=pad_type, + bias=head_bias, + **dd, + ) self.norm_head = nn.Identity() self.act2 = act_layer(inplace=True) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(self.head_hidden_size, num_classes, **dd) if num_classes > 0 else nn.Identity() efficientnet_init_weights(self) @@ -351,6 +369,8 @@ def __init__( drop_rate: float = 0., drop_path_rate: float = 0., layer_scale_init_value: Optional[float] = None, + device=None, + dtype=None, ): """Initialize MobileNetV3Features. @@ -373,7 +393,8 @@ def __init__( drop_path_rate: Stochastic depth rate. layer_scale_init_value: Enable layer scale on compatible blocks if not None. """ - super(MobileNetV3Features, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d se_layer = se_layer or SqueezeExcite @@ -383,8 +404,8 @@ def __init__( # Stem if not fix_stem: stem_size = round_chs_fn(stem_size) - self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) - self.bn1 = norm_layer(stem_size) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type, **dd) + self.bn1 = norm_layer(stem_size, **dd) self.act1 = act_layer(inplace=True) # Middle stages (IR/ER/DS Blocks) @@ -400,6 +421,7 @@ def __init__( drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value, feature_location=feature_location, + **dd, ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = FeatureInfo(builder.features, out_indices) diff --git a/timm/models/mobilenetv5.py b/timm/models/mobilenetv5.py index 8a3e132d1e..f22dc76a3b 100644 --- a/timm/models/mobilenetv5.py +++ b/timm/models/mobilenetv5.py @@ -60,7 +60,10 @@ def __init__( noskip: bool = True, act_layer: Optional[LayerType] = None, norm_layer: Optional[LayerType] = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.in_channels = sum(in_chs) if isinstance(in_chs, Sequence) else in_chs self.out_channels = out_chs @@ -81,9 +84,10 @@ def __init__( norm_layer=norm_layer, noskip=self.noskip, layer_scale_init_value=self.layer_scale_init_value, + **dd, ) - self.norm = norm_layer(self.out_channels) + self.norm = norm_layer(self.out_channels, **dd) def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: # Inputs list of [B, C, H, W] tensors @@ -146,6 +150,8 @@ def __init__( drop_path_rate: float = 0., layer_scale_init_value: Optional[float] = None, global_pool: str = 'avg', + device=None, + dtype=None, ): """ Args: @@ -169,6 +175,7 @@ def __init__( global_pool: Type of pooling to use for global pooling features of the FC head. """ super().__init__() + dd = {'device': device, 'dtype': dtype} act_layer = act_layer or _GELU norm_layer = get_norm_layer(norm_layer) or RmsNorm2d norm_act_layer = get_norm_act_layer(norm_layer, act_layer) @@ -191,6 +198,7 @@ def __init__( bias=stem_bias, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) # Middle stages (IR/ER/DS Blocks) @@ -205,6 +213,7 @@ def __init__( se_layer=se_layer, drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value, + **dd, ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features @@ -224,6 +233,7 @@ def __init__( output_resolution=self.msfa_output_resolution, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.conv_head = None @@ -235,11 +245,11 @@ def __init__( self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) num_pooled_chs = self.num_features * self.global_pool.feat_mult() # mobilenet-v4 style post-pooling PW conv is followed by a norm+act layer - self.conv_head = create_conv2d(num_pooled_chs, self.head_hidden_size, 1, padding=pad_type) - self.norm_head = norm_act_layer(self.head_hidden_size) + self.conv_head = create_conv2d(num_pooled_chs, self.head_hidden_size, 1, padding=pad_type, **dd) + self.norm_head = norm_act_layer(self.head_hidden_size, **dd) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(self.head_hidden_size, num_classes, **dd) if num_classes > 0 else nn.Identity() efficientnet_init_weights(self) @@ -426,8 +436,11 @@ def __init__( drop_rate: float = 0., drop_path_rate: float = 0., layer_scale_init_value: Optional[float] = None, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} act_layer = act_layer or _GELU norm_layer = get_norm_layer(norm_layer) or RmsNorm2d se_layer = se_layer or SqueezeExcite @@ -447,6 +460,7 @@ def __init__( bias=stem_bias, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) builder = EfficientNetBuilder( @@ -460,6 +474,7 @@ def __init__( se_layer=se_layer, drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value, + **dd, ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features @@ -477,6 +492,7 @@ def __init__( output_resolution=self.msfa_output_resolution, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) efficientnet_init_weights(self) diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index 9c84871e6d..93c9d4e306 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -14,7 +14,7 @@ # Copyright (C) 2020 Apple Inc. All Rights Reserved. # import math -from typing import Callable, Tuple, Optional +from typing import Callable, Tuple, Optional, Type import torch import torch.nn.functional as F @@ -185,20 +185,28 @@ def __init__( no_fusion: bool = False, drop_path_rate: float = 0., layers: LayerFn = None, - transformer_norm_layer: Callable = nn.LayerNorm, + transformer_norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, **kwargs, # eat unused args ): - super(MobileVitBlock, self).__init__() - + dd = {'device': device, 'dtype': dtype} + super().__init__() layers = layers or LayerFn() groups = num_groups(group_size, in_chs) out_chs = out_chs or in_chs transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs) self.conv_kxk = layers.conv_norm_act( - in_chs, in_chs, kernel_size=kernel_size, - stride=stride, groups=groups, dilation=dilation[0]) - self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False) + in_chs, + in_chs, + kernel_size=kernel_size, + stride=stride, + groups=groups, + dilation=dilation[0], + **dd, + ) + self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False, **dd) self.transformer = nn.Sequential(*[ TransformerBlock( @@ -211,17 +219,18 @@ def __init__( drop_path=drop_path_rate, act_layer=layers.act, norm_layer=transformer_norm_layer, + **dd, ) for _ in range(transformer_depth) ]) - self.norm = transformer_norm_layer(transformer_dim) + self.norm = transformer_norm_layer(transformer_dim, **dd) - self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1) + self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, **dd) if no_fusion: self.conv_fusion = None else: - self.conv_fusion = layers.conv_norm_act(in_chs + out_chs, out_chs, kernel_size=kernel_size, stride=1) + self.conv_fusion = layers.conv_norm_act(in_chs + out_chs, out_chs, kernel_size=kernel_size, stride=1, **dd) self.patch_size = to_2tuple(patch_size) self.patch_area = self.patch_size[0] * self.patch_size[1] @@ -290,12 +299,15 @@ class LinearSelfAttention(nn.Module): """ def __init__( - self, - embed_dim: int, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - bias: bool = True, + self, + embed_dim: int, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + bias: bool = True, + device=None, + dtype=None, ) -> None: + dd = {'device': device, 'dtype': dtype} super().__init__() self.embed_dim = embed_dim @@ -304,6 +316,7 @@ def __init__( out_channels=1 + (2 * embed_dim), bias=bias, kernel_size=1, + **dd, ) self.attn_drop = nn.Dropout(attn_drop) self.out_proj = nn.Conv2d( @@ -311,6 +324,7 @@ def __init__( out_channels=embed_dim, bias=bias, kernel_size=1, + **dd, ) self.out_drop = nn.Dropout(proj_drop) @@ -405,29 +419,33 @@ class LinearTransformerBlock(nn.Module): """ def __init__( - self, - embed_dim: int, - mlp_ratio: float = 2.0, - drop: float = 0.0, - attn_drop: float = 0.0, - drop_path: float = 0.0, - act_layer=None, - norm_layer=None, + self, + embed_dim: int, + mlp_ratio: float = 2.0, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: Optional[Type[nn.Module]] = None, + norm_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ) -> None: + dd = {'device': device, 'dtype': dtype} super().__init__() act_layer = act_layer or nn.SiLU norm_layer = norm_layer or GroupNorm1 - self.norm1 = norm_layer(embed_dim) - self.attn = LinearSelfAttention(embed_dim=embed_dim, attn_drop=attn_drop, proj_drop=drop) + self.norm1 = norm_layer(embed_dim, **dd) + self.attn = LinearSelfAttention(embed_dim=embed_dim, attn_drop=attn_drop, proj_drop=drop, **dd) self.drop_path1 = DropPath(drop_path) - self.norm2 = norm_layer(embed_dim) + self.norm2 = norm_layer(embed_dim, **dd) self.mlp = ConvMlp( in_features=embed_dim, hidden_features=int(embed_dim * mlp_ratio), act_layer=act_layer, - drop=drop) + drop=drop, + **dd) self.drop_path2 = DropPath(drop_path) def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -453,34 +471,43 @@ class MobileVitV2Block(nn.Module): """ def __init__( - self, - in_chs: int, - out_chs: Optional[int] = None, - kernel_size: int = 3, - bottle_ratio: float = 1.0, - group_size: Optional[int] = 1, - dilation: Tuple[int, int] = (1, 1), - mlp_ratio: float = 2.0, - transformer_dim: Optional[int] = None, - transformer_depth: int = 2, - patch_size: int = 8, - attn_drop: float = 0., - drop: int = 0., - drop_path_rate: float = 0., - layers: LayerFn = None, - transformer_norm_layer: Callable = GroupNorm1, - **kwargs, # eat unused args + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 3, + bottle_ratio: float = 1.0, + group_size: Optional[int] = 1, + dilation: Tuple[int, int] = (1, 1), + mlp_ratio: float = 2.0, + transformer_dim: Optional[int] = None, + transformer_depth: int = 2, + patch_size: int = 8, + attn_drop: float = 0., + drop: int = 0., + drop_path_rate: float = 0., + layers: LayerFn = None, + transformer_norm_layer: Type[nn.Module] = GroupNorm1, + device=None, + dtype=None, + **kwargs, # eat unused args ): - super(MobileVitV2Block, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() layers = layers or LayerFn() groups = num_groups(group_size, in_chs) out_chs = out_chs or in_chs transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs) self.conv_kxk = layers.conv_norm_act( - in_chs, in_chs, kernel_size=kernel_size, - stride=1, groups=groups, dilation=dilation[0]) - self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False) + in_chs, + in_chs, + kernel_size=kernel_size, + stride=1, + groups=groups, + dilation=dilation[0], + **dd, + ) + self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False, **dd) self.transformer = nn.Sequential(*[ LinearTransformerBlock( @@ -490,13 +517,14 @@ def __init__( drop=drop, drop_path=drop_path_rate, act_layer=layers.act, - norm_layer=transformer_norm_layer + norm_layer=transformer_norm_layer, + **dd, ) for _ in range(transformer_depth) ]) - self.norm = transformer_norm_layer(transformer_dim) + self.norm = transformer_norm_layer(transformer_dim, **dd) - self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, apply_act=False) + self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, apply_act=False, **dd) self.patch_size = to_2tuple(patch_size) self.patch_area = self.patch_size[0] * self.patch_size[1] diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 514e8733f7..6b46e55970 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -17,7 +17,7 @@ from collections import OrderedDict from dataclasses import dataclass from functools import partial, reduce -from typing import Union, List, Tuple, Optional +from typing import Union, List, Tuple, Optional, Any, Type import torch from torch import nn @@ -93,13 +93,16 @@ class PatchEmbed(nn.Module): def __init__( self, - dim_in=3, - dim_out=768, - kernel=(7, 7), - stride=(4, 4), - padding=(3, 3), + dim_in: int = 3, + dim_out: int = 768, + kernel: Tuple[int, int] = (7, 7), + stride: Tuple[int, int] = (4, 4), + padding: Tuple[int, int] = (3, 3), + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.proj = nn.Conv2d( dim_in, @@ -107,6 +110,7 @@ def __init__( kernel_size=kernel, stride=stride, padding=padding, + **dd, ) def forward(self, x) -> Tuple[torch.Tensor, List[int]]: @@ -165,15 +169,15 @@ def cal_rel_pos_type( q_h_ratio = max(k_h / q_h, 1.0) k_h_ratio = max(q_h / k_h, 1.0) dist_h = ( - torch.arange(q_h, device=q.device).unsqueeze(-1) * q_h_ratio - - torch.arange(k_h, device=q.device).unsqueeze(0) * k_h_ratio + torch.arange(q_h, device=q.device, dtype=torch.long).unsqueeze(-1) * q_h_ratio - + torch.arange(k_h, device=q.device, dtype=torch.long).unsqueeze(0) * k_h_ratio ) dist_h += (k_h - 1) * k_h_ratio q_w_ratio = max(k_w / q_w, 1.0) k_w_ratio = max(q_w / k_w, 1.0) dist_w = ( - torch.arange(q_w, device=q.device).unsqueeze(-1) * q_w_ratio - - torch.arange(k_w, device=q.device).unsqueeze(0) * k_w_ratio + torch.arange(q_w, device=q.device, dtype=torch.long).unsqueeze(-1) * q_w_ratio - + torch.arange(k_w, device=q.device, dtype=torch.long).unsqueeze(0) * k_w_ratio ) dist_w += (k_w - 1) * k_w_ratio @@ -198,21 +202,24 @@ def cal_rel_pos_type( class MultiScaleAttentionPoolFirst(nn.Module): def __init__( self, - dim, - dim_out, - feat_size, - num_heads=8, - qkv_bias=True, - mode="conv", - kernel_q=(1, 1), - kernel_kv=(1, 1), - stride_q=(1, 1), - stride_kv=(1, 1), - has_cls_token=True, - rel_pos_type='spatial', - residual_pooling=True, - norm_layer=nn.LayerNorm, + dim: int, + dim_out: int, + feat_size: Tuple[int, int], + num_heads: int = 8, + qkv_bias: bool = True, + mode: str = "conv", + kernel_q: Tuple[int, int] = (1, 1), + kernel_kv: Tuple[int, int] = (1, 1), + stride_q: Tuple[int, int] = (1, 1), + stride_kv: Tuple[int, int] = (1, 1), + has_cls_token: bool = True, + rel_pos_type: str = 'spatial', + residual_pooling: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads self.dim_out = dim_out @@ -222,10 +229,10 @@ def __init__( padding_q = tuple([int(q // 2) for q in kernel_q]) padding_kv = tuple([int(kv // 2) for kv in kernel_kv]) - self.q = nn.Linear(dim, dim_out, bias=qkv_bias) - self.k = nn.Linear(dim, dim_out, bias=qkv_bias) - self.v = nn.Linear(dim, dim_out, bias=qkv_bias) - self.proj = nn.Linear(dim_out, dim_out) + self.q = nn.Linear(dim, dim_out, bias=qkv_bias, **dd) + self.k = nn.Linear(dim, dim_out, bias=qkv_bias, **dd) + self.v = nn.Linear(dim, dim_out, bias=qkv_bias, **dd) + self.proj = nn.Linear(dim_out, dim_out, **dd) # Skip pooling with kernel and stride size of (1, 1, 1). if prod(kernel_q) == 1 and prod(stride_q) == 1: @@ -254,8 +261,9 @@ def __init__( padding=padding_q, groups=dim_conv, bias=False, + **dd, ) - self.norm_q = norm_layer(dim_conv) + self.norm_q = norm_layer(dim_conv, **dd) if kernel_kv: self.pool_k = nn.Conv2d( dim_conv, @@ -265,8 +273,9 @@ def __init__( padding=padding_kv, groups=dim_conv, bias=False, + **dd, ) - self.norm_k = norm_layer(dim_conv) + self.norm_k = norm_layer(dim_conv, **dd) self.pool_v = nn.Conv2d( dim_conv, dim_conv, @@ -275,8 +284,9 @@ def __init__( padding=padding_kv, groups=dim_conv, bias=False, + **dd, ) - self.norm_v = norm_layer(dim_conv) + self.norm_v = norm_layer(dim_conv, **dd) else: raise NotImplementedError(f"Unsupported model {mode}") @@ -289,8 +299,8 @@ def __init__( kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size rel_sp_dim = 2 * max(q_size, kv_size) - 1 - self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) - self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) + self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd)) trunc_normal_tf_(self.rel_pos_h, std=0.02) trunc_normal_tf_(self.rel_pos_w, std=0.02) @@ -368,21 +378,24 @@ def forward(self, x, feat_size: List[int]): class MultiScaleAttention(nn.Module): def __init__( self, - dim, - dim_out, - feat_size, - num_heads=8, - qkv_bias=True, - mode="conv", - kernel_q=(1, 1), - kernel_kv=(1, 1), - stride_q=(1, 1), - stride_kv=(1, 1), - has_cls_token=True, - rel_pos_type='spatial', - residual_pooling=True, - norm_layer=nn.LayerNorm, + dim: int, + dim_out: int, + feat_size: Tuple[int, int], + num_heads: int = 8, + qkv_bias: bool = True, + mode: str = "conv", + kernel_q: Tuple[int, int] = (1, 1), + kernel_kv: Tuple[int, int] = (1, 1), + stride_q: Tuple[int, int] = (1, 1), + stride_kv: Tuple[int, int] = (1, 1), + has_cls_token: bool = True, + rel_pos_type: str = 'spatial', + residual_pooling: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads self.dim_out = dim_out @@ -392,8 +405,8 @@ def __init__( padding_q = tuple([int(q // 2) for q in kernel_q]) padding_kv = tuple([int(kv // 2) for kv in kernel_kv]) - self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) - self.proj = nn.Linear(dim_out, dim_out) + self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias, **dd) + self.proj = nn.Linear(dim_out, dim_out, **dd) # Skip pooling with kernel and stride size of (1, 1, 1). if prod(kernel_q) == 1 and prod(stride_q) == 1: @@ -422,8 +435,9 @@ def __init__( padding=padding_q, groups=dim_conv, bias=False, + **dd, ) - self.norm_q = norm_layer(dim_conv) + self.norm_q = norm_layer(dim_conv, **dd) if kernel_kv: self.pool_k = nn.Conv2d( dim_conv, @@ -433,8 +447,9 @@ def __init__( padding=padding_kv, groups=dim_conv, bias=False, + **dd, ) - self.norm_k = norm_layer(dim_conv) + self.norm_k = norm_layer(dim_conv, **dd) self.pool_v = nn.Conv2d( dim_conv, dim_conv, @@ -443,8 +458,9 @@ def __init__( padding=padding_kv, groups=dim_conv, bias=False, + **dd, ) - self.norm_v = norm_layer(dim_conv) + self.norm_v = norm_layer(dim_conv, **dd) else: raise NotImplementedError(f"Unsupported model {mode}") @@ -457,8 +473,8 @@ def __init__( kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size rel_sp_dim = 2 * max(q_size, kv_size) - 1 - self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) - self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) + self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd)) trunc_normal_tf_(self.rel_pos_h, std=0.02) trunc_normal_tf_(self.rel_pos_w, std=0.02) @@ -521,34 +537,37 @@ def forward(self, x, feat_size: List[int]): class MultiScaleBlock(nn.Module): def __init__( self, - dim, - dim_out, - num_heads, - feat_size, - mlp_ratio=4.0, - qkv_bias=True, - drop_path=0.0, - norm_layer=nn.LayerNorm, - kernel_q=(1, 1), - kernel_kv=(1, 1), - stride_q=(1, 1), - stride_kv=(1, 1), - mode="conv", - has_cls_token=True, - expand_attn=False, - pool_first=False, - rel_pos_type='spatial', - residual_pooling=True, + dim: int, + dim_out: int, + num_heads: int, + feat_size: Tuple[int, int], + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop_path: float = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, + kernel_q: Tuple[int, int] = (1, 1), + kernel_kv: Tuple[int, int] = (1, 1), + stride_q: Tuple[int, int] = (1, 1), + stride_kv: Tuple[int, int] = (1, 1), + mode: str = "conv", + has_cls_token: bool = True, + expand_attn: bool = False, + pool_first: bool = False, + rel_pos_type: str = 'spatial', + residual_pooling: bool = True, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() proj_needed = dim != dim_out self.dim = dim self.dim_out = dim_out self.has_cls_token = has_cls_token - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) - self.shortcut_proj_attn = nn.Linear(dim, dim_out) if proj_needed and expand_attn else None + self.shortcut_proj_attn = nn.Linear(dim, dim_out, **dd) if proj_needed and expand_attn else None if stride_q and prod(stride_q) > 1: kernel_skip = [s + 1 if s > 1 else s for s in stride_q] stride_skip = stride_q @@ -574,16 +593,18 @@ def __init__( mode=mode, rel_pos_type=rel_pos_type, residual_pooling=residual_pooling, + **dd, ) self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - self.norm2 = norm_layer(att_dim) + self.norm2 = norm_layer(att_dim, **dd) mlp_dim_out = dim_out - self.shortcut_proj_mlp = nn.Linear(dim, dim_out) if proj_needed and not expand_attn else None + self.shortcut_proj_mlp = nn.Linear(dim, dim_out, **dd) if proj_needed and not expand_attn else None self.mlp = Mlp( in_features=att_dim, hidden_features=int(att_dim * mlp_ratio), out_features=mlp_dim_out, + **dd, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() @@ -621,26 +642,29 @@ class MultiScaleVitStage(nn.Module): def __init__( self, - dim, - dim_out, - depth, - num_heads, - feat_size, - mlp_ratio=4.0, - qkv_bias=True, - mode="conv", - kernel_q=(1, 1), - kernel_kv=(1, 1), - stride_q=(1, 1), - stride_kv=(1, 1), - has_cls_token=True, - expand_attn=False, - pool_first=False, - rel_pos_type='spatial', - residual_pooling=True, - norm_layer=nn.LayerNorm, - drop_path=0.0, + dim: int, + dim_out: int, + depth: int, + num_heads: int, + feat_size: Tuple[int, int], + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + kernel_q: Tuple[int, int] = (1, 1), + kernel_kv: Tuple[int, int] = (1, 1), + stride_q: Tuple[int, int] = (1, 1), + stride_kv: Tuple[int, int] = (1, 1), + mode: str = "conv", + has_cls_token: bool = True, + expand_attn: bool = False, + pool_first: bool = False, + rel_pos_type: str = 'spatial', + residual_pooling: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + drop_path: Union[float, List[float]] = 0.0, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.grad_checkpointing = False @@ -670,6 +694,7 @@ def __init__( expand_attn=expand_attn, norm_layer=norm_layer, drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path, + **dd, ) dim = out_dims[i] self.blocks.append(attention_block) @@ -709,8 +734,11 @@ def __init__( num_classes: int = 1000, drop_path_rate: float = 0., drop_rate: float = 0., + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} img_size = to_2tuple(img_size) norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) self.num_classes = num_classes @@ -728,12 +756,13 @@ def __init__( kernel=cfg.patch_kernel, stride=cfg.patch_stride, padding=cfg.patch_padding, + **dd, ) patch_dims = (img_size[0] // cfg.patch_stride[0], img_size[1] // cfg.patch_stride[1]) num_patches = prod(patch_dims) if cfg.use_cls_token: - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd)) self.num_prefix_tokens = 1 pos_embed_dim = num_patches + 1 else: @@ -742,7 +771,7 @@ def __init__( pos_embed_dim = num_patches if cfg.use_abs_pos: - self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_dim, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_dim, embed_dim, **dd)) else: self.pos_embed = None @@ -777,6 +806,7 @@ def __init__( residual_pooling=cfg.residual_pooling, norm_layer=norm_layer, drop_path=dpr[i], + **dd, ) curr_stride *= max(cfg.stride_q[i]) self.feature_info += [dict(module=f'block.{i}', num_chs=dim_out, reduction=curr_stride)] @@ -785,10 +815,10 @@ def __init__( self.stages.append(stage) self.num_features = self.head_hidden_size = embed_dim - self.norm = norm_layer(embed_dim) + self.norm = norm_layer(embed_dim, **dd) self.head = nn.Sequential(OrderedDict([ ('drop', nn.Dropout(self.drop_rate)), - ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) + ('fc', nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()) ])) if self.pos_embed is not None: @@ -829,9 +859,11 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool + device = self.head.fc.weight.device if hasattr(self.head.fc, 'weight') else None + dtype = self.head.fc.weight.dtype if hasattr(self.head.fc, 'weight') else None self.head = nn.Sequential(OrderedDict([ ('drop', nn.Dropout(self.drop_rate)), - ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) + ('fc', nn.Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()) ])) def forward_intermediates( diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index 67f658b3c4..4e8e691b61 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -203,8 +203,8 @@ def __init__( unique_sizes: List[Tuple[int, int]], batch_size: int, seq_len: int, - dtype: torch.dtype, device: torch.device, + dtype: torch.dtype, ): self.rope = rope_module self.size_to_indices = size_to_indices @@ -362,6 +362,8 @@ def __init__( norm_layer: Optional[Type[nn.Module]] = None, pos_drop_rate: float = 0., enable_patch_interpolator: bool = False, + device=None, + dtype=None, ) -> None: """Initialize NaFlexEmbeds module. @@ -385,6 +387,7 @@ def __init__( pos_drop_rate: Dropout rate for position embeddings. enable_patch_interpolator: Enable dynamic patch size support. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.has_class_token = class_token self.num_reg_tokens = reg_tokens @@ -402,8 +405,8 @@ def __init__( self.num_prefix_tokens += reg_tokens # Create class and register tokens - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None - self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd)) if class_token else None + self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim, **dd)) if reg_tokens else None # Calculate grid size and number of patches self.default_img_size: Optional[Tuple[int, int]] = None @@ -425,7 +428,7 @@ def __init__( "`norm_layer` must be given when input_norm_layer=True" input_norm_layer = norm_layer if input_norm_layer is True else (input_norm_layer or None) self.norm_input = input_norm_layer(patch_dim) if input_norm_layer else None - self.proj = nn.Linear(patch_dim, embed_dim, bias=proj_bias) + self.proj = nn.Linear(patch_dim, embed_dim, bias=proj_bias, **dd) self.flatten = False self.is_linear = True else: @@ -433,7 +436,12 @@ def __init__( assert not input_norm_layer self.norm_input = None self.proj = nn.Conv2d( - in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=proj_bias + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=proj_bias, + **dd, ) self.flatten = True self.is_linear = False @@ -470,12 +478,12 @@ def __init__( assert self.pos_embed_grid_size is not None h, w = self.pos_embed_grid_size self.pos_embed_type = 'factorized' - self.pos_embed_y = nn.Parameter(torch.randn(1, h, embed_dim) * .02) - self.pos_embed_x = nn.Parameter(torch.randn(1, w, embed_dim) * .02) + self.pos_embed_y = nn.Parameter(torch.randn(1, h, embed_dim, **dd) * .02) + self.pos_embed_x = nn.Parameter(torch.randn(1, w, embed_dim, **dd) * .02) else: assert self.pos_embed_grid_size is not None h, w = self.pos_embed_grid_size - self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02) + self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim, **dd) * .02) self.pos_embed_type = 'learned' # Dropout layer @@ -623,7 +631,7 @@ def _apply_learned_naflex_pos_embed_grid_sample( padding_mode='border', ).to(dtype=x.dtype) # (B, C, H_out, W_out) - bi = torch.arange(B, device=device).unsqueeze(1) + bi = torch.arange(B, device=device, dtype=torch.long).unsqueeze(1) x += pos_embed[bi, :, patch_coord[..., 0], patch_coord[..., 1]] # NOTE leave as '+=' def _apply_learned_pos_embed( @@ -768,7 +776,7 @@ def _interp1d(table: torch.Tensor, scale: torch.Tensor, out_length: torch.Tensor pe_x = _interp1d(self.pos_embed_x, scale=scale_x, out_length=grid_size_x) pe_y = _interp1d(self.pos_embed_y, scale=scale_y, out_length=grid_size_y) - bi = torch.arange(B, device=device).unsqueeze(1) + bi = torch.arange(B, device=device, dtype=torch.long).unsqueeze(1) x += pe_x[bi, :, 0, patch_coord[..., 1]] + pe_y[bi, :, 0, patch_coord[..., 0]] def _apply_factorized_pos_embed( @@ -1080,6 +1088,8 @@ def __init__( in_chans: int = 3, num_classes: int = 1000, img_size: Optional[Union[int, Tuple[int, int]]] = None, + device=None, + dtype=None, **kwargs, ) -> None: """Initialize NaFlexVit model. @@ -1092,6 +1102,7 @@ def __init__( **kwargs: Additional config parameters to override cfg values. """ super().__init__() + dd = {'device': device, 'dtype': dtype} # Initialize config cfg = cfg or NaFlexVitCfg() @@ -1141,8 +1152,9 @@ def __init__( proj_norm_layer=embed_norm_layer, pos_drop_rate=cfg.pos_drop_rate, enable_patch_interpolator=getattr(cfg, 'enable_patch_interpolator', False), + **dd, ) - self.norm_pre = norm_layer(cfg.embed_dim) if cfg.pre_norm else nn.Identity() + self.norm_pre = norm_layer(cfg.embed_dim, **dd) if cfg.pre_norm else nn.Identity() # ROPE position embeddings at model level self.rope: Optional[nn.Module] = None @@ -1157,6 +1169,7 @@ def __init__( temperature=cfg.rope_temperature, feat_shape=None, # Dynamic shapes for NaFlex grid_indexing=cfg.rope_grid_indexing, + **dd, ) self.rope_is_mixed = True elif cfg.rope_type == 'axial': @@ -1168,6 +1181,7 @@ def __init__( ref_feat_shape=cfg.rope_ref_feat_shape, grid_offset=cfg.rope_grid_offset, grid_indexing=cfg.rope_grid_indexing, + **dd, ) self.rope_is_mixed = False else: @@ -1200,6 +1214,7 @@ def __init__( norm_layer=norm_layer, act_layer=act_layer, mlp_layer=mlp_layer, + **dd, ) for i in range(cfg.depth) ]) @@ -1211,7 +1226,7 @@ def __init__( for i in range(cfg.depth) ] - self.norm = norm_layer(cfg.embed_dim) if cfg.final_norm and not cfg.fc_norm else nn.Identity() + self.norm = norm_layer(cfg.embed_dim, **dd) if cfg.final_norm and not cfg.fc_norm else nn.Identity() # Classifier Head if cfg.global_pool == 'map': @@ -1221,6 +1236,7 @@ def __init__( mlp_ratio=cfg.attn_pool_mlp_ratio or cfg.mlp_ratio, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) else: self.attn_pool = None @@ -1229,9 +1245,9 @@ def __init__( fc_norm = cfg.fc_norm if fc_norm is None: fc_norm = cfg.global_pool == 'avg' - self.fc_norm = norm_layer(cfg.embed_dim) if cfg.final_norm and fc_norm else nn.Identity() + self.fc_norm = norm_layer(cfg.embed_dim, **dd) if cfg.final_norm and fc_norm else nn.Identity() self.head_drop = nn.Dropout(cfg.drop_rate) - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity() if cfg.weight_init != 'skip': self.init_weights(cfg.weight_init) @@ -1921,6 +1937,8 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: hf_hub_id='timm/', ), 'naflexvit_base_patch16_map.untrained': _cfg(), + 'naflexvit_so150m2_patch16_reg1_gap.untrained': _cfg(), + 'naflexvit_so150m2_patch16_reg1_map.untrained': _cfg(), # SigLIP-2 NaFlex vit encoder weights 'naflexvit_base_patch16_siglip.v2_webli': _cfg( diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 8edee45c86..af17f44af4 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -3,6 +3,7 @@ https://github.com/Cadene/pretrained-models.pytorch """ from functools import partial +from typing import Optional, Type import torch import torch.nn as nn @@ -17,12 +18,22 @@ class ActConvBn(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''): - super(ActConvBn, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.act = nn.ReLU() self.conv = create_conv2d( - in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) - self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1) + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, **dd) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, **dd) def forward(self, x): x = self.act(x) @@ -33,13 +44,34 @@ def forward(self, x): class SeparableConv2d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''): - super(SeparableConv2d, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.depthwise_conv2d = create_conv2d( - in_channels, in_channels, kernel_size=kernel_size, - stride=stride, padding=padding, groups=in_channels) + in_channels, + in_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=in_channels, + **dd, + ) self.pointwise_conv2d = create_conv2d( - in_channels, out_channels, kernel_size=1, padding=0) + in_channels, + out_channels, + kernel_size=1, + padding=0, + **dd, + ) def forward(self, x): x = self.depthwise_conv2d(x) @@ -49,17 +81,28 @@ def forward(self, x): class BranchSeparables(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_type='', stem_cell=False): - super(BranchSeparables, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + pad_type: str = '', + stem_cell: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() middle_channels = out_channels if stem_cell else in_channels self.act_1 = nn.ReLU() self.separable_1 = SeparableConv2d( - in_channels, middle_channels, kernel_size, stride=stride, padding=pad_type) - self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, momentum=0.1) + in_channels, middle_channels, kernel_size, stride=stride, padding=pad_type, **dd) + self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, momentum=0.1, **dd) self.act_2 = nn.ReLU(inplace=True) self.separable_2 = SeparableConv2d( - middle_channels, out_channels, kernel_size, stride=1, padding=pad_type) - self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1) + middle_channels, out_channels, kernel_size, stride=1, padding=pad_type, **dd) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, **dd) def forward(self, x): x = self.act_1(x) @@ -72,24 +115,37 @@ def forward(self, x): class CellStem0(nn.Module): - def __init__(self, stem_size, num_channels=42, pad_type=''): - super(CellStem0, self).__init__() + def __init__( + self, + stem_size: int, + num_channels: int = 42, + pad_type: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.num_channels = num_channels self.stem_size = stem_size - self.conv_1x1 = ActConvBn(self.stem_size, self.num_channels, 1, stride=1) + self.conv_1x1 = ActConvBn(self.stem_size, self.num_channels, 1, stride=1, **dd) - self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) - self.comb_iter_0_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True) + self.comb_iter_0_left = BranchSeparables( + self.num_channels, self.num_channels, 5, 2, pad_type, **dd) + self.comb_iter_0_right = BranchSeparables( + self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True, **dd) self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) - self.comb_iter_1_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True) + self.comb_iter_1_right = BranchSeparables( + self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True, **dd) self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) - self.comb_iter_2_right = BranchSeparables(self.stem_size, self.num_channels, 5, 2, pad_type, stem_cell=True) + self.comb_iter_2_right = BranchSeparables( + self.stem_size, self.num_channels, 5, 2, pad_type, stem_cell=True, **dd) self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type) + self.comb_iter_4_left = BranchSeparables( + self.num_channels, self.num_channels, 3, 1, pad_type, **dd) self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) def forward(self, x): @@ -120,36 +176,44 @@ def forward(self, x): class CellStem1(nn.Module): - def __init__(self, stem_size, num_channels, pad_type=''): - super(CellStem1, self).__init__() + def __init__( + self, + stem_size: int, + num_channels: int, + pad_type: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.num_channels = num_channels self.stem_size = stem_size - self.conv_1x1 = ActConvBn(2 * self.num_channels, self.num_channels, 1, stride=1) + self.conv_1x1 = ActConvBn(2 * self.num_channels, self.num_channels, 1, stride=1, **dd) self.act = nn.ReLU() self.path_1 = nn.Sequential() self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) - self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False)) + self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False, **dd)) self.path_2 = nn.Sequential() self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1))) self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) - self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False)) + self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False, **dd)) - self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1) + self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, **dd) - self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) - self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type) + self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type, **dd) + self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type, **dd) self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) - self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type) + self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type, **dd) self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) - self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) + self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type, **dd) self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type) + self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type, **dd) self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) def forward(self, x_conv0, x_stem_0): @@ -188,34 +252,44 @@ def forward(self, x_conv0, x_stem_0): class FirstCell(nn.Module): - def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): - super(FirstCell, self).__init__() - self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1) + def __init__( + self, + in_chs_left: int, + out_chs_left: int, + in_chs_right: int, + out_chs_right: int, + pad_type: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, **dd) self.act = nn.ReLU() self.path_1 = nn.Sequential() self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) - self.path_1.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False)) + self.path_1.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False, **dd)) self.path_2 = nn.Sequential() self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1))) self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) - self.path_2.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False)) + self.path_2.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False, **dd)) - self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2, eps=0.001, momentum=0.1) + self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2, eps=0.001, momentum=0.1, **dd) - self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) - self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type, **dd) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type, **dd) - self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) - self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type, **dd) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type, **dd) self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type, **dd) def forward(self, x, x_prev): x_relu = self.act(x_prev) @@ -248,23 +322,33 @@ def forward(self, x, x_prev): class NormalCell(nn.Module): - def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): - super(NormalCell, self).__init__() - self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) - self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) + def __init__( + self, + in_chs_left: int, + out_chs_left: int, + in_chs_right: int, + out_chs_right: int, + pad_type: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type, **dd) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type, **dd) - self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) - self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type) + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type, **dd) + self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type, **dd) - self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5, 1, pad_type) - self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type) + self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5, 1, pad_type, **dd) + self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type, **dd) self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type, **dd) def forward(self, x, x_prev): x_left = self.conv_prev_1x1(x_prev) @@ -294,23 +378,33 @@ def forward(self, x, x_prev): class ReductionCell0(nn.Module): - def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): - super(ReductionCell0, self).__init__() - self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) - self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) + def __init__( + self, + in_chs_left: int, + out_chs_left: int, + in_chs_right: int, + out_chs_right: int, + pad_type: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type, **dd) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type, **dd) - self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) - self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type, **dd) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type, **dd) self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) - self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type, **dd) self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) - self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type, **dd) self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type, **dd) self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) def forward(self, x, x_prev): @@ -342,23 +436,33 @@ def forward(self, x, x_prev): class ReductionCell1(nn.Module): - def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): - super(ReductionCell1, self).__init__() - self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) - self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) + def __init__( + self, + in_chs_left: int, + out_chs_left: int, + in_chs_right: int, + out_chs_right: int, + pad_type: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type, **dd) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type, **dd) - self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) - self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type, **dd) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type, **dd) self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) - self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type, **dd) self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) - self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type, **dd) self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) - self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type, **dd) self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) def forward(self, x, x_prev): @@ -393,17 +497,20 @@ class NASNetALarge(nn.Module): def __init__( self, - num_classes=1000, - in_chans=3, - stem_size=96, - channel_multiplier=2, - num_features=4032, - output_stride=32, - drop_rate=0., - global_pool='avg', - pad_type='same', + num_classes: int = 1000, + in_chans: int = 3, + stem_size: int = 96, + channel_multiplier: int = 2, + num_features: int = 4032, + output_stride: int = 32, + drop_rate: float = 0., + global_pool: str = 'avg', + pad_type: str = 'same', + device=None, + dtype=None, ): - super(NASNetALarge, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.stem_size = stem_size self.num_features = self.head_hidden_size = num_features @@ -414,76 +521,83 @@ def __init__( # 24 is default value for the architecture self.conv0 = ConvNormAct( - in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False) + in_channels=in_chans, + out_channels=self.stem_size, + kernel_size=3, + padding=0, + stride=2, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), + apply_act=False, + **dd, + ) self.cell_stem_0 = CellStem0( - self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type) + self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type, **dd) self.cell_stem_1 = CellStem1( - self.stem_size, num_channels=channels // channel_multiplier, pad_type=pad_type) + self.stem_size, num_channels=channels // channel_multiplier, pad_type=pad_type, **dd) self.cell_0 = FirstCell( in_chs_left=channels, out_chs_left=channels // 2, - in_chs_right=2 * channels, out_chs_right=channels, pad_type=pad_type) + in_chs_right=2 * channels, out_chs_right=channels, pad_type=pad_type, **dd) self.cell_1 = NormalCell( in_chs_left=2 * channels, out_chs_left=channels, - in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type, **dd) self.cell_2 = NormalCell( in_chs_left=6 * channels, out_chs_left=channels, - in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type, **dd) self.cell_3 = NormalCell( in_chs_left=6 * channels, out_chs_left=channels, - in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type, **dd) self.cell_4 = NormalCell( in_chs_left=6 * channels, out_chs_left=channels, - in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type, **dd) self.cell_5 = NormalCell( in_chs_left=6 * channels, out_chs_left=channels, - in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type, **dd) self.reduction_cell_0 = ReductionCell0( in_chs_left=6 * channels, out_chs_left=2 * channels, - in_chs_right=6 * channels, out_chs_right=2 * channels, pad_type=pad_type) + in_chs_right=6 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd) self.cell_6 = FirstCell( in_chs_left=6 * channels, out_chs_left=channels, - in_chs_right=8 * channels, out_chs_right=2 * channels, pad_type=pad_type) + in_chs_right=8 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd) self.cell_7 = NormalCell( in_chs_left=8 * channels, out_chs_left=2 * channels, - in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd) self.cell_8 = NormalCell( in_chs_left=12 * channels, out_chs_left=2 * channels, - in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd) self.cell_9 = NormalCell( in_chs_left=12 * channels, out_chs_left=2 * channels, - in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd) self.cell_10 = NormalCell( in_chs_left=12 * channels, out_chs_left=2 * channels, - in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd) self.cell_11 = NormalCell( in_chs_left=12 * channels, out_chs_left=2 * channels, - in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type, **dd) self.reduction_cell_1 = ReductionCell1( in_chs_left=12 * channels, out_chs_left=4 * channels, - in_chs_right=12 * channels, out_chs_right=4 * channels, pad_type=pad_type) + in_chs_right=12 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd) self.cell_12 = FirstCell( in_chs_left=12 * channels, out_chs_left=2 * channels, - in_chs_right=16 * channels, out_chs_right=4 * channels, pad_type=pad_type) + in_chs_right=16 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd) self.cell_13 = NormalCell( in_chs_left=16 * channels, out_chs_left=4 * channels, - in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd) self.cell_14 = NormalCell( in_chs_left=24 * channels, out_chs_left=4 * channels, - in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd) self.cell_15 = NormalCell( in_chs_left=24 * channels, out_chs_left=4 * channels, - in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd) self.cell_16 = NormalCell( in_chs_left=24 * channels, out_chs_left=4 * channels, - in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd) self.cell_17 = NormalCell( in_chs_left=24 * channels, out_chs_left=4 * channels, - in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type, **dd) self.act = nn.ReLU(inplace=True) self.feature_info = [ dict(num_chs=96, reduction=2, module='conv0'), @@ -494,7 +608,12 @@ def __init__( ] self.global_pool, self.head_drop, self.last_linear = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate) + self.num_features, + self.num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + **dd, + ) @torch.jit.ignore def group_matcher(self, coarse=False): diff --git a/timm/models/nest.py b/timm/models/nest.py index 120550001f..62aab6f023 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -19,15 +19,27 @@ import logging import math from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import torch import torch.nn.functional as F from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, DropPath, calculate_drop_path_rates, create_classifier, trunc_normal_, _assert -from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm +from timm.layers import ( + PatchEmbed, + Mlp, + DropPath, + calculate_drop_path_rates, + create_classifier, + trunc_normal_, + _assert, + create_conv2d, + create_pool2d, + to_ntuple, + use_fused_attn, + LayerNorm, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function @@ -46,16 +58,26 @@ class Attention(nn.Module): """ fused_attn: torch.jit.Final[bool] - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.qkv = nn.Linear(dim, 3*dim, bias=qkv_bias) + self.qkv = nn.Linear(dim, 3*dim, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): @@ -91,47 +113,62 @@ class TransformerLayer(nn.Module): """ def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, + **dd, ) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim, **dd) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop, + **dd, ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): y = self.norm1(x) - x = x + self.drop_path(self.attn(y)) - x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x + self.drop_path1(self.attn(y)) + x = x + self.drop_path2(self.mlp(self.norm2(x))) return x class ConvPool(nn.Module): - def __init__(self, in_channels, out_channels, norm_layer, pad_type=''): + def __init__( + self, + in_channels: int, + out_channels: int, + norm_layer: Type[nn.Module], + pad_type: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.conv = create_conv2d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True) - self.norm = norm_layer(out_channels) + self.conv = create_conv2d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True, **dd) + self.norm = norm_layer(out_channels, **dd) self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=pad_type) def forward(self, x): @@ -183,30 +220,33 @@ class NestLevel(nn.Module): """ def __init__( self, - num_blocks, - block_size, - seq_length, - num_heads, - depth, - embed_dim, - prev_embed_dim=None, - mlp_ratio=4., - qkv_bias=True, - proj_drop=0., - attn_drop=0., - drop_path=[], - norm_layer=None, - act_layer=None, - pad_type='', + num_blocks: int, + block_size: int, + seq_length: int, + num_heads: int, + depth: int, + embed_dim: int, + prev_embed_dim: Optional[int] = None, + mlp_ratio: float = 4., + qkv_bias: bool = True, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: Optional[List[float]] = None, + norm_layer: Optional[Type[nn.Module]] = None, + act_layer: Optional[Type[nn.Module]] = None, + pad_type: str = '', + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.block_size = block_size self.grad_checkpointing = False - self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim, **dd)) if prev_embed_dim is not None: - self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type) + self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type, **dd) else: self.pool = nn.Identity() @@ -221,9 +261,10 @@ def __init__( qkv_bias=qkv_bias, proj_drop=proj_drop, attn_drop=attn_drop, - drop_path=drop_path[i], + drop_path=drop_path[i] if drop_path else None, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) for i in range(depth)]) @@ -253,25 +294,27 @@ class Nest(nn.Module): def __init__( self, - img_size=224, - in_chans=3, - patch_size=4, - num_levels=3, - embed_dims=(128, 256, 512), - num_heads=(4, 8, 16), - depths=(2, 2, 20), - num_classes=1000, - mlp_ratio=4., - qkv_bias=True, - drop_rate=0., - proj_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.5, - norm_layer=None, - act_layer=None, - pad_type='', - weight_init='', - global_pool='avg', + img_size: int = 224, + in_chans: int = 3, + patch_size: int = 4, + num_levels: int = 3, + embed_dims: Tuple[int, ...] = (128, 256, 512), + num_heads: Tuple[int, ...] = (4, 8, 16), + depths: Tuple[int, ...] = (2, 2, 20), + num_classes: int = 1000, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0.5, + norm_layer: Optional[Type[nn.Module]] = None, + act_layer: Optional[Type[nn.Module]] = None, + pad_type: str = '', + weight_init: str = '', + global_pool: str = 'avg', + device=None, + dtype=None, ): """ Args: @@ -301,7 +344,7 @@ def __init__( - https://github.com/google-research/nested-transformer/issues/2 """ super().__init__() - + dd = {'device': device, 'dtype': dtype} for param_name in ['embed_dims', 'num_heads', 'depths']: param_value = locals()[param_name] if isinstance(param_value, collections.abc.Sequence): @@ -324,7 +367,7 @@ def __init__( self.patch_size = patch_size # Number of blocks at each level - self.num_blocks = (4 ** torch.arange(num_levels)).flip(0).tolist() + self.num_blocks = (4 ** torch.arange(num_levels, device='cpu', dtype=torch.long)).flip(0).tolist() assert (img_size // patch_size) % math.sqrt(self.num_blocks[0]) == 0, \ 'First level blocks don\'t fit evenly. Check `img_size`, `patch_size`, and `num_levels`' @@ -340,6 +383,7 @@ def __init__( in_chans=in_chans, embed_dim=embed_dims[0], flatten=False, + **dd, ) self.num_patches = self.patch_embed.num_patches self.seq_length = self.num_patches // self.num_blocks[0] @@ -367,6 +411,7 @@ def __init__( norm_layer=norm_layer, act_layer=act_layer, pad_type=pad_type, + **dd, )) self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f'levels.{i}')] prev_dim = dim @@ -374,10 +419,10 @@ def __init__( self.levels = nn.Sequential(*levels) # Final normalization layer - self.norm = norm_layer(embed_dims[-1]) + self.norm = norm_layer(embed_dims[-1], **dd) # Classifier - global_pool, head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + global_pool, head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool, **dd) self.global_pool = global_pool self.head_drop = nn.Dropout(drop_rate) self.head = head diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 250161c510..fbaf4e515a 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -6,7 +6,7 @@ """ # Copyright (c) ByteDance Inc. All rights reserved. from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type import torch import torch.nn.functional as F @@ -73,19 +73,29 @@ def merge_pre_bn(module, pre_bn_1, pre_bn_2=None): class ConvNormAct(nn.Module): def __init__( self, - in_chs, - out_chs, - kernel_size=3, - stride=1, - groups=1, - norm_layer=nn.BatchNorm2d, - act_layer=nn.ReLU, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + stride: int = 1, + groups: int = 1, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, ): - super(ConvNormAct, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.conv = nn.Conv2d( - in_chs, out_chs, kernel_size=kernel_size, stride=stride, - padding=1, groups=groups, bias=False) - self.norm = norm_layer(out_chs) + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + padding=1, + groups=groups, + bias=False, + **dd, + ) + self.norm = norm_layer(out_chs, **dd) self.act = act_layer() def forward(self, x): @@ -106,22 +116,25 @@ def _make_divisible(v, divisor, min_value=None): class PatchEmbed(nn.Module): - def __init__(self, - in_chs, - out_chs, - stride=1, - norm_layer = nn.BatchNorm2d, + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + device=None, + dtype=None, ): - super(PatchEmbed, self).__init__() - + dd = {'device': device, 'dtype': dtype} + super().__init__() if stride == 2: self.pool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False) - self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False) - self.norm = norm_layer(out_chs) + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False, **dd) + self.norm = norm_layer(out_chs, **dd) elif in_chs != out_chs: self.pool = nn.Identity() - self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False) - self.norm = norm_layer(out_chs) + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False, **dd) + self.norm = norm_layer(out_chs, **dd) else: self.pool = nn.Identity() self.conv = nn.Identity() @@ -136,15 +149,30 @@ class ConvAttention(nn.Module): Multi-Head Convolutional Attention """ - def __init__(self, out_chs, head_dim, norm_layer = nn.BatchNorm2d, act_layer = nn.ReLU): - super(ConvAttention, self).__init__() + def __init__( + self, + out_chs: int, + head_dim: int, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.group_conv3x3 = nn.Conv2d( - out_chs, out_chs, - kernel_size=3, stride=1, padding=1, groups=out_chs // head_dim, bias=False + out_chs, + out_chs, + kernel_size=3, + stride=1, + padding=1, + groups=out_chs // head_dim, + bias=False, + **dd, ) - self.norm = norm_layer(out_chs) + self.norm = norm_layer(out_chs, **dd) self.act = act_layer() - self.projection = nn.Conv2d(out_chs, out_chs, kernel_size=1, bias=False) + self.projection = nn.Conv2d(out_chs, out_chs, kernel_size=1, bias=False, **dd) def forward(self, x): out = self.group_conv3x3(x) @@ -160,37 +188,42 @@ class NextConvBlock(nn.Module): def __init__( self, - in_chs, - out_chs, - stride=1, - drop_path=0., - drop=0., - head_dim=32, - mlp_ratio=3., - norm_layer=nn.BatchNorm2d, - act_layer=nn.ReLU + in_chs: int, + out_chs: int, + stride: int = 1, + drop_path: float = 0., + drop: float = 0., + head_dim: int = 32, + mlp_ratio: float = 3., + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, ): - super(NextConvBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.in_chs = in_chs self.out_chs = out_chs assert out_chs % head_dim == 0 - self.patch_embed = PatchEmbed(in_chs, out_chs, stride, norm_layer=norm_layer) + self.patch_embed = PatchEmbed(in_chs, out_chs, stride, norm_layer=norm_layer, **dd) self.mhca = ConvAttention( out_chs, head_dim, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) self.attn_drop_path = DropPath(drop_path) - self.norm = norm_layer(out_chs) + self.norm = norm_layer(out_chs, **dd) self.mlp = ConvMlp( out_chs, hidden_features=int(out_chs * mlp_ratio), drop=drop, bias=True, act_layer=act_layer, + **dd, ) self.mlp_drop_path = DropPath(drop_path) self.is_fused = False @@ -219,15 +252,18 @@ class EfficientAttention(nn.Module): def __init__( self, - dim, - out_dim=None, - head_dim=32, - qkv_bias=True, - attn_drop=0., - proj_drop=0., - sr_ratio=1, - norm_layer=nn.BatchNorm1d, + dim: int, + out_dim: Optional[int] = None, + head_dim: int = 32, + qkv_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + sr_ratio: int = 1, + norm_layer: Type[nn.Module] = nn.BatchNorm1d, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.out_dim = out_dim if out_dim is not None else dim @@ -236,10 +272,10 @@ def __init__( self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.q = nn.Linear(dim, self.dim, bias=qkv_bias) - self.k = nn.Linear(dim, self.dim, bias=qkv_bias) - self.v = nn.Linear(dim, self.dim, bias=qkv_bias) - self.proj = nn.Linear(self.dim, self.out_dim) + self.q = nn.Linear(dim, self.dim, bias=qkv_bias, **dd) + self.k = nn.Linear(dim, self.dim, bias=qkv_bias, **dd) + self.v = nn.Linear(dim, self.dim, bias=qkv_bias, **dd) + self.proj = nn.Linear(self.dim, self.out_dim, **dd) self.attn_drop = nn.Dropout(attn_drop) self.proj_drop = nn.Dropout(proj_drop) @@ -247,7 +283,7 @@ def __init__( self.N_ratio = sr_ratio ** 2 if sr_ratio > 1: self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio) - self.norm = norm_layer(dim) + self.norm = norm_layer(dim, **dd) else: self.sr = None self.norm = None @@ -288,20 +324,23 @@ class NextTransformerBlock(nn.Module): def __init__( self, - in_chs, - out_chs, - drop_path, - stride=1, - sr_ratio=1, - mlp_ratio=2, - head_dim=32, - mix_block_ratio=0.75, - attn_drop=0., - drop=0., - norm_layer=nn.BatchNorm2d, - act_layer=nn.ReLU, + in_chs: int, + out_chs: int, + drop_path: float, + stride: int = 1, + sr_ratio: int = 1, + mlp_ratio: float = 2, + head_dim: int = 32, + mix_block_ratio: float = 0.75, + attn_drop: float = 0., + drop: float = 0., + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, ): - super(NextTransformerBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.in_chs = in_chs self.out_chs = out_chs self.mix_block_ratio = mix_block_ratio @@ -309,32 +348,41 @@ def __init__( self.mhsa_out_chs = _make_divisible(int(out_chs * mix_block_ratio), 32) self.mhca_out_chs = out_chs - self.mhsa_out_chs - self.patch_embed = PatchEmbed(in_chs, self.mhsa_out_chs, stride) - self.norm1 = norm_layer(self.mhsa_out_chs) + self.patch_embed = PatchEmbed(in_chs, self.mhsa_out_chs, stride, **dd) + self.norm1 = norm_layer(self.mhsa_out_chs, **dd) self.e_mhsa = EfficientAttention( self.mhsa_out_chs, head_dim=head_dim, sr_ratio=sr_ratio, attn_drop=attn_drop, proj_drop=drop, + **dd, ) self.mhsa_drop_path = DropPath(drop_path * mix_block_ratio) - self.projection = PatchEmbed(self.mhsa_out_chs, self.mhca_out_chs, stride=1, norm_layer=norm_layer) + self.projection = PatchEmbed( + self.mhsa_out_chs, + self.mhca_out_chs, + stride=1, + norm_layer=norm_layer, + **dd, + ) self.mhca = ConvAttention( self.mhca_out_chs, head_dim=head_dim, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) self.mhca_drop_path = DropPath(drop_path * (1 - mix_block_ratio)) - self.norm2 = norm_layer(out_chs) + self.norm2 = norm_layer(out_chs, **dd) self.mlp = ConvMlp( out_chs, hidden_features=int(out_chs * mlp_ratio), act_layer=act_layer, drop=drop, + **dd, ) self.mlp_drop_path = DropPath(drop_path) self.is_fused = False @@ -378,19 +426,22 @@ class NextStage(nn.Module): def __init__( self, - in_chs, - block_chs, - block_types, - stride=2, - sr_ratio=1, - mix_block_ratio=1.0, - drop=0., - attn_drop=0., - drop_path=0., - head_dim=32, - norm_layer=nn.BatchNorm2d, - act_layer=nn.ReLU, + in_chs: int, + block_chs: List[int], + block_types: List[Type[nn.Module]], + stride: int = 2, + sr_ratio: int = 1, + mix_block_ratio: float = 1.0, + drop: float = 0., + attn_drop: float = 0., + drop_path: Union[float, List[float], Tuple[float, ...]] = 0., + head_dim: int = 32, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.grad_checkpointing = False @@ -410,6 +461,7 @@ def __init__( head_dim=head_dim, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) blocks.append(layer) elif block_type is NextTransformerBlock: @@ -425,6 +477,7 @@ def __init__( drop=drop, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) blocks.append(layer) in_chs = out_chs @@ -446,22 +499,25 @@ def forward(self, x): class NextViT(nn.Module): def __init__( self, - in_chans, - num_classes=1000, - global_pool='avg', - stem_chs=(64, 32, 64), - depths=(3, 4, 10, 3), - strides=(1, 2, 2, 2), - sr_ratios=(8, 4, 2, 1), - drop_path_rate=0.1, - attn_drop_rate=0., - drop_rate=0., - head_dim=32, - mix_block_ratio=0.75, - norm_layer=nn.BatchNorm2d, - act_layer=None, + in_chans: int, + num_classes: int = 1000, + global_pool: str = 'avg', + stem_chs: Tuple[int, ...] = (64, 32, 64), + depths: Tuple[int, ...] = (3, 4, 10, 3), + strides: Tuple[int, ...] = (1, 2, 2, 2), + sr_ratios: Tuple[int, ...] = (8, 4, 2, 1), + drop_path_rate: float = 0.1, + attn_drop_rate: float = 0., + drop_rate: float = 0., + head_dim: int = 32, + mix_block_ratio: float = 0.75, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ): - super(NextViT, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} self.grad_checkpointing = False self.num_classes = num_classes norm_layer = get_norm_layer(norm_layer) @@ -490,10 +546,14 @@ def __init__( [NextConvBlock] * (depths[3] - 1) + [NextTransformerBlock]] self.stem = nn.Sequential( - ConvNormAct(in_chans, stem_chs[0], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer), - ConvNormAct(stem_chs[0], stem_chs[1], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer), - ConvNormAct(stem_chs[1], stem_chs[2], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer), - ConvNormAct(stem_chs[2], stem_chs[2], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer), + ConvNormAct( + in_chans, stem_chs[0], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer, **dd), + ConvNormAct( + stem_chs[0], stem_chs[1], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer, **dd), + ConvNormAct( + stem_chs[1], stem_chs[2], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer, **dd), + ConvNormAct( + stem_chs[2], stem_chs[2], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer, **dd), ) in_chs = out_chs = stem_chs[-1] stages = [] @@ -513,14 +573,15 @@ def __init__( drop_path=dpr[stage_idx], norm_layer=norm_layer, act_layer=act_layer, + **dd, ) in_chs = out_chs = self.stage_out_chs[stage_idx][-1] stages += [stage] idx += depths[stage_idx] self.num_features = self.head_hidden_size = out_chs self.stages = nn.Sequential(*stages) - self.norm = norm_layer(out_chs) - self.head = ClassifierHead(pool_type=global_pool, in_features=out_chs, num_classes=num_classes) + self.norm = norm_layer(out_chs, **dd) + self.head = ClassifierHead(pool_type=global_pool, in_features=out_chs, num_classes=num_classes, **dd) self.stage_out_idx = [sum(depths[:idx + 1]) - 1 for idx in range(len(depths))] self._initialize_weights() diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 309ee8c5fe..16974d9fb3 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -115,6 +115,8 @@ def __init__( dilation: int = 1, first_dilation: Optional[int] = None, conv_layer: Callable = ScaledStdConv2d, + device=None, + dtype=None, ): """Initialize DownsampleAvg. @@ -126,14 +128,14 @@ def __init__( first_dilation: First dilation rate (unused). conv_layer: Convolution layer type. """ - super(DownsampleAvg, self).__init__() + super().__init__() avg_stride = stride if dilation == 1 else 1 if stride > 1 or dilation > 1: avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) else: self.pool = nn.Identity() - self.conv = conv_layer(in_chs, out_chs, 1, stride=1) + self.conv = conv_layer(in_chs, out_chs, 1, stride=1, device=device, dtype=dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. @@ -172,6 +174,8 @@ def __init__( act_layer: Optional[Callable] = None, conv_layer: Callable = ScaledStdConv2d, drop_path_rate: float = 0., + device=None, + dtype=None, ): """Initialize NormFreeBlock. @@ -195,6 +199,7 @@ def __init__( conv_layer: Convolution layer type. drop_path_rate: Stochastic depth drop rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() first_dilation = first_dilation or dilation out_chs = out_chs or in_chs @@ -215,32 +220,33 @@ def __init__( dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer, + **dd, ) else: self.downsample = None self.act1 = act_layer() - self.conv1 = conv_layer(in_chs, mid_chs, 1) + self.conv1 = conv_layer(in_chs, mid_chs, 1, **dd) self.act2 = act_layer(inplace=True) - self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups, **dd) if extra_conv: self.act2b = act_layer(inplace=True) - self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups) + self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups, **dd) else: self.act2b = None self.conv2b = None if reg and attn_layer is not None: - self.attn = attn_layer(mid_chs) # RegNet blocks apply attn btw conv2 & 3 + self.attn = attn_layer(mid_chs, **dd) # RegNet blocks apply attn btw conv2 & 3 else: self.attn = None self.act3 = act_layer() - self.conv3 = conv_layer(mid_chs, out_chs, 1, gain_init=1. if skipinit else 0.) + self.conv3 = conv_layer(mid_chs, out_chs, 1, gain_init=1. if skipinit else 0., **dd) if not reg and attn_layer is not None: - self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3 + self.attn_last = attn_layer(out_chs, **dd) # ResNet blocks apply attn after conv3 else: self.attn_last = None self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() - self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None + self.skipinit_gain = nn.Parameter(torch.tensor(0., **dd)) if skipinit else None def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. @@ -283,6 +289,8 @@ def create_stem( conv_layer: Optional[Callable] = None, act_layer: Optional[Callable] = None, preact_feature: bool = True, + device=None, + dtype=None, ) -> Tuple[nn.Sequential, int, Dict[str, Any]]: """Create stem module for NFNet models. @@ -297,6 +305,7 @@ def create_stem( Returns: Tuple of (stem_module, stem_stride, stem_feature_info). """ + dd = {'device': device, 'dtype': dtype} stem_stride = 2 stem_feature = dict(num_chs=out_chs, reduction=2, module='stem.conv') stem = OrderedDict() @@ -318,16 +327,16 @@ def create_stem( stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv2') last_idx = len(stem_chs) - 1 for i, (c, s) in enumerate(zip(stem_chs, strides)): - stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s) + stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s, **dd) if i != last_idx: stem[f'act{i + 2}'] = act_layer(inplace=True) in_chs = c elif '3x3' in stem_type: # 3x3 stem conv as in RegNet - stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2) + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2, **dd) else: # 7x7 stem conv as in ResNet - stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2, **dd) if 'pool' in stem_type: stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1) @@ -387,6 +396,8 @@ def __init__( output_stride: int = 32, drop_rate: float = 0., drop_path_rate: float = 0., + device=None, + dtype=None, **kwargs: Any, ): """ @@ -401,6 +412,7 @@ def __init__( **kwargs: Extra kwargs overlayed onto cfg. """ super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.drop_rate = drop_rate self.grad_checkpointing = False @@ -423,6 +435,7 @@ def __init__( cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer, + **dd, ) self.feature_info = [stem_feat] @@ -462,6 +475,7 @@ def __init__( act_layer=act_layer, conv_layer=conv_layer, drop_path_rate=drop_path_rates[stage_idx][block_idx], + **dd, )] if block_idx == 0: expected_var = 1. # expected var is reset after first block of each stage @@ -475,7 +489,7 @@ def __init__( if cfg.num_features: # The paper NFRegNet models have an EfficientNet-like final head convolution. self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div) - self.final_conv = conv_layer(prev_chs, self.num_features, 1) + self.final_conv = conv_layer(prev_chs, self.num_features, 1, **dd) self.feature_info[-1] = dict(num_chs=self.num_features, reduction=net_stride, module=f'final_conv') else: self.num_features = prev_chs @@ -488,6 +502,7 @@ def __init__( num_classes, pool_type=global_pool, drop_rate=self.drop_rate, + **dd, ) for n, m in self.named_modules(): diff --git a/timm/models/pit.py b/timm/models/pit.py index c8b1965a6b..dca27f598f 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -14,7 +14,7 @@ import math import re from functools import partial -from typing import List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union, Type, Any import torch from torch import nn @@ -32,9 +32,6 @@ class SequentialTuple(nn.Sequential): """ This module exists to work around torchscript typing issues list -> list""" - def __init__(self, *args): - super(SequentialTuple, self).__init__(*args) - def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: for module in self: x = module(x) @@ -44,21 +41,24 @@ def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, t class Transformer(nn.Module): def __init__( self, - base_dim, - depth, - heads, - mlp_ratio, - pool=None, - proj_drop=.0, - attn_drop=.0, - drop_path_prob=None, - norm_layer=None, + base_dim: int, + depth: int, + heads: int, + mlp_ratio: float, + pool: Optional[Any] = None, + proj_drop: float = .0, + attn_drop: float = .0, + drop_path_prob: Optional[List[float]] = None, + norm_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ): - super(Transformer, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() embed_dim = base_dim * heads self.pool = pool - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + self.norm = norm_layer(embed_dim, **dd) if norm_layer else nn.Identity() self.blocks = nn.Sequential(*[ Block( dim=embed_dim, @@ -68,7 +68,8 @@ def __init__( proj_drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path_prob[i], - norm_layer=partial(nn.LayerNorm, eps=1e-6) + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **dd, ) for i in range(depth)]) @@ -93,8 +94,17 @@ def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, t class Pooling(nn.Module): - def __init__(self, in_feature, out_feature, stride, padding_mode='zeros'): - super(Pooling, self).__init__() + def __init__( + self, + in_feature: int, + out_feature: int, + stride: int, + padding_mode: str = 'zeros', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.conv = nn.Conv2d( in_feature, @@ -104,8 +114,9 @@ def __init__(self, in_feature, out_feature, stride, padding_mode='zeros'): stride=stride, padding_mode=padding_mode, groups=in_feature, + **dd, ) - self.fc = nn.Linear(in_feature, out_feature) + self.fc = nn.Linear(in_feature, out_feature, **dd) def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]: x = self.conv(x) @@ -116,14 +127,17 @@ def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]: class ConvEmbedding(nn.Module): def __init__( self, - in_channels, - out_channels, + in_channels: int, + out_channels: int, img_size: int = 224, patch_size: int = 16, stride: int = 8, padding: int = 0, + device=None, + dtype=None, ): - super(ConvEmbedding, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() padding = padding self.img_size = to_2tuple(img_size) self.patch_size = to_2tuple(patch_size) @@ -132,8 +146,14 @@ def __init__( self.grid_size = (self.height, self.width) self.conv = nn.Conv2d( - in_channels, out_channels, kernel_size=patch_size, - stride=stride, padding=padding, bias=True) + in_channels, + out_channels, + kernel_size=patch_size, + stride=stride, + padding=padding, + bias=True, + **dd, + ) def forward(self, x): x = self.conv(x) @@ -156,17 +176,20 @@ def __init__( depth: Sequence[int] = (2, 6, 4), heads: Sequence[int] = (2, 4, 8), mlp_ratio: float = 4, - num_classes=1000, - in_chans=3, - global_pool='token', - distilled=False, - drop_rate=0., - pos_drop_drate=0., - proj_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., + num_classes: int = 1000, + in_chans: int = 3, + global_pool: str = 'token', + distilled: bool = False, + drop_rate: float = 0., + pos_drop_drate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + device=None, + dtype=None, ): - super(PoolingVisionTransformer, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} assert global_pool in ('token',) self.base_dims = base_dims @@ -177,9 +200,9 @@ def __init__( self.num_tokens = 2 if distilled else 1 self.feature_info = [] - self.patch_embed = ConvEmbedding(in_chans, embed_dim, img_size, patch_size, stride) - self.pos_embed = nn.Parameter(torch.randn(1, embed_dim, self.patch_embed.height, self.patch_embed.width)) - self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim)) + self.patch_embed = ConvEmbedding(in_chans, embed_dim, img_size, patch_size, stride, **dd) + self.pos_embed = nn.Parameter(torch.randn(1, embed_dim, self.patch_embed.height, self.patch_embed.width, **dd)) + self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim, **dd)) self.pos_drop = nn.Dropout(p=pos_drop_drate) transformers = [] @@ -194,6 +217,7 @@ def __init__( prev_dim, embed_dim, stride=2, + **dd, ) transformers += [Transformer( base_dims[i], @@ -204,20 +228,21 @@ def __init__( proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path_prob=dpr[i], + **dd, )] prev_dim = embed_dim self.feature_info += [dict(num_chs=prev_dim, reduction=(stride - 1) * 2**i, module=f'transformers.{i}')] self.transformers = SequentialTuple(*transformers) - self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6) + self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6, **dd) self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # Classifier head self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity() self.head_dist = None if distilled: - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.embed_dim, self.num_classes, **dd) if num_classes > 0 else nn.Identity() self.distilled_training = False # must set this True to train w/ distillation token trunc_normal_(self.pos_embed, std=.02) @@ -251,9 +276,11 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + device = self.head.weight.device if hasattr(self.head, 'weight') else None + dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None + self.head = nn.Linear(self.embed_dim, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity() if self.head_dist is not None: - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.embed_dim, self.num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity() def forward_intermediates( self, diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 7f33aaeabb..f8a16bd9e5 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -7,6 +7,7 @@ """ from collections import OrderedDict from functools import partial +from typing import Type import torch import torch.nn as nn @@ -20,13 +21,34 @@ class SeparableConv2d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''): - super(SeparableConv2d, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.depthwise_conv2d = create_conv2d( - in_channels, in_channels, kernel_size=kernel_size, - stride=stride, padding=padding, groups=in_channels) + in_channels, + in_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=in_channels, + **dd, + ) self.pointwise_conv2d = create_conv2d( - in_channels, out_channels, kernel_size=1, padding=padding) + in_channels, + out_channels, + kernel_size=1, + padding=padding, + **dd, + ) def forward(self, x): x = self.depthwise_conv2d(x) @@ -36,17 +58,40 @@ def forward(self, x): class BranchSeparables(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, stem_cell=False, padding=''): - super(BranchSeparables, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + stem_cell: bool = False, + padding: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() middle_channels = out_channels if stem_cell else in_channels self.act_1 = nn.ReLU() self.separable_1 = SeparableConv2d( - in_channels, middle_channels, kernel_size, stride=stride, padding=padding) - self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001) + in_channels, + middle_channels, + kernel_size, + stride=stride, + padding=padding, + **dd, + ) + self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, **dd) self.act_2 = nn.ReLU() self.separable_2 = SeparableConv2d( - middle_channels, out_channels, kernel_size, stride=1, padding=padding) - self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001) + middle_channels, + out_channels, + kernel_size, + stride=1, + padding=padding, + **dd, + ) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, **dd) def forward(self, x): x = self.act_1(x) @@ -60,12 +105,28 @@ def forward(self, x): class ActConvBn(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''): - super(ActConvBn, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.act = nn.ReLU() self.conv = create_conv2d( - in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) - self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + **dd, + ) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001, **dd) def forward(self, x): x = self.act(x) @@ -76,19 +137,27 @@ def forward(self, x): class FactorizedReduction(nn.Module): - def __init__(self, in_channels, out_channels, padding=''): - super(FactorizedReduction, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: int, + padding: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.act = nn.ReLU() self.path_1 = nn.Sequential(OrderedDict([ ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), - ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)), + ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding, **dd)), ])) self.path_2 = nn.Sequential(OrderedDict([ ('pad', nn.ZeroPad2d((-1, 1, -1, 1))), # shift ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), - ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)), + ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding, **dd)), ])) - self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001) + self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001, **dd) def forward(self, x): x = self.act(x) @@ -130,35 +199,45 @@ def cell_forward(self, x_left, x_right): class CellStem0(CellBase): - def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): - super(CellStem0, self).__init__() - self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type) + def __init__( + self, + in_chs_left: int, + out_chs_left: int, + in_chs_right: int, + out_chs_right: int, + pad_type: str = '', + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type, **dd) self.comb_iter_0_left = BranchSeparables( - in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=pad_type) + in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=pad_type, **dd) self.comb_iter_0_right = nn.Sequential(OrderedDict([ ('max_pool', create_pool2d('max', 3, stride=2, padding=pad_type)), - ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)), - ('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)), + ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type, **dd)), + ('bn', nn.BatchNorm2d(out_chs_left, eps=0.001, **dd)), ])) self.comb_iter_1_left = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=pad_type) + out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=pad_type, **dd) self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=pad_type) self.comb_iter_2_left = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=pad_type) + out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=pad_type, **dd) self.comb_iter_2_right = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=pad_type) + out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=pad_type, **dd) self.comb_iter_3_left = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=3, padding=pad_type) + out_chs_right, out_chs_right, kernel_size=3, padding=pad_type, **dd) self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=pad_type) self.comb_iter_4_left = BranchSeparables( - in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=pad_type) + in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=pad_type, **dd) self.comb_iter_4_right = ActConvBn( - out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=pad_type) + out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=pad_type, **dd) def forward(self, x_left): x_right = self.conv_1x1(x_left) @@ -170,16 +249,18 @@ class Cell(CellBase): def __init__( self, - in_chs_left, - out_chs_left, - in_chs_right, - out_chs_right, - pad_type='', - is_reduction=False, - match_prev_layer_dims=False, + in_chs_left: int, + out_chs_left: int, + in_chs_right: int, + out_chs_right: int, + pad_type: str = '', + is_reduction: bool = False, + match_prev_layer_dims: bool = False, + device=None, + dtype=None, ): - super(Cell, self).__init__() - + dd = {'device': device, 'dtype': dtype} + super().__init__() # If `is_reduction` is set to `True` stride 2 is used for # convolution and pooling layers to reduce the spatial size of # the output of a cell approximately by a factor of 2. @@ -190,32 +271,32 @@ def __init__( # of the left input of a cell approximately by a factor of 2. self.match_prev_layer_dimensions = match_prev_layer_dims if match_prev_layer_dims: - self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=pad_type) + self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=pad_type, **dd) else: - self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type) - self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type) + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type, **dd) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type, **dd) self.comb_iter_0_left = BranchSeparables( - out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=pad_type) + out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=pad_type, **dd) self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=pad_type) self.comb_iter_1_left = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=pad_type) + out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=pad_type, **dd) self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=pad_type) self.comb_iter_2_left = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=pad_type) + out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=pad_type, **dd) self.comb_iter_2_right = BranchSeparables( - out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=pad_type) + out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=pad_type, **dd) - self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3) + self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3, **dd) self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=pad_type) self.comb_iter_4_left = BranchSeparables( - out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=pad_type) + out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=pad_type, **dd) if is_reduction: self.comb_iter_4_right = ActConvBn( - out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=pad_type) + out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=pad_type, **dd) else: self.comb_iter_4_right = None @@ -229,59 +310,62 @@ def forward(self, x_left, x_right): class PNASNet5Large(nn.Module): def __init__( self, - num_classes=1000, - in_chans=3, - output_stride=32, - drop_rate=0., - global_pool='avg', - pad_type='', + num_classes: int = 1000, + in_chans: int = 3, + output_stride: int = 32, + drop_rate: float = 0., + global_pool: str = 'avg', + pad_type: str = '', + device=None, + dtype=None, ): - super(PNASNet5Large, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.num_features = self.head_hidden_size = 4320 assert output_stride == 32 self.conv_0 = ConvNormAct( in_chans, 96, kernel_size=3, stride=2, padding=0, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False) + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False, **dd) self.cell_stem_0 = CellStem0( - in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type) + in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type, **dd) self.cell_stem_1 = Cell( in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, pad_type=pad_type, - match_prev_layer_dims=True, is_reduction=True) + match_prev_layer_dims=True, is_reduction=True, **dd) self.cell_0 = Cell( in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, pad_type=pad_type, - match_prev_layer_dims=True) + match_prev_layer_dims=True, **dd) self.cell_1 = Cell( - in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) + in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type, **dd) self.cell_2 = Cell( - in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) + in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type, **dd) self.cell_3 = Cell( - in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) + in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type, **dd) self.cell_4 = Cell( in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, pad_type=pad_type, - is_reduction=True) + is_reduction=True, **dd) self.cell_5 = Cell( in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type, - match_prev_layer_dims=True) + match_prev_layer_dims=True, **dd) self.cell_6 = Cell( - in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type) + in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type, **dd) self.cell_7 = Cell( - in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type) + in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type, **dd) self.cell_8 = Cell( in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, pad_type=pad_type, - is_reduction=True) + is_reduction=True, **dd) self.cell_9 = Cell( in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type, - match_prev_layer_dims=True) + match_prev_layer_dims=True, **dd) self.cell_10 = Cell( - in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type) + in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type, **dd) self.cell_11 = Cell( - in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type) + in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type, **dd) self.act = nn.ReLU() self.feature_info = [ dict(num_chs=96, reduction=2, module='conv_0'), @@ -292,7 +376,7 @@ def __init__( ] self.global_pool, self.head_drop, self.last_linear = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate) + self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate, **dd) @torch.jit.ignore def group_matcher(self, coarse=False): @@ -306,10 +390,11 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.last_linear - def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg', device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.global_pool, self.last_linear = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool) + self.num_features, self.num_classes, pool_type=global_pool, **dd) def forward_features(self, x): x_conv_0 = self.conv_0(x) diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index ac8c90492a..1a00700022 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -16,7 +16,7 @@ """ import math -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union, Type, Any import torch import torch.nn as nn @@ -35,21 +35,24 @@ class MlpWithDepthwiseConv(nn.Module): def __init__( self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0., - extra_relu=False, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Type[nn.Module] = nn.GELU, + drop: float = 0., + extra_relu: bool = False, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = nn.Linear(in_features, hidden_features, **dd) self.relu = nn.ReLU() if extra_relu else nn.Identity() - self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features) + self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features, **dd) self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = nn.Linear(hidden_features, out_features, **dd) self.drop = nn.Dropout(drop) def forward(self, x, feat_size: List[int]): @@ -71,15 +74,18 @@ class Attention(nn.Module): def __init__( self, - dim, - num_heads=8, - sr_ratio=1, - linear_attn=False, - qkv_bias=True, - attn_drop=0., - proj_drop=0. + dim: int, + num_heads: int = 8, + sr_ratio: int = 1, + linear_attn: bool = False, + qkv_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." self.dim = dim @@ -88,25 +94,25 @@ def __init__( self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.q = nn.Linear(dim, dim, bias=qkv_bias) - self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) if not linear_attn: self.pool = None if sr_ratio > 1: - self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) - self.norm = nn.LayerNorm(dim) + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio, **dd) + self.norm = nn.LayerNorm(dim, **dd) else: self.sr = None self.norm = None self.act = None else: self.pool = nn.AdaptiveAvgPool2d(7) - self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) - self.norm = nn.LayerNorm(dim) + self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1, **dd) + self.norm = nn.LayerNorm(dim, **dd) self.act = nn.GELU() def forward(self, x, feat_size: List[int]): @@ -149,20 +155,23 @@ class Block(nn.Module): def __init__( self, - dim, - num_heads, - mlp_ratio=4., - sr_ratio=1, - linear_attn=False, - qkv_bias=False, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=LayerNorm, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + sr_ratio: int = 1, + linear_attn: bool = False, + qkv_bias: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = LayerNorm, + device=None, + dtype=None, ): super().__init__() - self.norm1 = norm_layer(dim) + dd = {'device': device, 'dtype': dtype} + self.norm1 = norm_layer(dim, **dd) self.attn = Attention( dim, num_heads=num_heads, @@ -171,16 +180,18 @@ def __init__( qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, + **dd, ) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = MlpWithDepthwiseConv( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, extra_relu=linear_attn, + **dd, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -194,15 +205,24 @@ def forward(self, x, feat_size: List[int]): class OverlapPatchEmbed(nn.Module): """ Image to Patch Embedding """ - def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768): + def __init__( + self, + patch_size: Union[int, Tuple[int, int]] = 7, + stride: int = 4, + in_chans: int = 3, + embed_dim: int = 768, + device=None, + dtype=None, + ): super().__init__() + dd = {'device': device, 'dtype': dtype} patch_size = to_2tuple(patch_size) assert max(patch_size) > stride, "Set larger patch_size than stride" self.patch_size = patch_size self.proj = nn.Conv2d( in_chans, embed_dim, patch_size, - stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2)) - self.norm = nn.LayerNorm(embed_dim) + stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2), **dd) + self.norm = nn.LayerNorm(embed_dim, **dd) def forward(self, x): x = self.proj(x) @@ -227,8 +247,11 @@ def __init__( attn_drop: float = 0., drop_path: Union[List[float], float] = 0.0, norm_layer: Callable = LayerNorm, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.grad_checkpointing = False if downsample: @@ -237,6 +260,7 @@ def __init__( stride=2, in_chans=dim, embed_dim=dim_out, + **dd, ) else: assert dim == dim_out @@ -253,9 +277,10 @@ def __init__( attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, + **dd, ) for i in range(depth)]) - self.norm = norm_layer(dim_out) + self.norm = norm_layer(dim_out, **dd) def forward(self, x): # x is either B, C, H, W (if downsample) or B, H, W, C if not @@ -278,23 +303,26 @@ def forward(self, x): class PyramidVisionTransformerV2(nn.Module): def __init__( self, - in_chans=3, - num_classes=1000, - global_pool='avg', - depths=(3, 4, 6, 3), - embed_dims=(64, 128, 256, 512), - num_heads=(1, 2, 4, 8), - sr_ratios=(8, 4, 2, 1), - mlp_ratios=(8., 8., 4., 4.), - qkv_bias=True, - linear=False, - drop_rate=0., - proj_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - norm_layer=LayerNorm, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + depths: Tuple[int, ...] = (3, 4, 6, 3), + embed_dims: Tuple[int, ...] = (64, 128, 256, 512), + num_heads: Tuple[int, ...] = (1, 2, 4, 8), + sr_ratios: Tuple[int, ...] = (8, 4, 2, 1), + mlp_ratios: Tuple[float, ...] = (8., 8., 4., 4.), + qkv_bias: bool = True, + linear: bool = False, + drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_layer: Type[nn.Module] = LayerNorm, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes assert global_pool in ('avg', '') self.global_pool = global_pool @@ -311,6 +339,7 @@ def __init__( stride=4, in_chans=in_chans, embed_dim=embed_dims[0], + **dd, ) dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) @@ -332,6 +361,7 @@ def __init__( attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + **dd, )] prev_dim = embed_dims[i] cur += depths[i] @@ -341,7 +371,7 @@ def __init__( # classification head self.num_features = self.head_hidden_size = embed_dims[-1] self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(embed_dims[-1], num_classes, **dd) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) @@ -385,7 +415,9 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): if global_pool is not None: assert global_pool in ('avg', '') self.global_pool = global_pool - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + device = self.head.weight.device if hasattr(self.head, 'weight') else None + dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None + self.head = nn.Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity() def forward_intermediates( self, diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py index a5961571c0..2f86d93dcc 100644 --- a/timm/models/rdnet.py +++ b/timm/models/rdnet.py @@ -5,7 +5,7 @@ """ from functools import partial -from typing import List, Optional, Tuple, Union, Callable +from typing import List, Optional, Tuple, Union, Callable, Type import torch import torch.nn as nn @@ -22,14 +22,24 @@ class Block(nn.Module): - def __init__(self, in_chs, inter_chs, out_chs, norm_layer, act_layer): + def __init__( + self, + in_chs: int, + inter_chs: int, + out_chs: int, + norm_layer: Type[nn.Module], + act_layer: Type[nn.Module], + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.layers = nn.Sequential( - nn.Conv2d(in_chs, in_chs, groups=in_chs, kernel_size=7, stride=1, padding=3), - norm_layer(in_chs), - nn.Conv2d(in_chs, inter_chs, kernel_size=1, stride=1, padding=0), + nn.Conv2d(in_chs, in_chs, groups=in_chs, kernel_size=7, stride=1, padding=3, **dd), + norm_layer(in_chs, **dd), + nn.Conv2d(in_chs, inter_chs, kernel_size=1, stride=1, padding=0, **dd), act_layer(), - nn.Conv2d(inter_chs, out_chs, kernel_size=1, stride=1, padding=0), + nn.Conv2d(inter_chs, out_chs, kernel_size=1, stride=1, padding=0, **dd), ) def forward(self, x): @@ -37,15 +47,25 @@ def forward(self, x): class BlockESE(nn.Module): - def __init__(self, in_chs, inter_chs, out_chs, norm_layer, act_layer): + def __init__( + self, + in_chs: int, + inter_chs: int, + out_chs: int, + norm_layer: Type[nn.Module], + act_layer: Type[nn.Module], + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.layers = nn.Sequential( - nn.Conv2d(in_chs, in_chs, groups=in_chs, kernel_size=7, stride=1, padding=3), - norm_layer(in_chs), - nn.Conv2d(in_chs, inter_chs, kernel_size=1, stride=1, padding=0), + nn.Conv2d(in_chs, in_chs, groups=in_chs, kernel_size=7, stride=1, padding=3, **dd), + norm_layer(in_chs, **dd), + nn.Conv2d(in_chs, inter_chs, kernel_size=1, stride=1, padding=0, **dd), act_layer(), - nn.Conv2d(inter_chs, out_chs, kernel_size=1, stride=1, padding=0), - EffectiveSEModule(out_chs), + nn.Conv2d(inter_chs, out_chs, kernel_size=1, stride=1, padding=0, **dd), + EffectiveSEModule(out_chs, **dd), ) def forward(self, x): @@ -74,9 +94,12 @@ def __init__( block_idx: int = 0, block_type: str = "Block", ls_init_value: float = 1e-6, - norm_layer: str = "layernorm2d", - act_layer: str = "gelu", + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.drop_rate = drop_rate self.drop_path_rate = drop_path_rate @@ -84,7 +107,7 @@ def __init__( self.block_idx = block_idx self.growth_rate = growth_rate - self.gamma = nn.Parameter(ls_init_value * torch.ones(growth_rate)) if ls_init_value > 0 else None + self.gamma = nn.Parameter(ls_init_value * torch.ones(growth_rate, **dd)) if ls_init_value > 0 else None growth_rate = int(growth_rate) inter_chs = int(num_input_features * bottleneck_width_ratio / 8) * 8 @@ -96,6 +119,7 @@ def __init__( out_chs=growth_rate, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) def forward(self, x: List[torch.Tensor]) -> torch.Tensor: @@ -110,7 +134,17 @@ def forward(self, x: List[torch.Tensor]) -> torch.Tensor: class DenseStage(nn.Sequential): - def __init__(self, num_block, num_input_features, drop_path_rates, growth_rate, **kwargs): + def __init__( + self, + num_block: int, + num_input_features: int, + drop_path_rates: List[float], + growth_rate: int, + device=None, + dtype=None, + **kwargs, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() for i in range(num_block): layer = DenseBlock( @@ -118,6 +152,7 @@ def __init__(self, num_block, num_input_features, drop_path_rates, growth_rate, growth_rate=growth_rate, drop_path_rate=drop_path_rates[i], block_idx=i, + **dd, **kwargs, ) num_input_features += growth_rate @@ -156,6 +191,8 @@ def __init__( norm_eps: Optional[float] = None, drop_rate: float = 0.0, # timm option [--drop: dropout ratio] drop_path_rate: float = 0.0, # timm option [--drop-path: drop-path ratio] + device=None, + dtype=None, ): """ Args: @@ -181,6 +218,7 @@ def __init__( drop_path_rate: Stochastic depth drop rate. """ super().__init__() + dd = {'device': device, 'dtype': dtype} assert len(growth_rates) == len(num_blocks_list) == len(is_downsample_block) act_layer = get_act_layer(act_layer) norm_layer = get_norm_layer(norm_layer) @@ -195,16 +233,16 @@ def __init__( if stem_type == 'patch': # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 self.stem = nn.Sequential( - nn.Conv2d(in_chans, num_init_features, kernel_size=patch_size, stride=patch_size, bias=conv_bias), - norm_layer(num_init_features), + nn.Conv2d(in_chans, num_init_features, kernel_size=patch_size, stride=patch_size, bias=conv_bias, **dd), + norm_layer(num_init_features, **dd), ) stem_stride = patch_size else: mid_chs = make_divisible(num_init_features // 2) if 'tiered' in stem_type else num_init_features self.stem = nn.Sequential( - nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias), - nn.Conv2d(mid_chs, num_init_features, kernel_size=3, stride=2, padding=1, bias=conv_bias), - norm_layer(num_init_features), + nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd), + nn.Conv2d(mid_chs, num_init_features, kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd), + norm_layer(num_init_features, **dd), ) stem_stride = 4 @@ -225,10 +263,15 @@ def __init__( curr_stride *= 2 k_size = stride = 2 - dense_stage_layers.append(norm_layer(num_features)) - dense_stage_layers.append( - nn.Conv2d(num_features, compressed_num_features, kernel_size=k_size, stride=stride, padding=0) - ) + dense_stage_layers.append(norm_layer(num_features, **dd)) + dense_stage_layers.append(nn.Conv2d( + num_features, + compressed_num_features, + kernel_size=k_size, + stride=stride, + padding=0, + **dd, + )) num_features = compressed_num_features stage = DenseStage( @@ -242,6 +285,7 @@ def __init__( block_type=block_type[i], norm_layer=norm_layer, act_layer=act_layer, + **dd, ) dense_stage_layers.append(stage) num_features += num_blocks_list[i] * growth_rates[i] @@ -262,12 +306,13 @@ def __init__( # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # otherwise pool -> norm -> fc, the default RDNet ordering (pretrained NV weights) if head_norm_first: - self.norm_pre = norm_layer(self.num_features) + self.norm_pre = norm_layer(self.num_features, **dd) self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, + **dd, ) else: self.norm_pre = nn.Identity() @@ -277,6 +322,7 @@ def __init__( pool_type=global_pool, drop_rate=self.drop_rate, norm_layer=norm_layer, + **dd, ) named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index e42bf44641..b062eb2c2c 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -26,7 +26,7 @@ import math from dataclasses import dataclass, replace from functools import partial -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Union, Tuple, Type import torch import torch.nn as nn @@ -142,8 +142,10 @@ def downsample_conv( kernel_size: int = 1, stride: int = 1, dilation: int = 1, - norm_layer: Optional[Callable] = None, + norm_layer: Optional[Type[nn.Module]] = None, preact: bool = False, + device=None, + dtype=None, ) -> nn.Module: """Create convolutional downsampling module. @@ -159,6 +161,7 @@ def downsample_conv( Returns: Downsampling module. """ + dd = {'device': device, 'dtype': dtype} norm_layer = norm_layer or nn.BatchNorm2d kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size dilation = dilation if kernel_size > 1 else 1 @@ -169,6 +172,7 @@ def downsample_conv( kernel_size, stride=stride, dilation=dilation, + **dd, ) else: return ConvNormAct( @@ -179,6 +183,7 @@ def downsample_conv( dilation=dilation, norm_layer=norm_layer, apply_act=False, + **dd, ) @@ -188,8 +193,10 @@ def downsample_avg( kernel_size: int = 1, stride: int = 1, dilation: int = 1, - norm_layer: Optional[Callable] = None, + norm_layer: Optional[Type[nn.Module]] = None, preact: bool = False, + device=None, + dtype=None, ) -> nn.Sequential: """Create average pool downsampling module. @@ -207,6 +214,7 @@ def downsample_avg( Returns: Sequential downsampling module. """ + dd = {'device': device, 'dtype': dtype} norm_layer = norm_layer or nn.BatchNorm2d avg_stride = stride if dilation == 1 else 1 pool = nn.Identity() @@ -214,9 +222,9 @@ def downsample_avg( avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) if preact: - conv = create_conv2d(in_chs, out_chs, 1, stride=1) + conv = create_conv2d(in_chs, out_chs, 1, stride=1, **dd) else: - conv = ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False) + conv = ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False, **dd) return nn.Sequential(*[pool, conv]) @@ -227,8 +235,10 @@ def create_shortcut( kernel_size: int, stride: int, dilation: Tuple[int, int] = (1, 1), - norm_layer: Optional[Callable] = None, + norm_layer: Optional[Type[nn.Module]] = None, preact: bool = False, + device=None, + dtype=None, ) -> Optional[nn.Module]: """Create shortcut connection for residual blocks. @@ -245,9 +255,10 @@ def create_shortcut( Returns: Shortcut module or None. """ + dd = {'device': device, 'dtype': dtype} assert downsample_type in ('avg', 'conv1x1', '', None) if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact) + dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact, **dd) if not downsample_type: return None # no shortcut, no downsample elif downsample_type == 'avg': @@ -276,10 +287,12 @@ def __init__( se_ratio: float = 0.25, downsample: str = 'conv1x1', linear_out: bool = False, - act_layer: Callable = nn.ReLU, - norm_layer: Callable = nn.BatchNorm2d, - drop_block=None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + drop_block: Optional[Type[nn.Module]] = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): """Initialize RegNet Bottleneck block. @@ -298,13 +311,14 @@ def __init__( drop_block: Drop block layer. drop_path_rate: Stochastic depth drop rate. """ - super(Bottleneck, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() act_layer = get_act_layer(act_layer) bottleneck_chs = int(round(out_chs * bottle_ratio)) groups = bottleneck_chs // group_size cargs = dict(act_layer=act_layer, norm_layer=norm_layer) - self.conv1 = ConvNormAct(in_chs, bottleneck_chs, kernel_size=1, **cargs) + self.conv1 = ConvNormAct(in_chs, bottleneck_chs, kernel_size=1, **cargs, **dd) self.conv2 = ConvNormAct( bottleneck_chs, bottleneck_chs, @@ -314,13 +328,14 @@ def __init__( groups=groups, drop_layer=drop_block, **cargs, + **dd, ) if se_ratio: se_channels = int(round(in_chs * se_ratio)) - self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer) + self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer, **dd) else: self.se = nn.Identity() - self.conv3 = ConvNormAct(bottleneck_chs, out_chs, kernel_size=1, apply_act=False, **cargs) + self.conv3 = ConvNormAct(bottleneck_chs, out_chs, kernel_size=1, apply_act=False, **cargs, **dd) self.act3 = nn.Identity() if linear_out else act_layer() self.downsample = create_shortcut( downsample, @@ -330,6 +345,7 @@ def __init__( stride=stride, dilation=dilation, norm_layer=norm_layer, + **dd, ) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() @@ -376,10 +392,12 @@ def __init__( se_ratio: float = 0.25, downsample: str = 'conv1x1', linear_out: bool = False, - act_layer: Callable = nn.ReLU, - norm_layer: Callable = nn.BatchNorm2d, - drop_block=None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + drop_block: Optional[Type[nn.Module]] = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): """Initialize pre-activation RegNet Bottleneck block. @@ -398,14 +416,15 @@ def __init__( drop_block: Drop block layer. drop_path_rate: Stochastic depth drop rate. """ - super(PreBottleneck, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) bottleneck_chs = int(round(out_chs * bottle_ratio)) groups = bottleneck_chs // group_size - self.norm1 = norm_act_layer(in_chs) - self.conv1 = create_conv2d(in_chs, bottleneck_chs, kernel_size=1) - self.norm2 = norm_act_layer(bottleneck_chs) + self.norm1 = norm_act_layer(in_chs, **dd) + self.conv1 = create_conv2d(in_chs, bottleneck_chs, kernel_size=1, **dd) + self.norm2 = norm_act_layer(bottleneck_chs, **dd) self.conv2 = create_conv2d( bottleneck_chs, bottleneck_chs, @@ -413,14 +432,15 @@ def __init__( stride=stride, dilation=dilation[0], groups=groups, + **dd, ) if se_ratio: se_channels = int(round(in_chs * se_ratio)) - self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer) + self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer, **dd) else: self.se = nn.Identity() - self.norm3 = norm_act_layer(bottleneck_chs) - self.conv3 = create_conv2d(bottleneck_chs, out_chs, kernel_size=1) + self.norm3 = norm_act_layer(bottleneck_chs, **dd) + self.conv3 = create_conv2d(bottleneck_chs, out_chs, kernel_size=1, **dd) self.downsample = create_shortcut( downsample, in_chs, @@ -429,6 +449,7 @@ def __init__( stride=stride, dilation=dilation, preact=True, + **dd, ) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() @@ -474,7 +495,7 @@ def __init__( stride: int, dilation: int, drop_path_rates: Optional[List[float]] = None, - block_fn: Callable = Bottleneck, + block_fn: Type[nn.Module] = Bottleneck, **block_kwargs, ): """Initialize RegNet stage. @@ -489,7 +510,7 @@ def __init__( block_fn: Block class to use. **block_kwargs: Additional block arguments. """ - super(RegStage, self).__init__() + super().__init__() self.grad_checkpointing = False first_dilation = 1 if dilation in (1, 2) else 2 @@ -546,6 +567,8 @@ def __init__( drop_rate: float = 0., drop_path_rate: float = 0., zero_init_last: bool = True, + device=None, + dtype=None, **kwargs, ): """Initialize RegNet model. @@ -562,6 +585,7 @@ def __init__( kwargs: Extra kwargs overlayed onto cfg. """ super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.drop_rate = drop_rate assert output_stride in (8, 16, 32) @@ -571,9 +595,9 @@ def __init__( stem_width = cfg.stem_width na_args = dict(act_layer=cfg.act_layer, norm_layer=cfg.norm_layer) if cfg.preact: - self.stem = create_conv2d(in_chans, stem_width, 3, stride=2) + self.stem = create_conv2d(in_chans, stem_width, 3, stride=2, **dd) else: - self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, **na_args) + self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, **na_args, **dd) self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')] # Construct the stages @@ -595,6 +619,7 @@ def __init__( block_fn=block_fn, **stage_args, **common_args, + **dd, ) ) prev_width = stage_args['out_chs'] @@ -603,7 +628,7 @@ def __init__( # Construct the head if cfg.num_features: - self.final_conv = ConvNormAct(prev_width, cfg.num_features, kernel_size=1, **na_args) + self.final_conv = ConvNormAct(prev_width, cfg.num_features, kernel_size=1, **na_args, **dd) self.num_features = cfg.num_features else: final_act = cfg.linear_out or cfg.preact @@ -615,6 +640,7 @@ def __init__( num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate, + **dd, ) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) diff --git a/timm/models/repghost.py b/timm/models/repghost.py index c5a7d93a4f..779d27e037 100644 --- a/timm/models/repghost.py +++ b/timm/models/repghost.py @@ -6,7 +6,7 @@ """ import copy from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type import torch import torch.nn as nn @@ -29,22 +29,25 @@ class RepGhostModule(nn.Module): def __init__( self, - in_chs, - out_chs, - kernel_size=1, - dw_size=3, - stride=1, - relu=True, - reparam=True, + in_chs: int, + out_chs: int, + kernel_size: int = 1, + dw_size: int = 3, + stride: int = 1, + relu: bool = True, + reparam: bool = True, + device=None, + dtype=None, ): - super(RepGhostModule, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.out_chs = out_chs init_chs = out_chs new_chs = out_chs self.primary_conv = nn.Sequential( - nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False), - nn.BatchNorm2d(init_chs), + nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False, **dd), + nn.BatchNorm2d(init_chs, **dd), nn.ReLU(inplace=True) if relu else nn.Identity(), ) @@ -52,14 +55,14 @@ def __init__( fusion_bn = [] if reparam: fusion_conv.append(nn.Identity()) - fusion_bn.append(nn.BatchNorm2d(init_chs)) + fusion_bn.append(nn.BatchNorm2d(init_chs, **dd)) self.fusion_conv = nn.Sequential(*fusion_conv) self.fusion_bn = nn.Sequential(*fusion_bn) self.cheap_operation = nn.Sequential( - nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False), - nn.BatchNorm2d(new_chs), + nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False, **dd), + nn.BatchNorm2d(new_chs, **dd), # nn.ReLU(inplace=True) if relu else nn.Identity(), ) self.relu = nn.ReLU(inplace=False) if relu else nn.Identity() @@ -113,6 +116,7 @@ def switch_to_deploy(self): if len(self.fusion_conv) == 0 and len(self.fusion_bn) == 0: return kernel, bias = self.get_equivalent_kernel_bias() + dd = {'device': kernel.device, 'dtype': kernel.dtype} self.cheap_operation = nn.Conv2d( in_channels=self.cheap_operation[0].in_channels, out_channels=self.cheap_operation[0].out_channels, @@ -120,7 +124,8 @@ def switch_to_deploy(self): padding=self.cheap_operation[0].padding, dilation=self.cheap_operation[0].dilation, groups=self.cheap_operation[0].groups, - bias=True) + bias=True, + **dd) self.cheap_operation.weight.data = kernel self.cheap_operation.bias.data = bias self.__delattr__('fusion_conv') @@ -137,37 +142,47 @@ class RepGhostBottleneck(nn.Module): def __init__( self, - in_chs, - mid_chs, - out_chs, - dw_kernel_size=3, - stride=1, - act_layer=nn.ReLU, - se_ratio=0., - reparam=True, + in_chs: int, + mid_chs: int, + out_chs: int, + dw_kernel_size: int = 3, + stride: int = 1, + act_layer: Type[nn.Module] = nn.ReLU, + se_ratio: float = 0., + reparam: bool = True, + device=None, + dtype=None, ): - super(RepGhostBottleneck, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() has_se = se_ratio is not None and se_ratio > 0. self.stride = stride # Point-wise expansion - self.ghost1 = RepGhostModule(in_chs, mid_chs, relu=True, reparam=reparam) + self.ghost1 = RepGhostModule(in_chs, mid_chs, relu=True, reparam=reparam, **dd) # Depth-wise convolution if self.stride > 1: self.conv_dw = nn.Conv2d( - mid_chs, mid_chs, dw_kernel_size, stride=stride, - padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False) - self.bn_dw = nn.BatchNorm2d(mid_chs) + mid_chs, + mid_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size-1)//2, + groups=mid_chs, + bias=False, + **dd, + ) + self.bn_dw = nn.BatchNorm2d(mid_chs, **dd) else: self.conv_dw = None self.bn_dw = None # Squeeze-and-excitation - self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None + self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio, **dd) if has_se else None # Point-wise linear projection - self.ghost2 = RepGhostModule(mid_chs, out_chs, relu=False, reparam=reparam) + self.ghost2 = RepGhostModule(mid_chs, out_chs, relu=False, reparam=reparam, **dd) # shortcut if in_chs == out_chs and self.stride == 1: @@ -175,11 +190,18 @@ def __init__( else: self.shortcut = nn.Sequential( nn.Conv2d( - in_chs, in_chs, dw_kernel_size, stride=stride, - padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), - nn.BatchNorm2d(in_chs), - nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), - nn.BatchNorm2d(out_chs), + in_chs, + in_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size-1)//2, + groups=in_chs, + bias=False, + **dd, + ), + nn.BatchNorm2d(in_chs, **dd), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False, **dd), + nn.BatchNorm2d(out_chs, **dd), ) def forward(self, x): @@ -207,16 +229,19 @@ def forward(self, x): class RepGhostNet(nn.Module): def __init__( self, - cfgs, - num_classes=1000, - width=1.0, - in_chans=3, - output_stride=32, - global_pool='avg', - drop_rate=0.2, - reparam=True, + cfgs: List[List[List]], + num_classes: int = 1000, + width: float = 1.0, + in_chans: int = 3, + output_stride: int = 32, + global_pool: str = 'avg', + drop_rate: float = 0.2, + reparam: bool = True, + device=None, + dtype=None, ): - super(RepGhostNet, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} # setting of inverted residual blocks assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported' self.cfgs = cfgs @@ -227,9 +252,9 @@ def __init__( # building first layer stem_chs = make_divisible(16 * width, 4) - self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False) + self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False, **dd) self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem')) - self.bn1 = nn.BatchNorm2d(stem_chs) + self.bn1 = nn.BatchNorm2d(stem_chs, **dd) self.act1 = nn.ReLU(inplace=True) prev_chs = stem_chs @@ -244,7 +269,7 @@ def __init__( for k, exp_size, c, se_ratio, s in cfg: out_chs = make_divisible(c * width, 4) mid_chs = make_divisible(exp_size * width, 4) - layers.append(block(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio, reparam=reparam)) + layers.append(block(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio, reparam=reparam, **dd)) prev_chs = out_chs if s > 1: net_stride *= 2 @@ -254,7 +279,7 @@ def __init__( stage_idx += 1 out_chs = make_divisible(exp_size * width * 2, 4) - stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1))) + stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1, **dd))) self.pool_dim = prev_chs = out_chs self.blocks = nn.Sequential(*stages) @@ -263,10 +288,10 @@ def __init__( self.num_features = prev_chs self.head_hidden_size = out_chs = 1280 self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) + self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True, **dd) self.act2 = nn.ReLU(inplace=True) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(out_chs, num_classes, **dd) if num_classes > 0 else nn.Identity() @torch.jit.ignore def group_matcher(self, coarse=False): @@ -293,7 +318,13 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): # NOTE: cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() + if num_classes > 0: + device = self.classifier.weight.device if hasattr(self.classifier, 'weight') else None + dtype = self.classifier.weight.dtype if hasattr(self.classifier, 'weight') else None + dd = {'device': device, 'dtype': dtype} + self.classifier = Linear(self.head_hidden_size, num_classes, **dd) + else: + self.classifier = nn.Identity() def forward_intermediates( self, diff --git a/timm/models/repvit.py b/timm/models/repvit.py index 3641d6f70c..66294e5582 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -14,7 +14,7 @@ Adapted from official impl at https://github.com/jameslahm/RepViT """ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type import torch import torch.nn as nn @@ -30,10 +30,23 @@ class ConvNorm(nn.Sequential): - def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + def __init__( + self, + in_dim: int, + out_dim: int, + ks: int = 1, + stride: int = 1, + pad: int = 0, + dilation: int = 1, + groups: int = 1, + bn_weight_init: float = 1, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.add_module('c', nn.Conv2d(in_dim, out_dim, ks, stride, pad, dilation, groups, bias=False)) - self.add_module('bn', nn.BatchNorm2d(out_dim)) + self.add_module('c', nn.Conv2d(in_dim, out_dim, ks, stride, pad, dilation, groups, bias=False, **dd)) + self.add_module('bn', nn.BatchNorm2d(out_dim, **dd)) nn.init.constant_(self.bn.weight, bn_weight_init) nn.init.constant_(self.bn.bias, 0) @@ -59,10 +72,19 @@ def fuse(self): class NormLinear(nn.Sequential): - def __init__(self, in_dim, out_dim, bias=True, std=0.02): + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + std: float = 0.02, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.add_module('bn', nn.BatchNorm1d(in_dim)) - self.add_module('l', nn.Linear(in_dim, out_dim, bias=bias)) + self.add_module('bn', nn.BatchNorm1d(in_dim, **dd)) + self.add_module('l', nn.Linear(in_dim, out_dim, bias=bias, **dd)) trunc_normal_(self.l.weight, std=std) if bias: nn.init.constant_(self.l.bias, 0) @@ -84,16 +106,24 @@ def fuse(self): class RepVggDw(nn.Module): - def __init__(self, ed, kernel_size, legacy=False): + def __init__( + self, + ed: int, + kernel_size: int, + legacy: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed) + self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed, **dd) if legacy: - self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed) + self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed, **dd) # Make torchscript happy. self.bn = nn.Identity() else: - self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed) - self.bn = nn.BatchNorm2d(ed) + self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed, **dd) + self.bn = nn.BatchNorm2d(ed, **dd) self.dim = ed self.legacy = legacy @@ -137,23 +167,41 @@ def fuse(self): class RepVitMlp(nn.Module): - def __init__(self, in_dim, hidden_dim, act_layer): + def __init__( + self, + in_dim: int, + hidden_dim: int, + act_layer: Type[nn.Module], + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0) + self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0, **dd) self.act = act_layer() - self.conv2 = ConvNorm(hidden_dim, in_dim, 1, 1, 0, bn_weight_init=0) + self.conv2 = ConvNorm(hidden_dim, in_dim, 1, 1, 0, bn_weight_init=0, **dd) def forward(self, x): return self.conv2(self.act(self.conv1(x))) class RepViTBlock(nn.Module): - def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy=False): - super(RepViTBlock, self).__init__() - - self.token_mixer = RepVggDw(in_dim, kernel_size, legacy) - self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity() - self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer) + def __init__( + self, + in_dim: int, + mlp_ratio: float, + kernel_size: int, + use_se: bool, + act_layer: Type[nn.Module], + legacy: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.token_mixer = RepVggDw(in_dim, kernel_size, legacy, **dd) + self.se = SqueezeExcite(in_dim, 0.25, **dd) if use_se else nn.Identity() + self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer, **dd) def forward(self, x): x = self.token_mixer(x) @@ -164,11 +212,19 @@ def forward(self, x): class RepVitStem(nn.Module): - def __init__(self, in_chs, out_chs, act_layer): + def __init__( + self, + in_chs: int, + out_chs: int, + act_layer: Type[nn.Module], + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1) + self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1, **dd) self.act1 = act_layer() - self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1) + self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1, **dd) self.stride = 4 def forward(self, x): @@ -176,12 +232,39 @@ def forward(self, x): class RepVitDownsample(nn.Module): - def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer, legacy=False): + def __init__( + self, + in_dim: int, + mlp_ratio: float, + out_dim: int, + kernel_size: int, + act_layer: Type[nn.Module], + legacy: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer, legacy=legacy) - self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim) - self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1) - self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer) + self.pre_block = RepViTBlock( + in_dim, + mlp_ratio, + kernel_size, + use_se=False, + act_layer=act_layer, + legacy=legacy, + **dd, + ) + self.spatial_downsample = ConvNorm( + in_dim, + in_dim, + kernel_size, + stride=2, + pad=(kernel_size - 1) // 2, + groups=in_dim, + **dd, + ) + self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1, **dd) + self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer, **dd) def forward(self, x): x = self.pre_block(x) @@ -193,15 +276,24 @@ def forward(self, x): class RepVitClassifier(nn.Module): - def __init__(self, dim, num_classes, distillation=False, drop=0.0): + def __init__( + self, + dim: int, + num_classes: int, + distillation: bool = False, + drop: float = 0.0, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.head_drop = nn.Dropout(drop) - self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = NormLinear(dim, num_classes, **dd) if num_classes > 0 else nn.Identity() self.distillation = distillation self.distilled_training = False self.num_classes = num_classes if distillation: - self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = NormLinear(dim, num_classes, **dd) if num_classes > 0 else nn.Identity() def forward(self, x): x = self.head_drop(x) @@ -232,10 +324,31 @@ def fuse(self): class RepVitStage(nn.Module): - def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True, legacy=False): + def __init__( + self, + in_dim: int, + out_dim: int, + depth: int, + mlp_ratio: float, + act_layer: Type[nn.Module], + kernel_size: int = 3, + downsample: bool = True, + legacy: bool = False, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() if downsample: - self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer, legacy) + self.downsample = RepVitDownsample( + in_dim, + mlp_ratio, + out_dim, + kernel_size, + act_layer=act_layer, + legacy=legacy, + **dd, + ) else: assert in_dim == out_dim self.downsample = nn.Identity() @@ -243,7 +356,7 @@ def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, blocks = [] use_se = True for _ in range(depth): - blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy)) + blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy, **dd)) use_se = not use_se self.blocks = nn.Sequential(*blocks) @@ -257,27 +370,30 @@ def forward(self, x): class RepVit(nn.Module): def __init__( self, - in_chans=3, - img_size=224, - embed_dim=(48,), - depth=(2,), - mlp_ratio=2, - global_pool='avg', - kernel_size=3, - num_classes=1000, - act_layer=nn.GELU, - distillation=True, - drop_rate=0.0, - legacy=False, + in_chans: int = 3, + img_size: int = 224, + embed_dim: Tuple[int, ...] = (48,), + depth: Tuple[int, ...] = (2,), + mlp_ratio: float = 2, + global_pool: str = 'avg', + kernel_size: int = 3, + num_classes: int = 1000, + act_layer: Type[nn.Module] = nn.GELU, + distillation: bool = True, + drop_rate: float = 0.0, + legacy: bool = False, + device=None, + dtype=None, ): - super(RepVit, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} self.grad_checkpointing = False self.global_pool = global_pool self.embed_dim = embed_dim self.num_classes = num_classes in_dim = embed_dim[0] - self.stem = RepVitStem(in_chans, in_dim, act_layer) + self.stem = RepVitStem(in_chans, in_dim, act_layer, **dd) stride = self.stem.stride resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))]) @@ -298,6 +414,7 @@ def __init__( kernel_size=kernel_size, downsample=downsample, legacy=legacy, + **dd, ) ) stage_stride = 2 if downsample else 1 @@ -309,7 +426,7 @@ def __init__( self.num_features = self.head_hidden_size = embed_dim[-1] self.head_drop = nn.Dropout(drop_rate) - self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation) + self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation, **dd) @torch.jit.ignore def group_matcher(self, coarse=False): @@ -324,11 +441,12 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation: bool = False): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation: bool = False, device=None, dtype=None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool - self.head = RepVitClassifier(self.embed_dim[-1], num_classes, distillation) + dd = {'device': device, 'dtype': dtype} + self.head = RepVitClassifier(self.embed_dim[-1], num_classes, distillation, **dd) @torch.jit.ignore def set_distilled_training(self, enable=True): diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 691f929b91..589bf2ddb2 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -3,6 +3,7 @@ Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169 """ import math +from typing import Optional, Type import torch import torch.nn as nn @@ -23,21 +24,24 @@ class Bottle2neck(nn.Module): def __init__( self, - inplanes, - planes, - stride=1, - downsample=None, - cardinality=1, - base_width=26, - scale=4, - dilation=1, - first_dilation=None, - act_layer=nn.ReLU, - norm_layer=None, - attn_layer=None, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + cardinality: int = 1, + base_width: int = 26, + scale: int = 4, + dilation: int = 1, + first_dilation: Optional[int] = None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Optional[Type[nn.Module]] = None, + attn_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, **_, ): - super(Bottle2neck, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.scale = scale self.is_first = stride > 1 or downsample is not None self.num_scales = max(1, scale - 1) @@ -46,16 +50,24 @@ def __init__( outplanes = planes * self.expansion first_dilation = first_dilation or dilation - self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) - self.bn1 = norm_layer(width * scale) + self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False, **dd) + self.bn1 = norm_layer(width * scale, **dd) convs = [] bns = [] for i in range(self.num_scales): convs.append(nn.Conv2d( - width, width, kernel_size=3, stride=stride, padding=first_dilation, - dilation=first_dilation, groups=cardinality, bias=False)) - bns.append(norm_layer(width)) + width, + width, + kernel_size=3, + stride=stride, + padding=first_dilation, + dilation=first_dilation, + groups=cardinality, + bias=False, + **dd, + )) + bns.append(norm_layer(width, **dd)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) if self.is_first: @@ -64,9 +76,9 @@ def __init__( else: self.pool = None - self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False) - self.bn3 = norm_layer(outplanes) - self.se = attn_layer(outplanes) if attn_layer is not None else None + self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False, **dd) + self.bn3 = norm_layer(outplanes, **dd) + self.se = attn_layer(outplanes, **dd) if attn_layer is not None else None self.relu = act_layer(inplace=True) self.downsample = downsample diff --git a/timm/models/resnest.py b/timm/models/resnest.py index ed25c6e448..0695b38774 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -6,6 +6,8 @@ Modified for torchscript compat, and consistency with timm by Ross Wightman """ +from typing import Optional, Type + from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -23,27 +25,30 @@ class ResNestBottleneck(nn.Module): def __init__( self, - inplanes, - planes, - stride=1, - downsample=None, - radix=1, - cardinality=1, - base_width=64, - avd=False, - avd_first=False, - is_first=False, - reduce_first=1, - dilation=1, - first_dilation=None, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - attn_layer=None, - aa_layer=None, - drop_block=None, - drop_path=None, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + radix: int = 1, + cardinality: int = 1, + base_width: int = 64, + avd: bool = False, + avd_first: bool = False, + is_first: bool = False, + reduce_first: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + attn_layer: Optional[Type[nn.Module]] = None, + aa_layer: Optional[Type[nn.Module]] = None, + drop_block: Optional[Type[nn.Module]] = None, + drop_path: Optional[nn.Module] = None, + device=None, + dtype=None, ): - super(ResNestBottleneck, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() assert reduce_first == 1 # not supported assert attn_layer is None, 'attn_layer is not supported' # not supported assert aa_layer is None, 'aa_layer is not supported' # TODO not yet supported @@ -57,29 +62,47 @@ def __init__( avd_stride = 0 self.radix = radix - self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) - self.bn1 = norm_layer(group_width) + self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False, **dd) + self.bn1 = norm_layer(group_width, **dd) self.act1 = act_layer(inplace=True) self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None if self.radix >= 1: self.conv2 = SplitAttn( - group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, - dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_layer=drop_block) + group_width, + group_width, + kernel_size=3, + stride=stride, + padding=first_dilation, + dilation=first_dilation, + groups=cardinality, + radix=radix, + norm_layer=norm_layer, + drop_layer=drop_block, + **dd, + ) self.bn2 = nn.Identity() self.drop_block = nn.Identity() self.act2 = nn.Identity() else: self.conv2 = nn.Conv2d( - group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, - dilation=first_dilation, groups=cardinality, bias=False) - self.bn2 = norm_layer(group_width) + group_width, + group_width, + kernel_size=3, + stride=stride, + padding=first_dilation, + dilation=first_dilation, + groups=cardinality, + bias=False, + **dd, + ) + self.bn2 = norm_layer(group_width, **dd) self.drop_block = drop_block() if drop_block is not None else nn.Identity() self.act2 = act_layer(inplace=True) self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None - self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False) - self.bn3 = norm_layer(planes*4) + self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False, **dd) + self.bn3 = norm_layer(planes * 4, **dd) self.act3 = act_layer(inplace=True) self.downsample = downsample self.drop_path = drop_path diff --git a/timm/models/resnet.py b/timm/models/resnet.py index ca07682b80..a92c5df6ce 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -55,6 +55,8 @@ def __init__( aa_layer: Optional[Type[nn.Module]] = None, drop_block: Optional[Type[nn.Module]] = None, drop_path: Optional[nn.Module] = None, + device=None, + dtype=None, ) -> None: """ Args: @@ -74,7 +76,8 @@ def __init__( drop_block: DropBlock layer class. drop_path: Optional DropPath layer instance. """ - super(BasicBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock does not support changing base width' @@ -84,18 +87,32 @@ def __init__( use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, - dilation=first_dilation, bias=False) - self.bn1 = norm_layer(first_planes) + inplanes, + first_planes, + kernel_size=3, + stride=1 if use_aa else stride, + padding=first_dilation, + dilation=first_dilation, + bias=False, + **dd, + ) + self.bn1 = norm_layer(first_planes, **dd) self.drop_block = drop_block() if drop_block is not None else nn.Identity() self.act1 = act_layer(inplace=True) - self.aa = create_aa(aa_layer, channels=first_planes, stride=stride, enable=use_aa) + self.aa = create_aa(aa_layer, channels=first_planes, stride=stride, enable=use_aa, **dd) self.conv2 = nn.Conv2d( - first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) - self.bn2 = norm_layer(outplanes) + first_planes, + outplanes, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=False, + **dd, + ) + self.bn2 = norm_layer(outplanes, **dd) - self.se = create_attn(attn_layer, outplanes) + self.se = create_attn(attn_layer, outplanes, **dd) self.act2 = act_layer(inplace=True) self.downsample = downsample @@ -158,6 +175,8 @@ def __init__( aa_layer: Optional[Type[nn.Module]] = None, drop_block: Optional[Type[nn.Module]] = None, drop_path: Optional[nn.Module] = None, + device=None, + dtype=None, ) -> None: """ Args: @@ -177,7 +196,8 @@ def __init__( drop_block: DropBlock layer class. drop_path: Optional DropPath layer instance. """ - super(Bottleneck, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first @@ -185,22 +205,30 @@ def __init__( first_dilation = first_dilation or dilation use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) - self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) - self.bn1 = norm_layer(first_planes) + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False, **dd) + self.bn1 = norm_layer(first_planes, **dd) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( - first_planes, width, kernel_size=3, stride=1 if use_aa else stride, - padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) - self.bn2 = norm_layer(width) + first_planes, + width, + kernel_size=3, + stride=1 if use_aa else stride, + padding=first_dilation, + dilation=first_dilation, + groups=cardinality, + bias=False, + **dd, + ) + self.bn2 = norm_layer(width, **dd) self.drop_block = drop_block() if drop_block is not None else nn.Identity() self.act2 = act_layer(inplace=True) - self.aa = create_aa(aa_layer, channels=width, stride=stride, enable=use_aa) + self.aa = create_aa(aa_layer, channels=width, stride=stride, enable=use_aa, **dd) - self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) - self.bn3 = norm_layer(outplanes) + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False, **dd) + self.bn3 = norm_layer(outplanes, **dd) - self.se = create_attn(attn_layer, outplanes) + self.se = create_attn(attn_layer, outplanes, **dd) self.act3 = act_layer(inplace=True) self.downsample = downsample @@ -251,7 +279,10 @@ def downsample_conv( dilation: int = 1, first_dilation: Optional[int] = None, norm_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ) -> nn.Module: + dd = {'device': device, 'dtype': dtype} norm_layer = norm_layer or nn.BatchNorm2d kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 @@ -259,8 +290,16 @@ def downsample_conv( return nn.Sequential(*[ nn.Conv2d( - in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False), - norm_layer(out_channels) + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=p, + dilation=first_dilation, + bias=False, + **dd + ), + norm_layer(out_channels, **dd) ]) @@ -272,7 +311,10 @@ def downsample_avg( dilation: int = 1, first_dilation: Optional[int] = None, norm_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ) -> nn.Module: + dd = {'device': device, 'dtype': dtype} norm_layer = norm_layer or nn.BatchNorm2d avg_stride = stride if dilation == 1 else 1 if stride == 1 and dilation == 1: @@ -283,8 +325,8 @@ def downsample_avg( return nn.Sequential(*[ pool, - nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False), - norm_layer(out_channels) + nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False, **dd), + norm_layer(out_channels, **dd) ]) @@ -314,6 +356,8 @@ def make_blocks( avg_down: bool = False, drop_block_rate: float = 0., drop_path_rate: float = 0., + device=None, + dtype=None, **kwargs, ) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]: """Create ResNet stages with specified block configurations. @@ -334,6 +378,7 @@ def make_blocks( Returns: Tuple of stage modules list and feature info list. """ + dd = {'device': device, 'dtype': dtype} stages = [] feature_info = [] net_num_blocks = sum(block_repeats) @@ -359,6 +404,7 @@ def make_blocks( dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer'), + **dd, ) downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs) @@ -376,6 +422,7 @@ def make_blocks( first_dilation=prev_dilation, drop_path=DropPath(block_dpr) if block_dpr > 0. else None, **block_kwargs, + **dd, )) prev_dilation = dilation inplanes = planes * block_fn.expansion @@ -444,6 +491,8 @@ def __init__( drop_block_rate: float = 0., zero_init_last: bool = True, block_args: Optional[Dict[str, Any]] = None, + device=None, + dtype=None, ): """ Args: @@ -475,7 +524,8 @@ def __init__( zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight) block_args (dict): Extra kwargs to pass through to block module """ - super(ResNet, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} block_args = block_args or dict() assert output_stride in (8, 16, 32) self.num_classes = num_classes @@ -493,25 +543,25 @@ def __init__( if 'tiered' in stem_type: stem_chs = (3 * (stem_width // 4), stem_width) self.conv1 = nn.Sequential(*[ - nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False), - norm_layer(stem_chs[0]), + nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False, **dd), + norm_layer(stem_chs[0], **dd), act_layer(inplace=True), - nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False), - norm_layer(stem_chs[1]), + nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False, **dd), + norm_layer(stem_chs[1], **dd), act_layer(inplace=True), - nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)]) + nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False, **dd)]) else: - self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False) - self.bn1 = norm_layer(inplanes) + self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False, **dd) + self.bn1 = norm_layer(inplanes, **dd) self.act1 = act_layer(inplace=True) self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')] # Stem pooling. The name 'maxpool' remains for weight compatibility. if replace_stem_pool: self.maxpool = nn.Sequential(*filter(None, [ - nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False), - create_aa(aa_layer, channels=inplanes, stride=2) if aa_layer is not None else None, - norm_layer(inplanes), + nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False, **dd), + create_aa(aa_layer, channels=inplanes, stride=2, **dd) if aa_layer is not None else None, + norm_layer(inplanes, **dd), act_layer(inplace=True), ])) else: @@ -521,7 +571,7 @@ def __init__( else: self.maxpool = nn.Sequential(*[ nn.MaxPool2d(kernel_size=3, stride=1, padding=1), - aa_layer(channels=inplanes, stride=2)]) + aa_layer(channels=inplanes, stride=2, **dd)]) else: self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -544,6 +594,7 @@ def __init__( drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args, + **dd, ) for stage in stage_modules: self.add_module(*stage) # layer1, layer2, etc @@ -551,7 +602,7 @@ def __init__( # Head (Pooling and Classifier) self.num_features = self.head_hidden_size = channels[-1] * block_fns[-1].expansion - self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool, **dd) self.init_weights(zero_init_last=zero_init_last) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 7c1424d93c..48c8131368 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -64,6 +64,8 @@ def __init__( norm_layer: Optional[Callable] = None, proj_layer: Optional[Callable] = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): """Initialize PreActBasic block. @@ -81,6 +83,7 @@ def __init__( proj_layer: Projection/downsampling layer type. drop_path_rate: Stochastic depth drop rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() first_dilation = first_dilation or dilation conv_layer = conv_layer or StdConv2d @@ -98,14 +101,15 @@ def __init__( preact=True, conv_layer=conv_layer, norm_layer=norm_layer, + **dd, ) else: self.downsample = None - self.norm1 = norm_layer(in_chs) - self.conv1 = conv_layer(in_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) - self.norm2 = norm_layer(mid_chs) - self.conv2 = conv_layer(mid_chs, out_chs, 3, dilation=dilation, groups=groups) + self.norm1 = norm_layer(in_chs, **dd) + self.conv1 = conv_layer(in_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups, **dd) + self.norm2 = norm_layer(mid_chs, **dd) + self.conv2 = conv_layer(mid_chs, out_chs, 3, dilation=dilation, groups=groups, **dd) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() def zero_init_last(self) -> None: @@ -158,6 +162,8 @@ def __init__( norm_layer: Optional[Callable] = None, proj_layer: Optional[Callable] = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): """Initialize PreActBottleneck block. @@ -175,6 +181,7 @@ def __init__( proj_layer: Projection/downsampling layer type. drop_path_rate: Stochastic depth drop rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() first_dilation = first_dilation or dilation conv_layer = conv_layer or StdConv2d @@ -192,16 +199,17 @@ def __init__( preact=True, conv_layer=conv_layer, norm_layer=norm_layer, + **dd, ) else: self.downsample = None - self.norm1 = norm_layer(in_chs) - self.conv1 = conv_layer(in_chs, mid_chs, 1) - self.norm2 = norm_layer(mid_chs) - self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) - self.norm3 = norm_layer(mid_chs) - self.conv3 = conv_layer(mid_chs, out_chs, 1) + self.norm1 = norm_layer(in_chs, **dd) + self.conv1 = conv_layer(in_chs, mid_chs, 1, **dd) + self.norm2 = norm_layer(mid_chs, **dd) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups, **dd) + self.norm3 = norm_layer(mid_chs, **dd) + self.conv3 = conv_layer(mid_chs, out_chs, 1, **dd) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() def zero_init_last(self) -> None: @@ -249,7 +257,10 @@ def __init__( norm_layer: Optional[Callable] = None, proj_layer: Optional[Callable] = None, drop_path_rate: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() first_dilation = first_dilation or dilation act_layer = act_layer or nn.ReLU @@ -267,16 +278,17 @@ def __init__( preact=False, conv_layer=conv_layer, norm_layer=norm_layer, + **dd, ) else: self.downsample = None - self.conv1 = conv_layer(in_chs, mid_chs, 1) - self.norm1 = norm_layer(mid_chs) - self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) - self.norm2 = norm_layer(mid_chs) - self.conv3 = conv_layer(mid_chs, out_chs, 1) - self.norm3 = norm_layer(out_chs, apply_act=False) + self.conv1 = conv_layer(in_chs, mid_chs, 1, **dd) + self.norm1 = norm_layer(mid_chs, **dd) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups, **dd) + self.norm2 = norm_layer(mid_chs, **dd) + self.conv3 = conv_layer(mid_chs, out_chs, 1, **dd) + self.norm3 = norm_layer(out_chs, apply_act=False, **dd) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.act3 = act_layer(inplace=True) @@ -324,10 +336,13 @@ def __init__( preact: bool = True, conv_layer: Optional[Callable] = None, norm_layer: Optional[Callable] = None, + device=None, + dtype=None, ): - super(DownsampleConv, self).__init__() - self.conv = conv_layer(in_chs, out_chs, 1, stride=stride) - self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.conv = conv_layer(in_chs, out_chs, 1, stride=stride, **dd) + self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. @@ -354,16 +369,19 @@ def __init__( preact: bool = True, conv_layer: Optional[Callable] = None, norm_layer: Optional[Callable] = None, + device=None, + dtype=None, ): - super(DownsampleAvg, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() avg_stride = stride if dilation == 1 else 1 if stride > 1 or dilation > 1: avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) else: self.pool = nn.Identity() - self.conv = conv_layer(in_chs, out_chs, 1, stride=1) - self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) + self.conv = conv_layer(in_chs, out_chs, 1, stride=1, **dd) + self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. @@ -396,7 +414,7 @@ def __init__( norm_layer: Optional[Callable] = None, **block_kwargs: Any, ): - super(ResNetStage, self).__init__() + super().__init__() self.grad_checkpointing = False first_dilation = 1 if dilation in (1, 2) else 2 @@ -459,7 +477,10 @@ def create_resnetv2_stem( preact: bool = True, conv_layer: Callable = StdConv2d, norm_layer: Callable = partial(GroupNormAct, num_groups=32), + device=None, + dtype=None, ) -> nn.Sequential: + dd = {'device': device, 'dtype': dtype} stem = OrderedDict() assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered') @@ -470,18 +491,18 @@ def create_resnetv2_stem( stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py else: stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets - stem['conv1'] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2) - stem['norm1'] = norm_layer(stem_chs[0]) - stem['conv2'] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1) - stem['norm2'] = norm_layer(stem_chs[1]) - stem['conv3'] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1) + stem['conv1'] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2, **dd) + stem['norm1'] = norm_layer(stem_chs[0], **dd) + stem['conv2'] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1, **dd) + stem['norm2'] = norm_layer(stem_chs[1], **dd) + stem['conv3'] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1, **dd) if not preact: - stem['norm3'] = norm_layer(out_chs) + stem['norm3'] = norm_layer(out_chs, **dd) else: # The usual 7x7 stem conv - stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2, **dd) if not preact: - stem['norm'] = norm_layer(out_chs) + stem['norm'] = norm_layer(out_chs, **dd) if 'fixed' in stem_type: # 'fixed' SAME padding approximation that is used in BiT models @@ -522,6 +543,8 @@ def __init__( drop_rate: float = 0., drop_path_rate: float = 0., zero_init_last: bool = False, + device=None, + dtype=None, ): """ Args: @@ -544,6 +567,7 @@ def __init__( zero_init_last: zero-init last weight in residual path (default: False) """ super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.drop_rate = drop_rate wf = width_factor @@ -559,6 +583,7 @@ def __init__( preact, conv_layer=conv_layer, norm_layer=norm_layer, + **dd, ) stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm' self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat)) @@ -592,6 +617,7 @@ def __init__( norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn, + **dd, ) prev_chs = out_chs curr_stride *= stride @@ -599,13 +625,14 @@ def __init__( self.stages.add_module(str(stage_idx), stage) self.num_features = self.head_hidden_size = prev_chs - self.norm = norm_layer(self.num_features) if preact else nn.Identity() + self.norm = norm_layer(self.num_features, **dd) if preact else nn.Identity() self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True, + **dd, ) self.init_weights(zero_init_last=zero_init_last) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 77b801db87..4d2dd5207d 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -12,7 +12,7 @@ from functools import partial from math import ceil -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -49,6 +49,8 @@ def __init__( act_layer: str = 'swish', dw_act_layer: str = 'relu6', drop_path: Optional[nn.Module] = None, + device=None, + dtype=None, ): """Initialize LinearBottleneck. @@ -64,14 +66,15 @@ def __init__( dw_act_layer: Activation layer for depthwise. drop_path: Drop path module. """ - super(LinearBottleneck, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.use_shortcut = stride == 1 and dilation[0] == dilation[1] and in_chs <= out_chs self.in_channels = in_chs self.out_channels = out_chs if exp_ratio != 1.: dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div) - self.conv_exp = ConvNormAct(in_chs, dw_chs, act_layer=act_layer) + self.conv_exp = ConvNormAct(in_chs, dw_chs, act_layer=act_layer, **dd) else: dw_chs = in_chs self.conv_exp = None @@ -84,14 +87,15 @@ def __init__( dilation=dilation[0], groups=dw_chs, apply_act=False, + **dd, ) if se_ratio > 0: - self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div)) + self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div), **dd) else: self.se = None self.act_dw = create_act_layer(dw_act_layer) - self.conv_pwl = ConvNormAct(dw_chs, out_chs, 1, apply_act=False) + self.conv_pwl = ConvNormAct(dw_chs, out_chs, 1, apply_act=False, **dd) self.drop_path = drop_path def feat_channels(self, exp: bool = False) -> int: @@ -178,6 +182,8 @@ def _build_blocks( act_layer: str = 'swish', dw_act_layer: str = 'relu6', drop_path_rate: float = 0., + device=None, + dtype=None, ) -> Tuple[List[nn.Module], List[Dict[str, Any]]]: """Build ReXNet blocks from configuration. @@ -194,6 +200,7 @@ def _build_blocks( Returns: Tuple of (features list, feature_info list). """ + dd = {'device': device, 'dtype': dtype} feat_chs = [prev_chs] feature_info = [] curr_stride = 2 @@ -221,6 +228,7 @@ def _build_blocks( act_layer=act_layer, dw_act_layer=dw_act_layer, drop_path=drop_path, + **dd, )) curr_stride *= stride dilation = next_dilation @@ -228,7 +236,7 @@ def _build_blocks( feat_chs += [features[-1].feat_channels()] pen_chs = make_divisible(1280 * width_mult, divisor=ch_div) feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=f'features.{len(features) - 1}')] - features.append(ConvNormAct(prev_chs, pen_chs, act_layer=act_layer)) + features.append(ConvNormAct(prev_chs, pen_chs, act_layer=act_layer, **dd)) return features, feature_info @@ -255,6 +263,8 @@ def __init__( dw_act_layer: str = 'relu6', drop_rate: float = 0.2, drop_path_rate: float = 0., + device=None, + dtype=None, ): """Initialize ReXNet. @@ -274,7 +284,8 @@ def __init__( drop_rate: Dropout rate. drop_path_rate: Drop path rate. """ - super(RexNet, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.drop_rate = drop_rate self.grad_checkpointing = False @@ -282,7 +293,7 @@ def __init__( assert output_stride in (32, 16, 8) stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32 stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div) - self.stem = ConvNormAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer) + self.stem = ConvNormAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer, **dd) block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div) features, self.feature_info = _build_blocks( @@ -294,11 +305,12 @@ def __init__( act_layer, dw_act_layer, drop_path_rate, + **dd, ) self.num_features = self.head_hidden_size = features[-1].out_channels self.features = nn.Sequential(*features) - self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate) + self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate, **dd) efficientnet_init_weights(self) @@ -336,7 +348,7 @@ def get_classifier(self) -> nn.Module: """ return self.head.fc - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, device=None, dtype=None) -> None: """Reset the classifier. Args: @@ -344,7 +356,12 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) global_pool: Global pooling type. """ self.num_classes = num_classes - self.head.reset(num_classes, global_pool) + if device is not None or dtype is not None: + dd = {'device': device, 'dtype': dtype} + pool_type = global_pool if global_pool is not None else self.head.global_pool.pool_type + self.head = ClassifierHead(self.num_features, num_classes, pool_type, self.drop_rate, **dd) + else: + self.head.reset(num_classes, global_pool) def forward_intermediates( self, diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index dc19ff41d1..7937aeb73b 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -9,7 +9,7 @@ Based on ResNet implementation in https://github.com/rwightman/pytorch-image-models and SelecSLS Net implementation in https://github.com/mehtadushy/SelecSLS-Pytorch """ -from typing import List +from typing import List, Type import torch import torch.nn as nn @@ -25,7 +25,7 @@ class SequentialList(nn.Sequential): def __init__(self, *args): - super(SequentialList, self).__init__(*args) + super().__init__(*args) @torch.jit._overload_method # noqa: F811 def forward(self, x): @@ -45,7 +45,7 @@ def forward(self, x) -> List[torch.Tensor]: class SelectSeq(nn.Module): def __init__(self, mode='index', index=0): - super(SelectSeq, self).__init__() + super().__init__() self.mode = mode self.index = index @@ -66,30 +66,43 @@ def forward(self, x) -> torch.Tensor: return torch.cat(x, dim=1) -def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1): +def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} if padding is None: padding = ((stride - 1) + dilation * (k - 1)) // 2 return nn.Sequential( - nn.Conv2d(in_chs, out_chs, k, stride, padding=padding, dilation=dilation, bias=False), - nn.BatchNorm2d(out_chs), + nn.Conv2d(in_chs, out_chs, k, stride, padding=padding, dilation=dilation, bias=False, **dd), + nn.BatchNorm2d(out_chs, **dd), nn.ReLU(inplace=True) ) class SelecSlsBlock(nn.Module): - def __init__(self, in_chs, skip_chs, mid_chs, out_chs, is_first, stride, dilation=1): - super(SelecSlsBlock, self).__init__() + def __init__( + self, + in_chs: int, + skip_chs: int, + mid_chs: int, + out_chs: int, + is_first: bool, + stride: int, + dilation: int = 1, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.stride = stride self.is_first = is_first assert stride in [1, 2] # Process input with 4 conv blocks with the same number of input and output channels - self.conv1 = conv_bn(in_chs, mid_chs, 3, stride, dilation=dilation) - self.conv2 = conv_bn(mid_chs, mid_chs, 1) - self.conv3 = conv_bn(mid_chs, mid_chs // 2, 3) - self.conv4 = conv_bn(mid_chs // 2, mid_chs, 1) - self.conv5 = conv_bn(mid_chs, mid_chs // 2, 3) - self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1) + self.conv1 = conv_bn(in_chs, mid_chs, 3, stride, dilation=dilation, **dd) + self.conv2 = conv_bn(mid_chs, mid_chs, 1, **dd) + self.conv3 = conv_bn(mid_chs, mid_chs // 2, 3, **dd) + self.conv4 = conv_bn(mid_chs // 2, mid_chs, 1, **dd) + self.conv5 = conv_bn(mid_chs, mid_chs // 2, 3, **dd) + self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1, **dd) def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: if not isinstance(x, list): @@ -122,14 +135,24 @@ class SelecSls(nn.Module): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' """ - def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool='avg'): + def __init__( + self, + cfg, + num_classes: int = 1000, + in_chans: int = 3, + drop_rate: float = 0.0, + global_pool: str = 'avg', + device=None, + dtype=None, + ): self.num_classes = num_classes - super(SelecSls, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} - self.stem = conv_bn(in_chans, 32, stride=2) - self.features = SequentialList(*[cfg['block'](*block_args) for block_args in cfg['features']]) + self.stem = conv_bn(in_chans, 32, stride=2, **dd) + self.features = SequentialList(*[cfg['block'](*block_args, **dd) for block_args in cfg['features']]) self.from_seq = SelectSeq() # from List[tensor] -> Tensor in module compatible way - self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']]) + self.head = nn.Sequential(*[conv_bn(*conv_args, **dd) for conv_args in cfg['head']]) self.num_features = self.head_hidden_size = cfg['num_features'] self.feature_info = cfg['feature_info'] @@ -138,6 +161,7 @@ def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool self.num_classes, pool_type=global_pool, drop_rate=drop_rate, + **dd, ) for n, m in self.named_modules(): @@ -162,7 +186,13 @@ def get_classifier(self) -> nn.Module: def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes - self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + self.global_pool, self.fc = create_classifier( + self.num_features, + self.num_classes, + pool_type=global_pool, + device=self.fc.weight.device if hasattr(self.fc, 'weight') else None, + dtype=self.fc.weight.dtype if hasattr(self.fc, 'weight') else None, + ) def forward_features(self, x): x = self.stem(x) diff --git a/timm/models/senet.py b/timm/models/senet.py index 9884c61419..6a1bc99eb2 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -13,6 +13,7 @@ """ import math from collections import OrderedDict +from typing import Type, Optional, Tuple import torch import torch.nn as nn @@ -36,11 +37,12 @@ def _weight_init(m): class SEModule(nn.Module): - def __init__(self, channels, reduction): - super(SEModule, self).__init__() - self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1) + def __init__(self, channels: int, reduction: int, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, **dd) self.relu = nn.ReLU(inplace=True) - self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, **dd) self.sigmoid = nn.Sigmoid() def forward(self, x): @@ -87,18 +89,36 @@ class SEBottleneck(Bottleneck): """ expansion = 4 - def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): - super(SEBottleneck, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes * 2) + def __init__( + self, + inplanes: int, + planes: int, + groups: int, + reduction: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False, **dd) + self.bn1 = nn.BatchNorm2d(planes * 2, **dd) self.conv2 = nn.Conv2d( - planes * 2, planes * 4, kernel_size=3, stride=stride, - padding=1, groups=groups, bias=False) - self.bn2 = nn.BatchNorm2d(planes * 4) - self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * 4) + planes * 2, + planes * 4, + kernel_size=3, + stride=stride, + padding=1, + groups=groups, + bias=False, + **dd, + ) + self.bn2 = nn.BatchNorm2d(planes * 4, **dd) + self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, bias=False, **dd) + self.bn3 = nn.BatchNorm2d(planes * 4, **dd) self.relu = nn.ReLU(inplace=True) - self.se_module = SEModule(planes * 4, reduction=reduction) + self.se_module = SEModule(planes * 4, reduction=reduction, **dd) self.downsample = downsample self.stride = stride @@ -111,16 +131,27 @@ class SEResNetBottleneck(Bottleneck): """ expansion = 4 - def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): - super(SEResNetBottleneck, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, stride=stride) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * 4) + def __init__( + self, + inplanes: int, + planes: int, + groups: int, + reduction: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, stride=stride, **dd) + self.bn1 = nn.BatchNorm2d(planes, **dd) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False, **dd) + self.bn2 = nn.BatchNorm2d(planes, **dd) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False, **dd) + self.bn3 = nn.BatchNorm2d(planes * 4, **dd) self.relu = nn.ReLU(inplace=True) - self.se_module = SEModule(planes * 4, reduction=reduction) + self.se_module = SEModule(planes * 4, reduction=reduction, **dd) self.downsample = downsample self.stride = stride @@ -131,17 +162,29 @@ class SEResNeXtBottleneck(Bottleneck): """ expansion = 4 - def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None, base_width=4): - super(SEResNeXtBottleneck, self).__init__() + def __init__( + self, + inplanes: int, + planes: int, + groups: int, + reduction: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + base_width: int = 4, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() width = math.floor(planes * (base_width / 64)) * groups - self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1) - self.bn1 = nn.BatchNorm2d(width) - self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) - self.bn2 = nn.BatchNorm2d(width) - self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * 4) + self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1, **dd) + self.bn1 = nn.BatchNorm2d(width, **dd) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False, **dd) + self.bn2 = nn.BatchNorm2d(width, **dd) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False, **dd) + self.bn3 = nn.BatchNorm2d(planes * 4, **dd) self.relu = nn.ReLU(inplace=True) - self.se_module = SEModule(planes * 4, reduction=reduction) + self.se_module = SEModule(planes * 4, reduction=reduction, **dd) self.downsample = downsample self.stride = stride @@ -149,14 +192,25 @@ def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=Non class SEResNetBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): - super(SEResNetBlock, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) - self.bn2 = nn.BatchNorm2d(planes) + def __init__( + self, + inplanes: int, + planes: int, + groups: int, + reduction: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False, **dd) + self.bn1 = nn.BatchNorm2d(planes, **dd) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False, **dd) + self.bn2 = nn.BatchNorm2d(planes, **dd) self.relu = nn.ReLU(inplace=True) - self.se_module = SEModule(planes, reduction=reduction) + self.se_module = SEModule(planes, reduction=reduction, **dd) self.downsample = downsample self.stride = stride @@ -183,9 +237,22 @@ def forward(self, x): class SENet(nn.Module): def __init__( - self, block, layers, groups, reduction, drop_rate=0.2, - in_chans=3, inplanes=64, input_3x3=False, downsample_kernel_size=1, - downsample_padding=0, num_classes=1000, global_pool='avg'): + self, + block: Type[nn.Module], + layers: Tuple[int, ...], + groups: int, + reduction: int, + drop_rate: float = 0.2, + in_chans: int = 3, + inplanes: int = 64, + input_3x3: bool = False, + downsample_kernel_size: int = 1, + downsample_padding: int = 0, + num_classes: int = 1000, + global_pool: str = 'avg', + device=None, + dtype=None, + ): """ Parameters ---------- @@ -229,27 +296,27 @@ def __init__( num_classes (int): Number of outputs in `last_linear` layer. - For all models: 1000 """ - super(SENet, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} self.inplanes = inplanes self.num_classes = num_classes self.drop_rate = drop_rate if input_3x3: layer0_modules = [ - ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False)), - ('bn1', nn.BatchNorm2d(64)), + ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False, **dd)), + ('bn1', nn.BatchNorm2d(64, **dd)), ('relu1', nn.ReLU(inplace=True)), - ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), - ('bn2', nn.BatchNorm2d(64)), + ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False, **dd)), + ('bn2', nn.BatchNorm2d(64, **dd)), ('relu2', nn.ReLU(inplace=True)), - ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)), - ('bn3', nn.BatchNorm2d(inplanes)), + ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False, **dd)), + ('bn3', nn.BatchNorm2d(inplanes, **dd)), ('relu3', nn.ReLU(inplace=True)), ] else: layer0_modules = [ - ('conv1', nn.Conv2d( - in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), - ('bn1', nn.BatchNorm2d(inplanes)), + ('conv1', nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False, **dd)), + ('bn1', nn.BatchNorm2d(inplanes, **dd)), ('relu1', nn.ReLU(inplace=True)), ] self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) @@ -263,7 +330,8 @@ def __init__( groups=groups, reduction=reduction, downsample_kernel_size=1, - downsample_padding=0 + downsample_padding=0, + **dd, ) self.feature_info += [dict(num_chs=64 * block.expansion, reduction=4, module='layer1')] self.layer2 = self._make_layer( @@ -274,7 +342,8 @@ def __init__( groups=groups, reduction=reduction, downsample_kernel_size=downsample_kernel_size, - downsample_padding=downsample_padding + downsample_padding=downsample_padding, + **dd, ) self.feature_info += [dict(num_chs=128 * block.expansion, reduction=8, module='layer2')] self.layer3 = self._make_layer( @@ -285,7 +354,8 @@ def __init__( groups=groups, reduction=reduction, downsample_kernel_size=downsample_kernel_size, - downsample_padding=downsample_padding + downsample_padding=downsample_padding, + **dd, ) self.feature_info += [dict(num_chs=256 * block.expansion, reduction=16, module='layer3')] self.layer4 = self._make_layer( @@ -296,31 +366,37 @@ def __init__( groups=groups, reduction=reduction, downsample_kernel_size=downsample_kernel_size, - downsample_padding=downsample_padding + downsample_padding=downsample_padding, + **dd, ) self.feature_info += [dict(num_chs=512 * block.expansion, reduction=32, module='layer4')] self.num_features = self.head_hidden_size = 512 * block.expansion self.global_pool, self.last_linear = create_classifier( - self.num_features, self.num_classes, pool_type=global_pool) + self.num_features, + self.num_classes, + pool_type=global_pool, + **dd, + ) for m in self.modules(): _weight_init(m) def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, - downsample_kernel_size=1, downsample_padding=0): + downsample_kernel_size=1, downsample_padding=0, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d( self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size, - stride=stride, padding=downsample_padding, bias=False), - nn.BatchNorm2d(planes * block.expansion), + stride=stride, padding=downsample_padding, bias=False, **dd), + nn.BatchNorm2d(planes * block.expansion, **dd), ) - layers = [block(self.inplanes, planes, groups, reduction, stride, downsample)] + layers = [block(self.inplanes, planes, groups, reduction, stride, downsample, **dd)] self.inplanes = planes * block.expansion for i in range(1, blocks): - layers.append(block(self.inplanes, planes, groups, reduction)) + layers.append(block(self.inplanes, planes, groups, reduction, **dd)) return nn.Sequential(*layers) diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index 86c4b1df4d..1c3b21acae 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -9,7 +9,7 @@ import math from functools import partial from itertools import accumulate -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -58,7 +58,7 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=Fals class RNNIdentity(nn.Module): def __init__(self, *args, **kwargs): - super(RNNIdentity, self).__init__() + super().__init__() def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]: return x, None @@ -73,9 +73,12 @@ def __init__( num_layers: int = 1, bias: bool = True, bidirectional: bool = True, - union="cat", - with_fc=True, + union: str = "cat", + with_fc: bool = True, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.input_size = input_size @@ -90,14 +93,14 @@ def __init__( self.fc = None if with_fc: if union == "cat": - self.fc = nn.Linear(2 * self.output_size, input_size) + self.fc = nn.Linear(2 * self.output_size, input_size, **dd) elif union == "add": - self.fc = nn.Linear(self.output_size, input_size) + self.fc = nn.Linear(self.output_size, input_size, **dd) elif union == "vertical": - self.fc = nn.Linear(self.output_size, input_size) + self.fc = nn.Linear(self.output_size, input_size, **dd) self.with_horizontal = False elif union == "horizontal": - self.fc = nn.Linear(self.output_size, input_size) + self.fc = nn.Linear(self.output_size, input_size, **dd) self.with_vertical = False else: raise ValueError("Unrecognized union: " + union) @@ -167,10 +170,13 @@ def __init__( num_layers: int = 1, bias: bool = True, bidirectional: bool = True, - union="cat", - with_fc=True, + union: str = "cat", + with_fc: bool = True, + device=None, + dtype=None, ): - super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc) + dd = {'device': device, 'dtype': dtype} + super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc, device, dtype) if self.with_vertical: self.rnn_v = nn.LSTM( input_size, @@ -179,6 +185,7 @@ def __init__( batch_first=True, bias=bias, bidirectional=bidirectional, + **dd, ) if self.with_horizontal: self.rnn_h = nn.LSTM( @@ -188,29 +195,33 @@ def __init__( batch_first=True, bias=bias, bidirectional=bidirectional, + **dd, ) class Sequencer2dBlock(nn.Module): def __init__( self, - dim, - hidden_size, - mlp_ratio=3.0, - rnn_layer=LSTM2d, - mlp_layer=Mlp, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - act_layer=nn.GELU, - num_layers=1, - bidirectional=True, - union="cat", - with_fc=True, - drop=0., - drop_path=0., + dim: int, + hidden_size: int, + mlp_ratio: float = 3.0, + rnn_layer: Type[nn.Module] = LSTM2d, + mlp_layer: Type[nn.Module] = Mlp, + norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), + act_layer: Type[nn.Module] = nn.GELU, + num_layers: int = 1, + bidirectional: bool = True, + union: str = "cat", + with_fc: bool = True, + drop: float = 0., + drop_path: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() channels_dim = int(mlp_ratio * dim) - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.rnn_tokens = rnn_layer( dim, hidden_size, @@ -218,10 +229,11 @@ def __init__( bidirectional=bidirectional, union=union, with_fc=with_fc, + **dd, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) + self.norm2 = norm_layer(dim, **dd) + self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop, **dd) def forward(self, x): x = x + self.drop_path(self.rnn_tokens(self.norm1(x))) @@ -243,9 +255,17 @@ def forward(self, x): class Downsample2d(nn.Module): - def __init__(self, input_dim, output_dim, patch_size): + def __init__( + self, + input_dim: int, + output_dim: int, + patch_size: int, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.down = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size) + self.down = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size, **dd) def forward(self, x): x = x.permute(0, 3, 1, 2) @@ -257,28 +277,31 @@ def forward(self, x): class Sequencer2dStage(nn.Module): def __init__( self, - dim, - dim_out, - depth, - patch_size, - hidden_size, - mlp_ratio, - downsample=False, - block_layer=Sequencer2dBlock, - rnn_layer=LSTM2d, - mlp_layer=Mlp, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - act_layer=nn.GELU, - num_layers=1, - bidirectional=True, - union="cat", - with_fc=True, - drop=0., - drop_path=0., + dim: int, + dim_out: int, + depth: int, + patch_size: int, + hidden_size: int, + mlp_ratio: float, + downsample: bool = False, + block_layer: Type[nn.Module] = Sequencer2dBlock, + rnn_layer: Type[nn.Module] = LSTM2d, + mlp_layer: Type[nn.Module] = Mlp, + norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), + act_layer: Type[nn.Module] = nn.GELU, + num_layers: int = 1, + bidirectional: bool = True, + union: str = "cat", + with_fc: bool = True, + drop: float = 0., + drop_path: Union[float, List[float]] = 0., + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} if downsample: - self.downsample = Downsample2d(dim, dim_out, patch_size) + self.downsample = Downsample2d(dim, dim_out, patch_size, **dd) else: assert dim == dim_out self.downsample = nn.Identity() @@ -299,6 +322,7 @@ def __init__( with_fc=with_fc, drop=drop, drop_path=drop_path[block_idx] if isinstance(drop_path, (list, tuple)) else drop_path, + **dd, )) self.blocks = nn.Sequential(*blocks) @@ -311,30 +335,33 @@ def forward(self, x): class Sequencer2d(nn.Module): def __init__( self, - num_classes=1000, - img_size=224, - in_chans=3, - global_pool='avg', - layers=(4, 3, 8, 3), - patch_sizes=(7, 2, 2, 1), - embed_dims=(192, 384, 384, 384), - hidden_sizes=(48, 96, 96, 96), - mlp_ratios=(3.0, 3.0, 3.0, 3.0), - block_layer=Sequencer2dBlock, - rnn_layer=LSTM2d, - mlp_layer=Mlp, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - act_layer=nn.GELU, - num_rnn_layers=1, - bidirectional=True, - union="cat", - with_fc=True, - drop_rate=0., - drop_path_rate=0., - nlhb=False, - stem_norm=False, + num_classes: int = 1000, + img_size: int = 224, + in_chans: int = 3, + global_pool: str = 'avg', + layers: Tuple[int, ...] = (4, 3, 8, 3), + patch_sizes: Tuple[int, ...] = (7, 2, 2, 1), + embed_dims: Tuple[int, ...] = (192, 384, 384, 384), + hidden_sizes: Tuple[int, ...] = (48, 96, 96, 96), + mlp_ratios: Tuple[float, ...] = (3.0, 3.0, 3.0, 3.0), + block_layer: Type[nn.Module] = Sequencer2dBlock, + rnn_layer: Type[nn.Module] = LSTM2d, + mlp_layer: Type[nn.Module] = Mlp, + norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), + act_layer: Type[nn.Module] = nn.GELU, + num_rnn_layers: int = 1, + bidirectional: bool = True, + union: str = "cat", + with_fc: bool = True, + drop_rate: float = 0., + drop_path_rate: float = 0., + nlhb: bool = False, + stem_norm: bool = False, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} assert global_pool in ('', 'avg') self.num_classes = num_classes self.global_pool = global_pool @@ -351,6 +378,7 @@ def __init__( norm_layer=norm_layer if stem_norm else None, flatten=False, output_fmt='NHWC', + **dd, ) assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios) @@ -377,18 +405,20 @@ def __init__( with_fc=with_fc, drop=drop_rate, drop_path=drop_path_rate, + **dd, )] prev_dim = embed_dims[i] self.feature_info += [dict(num_chs=prev_dim, reduction=reductions[i], module=f'stages.{i}')] self.stages = nn.Sequential(*stages) - self.norm = norm_layer(embed_dims[-1]) + self.norm = norm_layer(embed_dims[-1], **dd) self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate, input_fmt=self.output_fmt, + **dd, ) self.init_weights(nlhb=nlhb) diff --git a/timm/models/shvit.py b/timm/models/shvit.py index c165f1a280..396bc64d72 100644 --- a/timm/models/shvit.py +++ b/timm/models/shvit.py @@ -11,7 +11,7 @@ year={2024} } """ -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import torch import torch.nn as nn @@ -57,12 +57,15 @@ def __init__( stride: int = 1, padding: int = 0, bn_weight_init: int = 1, + device=None, + dtype=None, **kwargs, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.add_module('c', nn.Conv2d( - in_channels, out_channels, kernel_size, stride, padding, bias=False, **kwargs)) - self.add_module('bn', nn.BatchNorm2d(out_channels)) + in_channels, out_channels, kernel_size, stride, padding, bias=False, **dd, **kwargs)) + self.add_module('bn', nn.BatchNorm2d(out_channels, **dd)) nn.init.constant_(self.bn.weight, bn_weight_init) nn.init.constant_(self.bn.bias, 0) @@ -95,10 +98,13 @@ def __init__( out_features: int, bias: bool = True, std: float = 0.02, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.add_module('bn', nn.BatchNorm1d(in_features)) - self.add_module('l', nn.Linear(in_features, out_features, bias=bias)) + self.add_module('bn', nn.BatchNorm1d(in_features, **dd)) + self.add_module('l', nn.Linear(in_features, out_features, bias=bias, **dd)) trunc_normal_(self.l.weight, std=std) if bias: nn.init.constant_(self.l.bias, 0) @@ -120,15 +126,23 @@ def fuse(self) -> nn.Linear: class PatchMerging(nn.Module): - def __init__(self, dim: int, out_dim: int, act_layer: LayerType = nn.ReLU): + def __init__( + self, + dim: int, + out_dim: int, + act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() hid_dim = int(dim * 4) - self.conv1 = Conv2dNorm(dim, hid_dim) + self.conv1 = Conv2dNorm(dim, hid_dim, **dd) self.act1 = act_layer() - self.conv2 = Conv2dNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim) + self.conv2 = Conv2dNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim, **dd) self.act2 = act_layer() - self.se = SqueezeExcite(hid_dim, 0.25) - self.conv3 = Conv2dNorm(hid_dim, out_dim) + self.se = SqueezeExcite(hid_dim, 0.25, **dd) + self.conv3 = Conv2dNorm(hid_dim, out_dim, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv1(x) @@ -141,11 +155,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FFN(nn.Module): - def __init__(self, dim: int, embed_dim: int, act_layer: LayerType = nn.ReLU): + def __init__( + self, + dim: int, + embed_dim: int, + act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.pw1 = Conv2dNorm(dim, embed_dim) + self.pw1 = Conv2dNorm(dim, embed_dim, **dd) self.act = act_layer() - self.pw2 = Conv2dNorm(embed_dim, dim, bn_weight_init=0) + self.pw2 = Conv2dNorm(embed_dim, dim, bn_weight_init=0, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pw1(x) @@ -161,19 +183,22 @@ def __init__( dim: int, qk_dim: int, pdim: int, - norm_layer: LayerType = GroupNorm1, - act_layer: LayerType = nn.ReLU, + norm_layer: Type[nn.Module] = GroupNorm1, + act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.scale = qk_dim ** -0.5 self.qk_dim = qk_dim self.dim = dim self.pdim = pdim - self.pre_norm = norm_layer(pdim) + self.pre_norm = norm_layer(pdim, **dd) - self.qkv = Conv2dNorm(pdim, qk_dim * 2 + pdim) - self.proj = nn.Sequential(act_layer(), Conv2dNorm(dim, dim, bn_weight_init=0)) + self.qkv = Conv2dNorm(pdim, qk_dim * 2 + pdim, **dd) + self.proj = nn.Sequential(act_layer(), Conv2dNorm(dim, dim, bn_weight_init=0, **dd)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, _, H, W = x.shape @@ -197,16 +222,19 @@ def __init__( qk_dim: int, pdim: int, type: str, - norm_layer: LayerType = GroupNorm1, - act_layer: LayerType = nn.ReLU, + norm_layer: Type[nn.Module] = GroupNorm1, + act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.conv = Residual(Conv2dNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0)) + self.conv = Residual(Conv2dNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0, **dd)) if type == "s": - self.mixer = Residual(SHSA(dim, qk_dim, pdim, norm_layer, act_layer)) + self.mixer = Residual(SHSA(dim, qk_dim, pdim, norm_layer, act_layer, **dd)) else: self.mixer = nn.Identity() - self.ffn = Residual(FFN(dim, int(dim * 2))) + self.ffn = Residual(FFN(dim, int(dim * 2), **dd)) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) @@ -224,21 +252,24 @@ def __init__( pdim: int, type: str, depth: int, - norm_layer: LayerType = GroupNorm1, - act_layer: LayerType = nn.ReLU, + norm_layer: Type[nn.Module] = GroupNorm1, + act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.grad_checkpointing = False self.downsample = nn.Sequential( - Residual(Conv2dNorm(prev_dim, prev_dim, 3, 1, 1, groups=prev_dim)), - Residual(FFN(prev_dim, int(prev_dim * 2), act_layer)), - PatchMerging(prev_dim, dim, act_layer), - Residual(Conv2dNorm(dim, dim, 3, 1, 1, groups=dim)), - Residual(FFN(dim, int(dim * 2), act_layer)), + Residual(Conv2dNorm(prev_dim, prev_dim, 3, 1, 1, groups=prev_dim, **dd)), + Residual(FFN(prev_dim, int(prev_dim * 2), act_layer, **dd)), + PatchMerging(prev_dim, dim, act_layer, **dd), + Residual(Conv2dNorm(dim, dim, 3, 1, 1, groups=dim, **dd)), + Residual(FFN(dim, int(dim * 2), act_layer, **dd)), ) if prev_dim != dim else nn.Identity() self.blocks = nn.Sequential(*[ - BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer) for _ in range(depth) + BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer, **dd) for _ in range(depth) ]) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -262,10 +293,13 @@ def __init__( depth: Tuple[int, int, int] = (1, 2, 3), types: Tuple[str, str, str] = ("s", "s", "s"), drop_rate: float = 0., - norm_layer: LayerType = GroupNorm1, - act_layer: LayerType = nn.ReLU, + norm_layer: Type[nn.Module] = GroupNorm1, + act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.drop_rate = drop_rate self.feature_info = [] @@ -273,13 +307,13 @@ def __init__( # Patch embedding stem_chs = embed_dim[0] self.patch_embed = nn.Sequential( - Conv2dNorm(in_chans, stem_chs // 8, 3, 2, 1), + Conv2dNorm(in_chans, stem_chs // 8, 3, 2, 1, **dd), act_layer(), - Conv2dNorm(stem_chs // 8, stem_chs // 4, 3, 2, 1), + Conv2dNorm(stem_chs // 8, stem_chs // 4, 3, 2, 1, **dd), act_layer(), - Conv2dNorm(stem_chs // 4, stem_chs // 2, 3, 2, 1), + Conv2dNorm(stem_chs // 4, stem_chs // 2, 3, 2, 1, **dd), act_layer(), - Conv2dNorm(stem_chs // 2, stem_chs, 3, 2, 1) + Conv2dNorm(stem_chs // 2, stem_chs, 3, 2, 1, **dd) ) # Build SHViT blocks @@ -295,6 +329,7 @@ def __init__( depth=depth[i], norm_layer=norm_layer, act_layer=act_layer, + **dd, )) prev_chs = embed_dim[i] self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'stages.{i}')) @@ -304,7 +339,7 @@ def __init__( self.num_features = self.head_hidden_size = embed_dim[-1] self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.head = NormLinear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() + self.head = NormLinear(self.head_hidden_size, num_classes, **dd) if num_classes > 0 else nn.Identity() @torch.jit.ignore def no_weight_decay(self) -> Set: diff --git a/timm/models/sknet.py b/timm/models/sknet.py index b12df2319f..5884d20261 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -9,6 +9,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ import math +from typing import Optional, Type from torch import nn as nn @@ -24,27 +25,30 @@ class SelectiveKernelBasic(nn.Module): def __init__( self, - inplanes, - planes, - stride=1, - downsample=None, - cardinality=1, - base_width=64, - sk_kwargs=None, - reduce_first=1, - dilation=1, - first_dilation=None, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - attn_layer=None, - aa_layer=None, - drop_block=None, - drop_path=None, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + cardinality: int = 1, + base_width: int = 64, + sk_kwargs: Optional[dict] = None, + reduce_first: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + attn_layer: Optional[Type[nn.Module]] = None, + aa_layer: Optional[Type[nn.Module]] = None, + drop_block: Optional[nn.Module] = None, + drop_path: Optional[nn.Module] = None, + device=None, + dtype=None, ): - super(SelectiveKernelBasic, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() sk_kwargs = sk_kwargs or {} - conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer) + conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer, **dd) assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first @@ -52,11 +56,24 @@ def __init__( first_dilation = first_dilation or dilation self.conv1 = SelectiveKernel( - inplanes, first_planes, stride=stride, dilation=first_dilation, - aa_layer=aa_layer, drop_layer=drop_block, **conv_kwargs, **sk_kwargs) + inplanes, + first_planes, + stride=stride, + dilation=first_dilation, + aa_layer=aa_layer, + drop_layer=drop_block, + **conv_kwargs, + **sk_kwargs, + ) self.conv2 = ConvNormAct( - first_planes, outplanes, kernel_size=3, dilation=dilation, apply_act=False, **conv_kwargs) - self.se = create_attn(attn_layer, outplanes) + first_planes, + outplanes, + kernel_size=3, + dilation=dilation, + apply_act=False, + **conv_kwargs, + ) + self.se = create_attn(attn_layer, outplanes, **dd) self.act = act_layer(inplace=True) self.downsample = downsample self.drop_path = drop_path @@ -85,27 +102,30 @@ class SelectiveKernelBottleneck(nn.Module): def __init__( self, - inplanes, - planes, - stride=1, - downsample=None, - cardinality=1, - base_width=64, - sk_kwargs=None, - reduce_first=1, - dilation=1, - first_dilation=None, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - attn_layer=None, - aa_layer=None, - drop_block=None, - drop_path=None, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + cardinality: int = 1, + base_width: int = 64, + sk_kwargs: Optional[dict] = None, + reduce_first: int = 1, + dilation: int = 1, + first_dilation: Optional[int] = None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + attn_layer: Optional[Type[nn.Module]] = None, + aa_layer: Optional[Type[nn.Module]] = None, + drop_block: Optional[nn.Module] = None, + drop_path: Optional[nn.Module] = None, + device=None, + dtype=None, ): - super(SelectiveKernelBottleneck, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() sk_kwargs = sk_kwargs or {} - conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer) + conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer, **dd) width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion @@ -113,10 +133,18 @@ def __init__( self.conv1 = ConvNormAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) self.conv2 = SelectiveKernel( - first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, - aa_layer=aa_layer, drop_layer=drop_block, **conv_kwargs, **sk_kwargs) + first_planes, + width, + stride=stride, + dilation=first_dilation, + groups=cardinality, + aa_layer=aa_layer, + drop_layer=drop_block, + **conv_kwargs, + **sk_kwargs, + ) self.conv3 = ConvNormAct(width, outplanes, kernel_size=1, apply_act=False, **conv_kwargs) - self.se = create_attn(attn_layer, outplanes) + self.se = create_attn(attn_layer, outplanes, **dd) self.act = act_layer(inplace=True) self.downsample = downsample self.drop_path = drop_path diff --git a/timm/models/starnet.py b/timm/models/starnet.py index 414b0aa8a4..1cacd05cf9 100644 --- a/timm/models/starnet.py +++ b/timm/models/starnet.py @@ -9,7 +9,7 @@ Created by: Xu Ma (Email: ma.xu1@northeastern.edu) Modified Date: Mar/29/2024 """ -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type import torch import torch.nn as nn @@ -34,13 +34,16 @@ def __init__( stride: int = 1, padding: int = 0, with_bn: bool = True, + device=None, + dtype=None, **kwargs, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.add_module('conv', nn.Conv2d( - in_channels, out_channels, kernel_size, stride=stride, padding=padding, **kwargs)) + in_channels, out_channels, kernel_size, stride=stride, padding=padding, **dd, **kwargs)) if with_bn: - self.add_module('bn', nn.BatchNorm2d(out_channels)) + self.add_module('bn', nn.BatchNorm2d(out_channels, **dd)) nn.init.constant_(self.bn.weight, 1) nn.init.constant_(self.bn.bias, 0) @@ -51,14 +54,17 @@ def __init__( dim: int, mlp_ratio: int = 3, drop_path: float = 0., - act_layer: LayerType = nn.ReLU6, + act_layer: Type[nn.Module] = nn.ReLU6, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.dwconv = ConvBN(dim, dim, 7, 1, 3, groups=dim, with_bn=True) - self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False) - self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False) - self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True) - self.dwconv2 = ConvBN(dim, dim, 7, 1, 3, groups=dim, with_bn=False) + self.dwconv = ConvBN(dim, dim, 7, 1, 3, groups=dim, with_bn=True, **dd) + self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False, **dd) + self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False, **dd) + self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True, **dd) + self.dwconv2 = ConvBN(dim, dim, 7, 1, 3, groups=dim, with_bn=False, **dd) self.act = act_layer() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -80,13 +86,16 @@ def __init__( mlp_ratio: int = 4, drop_rate: float = 0., drop_path_rate: float = 0., - act_layer: LayerType = nn.ReLU6, + act_layer: Type[nn.Module] = nn.ReLU6, num_classes: int = 1000, in_chans: int = 3, global_pool: str = 'avg', output_stride: int = 32, + device=None, + dtype=None, **kwargs, ): + dd = {'device': device, 'dtype': dtype} super().__init__() assert output_stride == 32 self.num_classes = num_classes @@ -97,7 +106,7 @@ def __init__( # stem layer self.stem = nn.Sequential( - ConvBN(in_chans, stem_chs, kernel_size=3, stride=2, padding=1), + ConvBN(in_chans, stem_chs, kernel_size=3, stride=2, padding=1, **dd), act_layer(), ) prev_chs = stem_chs @@ -108,8 +117,8 @@ def __init__( cur = 0 for i_layer in range(len(depths)): embed_dim = base_dim * 2 ** i_layer - down_sampler = ConvBN(prev_chs, embed_dim, 3, stride=2, padding=1) - blocks = [Block(embed_dim, mlp_ratio, dpr[cur + i], act_layer) for i in range(depths[i_layer])] + down_sampler = ConvBN(prev_chs, embed_dim, 3, stride=2, padding=1, **dd) + blocks = [Block(embed_dim, mlp_ratio, dpr[cur + i], act_layer, **dd) for i in range(depths[i_layer])] cur += depths[i_layer] prev_chs = embed_dim stages.append(nn.Sequential(down_sampler, *blocks)) @@ -118,10 +127,10 @@ def __init__( self.stages = nn.Sequential(*stages) # head self.num_features = self.head_hidden_size = prev_chs - self.norm = nn.BatchNorm2d(self.num_features) + self.norm = nn.BatchNorm2d(self.num_features, **dd) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.head = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): @@ -162,7 +171,11 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): # NOTE: cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.head = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() + self.head = Linear( + self.head_hidden_size, num_classes, + device=self.head.weight.device if isinstance(self.head, nn.Linear) else None, + dtype=self.head.weight.dtype if isinstance(self.head, nn.Linear) else None, + ) if num_classes > 0 else nn.Identity() def forward_intermediates( self, diff --git a/timm/models/swiftformer.py b/timm/models/swiftformer.py index 38df6f1638..1a9900f1f0 100644 --- a/timm/models/swiftformer.py +++ b/timm/models/swiftformer.py @@ -11,7 +11,7 @@ } """ import re -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import torch import torch.nn as nn @@ -28,11 +28,12 @@ class LayerScale2d(nn.Module): - def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False): + def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} super().__init__() self.inplace = inplace self.gamma = nn.Parameter( - init_values * torch.ones(dim, 1, 1), requires_grad=True) + init_values * torch.ones(dim, 1, 1, **dd), requires_grad=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma @@ -51,14 +52,17 @@ def __init__( patch_size: int = 16, stride: int = 16, padding: int = 0, - norm_layer: LayerType = nn.BatchNorm2d, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() patch_size = to_2tuple(patch_size) stride = to_2tuple(stride) padding = to_2tuple(padding) - self.proj = nn.Conv2d(in_chans, embed_dim, patch_size, stride, padding) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + self.proj = nn.Conv2d(in_chans, embed_dim, patch_size, stride, padding, **dd) + self.norm = norm_layer(embed_dim, **dd) if norm_layer else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) @@ -78,18 +82,21 @@ def __init__( hidden_dim: int = 64, kernel_size: int = 3, drop_path: float = 0., - act_layer: LayerType = nn.GELU, - norm_layer: LayerType = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, use_layer_scale: bool = True, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.dwconv = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim) - self.norm = norm_layer(dim) - self.pwconv1 = nn.Conv2d(dim, hidden_dim, 1) + self.dwconv = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim, **dd) + self.norm = norm_layer(dim, **dd) + self.pwconv1 = nn.Conv2d(dim, hidden_dim, 1, **dd) self.act = act_layer() - self.pwconv2 = nn.Conv2d(hidden_dim, dim, 1) + self.pwconv2 = nn.Conv2d(hidden_dim, dim, 1, **dd) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.layer_scale = LayerScale2d(dim, 1) if use_layer_scale else nn.Identity() + self.layer_scale = LayerScale2d(dim, 1, **dd) if use_layer_scale else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: input = x @@ -114,17 +121,20 @@ def __init__( in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, - act_layer: LayerType = nn.GELU, - norm_layer: LayerType = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, drop: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.norm1 = norm_layer(in_features) - self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.norm1 = norm_layer(in_features, **dd) + self.fc1 = nn.Conv2d(in_features, hidden_features, 1, **dd) self.act = act_layer() - self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1, **dd) self.drop = nn.Dropout(drop) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -143,16 +153,24 @@ class EfficientAdditiveAttention(nn.Module): Input: tensor in shape [B, C, H, W] Output: tensor in shape [B, C, H, W] """ - def __init__(self, in_dims: int = 512, token_dim: int = 256, num_heads: int = 1): + def __init__( + self, + in_dims: int = 512, + token_dim: int = 256, + num_heads: int = 1, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.scale_factor = token_dim ** -0.5 - self.to_query = nn.Linear(in_dims, token_dim * num_heads) - self.to_key = nn.Linear(in_dims, token_dim * num_heads) + self.to_query = nn.Linear(in_dims, token_dim * num_heads, **dd) + self.to_key = nn.Linear(in_dims, token_dim * num_heads, **dd) - self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1)) + self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1, **dd)) - self.proj = nn.Linear(token_dim * num_heads, token_dim * num_heads) - self.final = nn.Linear(token_dim * num_heads, token_dim) + self.proj = nn.Linear(token_dim * num_heads, token_dim * num_heads, **dd) + self.final = nn.Linear(token_dim * num_heads, token_dim, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: B, _, H, W = x.shape @@ -181,17 +199,20 @@ def __init__( kernel_size: int = 3, drop_path: float = 0., use_layer_scale: bool = True, - act_layer: LayerType = nn.GELU, - norm_layer: LayerType = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.dwconv = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim) - self.norm = norm_layer(dim) - self.pwconv1 = nn.Conv2d(dim, dim, kernel_size=1) + self.dwconv = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim, **dd) + self.norm = norm_layer(dim, **dd) + self.pwconv1 = nn.Conv2d(dim, dim, kernel_size=1, **dd) self.act = act_layer() - self.pwconv2 = nn.Conv2d(dim, dim, kernel_size=1) + self.pwconv2 = nn.Conv2d(dim, dim, kernel_size=1, **dd) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.layer_scale = LayerScale2d(dim, 1) if use_layer_scale else nn.Identity() + self.layer_scale = LayerScale2d(dim, 1, **dd) if use_layer_scale else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: skip = x @@ -218,30 +239,35 @@ def __init__( mlp_ratio: float = 4., drop_rate: float = 0., drop_path: float = 0., - act_layer: LayerType = nn.GELU, - norm_layer: LayerType = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.local_representation = LocalRepresentation( dim=dim, use_layer_scale=use_layer_scale, act_layer=act_layer, norm_layer=norm_layer, + **dd, ) - self.attn = EfficientAdditiveAttention(in_dims=dim, token_dim=dim) + self.attn = EfficientAdditiveAttention(in_dims=dim, token_dim=dim, **dd) self.linear = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, norm_layer=norm_layer, drop=drop_rate, + **dd, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value) \ + self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value, **dd) \ if use_layer_scale else nn.Identity() - self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value) \ + self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value, **dd) \ if use_layer_scale else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -263,14 +289,17 @@ def __init__( index: int, layers: List[int], mlp_ratio: float = 4., - act_layer: LayerType = nn.GELU, - norm_layer: LayerType = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, drop_rate: float = 0., drop_path_rate: float = 0., use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, - downsample: Optional[LayerType] = None, + downsample: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.grad_checkpointing = False self.downsample = downsample if downsample is not None else nn.Identity() @@ -288,6 +317,7 @@ def __init__( norm_layer=norm_layer, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, + **dd, )) else: blocks.append(ConvEncoder( @@ -298,6 +328,7 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, use_layer_scale=use_layer_scale, + **dd, )) self.blocks = nn.Sequential(*blocks) @@ -317,7 +348,7 @@ def __init__( embed_dims: List[int] = [48, 56, 112, 220], mlp_ratios: int = 4, downsamples: List[bool] = [False, True, True, True], - act_layer: LayerType = nn.GELU, + act_layer: Type[nn.Module] = nn.GELU, down_patch_size: int = 3, down_stride: int = 2, down_pad: int = 1, @@ -329,20 +360,23 @@ def __init__( global_pool: str = 'avg', output_stride: int = 32, in_chans: int = 3, + device=None, + dtype=None, **kwargs, ): super().__init__() + dd = {'device': device, 'dtype': dtype} assert output_stride == 32 self.num_classes = num_classes self.global_pool = global_pool self.feature_info = [] self.stem = nn.Sequential( - nn.Conv2d(in_chans, embed_dims[0] // 2, 3, 2, 1), - nn.BatchNorm2d(embed_dims[0] // 2), + nn.Conv2d(in_chans, embed_dims[0] // 2, 3, 2, 1, **dd), + nn.BatchNorm2d(embed_dims[0] // 2, **dd), nn.ReLU(), - nn.Conv2d(embed_dims[0] // 2, embed_dims[0], 3, 2, 1), - nn.BatchNorm2d(embed_dims[0]), + nn.Conv2d(embed_dims[0] // 2, embed_dims[0], 3, 2, 1, **dd), + nn.BatchNorm2d(embed_dims[0], **dd), nn.ReLU(), ) prev_dim = embed_dims[0] @@ -355,6 +389,7 @@ def __init__( patch_size=down_patch_size, stride=down_stride, padding=down_pad, + **dd, ) if downsamples[i] else nn.Identity() stage = Stage( dim=embed_dims[i], @@ -367,6 +402,7 @@ def __init__( use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, downsample=downsample, + **dd, ) prev_dim = embed_dims[i] stages.append(stage) @@ -375,11 +411,11 @@ def __init__( # Classifier head self.num_features = self.head_hidden_size = out_chs = embed_dims[-1] - self.norm = nn.BatchNorm2d(out_chs) + self.norm = nn.BatchNorm2d(out_chs, **dd) self.head_drop = nn.Dropout(drop_rate) - self.head = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity() + self.head = Linear(out_chs, num_classes, **dd) if num_classes > 0 else nn.Identity() # assuming model is always distilled (valid for current checkpoints, will split def if that changes) - self.head_dist = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = Linear(out_chs, num_classes, **dd) if num_classes > 0 else nn.Identity() self.distilled_training = False # must set this True to train w/ distillation token self._initialize_weights() @@ -423,8 +459,9 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool - self.head = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + device, dtype = self.head.weight.device, self.head.weight.dtype if hasattr(self.head, 'weight') else (None, None) + self.head = Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity() + self.head_dist = Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity() @torch.jit.ignore def set_distilled_training(self, enable: bool = True): diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 98409975ab..c0dd3c0497 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -17,7 +17,7 @@ # -------------------------------------------------------- import logging import math -from typing import Any, Dict, Callable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Callable, List, Optional, Set, Tuple, Union, Type import torch import torch.nn as nn @@ -77,7 +77,7 @@ def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], H: int, return x -def get_relative_position_index(win_h: int, win_w: int) -> torch.Tensor: +def get_relative_position_index(win_h: int, win_w: int, device=None) -> torch.Tensor: """Get pair-wise relative position index for each token inside the window. Args: @@ -88,7 +88,10 @@ def get_relative_position_index(win_h: int, win_w: int) -> torch.Tensor: Relative position index tensor. """ # get pair-wise relative position index for each token inside the window - coords = torch.stack(ndgrid(torch.arange(win_h), torch.arange(win_w))) # 2, Wh, Ww + coords = torch.stack(ndgrid( + torch.arange(win_h, device=device, dtype=torch.long), + torch.arange(win_w, device=device, dtype=torch.long), + )) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 @@ -114,6 +117,8 @@ def __init__( qkv_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., + device=None, + dtype=None, ): """ Args: @@ -125,6 +130,7 @@ def __init__( attn_drop: Dropout ratio of attention weight. proj_drop: Dropout ratio of output. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.window_size = to_2tuple(window_size) # Wh, Ww @@ -137,14 +143,19 @@ def __init__( self.fused_attn = use_fused_attn(experimental=True) # NOTE not tested for prime-time yet # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH - self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)) + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads, **dd)) # get pair-wise relative position index for each token inside the window - self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False) + self.register_buffer( + "relative_position_index", + get_relative_position_index(win_h, win_w, device=device), + persistent=False, + ) - self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(attn_dim, dim) + self.proj = nn.Linear(attn_dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) @@ -169,7 +180,11 @@ def set_window_size(self, window_size: Tuple[int, int]) -> None: new_window_size=self.window_size, new_bias_shape=new_bias_shape, )) - self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False) + self.register_buffer( + "relative_position_index", + get_relative_position_index(win_h, win_w, device=self.relative_position_bias_table.device), + persistent=False, + ) def _get_rel_pos_bias(self) -> torch.Tensor: relative_position_bias = self.relative_position_bias_table[ @@ -241,8 +256,10 @@ def __init__( proj_drop: float = 0., attn_drop: float = 0., drop_path: float = 0., - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): """ Args: @@ -261,6 +278,7 @@ def __init__( act_layer: Activation layer. norm_layer: Normalization layer. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -271,7 +289,7 @@ def __init__( self.window_area = self.window_size[0] * self.window_size[1] self.mlp_ratio = mlp_ratio - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = WindowAttention( dim, num_heads=num_heads, @@ -280,25 +298,32 @@ def __init__( qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, + **dd, ) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, + **dd, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.register_buffer( "attn_mask", - None if self.dynamic_mask else self.get_attn_mask(), + None if self.dynamic_mask else self.get_attn_mask(**dd), persistent=False, ) - def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: + def get_attn_mask( + self, + x: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Optional[torch.Tensor]: if any(self.shift_size): # calculate attention mask for SW-MSA if x is not None: @@ -307,8 +332,8 @@ def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tens dtype = x.dtype else: H, W = self.input_resolution - device = None - dtype = None + device = device + dtype = dtype H = math.ceil(H / self.window_size[0]) * self.window_size[0] W = math.ceil(W / self.window_size[1]) * self.window_size[1] img_mask = torch.zeros((1, H, W, 1), dtype=dtype, device=device) # 1 H W 1 @@ -372,9 +397,11 @@ def set_input_size( self.window_size, self.shift_size = self._calc_window_shift(window_size) self.window_area = self.window_size[0] * self.window_size[1] self.attn.set_window_size(self.window_size) + device = self.attn_mask.device if self.attn_mask is not None else None + dtype = self.attn_mask.dtype if self.attn_mask is not None else None self.register_buffer( "attn_mask", - None if self.dynamic_mask else self.get_attn_mask(), + None if self.dynamic_mask else self.get_attn_mask(device=device, dtype=dtype), persistent=False, ) @@ -444,7 +471,9 @@ def __init__( self, dim: int, out_dim: Optional[int] = None, - norm_layer: Callable = nn.LayerNorm, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): """ Args: @@ -452,11 +481,12 @@ def __init__( out_dim: Number of output channels (or 2 * dim if None) norm_layer: Normalization layer. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.out_dim = out_dim or 2 * dim - self.norm = norm_layer(4 * dim) - self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False) + self.norm = norm_layer(4 * dim, **dd) + self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. @@ -502,7 +532,9 @@ def __init__( proj_drop: float = 0., attn_drop: float = 0., drop_path: Union[List[float], float] = 0., - norm_layer: Callable = nn.LayerNorm, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): """ Args: @@ -521,6 +553,7 @@ def __init__( drop_path: Stochastic depth rate. norm_layer: Normalization layer. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -536,6 +569,7 @@ def __init__( dim=dim, out_dim=out_dim, norm_layer=norm_layer, + **dd, ) else: assert dim == out_dim @@ -558,6 +592,7 @@ def __init__( attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, + **dd, ) for i in range(depth)]) @@ -631,9 +666,11 @@ def __init__( proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0.1, - embed_layer: Callable = PatchEmbed, - norm_layer: Union[str, Callable] = nn.LayerNorm, + embed_layer: Type[nn.Module] = PatchEmbed, + norm_layer: Union[str, Type[nn.Module]] = nn.LayerNorm, weight_init: str = '', + device=None, + dtype=None, **kwargs, ): """ @@ -656,6 +693,7 @@ def __init__( norm_layer (nn.Module): Normalization layer. """ super().__init__() + dd = {'device': device, 'dtype': dtype} assert global_pool in ('', 'avg') self.num_classes = num_classes self.global_pool = global_pool @@ -678,6 +716,7 @@ def __init__( norm_layer=norm_layer, strict_img_size=strict_img_size, output_fmt='NHWC', + **dd, ) patch_grid = self.patch_embed.grid_size @@ -715,6 +754,7 @@ def __init__( attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + **dd, )] in_dim = out_dim if i > 0: @@ -722,13 +762,14 @@ def __init__( self.feature_info += [dict(num_chs=out_dim, reduction=patch_size * scale, module=f'layers.{i}')] self.layers = nn.Sequential(*layers) - self.norm = norm_layer(self.num_features) + self.norm = norm_layer(self.num_features, **dd) self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate, input_fmt=self.output_fmt, + **dd, ) if weight_init != 'skip': self.init_weights(weight_init) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 96e2be2730..04bdedb01b 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -33,7 +33,10 @@ _int_or_tuple_2_t = Union[int, Tuple[int, int]] -def window_partition(x: torch.Tensor, window_size: Tuple[int, int]) -> torch.Tensor: +def window_partition( + x: torch.Tensor, + window_size: Tuple[int, int], +) -> torch.Tensor: """Partition into non-overlapping windows. Args: @@ -50,7 +53,11 @@ def window_partition(x: torch.Tensor, window_size: Tuple[int, int]) -> torch.Ten @register_notrace_function # reason: int argument is a Proxy -def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], img_size: Tuple[int, int]) -> torch.Tensor: +def window_reverse( + windows: torch.Tensor, + window_size: Tuple[int, int], + img_size: Tuple[int, int], +) -> torch.Tensor: """Merge windows back to feature map. Args: @@ -85,6 +92,8 @@ def __init__( attn_drop: float = 0., proj_drop: float = 0., pretrained_window_size: Tuple[int, int] = (0, 0), + device=None, + dtype=None, ) -> None: """Initialize window attention module. @@ -98,6 +107,7 @@ def __init__( proj_drop: Dropout ratio of output. pretrained_window_size: The height and width of the window in pre-training. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww @@ -105,36 +115,38 @@ def __init__( self.num_heads = num_heads self.qkv_bias_separate = qkv_bias_separate - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1), **dd))) # mlp to generate continuous relative position bias self.cpb_mlp = nn.Sequential( - nn.Linear(2, 512, bias=True), + nn.Linear(2, 512, bias=True, **dd), nn.ReLU(inplace=True), - nn.Linear(512, num_heads, bias=False) + nn.Linear(512, num_heads, bias=False, **dd) ) - self.qkv = nn.Linear(dim, dim * 3, bias=False) + self.qkv = nn.Linear(dim, dim * 3, bias=False, **dd) if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(dim)) - self.register_buffer('k_bias', torch.zeros(dim), persistent=False) - self.v_bias = nn.Parameter(torch.zeros(dim)) + self.q_bias = nn.Parameter(torch.zeros(dim, **dd)) + self.register_buffer('k_bias', torch.zeros(dim, **dd), persistent=False) + self.v_bias = nn.Parameter(torch.zeros(dim, **dd)) else: self.q_bias = None self.k_bias = None self.v_bias = None self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) self.softmax = nn.Softmax(dim=-1) - self._make_pair_wise_relative_positions() + self._make_pair_wise_relative_positions(**dd) - def _make_pair_wise_relative_positions(self) -> None: + def _make_pair_wise_relative_positions(self, device=None, dtype=None) -> None: """Create pair-wise relative position index and coordinates table.""" # get relative_coords_table - relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0]).to(torch.float32) - relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1]).to(torch.float32) + relative_coords_h = torch.arange( + -(self.window_size[0] - 1), self.window_size[0], device=device, dtype=torch.float32) + relative_coords_w = torch.arange( + -(self.window_size[1] - 1), self.window_size[1], device=device, dtype=torch.float32) relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w)) relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 if self.pretrained_window_size[0] > 0: @@ -146,11 +158,11 @@ def _make_pair_wise_relative_positions(self) -> None: relative_coords_table *= 8 # normalize to -8, 8 relative_coords_table = torch.sign(relative_coords_table) * torch.log2( torch.abs(relative_coords_table) + 1.0) / math.log2(8) - self.register_buffer("relative_coords_table", relative_coords_table, persistent=False) + self.register_buffer("relative_coords_table", relative_coords_table.to(dtype=dtype), persistent=False) # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) + coords_h = torch.arange(self.window_size[0], device=device, dtype=torch.long) + coords_w = torch.arange(self.window_size[1], device=device, dtype=torch.long) coords = torch.stack(ndgrid(coords_h, coords_w)) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww @@ -169,8 +181,11 @@ def set_window_size(self, window_size: Tuple[int, int]) -> None: """ window_size = to_2tuple(window_size) if window_size != self.window_size: + assert self.relative_coords_table is not None + device = self.relative_coords_table.device + dtype = self.relative_coords_table.dtype self.window_size = window_size - self._make_pair_wise_relative_positions() + self._make_pair_wise_relative_positions(device=device, dtype=dtype) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Forward pass of window attention. @@ -248,6 +263,8 @@ def __init__( act_layer: LayerType = "gelu", norm_layer: Type[nn.Module] = nn.LayerNorm, pretrained_window_size: _int_or_tuple_2_t = 0, + device=None, + dtype=None, ): """ Args: @@ -266,6 +283,7 @@ def __init__( norm_layer: Normalization layer. pretrained_window_size: Window size in pretraining. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.input_resolution = to_2tuple(input_resolution) @@ -286,8 +304,9 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, pretrained_window_size=to_2tuple(pretrained_window_size), + **dd, ) - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.mlp = Mlp( @@ -295,17 +314,23 @@ def __init__( hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, + **dd, ) - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.register_buffer( "attn_mask", - None if self.dynamic_mask else self.get_attn_mask(), + None if self.dynamic_mask else self.get_attn_mask(**dd), persistent=False, ) - def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: + def get_attn_mask( + self, + x: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Optional[torch.Tensor]: """Generate attention mask for shifted window attention. Args: @@ -317,9 +342,9 @@ def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tens if any(self.shift_size): # calculate attention mask for SW-MSA if x is None: - img_mask = torch.zeros((1, *self.input_resolution, 1)) # 1 H W 1 + img_mask = torch.zeros((1, *self.input_resolution, 1), device=device, dtype=dtype) # 1 H W 1 else: - img_mask = torch.zeros((1, x.shape[1], x.shape[2], 1), dtype=x.dtype, device=x.device) # 1 H W 1 + img_mask = torch.zeros((1, x.shape[1], x.shape[2], 1), device=x.device, dtype=x.dtype) # 1 H W 1 cnt = 0 for h in ( (0, -self.window_size[0]), @@ -394,9 +419,11 @@ def set_input_size( self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) self.window_area = self.window_size[0] * self.window_size[1] self.attn.set_window_size(self.window_size) + device = self.attn_mask.device if self.attn_mask is not None else None + dtype = self.attn_mask.dtype if self.attn_mask is not None else None self.register_buffer( "attn_mask", - None if self.dynamic_mask else self.get_attn_mask(), + None if self.dynamic_mask else self.get_attn_mask(device=device, dtype=dtype), persistent=False, ) @@ -466,7 +493,9 @@ def __init__( self, dim: int, out_dim: Optional[int] = None, - norm_layer: Type[nn.Module] = nn.LayerNorm + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): """ Args: @@ -474,11 +503,12 @@ def __init__( out_dim (int): Number of output channels (or 2 * dim if None) norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.out_dim = out_dim or 2 * dim - self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False) - self.norm = norm_layer(self.out_dim) + self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False, **dd) + self.norm = norm_layer(self.out_dim, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, C = x.shape @@ -516,10 +546,12 @@ def __init__( proj_drop: float = 0., attn_drop: float = 0., drop_path: float = 0., - act_layer: Union[str, Callable] = 'gelu', + act_layer: Union[str, Type[nn.Module]] = 'gelu', norm_layer: Type[nn.Module] = nn.LayerNorm, pretrained_window_size: _int_or_tuple_2_t = 0, output_nchw: bool = False, + device=None, + dtype=None, ) -> None: """ Args: @@ -542,6 +574,7 @@ def __init__( pretrained_window_size: Local window size in pretraining. output_nchw: Output tensors on NCHW format instead of NHWC. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -554,7 +587,7 @@ def __init__( # patch merging / downsample layer if downsample: - self.downsample = PatchMerging(dim=dim, out_dim=out_dim, norm_layer=norm_layer) + self.downsample = PatchMerging(dim=dim, out_dim=out_dim, norm_layer=norm_layer, **dd) else: assert dim == out_dim self.downsample = nn.Identity() @@ -577,6 +610,7 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, pretrained_window_size=pretrained_window_size, + **dd, ) for i in range(depth)]) @@ -663,8 +697,10 @@ def __init__( attn_drop_rate: float = 0., drop_path_rate: float = 0.1, act_layer: Union[str, Callable] = 'gelu', - norm_layer: Callable = nn.LayerNorm, + norm_layer: Type[nn.Module] = nn.LayerNorm, pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0), + device=None, + dtype=None, **kwargs, ): """ @@ -690,6 +726,7 @@ def __init__( output_fmt: Output tensor format if not None, otherwise output 'NHWC' by default. """ super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes assert global_pool in ('', 'avg') @@ -712,6 +749,7 @@ def __init__( norm_layer=norm_layer, strict_img_size=strict_img_size, output_fmt='NHWC', + **dd, ) grid_size = self.patch_embed.grid_size @@ -739,6 +777,7 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, pretrained_window_size=pretrained_window_sizes[i], + **dd, )] in_dim = out_dim if i > 0: @@ -746,13 +785,14 @@ def __init__( self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')] self.layers = nn.Sequential(*layers) - self.norm = norm_layer(self.num_features) + self.norm = norm_layer(self.num_features, **dd) self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate, input_fmt=self.output_fmt, + **dd, ) self.apply(self._init_weights) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 7ac7d7aef3..34aefbcebc 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -107,16 +107,19 @@ class WindowMultiHeadAttention(nn.Module): """ def __init__( - self, - dim: int, - num_heads: int, - window_size: Tuple[int, int], - drop_attn: float = 0.0, - drop_proj: float = 0.0, - meta_hidden_dim: int = 384, # FIXME what's the optimal value? - sequential_attn: bool = False, + self, + dim: int, + num_heads: int, + window_size: Tuple[int, int], + drop_attn: float = 0.0, + drop_proj: float = 0.0, + meta_hidden_dim: int = 384, # FIXME what's the optimal value? + sequential_attn: bool = False, + device=None, + dtype=None, ) -> None: - super(WindowMultiHeadAttention, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() assert dim % num_heads == 0, \ "The number of input features (in_features) are not divisible by the number of heads (num_heads)." self.in_features: int = dim @@ -124,9 +127,9 @@ def __init__( self.num_heads: int = num_heads self.sequential_attn: bool = sequential_attn - self.qkv = nn.Linear(in_features=dim, out_features=dim * 3, bias=True) + self.qkv = nn.Linear(in_features=dim, out_features=dim * 3, bias=True, **dd) self.attn_drop = nn.Dropout(drop_attn) - self.proj = nn.Linear(in_features=dim, out_features=dim, bias=True) + self.proj = nn.Linear(in_features=dim, out_features=dim, bias=True, **dd) self.proj_drop = nn.Dropout(drop_proj) # meta network for positional encodings self.meta_mlp = Mlp( @@ -134,24 +137,29 @@ def __init__( hidden_features=meta_hidden_dim, out_features=num_heads, act_layer=nn.ReLU, - drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without? + drop=(0.125, 0.), # FIXME should there be stochasticity, appears to 'overfit' without? + **dd, ) # NOTE old checkpoints used inverse of logit_scale ('tau') following the paper, see conversion fn - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones(num_heads))) + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones(num_heads, **dd))) self._make_pair_wise_relative_positions() def _make_pair_wise_relative_positions(self) -> None: """Initialize the pair-wise relative positions to compute the positional biases.""" device = self.logit_scale.device coordinates = torch.stack(ndgrid( - torch.arange(self.window_size[0], device=device), - torch.arange(self.window_size[1], device=device) - ), dim=0).flatten(1) + torch.arange(self.window_size[0], device=device, dtype=torch.float32), + torch.arange(self.window_size[1], device=device, dtype=torch.float32), + )).flatten(1) relative_coordinates = coordinates[:, :, None] - coordinates[:, None, :] relative_coordinates = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float() relative_coordinates_log = torch.sign(relative_coordinates) * torch.log( 1.0 + relative_coordinates.abs()) - self.register_buffer("relative_coordinates_log", relative_coordinates_log, persistent=False) + self.register_buffer( + "relative_coordinates_log", + relative_coordinates_log.to(self.logit_scale.dtype), + persistent=False, + ) def set_window_size(self, window_size: Tuple[int, int]) -> None: """Update window size and regenerate relative position coordinates. @@ -249,8 +257,11 @@ def __init__( extra_norm: bool = False, sequential_attn: bool = False, norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): - super(SwinTransformerV2CrBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.dim: int = dim self.feat_size: Tuple[int, int] = feat_size self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size) @@ -268,8 +279,9 @@ def __init__( drop_attn=drop_attn, drop_proj=proj_drop, sequential_attn=sequential_attn, + **dd, ) - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.drop_path1 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity() # mlp branch @@ -278,17 +290,18 @@ def __init__( hidden_features=int(dim * mlp_ratio), drop=proj_drop, out_features=dim, + **dd, ) - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity() # Extra main branch norm layer mentioned for Huge/Giant models in V2 paper. # Also being used as final network norm and optional stage ending norm while still in a C-last format. - self.norm3 = norm_layer(dim) if extra_norm else nn.Identity() + self.norm3 = norm_layer(dim, **dd) if extra_norm else nn.Identity() self.register_buffer( "attn_mask", - None if self.dynamic_mask else self.get_attn_mask(), + None if self.dynamic_mask else self.get_attn_mask(**dd), persistent=False, ) self.init_weights() @@ -310,15 +323,20 @@ def _calc_window_shift( shift_size = [0 if f <= w else s for f, w, s in zip(self.feat_size, window_size, target_shift_size)] return tuple(window_size), tuple(shift_size) - def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: + def get_attn_mask( + self, + x: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Optional[torch.Tensor]: """Method generates the attention mask used in shift case.""" # Make masks for shift case if any(self.shift_size): # calculate attention mask for SW-MSA if x is None: - img_mask = torch.zeros((1, *self.feat_size, 1)) # 1 H W 1 + img_mask = torch.zeros((1, *self.feat_size, 1), device=device, dtype=dtype) # 1 H W 1 else: - img_mask = torch.zeros((1, x.shape[1], x.shape[2], 1), dtype=x.dtype, device=x.device) # 1 H W 1 + img_mask = torch.zeros((1, x.shape[1], x.shape[2], 1), device=x.device, dtype=x.dtype) # 1 H W 1 cnt = 0 for h in ( (0, -self.window_size[0]), @@ -358,9 +376,11 @@ def set_input_size(self, feat_size: Tuple[int, int], window_size: Tuple[int, int self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) self.window_area = self.window_size[0] * self.window_size[1] self.attn.set_window_size(self.window_size) + device = self.attn_mask.device if self.attn_mask is not None else None + dtype = self.attn_mask.dtype if self.attn_mask is not None else None self.register_buffer( "attn_mask", - None if self.dynamic_mask else self.get_attn_mask(), + None if self.dynamic_mask else self.get_attn_mask(device=device, dtype=dtype), persistent=False, ) @@ -432,16 +452,23 @@ class PatchMerging(nn.Module): This class implements the patch merging as a strided convolution with a normalization before. """ - def __init__(self, dim: int, norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: + def __init__( + self, + dim: int, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, + ) -> None: """Initialize patch merging layer. Args: dim: Number of input channels. norm_layer: Type of normalization layer to be utilized. """ - super(PatchMerging, self).__init__() - self.norm = norm_layer(4 * dim) - self.reduction = nn.Linear(in_features=4 * dim, out_features=2 * dim, bias=False) + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.norm = norm_layer(4 * dim, **dd) + self.reduction = nn.Linear(in_features=4 * dim, out_features=2 * dim, bias=False, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass of patch merging. @@ -472,8 +499,10 @@ def __init__( patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, - norm_layer: Optional[Callable] = None, + norm_layer: Optional[Type[nn.Module]] = None, strict_img_size: bool = True, + device=None, + dtype=None, ) -> None: """Initialize patch embedding. @@ -485,6 +514,7 @@ def __init__( norm_layer: Normalization layer. strict_img_size: Enforce strict image size. """ + dd = {'device': device, 'dtype': dtype} super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -494,8 +524,8 @@ def __init__( self.num_patches = self.grid_size[0] * self.grid_size[1] self.strict_img_size = strict_img_size - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, **dd) + self.norm = norm_layer(embed_dim, **dd) if norm_layer else nn.Identity() def set_input_size(self, img_size: Tuple[int, int]) -> None: """Update input image size. @@ -548,32 +578,35 @@ class SwinTransformerV2CrStage(nn.Module): """ def __init__( - self, - embed_dim: int, - depth: int, - downscale: bool, - num_heads: int, - feat_size: Tuple[int, int], - window_size: Tuple[int, int], - always_partition: bool = False, - dynamic_mask: bool = False, - mlp_ratio: float = 4.0, - init_values: Optional[float] = 0.0, - proj_drop: float = 0.0, - drop_attn: float = 0.0, - drop_path: Union[List[float], float] = 0.0, - norm_layer: Type[nn.Module] = nn.LayerNorm, - extra_norm_period: int = 0, - extra_norm_stage: bool = False, - sequential_attn: bool = False, + self, + embed_dim: int, + depth: int, + downscale: bool, + num_heads: int, + feat_size: Tuple[int, int], + window_size: Tuple[int, int], + always_partition: bool = False, + dynamic_mask: bool = False, + mlp_ratio: float = 4.0, + init_values: Optional[float] = 0.0, + proj_drop: float = 0.0, + drop_attn: float = 0.0, + drop_path: Union[List[float], float] = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, + extra_norm_period: int = 0, + extra_norm_stage: bool = False, + sequential_attn: bool = False, + device=None, + dtype=None, ): - super(SwinTransformerV2CrStage, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.downscale: bool = downscale self.grad_checkpointing: bool = False self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size if downscale: - self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) + self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer, **dd) embed_dim = embed_dim * 2 else: self.downsample = nn.Identity() @@ -601,6 +634,7 @@ def _extra_norm(index): extra_norm=_extra_norm(index), sequential_attn=sequential_attn, norm_layer=norm_layer, + **dd, ) for index in range(depth)] ) @@ -670,33 +704,36 @@ class SwinTransformerV2Cr(nn.Module): """ def __init__( - self, - img_size: Tuple[int, int] = (224, 224), - patch_size: int = 4, - window_size: Optional[int] = None, - window_ratio: int = 8, - always_partition: bool = False, - strict_img_size: bool = True, - in_chans: int = 3, - num_classes: int = 1000, - embed_dim: int = 96, - depths: Tuple[int, ...] = (2, 2, 6, 2), - num_heads: Tuple[int, ...] = (3, 6, 12, 24), - mlp_ratio: float = 4.0, - init_values: Optional[float] = 0., - drop_rate: float = 0.0, - proj_drop_rate: float = 0.0, - attn_drop_rate: float = 0.0, - drop_path_rate: float = 0.0, - norm_layer: Type[nn.Module] = nn.LayerNorm, - extra_norm_period: int = 0, - extra_norm_stage: bool = False, - sequential_attn: bool = False, - global_pool: str = 'avg', - weight_init='skip', - **kwargs: Any + self, + img_size: Tuple[int, int] = (224, 224), + patch_size: int = 4, + window_size: Optional[int] = None, + window_ratio: int = 8, + always_partition: bool = False, + strict_img_size: bool = True, + in_chans: int = 3, + num_classes: int = 1000, + embed_dim: int = 96, + depths: Tuple[int, ...] = (2, 2, 6, 2), + num_heads: Tuple[int, ...] = (3, 6, 12, 24), + mlp_ratio: float = 4.0, + init_values: Optional[float] = 0., + drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, + extra_norm_period: int = 0, + extra_norm_stage: bool = False, + sequential_attn: bool = False, + global_pool: str = 'avg', + weight_init: str = 'skip', + device=None, + dtype=None, + **kwargs: Any ) -> None: - super(SwinTransformerV2Cr, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} img_size = to_2tuple(img_size) self.num_classes: int = num_classes self.patch_size: int = patch_size @@ -711,6 +748,7 @@ def __init__( embed_dim=embed_dim, norm_layer=norm_layer, strict_img_size=strict_img_size, + **dd, ) grid_size = self.patch_embed.grid_size if window_size is None: @@ -741,6 +779,7 @@ def __init__( extra_norm_stage=extra_norm_stage or (stage_idx + 1) == len(depths), # last stage ends w/ norm sequential_attn=sequential_attn, norm_layer=norm_layer, + **dd, )] if stage_idx != 0: in_dim *= 2 @@ -753,6 +792,7 @@ def __init__( num_classes, pool_type=global_pool, drop_rate=drop_rate, + **dd, ) # current weight init skips custom init and uses pytorch layer defaults, seems to work well diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index 2da814ab45..a70d570f22 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -10,7 +10,7 @@ import itertools from functools import partial -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, Type, Any import torch import torch.nn as nn @@ -27,10 +27,23 @@ class ConvNorm(torch.nn.Sequential): - def __init__(self, in_chs, out_chs, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + def __init__( + self, + in_chs: int, + out_chs: int, + ks: int = 1, + stride: int = 1, + pad: int = 0, + dilation: int = 1, + groups: int = 1, + bn_weight_init: float = 1, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False) - self.bn = nn.BatchNorm2d(out_chs) + self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False, **dd) + self.bn = nn.BatchNorm2d(out_chs, **dd) torch.nn.init.constant_(self.bn.weight, bn_weight_init) torch.nn.init.constant_(self.bn.bias, 0) @@ -50,12 +63,20 @@ def fuse(self): class PatchEmbed(nn.Module): - def __init__(self, in_chs, out_chs, act_layer): + def __init__( + self, + in_chs: int, + out_chs: int, + act_layer: Type[nn.Module], + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.stride = 4 - self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1) + self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1, **dd) self.act = act_layer() - self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1) + self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1, **dd) def forward(self, x): x = self.conv1(x) @@ -65,14 +86,24 @@ def forward(self, x): class MBConv(nn.Module): - def __init__(self, in_chs, out_chs, expand_ratio, act_layer, drop_path): + def __init__( + self, + in_chs: int, + out_chs: int, + expand_ratio: float, + act_layer: Type[nn.Module], + drop_path: float, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() mid_chs = int(in_chs * expand_ratio) - self.conv1 = ConvNorm(in_chs, mid_chs, ks=1) + self.conv1 = ConvNorm(in_chs, mid_chs, ks=1, **dd) self.act1 = act_layer() - self.conv2 = ConvNorm(mid_chs, mid_chs, ks=3, stride=1, pad=1, groups=mid_chs) + self.conv2 = ConvNorm(mid_chs, mid_chs, ks=3, stride=1, pad=1, groups=mid_chs, **dd) self.act2 = act_layer() - self.conv3 = ConvNorm(mid_chs, out_chs, ks=1, bn_weight_init=0.0) + self.conv3 = ConvNorm(mid_chs, out_chs, ks=1, bn_weight_init=0.0, **dd) self.act3 = act_layer() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -90,13 +121,21 @@ def forward(self, x): class PatchMerging(nn.Module): - def __init__(self, dim, out_dim, act_layer): + def __init__( + self, + dim: int, + out_dim: int, + act_layer: Type[nn.Module], + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.conv1 = ConvNorm(dim, out_dim, 1, 1, 0) + self.conv1 = ConvNorm(dim, out_dim, 1, 1, 0, **dd) self.act1 = act_layer() - self.conv2 = ConvNorm(out_dim, out_dim, 3, 2, 1, groups=out_dim) + self.conv2 = ConvNorm(out_dim, out_dim, 3, 2, 1, groups=out_dim, **dd) self.act2 = act_layer() - self.conv3 = ConvNorm(out_dim, out_dim, 1, 1, 0) + self.conv3 = ConvNorm(out_dim, out_dim, 1, 1, 0, **dd) def forward(self, x): x = self.conv1(x) @@ -110,19 +149,26 @@ def forward(self, x): class ConvLayer(nn.Module): def __init__( self, - dim, - depth, - act_layer, - drop_path=0., - conv_expand_ratio=4., + dim: int, + depth: int, + act_layer: Type[nn.Module], + drop_path: Union[float, List[float]] = 0., + conv_expand_ratio: float = 4., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.depth = depth self.blocks = nn.Sequential(*[ MBConv( - dim, dim, conv_expand_ratio, act_layer, + dim, + dim, + conv_expand_ratio, + act_layer, drop_path[i] if isinstance(drop_path, list) else drop_path, + **dd, ) for i in range(depth) ]) @@ -135,21 +181,24 @@ def forward(self, x): class NormMlp(nn.Module): def __init__( self, - in_features, - hidden_features=None, - out_features=None, - norm_layer=nn.LayerNorm, - act_layer=nn.GELU, - drop=0., + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + drop: float = 0., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.norm = norm_layer(in_features) - self.fc1 = nn.Linear(in_features, hidden_features) + self.norm = norm_layer(in_features, **dd) + self.fc1 = nn.Linear(in_features, hidden_features, **dd) self.act = act_layer() self.drop1 = nn.Dropout(drop) - self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = nn.Linear(hidden_features, out_features, **dd) self.drop2 = nn.Dropout(drop) def forward(self, x): @@ -168,12 +217,15 @@ class Attention(torch.nn.Module): def __init__( self, - dim, - key_dim, - num_heads=8, - attn_ratio=4, - resolution=(14, 14), + dim: int, + key_dim: int, + num_heads: int = 8, + attn_ratio: int = 4, + resolution: Tuple[int, int] = (14, 14), + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() assert isinstance(resolution, tuple) and len(resolution) == 2 self.num_heads = num_heads @@ -185,9 +237,9 @@ def __init__( self.resolution = resolution self.fused_attn = use_fused_attn() - self.norm = nn.LayerNorm(dim) - self.qkv = nn.Linear(dim, num_heads * (self.val_dim + 2 * key_dim)) - self.proj = nn.Linear(self.out_dim, dim) + self.norm = nn.LayerNorm(dim, **dd) + self.qkv = nn.Linear(dim, num_heads * (self.val_dim + 2 * key_dim), **dd) + self.proj = nn.Linear(self.out_dim, dim, **dd) points = list(itertools.product(range(resolution[0]), range(resolution[1]))) N = len(points) @@ -199,8 +251,12 @@ def __init__( if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) - self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets), **dd)) + self.register_buffer( + 'attention_bias_idxs', + torch.tensor(idxs, device=device, dtype=torch.long).view(N, N), + persistent=False, + ) self.attention_bias_cache = {} @torch.no_grad() @@ -261,15 +317,18 @@ class TinyVitBlock(nn.Module): def __init__( self, - dim, - num_heads, - window_size=7, - mlp_ratio=4., - drop=0., - drop_path=0., - local_conv_size=3, - act_layer=nn.GELU + dim: int, + num_heads: int, + window_size: int = 7, + mlp_ratio: float = 4., + drop: float = 0., + drop_path: float = 0., + local_conv_size: int = 3, + act_layer: Type[nn.Module] = nn.GELU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.num_heads = num_heads @@ -281,20 +340,27 @@ def __init__( head_dim = dim // num_heads window_resolution = (window_size, window_size) - self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution) + self.attn = Attention( + dim, + head_dim, + num_heads, + attn_ratio=1, + resolution=window_resolution, + **dd, + ) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.mlp = NormMlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, + **dd, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() pad = local_conv_size // 2 - self.local_conv = ConvNorm(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + self.local_conv = ConvNorm(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim, **dd) def forward(self, x): B, H, W, C = x.shape @@ -363,19 +429,21 @@ class TinyVitStage(nn.Module): def __init__( self, - dim, - out_dim, - depth, - num_heads, - window_size, - mlp_ratio=4., - drop=0., - drop_path=0., - downsample=None, - local_conv_size=3, - act_layer=nn.GELU, + dim: int, + out_dim: int, + depth: int, + num_heads: int, + window_size: int, + mlp_ratio: float = 4., + drop: float = 0., + drop_path: Union[float, List[float]] = 0., + downsample: Optional[Type[nn.Module]] = None, + local_conv_size: int = 3, + act_layer: Type[nn.Module] = nn.GELU, + device=None, + dtype=None, ): - + dd = {'device': device, 'dtype': dtype} super().__init__() self.depth = depth self.out_dim = out_dim @@ -386,6 +454,7 @@ def __init__( dim=dim, out_dim=out_dim, act_layer=act_layer, + **dd, ) else: self.downsample = nn.Identity() @@ -402,6 +471,7 @@ def __init__( drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, local_conv_size=local_conv_size, act_layer=act_layer, + **dd, ) for i in range(depth)]) @@ -419,22 +489,25 @@ def extra_repr(self) -> str: class TinyVit(nn.Module): def __init__( self, - in_chans=3, - num_classes=1000, - global_pool='avg', - embed_dims=(96, 192, 384, 768), - depths=(2, 2, 6, 2), - num_heads=(3, 6, 12, 24), - window_sizes=(7, 7, 14, 7), - mlp_ratio=4., - drop_rate=0., - drop_path_rate=0.1, - use_checkpoint=False, - mbconv_expand_ratio=4.0, - local_conv_size=3, - act_layer=nn.GELU, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + embed_dims: Tuple[int, ...] = (96, 192, 384, 768), + depths: Tuple[int, ...] = (2, 2, 6, 2), + num_heads: Tuple[int, ...] = (3, 6, 12, 24), + window_sizes: Tuple[int, ...] = (7, 7, 14, 7), + mlp_ratio: float = 4., + drop_rate: float = 0., + drop_path_rate: float = 0.1, + use_checkpoint: bool = False, + mbconv_expand_ratio: float = 4.0, + local_conv_size: int = 3, + act_layer: Type[nn.Module] = nn.GELU, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.depths = depths @@ -446,6 +519,7 @@ def __init__( in_chs=in_chans, out_chs=embed_dims[0], act_layer=act_layer, + **dd, ) # stochastic depth rate rule @@ -464,6 +538,7 @@ def __init__( act_layer=act_layer, drop_path=dpr[:depths[stage_idx]], conv_expand_ratio=mbconv_expand_ratio, + **dd, ) else: out_dim = embed_dims[stage_idx] @@ -480,6 +555,7 @@ def __init__( drop_path=drop_path_rate, downsample=PatchMerging, act_layer=act_layer, + **dd, ) prev_dim = out_dim stride *= 2 @@ -495,6 +571,7 @@ def __init__( num_classes, pool_type=global_pool, norm_layer=norm_layer_cf, + **dd, ) # init weights diff --git a/timm/models/tnt.py b/timm/models/tnt.py index a801062425..85435aa465 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -10,7 +10,7 @@ https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch """ import math -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type, Any import torch import torch.nn as nn @@ -29,7 +29,18 @@ class Attention(nn.Module): """ Multi-Head Attention """ - def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + def __init__( + self, + dim: int, + hidden_dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads @@ -37,10 +48,10 @@ def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., p self.head_dim = head_dim self.scale = head_dim ** -0.5 - self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias) - self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias, **dd) + self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop, inplace=True) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop, inplace=True) def forward(self, x): @@ -65,23 +76,26 @@ class Block(nn.Module): def __init__( self, - dim, - dim_out, - num_pixel, - num_heads_in=4, - num_heads_out=12, - mlp_ratio=4., - qkv_bias=False, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - legacy=False, + dim: int, + dim_out: int, + num_pixel: int, + num_heads_in: int = 4, + num_heads_out: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + legacy: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() # Inner transformer - self.norm_in = norm_layer(dim) + self.norm_in = norm_layer(dim, **dd) self.attn_in = Attention( dim, dim, @@ -89,28 +103,30 @@ def __init__( qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, + **dd, ) - self.norm_mlp_in = norm_layer(dim) + self.norm_mlp_in = norm_layer(dim, **dd) self.mlp_in = Mlp( in_features=dim, hidden_features=int(dim * 4), out_features=dim, act_layer=act_layer, drop=proj_drop, + **dd, ) self.legacy = legacy if self.legacy: - self.norm1_proj = norm_layer(dim) - self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True) + self.norm1_proj = norm_layer(dim, **dd) + self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True, **dd) self.norm2_proj = None else: - self.norm1_proj = norm_layer(dim * num_pixel) - self.proj = nn.Linear(dim * num_pixel, dim_out, bias=False) - self.norm2_proj = norm_layer(dim_out) + self.norm1_proj = norm_layer(dim * num_pixel, **dd) + self.proj = nn.Linear(dim * num_pixel, dim_out, bias=False, **dd) + self.norm2_proj = norm_layer(dim_out, **dd) # Outer transformer - self.norm_out = norm_layer(dim_out) + self.norm_out = norm_layer(dim_out, **dd) self.attn_out = Attention( dim_out, dim_out, @@ -118,16 +134,18 @@ def __init__( qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, + **dd, ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm_mlp = norm_layer(dim_out) + self.norm_mlp = norm_layer(dim_out, **dd) self.mlp = Mlp( in_features=dim_out, hidden_features=int(dim_out * mlp_ratio), out_features=dim_out, act_layer=act_layer, drop=proj_drop, + **dd, ) def forward(self, pixel_embed, patch_embed): @@ -157,13 +175,16 @@ class PixelEmbed(nn.Module): def __init__( self, - img_size=224, - patch_size=16, - in_chans=3, - in_dim=48, - stride=4, - legacy=False, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + in_dim: int = 48, + stride: int = 4, + legacy: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -178,7 +199,7 @@ def __init__( new_patch_size = [math.ceil(ps / stride) for ps in patch_size] self.new_patch_size = new_patch_size - self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride) + self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride, **dd) if self.legacy: self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size) else: @@ -221,28 +242,31 @@ class TNT(nn.Module): def __init__( self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - global_pool='token', - embed_dim=768, - inner_dim=48, - depth=12, - num_heads_inner=4, - num_heads_outer=12, - mlp_ratio=4., - qkv_bias=False, - drop_rate=0., - pos_drop_rate=0., - proj_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - norm_layer=nn.LayerNorm, - first_stride=4, - legacy=False, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'token', + embed_dim: int = 768, + inner_dim: int = 48, + depth: int = 12, + num_heads_inner: int = 4, + num_heads_outer: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = False, + drop_rate: float = 0., + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_layer: Type[nn.Module] = nn.LayerNorm, + first_stride: int = 4, + legacy: bool = False, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} assert global_pool in ('', 'token', 'avg') self.num_classes = num_classes self.global_pool = global_pool @@ -257,6 +281,7 @@ def __init__( in_dim=inner_dim, stride=first_stride, legacy=legacy, + **dd, ) num_patches = self.pixel_embed.num_patches r = self.pixel_embed.feat_ratio() if hasattr(self.pixel_embed, 'feat_ratio') else patch_size @@ -264,13 +289,13 @@ def __init__( new_patch_size = self.pixel_embed.new_patch_size num_pixel = new_patch_size[0] * new_patch_size[1] - self.norm1_proj = norm_layer(num_pixel * inner_dim) - self.proj = nn.Linear(num_pixel * inner_dim, embed_dim) - self.norm2_proj = norm_layer(embed_dim) + self.norm1_proj = norm_layer(num_pixel * inner_dim, **dd) + self.proj = nn.Linear(num_pixel * inner_dim, embed_dim, **dd) + self.norm2_proj = norm_layer(embed_dim, **dd) - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) - self.pixel_pos = nn.Parameter(torch.zeros(1, inner_dim, new_patch_size[0], new_patch_size[1])) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd)) + self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim, **dd)) + self.pixel_pos = nn.Parameter(torch.zeros(1, inner_dim, new_patch_size[0], new_patch_size[1], **dd)) self.pos_drop = nn.Dropout(p=pos_drop_rate) dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule @@ -289,14 +314,15 @@ def __init__( drop_path=dpr[i], norm_layer=norm_layer, legacy=legacy, + **dd, )) self.blocks = nn.ModuleList(blocks) self.feature_info = [ dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] - self.norm = norm_layer(embed_dim) + self.norm = norm_layer(embed_dim, **dd) self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity() trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.patch_pos, std=.02) @@ -340,7 +366,9 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): if global_pool is not None: assert global_pool in ('', 'token', 'avg') self.global_pool = global_pool - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + device = self.head.weight.device if hasattr(self.head, 'weight') else None + dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None + self.head = nn.Linear(self.embed_dim, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity() def forward_intermediates( self, diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 26d44c5f0c..7e318f6377 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -7,7 +7,7 @@ """ from collections import OrderedDict from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type import torch import torch.nn as nn @@ -26,25 +26,36 @@ class BasicBlock(nn.Module): def __init__( self, - inplanes, - planes, - stride=1, - downsample=None, - use_se=True, - aa_layer=None, - drop_path_rate=0. - ): - super(BasicBlock, self).__init__() + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + use_se: bool = True, + aa_layer: Optional[Type[nn.Module]] = None, + drop_path_rate: float = 0., + device=None, + dtype=None, + ) -> None: + dd = {'device': device, 'dtype': dtype} + super().__init__() self.downsample = downsample self.stride = stride act_layer = partial(nn.LeakyReLU, negative_slope=1e-3) - self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer) - self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False) + self.conv1 = ConvNormAct( + inplanes, + planes, + kernel_size=3, + stride=stride, + act_layer=act_layer, + aa_layer=aa_layer, + **dd, + ) + self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, **dd) self.act = nn.ReLU(inplace=True) rd_chs = max(planes * self.expansion // 4, 64) - self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None + self.se = SEModule(planes * self.expansion, rd_channels=rd_chs, **dd) if use_se else None self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() def forward(self, x): @@ -66,30 +77,38 @@ class Bottleneck(nn.Module): def __init__( self, - inplanes, - planes, - stride=1, - downsample=None, - use_se=True, - act_layer=None, - aa_layer=None, - drop_path_rate=0., - ): - super(Bottleneck, self).__init__() + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + use_se: bool = True, + act_layer: Optional[Type[nn.Module]] = None, + aa_layer: Optional[Type[nn.Module]] = None, + drop_path_rate: float = 0., + device=None, + dtype=None, + ) -> None: + dd = {'device': device, 'dtype': dtype} + super().__init__() self.downsample = downsample self.stride = stride act_layer = act_layer or partial(nn.LeakyReLU, negative_slope=1e-3) - self.conv1 = ConvNormAct( - inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer) + self.conv1 = ConvNormAct(inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, **dd) self.conv2 = ConvNormAct( - planes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer) + planes, + planes, + kernel_size=3, + stride=stride, + act_layer=act_layer, + aa_layer=aa_layer, + **dd, + ) reduction_chs = max(planes * self.expansion // 8, 64) - self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None + self.se = SEModule(planes, rd_channels=reduction_chs, **dd) if use_se else None - self.conv3 = ConvNormAct( - planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False) + self.conv3 = ConvNormAct(planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, **dd) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.act = nn.ReLU(inplace=True) @@ -112,19 +131,22 @@ def forward(self, x): class TResNet(nn.Module): def __init__( self, - layers, - in_chans=3, - num_classes=1000, - width_factor=1.0, - v2=False, - global_pool='fast', - drop_rate=0., - drop_path_rate=0., - ): + layers: List[int], + in_chans: int = 3, + num_classes: int = 1000, + width_factor: float = 1.0, + v2: bool = False, + global_pool: str = 'fast', + drop_rate: float = 0., + drop_path_rate: float = 0., + device=None, + dtype=None, + ) -> None: + super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.drop_rate = drop_rate self.grad_checkpointing = False - super(TResNet, self).__init__() aa_layer = BlurPool2d act_layer = nn.LeakyReLU @@ -137,19 +159,19 @@ def __init__( self.planes = self.planes // 8 * 8 dpr = calculate_drop_path_rates(drop_path_rate, layers, stagewise=True) - conv1 = ConvNormAct(in_chans * 16, self.planes, stride=1, kernel_size=3, act_layer=act_layer) + conv1 = ConvNormAct(in_chans * 16, self.planes, stride=1, kernel_size=3, act_layer=act_layer, **dd) layer1 = self._make_layer( Bottleneck if v2 else BasicBlock, - self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[0]) + self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[0], **dd) layer2 = self._make_layer( Bottleneck if v2 else BasicBlock, - self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[1]) + self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[1], **dd) layer3 = self._make_layer( Bottleneck, - self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[2]) + self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[2], **dd) layer4 = self._make_layer( Bottleneck, - self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer, drop_path_rate=dpr[3]) + self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer, drop_path_rate=dpr[3], **dd) # body self.body = nn.Sequential(OrderedDict([ @@ -171,7 +193,7 @@ def __init__( # head self.num_features = self.head_hidden_size = (self.planes * 8) * Bottleneck.expansion - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate, **dd) # model initialization for m in self.modules(): @@ -187,7 +209,20 @@ def __init__( if isinstance(m, Bottleneck): nn.init.zeros_(m.conv3.bn.weight) - def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None, drop_path_rate=0.): + def _make_layer( + self, + block, + planes, + blocks, + stride=1, + use_se=True, + aa_layer=None, + drop_path_rate=0., + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + downsample = None if stride != 1 or self.inplanes != planes * block.expansion: layers = [] @@ -195,7 +230,7 @@ def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=Non # avg pooling before 1x1 conv layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False)) layers += [ConvNormAct( - self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False)] + self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, **dd)] downsample = nn.Sequential(*layers) layers = [] @@ -208,6 +243,7 @@ def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=Non use_se=use_se, aa_layer=aa_layer, drop_path_rate=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate, + **dd, )) self.inplanes = planes * block.expansion return nn.Sequential(*layers) diff --git a/timm/models/twins.py b/timm/models/twins.py index 2d1e23bd8d..c27fd2be42 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -13,7 +13,7 @@ # -------------------------------------------------------- import math from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type, Any import torch import torch.nn as nn @@ -38,9 +38,19 @@ class LocallyGroupedAttn(nn.Module): """ fused_attn: torch.jit.Final[bool] - def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): + def __init__( + self, + dim: int, + num_heads: int = 8, + attn_drop: float = 0., + proj_drop: float = 0., + ws: int = 1, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} assert ws != 1 - super(LocallyGroupedAttn, self).__init__() + super().__init__() assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." self.dim = dim @@ -49,9 +59,9 @@ def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.qkv = nn.Linear(dim, dim * 3, bias=True, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) self.ws = ws @@ -135,7 +145,17 @@ class GlobalSubSampleAttn(nn.Module): """ fused_attn: torch.jit.Final[bool] - def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + def __init__( + self, + dim: int, + num_heads: int = 8, + attn_drop: float = 0., + proj_drop: float = 0., + sr_ratio: int = 1, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." @@ -145,16 +165,16 @@ def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.q = nn.Linear(dim, dim, bias=True) - self.kv = nn.Linear(dim, dim * 2, bias=True) + self.q = nn.Linear(dim, dim, bias=True, **dd) + self.kv = nn.Linear(dim, dim * 2, bias=True, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) self.sr_ratio = sr_ratio if sr_ratio > 1: - self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) - self.norm = nn.LayerNorm(dim) + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio, **dd) + self.norm = nn.LayerNorm(dim, **dd) else: self.sr = None self.norm = None @@ -193,33 +213,37 @@ class Block(nn.Module): def __init__( self, - dim, - num_heads, - mlp_ratio=4., - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - sr_ratio=1, - ws=None, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + sr_ratio: int = 1, + ws: Optional[int] = None, + device=None, + dtype=None, ): super().__init__() - self.norm1 = norm_layer(dim) + dd = {'device': device, 'dtype': dtype} + self.norm1 = norm_layer(dim, **dd) if ws is None: - self.attn = Attention(dim, num_heads, False, None, attn_drop, proj_drop) + self.attn = Attention(dim, num_heads, False, None, attn_drop, proj_drop, **dd) elif ws == 1: - self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, proj_drop, sr_ratio) + self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, proj_drop, sr_ratio, **dd) else: - self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, proj_drop, ws) + self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, proj_drop, ws, **dd) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, + **dd, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -231,10 +255,18 @@ def forward(self, x, size: Size_): class PosConv(nn.Module): # PEG from https://arxiv.org/abs/2102.10882 - def __init__(self, in_chans, embed_dim=768, stride=1): - super(PosConv, self).__init__() + def __init__( + self, + in_chans: int, + embed_dim: int = 768, + stride: int = 1, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.proj = nn.Sequential( - nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), + nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim, **dd), ) self.stride = stride @@ -255,7 +287,16 @@ class PatchEmbed(nn.Module): """ Image to Patch Embedding """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -266,8 +307,8 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): f"img_size {img_size} should be divided by patch_size {patch_size}." self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] self.num_patches = self.H * self.W - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - self.norm = nn.LayerNorm(embed_dim) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, **dd) + self.norm = nn.LayerNorm(embed_dim, **dd) def forward(self, x) -> Tuple[torch.Tensor, Size_]: B, C, H, W = x.shape @@ -286,26 +327,29 @@ class Twins(nn.Module): """ def __init__( self, - img_size=224, - patch_size=4, - in_chans=3, - num_classes=1000, - global_pool='avg', - embed_dims=(64, 128, 256, 512), - num_heads=(1, 2, 4, 8), - mlp_ratios=(4, 4, 4, 4), - depths=(3, 4, 6, 3), - sr_ratios=(8, 4, 2, 1), - wss=None, - drop_rate=0., - pos_drop_rate=0., - proj_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - norm_layer=partial(nn.LayerNorm, eps=1e-6), - block_cls=Block, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: int = 4, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + embed_dims: Tuple[int, ...] = (64, 128, 256, 512), + num_heads: Tuple[int, ...] = (1, 2, 4, 8), + mlp_ratios: Tuple[float, ...] = (4, 4, 4, 4), + depths: Tuple[int, ...] = (3, 4, 6, 3), + sr_ratios: Tuple[int, ...] = (8, 4, 2, 1), + wss: Optional[Tuple[int, ...]] = None, + drop_rate: float = 0., + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6), + block_cls: Any = Block, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.global_pool = global_pool self.depths = depths @@ -318,7 +362,7 @@ def __init__( self.patch_embeds = nn.ModuleList() self.pos_drops = nn.ModuleList() for i in range(len(depths)): - self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i])) + self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i], **dd)) self.pos_drops.append(nn.Dropout(p=pos_drop_rate)) prev_chs = embed_dims[i] img_size = tuple(t // patch_size for t in img_size) @@ -338,19 +382,20 @@ def __init__( drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k], - ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])], - ) + ws=1 if wss is None or i % 2 == 1 else wss[k], + **dd, + ) for i in range(depths[k])]) self.blocks.append(_block) self.feature_info += [dict(module=f'block.{k}', num_chs=embed_dims[k], reduction=2**(2+k))] cur += depths[k] - self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims]) + self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim, **dd) for embed_dim in embed_dims]) - self.norm = norm_layer(self.num_features) + self.norm = norm_layer(self.num_features, **dd) # classification head self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity() # init weights self.apply(self._init_weights) @@ -387,7 +432,9 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): if global_pool is not None: assert global_pool in ('', 'avg') self.global_pool = global_pool - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + device = self.head.weight.device if hasattr(self.head, 'weight') else None + dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None + self.head = nn.Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity() def _init_weights(self, m): if isinstance(m, nn.Linear): diff --git a/timm/models/vgg.py b/timm/models/vgg.py index 23c8834894..f51396cc2d 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -44,7 +44,9 @@ def __init__( drop_rate: float = 0.2, act_layer: Type[nn.Module] = nn.ReLU, conv_layer: Type[nn.Module] = nn.Conv2d, - ): + device=None, + dtype=None, + ) -> None: """Initialize ConvMlp. Args: @@ -56,13 +58,14 @@ def __init__( act_layer: Activation layer type. conv_layer: Convolution layer type. """ - super(ConvMlp, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.input_kernel_size = kernel_size mid_features = int(out_features * mlp_ratio) - self.fc1 = conv_layer(in_features, mid_features, kernel_size, bias=True) + self.fc1 = conv_layer(in_features, mid_features, kernel_size, bias=True, **dd) self.act1 = act_layer(True) self.drop = nn.Dropout(drop_rate) - self.fc2 = conv_layer(mid_features, out_features, 1, bias=True) + self.fc2 = conv_layer(mid_features, out_features, 1, bias=True, **dd) self.act2 = act_layer(True) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -105,6 +108,8 @@ def __init__( norm_layer: Optional[Type[nn.Module]] = None, global_pool: str = 'avg', drop_rate: float = 0., + device=None, + dtype=None, ) -> None: """Initialize VGG model. @@ -120,7 +125,8 @@ def __init__( global_pool: Global pooling type. drop_rate: Dropout rate. """ - super(VGG, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} assert output_stride == 32 self.num_classes = num_classes self.drop_rate = drop_rate @@ -140,9 +146,9 @@ def __init__( net_stride *= 2 else: v = cast(int, v) - conv2d = conv_layer(prev_chs, v, kernel_size=3, padding=1) + conv2d = conv_layer(prev_chs, v, kernel_size=3, padding=1, **dd) if norm_layer is not None: - layers += [conv2d, norm_layer(v), act_layer(inplace=True)] + layers += [conv2d, norm_layer(v, **dd), act_layer(inplace=True)] else: layers += [conv2d, act_layer(inplace=True)] prev_chs = v @@ -159,12 +165,14 @@ def __init__( drop_rate=drop_rate, act_layer=act_layer, conv_layer=conv_layer, + **dd, ) self.head = ClassifierHead( self.head_hidden_size, num_classes, pool_type=global_pool, drop_rate=drop_rate, + **dd, ) self._initialize_weights() diff --git a/timm/models/visformer.py b/timm/models/visformer.py index deadd98a6a..31f1015b6d 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -7,11 +7,13 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman """ +from typing import Optional, Union, Type, Any import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import to_2tuple, trunc_normal_, DropPath, calculate_drop_path_rates, PatchEmbed, LayerNorm2d, create_classifier, use_fused_attn + from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -22,14 +24,17 @@ class SpatialMlp(nn.Module): def __init__( self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0., - group=8, - spatial_conv=False, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Type[nn.Module] = nn.GELU, + drop: float = 0., + group: int = 8, + spatial_conv: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -45,17 +50,17 @@ def __init__( hidden_features = in_features * 2 self.hidden_features = hidden_features self.group = group - self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False) + self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False, **dd) self.act1 = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) if self.spatial_conv: self.conv2 = nn.Conv2d( - hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False) + hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False, **dd) self.act2 = act_layer() else: self.conv2 = None self.act2 = None - self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False) + self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False, **dd) self.drop3 = nn.Dropout(drop_probs[1]) def forward(self, x): @@ -73,7 +78,17 @@ def forward(self, x): class Attention(nn.Module): fused_attn: torch.jit.Final[bool] - def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.): + def __init__( + self, + dim: int, + num_heads: int = 8, + head_dim_ratio: float = 1., + attn_drop: float = 0., + proj_drop: float = 0., + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.num_heads = num_heads @@ -82,9 +97,9 @@ def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop= self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn(experimental=True) - self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False) + self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False) + self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): @@ -112,19 +127,22 @@ def forward(self, x): class Block(nn.Module): def __init__( self, - dim, - num_heads, - head_dim_ratio=1., - mlp_ratio=4., - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=LayerNorm2d, - group=8, - attn_disabled=False, - spatial_conv=False, + dim: int, + num_heads: int, + head_dim_ratio: float = 1., + mlp_ratio: float = 4., + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = LayerNorm2d, + group: int = 8, + attn_disabled: bool = False, + spatial_conv: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.spatial_conv = spatial_conv self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -132,16 +150,17 @@ def __init__( self.norm1 = None self.attn = None else: - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = Attention( dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=proj_drop, + **dd, ) - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = SpatialMlp( in_features=dim, hidden_features=int(dim * mlp_ratio), @@ -149,6 +168,7 @@ def __init__( drop=proj_drop, group=group, spatial_conv=spatial_conv, + **dd, ) def forward(self, x): @@ -161,31 +181,34 @@ def forward(self, x): class Visformer(nn.Module): def __init__( self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - init_channels=32, - embed_dim=384, - depth=12, - num_heads=6, - mlp_ratio=4., - drop_rate=0., - pos_drop_rate=0., - proj_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - norm_layer=LayerNorm2d, - attn_stage='111', - use_pos_embed=True, - spatial_conv='111', - vit_stem=False, - group=8, - global_pool='avg', - conv_init=False, - embed_norm=None, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + num_classes: int = 1000, + init_channels: Optional[int] = 32, + embed_dim: int = 384, + depth: Union[int, tuple] = 12, + num_heads: int = 6, + mlp_ratio: float = 4., + drop_rate: float = 0., + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + norm_layer: Type[nn.Module] = LayerNorm2d, + attn_stage: str = '111', + use_pos_embed: bool = True, + spatial_conv: str = '111', + vit_stem: bool = False, + group: int = 8, + global_pool: str = 'avg', + conv_init: bool = False, + embed_norm: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, ): super().__init__() + dd = {'device': device, 'dtype': dtype} img_size = to_2tuple(img_size) self.num_classes = num_classes self.embed_dim = embed_dim @@ -213,6 +236,7 @@ def __init__( embed_dim=embed_dim, norm_layer=embed_norm, flatten=False, + **dd, ) img_size = [x // patch_size for x in img_size] else: @@ -225,12 +249,13 @@ def __init__( embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False, + **dd, ) img_size = [x // (patch_size // 2) for x in img_size] else: self.stem = nn.Sequential( - nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False), - nn.BatchNorm2d(self.init_channels), + nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False, **dd), + nn.BatchNorm2d(self.init_channels, **dd), nn.ReLU(inplace=True) ) img_size = [x // 2 for x in img_size] @@ -241,14 +266,15 @@ def __init__( embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False, + **dd, ) img_size = [x // (patch_size // 4) for x in img_size] if self.use_pos_embed: if self.vit_stem: - self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size, **dd)) else: - self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size)) + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size, **dd)) self.pos_drop = nn.Dropout(p=pos_drop_rate) else: self.pos_embed1 = None @@ -266,6 +292,7 @@ def __init__( group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1'), + **dd, ) for i in range(self.stage_num1) ]) @@ -279,10 +306,11 @@ def __init__( embed_dim=embed_dim, norm_layer=embed_norm, flatten=False, + **dd, ) img_size = [x // (patch_size // 8) for x in img_size] if self.use_pos_embed: - self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) + self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size, **dd)) else: self.pos_embed2 = None else: @@ -300,6 +328,7 @@ def __init__( group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1'), + **dd, ) for i in range(self.stage_num1, self.stage_num1+self.stage_num2) ]) @@ -313,10 +342,11 @@ def __init__( embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False, + **dd, ) img_size = [x // (patch_size // 8) for x in img_size] if self.use_pos_embed: - self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size)) + self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size, **dd)) else: self.pos_embed3 = None else: @@ -334,15 +364,22 @@ def __init__( group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1'), + **dd, ) for i in range(self.stage_num1+self.stage_num2, depth) ]) self.num_features = self.head_hidden_size = embed_dim if self.vit_stem else embed_dim * 2 - self.norm = norm_layer(self.num_features) + self.norm = norm_layer(self.num_features, **dd) # head - global_pool, head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + global_pool, head = create_classifier( + self.num_features, + self.num_classes, + pool_type=global_pool, + device=device, + dtype=dtype, + ) self.global_pool = global_pool self.head_drop = nn.Dropout(drop_rate) self.head = head @@ -389,7 +426,10 @@ def get_classifier(self) -> nn.Module: def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes - self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + device = self.head.weight.device if hasattr(self.head, 'weight') else None + dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None + self.global_pool, self.head = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, device=device, dtype=dtype) def forward_features(self, x): if self.stem is not None: diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index d65174c535..376825f08f 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -66,6 +66,7 @@ get_norm_layer, maybe_add_mask, LayerType, + LayerScale, ) from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -78,35 +79,6 @@ _logger = logging.getLogger(__name__) -class LayerScale(nn.Module): - """Layer scale module. - - References: - - https://arxiv.org/abs/2103.17239 - """ - - def __init__( - self, - dim: int, - init_values: float = 1e-5, - inplace: bool = False, - ) -> None: - """Initialize LayerScale module. - - Args: - dim: Dimension. - init_values: Initial value for scaling. - inplace: If True, perform inplace operations. - """ - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Apply layer scaling.""" - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - class Block(nn.Module): """Transformer block with pre-normalization.""" @@ -127,6 +99,8 @@ def __init__( act_layer: Type[nn.Module] = nn.GELU, norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Type[nn.Module] = Mlp, + device=None, + dtype=None, ) -> None: """Initialize Block. @@ -146,7 +120,9 @@ def __init__( mlp_layer: MLP layer. """ super().__init__() - self.norm1 = norm_layer(dim) + dd = {'device': device, 'dtype': dtype} + + self.norm1 = norm_layer(dim, **dd) self.attn = Attention( dim, num_heads=num_heads, @@ -157,11 +133,12 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, + **dd ) - self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.ls1 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = mlp_layer( in_features=dim, hidden_features=int(dim * mlp_ratio), @@ -169,8 +146,9 @@ def __init__( norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, + **dd, ) - self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.ls2 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -197,8 +175,11 @@ def __init__( act_layer: Type[nn.Module] = nn.GELU, norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Type[nn.Module] = Mlp, + device = None, + dtype = None, ) -> None: super().__init__() + dd = {'device': device, 'dtype': dtype} self.init_values = init_values self.attn = Attention( @@ -211,8 +192,9 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, + **dd, ) - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.mlp = mlp_layer( @@ -222,8 +204,9 @@ def __init__( norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, + **dd, ) - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.init_weights() @@ -264,8 +247,11 @@ def __init__( act_layer: Type[nn.Module] = nn.GELU, norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Optional[Type[nn.Module]] = None, + device = None, + dtype = None, ) -> None: super().__init__() + dd = {'device': device, 'dtype': dtype} assert dim % num_heads == 0, 'dim should be divisible by num_heads' assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported' self.num_heads = num_heads @@ -275,26 +261,26 @@ def __init__( mlp_hidden_dim = int(mlp_ratio * dim) in_proj_out_dim = mlp_hidden_dim + 3 * dim - self.in_norm = norm_layer(dim) - self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias) + self.in_norm = norm_layer(dim, **dd) + self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias, **dd) self.in_split = [mlp_hidden_dim] + [dim] * 3 if qkv_bias: self.register_buffer('qkv_bias', None) self.register_parameter('mlp_bias', None) else: - self.register_buffer('qkv_bias', torch.zeros(3 * dim), persistent=False) - self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim)) + self.register_buffer('qkv_bias', torch.zeros(3 * dim, **dd), persistent=False) + self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim, **dd)) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias) + self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias, **dd) self.mlp_drop = nn.Dropout(proj_drop) self.mlp_act = act_layer() - self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias) + self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias, **dd) - self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity() + self.ls = LayerScale(dim, init_values=init_values, **dd) if init_values is not None else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -365,14 +351,17 @@ def __init__( act_layer: Type[nn.Module] = nn.GELU, norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Type[nn.Module] = Mlp, + device = None, + dtype = None ) -> None: + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_parallel = num_parallel self.attns = nn.ModuleList() self.ffns = nn.ModuleList() for _ in range(num_parallel): self.attns.append(nn.Sequential(OrderedDict([ - ('norm', norm_layer(dim)), + ('norm', norm_layer(dim, **dd)), ('attn', Attention( dim, num_heads=num_heads, @@ -383,12 +372,13 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, + **dd, )), - ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('ls', LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()), ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ]))) self.ffns.append(nn.Sequential(OrderedDict([ - ('norm', norm_layer(dim)), + ('norm', norm_layer(dim, **dd)), ('mlp', mlp_layer( dim, hidden_features=int(dim * mlp_ratio), @@ -396,8 +386,9 @@ def __init__( norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, + **dd, )), - ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('ls', LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()), ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ]))) @@ -491,6 +482,8 @@ def __init__( act_layer: Optional[LayerType] = None, block_fn: Type[nn.Module] = Block, mlp_layer: Type[nn.Module] = Mlp, + device=None, + dtype=None, ) -> None: """ Args: @@ -524,6 +517,7 @@ def __init__( block_fn: Transformer block layer. """ super().__init__() + dd = {'device': device, 'dtype': dtype} assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') assert class_token or global_pool != 'token' assert pos_embed in ('', 'none', 'learn') @@ -558,17 +552,18 @@ def __init__( bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) dynamic_img_pad=dynamic_img_pad, **embed_args, + **dd, ) num_patches = self.patch_embed.num_patches reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None - self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd)) if class_token else None + self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim, **dd)) if reg_tokens else None embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens if not pos_embed or pos_embed == 'none': self.pos_embed = None else: - self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim, **dd) * .02) self.pos_drop = nn.Dropout(p=pos_drop_rate) if patch_drop_rate > 0: self.patch_drop = PatchDropout( @@ -577,7 +572,7 @@ def __init__( ) else: self.patch_drop = nn.Identity() - self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + self.norm_pre = norm_layer(embed_dim, **dd) if pre_norm else nn.Identity() dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule self.blocks = nn.Sequential(*[ @@ -597,11 +592,12 @@ def __init__( norm_layer=norm_layer, act_layer=act_layer, mlp_layer=mlp_layer, + **dd, ) for i in range(depth)]) self.feature_info = [ dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)] - self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity() + self.norm = norm_layer(embed_dim, **dd) if final_norm and not use_fc_norm else nn.Identity() # Classifier Head if global_pool == 'map': @@ -611,12 +607,13 @@ def __init__( mlp_ratio=mlp_ratio, norm_layer=norm_layer, act_layer=act_layer, + **dd, ) else: self.attn_pool = None - self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity() + self.fc_norm = norm_layer(embed_dim, **dd) if final_norm and use_fc_norm else nn.Identity() self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity() if weight_init != 'skip': self.init_weights(weight_init) @@ -646,6 +643,7 @@ def init_weights(self, mode: str = '') -> None: nn.init.normal_(self.cls_token, std=1e-6) if self.reg_token is not None: nn.init.normal_(self.reg_token, std=1e-6) + named_apply(get_init_weights_vit(mode, head_bias), self) def _init_weights(self, m: nn.Module) -> None: diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 7290bb7697..b244c75589 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -40,7 +40,10 @@ def __init__( padding: Union[str, int, Tuple[int, ...]] = "", norm_layer: Type[nn.Module] = nn.BatchNorm2d, act_layer: Type[nn.Module] = nn.ReLU, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() if isinstance(channels, int): # a default tiered channel strategy @@ -64,10 +67,15 @@ def __init__( apply_act=not last_conv, norm_layer=norm_layer, act_layer=act_layer, + **dd, )) in_chs = channels[i] +def _dd_from_kwargs(**kwargs): + return {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)} + + def _resnetv2(layers=(3, 4, 9), **kwargs): """ ResNet-V2 backbone helper""" padding_same = kwargs.get('padding_same', True) @@ -75,11 +83,23 @@ def _resnetv2(layers=(3, 4, 9), **kwargs): conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8) if len(layers): backbone = ResNetV2( - layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), - preact=False, stem_type=stem_type, conv_layer=conv_layer) + layers=layers, + num_classes=0, + global_pool='', + in_chans=kwargs.get('in_chans', 3), + preact=False, + stem_type=stem_type, + conv_layer=conv_layer, + **_dd_from_kwargs(**kwargs), + ) else: backbone = create_resnetv2_stem( - kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer) + kwargs.get('in_chans', 3), + stem_type=stem_type, + preact=False, + conv_layer=conv_layer, + **_dd_from_kwargs(**kwargs), + ) return backbone @@ -343,7 +363,13 @@ def vit_large_r50_s32_384(pretrained=False, **kwargs) -> VisionTransformer: def vit_small_resnet26d_224(pretrained=False, **kwargs) -> VisionTransformer: """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. """ - backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) + backbone = resnet26d( + pretrained=pretrained, + in_chans=kwargs.get('in_chans', 3), + features_only=True, + out_indices=[4], + **_dd_from_kwargs(**kwargs), + ) model_args = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3) model = _create_vision_transformer_hybrid( 'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) @@ -354,7 +380,13 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs) -> VisionTransformer: def vit_small_resnet50d_s16_224(pretrained=False, **kwargs) -> VisionTransformer: """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. """ - backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3]) + backbone = resnet50d( + pretrained=pretrained, + in_chans=kwargs.get('in_chans', 3), + features_only=True, + out_indices=[3], + **_dd_from_kwargs(**kwargs), + ) model_args = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3) model = _create_vision_transformer_hybrid( 'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) @@ -365,7 +397,13 @@ def vit_small_resnet50d_s16_224(pretrained=False, **kwargs) -> VisionTransformer def vit_base_resnet26d_224(pretrained=False, **kwargs) -> VisionTransformer: """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights. """ - backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) + backbone = resnet26d( + pretrained=pretrained, + in_chans=kwargs.get('in_chans', 3), + features_only=True, + out_indices=[4], + **_dd_from_kwargs(**kwargs), + ) model_args = dict(embed_dim=768, depth=12, num_heads=12) model = _create_vision_transformer_hybrid( 'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) @@ -376,7 +414,13 @@ def vit_base_resnet26d_224(pretrained=False, **kwargs) -> VisionTransformer: def vit_base_resnet50d_224(pretrained=False, **kwargs) -> VisionTransformer: """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. """ - backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) + backbone = resnet50d( + pretrained=pretrained, + in_chans=kwargs.get('in_chans', 3), + features_only=True, + out_indices=[4], + **_dd_from_kwargs(**kwargs), + ) model_args = dict(embed_dim=768, depth=12, num_heads=12) model = _create_vision_transformer_hybrid( 'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) @@ -394,6 +438,7 @@ def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer: padding=0, in_chans=kwargs.get('in_chans', 3), act_layer=nn.GELU, + **_dd_from_kwargs(**kwargs), ) model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True) model = _create_vision_transformer_hybrid( diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 8a5e5c6d65..492d6e2727 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -19,7 +19,17 @@ from torch.jit import Final from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import PatchEmbed, Mlp, DropPath, calculate_drop_path_rates, RelPosMlp, RelPosBias, use_fused_attn, LayerType +from timm.layers import ( + PatchEmbed, + Mlp, + LayerScale, + DropPath, + calculate_drop_path_rates, + RelPosMlp, + RelPosBias, + use_fused_attn, + LayerType, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint @@ -36,15 +46,18 @@ class RelPosAttention(nn.Module): def __init__( self, - dim, - num_heads=8, - qkv_bias=False, - qk_norm=False, - rel_pos_cls=None, - attn_drop=0., - proj_drop=0., - norm_layer=nn.LayerNorm, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + rel_pos_cls: Optional[Type[nn.Module]] = None, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads @@ -52,12 +65,12 @@ def __init__( self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) + self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() + self.rel_pos = rel_pos_cls(num_heads=num_heads, **dd) if rel_pos_cls else None self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): @@ -97,35 +110,28 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): return x -class LayerScale(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - class RelPosBlock(nn.Module): def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - qk_norm=False, - rel_pos_cls=None, - init_values=None, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + rel_pos_cls: Optional[Type[nn.Module]] = None, + init_values: Optional[float] = None, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = RelPosAttention( dim, num_heads, @@ -134,19 +140,22 @@ def __init__( rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=proj_drop, + norm_layer=norm_layer, + **dd, ) - self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.ls1 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity() # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, + **dd, ) - self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.ls2 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): @@ -159,19 +168,22 @@ class ResPostRelPosBlock(nn.Module): def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - qk_norm=False, - rel_pos_cls=None, - init_values=None, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + rel_pos_cls: Optional[Type[nn.Module]] = None, + init_values: Optional[float] = None, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.init_values = init_values @@ -183,8 +195,10 @@ def __init__( rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=proj_drop, + norm_layer=norm_layer, + **dd, ) - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.mlp = Mlp( @@ -192,8 +206,9 @@ def __init__( hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, + **dd, ) - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.init_weights() @@ -248,7 +263,9 @@ def __init__( embed_layer: Type[nn.Module] = PatchEmbed, norm_layer: Optional[LayerType] = None, act_layer: Optional[LayerType] = None, - block_fn: Type[nn.Module] = RelPosBlock + block_fn: Type[nn.Module] = RelPosBlock, + device=None, + dtype=None, ): """ Args: @@ -279,6 +296,7 @@ def __init__( act_layer: MLP activation layer """ super().__init__() + dd = {'device': device, 'dtype': dtype} assert global_pool in ('', 'avg', 'token') assert class_token or global_pool != 'token' norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) @@ -295,6 +313,7 @@ def __init__( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + **dd, ) feat_size = self.patch_embed.grid_size r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size @@ -310,11 +329,11 @@ def __init__( rel_pos_cls = partial(RelPosBias, **rel_pos_args) self.shared_rel_pos = None if shared_rel_pos: - self.shared_rel_pos = rel_pos_cls(num_heads=num_heads) + self.shared_rel_pos = rel_pos_cls(num_heads=num_heads, **dd) # NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both... rel_pos_cls = None - self.cls_token = nn.Parameter(torch.zeros(1, self.num_prefix_tokens, embed_dim)) if class_token else None + self.cls_token = nn.Parameter(torch.zeros(1, self.num_prefix_tokens, embed_dim, **dd)) if class_token else None dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule self.blocks = nn.ModuleList([ @@ -331,16 +350,17 @@ def __init__( drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, + **dd, ) for i in range(depth)]) self.feature_info = [ dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] - self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity() + self.norm = norm_layer(embed_dim, **dd) if not fc_norm else nn.Identity() # Classifier Head - self.fc_norm = norm_layer(embed_dim) if fc_norm else nn.Identity() + self.fc_norm = norm_layer(embed_dim, **dd) if fc_norm else nn.Identity() self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity() if weight_init != 'skip': self.init_weights(weight_init) @@ -380,12 +400,13 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, device=None, dtype=None): + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'token') self.global_pool = global_pool - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity() def forward_intermediates( self, diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index a8cce36c0c..d3624c9d89 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -11,14 +11,29 @@ """ import logging from functools import partial -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Type, Union import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import PatchEmbed, Mlp, DropPath, calculate_drop_path_rates, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \ - Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn +from timm.layers import ( + PatchEmbed, + Mlp, + DropPath, + calculate_drop_path_rates, + PatchDropout, + LayerNorm2d, + LayerScale, + ClassifierHead, + NormMlpClassifierHead, + Format, + resample_abs_pos_embed_nhwc, + RotaryEmbeddingCat, + apply_rot_embed_cat, + to_2tuple, + use_fused_attn, +) from torch.jit import Final from ._builder import build_model_with_cfg @@ -60,8 +75,8 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. - q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) - k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + q_coords = torch.arange(q_size, dtype=torch.float32)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size, dtype=torch.float32)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] @@ -70,11 +85,11 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor def get_decomposed_rel_pos_bias( - q: torch.Tensor, - rel_pos_h: torch.Tensor, - rel_pos_w: torch.Tensor, - q_size: Tuple[int, int], - k_size: Tuple[int, int], + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], ) -> torch.Tensor: """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. @@ -108,17 +123,20 @@ class Attention(nn.Module): def __init__( self, - dim, - num_heads=8, - qkv_bias=True, - qk_norm=False, - attn_drop=0., - proj_drop=0., - norm_layer=nn.LayerNorm, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: Type[nn.Module] = nn.LayerNorm, use_rel_pos: bool = False, input_size: Optional[Tuple[int, int]] = None, rope: Optional[nn.Module] = None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads @@ -126,11 +144,11 @@ def __init__( self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) + self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) self.use_rel_pos = use_rel_pos if self.use_rel_pos: @@ -139,10 +157,8 @@ def __init__( input_size is not None ), "Input size must be provided if using relative positional encoding." # initialize relative positional embeddings - self.rel_pos_h = nn.Parameter(torch.zeros( - 2 * input_size[0] - 1, self.head_dim)) - self.rel_pos_w = nn.Parameter(torch.zeros( - 2 * input_size[1] - 1, self.head_dim)) + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim, **dd)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim, **dd)) self.rope = rope def forward(self, x): @@ -186,40 +202,33 @@ def forward(self, x): return x -class LayerScale(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - class Block(nn.Module): def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=True, - qk_norm=False, - proj_drop=0., - attn_drop=0., - init_values=None, - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - mlp_layer=Mlp, - use_rel_pos=False, - window_size=0, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + mlp_layer: Type[nn.Module] = Mlp, + use_rel_pos: bool = False, + window_size: int = 0, input_size=None, rope=None, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.window_size = window_size - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = Attention( dim, num_heads=num_heads, @@ -231,18 +240,20 @@ def __init__( use_rel_pos=use_rel_pos, input_size=input_size if window_size == 0 else (window_size, window_size), rope=rope, + **dd, ) - self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.ls1 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = mlp_layer( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, + **dd, ) - self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.ls2 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): @@ -344,11 +355,11 @@ def __init__( attn_drop_rate: float = 0., drop_path_rate: float = 0., weight_init: str = '', - embed_layer: Callable = partial(PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False), - norm_layer: Optional[Callable] = nn.LayerNorm, - act_layer: Optional[Callable] = nn.GELU, - block_fn: Callable = Block, - mlp_layer: Callable = Mlp, + embed_layer: Type[nn.Module] = partial(PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False), + norm_layer: Optional[Type[nn.Module]] = nn.LayerNorm, + act_layer: Optional[Type[nn.Module]] = nn.GELU, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, use_abs_pos: bool = True, use_rel_pos: bool = False, use_rope: bool = False, @@ -357,7 +368,9 @@ def __init__( neck_chans: int = 256, global_pool: str = 'avg', head_hidden_size: Optional[int] = None, - ref_feat_shape: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None + ref_feat_shape: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None, + device=None, + dtype=None, ): """ Args: @@ -391,6 +404,7 @@ def __init__( ref_feat_shape: Tuple of reference feature shapes for ROPE, (global, local) """ super().__init__() + dd = {'device': device, 'dtype': dtype} norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU @@ -405,13 +419,14 @@ def __init__( in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm, # disable bias if pre-norm is used + **dd, ) grid_size = self.patch_embed.grid_size r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. - self.pos_embed = nn.Parameter(torch.zeros(1, grid_size[0], grid_size[1], embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, grid_size[0], grid_size[1], embed_dim, **dd)) else: self.pos_embed = None self.pos_drop = nn.Dropout(p=pos_drop_rate) @@ -422,7 +437,7 @@ def __init__( ) else: self.patch_drop = nn.Identity() - self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + self.norm_pre = norm_layer(embed_dim, **dd) if pre_norm else nn.Identity() if use_rope: assert not use_rel_pos, "ROPE and relative pos embeddings should not be enabled at same time" @@ -468,6 +483,7 @@ def __init__( window_size=window_size if i not in global_attn_indexes else 0, input_size=grid_size, rope=self.rope_window if i not in global_attn_indexes else self.rope_global, + **dd, ) for i in range(depth)]) self.feature_info = [ @@ -480,16 +496,18 @@ def __init__( neck_chans, kernel_size=1, bias=False, + **dd, ), - LayerNorm2d(neck_chans), + LayerNorm2d(neck_chans, **dd), nn.Conv2d( neck_chans, neck_chans, kernel_size=3, padding=1, bias=False, + **dd, ), - LayerNorm2d(neck_chans), + LayerNorm2d(neck_chans, **dd), ) self.num_features = neck_chans else: @@ -497,7 +515,7 @@ def __init__( self.neck = nn.Identity() else: # should have a final norm with standard ClassifierHead - self.neck = LayerNorm2d(embed_dim) + self.neck = LayerNorm2d(embed_dim, **dd) neck_chans = embed_dim # Classifier Head @@ -508,6 +526,7 @@ def __init__( hidden_size=head_hidden_size, pool_type=global_pool, drop_rate=drop_rate, + **dd, ) else: self.head = ClassifierHead( @@ -515,6 +534,7 @@ def __init__( num_classes, pool_type=global_pool, drop_rate=drop_rate, + **dd, ) @torch.jit.ignore diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index 0effa5bbfc..5fc1556a96 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -80,14 +80,17 @@ def __init__( norm_layer: str = 'layernorm2d', norm_eps: float = 1e-6, bias: bool = True, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) self.out_chs = out_chs - self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias) - self.norm1 = norm_act_layer(out_chs) - self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias) + self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias, **dd) + self.norm1 = norm_act_layer(out_chs, **dd) + self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias, **dd) named_apply(_init_conv, self) @@ -105,12 +108,15 @@ def __init__( dim_out: int, pool_type: str = 'avg2', bias: bool = True, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) if dim != dim_out: - self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) # 1x1 conv + self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias, **dd) # 1x1 conv else: self.expand = nn.Identity() @@ -125,17 +131,20 @@ class StridedConv(nn.Module): """ def __init__( self, - kernel_size=3, - stride=2, - padding=1, - in_chans=3, - embed_dim=768 + kernel_size: int = 3, + stride: int = 2, + padding: int = 1, + in_chans: int = 3, + embed_dim: int = 768, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6) - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) - self.norm = norm_layer(in_chans) # affine over C + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, **dd) + self.norm = norm_layer(in_chans, **dd) # affine over C def forward(self, x): x = self.norm(x) @@ -157,27 +166,30 @@ def __init__( norm_eps: float = 1e-6, act_layer: str = 'gelu', expand_ratio: float = 4.0, + device=None, + dtype=None, ): - super(MbConvLNBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs mid_chs = make_divisible(out_chs * expand_ratio) prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) if stride == 2: - self.shortcut = Downsample2d(in_chs, out_chs, pool_type='avg', bias=True) + self.shortcut = Downsample2d(in_chs, out_chs, pool_type='avg', bias=True, **dd) elif in_chs != out_chs: - self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True) + self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True, **dd) else: self.shortcut = nn.Identity() - self.pre_norm = prenorm_act_layer(in_chs, apply_act=False) + self.pre_norm = prenorm_act_layer(in_chs, apply_act=False, **dd) self.down = nn.Identity() - self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True) + self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True, **dd) self.act1 = create_act_layer(act_layer, inplace=True) self.conv2_kxk = create_conv2d( - mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True) + mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True, **dd) self.act2 = create_act_layer(act_layer, inplace=True) - self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True) + self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True, **dd) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -213,13 +225,17 @@ def __init__( cfg: VitCfg, img_size: Union[int, Tuple[int, int]] = 224, # place holder in_chans: int = 3, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.grad_checkpointing = False self.stem = Stem( in_chs=in_chans, out_chs=cfg.stem_width, + **dd, ) stages = [] @@ -231,6 +247,7 @@ def __init__( in_chs = stage_in_chs if d==0 else dim, out_chs = dim, stride = 2 if d == 0 else 1, + **dd, ) for d in range(cfg.depths[s]) ] @@ -240,7 +257,8 @@ def __init__( self.pool = StridedConv( stride=2, in_chans=cfg.embed_dim[1], - embed_dim=cfg.embed_dim[2] + embed_dim=cfg.embed_dim[2], + **dd, ) def forward(self, x): @@ -256,21 +274,24 @@ def forward(self, x): class GeGluMlp(nn.Module): def __init__( self, - in_features, - hidden_features, - act_layer = 'gelu', - norm_layer = None, - bias = True, - drop = 0.0, + in_features: int, + hidden_features: int, + act_layer: str = 'gelu', + norm_layer: Optional[str] = None, + bias: bool = True, + drop: float = 0.0, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() norm_layer = partial(get_norm_layer(norm_layer or 'layernorm'), eps=1e-6) - self.norm = norm_layer(in_features) - self.w0 = nn.Linear(in_features, hidden_features, bias=bias) + self.norm = norm_layer(in_features, **dd) + self.w0 = nn.Linear(in_features, hidden_features, bias=bias, **dd) self.act = create_act_layer(act_layer) - self.w1 = nn.Linear(in_features, hidden_features, bias=bias) - self.w2 = nn.Linear(hidden_features, in_features, bias=bias) + self.w1 = nn.Linear(in_features, hidden_features, bias=bias, **dd) + self.w2 = nn.Linear(hidden_features, in_features, bias=bias, **dd) def forward(self, x): x = self.norm(x) @@ -282,7 +303,8 @@ def forward(self, x): def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs): out_indices = kwargs.pop('out_indices', 3) assert embed_cfg is not None - backbone = MbConvStages(cfg=embed_cfg, in_chans=kwargs.get('in_chans', 3)) + dd = {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)} + backbone = MbConvStages(cfg=embed_cfg, in_chans=kwargs.get('in_chans', 3), **dd) kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False) kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set diff --git a/timm/models/volo.py b/timm/models/volo.py index f417dc6df1..db231aeff8 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type import torch import torch.nn as nn @@ -49,6 +49,8 @@ def __init__( qkv_bias: bool = False, attn_drop: float = 0., proj_drop: float = 0., + device=None, + dtype=None, ): """Initialize OutlookAttention. @@ -62,6 +64,7 @@ def __init__( attn_drop: Attention dropout rate. proj_drop: Projection dropout rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() head_dim = dim // num_heads self.num_heads = num_heads @@ -70,11 +73,11 @@ def __init__( self.stride = stride self.scale = head_dim ** -0.5 - self.v = nn.Linear(dim, dim, bias=qkv_bias) - self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads) + self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd) + self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride) @@ -128,9 +131,11 @@ def __init__( mlp_ratio: float = 3., attn_drop: float = 0., drop_path: float = 0., - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, qkv_bias: bool = False, + device=None, + dtype=None, ): """Initialize Outlooker block. @@ -147,8 +152,9 @@ def __init__( norm_layer: Normalization layer type. qkv_bias: Whether to use bias in linear layers. """ + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = OutlookAttention( dim, num_heads, @@ -157,14 +163,16 @@ def __init__( stride=stride, qkv_bias=qkv_bias, attn_drop=attn_drop, + **dd, ) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + **dd, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -193,6 +201,8 @@ def __init__( qkv_bias: bool = False, attn_drop: float = 0., proj_drop: float = 0., + device=None, + dtype=None, ): """Initialize Attention module. @@ -203,15 +213,16 @@ def __init__( attn_drop: Attention dropout rate. proj_drop: Projection dropout rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -258,8 +269,10 @@ def __init__( qkv_bias: bool = False, attn_drop: float = 0., drop_path: float = 0., - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): """Initialize Transformer block. @@ -273,13 +286,14 @@ def __init__( act_layer: Activation layer type. norm_layer: Normalization layer type. """ + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop) + self.norm1 = norm_layer(dim, **dd) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, **dd) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) + self.norm2 = norm_layer(dim, **dd) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, **dd) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -307,6 +321,8 @@ def __init__( qkv_bias: bool = False, attn_drop: float = 0., proj_drop: float = 0., + device=None, + dtype=None, ): """Initialize ClassAttention. @@ -318,6 +334,7 @@ def __init__( attn_drop: Attention dropout rate. proj_drop: Projection dropout rate. """ + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads if head_dim is not None: @@ -327,10 +344,10 @@ def __init__( self.head_dim = head_dim self.scale = head_dim ** -0.5 - self.kv = nn.Linear(dim, self.head_dim * self.num_heads * 2, bias=qkv_bias) - self.q = nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias) + self.kv = nn.Linear(dim, self.head_dim * self.num_heads * 2, bias=qkv_bias, **dd) + self.q = nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(self.head_dim * self.num_heads, dim) + self.proj = nn.Linear(self.head_dim * self.num_heads, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -371,8 +388,10 @@ def __init__( drop: float = 0., attn_drop: float = 0., drop_path: float = 0., - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, ): """Initialize ClassBlock. @@ -388,8 +407,9 @@ def __init__( act_layer: Activation layer type. norm_layer: Normalization layer type. """ + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, **dd) self.attn = ClassAttention( dim, num_heads=num_heads, @@ -397,15 +417,17 @@ def __init__( qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + **dd, ) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, **dd) self.mlp = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, + **dd, ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -424,18 +446,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.cat([cls_embed, x[:, 1:]], dim=1) -def get_block(block_type: str, **kargs: Any) -> nn.Module: +def get_block(block_type: str, **kwargs: Any) -> nn.Module: """Get block based on type. Args: block_type: Type of block ('ca' for ClassBlock). - **kargs: Additional keyword arguments for block. + **kwargs: Additional keyword arguments for block. Returns: The requested block module. """ if block_type == 'ca': - return ClassBlock(**kargs) + return ClassBlock(**kwargs) + else: + assert False, f'Invalid block type: {block_type}' def rand_bbox(size: Tuple[int, ...], lam: float, scale: int = 1) -> Tuple[int, int, int, int]: @@ -483,6 +507,8 @@ def __init__( in_chans: int = 3, hidden_dim: int = 64, embed_dim: int = 384, + device=None, + dtype=None, ): """Initialize PatchEmbed. @@ -497,25 +523,31 @@ def __init__( hidden_dim: Hidden dimension for stem convolution. embed_dim: Output embedding dimension. """ + dd = {'device': device, 'dtype': dtype} super().__init__() assert patch_size in [4, 8, 16] if stem_conv: self.conv = nn.Sequential( - nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112 - nn.BatchNorm2d(hidden_dim), + nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False, **dd), + nn.BatchNorm2d(hidden_dim, **dd), nn.ReLU(inplace=True), - nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112 - nn.BatchNorm2d(hidden_dim), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False, **dd), + nn.BatchNorm2d(hidden_dim, **dd), nn.ReLU(inplace=True), - nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112 - nn.BatchNorm2d(hidden_dim), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False, **dd), + nn.BatchNorm2d(hidden_dim, **dd), nn.ReLU(inplace=True), ) else: self.conv = None self.proj = nn.Conv2d( - hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride) + hidden_dim, + embed_dim, + kernel_size=patch_size // stem_stride, + stride=patch_size // stem_stride, + **dd, + ) self.num_patches = (img_size // patch_size) * (img_size // patch_size) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -536,7 +568,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Downsample(nn.Module): """Downsampling module between stages.""" - def __init__(self, in_embed_dim: int, out_embed_dim: int, patch_size: int = 2): + def __init__( + self, + in_embed_dim: int, + out_embed_dim: int, + patch_size: int = 2, + device=None, + dtype=None, + ): """Initialize Downsample. Args: @@ -545,7 +584,8 @@ def __init__(self, in_embed_dim: int, out_embed_dim: int, patch_size: int = 2): patch_size: Patch size for downsampling. """ super().__init__() - self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size) + dd = {'device': device, 'dtype': dtype} + self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size, **dd) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. @@ -575,6 +615,8 @@ def outlooker_blocks( qkv_bias: bool = False, attn_drop: float = 0, drop_path_rate: float = 0., + device=None, + dtype=None, **kwargs: Any, ) -> nn.Sequential: """Generate outlooker layers for stage 1. @@ -610,6 +652,9 @@ def outlooker_blocks( qkv_bias=qkv_bias, attn_drop=attn_drop, drop_path=block_dpr, + device=device, + dtype=dtype, + **kwargs, )) blocks = nn.Sequential(*blocks) return blocks @@ -654,6 +699,7 @@ def transformer_blocks( qkv_bias=qkv_bias, attn_drop=attn_drop, drop_path=block_dpr, + **kwargs, )) blocks = nn.Sequential(*blocks) return blocks @@ -681,11 +727,13 @@ def __init__( pos_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., - norm_layer: Callable = nn.LayerNorm, + norm_layer: Type[nn.Module] = nn.LayerNorm, post_layers: Optional[Tuple[str, ...]] = ('ca', 'ca'), use_aux_head: bool = True, use_mix_token: bool = False, pooling_scale: int = 2, + device=None, + dtype=None, ): """Initialize VOLO model. @@ -714,6 +762,7 @@ def __init__( pooling_scale: Pooling scale factor. """ super().__init__() + dd = {'device': device, 'dtype': dtype} num_layers = len(layers) mlp_ratio = to_ntuple(num_layers)(mlp_ratio) img_size = to_2tuple(img_size) @@ -735,12 +784,13 @@ def __init__( in_chans=in_chans, hidden_dim=stem_hidden_dim, embed_dim=embed_dims[0], + **dd, ) r = patch_size # initial positional encoding, we add positional encoding after outlooker blocks patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale) - self.pos_embed = nn.Parameter(torch.zeros(1, patch_grid[0], patch_grid[1], embed_dims[-1])) + self.pos_embed = nn.Parameter(torch.zeros(1, patch_grid[0], patch_grid[1], embed_dims[-1], **dd)) self.pos_drop = nn.Dropout(p=pos_drop_rate) # set the main block in network @@ -761,6 +811,7 @@ def __init__( qkv_bias=qkv_bias, attn_drop=attn_drop_rate, norm_layer=norm_layer, + **dd, ) else: # stage 2 @@ -775,6 +826,7 @@ def __init__( drop_path_rate=drop_path_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, + **dd, ) network.append(stage) self.stage_ends.append(block_idx) @@ -782,7 +834,7 @@ def __init__( block_idx += 1 if downsamples[i]: # downsampling between two stages - network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2)) + network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2, **dd)) r *= 2 block_idx += 1 @@ -800,22 +852,24 @@ def __init__( qkv_bias=qkv_bias, attn_drop=attn_drop_rate, drop_path=0., - norm_layer=norm_layer) + norm_layer=norm_layer, + **dd, + ) for i in range(len(post_layers)) ]) - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1])) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1], **dd)) trunc_normal_(self.cls_token, std=.02) # set output type if use_aux_head: - self.aux_head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.aux_head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity() else: self.aux_head = None - self.norm = norm_layer(self.num_features) + self.norm = norm_layer(self.num_features, **dd) # Classifier head self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) self.apply(self._init_weights) @@ -891,9 +945,13 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + device = self.head.weight.device if hasattr(self.head, 'weight') else None + dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None + self.head = nn.Linear( + self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity() if self.aux_head is not None: - self.aux_head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.aux_head = nn.Linear( + self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity() def forward_tokens(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through token processing stages. diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 948ea501dc..ed8df58f23 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type import torch import torch.nn as nn @@ -28,8 +28,8 @@ class SequentialAppendList(nn.Sequential): - def __init__(self, *args): - super(SequentialAppendList, self).__init__(*args) + def __init__(self, *args, **kwargs): + super().__init__(*args) def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor: for i, module in enumerate(self): @@ -45,22 +45,25 @@ class OsaBlock(nn.Module): def __init__( self, - in_chs, - mid_chs, - out_chs, - layer_per_block, - residual=False, - depthwise=False, - attn='', - norm_layer=BatchNormAct2d, - act_layer=nn.ReLU, - drop_path=None, + in_chs: int, + mid_chs: int, + out_chs: int, + layer_per_block: int, + residual: bool = False, + depthwise: bool = False, + attn: str = '', + norm_layer: Type[nn.Module] = BatchNormAct2d, + act_layer: Type[nn.Module] = nn.ReLU, + drop_path: Optional[nn.Module] = None, + device=None, + dtype=None, ): - super(OsaBlock, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.residual = residual self.depthwise = depthwise - conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer) + conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer, **dd) next_in_chs = in_chs if self.depthwise and next_in_chs != mid_chs: @@ -83,7 +86,7 @@ def __init__( next_in_chs = in_chs + layer_per_block * mid_chs self.conv_concat = ConvNormAct(next_in_chs, out_chs, **conv_kwargs) - self.attn = create_attn(attn, out_chs) if attn else None + self.attn = create_attn(attn, out_chs, **dd) if attn else None self.drop_path = drop_path @@ -106,20 +109,23 @@ class OsaStage(nn.Module): def __init__( self, - in_chs, - mid_chs, - out_chs, - block_per_stage, - layer_per_block, - downsample=True, - residual=True, - depthwise=False, - attn='ese', - norm_layer=BatchNormAct2d, - act_layer=nn.ReLU, - drop_path_rates=None, + in_chs: int, + mid_chs: int, + out_chs: int, + block_per_stage: int, + layer_per_block: int, + downsample: bool = True, + residual: bool = True, + depthwise: bool = False, + attn: str = 'ese', + norm_layer: Type[nn.Module] = BatchNormAct2d, + act_layer: Type[nn.Module] = nn.ReLU, + drop_path_rates: Optional[List[float]] = None, + device=None, + dtype=None, ): - super(OsaStage, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.grad_checkpointing = False if downsample: @@ -144,7 +150,8 @@ def __init__( attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, - drop_path=drop_path + drop_path=drop_path, + **dd, )] in_chs = out_chs self.blocks = nn.Sequential(*blocks) @@ -163,15 +170,17 @@ class VovNet(nn.Module): def __init__( self, - cfg, - in_chans=3, - num_classes=1000, - global_pool='avg', - output_stride=32, - norm_layer=BatchNormAct2d, - act_layer=nn.ReLU, - drop_rate=0., - drop_path_rate=0., + cfg: dict, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + output_stride: int = 32, + norm_layer: Type[nn.Module] = BatchNormAct2d, + act_layer: Type[nn.Module] = nn.ReLU, + drop_rate: float = 0., + drop_path_rate: float = 0., + device=None, + dtype=None, **kwargs, ): """ @@ -187,7 +196,8 @@ def __init__( drop_path_rate (float): Stochastic depth drop-path rate (default: 0.) kwargs (dict): Extra kwargs overlayed onto cfg """ - super(VovNet, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} self.num_classes = num_classes self.drop_rate = drop_rate assert output_stride == 32 # FIXME support dilation @@ -199,7 +209,7 @@ def __init__( stage_out_chs = cfg["stage_out_chs"] block_per_stage = cfg["block_per_stage"] layer_per_block = cfg["layer_per_block"] - conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer) + conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer, **dd) # Stem module last_stem_stride = stem_stride // 2 @@ -237,7 +247,7 @@ def __init__( self.stages = nn.Sequential(*stages) self.head_hidden_size = self.num_features - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate, **dd) for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): diff --git a/timm/models/xception.py b/timm/models/xception.py index e1f92abfa0..6f0b5dc590 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -24,6 +24,7 @@ import torch.jit import torch.nn as nn import torch.nn.functional as F +from typing import Optional from timm.layers import create_classifier from ._builder import build_model_with_cfg @@ -33,12 +34,32 @@ class SeparableConv2d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1): - super(SeparableConv2d, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() self.conv1 = nn.Conv2d( - in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=False) - self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False) + in_channels, + in_channels, + kernel_size, + stride, + padding, + dilation, + groups=in_channels, + bias=False, + **dd, + ) + self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False, **dd) def forward(self, x): x = self.conv1(x) @@ -47,12 +68,23 @@ def forward(self, x): class Block(nn.Module): - def __init__(self, in_channels, out_channels, reps, strides=1, start_with_relu=True, grow_first=True): - super(Block, self).__init__() + def __init__( + self, + in_channels: int, + out_channels: int, + reps: int, + strides: int = 1, + start_with_relu: bool = True, + grow_first: bool = True, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() if out_channels != in_channels or strides != 1: - self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False) - self.skipbn = nn.BatchNorm2d(out_channels) + self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False, **dd) + self.skipbn = nn.BatchNorm2d(out_channels, **dd) else: self.skip = None @@ -65,8 +97,8 @@ def __init__(self, in_channels, out_channels, reps, strides=1, start_with_relu=T inc = in_channels outc = in_channels if i < (reps - 1) else out_channels rep.append(nn.ReLU(inplace=True)) - rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1)) - rep.append(nn.BatchNorm2d(outc)) + rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1, **dd)) + rep.append(nn.BatchNorm2d(outc, **dd)) if not start_with_relu: rep = rep[1:] @@ -96,47 +128,56 @@ class Xception(nn.Module): https://arxiv.org/pdf/1610.02357.pdf """ - def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'): + def __init__( + self, + num_classes: int = 1000, + in_chans: int = 3, + drop_rate: float = 0., + global_pool: str = 'avg', + device=None, + dtype=None, + ): """ Constructor Args: num_classes: number of classes """ - super(Xception, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} self.drop_rate = drop_rate self.global_pool = global_pool self.num_classes = num_classes self.num_features = self.head_hidden_size = 2048 - self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False) - self.bn1 = nn.BatchNorm2d(32) + self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False, **dd) + self.bn1 = nn.BatchNorm2d(32, **dd) self.act1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(32, 64, 3, bias=False) - self.bn2 = nn.BatchNorm2d(64) + self.conv2 = nn.Conv2d(32, 64, 3, bias=False, **dd) + self.bn2 = nn.BatchNorm2d(64, **dd) self.act2 = nn.ReLU(inplace=True) - self.block1 = Block(64, 128, 2, 2, start_with_relu=False) - self.block2 = Block(128, 256, 2, 2) - self.block3 = Block(256, 728, 2, 2) + self.block1 = Block(64, 128, 2, 2, start_with_relu=False, **dd) + self.block2 = Block(128, 256, 2, 2, **dd) + self.block3 = Block(256, 728, 2, 2, **dd) - self.block4 = Block(728, 728, 3, 1) - self.block5 = Block(728, 728, 3, 1) - self.block6 = Block(728, 728, 3, 1) - self.block7 = Block(728, 728, 3, 1) + self.block4 = Block(728, 728, 3, 1, **dd) + self.block5 = Block(728, 728, 3, 1, **dd) + self.block6 = Block(728, 728, 3, 1, **dd) + self.block7 = Block(728, 728, 3, 1, **dd) - self.block8 = Block(728, 728, 3, 1) - self.block9 = Block(728, 728, 3, 1) - self.block10 = Block(728, 728, 3, 1) - self.block11 = Block(728, 728, 3, 1) + self.block8 = Block(728, 728, 3, 1, **dd) + self.block9 = Block(728, 728, 3, 1, **dd) + self.block10 = Block(728, 728, 3, 1, **dd) + self.block11 = Block(728, 728, 3, 1, **dd) - self.block12 = Block(728, 1024, 2, 2, grow_first=False) + self.block12 = Block(728, 1024, 2, 2, grow_first=False, **dd) - self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) - self.bn3 = nn.BatchNorm2d(1536) + self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1, **dd) + self.bn3 = nn.BatchNorm2d(1536, **dd) self.act3 = nn.ReLU(inplace=True) - self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1) - self.bn4 = nn.BatchNorm2d(self.num_features) + self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1, **dd) + self.bn4 = nn.BatchNorm2d(self.num_features, **dd) self.act4 = nn.ReLU(inplace=True) self.feature_info = [ dict(num_chs=64, reduction=2, module='act2'), @@ -146,7 +187,7 @@ def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg' dict(num_chs=2048, reduction=32, module='act4'), ] - self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool, **dd) # #------- init weights -------- for m in self.modules(): diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 3c67fa8036..18e1f0b4a9 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -32,21 +32,31 @@ def __init__( padding: PadType = '', act_layer: Type[nn.Module] = nn.ReLU, norm_layer: Type[nn.Module] = nn.BatchNorm2d, + device=None, + dtype=None, ): - super(SeparableConv2d, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() self.kernel_size = kernel_size self.dilation = dilation # depthwise convolution self.conv_dw = create_conv2d( - in_chs, in_chs, kernel_size, stride=stride, - padding=padding, dilation=dilation, depthwise=True) - self.bn_dw = norm_layer(in_chs) + in_chs, + in_chs, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + depthwise=True, + **dd, + ) + self.bn_dw = norm_layer(in_chs, **dd) self.act_dw = act_layer(inplace=True) if act_layer is not None else nn.Identity() # pointwise convolution - self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1) - self.bn_pw = norm_layer(out_chs) + self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1, **dd) + self.bn_pw = norm_layer(out_chs, **dd) self.act_pw = act_layer(inplace=True) if act_layer is not None else nn.Identity() def forward(self, x): @@ -71,20 +81,30 @@ def __init__( act_layer: Type[nn.Module] = nn.ReLU, norm_layer: Type[nn.Module] = nn.BatchNorm2d, first_act: bool = True, + device=None, + dtype=None, ): - super(PreSeparableConv2d, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer) self.kernel_size = kernel_size self.dilation = dilation - self.norm = norm_act_layer(in_chs, inplace=True) if first_act else nn.Identity() + self.norm = norm_act_layer(in_chs, inplace=True, **dd) if first_act else nn.Identity() # depthwise convolution self.conv_dw = create_conv2d( - in_chs, in_chs, kernel_size, stride=stride, - padding=padding, dilation=dilation, depthwise=True) + in_chs, + in_chs, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + depthwise=True, + **dd, + ) # pointwise convolution - self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1) + self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1, **dd) def forward(self, x): x = self.norm(x) @@ -105,16 +125,26 @@ def __init__( no_skip: bool = False, act_layer: Type[nn.Module] = nn.ReLU, norm_layer: Optional[Type[nn.Module]] = None, - drop_path: Optional[nn.Module] = None + drop_path: Optional[nn.Module] = None, + device=None, + dtype=None, ): - super(XceptionModule, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() out_chs = to_3tuple(out_chs) self.in_channels = in_chs self.out_channels = out_chs[-1] self.no_skip = no_skip if not no_skip and (self.out_channels != self.in_channels or stride != 1): self.shortcut = ConvNormAct( - in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, apply_act=False) + in_chs, + self.out_channels, + 1, + stride=stride, + norm_layer=norm_layer, + apply_act=False, + **dd, + ) else: self.shortcut = None @@ -124,8 +154,16 @@ def __init__( if start_with_relu: self.stack.add_module(f'act{i + 1}', act_layer(inplace=i > 0)) self.stack.add_module(f'conv{i + 1}', SeparableConv2d( - in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type, - act_layer=separable_act_layer, norm_layer=norm_layer)) + in_chs, + out_chs[i], + 3, + stride=stride if i == 2 else 1, + dilation=dilation, + padding=pad_type, + act_layer=separable_act_layer, + norm_layer=norm_layer, + **dd, + )) in_chs = out_chs[i] self.drop_path = drop_path @@ -153,19 +191,22 @@ def __init__( no_skip: bool = False, act_layer: Type[nn.Module] = nn.ReLU, norm_layer: Optional[Type[nn.Module]] = None, - drop_path: Optional[nn.Module] = None + drop_path: Optional[nn.Module] = None, + device=None, + dtype=None, ): - super(PreXceptionModule, self).__init__() + dd = {'device': device, 'dtype': dtype} + super().__init__() out_chs = to_3tuple(out_chs) self.in_channels = in_chs self.out_channels = out_chs[-1] self.no_skip = no_skip if not no_skip and (self.out_channels != self.in_channels or stride != 1): - self.shortcut = create_conv2d(in_chs, self.out_channels, 1, stride=stride) + self.shortcut = create_conv2d(in_chs, self.out_channels, 1, stride=stride, **dd) else: self.shortcut = nn.Identity() - self.norm = get_norm_act_layer(norm_layer, act_layer=act_layer)(in_chs, inplace=True) + self.norm = get_norm_act_layer(norm_layer, act_layer=act_layer)(in_chs, inplace=True, **dd) self.stack = nn.Sequential() for i in range(3): self.stack.add_module(f'conv{i + 1}', PreSeparableConv2d( @@ -178,6 +219,7 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, first_act=i > 0, + **dd, )) in_chs = out_chs[i] @@ -210,17 +252,20 @@ def __init__( drop_rate: float = 0., drop_path_rate: float = 0., global_pool: str = 'avg', + device=None, + dtype=None, ): - super(XceptionAligned, self).__init__() + super().__init__() + dd = {'device': device, 'dtype': dtype} assert output_stride in (8, 16, 32) self.num_classes = num_classes self.drop_rate = drop_rate self.grad_checkpointing = False - layer_args = dict(act_layer=act_layer, norm_layer=norm_layer) + layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, **dd) self.stem = nn.Sequential(*[ ConvNormAct(in_chans, 32, kernel_size=3, stride=2, **layer_args), - create_conv2d(32, 64, kernel_size=3, stride=1) if preact else + create_conv2d(32, 64, kernel_size=3, stride=1, **dd) if preact else ConvNormAct(32, 64, kernel_size=3, stride=1, **layer_args) ]) @@ -257,6 +302,7 @@ def __init__( num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate, + **dd, ) @torch.jit.ignore diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 271578adf8..7833a8e1ca 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -13,7 +13,7 @@ import math from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Type, Any import torch import torch.nn as nn @@ -38,9 +38,17 @@ class PositionalEncodingFourier(nn.Module): - https://github.com/facebookresearch/xcit/blob/master/xcit.py """ - def __init__(self, hidden_dim=32, dim=768, temperature=10000): + def __init__( + self, + hidden_dim: int = 32, + dim: int = 768, + temperature: float = 10000, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) + self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1, **dd) self.scale = 2 * math.pi self.temperature = temperature self.hidden_dim = hidden_dim @@ -65,18 +73,29 @@ def forward(self, B: int, H: int, W: int): return pos.repeat(B, 1, 1, 1) # (B, C, H, W) -def conv3x3(in_planes, out_planes, stride=1): +def conv3x3(in_planes, out_planes, stride=1, device=None, dtype=None): """3x3 convolution + batch norm""" + dd = {'device': device, 'dtype': dtype} return torch.nn.Sequential( - nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False), - nn.BatchNorm2d(out_planes) + nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False, **dd), + nn.BatchNorm2d(out_planes, **dd) ) class ConvPatchEmbed(nn.Module): """Image to Patch Embedding using multiple convolutional layers""" - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, act_layer=nn.GELU): + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + act_layer: Type[nn.Module] = nn.GELU, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() img_size = to_2tuple(img_size) num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size) @@ -86,21 +105,21 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, act_l if patch_size == 16: self.proj = torch.nn.Sequential( - conv3x3(in_chans, embed_dim // 8, 2), + conv3x3(in_chans, embed_dim // 8, 2, **dd), act_layer(), - conv3x3(embed_dim // 8, embed_dim // 4, 2), + conv3x3(embed_dim // 8, embed_dim // 4, 2, **dd), act_layer(), - conv3x3(embed_dim // 4, embed_dim // 2, 2), + conv3x3(embed_dim // 4, embed_dim // 2, 2, **dd), act_layer(), - conv3x3(embed_dim // 2, embed_dim, 2), + conv3x3(embed_dim // 2, embed_dim, 2, **dd), ) elif patch_size == 8: self.proj = torch.nn.Sequential( - conv3x3(in_chans, embed_dim // 4, 2), + conv3x3(in_chans, embed_dim // 4, 2, **dd), act_layer(), - conv3x3(embed_dim // 4, embed_dim // 2, 2), + conv3x3(embed_dim // 4, embed_dim // 2, 2, **dd), act_layer(), - conv3x3(embed_dim // 2, embed_dim, 2), + conv3x3(embed_dim // 2, embed_dim, 2, **dd), ) else: raise('For convolutional projection, patch size has to be in [8, 16]') @@ -119,18 +138,26 @@ class LPI(nn.Module): 3x3 convolutions with GeLU and BatchNorm2d """ - def __init__(self, in_features, out_features=None, act_layer=nn.GELU, kernel_size=3): + def __init__( + self, + in_features: int, + out_features: Optional[int] = None, + act_layer: Type[nn.Module] = nn.GELU, + kernel_size: int = 3, + device=None, + dtype=None, + ): super().__init__() + dd = {'device': device, 'dtype': dtype} out_features = out_features or in_features - padding = kernel_size // 2 self.conv1 = torch.nn.Conv2d( - in_features, in_features, kernel_size=kernel_size, padding=padding, groups=in_features) + in_features, in_features, kernel_size=kernel_size, padding=padding, groups=in_features, **dd) self.act = act_layer() - self.bn = nn.BatchNorm2d(in_features) + self.bn = nn.BatchNorm2d(in_features, **dd) self.conv2 = torch.nn.Conv2d( - in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features) + in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features, **dd) def forward(self, x, H: int, W: int): B, N, C = x.shape @@ -148,31 +175,46 @@ class ClassAttentionBlock(nn.Module): def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - eta=1., - tokens_norm=False, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + eta: Optional[float] = 1., + tokens_norm: bool = False, + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm1 = norm_layer(dim) - + self.norm1 = norm_layer(dim, **dd) self.attn = ClassAttn( - dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + **dd, + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop) + self.norm2 = norm_layer(dim, **dd) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + **dd, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() if eta is not None: # LayerScale Initialization (no layerscale when None) - self.gamma1 = nn.Parameter(eta * torch.ones(dim)) - self.gamma2 = nn.Parameter(eta * torch.ones(dim)) + self.gamma1 = nn.Parameter(eta * torch.ones(dim, **dd)) + self.gamma2 = nn.Parameter(eta * torch.ones(dim, **dd)) else: self.gamma1, self.gamma2 = 1.0, 1.0 @@ -182,7 +224,8 @@ def __init__( def forward(self, x): x_norm1 = self.norm1(x) x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1) - x = x + self.drop_path(self.gamma1 * x_attn) + x = x + self.drop_path1(self.gamma1 * x_attn) + if self.tokens_norm: x = self.norm2(x) else: @@ -191,7 +234,7 @@ def forward(self, x): cls_token = x[:, 0:1] cls_token = self.gamma2 * self.mlp(cls_token) x = torch.cat([cls_token, x[:, 1:]], dim=1) - x = x_res + self.drop_path(x) + x = x_res + self.drop_path2(x) return x @@ -202,14 +245,24 @@ class XCA(nn.Module): normalized) Cross-covariance matrix (Q^T \\cdot K \\in d_h \\times d_h) """ - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} super().__init__() self.num_heads = num_heads self.fused_attn = use_fused_attn(experimental=True) - self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1, **dd)) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, **dd) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): @@ -244,38 +297,56 @@ def no_weight_decay(self): class XCABlock(nn.Module): def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - eta=1., + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + eta: float = 1., + device=None, + dtype=None, ): + dd = {'device': device, 'dtype': dtype} super().__init__() - self.norm1 = norm_layer(dim) - self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm1 = norm_layer(dim, **dd) + self.attn = XCA( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + **dd, + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm3 = norm_layer(dim) - self.local_mp = LPI(in_features=dim, act_layer=act_layer) + self.norm3 = norm_layer(dim, **dd) + self.local_mp = LPI(in_features=dim, act_layer=act_layer, **dd) + self.drop_path3 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop) + self.norm2 = norm_layer(dim, **dd) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + **dd, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.gamma1 = nn.Parameter(eta * torch.ones(dim)) - self.gamma3 = nn.Parameter(eta * torch.ones(dim)) - self.gamma2 = nn.Parameter(eta * torch.ones(dim)) + self.gamma1 = nn.Parameter(eta * torch.ones(dim, **dd)) + self.gamma3 = nn.Parameter(eta * torch.ones(dim, **dd)) + self.gamma2 = nn.Parameter(eta * torch.ones(dim, **dd)) def forward(self, x, H: int, W: int): - x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) + x = x + self.drop_path1(self.gamma1 * self.attn(self.norm1(x))) # NOTE official code has 3 then 2, so keeping it the same to be consistent with loaded weights # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 - x = x + self.drop_path(self.gamma3 * self.local_mp(self.norm3(x), H, W)) - x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + x = x + self.drop_path3(self.gamma3 * self.local_mp(self.norm3(x), H, W)) + x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x))) return x @@ -288,27 +359,29 @@ class Xcit(nn.Module): def __init__( self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - global_pool='token', - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4., - qkv_bias=True, - drop_rate=0., - pos_drop_rate=0., - proj_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - act_layer=None, - norm_layer=None, - cls_attn_layers=2, - use_pos_embed=True, - eta=1., - tokens_norm=False, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: int = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'token', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_rate: float = 0., + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + act_layer: Optional[Type[nn.Module]] = None, + norm_layer: Optional[Type[nn.Module]] = None, + cls_attn_layers: int = 2, + use_pos_embed: bool = True, + eta: float = 1., + tokens_norm: bool = False, + device=None, + dtype=None, ): """ Args: @@ -337,6 +410,7 @@ def __init__( interaction (class LPI) and the patch embedding (class ConvPatchEmbed) """ super().__init__() + dd = {'device': device, 'dtype': dtype} assert global_pool in ('', 'avg', 'token') img_size = to_2tuple(img_size) assert (img_size[0] % patch_size == 0) and (img_size[0] % patch_size == 0), \ @@ -355,12 +429,13 @@ def __init__( in_chans=in_chans, embed_dim=embed_dim, act_layer=act_layer, + **dd, ) r = patch_size - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd)) if use_pos_embed: - self.pos_embed = PositionalEncodingFourier(dim=embed_dim) + self.pos_embed = PositionalEncodingFourier(dim=embed_dim, **dd) else: self.pos_embed = None self.pos_drop = nn.Dropout(p=pos_drop_rate) @@ -377,6 +452,7 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, eta=eta, + **dd, ) for _ in range(depth)]) self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)] @@ -393,13 +469,14 @@ def __init__( norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm, + **dd, ) for _ in range(cls_attn_layers)]) # Classifier head - self.norm = norm_layer(embed_dim) + self.norm = norm_layer(embed_dim, **dd) self.head_drop = nn.Dropout(drop_rate) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity() # Init weights trunc_normal_(self.cls_token, std=.02) @@ -436,7 +513,9 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): if global_pool is not None: assert global_pool in ('', 'avg', 'token') self.global_pool = global_pool - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + device = self.head.weight.device if hasattr(self.head, 'weight') else None + dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None + self.head = nn.Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity() def forward_intermediates( self,