25
25
26
26
# This is experimental, undocumented, and could go away any second.
27
27
# Consider yourself warned.
28
-
28
+ from collections . abc import Set
29
29
from typing import ClassVar
30
30
31
31
import pymbolic .geometric_algebra .primitives as prim
32
32
from pymbolic .geometric_algebra import MultiVector
33
33
from pymbolic .mapper import (
34
34
CachedMapper ,
35
+ CollectedT ,
35
36
Collector as CollectorBase ,
36
37
CombineMapper as CombineMapperBase ,
37
38
IdentityMapper as IdentityMapperBase ,
39
+ P ,
40
+ ResultT ,
38
41
WalkMapper as WalkMapperBase ,
39
42
)
40
43
from pymbolic .mapper .constant_folder import (
46
49
PREC_NONE ,
47
50
StringifyMapper as StringifyMapperBase ,
48
51
)
52
+ from pymbolic .primitives import Expression
49
53
50
54
51
- class IdentityMapper (IdentityMapperBase ):
52
- def map_multivector_variable (self , expr ):
55
+ class IdentityMapper (IdentityMapperBase [P ]):
56
+ def map_nabla (
57
+ self , expr : prim .Nabla , * args : P .args , ** kwargs : P .kwargs ) -> Expression :
53
58
return expr
54
59
55
- map_nabla = map_multivector_variable
56
- map_nabla_component = map_multivector_variable
60
+ def map_nabla_component (self ,
61
+ expr : prim .NablaComponent , * args : P .args , ** kwargs : P .kwargs ) -> Expression :
62
+ return expr
57
63
58
- def map_derivative_source (self , expr ):
59
- operand = self .rec (expr .operand )
64
+ def map_derivative_source (self ,
65
+ expr : prim .DerivativeSource , * args : P .args , ** kwargs : P .kwargs
66
+ ) -> Expression :
67
+ operand = self .rec (expr .operand , * args , ** kwargs )
60
68
if operand is expr .operand :
61
69
return expr
62
70
63
71
return type (expr )(operand , expr .nabla_id )
64
72
65
73
66
- class CombineMapper (CombineMapperBase ):
67
- def map_derivative_source (self , expr ):
68
- return self .rec (expr .operand )
74
+ class CombineMapper (CombineMapperBase [ResultT , P ]):
75
+ def map_derivative_source (
76
+ self , expr : prim .DerivativeSource , * args : P .args , ** kwargs : P .kwargs
77
+ ) -> ResultT :
78
+ return self .rec (expr .operand , * args , ** kwargs )
69
79
70
80
71
- class Collector (CollectorBase ):
72
- def map_nabla (self , expr ):
81
+ class Collector (CollectorBase [CollectedT , P ]):
82
+ def map_nabla (self ,
83
+ expr : prim .Nabla , * args : P .args , ** kwargs : P .kwargs
84
+ ) -> Set [CollectedT ]:
73
85
return set ()
74
86
75
- map_nabla_component = map_nabla
87
+ def map_nabla_component (self ,
88
+ expr : prim .NablaComponent , * args : P .args , ** kwargs : P .kwargs
89
+ ) -> Set [CollectedT ]:
90
+ return set ()
76
91
77
92
78
- class WalkMapper (WalkMapperBase ):
79
- def map_nabla (self , expr , * args ) :
80
- self .visit (expr , * args )
81
- self .post_visit (expr )
93
+ class WalkMapper (WalkMapperBase [ P ] ):
94
+ def map_nabla (self , expr : prim . Nabla , * args : P . args , ** kwargs : P . kwargs ) -> None :
95
+ self .visit (expr , * args , ** kwargs )
96
+ self .post_visit (expr , * args , ** kwargs )
82
97
83
- def map_nabla_component (self , expr , * args ):
84
- self .visit (expr , * args )
85
- self .post_visit (expr )
98
+ def map_nabla_component (
99
+ self , expr : prim .NablaComponent , * args : P .args , ** kwargs : P .kwargs
100
+ ) -> None :
101
+ self .visit (expr , * args , ** kwargs )
102
+ self .post_visit (expr , * args , ** kwargs )
86
103
87
- def map_derivative_source (self , expr , * args ):
88
- if not self .visit (expr , * args ):
104
+ def map_derivative_source (
105
+ self , expr , * args : P .args , ** kwargs : P .kwargs
106
+ ) -> None :
107
+ if not self .visit (expr , * args , ** kwargs ):
89
108
return
90
109
91
- self .rec (expr .operand )
92
- self .post_visit (expr )
110
+ self .rec (expr .operand , * args , ** kwargs )
111
+ self .post_visit (expr , * args , ** kwargs )
93
112
94
113
95
114
class EvaluationMapper (EvaluationMapperBase ):
@@ -106,7 +125,7 @@ def map_derivative_source(self, expr):
106
125
return type (expr )(operand , expr .nabla_id )
107
126
108
127
109
- class StringifyMapper (StringifyMapperBase ):
128
+ class StringifyMapper (StringifyMapperBase [[]] ):
110
129
AXES : ClassVar [dict [int , str ]] = {0 : "x" , 1 : "y" , 2 : "z" }
111
130
112
131
def map_nabla (self , expr , enclosing_prec ):
0 commit comments