diff --git a/ndindex/chunking.py b/ndindex/chunking.py index 5e3e9517..77b619cf 100644 --- a/ndindex/chunking.py +++ b/ndindex/chunking.py @@ -93,8 +93,6 @@ def num_chunks(self, shape): """ shape = asshape(shape) d = [ceiling(i, c) for i, c in zip(shape, self)] - if 0 in d: - return 1 return prod(d) def indices(self, shape): @@ -134,7 +132,7 @@ def indices(self, shape): raise ValueError("chunks dimensions must equal the array dimensions") d = [ceiling(i, c) for i, c in zip(shape, self)] if 0 in d: - yield Tuple(*[Slice(0, bool(i)*chunk_size, 1) for i, chunk_size in zip(d, self)]).expand(shape) + return for p in product(*[range(i) for i in d]): # p = (0, 0, 0), (0, 0, 1), ... yield Tuple(*[Slice(chunk_size*i, min(chunk_size*(i + 1), n), 1) diff --git a/ndindex/tests/test_chunking.py b/ndindex/tests/test_chunking.py index 6e6eaccc..34a7ff7a 100644 --- a/ndindex/tests/test_chunking.py +++ b/ndindex/tests/test_chunking.py @@ -82,6 +82,8 @@ def test_indices_error(): def test_num_chunks(chunk_size, shape): chunk_size = ChunkSize(chunk_size) assert chunk_size.num_chunks(shape) == len(list(chunk_size.indices(shape))) + if 0 in shape: + assert chunk_size.num_chunks(shape) == 0 @given(chunk_sizes(), chunk_shapes) def test_indices(chunk_size, shape):