@@ -256,46 +256,70 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
256
256
257
257
def rec_map_array_container (
258
258
f : Callable [[Any ], Any ],
259
- ary : ArrayOrContainerT ) -> ArrayOrContainerT :
259
+ ary : ArrayOrContainerT ,
260
+ leaf_class : Optional [type ] = None ) -> ArrayOrContainerT :
260
261
r"""Applies *f* recursively to an :class:`ArrayContainer`.
261
262
262
263
For a non-recursive version see :func:`map_array_container`.
263
264
264
265
:param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s,
265
266
or an instance of a base array type.
266
267
"""
267
- return _map_array_container_impl (f , ary , recursive = True )
268
+ return _map_array_container_impl (f , ary , leaf_cls = leaf_class , recursive = True )
268
269
269
270
270
271
def mapped_over_array_containers (
271
- f : Callable [[Any ], Any ]) -> Callable [[ArrayOrContainerT ], ArrayOrContainerT ]:
272
+ f : Optional [Callable [[Any ], Any ]] = None ,
273
+ leaf_class : Optional [type ] = None ) -> Union [
274
+ Callable [[ArrayOrContainerT ], ArrayOrContainerT ],
275
+ Callable [
276
+ [Callable [[Any ], Any ]],
277
+ Callable [[ArrayOrContainerT ], ArrayOrContainerT ]]]:
272
278
"""Decorator around :func:`rec_map_array_container`."""
273
- wrapper = partial (rec_map_array_container , f )
274
- update_wrapper (wrapper , f )
275
- return wrapper
279
+ def decorator (g : Callable [[Any ], Any ]) -> Callable [
280
+ [ArrayOrContainerT ], ArrayOrContainerT ]:
281
+ wrapper = partial (rec_map_array_container , g , leaf_class = leaf_class )
282
+ update_wrapper (wrapper , g )
283
+ return wrapper
284
+ if f is not None :
285
+ return decorator (f )
286
+ else :
287
+ return decorator
276
288
277
289
278
- def rec_multimap_array_container (f : Callable [..., Any ], * args : Any ) -> Any :
290
+ def rec_multimap_array_container (
291
+ f : Callable [..., Any ],
292
+ * args : Any ,
293
+ leaf_class : Optional [type ] = None ) -> Any :
279
294
r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.
280
295
281
296
For a non-recursive version see :func:`multimap_array_container`.
282
297
283
298
:param args: all :class:`ArrayContainer` arguments must be of the same
284
299
type and with the same structure (same number of components, etc.).
285
300
"""
286
- return _multimap_array_container_impl (f , * args , recursive = True )
301
+ return _multimap_array_container_impl (
302
+ f , * args , leaf_cls = leaf_class , recursive = True )
287
303
288
304
289
305
def multimapped_over_array_containers (
290
- f : Callable [..., Any ]) -> Callable [..., Any ]:
306
+ f : Optional [Callable [..., Any ]] = None ,
307
+ leaf_class : Optional [type ] = None ) -> Union [
308
+ Callable [..., Any ],
309
+ Callable [[Callable [..., Any ]], Callable [..., Any ]]]:
291
310
"""Decorator around :func:`rec_multimap_array_container`."""
292
- # can't use functools.partial, because its result is insufficiently
293
- # function-y to be used as a method definition.
294
- def wrapper (* args : Any ) -> Any :
295
- return rec_multimap_array_container (f , * args )
311
+ def decorator (g : Callable [..., Any ]) -> Callable [..., Any ]:
312
+ # can't use functools.partial, because its result is insufficiently
313
+ # function-y to be used as a method definition.
314
+ def wrapper (* args : Any ) -> Any :
315
+ return rec_multimap_array_container (g , * args , leaf_class = leaf_class )
316
+ update_wrapper (wrapper , g )
317
+ return wrapper
318
+ if f is not None :
319
+ return decorator (f )
320
+ else :
321
+ return decorator
296
322
297
- update_wrapper (wrapper , f )
298
- return wrapper
299
323
300
324
# }}}
301
325
@@ -401,7 +425,8 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
401
425
def rec_map_reduce_array_container (
402
426
reduce_func : Callable [[Iterable [Any ]], Any ],
403
427
map_func : Callable [[Any ], Any ],
404
- ary : ArrayOrContainerT ) -> "DeviceArray" :
428
+ ary : ArrayOrContainerT ,
429
+ leaf_class : Optional [type ] = None ) -> "DeviceArray" :
405
430
"""Perform a map-reduce over array containers recursively.
406
431
407
432
:param reduce_func: callable used to reduce over the components of *ary*
@@ -440,22 +465,26 @@ def rec_map_reduce_array_container(
440
465
or any other such traversal.
441
466
"""
442
467
def rec (_ary : ArrayOrContainerT ) -> ArrayOrContainerT :
443
- try :
444
- iterable = serialize_container (_ary )
445
- except NotAnArrayContainerError :
468
+ if type (_ary ) is leaf_class :
446
469
return map_func (_ary )
447
470
else :
448
- return reduce_func ([
449
- rec (subary ) for _ , subary in iterable
450
- ])
471
+ try :
472
+ iterable = serialize_container (_ary )
473
+ except NotAnArrayContainerError :
474
+ return map_func (_ary )
475
+ else :
476
+ return reduce_func ([
477
+ rec (subary ) for _ , subary in iterable
478
+ ])
451
479
452
480
return rec (ary )
453
481
454
482
455
483
def rec_multimap_reduce_array_container (
456
484
reduce_func : Callable [[Iterable [Any ]], Any ],
457
485
map_func : Callable [..., Any ],
458
- * args : Any ) -> "DeviceArray" :
486
+ * args : Any ,
487
+ leaf_class : Optional [type ] = None ) -> "DeviceArray" :
459
488
r"""Perform a map-reduce over multiple array containers recursively.
460
489
461
490
:param reduce_func: callable used to reduce over the components of any
@@ -478,7 +507,7 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
478
507
479
508
return _multimap_array_container_impl (
480
509
map_func , * args ,
481
- reduce_func = _reduce_wrapper , leaf_cls = None , recursive = True )
510
+ reduce_func = _reduce_wrapper , leaf_cls = leaf_class , recursive = True )
482
511
483
512
# }}}
484
513
0 commit comments