Skip to content

Commit ba5f514

Browse files
committed
More typing of GA mappers
1 parent a365dbf commit ba5f514

File tree

1 file changed

+44
-24
lines changed

1 file changed

+44
-24
lines changed

pymbolic/geometric_algebra/mapper.py

+44-24
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,22 @@
2525

2626
# This is experimental, undocumented, and could go away any second.
2727
# Consider yourself warned.
28+
from collections.abc import Set
2829

2930
from typing import ClassVar
3031

32+
from pymbolic.primitives import Expression
3133
import pymbolic.geometric_algebra.primitives as prim
3234
from pymbolic.geometric_algebra import MultiVector
3335
from pymbolic.mapper import (
3436
CachedMapper,
37+
CollectedT,
3538
Collector as CollectorBase,
3639
CombineMapper as CombineMapperBase,
3740
IdentityMapper as IdentityMapperBase,
41+
ResultT,
3842
WalkMapper as WalkMapperBase,
43+
P,
3944
)
4045
from pymbolic.mapper.constant_folder import (
4146
ConstantFoldingMapper as ConstantFoldingMapperBase,
@@ -48,48 +53,63 @@
4853
)
4954

5055

51-
class IdentityMapper(IdentityMapperBase):
52-
def map_multivector_variable(self, expr):
56+
class IdentityMapper(IdentityMapperBase[P]):
57+
def map_nabla(
58+
self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> Expression:
5359
return expr
5460

55-
map_nabla = map_multivector_variable
56-
map_nabla_component = map_multivector_variable
61+
def map_nabla_component(self,
62+
expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs) -> Expression:
63+
return expr
5764

58-
def map_derivative_source(self, expr):
59-
operand = self.rec(expr.operand)
65+
def map_derivative_source(self,
66+
expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs
67+
) -> Expression:
68+
operand = self.rec(expr.operand, *args, **kwargs)
6069
if operand is expr.operand:
6170
return expr
6271

6372
return type(expr)(operand, expr.nabla_id)
6473

6574

66-
class CombineMapper(CombineMapperBase):
67-
def map_derivative_source(self, expr):
68-
return self.rec(expr.operand)
75+
class CombineMapper(CombineMapperBase[ResultT, P]):
76+
def map_derivative_source(
77+
self, expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs
78+
) -> ResultT:
79+
return self.rec(expr.operand, *args, **kwargs)
6980

7081

71-
class Collector(CollectorBase):
72-
def map_nabla(self, expr):
82+
class Collector(CollectorBase[CollectedT, P]):
83+
def map_nabla(self,
84+
expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs
85+
) -> Set[CollectedT]:
7386
return set()
7487

75-
map_nabla_component = map_nabla
88+
def map_nabla_component(self,
89+
expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
90+
) -> Set[CollectedT]:
91+
return set()
7692

7793

78-
class WalkMapper(WalkMapperBase):
79-
def map_nabla(self, expr, *args):
80-
self.visit(expr, *args)
81-
self.post_visit(expr)
94+
class WalkMapper(WalkMapperBase[P]):
95+
def map_nabla(self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> None:
96+
self.visit(expr, *args, **kwargs)
97+
self.post_visit(expr, *args, **kwargs)
8298

83-
def map_nabla_component(self, expr, *args):
84-
self.visit(expr, *args)
85-
self.post_visit(expr)
99+
def map_nabla_component(
100+
self, expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
101+
) -> None:
102+
self.visit(expr, *args, **kwargs)
103+
self.post_visit(expr, *args, **kwargs)
86104

87-
def map_derivative_source(self, expr, *args):
88-
if not self.visit(expr, *args):
105+
def map_derivative_source(
106+
self, expr, *args: P.args, **kwargs: P.kwargs
107+
) -> None:
108+
if not self.visit(expr, *args, **kwargs):
89109
return
90110

91-
self.rec(expr.operand)
92-
self.post_visit(expr)
111+
self.rec(expr.operand, *args, **kwargs)
112+
self.post_visit(expr, *args, **kwargs)
93113

94114

95115
class EvaluationMapper(EvaluationMapperBase):
@@ -106,7 +126,7 @@ def map_derivative_source(self, expr):
106126
return type(expr)(operand, expr.nabla_id)
107127

108128

109-
class StringifyMapper(StringifyMapperBase):
129+
class StringifyMapper(StringifyMapperBase[[]]):
110130
AXES: ClassVar[dict[int, str]] = {0: "x", 1: "y", 2: "z"}
111131

112132
def map_nabla(self, expr, enclosing_prec):

0 commit comments

Comments
 (0)