@@ -168,6 +168,11 @@ class UnsupportedArrayError(ValueError):
168
168
P = ParamSpec ("P" )
169
169
170
170
171
+ def _verify_is_array (expr : ArrayOrNames ) -> Array :
172
+ assert isinstance (expr , Array )
173
+ return expr
174
+
175
+
171
176
class Mapper (Generic [ResultT , FunctionResultT , P ]):
172
177
"""A class that when called with a :class:`pytato.Array` recursively
173
178
iterates over the DAG, calling the *_mapper_method* of each node. Users of
@@ -321,10 +326,7 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]):
321
326
implement default mapper methods; for that, see :class:`CopyMapper`.
322
327
323
328
"""
324
- def rec_ary (self , expr : Array ) -> Array :
325
- res = self .rec (expr )
326
- assert isinstance (res , Array )
327
- return res
329
+ pass
328
330
329
331
# }}}
330
332
@@ -341,10 +343,7 @@ class TransformMapperWithExtraArgs(
341
343
The logic in :class:`TransformMapper` purposely does not take the extra
342
344
arguments to keep the cost of its each call frame low.
343
345
"""
344
- def rec_ary (self , expr : Array , * args : P .args , ** kwargs : P .kwargs ) -> Array :
345
- res = self .rec (expr , * args , ** kwargs )
346
- assert isinstance (res , Array )
347
- return res
346
+ pass
348
347
349
348
# }}}
350
349
@@ -390,33 +389,33 @@ def map_placeholder(self, expr: Placeholder) -> Array:
390
389
non_equality_tags = expr .non_equality_tags )
391
390
392
391
def map_stack (self , expr : Stack ) -> Array :
393
- arrays = tuple (self .rec_ary (arr ) for arr in expr .arrays )
392
+ arrays = tuple (_verify_is_array ( self .rec (arr ) ) for arr in expr .arrays )
394
393
return Stack (arrays = arrays , axis = expr .axis , axes = expr .axes , tags = expr .tags ,
395
394
non_equality_tags = expr .non_equality_tags )
396
395
397
396
def map_concatenate (self , expr : Concatenate ) -> Array :
398
- arrays = tuple (self .rec_ary (arr ) for arr in expr .arrays )
397
+ arrays = tuple (_verify_is_array ( self .rec (arr ) ) for arr in expr .arrays )
399
398
return Concatenate (arrays = arrays , axis = expr .axis ,
400
399
axes = expr .axes , tags = expr .tags ,
401
400
non_equality_tags = expr .non_equality_tags )
402
401
403
402
def map_roll (self , expr : Roll ) -> Array :
404
- return Roll (array = self .rec_ary (expr .array ),
403
+ return Roll (array = _verify_is_array ( self .rec (expr .array ) ),
405
404
shift = expr .shift ,
406
405
axis = expr .axis ,
407
406
axes = expr .axes ,
408
407
tags = expr .tags ,
409
408
non_equality_tags = expr .non_equality_tags )
410
409
411
410
def map_axis_permutation (self , expr : AxisPermutation ) -> Array :
412
- return AxisPermutation (array = self .rec_ary (expr .array ),
411
+ return AxisPermutation (array = _verify_is_array ( self .rec (expr .array ) ),
413
412
axis_permutation = expr .axis_permutation ,
414
413
axes = expr .axes ,
415
414
tags = expr .tags ,
416
415
non_equality_tags = expr .non_equality_tags )
417
416
418
417
def _map_index_base (self , expr : IndexBase ) -> Array :
419
- return type (expr )(self .rec_ary (expr .array ),
418
+ return type (expr )(_verify_is_array ( self .rec (expr .array ) ),
420
419
indices = self .rec_idx_or_size_tuple (expr .indices ),
421
420
axes = expr .axes ,
422
421
tags = expr .tags ,
@@ -453,7 +452,7 @@ def map_size_param(self, expr: SizeParam) -> Array:
453
452
454
453
def map_einsum (self , expr : Einsum ) -> Array :
455
454
return Einsum (expr .access_descriptors ,
456
- tuple (self .rec_ary (arg ) for arg in expr .args ),
455
+ tuple (_verify_is_array ( self .rec (arg ) ) for arg in expr .args ),
457
456
axes = expr .axes ,
458
457
redn_axis_to_redn_descr = expr .redn_axis_to_redn_descr ,
459
458
tags = expr .tags ,
@@ -470,7 +469,7 @@ def map_named_array(self, expr: NamedArray) -> Array:
470
469
471
470
def map_dict_of_named_arrays (self ,
472
471
expr : DictOfNamedArrays ) -> DictOfNamedArrays :
473
- return DictOfNamedArrays ({key : self .rec_ary (val .expr )
472
+ return DictOfNamedArrays ({key : _verify_is_array ( self .rec (val .expr ) )
474
473
for key , val in expr .items ()},
475
474
tags = expr .tags
476
475
)
@@ -498,7 +497,7 @@ def map_loopy_call_result(self, expr: LoopyCallResult) -> Array:
498
497
non_equality_tags = expr .non_equality_tags )
499
498
500
499
def map_reshape (self , expr : Reshape ) -> Array :
501
- return Reshape (self .rec_ary (expr .array ),
500
+ return Reshape (_verify_is_array ( self .rec (expr .array ) ),
502
501
newshape = self .rec_idx_or_size_tuple (expr .newshape ),
503
502
order = expr .order ,
504
503
axes = expr .axes ,
@@ -509,10 +508,10 @@ def map_distributed_send_ref_holder(
509
508
self , expr : DistributedSendRefHolder ) -> Array :
510
509
return DistributedSendRefHolder (
511
510
send = DistributedSend (
512
- data = self .rec_ary (expr .send .data ),
511
+ data = _verify_is_array ( self .rec (expr .send .data ) ),
513
512
dest_rank = expr .send .dest_rank ,
514
513
comm_tag = expr .send .comm_tag ),
515
- passthrough_data = self .rec_ary (expr .passthrough_data ),
514
+ passthrough_data = _verify_is_array ( self .rec (expr .passthrough_data ) ),
516
515
)
517
516
518
517
def map_distributed_recv (self , expr : DistributedRecv ) -> Array :
@@ -590,19 +589,21 @@ def map_placeholder(self,
590
589
non_equality_tags = expr .non_equality_tags )
591
590
592
591
def map_stack (self , expr : Stack , * args : P .args , ** kwargs : P .kwargs ) -> Array :
593
- arrays = tuple (self .rec_ary (arr , * args , ** kwargs ) for arr in expr .arrays )
592
+ arrays = tuple (
593
+ _verify_is_array (self .rec (arr , * args , ** kwargs )) for arr in expr .arrays )
594
594
return Stack (arrays = arrays , axis = expr .axis , axes = expr .axes , tags = expr .tags ,
595
595
non_equality_tags = expr .non_equality_tags )
596
596
597
597
def map_concatenate (self ,
598
598
expr : Concatenate , * args : P .args , ** kwargs : P .kwargs ) -> Array :
599
- arrays = tuple (self .rec_ary (arr , * args , ** kwargs ) for arr in expr .arrays )
599
+ arrays = tuple (
600
+ _verify_is_array (self .rec (arr , * args , ** kwargs )) for arr in expr .arrays )
600
601
return Concatenate (arrays = arrays , axis = expr .axis ,
601
602
axes = expr .axes , tags = expr .tags ,
602
603
non_equality_tags = expr .non_equality_tags )
603
604
604
605
def map_roll (self , expr : Roll , * args : P .args , ** kwargs : P .kwargs ) -> Array :
605
- return Roll (array = self .rec_ary (expr .array , * args , ** kwargs ),
606
+ return Roll (array = _verify_is_array ( self .rec (expr .array , * args , ** kwargs ) ),
606
607
shift = expr .shift ,
607
608
axis = expr .axis ,
608
609
axes = expr .axes ,
@@ -611,7 +612,8 @@ def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> Array:
611
612
612
613
def map_axis_permutation (self , expr : AxisPermutation ,
613
614
* args : P .args , ** kwargs : P .kwargs ) -> Array :
614
- return AxisPermutation (array = self .rec_ary (expr .array , * args , ** kwargs ),
615
+ return AxisPermutation (array = _verify_is_array (
616
+ self .rec (expr .array , * args , ** kwargs )),
615
617
axis_permutation = expr .axis_permutation ,
616
618
axes = expr .axes ,
617
619
tags = expr .tags ,
@@ -620,7 +622,7 @@ def map_axis_permutation(self, expr: AxisPermutation,
620
622
def _map_index_base (self ,
621
623
expr : IndexBase , * args : P .args , ** kwargs : P .kwargs ) -> Array :
622
624
assert isinstance (expr , _SuppliedAxesAndTagsMixin )
623
- return type (expr )(self .rec_ary (expr .array , * args , ** kwargs ),
625
+ return type (expr )(_verify_is_array ( self .rec (expr .array , * args , ** kwargs ) ),
624
626
indices = self .rec_idx_or_size_tuple (expr .indices ,
625
627
* args , ** kwargs ),
626
628
axes = expr .axes ,
@@ -660,7 +662,8 @@ def map_size_param(self,
660
662
661
663
def map_einsum (self , expr : Einsum , * args : P .args , ** kwargs : P .kwargs ) -> Array :
662
664
return Einsum (expr .access_descriptors ,
663
- tuple (self .rec_ary (arg , * args , ** kwargs ) for arg in expr .args ),
665
+ tuple (_verify_is_array (
666
+ self .rec (arg , * args , ** kwargs )) for arg in expr .args ),
664
667
axes = expr .axes ,
665
668
redn_axis_to_redn_descr = expr .redn_axis_to_redn_descr ,
666
669
tags = expr .tags ,
@@ -679,7 +682,8 @@ def map_named_array(self,
679
682
def map_dict_of_named_arrays (self ,
680
683
expr : DictOfNamedArrays , * args : P .args , ** kwargs : P .kwargs
681
684
) -> DictOfNamedArrays :
682
- return DictOfNamedArrays ({key : self .rec_ary (val .expr , * args , ** kwargs )
685
+ return DictOfNamedArrays ({key : _verify_is_array (
686
+ self .rec (val .expr , * args , ** kwargs ))
683
687
for key , val in expr .items ()},
684
688
tags = expr .tags ,
685
689
)
@@ -711,7 +715,7 @@ def map_loopy_call_result(self, expr: LoopyCallResult,
711
715
712
716
def map_reshape (self , expr : Reshape ,
713
717
* args : P .args , ** kwargs : P .kwargs ) -> Array :
714
- return Reshape (self .rec_ary (expr .array , * args , ** kwargs ),
718
+ return Reshape (_verify_is_array ( self .rec (expr .array , * args , ** kwargs ) ),
715
719
newshape = self .rec_idx_or_size_tuple (expr .newshape ,
716
720
* args , ** kwargs ),
717
721
order = expr .order ,
@@ -723,10 +727,11 @@ def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder,
723
727
* args : P .args , ** kwargs : P .kwargs ) -> Array :
724
728
return DistributedSendRefHolder (
725
729
send = DistributedSend (
726
- data = self .rec_ary (expr .send .data , * args , ** kwargs ),
730
+ data = _verify_is_array ( self .rec (expr .send .data , * args , ** kwargs ) ),
727
731
dest_rank = expr .send .dest_rank ,
728
732
comm_tag = expr .send .comm_tag ),
729
- passthrough_data = self .rec_ary (expr .passthrough_data , * args , ** kwargs ))
733
+ passthrough_data = _verify_is_array (
734
+ self .rec (expr .passthrough_data , * args , ** kwargs )))
730
735
731
736
def map_distributed_recv (self , expr : DistributedRecv ,
732
737
* args : P .args , ** kwargs : P .kwargs ) -> Array :
@@ -1619,7 +1624,7 @@ def copy_dict_of_named_arrays(source_dict: DictOfNamedArrays,
1619
1624
if not source_dict :
1620
1625
data = {}
1621
1626
else :
1622
- data = {name : copy_mapper .rec_ary (val .expr )
1627
+ data = {name : _verify_is_array ( copy_mapper .rec (val .expr ) )
1623
1628
for name , val in sorted (source_dict .items ())}
1624
1629
1625
1630
return DictOfNamedArrays (data , tags = source_dict .tags )
0 commit comments