1
+ from __future__ import annotations
2
+
3
+
1
4
"""
2
5
Interoperability with :mod:`matchpy.functions` for pattern-matching and
3
6
term-rewriting.
38
41
"""
39
42
40
43
41
- import numpy as np
42
44
import abc
43
45
import pymbolic .primitives as p
44
46
45
- from typing import (Union , ClassVar , Optional , Iterator , Mapping ,
46
- Generic , TypeVar , Tuple , Iterable , Callable )
47
+ from typing_extensions import TypeAlias
48
+ from typing import (ClassVar , Iterator , Mapping ,
49
+ Generic , TypeVar , Iterable , Callable )
47
50
from dataclasses import dataclass , fields , field
48
51
49
52
from matchpy import (Operation , Arity , Expression , Atom as BaseAtom ,
50
53
Wildcard as BaseWildcard , ReplacementRule )
51
- from numbers import Number
52
54
from functools import partial
55
+ from pymbolic .typing import ScalarT
53
56
54
57
55
- ScalarT = Union [Number , int , np .bool_ , bool ]
56
- ExprT = Expression
58
+ ExprT : TypeAlias = Expression
57
59
ConstantT = TypeVar ("ConstantT" )
58
60
ToMatchpyT = Callable [[p .Expression ], ExprT ]
59
61
FromMatchpyT = Callable [[ExprT ], p .Expression ]
70
72
@op_dataclass
71
73
class _Constant (BaseAtom , Generic [ConstantT ]):
72
74
value : ConstantT
73
- variable_name : Optional [ str ] = None
75
+ variable_name : str | None = None
74
76
75
77
@property
76
78
def head (self ):
@@ -89,7 +91,7 @@ def __lt__(self, other):
89
91
90
92
@op_dataclass
91
93
class Scalar (_Constant [ScalarT ]):
92
- _mapper_method : [ str ] = "map_scalar"
94
+ _mapper_method : str = "map_scalar"
93
95
94
96
95
97
@op_dataclass
@@ -104,8 +106,8 @@ class ComparisonOp(_Constant[str]):
104
106
105
107
@op_dataclass
106
108
class TupleOp (Operation ):
107
- _operands : Tuple [ExprT , ...]
108
- variable_name : Optional [ str ] = non_operand_field (default = None )
109
+ _operands : tuple [ExprT , ...]
110
+ variable_name : str | None = non_operand_field (default = None )
109
111
110
112
arity : ClassVar [Arity ] = Arity .variadic
111
113
name : ClassVar [str ] = "tuple"
@@ -129,7 +131,7 @@ def variable_name(self):
129
131
pass
130
132
131
133
@property
132
- def operands (self ) -> Tuple [Expression ]:
134
+ def operands (self ) -> tuple [Expression , ... ]:
133
135
return tuple (getattr (self , field .name )
134
136
for field in fields (self )
135
137
if not field .metadata .get ("not_an_operand" , False ))
@@ -150,15 +152,15 @@ def name(self) -> str:
150
152
class Variable (PymbolicOp ):
151
153
id : Id
152
154
arity : ClassVar [Arity ] = Arity .unary
153
- variable_name : Optional [ str ] = non_operand_field (default = None )
155
+ variable_name : str | None = non_operand_field (default = None )
154
156
_mapper_method : ClassVar [str ] = "map_variable"
155
157
156
158
157
159
@op_dataclass
158
160
class Call (PymbolicOp ):
159
161
function : ExprT
160
162
args : TupleOp
161
- variable_name : Optional [ str ] = non_operand_field (default = None )
163
+ variable_name : str | None = non_operand_field (default = None )
162
164
163
165
arity : ClassVar [Arity ] = Arity .binary
164
166
_mapper_method : ClassVar [str ] = "map_call"
@@ -168,7 +170,7 @@ class Call(PymbolicOp):
168
170
class Subscript (PymbolicOp ):
169
171
aggregate : ExprT
170
172
indices : TupleOp
171
- variable_name : Optional [ str ] = non_operand_field (default = None )
173
+ variable_name : str | None = non_operand_field (default = None )
172
174
173
175
arity : ClassVar [Arity ] = Arity .binary
174
176
_mapper_method : ClassVar [str ] = "map_subscript"
@@ -182,7 +184,7 @@ class _BinaryOp(PymbolicOp):
182
184
x2 : ExprT
183
185
184
186
arity : ClassVar [Arity ] = Arity .binary
185
- variable_name : Optional [ str ] = non_operand_field (default = None )
187
+ variable_name : str | None = non_operand_field (default = None )
186
188
187
189
188
190
@op_dataclass
@@ -224,8 +226,8 @@ class RightShift(_BinaryOp):
224
226
225
227
@variadic_op_dataclass
226
228
class _VariadicCommAssocOp (PymbolicOp ):
227
- children : Tuple [ExprT , ...]
228
- variable_name : Optional [ str ] = non_operand_field (default = None )
229
+ children : tuple [ExprT , ...]
230
+ variable_name : str | None = non_operand_field (default = None )
229
231
230
232
commutative : ClassVar [bool ] = True
231
233
associative : ClassVar [bool ] = True
@@ -236,7 +238,7 @@ def __init__(self, *children: ExprT, variable_name=None):
236
238
object .__setattr__ (self , "variable_name" , variable_name )
237
239
238
240
@property
239
- def operands (self ) -> Tuple [ExprT , ...]:
241
+ def operands (self ) -> tuple [ExprT , ...]:
240
242
return self .children
241
243
242
244
@@ -283,7 +285,7 @@ class BitwiseXor(_VariadicCommAssocOp):
283
285
class _UnaryOp (PymbolicOp ):
284
286
x : ExprT
285
287
arity : ClassVar [Arity ] = Arity .unary
286
- variable_name : Optional [ str ] = non_operand_field (default = None )
288
+ variable_name : str | None = non_operand_field (default = None )
287
289
288
290
289
291
@op_dataclass
@@ -303,7 +305,7 @@ class Comparison(PymbolicOp):
303
305
left : ExprT
304
306
operator : ComparisonOp
305
307
right : ExprT
306
- variable_name : Optional [ str ] = non_operand_field (default = None )
308
+ variable_name : str | None = non_operand_field (default = None )
307
309
308
310
arity : ClassVar [Arity ] = Arity .ternary
309
311
_mapper_method : ClassVar [str ] = "map_comparison"
@@ -314,7 +316,7 @@ class If(PymbolicOp):
314
316
condition : ExprT
315
317
then : ExprT
316
318
else_ : ExprT
317
- variable_name : Optional [ str ] = non_operand_field (default = None )
319
+ variable_name : str | None = non_operand_field (default = None )
318
320
319
321
arity : ClassVar [Arity ] = Arity .ternary
320
322
_mapper_method : ClassVar [str ] = "map_if"
@@ -325,16 +327,16 @@ class Wildcard(BaseWildcard):
325
327
# {{{ FIXME: This should go into matchpy itself.
326
328
327
329
@classmethod
328
- def dot (cls , name = None ) -> " Wildcard" :
330
+ def dot (cls , name = None ) -> Wildcard :
329
331
return cls (min_count = 1 , fixed_size = True , variable_name = name )
330
332
331
333
@classmethod
332
- def star (cls , name = None ) -> " Wildcard" :
334
+ def star (cls , name = None ) -> Wildcard :
333
335
# FIXME: This should go into matchpy itself.
334
336
return cls (min_count = 0 , fixed_size = False , variable_name = name )
335
337
336
338
@classmethod
337
- def plus (cls , name = None ) -> " Wildcard" :
339
+ def plus (cls , name = None ) -> Wildcard :
338
340
# FIXME: This should go into matchpy itself.
339
341
return cls (min_count = 1 , fixed_size = False , variable_name = name )
340
342
@@ -343,7 +345,7 @@ def plus(cls, name=None) -> "Wildcard":
343
345
# }}}
344
346
345
347
346
- def _get_operand_at_path (expr : PymbolicOp , path : Tuple [int , ...]) -> PymbolicOp :
348
+ def _get_operand_at_path (expr : PymbolicOp , path : tuple [int , ...]) -> PymbolicOp :
347
349
result = expr
348
350
349
351
while path :
@@ -355,9 +357,9 @@ def _get_operand_at_path(expr: PymbolicOp, path: Tuple[int, ...]) -> PymbolicOp:
355
357
356
358
def match (subject : p .Expression ,
357
359
pattern : p .Expression ,
358
- to_matchpy_expr : Optional [ ToMatchpyT ] = None ,
359
- from_matchpy_expr : Optional [ FromMatchpyT ] = None
360
- ) -> Iterator [Mapping [str , Union [ p .Expression , ScalarT ] ]]:
360
+ to_matchpy_expr : ToMatchpyT | None = None ,
361
+ from_matchpy_expr : FromMatchpyT | None = None
362
+ ) -> Iterator [Mapping [str , p .Expression | ScalarT ]]:
361
363
from matchpy import match , Pattern
362
364
from .tofrom import (ToMatchpyExpressionMapper ,
363
365
FromMatchpyExpressionMapper )
@@ -378,10 +380,10 @@ def match(subject: p.Expression,
378
380
379
381
def match_anywhere (subject : p .Expression ,
380
382
pattern : p .Expression ,
381
- to_matchpy_expr : Optional [ ToMatchpyT ] = None ,
382
- from_matchpy_expr : Optional [ FromMatchpyT ] = None
383
- ) -> Iterator [Tuple [Mapping [str , Union [ p .Expression , ScalarT ] ],
384
- Union [ p .Expression , ScalarT ] ]
383
+ to_matchpy_expr : ToMatchpyT | None = None ,
384
+ from_matchpy_expr : FromMatchpyT | None = None
385
+ ) -> Iterator [tuple [Mapping [str , p .Expression | ScalarT ],
386
+ p .Expression | ScalarT ]
385
387
]:
386
388
from matchpy import match_anywhere , Pattern
387
389
from .tofrom import (ToMatchpyExpressionMapper ,
@@ -404,8 +406,8 @@ def match_anywhere(subject: p.Expression,
404
406
405
407
def make_replacement_rule (pattern : p .Expression ,
406
408
replacement : Callable [..., p .Expression ],
407
- to_matchpy_expr : Optional [ ToMatchpyT ] = None ,
408
- from_matchpy_expr : Optional [ FromMatchpyT ] = None
409
+ to_matchpy_expr : ToMatchpyT | None = None ,
410
+ from_matchpy_expr : FromMatchpyT | None = None
409
411
) -> ReplacementRule :
410
412
"""
411
413
Returns a :class:`matchpy.functions.ReplacementRule` from the objects
@@ -429,9 +431,9 @@ def make_replacement_rule(pattern: p.Expression,
429
431
430
432
def replace_all (expression : p .Expression ,
431
433
rules : Iterable [ReplacementRule ],
432
- to_matchpy_expr : Optional [ ToMatchpyT ] = None ,
433
- from_matchpy_expr : Optional [ FromMatchpyT ] = None
434
- ) -> Union [ p .Expression , Tuple [p .Expression , ...] ]:
434
+ to_matchpy_expr : ToMatchpyT | None = None ,
435
+ from_matchpy_expr : FromMatchpyT | None = None
436
+ ) -> p .Expression | tuple [p .Expression , ...]:
435
437
import collections .abc as abc
436
438
from .tofrom import (ToMatchpyExpressionMapper ,
437
439
FromMatchpyExpressionMapper )
0 commit comments