Skip to content

Commit ca4f336

Browse files
committed
add _verify_is_array to avoid the need for rec_ary
the latter inflates recursion depth
1 parent 119da92 commit ca4f336

File tree

5 files changed

+67
-50
lines changed

5 files changed

+67
-50
lines changed

pytato/distributed/partition.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,13 @@
9090
DistributedSendRefHolder,
9191
)
9292
from pytato.scalar_expr import SCALAR_CLASSES
93-
from pytato.transform import ArrayOrNames, CachedWalkMapper, CombineMapper, CopyMapper
93+
from pytato.transform import (
94+
ArrayOrNames,
95+
CachedWalkMapper,
96+
CombineMapper,
97+
CopyMapper,
98+
_verify_is_array,
99+
)
94100

95101

96102
if TYPE_CHECKING:
@@ -396,7 +402,7 @@ def _make_distributed_partition(
396402

397403
for name, val in name_to_part_output.items():
398404
assert name not in name_to_output
399-
name_to_output[name] = comm_replacer.rec_ary(val)
405+
name_to_output[name] = _verify_is_array(comm_replacer.rec(val))
400406

401407
comm_ids = part_comm_ids[part_id]
402408

pytato/transform/__init__.py

+34-29
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,11 @@ class UnsupportedArrayError(ValueError):
168168
P = ParamSpec("P")
169169

170170

171+
def _verify_is_array(expr: ArrayOrNames) -> Array:
172+
assert isinstance(expr, Array)
173+
return expr
174+
175+
171176
class Mapper(Generic[ResultT, FunctionResultT, P]):
172177
"""A class that when called with a :class:`pytato.Array` recursively
173178
iterates over the DAG, calling the *_mapper_method* of each node. Users of
@@ -321,10 +326,7 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]):
321326
implement default mapper methods; for that, see :class:`CopyMapper`.
322327
323328
"""
324-
def rec_ary(self, expr: Array) -> Array:
325-
res = self.rec(expr)
326-
assert isinstance(res, Array)
327-
return res
329+
pass
328330

329331
# }}}
330332

@@ -341,10 +343,7 @@ class TransformMapperWithExtraArgs(
341343
The logic in :class:`TransformMapper` purposely does not take the extra
342344
arguments to keep the cost of its each call frame low.
343345
"""
344-
def rec_ary(self, expr: Array, *args: P.args, **kwargs: P.kwargs) -> Array:
345-
res = self.rec(expr, *args, **kwargs)
346-
assert isinstance(res, Array)
347-
return res
346+
pass
348347

349348
# }}}
350349

@@ -390,33 +389,33 @@ def map_placeholder(self, expr: Placeholder) -> Array:
390389
non_equality_tags=expr.non_equality_tags)
391390

392391
def map_stack(self, expr: Stack) -> Array:
393-
arrays = tuple(self.rec_ary(arr) for arr in expr.arrays)
392+
arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays)
394393
return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags,
395394
non_equality_tags=expr.non_equality_tags)
396395

397396
def map_concatenate(self, expr: Concatenate) -> Array:
398-
arrays = tuple(self.rec_ary(arr) for arr in expr.arrays)
397+
arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays)
399398
return Concatenate(arrays=arrays, axis=expr.axis,
400399
axes=expr.axes, tags=expr.tags,
401400
non_equality_tags=expr.non_equality_tags)
402401

403402
def map_roll(self, expr: Roll) -> Array:
404-
return Roll(array=self.rec_ary(expr.array),
403+
return Roll(array=_verify_is_array(self.rec(expr.array)),
405404
shift=expr.shift,
406405
axis=expr.axis,
407406
axes=expr.axes,
408407
tags=expr.tags,
409408
non_equality_tags=expr.non_equality_tags)
410409

411410
def map_axis_permutation(self, expr: AxisPermutation) -> Array:
412-
return AxisPermutation(array=self.rec_ary(expr.array),
411+
return AxisPermutation(array=_verify_is_array(self.rec(expr.array)),
413412
axis_permutation=expr.axis_permutation,
414413
axes=expr.axes,
415414
tags=expr.tags,
416415
non_equality_tags=expr.non_equality_tags)
417416

418417
def _map_index_base(self, expr: IndexBase) -> Array:
419-
return type(expr)(self.rec_ary(expr.array),
418+
return type(expr)(_verify_is_array(self.rec(expr.array)),
420419
indices=self.rec_idx_or_size_tuple(expr.indices),
421420
axes=expr.axes,
422421
tags=expr.tags,
@@ -453,7 +452,7 @@ def map_size_param(self, expr: SizeParam) -> Array:
453452

454453
def map_einsum(self, expr: Einsum) -> Array:
455454
return Einsum(expr.access_descriptors,
456-
tuple(self.rec_ary(arg) for arg in expr.args),
455+
tuple(_verify_is_array(self.rec(arg)) for arg in expr.args),
457456
axes=expr.axes,
458457
redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr,
459458
tags=expr.tags,
@@ -470,7 +469,7 @@ def map_named_array(self, expr: NamedArray) -> Array:
470469

471470
def map_dict_of_named_arrays(self,
472471
expr: DictOfNamedArrays) -> DictOfNamedArrays:
473-
return DictOfNamedArrays({key: self.rec_ary(val.expr)
472+
return DictOfNamedArrays({key: _verify_is_array(self.rec(val.expr))
474473
for key, val in expr.items()},
475474
tags=expr.tags
476475
)
@@ -498,7 +497,7 @@ def map_loopy_call_result(self, expr: LoopyCallResult) -> Array:
498497
non_equality_tags=expr.non_equality_tags)
499498

500499
def map_reshape(self, expr: Reshape) -> Array:
501-
return Reshape(self.rec_ary(expr.array),
500+
return Reshape(_verify_is_array(self.rec(expr.array)),
502501
newshape=self.rec_idx_or_size_tuple(expr.newshape),
503502
order=expr.order,
504503
axes=expr.axes,
@@ -509,10 +508,10 @@ def map_distributed_send_ref_holder(
509508
self, expr: DistributedSendRefHolder) -> Array:
510509
return DistributedSendRefHolder(
511510
send=DistributedSend(
512-
data=self.rec_ary(expr.send.data),
511+
data=_verify_is_array(self.rec(expr.send.data)),
513512
dest_rank=expr.send.dest_rank,
514513
comm_tag=expr.send.comm_tag),
515-
passthrough_data=self.rec_ary(expr.passthrough_data),
514+
passthrough_data=_verify_is_array(self.rec(expr.passthrough_data)),
516515
)
517516

518517
def map_distributed_recv(self, expr: DistributedRecv) -> Array:
@@ -590,19 +589,21 @@ def map_placeholder(self,
590589
non_equality_tags=expr.non_equality_tags)
591590

592591
def map_stack(self, expr: Stack, *args: P.args, **kwargs: P.kwargs) -> Array:
593-
arrays = tuple(self.rec_ary(arr, *args, **kwargs) for arr in expr.arrays)
592+
arrays = tuple(
593+
_verify_is_array(self.rec(arr, *args, **kwargs)) for arr in expr.arrays)
594594
return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags,
595595
non_equality_tags=expr.non_equality_tags)
596596

597597
def map_concatenate(self,
598598
expr: Concatenate, *args: P.args, **kwargs: P.kwargs) -> Array:
599-
arrays = tuple(self.rec_ary(arr, *args, **kwargs) for arr in expr.arrays)
599+
arrays = tuple(
600+
_verify_is_array(self.rec(arr, *args, **kwargs)) for arr in expr.arrays)
600601
return Concatenate(arrays=arrays, axis=expr.axis,
601602
axes=expr.axes, tags=expr.tags,
602603
non_equality_tags=expr.non_equality_tags)
603604

604605
def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> Array:
605-
return Roll(array=self.rec_ary(expr.array, *args, **kwargs),
606+
return Roll(array=_verify_is_array(self.rec(expr.array, *args, **kwargs)),
606607
shift=expr.shift,
607608
axis=expr.axis,
608609
axes=expr.axes,
@@ -611,7 +612,8 @@ def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> Array:
611612

612613
def map_axis_permutation(self, expr: AxisPermutation,
613614
*args: P.args, **kwargs: P.kwargs) -> Array:
614-
return AxisPermutation(array=self.rec_ary(expr.array, *args, **kwargs),
615+
return AxisPermutation(array=_verify_is_array(
616+
self.rec(expr.array, *args, **kwargs)),
615617
axis_permutation=expr.axis_permutation,
616618
axes=expr.axes,
617619
tags=expr.tags,
@@ -620,7 +622,7 @@ def map_axis_permutation(self, expr: AxisPermutation,
620622
def _map_index_base(self,
621623
expr: IndexBase, *args: P.args, **kwargs: P.kwargs) -> Array:
622624
assert isinstance(expr, _SuppliedAxesAndTagsMixin)
623-
return type(expr)(self.rec_ary(expr.array, *args, **kwargs),
625+
return type(expr)(_verify_is_array(self.rec(expr.array, *args, **kwargs)),
624626
indices=self.rec_idx_or_size_tuple(expr.indices,
625627
*args, **kwargs),
626628
axes=expr.axes,
@@ -660,7 +662,8 @@ def map_size_param(self,
660662

661663
def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> Array:
662664
return Einsum(expr.access_descriptors,
663-
tuple(self.rec_ary(arg, *args, **kwargs) for arg in expr.args),
665+
tuple(_verify_is_array(
666+
self.rec(arg, *args, **kwargs)) for arg in expr.args),
664667
axes=expr.axes,
665668
redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr,
666669
tags=expr.tags,
@@ -679,7 +682,8 @@ def map_named_array(self,
679682
def map_dict_of_named_arrays(self,
680683
expr: DictOfNamedArrays, *args: P.args, **kwargs: P.kwargs
681684
) -> DictOfNamedArrays:
682-
return DictOfNamedArrays({key: self.rec_ary(val.expr, *args, **kwargs)
685+
return DictOfNamedArrays({key: _verify_is_array(
686+
self.rec(val.expr, *args, **kwargs))
683687
for key, val in expr.items()},
684688
tags=expr.tags,
685689
)
@@ -711,7 +715,7 @@ def map_loopy_call_result(self, expr: LoopyCallResult,
711715

712716
def map_reshape(self, expr: Reshape,
713717
*args: P.args, **kwargs: P.kwargs) -> Array:
714-
return Reshape(self.rec_ary(expr.array, *args, **kwargs),
718+
return Reshape(_verify_is_array(self.rec(expr.array, *args, **kwargs)),
715719
newshape=self.rec_idx_or_size_tuple(expr.newshape,
716720
*args, **kwargs),
717721
order=expr.order,
@@ -723,10 +727,11 @@ def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder,
723727
*args: P.args, **kwargs: P.kwargs) -> Array:
724728
return DistributedSendRefHolder(
725729
send=DistributedSend(
726-
data=self.rec_ary(expr.send.data, *args, **kwargs),
730+
data=_verify_is_array(self.rec(expr.send.data, *args, **kwargs)),
727731
dest_rank=expr.send.dest_rank,
728732
comm_tag=expr.send.comm_tag),
729-
passthrough_data=self.rec_ary(expr.passthrough_data, *args, **kwargs))
733+
passthrough_data=_verify_is_array(
734+
self.rec(expr.passthrough_data, *args, **kwargs)))
730735

731736
def map_distributed_recv(self, expr: DistributedRecv,
732737
*args: P.args, **kwargs: P.kwargs) -> Array:
@@ -1619,7 +1624,7 @@ def copy_dict_of_named_arrays(source_dict: DictOfNamedArrays,
16191624
if not source_dict:
16201625
data = {}
16211626
else:
1622-
data = {name: copy_mapper.rec_ary(val.expr)
1627+
data = {name: _verify_is_array(copy_mapper.rec(val.expr))
16231628
for name, val in sorted(source_dict.items())}
16241629

16251630
return DictOfNamedArrays(data, tags=source_dict.tags)

pytato/transform/calls.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
)
4141
from pytato.function import Call, NamedCallResult
4242
from pytato.tags import InlineCallTag
43-
from pytato.transform import ArrayOrNames, CopyMapper
43+
from pytato.transform import ArrayOrNames, CopyMapper, _verify_is_array
4444

4545

4646
if TYPE_CHECKING:
@@ -79,7 +79,7 @@ def map_call(self, expr: Call) -> AbstractResultWithNamedArrays:
7979
substitutor = PlaceholderSubstitutor(new_expr.bindings)
8080

8181
return DictOfNamedArrays(
82-
{name: substitutor.rec_ary(ret)
82+
{name: _verify_is_array(substitutor.rec(ret))
8383
for name, ret in new_expr.function.returns.items()},
8484
tags=new_expr.tags
8585
)

pytato/transform/einsum_distributive_law.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from pytato.transform import (
6060
MappedT,
6161
TransformMapperWithExtraArgs,
62+
_verify_is_array,
6263
)
6364
from pytato.utils import are_shapes_equal
6465

@@ -186,38 +187,42 @@ def map_index_lambda(self,
186187
and isinstance(hlo.x2, Array)
187188
and are_shapes_equal(hlo.x1.shape, hlo.x2.shape))
188189
# https://github.com/python/mypy/issues/16499
189-
return self.rec_ary(hlo.x1, ctx) + self.rec_ary(hlo.x2, ctx) # type: ignore[no-any-return]
190+
return ( # type: ignore[no-any-return]
191+
_verify_is_array(self.rec(hlo.x1, ctx))
192+
+ _verify_is_array(self.rec(hlo.x2, ctx)))
190193
elif hlo.binary_op == BinaryOpType.SUB:
191194
assert (isinstance(hlo.x1, Array)
192195
and isinstance(hlo.x2, Array)
193196
and are_shapes_equal(hlo.x1.shape, hlo.x2.shape))
194197
assert are_shapes_equal(hlo.x1.shape, hlo.x2.shape)
195198
# https://github.com/python/mypy/issues/16499
196-
return self.rec_ary(hlo.x1, ctx) - self.rec_ary(hlo.x2, ctx) # type: ignore[no-any-return]
199+
return ( # type: ignore[no-any-return]
200+
_verify_is_array(self.rec(hlo.x1, ctx))
201+
- _verify_is_array(self.rec(hlo.x2, ctx)))
197202
elif hlo.binary_op == BinaryOpType.MULT:
198203
if isinstance(hlo.x1, Array) and np.isscalar(hlo.x2):
199204
# https://github.com/python/mypy/issues/16499
200-
return self.rec_ary(hlo.x1, ctx) * hlo.x2 # type: ignore[no-any-return]
205+
return _verify_is_array(self.rec(hlo.x1, ctx)) * hlo.x2 # type: ignore[no-any-return]
201206
else:
202207
assert isinstance(hlo.x2, Array) and np.isscalar(hlo.x1)
203208
# https://github.com/python/mypy/issues/16499
204-
return hlo.x1 * self.rec_ary(hlo.x2, ctx) # type: ignore[no-any-return]
209+
return hlo.x1 * _verify_is_array(self.rec(hlo.x2, ctx)) # type: ignore[no-any-return]
205210
elif hlo.binary_op == BinaryOpType.TRUEDIV:
206211
if isinstance(hlo.x1, Array) and np.isscalar(hlo.x2):
207212
# https://github.com/python/mypy/issues/16499
208-
return self.rec_ary(hlo.x1, ctx) / hlo.x2 # type: ignore[no-any-return]
213+
return _verify_is_array(self.rec(hlo.x1, ctx)) / hlo.x2 # type: ignore[no-any-return]
209214
else:
210215
assert isinstance(hlo.x2, Array) and np.isscalar(hlo.x1)
211216
# https://github.com/python/mypy/issues/16499
212-
return hlo.x1 / self.rec_ary(hlo.x2, ctx) # type: ignore[no-any-return]
217+
return hlo.x1 / _verify_is_array(self.rec(hlo.x2, ctx)) # type: ignore[no-any-return]
213218
else:
214219
raise NotImplementedError(hlo)
215220
else:
216221
rec_expr = IndexLambda(
217222
expr=expr.expr,
218223
shape=expr.shape,
219224
dtype=expr.dtype,
220-
bindings=immutabledict({name: self.rec_ary(bnd, None)
225+
bindings=immutabledict({name: _verify_is_array(self.rec(bnd, None))
221226
for name, bnd in sorted(expr.bindings.items())}),
222227
var_to_reduction_descr=expr.var_to_reduction_descr,
223228
tags=expr.tags,
@@ -244,12 +249,13 @@ def map_einsum(self,
244249
tags=expr.tags,
245250
axes=expr.axes,
246251
)
247-
return self.rec_ary(expr.args[distributive_law_descr.ioperand], ctx)
252+
return _verify_is_array(
253+
self.rec(expr.args[distributive_law_descr.ioperand], ctx))
248254
else:
249255
assert isinstance(distributive_law_descr, DoNotDistribute)
250256
rec_expr = Einsum(
251257
expr.access_descriptors,
252-
tuple(self.rec_ary(arg, None) for arg in expr.args),
258+
tuple(_verify_is_array(self.rec(arg, None)) for arg in expr.args),
253259
expr.redn_axis_to_redn_descr,
254260
tags=expr.tags,
255261
axes=expr.axes
@@ -260,7 +266,7 @@ def map_einsum(self,
260266
def map_stack(self,
261267
expr: Stack,
262268
ctx: _EinsumDistributiveLawMapperContext | None) -> Array:
263-
rec_expr = Stack(tuple(self.rec_ary(ary, None)
269+
rec_expr = Stack(tuple(_verify_is_array(self.rec(ary, None))
264270
for ary in expr.arrays),
265271
expr.axis,
266272
tags=expr.tags,
@@ -271,7 +277,7 @@ def map_concatenate(self,
271277
expr: Concatenate,
272278
ctx: _EinsumDistributiveLawMapperContext | None
273279
) -> Array:
274-
rec_expr = Concatenate(tuple(self.rec_ary(ary, None)
280+
rec_expr = Concatenate(tuple(_verify_is_array(self.rec(ary, None))
275281
for ary in expr.arrays),
276282
expr.axis,
277283
tags=expr.tags,
@@ -282,7 +288,7 @@ def map_roll(self,
282288
expr: Roll,
283289
ctx: _EinsumDistributiveLawMapperContext | None
284290
) -> Array:
285-
rec_expr = Roll(self.rec_ary(expr.array, None),
291+
rec_expr = Roll(_verify_is_array(self.rec(expr.array, None)),
286292
expr.shift,
287293
expr.axis,
288294
tags=expr.tags,
@@ -293,7 +299,7 @@ def map_axis_permutation(self,
293299
expr: AxisPermutation,
294300
ctx: _EinsumDistributiveLawMapperContext | None
295301
) -> Array:
296-
rec_expr = AxisPermutation(self.rec_ary(expr.array, None),
302+
rec_expr = AxisPermutation(_verify_is_array(self.rec(expr.array, None)),
297303
expr.axis_permutation,
298304
tags=expr.tags,
299305
axes=expr.axes)
@@ -303,7 +309,7 @@ def _map_index_base(self,
303309
expr: IndexBase,
304310
ctx: _EinsumDistributiveLawMapperContext | None
305311
) -> Array:
306-
rec_expr = type(expr)(self.rec_ary(expr.array, None),
312+
rec_expr = type(expr)(_verify_is_array(self.rec(expr.array, None)),
307313
expr.indices,
308314
tags=expr.tags,
309315
axes=expr.axes)
@@ -317,7 +323,7 @@ def map_reshape(self,
317323
expr: Reshape,
318324
ctx: _EinsumDistributiveLawMapperContext | None
319325
) -> Array:
320-
rec_expr = Reshape(self.rec_ary(expr.array, None),
326+
rec_expr = Reshape(_verify_is_array(self.rec(expr.array, None)),
321327
expr.newshape,
322328
expr.order,
323329
tags=expr.tags,

pytato/transform/remove_broadcasts_einsum.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from typing import cast
3232

3333
from pytato.array import Array, Einsum, EinsumAxisDescriptor
34-
from pytato.transform import CopyMapper, MappedT
34+
from pytato.transform import CopyMapper, MappedT, _verify_is_array
3535
from pytato.utils import are_shape_components_equal
3636

3737

@@ -42,7 +42,7 @@ def map_einsum(self, expr: Einsum) -> Array:
4242
descr_to_axis_len = expr._access_descr_to_axis_len()
4343

4444
for acc_descrs, arg in zip(expr.access_descriptors, expr.args, strict=True):
45-
arg = self.rec_ary(arg)
45+
arg = _verify_is_array(self.rec(arg))
4646
axes_to_squeeze: list[int] = []
4747
for idim, acc_descr in enumerate(acc_descrs):
4848
if not are_shape_components_equal(arg.shape[idim],

0 commit comments

Comments
 (0)