Skip to content

Commit 1810b0e

Browse files
majosminducer
andauthored
Allow specifying leaf class in recursive map and map-reduce (#128)
* allow specifying leaf class in recursive map and map-reduce * revert broken changes to decorators * add leaf_class to decorators * add tests for [multi]mapped_over_array_containers Co-authored-by: Andreas Klöckner <[email protected]>
1 parent 5c64c75 commit 1810b0e

File tree

2 files changed

+139
-32
lines changed

2 files changed

+139
-32
lines changed

arraycontext/container/traversal.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -256,46 +256,70 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
256256

257257
def rec_map_array_container(
258258
f: Callable[[Any], Any],
259-
ary: ArrayOrContainerT) -> ArrayOrContainerT:
259+
ary: ArrayOrContainerT,
260+
leaf_class: Optional[type] = None) -> ArrayOrContainerT:
260261
r"""Applies *f* recursively to an :class:`ArrayContainer`.
261262
262263
For a non-recursive version see :func:`map_array_container`.
263264
264265
:param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s,
265266
or an instance of a base array type.
266267
"""
267-
return _map_array_container_impl(f, ary, recursive=True)
268+
return _map_array_container_impl(f, ary, leaf_cls=leaf_class, recursive=True)
268269

269270

270271
def mapped_over_array_containers(
271-
f: Callable[[Any], Any]) -> Callable[[ArrayOrContainerT], ArrayOrContainerT]:
272+
f: Optional[Callable[[Any], Any]] = None,
273+
leaf_class: Optional[type] = None) -> Union[
274+
Callable[[ArrayOrContainerT], ArrayOrContainerT],
275+
Callable[
276+
[Callable[[Any], Any]],
277+
Callable[[ArrayOrContainerT], ArrayOrContainerT]]]:
272278
"""Decorator around :func:`rec_map_array_container`."""
273-
wrapper = partial(rec_map_array_container, f)
274-
update_wrapper(wrapper, f)
275-
return wrapper
279+
def decorator(g: Callable[[Any], Any]) -> Callable[
280+
[ArrayOrContainerT], ArrayOrContainerT]:
281+
wrapper = partial(rec_map_array_container, g, leaf_class=leaf_class)
282+
update_wrapper(wrapper, g)
283+
return wrapper
284+
if f is not None:
285+
return decorator(f)
286+
else:
287+
return decorator
276288

277289

278-
def rec_multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
290+
def rec_multimap_array_container(
291+
f: Callable[..., Any],
292+
*args: Any,
293+
leaf_class: Optional[type] = None) -> Any:
279294
r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.
280295
281296
For a non-recursive version see :func:`multimap_array_container`.
282297
283298
:param args: all :class:`ArrayContainer` arguments must be of the same
284299
type and with the same structure (same number of components, etc.).
285300
"""
286-
return _multimap_array_container_impl(f, *args, recursive=True)
301+
return _multimap_array_container_impl(
302+
f, *args, leaf_cls=leaf_class, recursive=True)
287303

288304

289305
def multimapped_over_array_containers(
290-
f: Callable[..., Any]) -> Callable[..., Any]:
306+
f: Optional[Callable[..., Any]] = None,
307+
leaf_class: Optional[type] = None) -> Union[
308+
Callable[..., Any],
309+
Callable[[Callable[..., Any]], Callable[..., Any]]]:
291310
"""Decorator around :func:`rec_multimap_array_container`."""
292-
# can't use functools.partial, because its result is insufficiently
293-
# function-y to be used as a method definition.
294-
def wrapper(*args: Any) -> Any:
295-
return rec_multimap_array_container(f, *args)
311+
def decorator(g: Callable[..., Any]) -> Callable[..., Any]:
312+
# can't use functools.partial, because its result is insufficiently
313+
# function-y to be used as a method definition.
314+
def wrapper(*args: Any) -> Any:
315+
return rec_multimap_array_container(g, *args, leaf_class=leaf_class)
316+
update_wrapper(wrapper, g)
317+
return wrapper
318+
if f is not None:
319+
return decorator(f)
320+
else:
321+
return decorator
296322

297-
update_wrapper(wrapper, f)
298-
return wrapper
299323

300324
# }}}
301325

@@ -401,7 +425,8 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
401425
def rec_map_reduce_array_container(
402426
reduce_func: Callable[[Iterable[Any]], Any],
403427
map_func: Callable[[Any], Any],
404-
ary: ArrayOrContainerT) -> "DeviceArray":
428+
ary: ArrayOrContainerT,
429+
leaf_class: Optional[type] = None) -> "DeviceArray":
405430
"""Perform a map-reduce over array containers recursively.
406431
407432
:param reduce_func: callable used to reduce over the components of *ary*
@@ -440,22 +465,26 @@ def rec_map_reduce_array_container(
440465
or any other such traversal.
441466
"""
442467
def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
443-
try:
444-
iterable = serialize_container(_ary)
445-
except NotAnArrayContainerError:
468+
if type(_ary) is leaf_class:
446469
return map_func(_ary)
447470
else:
448-
return reduce_func([
449-
rec(subary) for _, subary in iterable
450-
])
471+
try:
472+
iterable = serialize_container(_ary)
473+
except NotAnArrayContainerError:
474+
return map_func(_ary)
475+
else:
476+
return reduce_func([
477+
rec(subary) for _, subary in iterable
478+
])
451479

452480
return rec(ary)
453481

454482

455483
def rec_multimap_reduce_array_container(
456484
reduce_func: Callable[[Iterable[Any]], Any],
457485
map_func: Callable[..., Any],
458-
*args: Any) -> "DeviceArray":
486+
*args: Any,
487+
leaf_class: Optional[type] = None) -> "DeviceArray":
459488
r"""Perform a map-reduce over multiple array containers recursively.
460489
461490
:param reduce_func: callable used to reduce over the components of any
@@ -478,7 +507,7 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
478507

479508
return _multimap_array_container_impl(
480509
map_func, *args,
481-
reduce_func=_reduce_wrapper, leaf_cls=None, recursive=True)
510+
reduce_func=_reduce_wrapper, leaf_cls=leaf_class, recursive=True)
482511

483512
# }}}
484513

test/test_arraycontext.py

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,59 @@ def test_container_scalar_map(actx_factory):
756756
assert result is not None
757757

758758

759+
def test_container_map(actx_factory):
760+
actx = actx_factory()
761+
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
762+
_get_test_containers(actx)
763+
764+
# {{{ check
765+
766+
def _check_allclose(f, arg1, arg2, atol=2.0e-14):
767+
from arraycontext import NotAnArrayContainerError
768+
try:
769+
arg1_iterable = serialize_container(arg1)
770+
arg2_iterable = serialize_container(arg2)
771+
except NotAnArrayContainerError:
772+
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
773+
else:
774+
arg1_subarrays = [
775+
subarray for _, subarray in arg1_iterable]
776+
arg2_subarrays = [
777+
subarray for _, subarray in arg2_iterable]
778+
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
779+
_check_allclose(f, subarray1, subarray2)
780+
781+
def func(x):
782+
return x + 1
783+
784+
from arraycontext import rec_map_array_container
785+
result = rec_map_array_container(func, 1)
786+
assert result == 2
787+
788+
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
789+
result = rec_map_array_container(func, ary)
790+
_check_allclose(func, ary, result)
791+
792+
from arraycontext import mapped_over_array_containers
793+
794+
@mapped_over_array_containers
795+
def mapped_func(x):
796+
return func(x)
797+
798+
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
799+
result = mapped_func(ary)
800+
_check_allclose(func, ary, result)
801+
802+
@mapped_over_array_containers(leaf_class=DOFArray)
803+
def check_leaf(x):
804+
assert isinstance(x, DOFArray)
805+
806+
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
807+
check_leaf(ary)
808+
809+
# }}}
810+
811+
759812
def test_container_multimap(actx_factory):
760813
actx = actx_factory()
761814
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
@@ -764,7 +817,19 @@ def test_container_multimap(actx_factory):
764817
# {{{ check
765818

766819
def _check_allclose(f, arg1, arg2, atol=2.0e-14):
767-
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
820+
from arraycontext import NotAnArrayContainerError
821+
try:
822+
arg1_iterable = serialize_container(arg1)
823+
arg2_iterable = serialize_container(arg2)
824+
except NotAnArrayContainerError:
825+
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
826+
else:
827+
arg1_subarrays = [
828+
subarray for _, subarray in arg1_iterable]
829+
arg2_subarrays = [
830+
subarray for _, subarray in arg2_iterable]
831+
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
832+
_check_allclose(f, subarray1, subarray2)
768833

769834
def func_all_scalar(x, y):
770835
return x + y
@@ -779,17 +844,30 @@ def func_multiple_scalar(a, subary1, b, subary2):
779844
result = rec_multimap_array_container(func_all_scalar, 1, 2)
780845
assert result == 3
781846

782-
from functools import partial
783847
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
784848
result = rec_multimap_array_container(func_first_scalar, 1, ary)
785-
rec_multimap_array_container(
786-
partial(_check_allclose, lambda x: 1 + x),
787-
ary, result)
849+
_check_allclose(lambda x: 1 + x, ary, result)
788850

789851
result = rec_multimap_array_container(func_multiple_scalar, 2, ary, 2, ary)
790-
rec_multimap_array_container(
791-
partial(_check_allclose, lambda x: 4 * x),
792-
ary, result)
852+
_check_allclose(lambda x: 4 * x, ary, result)
853+
854+
from arraycontext import multimapped_over_array_containers
855+
856+
@multimapped_over_array_containers
857+
def mapped_func(a, subary1, b, subary2):
858+
return func_multiple_scalar(a, subary1, b, subary2)
859+
860+
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
861+
result = mapped_func(2, ary, 2, ary)
862+
_check_allclose(lambda x: 4 * x, ary, result)
863+
864+
@multimapped_over_array_containers(leaf_class=DOFArray)
865+
def check_leaf(a, subary1, b, subary2):
866+
assert isinstance(subary1, DOFArray)
867+
assert isinstance(subary2, DOFArray)
868+
869+
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
870+
check_leaf(2, ary, 2, ary)
793871

794872
with pytest.raises(AssertionError):
795873
rec_multimap_array_container(func_multiple_scalar, 2, ary_dof, 2, dc_of_dofs)

0 commit comments

Comments
 (0)