7
7
"""
8
8
from functools import partial
9
9
from math import comb # Python 3.8
10
- from typing import Optional , Type
10
+ from typing import Callable , Optional , Type , Union
11
11
12
12
import torch
13
13
import torch .nn as nn
@@ -39,8 +39,7 @@ def __init__(
39
39
device = None ,
40
40
dtype = None
41
41
) -> None :
42
- dd = {'device' : device , 'dtype' : dtype }
43
- super (BlurPool2d , self ).__init__ ()
42
+ super ().__init__ ()
44
43
assert filt_size > 1
45
44
self .channels = channels
46
45
self .filt_size = filt_size
@@ -51,12 +50,18 @@ def __init__(
51
50
# (0.5 + 0.5 x)^N => coefficients = C(N,k) / 2^N, k = 0..N
52
51
coeffs = torch .tensor (
53
52
[comb (filt_size - 1 , k ) for k in range (filt_size )],
54
- ** dd ,
53
+ device = 'cpu' ,
54
+ dtype = torch .float32 ,
55
55
) / (2 ** (filt_size - 1 )) # normalise so coefficients sum to 1
56
56
blur_filter = (coeffs [:, None ] * coeffs [None , :])[None , None , :, :]
57
57
if channels is not None :
58
58
blur_filter = blur_filter .repeat (self .channels , 1 , 1 , 1 )
59
- self .register_buffer ('filt' , blur_filter , persistent = False )
59
+
60
+ self .register_buffer (
61
+ 'filt' ,
62
+ blur_filter .to (device = device , dtype = dtype ),
63
+ persistent = False ,
64
+ )
60
65
61
66
def forward (self , x : torch .Tensor ) -> torch .Tensor :
62
67
x = F .pad (x , self .padding , mode = self .pad_mode )
@@ -69,6 +74,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
74
return F .conv2d (x , weight , stride = self .stride , groups = channels )
70
75
71
76
77
+ def _normalize_aa_layer (aa_layer : LayerType ) -> Callable [..., nn .Module ]:
78
+ """Map string shorthands to callables (class or partial)."""
79
+ if isinstance (aa_layer , str ):
80
+ key = aa_layer .lower ().replace ('_' , '' ).replace ('-' , '' )
81
+ if key in ('avg' , 'avgpool' ):
82
+ return nn .AvgPool2d
83
+ if key in ('blur' , 'blurpool' ):
84
+ return BlurPool2d
85
+ if key == 'blurpc' :
86
+ # preconfigure a constant-pad BlurPool2d
87
+ return partial (BlurPool2d , pad_mode = 'constant' )
88
+ raise AssertionError (f"Unknown anti-aliasing layer ({ aa_layer } )." )
89
+ return aa_layer
90
+
91
+
92
+ def _underlying_cls (layer_callable : Callable [..., nn .Module ]):
93
+ """Return the class behind a callable (unwrap partial), else None."""
94
+ if isinstance (layer_callable , partial ):
95
+ return layer_callable .func
96
+ return layer_callable if isinstance (layer_callable , type ) else None
97
+
98
+
99
+ def _is_blurpool (layer_callable : Callable [..., nn .Module ]) -> bool :
100
+ """True if callable is BlurPool2d or a partial of it."""
101
+ cls = _underlying_cls (layer_callable )
102
+ try :
103
+ return issubclass (cls , BlurPool2d ) # cls may be None, protect below
104
+ except TypeError :
105
+ return False
106
+ except Exception :
107
+ return False
108
+
109
+
72
110
def create_aa (
73
111
aa_layer : LayerType ,
74
112
channels : Optional [int ] = None ,
@@ -77,24 +115,29 @@ def create_aa(
77
115
noop : Optional [Type [nn .Module ]] = nn .Identity ,
78
116
device = None ,
79
117
dtype = None ,
80
- ) -> nn .Module :
81
- """ Anti-aliasing """
118
+ ) -> Optional [ nn .Module ] :
119
+ """ Anti-aliasing factory that supports strings, classes, and partials. """
82
120
if not aa_layer or not enable :
83
121
return noop () if noop is not None else None
84
122
85
- if isinstance (aa_layer , str ):
86
- aa_layer = aa_layer .lower ().replace ('_' , '' ).replace ('-' , '' )
87
- if aa_layer == 'avg' or aa_layer == 'avgpool' :
88
- aa_layer = nn .AvgPool2d
89
- elif aa_layer == 'blur' or aa_layer == 'blurpool' :
90
- aa_layer = partial (BlurPool2d , device = device , dtype = dtype )
91
- elif aa_layer == 'blurpc' :
92
- aa_layer = partial (BlurPool2d , pad_mode = 'constant' , device = device , dtype = dtype )
123
+ # Resolve strings to callables
124
+ aa_layer = _normalize_aa_layer (aa_layer )
93
125
94
- else :
95
- assert False , f"Unknown anti-aliasing layer ({ aa_layer } )."
126
+ # Build kwargs we *intend* to pass
127
+ call_kwargs = {"channels" : channels , "stride" : stride }
128
+
129
+ # Only add device/dtype for BlurPool2d (or partial of it) and don't override if already provided in the partial.
130
+ if _is_blurpool (aa_layer ):
131
+ # Check if aa_layer is a partial and already has device/dtype set
132
+ existing_kw = aa_layer .keywords if isinstance (aa_layer , partial ) and aa_layer .keywords else {}
133
+ if "device" not in existing_kw and device is not None :
134
+ call_kwargs ["device" ] = device
135
+ if "dtype" not in existing_kw and dtype is not None :
136
+ call_kwargs ["dtype" ] = dtype
96
137
138
+ # Try (channels, stride, [device, dtype]) first; fall back to (stride) only
97
139
try :
98
- return aa_layer (channels = channels , stride = stride )
99
- except TypeError as e :
140
+ return aa_layer (** call_kwargs )
141
+ except TypeError :
142
+ # Some layers (e.g., AvgPool2d) may not accept 'channels' and need stride passed as kernel
100
143
return aa_layer (stride )
0 commit comments