Skip to content

Commit cd8b19a

Browse files
committed
fix illegal parameters of Pool and Pool3d of issue Jittor#451,Jittor#453,Jittor#456,Jittor#457
1 parent 862bce9 commit cd8b19a

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

python/jittor/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ def flatten(input, start_dim=0, end_dim=-1):
699699
end_dim = len(in_shape) + end_dim if end_dim < 0 else end_dim
700700
assert end_dim >= start_dim, "end_dim should be larger than or equal to start_dim for flatten function"
701701
if len(in_shape) <= end_dim:
702-
raise IndexError("Dimension out of range (expected to be in range of [%d, %d], but got %d)" % (-len(in_shape),len(in_shape) - 1,end_dim))
702+
raise IndexError(f"Dimension out of range (expected to be in range of [{-len(in_shape)}, {len(in_shape) - 1}], but got {end_dim})")
703703
out_shape = []
704704
for i in range(0,start_dim,1): out_shape.append(in_shape[i])
705705
dims = 1

python/jittor/pool.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ class Pool(Module):
2121
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
2222
assert dilation == None
2323
assert return_indices == None or op == "maximum"
24+
if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0:
25+
raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}")
26+
if self.stride[0] <= 0 or self.stride[1] <= 0:
27+
raise RuntimeError(f"stride must be greater than zero, but got {stride}")
28+
if self.padding[0] < 0 or self.padding[1] < 0:
29+
raise RuntimeError(f"padding must be non-negative, but got {padding}")
2430
self.return_indices = return_indices
2531
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
2632
self.op = op
@@ -29,12 +35,6 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in
2935
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
3036
self.ceil_mode = ceil_mode
3137
self.count_include_pad = count_include_pad and padding != 0
32-
if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0:
33-
raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}")
34-
if self.stride[0] <= 0 or self.stride[1] <= 0:
35-
raise RuntimeError(f"stride must be greater than zero, but got {stride}")
36-
if self.padding[0] < 0 or self.padding[1] < 0:
37-
raise RuntimeError(f"padding must be non-negative, but got {padding}")
3838

3939
def execute(self, x):
4040
N,C,H,W = x.shape
@@ -203,6 +203,12 @@ class Pool3d(Module):
203203
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
204204
assert dilation == None
205205
assert return_indices == None or op == "maximum"
206+
if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0 or self.kernel_size[2] <= 0:
207+
raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}")
208+
if self.stride[0] <= 0 or self.stride[1] <= 0 or self.stride[2] <= 0:
209+
raise RuntimeError(f"stride must be greater than zero, but got {stride}")
210+
if self.padding[0] < 0 or self.padding[1] < 0 or self.padding[2] < 0:
211+
raise RuntimeError(f"padding must be non-negative, but got {padding}")
206212
self.return_indices = return_indices
207213
self.kernel_size = _triple(kernel_size)
208214
self.op = op
@@ -211,12 +217,6 @@ def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_in
211217
self.padding = _triple(padding)
212218
self.ceil_mode = ceil_mode
213219
self.count_include_pad = count_include_pad and padding != 0
214-
if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0 or self.kernel_size[2] <= 0:
215-
raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}")
216-
if self.stride[0] <= 0 or self.stride[1] <= 0 or self.stride[2] <= 0:
217-
raise RuntimeError(f"stride must be greater than zero, but got {stride}")
218-
if self.padding[0] < 0 or self.padding[1] < 0 or self.padding[2] < 0:
219-
raise RuntimeError(f"padding must be non-negative, but got {padding}")
220220

221221
def execute(self, x):
222222
N,C,D,H,W = x.shape
@@ -518,7 +518,7 @@ def execute(self, x):
518518
f"i3*{self.sh}+i6", # Hid
519519
f"i4*{self.sw}+i7", # Wid
520520
])
521-
return xx.reduce("maximun", [5,6,7])
521+
return xx.reduce("maximum", [5,6,7])
522522

523523
def pool(x, kernel_size, op, padding=0, stride=None):
524524
return Pool(kernel_size, stride, padding, op=op)(x)

0 commit comments

Comments
 (0)