25
25
26
26
# This is experimental, undocumented, and could go away any second.
27
27
# Consider yourself warned.
28
+ from collections .abc import Set
28
29
29
30
from typing import ClassVar
30
31
32
+ from pymbolic .primitives import Expression
31
33
import pymbolic .geometric_algebra .primitives as prim
32
34
from pymbolic .geometric_algebra import MultiVector
33
35
from pymbolic .mapper import (
34
36
CachedMapper ,
37
+ CollectedT ,
35
38
Collector as CollectorBase ,
36
39
CombineMapper as CombineMapperBase ,
37
40
IdentityMapper as IdentityMapperBase ,
41
+ ResultT ,
38
42
WalkMapper as WalkMapperBase ,
43
+ P ,
39
44
)
40
45
from pymbolic .mapper .constant_folder import (
41
46
ConstantFoldingMapper as ConstantFoldingMapperBase ,
48
53
)
49
54
50
55
51
- class IdentityMapper (IdentityMapperBase ):
52
- def map_multivector_variable (self , expr ):
56
+ class IdentityMapper (IdentityMapperBase [P ]):
57
+ def map_nabla (
58
+ self , expr : prim .Nabla , * args : P .args , ** kwargs : P .kwargs ) -> Expression :
53
59
return expr
54
60
55
- map_nabla = map_multivector_variable
56
- map_nabla_component = map_multivector_variable
61
+ def map_nabla_component (self ,
62
+ expr : prim .NablaComponent , * args : P .args , ** kwargs : P .kwargs ) -> Expression :
63
+ return expr
57
64
58
- def map_derivative_source (self , expr ):
59
- operand = self .rec (expr .operand )
65
+ def map_derivative_source (self ,
66
+ expr : prim .DerivativeSource , * args : P .args , ** kwargs : P .kwargs
67
+ ) -> Expression :
68
+ operand = self .rec (expr .operand , * args , ** kwargs )
60
69
if operand is expr .operand :
61
70
return expr
62
71
63
72
return type (expr )(operand , expr .nabla_id )
64
73
65
74
66
- class CombineMapper (CombineMapperBase ):
67
- def map_derivative_source (self , expr ):
68
- return self .rec (expr .operand )
75
+ class CombineMapper (CombineMapperBase [ResultT , P ]):
76
+ def map_derivative_source (
77
+ self , expr : prim .DerivativeSource , * args : P .args , ** kwargs : P .kwargs
78
+ ) -> ResultT :
79
+ return self .rec (expr .operand , * args , ** kwargs )
69
80
70
81
71
- class Collector (CollectorBase ):
72
- def map_nabla (self , expr ):
82
+ class Collector (CollectorBase [CollectedT , P ]):
83
+ def map_nabla (self ,
84
+ expr : prim .Nabla , * args : P .args , ** kwargs : P .kwargs
85
+ ) -> Set [CollectedT ]:
73
86
return set ()
74
87
75
- map_nabla_component = map_nabla
88
+ def map_nabla_component (self ,
89
+ expr : prim .NablaComponent , * args : P .args , ** kwargs : P .kwargs
90
+ ) -> Set [CollectedT ]:
91
+ return set ()
76
92
77
93
78
- class WalkMapper (WalkMapperBase ):
79
- def map_nabla (self , expr , * args ) :
80
- self .visit (expr , * args )
81
- self .post_visit (expr )
94
+ class WalkMapper (WalkMapperBase [ P ] ):
95
+ def map_nabla (self , expr : prim . Nabla , * args : P . args , ** kwargs : P . kwargs ) -> None :
96
+ self .visit (expr , * args , ** kwargs )
97
+ self .post_visit (expr , * args , ** kwargs )
82
98
83
- def map_nabla_component (self , expr , * args ):
84
- self .visit (expr , * args )
85
- self .post_visit (expr )
99
+ def map_nabla_component (
100
+ self , expr : prim .NablaComponent , * args : P .args , ** kwargs : P .kwargs
101
+ ) -> None :
102
+ self .visit (expr , * args , ** kwargs )
103
+ self .post_visit (expr , * args , ** kwargs )
86
104
87
- def map_derivative_source (self , expr , * args ):
88
- if not self .visit (expr , * args ):
105
+ def map_derivative_source (
106
+ self , expr , * args : P .args , ** kwargs : P .kwargs
107
+ ) -> None :
108
+ if not self .visit (expr , * args , ** kwargs ):
89
109
return
90
110
91
- self .rec (expr .operand )
92
- self .post_visit (expr )
111
+ self .rec (expr .operand , * args , ** kwargs )
112
+ self .post_visit (expr , * args , ** kwargs )
93
113
94
114
95
115
class EvaluationMapper (EvaluationMapperBase ):
@@ -106,7 +126,7 @@ def map_derivative_source(self, expr):
106
126
return type (expr )(operand , expr .nabla_id )
107
127
108
128
109
- class StringifyMapper (StringifyMapperBase ):
129
+ class StringifyMapper (StringifyMapperBase [[]] ):
110
130
AXES : ClassVar [dict [int , str ]] = {0 : "x" , 1 : "y" , 2 : "z" }
111
131
112
132
def map_nabla (self , expr , enclosing_prec ):
0 commit comments