57
57
)
58
58
59
59
import pymbolic .primitives as p
60
- from pymbolic .typing import ScalarT
60
+ from pymbolic .typing import Scalar as PbScalar
61
61
62
62
63
63
ExprT : TypeAlias = Expression
64
64
ConstantT = TypeVar ("ConstantT" )
65
- ToMatchpyT = Callable [[p .Expression ], ExprT ]
66
- FromMatchpyT = Callable [[ExprT ], p .Expression ]
65
+ ToMatchpyT = Callable [[p .ExpressionNode ], ExprT ]
66
+ FromMatchpyT = Callable [[ExprT ], p .ExpressionNode ]
67
67
68
68
69
69
_NOT_OPERAND_METADATA = {"not_an_operand" : True }
@@ -95,7 +95,7 @@ def __lt__(self, other):
95
95
96
96
97
97
@op_dataclass
98
- class Scalar (_Constant [ScalarT ]):
98
+ class Scalar (_Constant [PbScalar ]):
99
99
_mapper_method : str = "map_scalar"
100
100
101
101
@@ -360,11 +360,11 @@ def _get_operand_at_path(expr: PymbolicOp, path: tuple[int, ...]) -> PymbolicOp:
360
360
return result
361
361
362
362
363
- def match (subject : p .Expression ,
364
- pattern : p .Expression ,
363
+ def match (subject : p .ExpressionNode ,
364
+ pattern : p .ExpressionNode ,
365
365
to_matchpy_expr : ToMatchpyT | None = None ,
366
366
from_matchpy_expr : FromMatchpyT | None = None
367
- ) -> Iterator [Mapping [str , p .Expression | ScalarT ]]:
367
+ ) -> Iterator [Mapping [str , p .ExpressionNode | PbScalar ]]:
368
368
from matchpy import Pattern , match
369
369
370
370
from .tofrom import FromMatchpyExpressionMapper , ToMatchpyExpressionMapper
@@ -383,12 +383,12 @@ def match(subject: p.Expression,
383
383
for name , expr in subst .items ()}
384
384
385
385
386
- def match_anywhere (subject : p .Expression ,
387
- pattern : p .Expression ,
386
+ def match_anywhere (subject : p .ExpressionNode ,
387
+ pattern : p .ExpressionNode ,
388
388
to_matchpy_expr : ToMatchpyT | None = None ,
389
389
from_matchpy_expr : FromMatchpyT | None = None
390
- ) -> Iterator [tuple [Mapping [str , p .Expression | ScalarT ],
391
- p .Expression | ScalarT ]
390
+ ) -> Iterator [tuple [Mapping [str , p .ExpressionNode | PbScalar ],
391
+ p .ExpressionNode | PbScalar ]
392
392
]:
393
393
from matchpy import Pattern , match_anywhere
394
394
@@ -409,8 +409,8 @@ def match_anywhere(subject: p.Expression,
409
409
from_matchpy_expr (_get_operand_at_path (m_subject , path )))
410
410
411
411
412
- def make_replacement_rule (pattern : p .Expression ,
413
- replacement : Callable [..., p .Expression ],
412
+ def make_replacement_rule (pattern : p .ExpressionNode ,
413
+ replacement : Callable [..., p .ExpressionNode ],
414
414
to_matchpy_expr : ToMatchpyT | None = None ,
415
415
from_matchpy_expr : FromMatchpyT | None = None
416
416
) -> ReplacementRule :
@@ -437,11 +437,11 @@ def make_replacement_rule(pattern: p.Expression,
437
437
from_matchpy_expr ))
438
438
439
439
440
- def replace_all (expression : p .Expression ,
440
+ def replace_all (expression : p .ExpressionNode ,
441
441
rules : Iterable [ReplacementRule ],
442
442
to_matchpy_expr : ToMatchpyT | None = None ,
443
443
from_matchpy_expr : FromMatchpyT | None = None
444
- ) -> p .Expression | tuple [p .Expression , ...]:
444
+ ) -> p .ExpressionNode | tuple [p .ExpressionNode , ...]:
445
445
import collections .abc as abc
446
446
447
447
from matchpy import replace_all
0 commit comments