Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b294edf

Browse files
committedNov 13, 2024·
More typing of GA mappers
1 parent a365dbf commit b294edf

File tree

1 file changed

+44
-25
lines changed

1 file changed

+44
-25
lines changed
 

‎pymbolic/geometric_algebra/mapper.py

+44-25
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,19 @@
2525

2626
# This is experimental, undocumented, and could go away any second.
2727
# Consider yourself warned.
28-
28+
from collections.abc import Set
2929
from typing import ClassVar
3030

3131
import pymbolic.geometric_algebra.primitives as prim
3232
from pymbolic.geometric_algebra import MultiVector
3333
from pymbolic.mapper import (
3434
CachedMapper,
35+
CollectedT,
3536
Collector as CollectorBase,
3637
CombineMapper as CombineMapperBase,
3738
IdentityMapper as IdentityMapperBase,
39+
P,
40+
ResultT,
3841
WalkMapper as WalkMapperBase,
3942
)
4043
from pymbolic.mapper.constant_folder import (
@@ -46,50 +49,66 @@
4649
PREC_NONE,
4750
StringifyMapper as StringifyMapperBase,
4851
)
52+
from pymbolic.primitives import Expression
4953

5054

51-
class IdentityMapper(IdentityMapperBase):
52-
def map_multivector_variable(self, expr):
55+
class IdentityMapper(IdentityMapperBase[P]):
56+
def map_nabla(
57+
self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> Expression:
5358
return expr
5459

55-
map_nabla = map_multivector_variable
56-
map_nabla_component = map_multivector_variable
60+
def map_nabla_component(self,
61+
expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs) -> Expression:
62+
return expr
5763

58-
def map_derivative_source(self, expr):
59-
operand = self.rec(expr.operand)
64+
def map_derivative_source(self,
65+
expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs
66+
) -> Expression:
67+
operand = self.rec(expr.operand, *args, **kwargs)
6068
if operand is expr.operand:
6169
return expr
6270

6371
return type(expr)(operand, expr.nabla_id)
6472

6573

66-
class CombineMapper(CombineMapperBase):
67-
def map_derivative_source(self, expr):
68-
return self.rec(expr.operand)
74+
class CombineMapper(CombineMapperBase[ResultT, P]):
75+
def map_derivative_source(
76+
self, expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs
77+
) -> ResultT:
78+
return self.rec(expr.operand, *args, **kwargs)
6979

7080

71-
class Collector(CollectorBase):
72-
def map_nabla(self, expr):
81+
class Collector(CollectorBase[CollectedT, P]):
82+
def map_nabla(self,
83+
expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs
84+
) -> Set[CollectedT]:
7385
return set()
7486

75-
map_nabla_component = map_nabla
87+
def map_nabla_component(self,
88+
expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
89+
) -> Set[CollectedT]:
90+
return set()
7691

7792

78-
class WalkMapper(WalkMapperBase):
79-
def map_nabla(self, expr, *args):
80-
self.visit(expr, *args)
81-
self.post_visit(expr)
93+
class WalkMapper(WalkMapperBase[P]):
94+
def map_nabla(self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> None:
95+
self.visit(expr, *args, **kwargs)
96+
self.post_visit(expr, *args, **kwargs)
8297

83-
def map_nabla_component(self, expr, *args):
84-
self.visit(expr, *args)
85-
self.post_visit(expr)
98+
def map_nabla_component(
99+
self, expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
100+
) -> None:
101+
self.visit(expr, *args, **kwargs)
102+
self.post_visit(expr, *args, **kwargs)
86103

87-
def map_derivative_source(self, expr, *args):
88-
if not self.visit(expr, *args):
104+
def map_derivative_source(
105+
self, expr, *args: P.args, **kwargs: P.kwargs
106+
) -> None:
107+
if not self.visit(expr, *args, **kwargs):
89108
return
90109

91-
self.rec(expr.operand)
92-
self.post_visit(expr)
110+
self.rec(expr.operand, *args, **kwargs)
111+
self.post_visit(expr, *args, **kwargs)
93112

94113

95114
class EvaluationMapper(EvaluationMapperBase):
@@ -106,7 +125,7 @@ def map_derivative_source(self, expr):
106125
return type(expr)(operand, expr.nabla_id)
107126

108127

109-
class StringifyMapper(StringifyMapperBase):
128+
class StringifyMapper(StringifyMapperBase[[]]):
110129
AXES: ClassVar[dict[int, str]] = {0: "x", 1: "y", 2: "z"}
111130

112131
def map_nabla(self, expr, enclosing_prec):

0 commit comments

Comments
 (0)
Please sign in to comment.