@@ -21,6 +21,12 @@ class Pool(Module):
21
21
def __init__ (self , kernel_size , stride = None , padding = 0 , dilation = None , return_indices = None , ceil_mode = False , count_include_pad = True , op = "maximum" ):
22
22
assert dilation == None
23
23
assert return_indices == None or op == "maximum"
24
+ if self .kernel_size [0 ] <= 0 or self .kernel_size [1 ] <= 0 :
25
+ raise RuntimeError (f"kernel_size must be greater than zero, but got { kernel_size } " )
26
+ if self .stride [0 ] <= 0 or self .stride [1 ] <= 0 :
27
+ raise RuntimeError (f"stride must be greater than zero, but got { stride } " )
28
+ if self .padding [0 ] < 0 or self .padding [1 ] < 0 :
29
+ raise RuntimeError (f"padding must be non-negative, but got { padding } " )
24
30
self .return_indices = return_indices
25
31
self .kernel_size = kernel_size if isinstance (kernel_size , tuple ) else (kernel_size , kernel_size )
26
32
self .op = op
@@ -29,12 +35,6 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in
29
35
self .padding = padding if isinstance (padding , tuple ) else (padding , padding )
30
36
self .ceil_mode = ceil_mode
31
37
self .count_include_pad = count_include_pad and padding != 0
32
- if self .kernel_size [0 ] <= 0 or self .kernel_size [1 ] <= 0 :
33
- raise RuntimeError (f"kernel_size must be greater than zero, but got { kernel_size } " )
34
- if self .stride [0 ] <= 0 or self .stride [1 ] <= 0 :
35
- raise RuntimeError (f"stride must be greater than zero, but got { stride } " )
36
- if self .padding [0 ] < 0 or self .padding [1 ] < 0 :
37
- raise RuntimeError (f"padding must be non-negative, but got { padding } " )
38
38
39
39
def execute (self , x ):
40
40
N ,C ,H ,W = x .shape
@@ -203,6 +203,12 @@ class Pool3d(Module):
203
203
def __init__ (self , kernel_size , stride = None , padding = 0 , dilation = None , return_indices = None , ceil_mode = False , count_include_pad = True , op = "maximum" ):
204
204
assert dilation == None
205
205
assert return_indices == None or op == "maximum"
206
+ if self .kernel_size [0 ] <= 0 or self .kernel_size [1 ] <= 0 or self .kernel_size [2 ] <= 0 :
207
+ raise RuntimeError (f"kernel_size must be greater than zero, but got { kernel_size } " )
208
+ if self .stride [0 ] <= 0 or self .stride [1 ] <= 0 or self .stride [2 ] <= 0 :
209
+ raise RuntimeError (f"stride must be greater than zero, but got { stride } " )
210
+ if self .padding [0 ] < 0 or self .padding [1 ] < 0 or self .padding [2 ] < 0 :
211
+ raise RuntimeError (f"padding must be non-negative, but got { padding } " )
206
212
self .return_indices = return_indices
207
213
self .kernel_size = _triple (kernel_size )
208
214
self .op = op
@@ -211,12 +217,6 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in
211
217
self .padding = _triple (padding )
212
218
self .ceil_mode = ceil_mode
213
219
self .count_include_pad = count_include_pad and padding != 0
214
- if self .kernel_size [0 ] <= 0 or self .kernel_size [1 ] <= 0 or self .kernel_size [2 ] <= 0 :
215
- raise RuntimeError (f"kernel_size must be greater than zero, but got { kernel_size } " )
216
- if self .stride [0 ] <= 0 or self .stride [1 ] <= 0 or self .stride [2 ] <= 0 :
217
- raise RuntimeError (f"stride must be greater than zero, but got { stride } " )
218
- if self .padding [0 ] < 0 or self .padding [1 ] < 0 or self .padding [2 ] < 0 :
219
- raise RuntimeError (f"padding must be non-negative, but got { padding } " )
220
220
221
221
def execute (self , x ):
222
222
N ,C ,D ,H ,W = x .shape
@@ -518,7 +518,7 @@ def execute(self, x):
518
518
f"i3*{ self .sh } +i6" , # Hid
519
519
f"i4*{ self .sw } +i7" , # Wid
520
520
])
521
- return xx .reduce ("maximun " , [5 ,6 ,7 ])
521
+ return xx .reduce ("maximum " , [5 ,6 ,7 ])
522
522
523
523
def pool (x , kernel_size , op , padding = 0 , stride = None ):
524
524
return Pool (kernel_size , stride , padding , op = op )(x )
0 commit comments