Skip to content

Commit a1e3054

Browse files
committed
Better type the EvaluationMapper
1 parent da6dded commit a1e3054

File tree

1 file changed

+72
-64
lines changed

1 file changed

+72
-64
lines changed

pymbolic/mapper/evaluator.py

+72-64
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@
3636
import operator as op
3737
from collections.abc import Mapping
3838
from functools import reduce
39-
from typing import TYPE_CHECKING, Any
39+
from typing import TYPE_CHECKING, cast
4040

4141
import pymbolic.primitives as p
42-
from pymbolic.mapper import CachedMapper, CSECachingMapperMixin, Mapper
42+
from pymbolic.mapper import CachedMapper, CSECachingMapperMixin, Mapper, ResultT
4343
from pymbolic.typing import ExpressionT
4444

4545

@@ -53,7 +53,7 @@ class UnknownVariableError(Exception):
5353
pass
5454

5555

56-
class EvaluationMapper(Mapper[Any, []], CSECachingMapperMixin):
56+
class EvaluationMapper(Mapper[ResultT, []], CSECachingMapperMixin):
5757
"""Example usage:
5858
5959
.. doctest::
@@ -70,9 +70,9 @@ class EvaluationMapper(Mapper[Any, []], CSECachingMapperMixin):
7070
110
7171
"""
7272

73-
context: Mapping[str, Any]
73+
context: Mapping[str, ResultT]
7474

75-
def __init__(self, context: Mapping[str, Any] | None = None) -> None:
75+
def __init__(self, context: Mapping[str, ResultT] | None = None) -> None:
7676
"""
7777
:arg context: a mapping from variable names to values
7878
"""
@@ -81,147 +81,151 @@ def __init__(self, context: Mapping[str, Any] | None = None) -> None:
8181

8282
self.context = context
8383

84-
def map_constant(self, expr: object) -> Any:
85-
return expr
84+
def map_constant(self, expr: object) -> ResultT:
85+
return cast(ResultT, expr)
8686

87-
def map_variable(self, expr: p.Variable) -> None:
87+
def map_variable(self, expr: p.Variable) -> ResultT:
8888
try:
8989
return self.context[expr.name]
9090
except KeyError:
9191
raise UnknownVariableError(expr.name) from None
9292

93-
def map_call(self, expr: p.Call) -> Any:
94-
return self.rec(expr.function)(*[self.rec(par) for par in expr.parameters])
93+
def map_call(self, expr: p.Call) -> ResultT:
94+
return self.rec(expr.function)(*[self.rec(par) for par in expr.parameters]) # type: ignore[operator]
9595

96-
def map_call_with_kwargs(self, expr: p.CallWithKwargs) -> Any:
96+
def map_call_with_kwargs(self, expr: p.CallWithKwargs) -> ResultT:
9797
args = [self.rec(par) for par in expr.parameters]
9898
kwargs = {
9999
k: self.rec(v)
100100
for k, v in expr.kw_parameters.items()}
101101

102-
return self.rec(expr.function)(*args, **kwargs)
102+
return self.rec(expr.function)(*args, **kwargs) # type: ignore[operator]
103103

104-
def map_subscript(self, expr: p.Subscript) -> Any:
105-
return self.rec(expr.aggregate)[self.rec(expr.index)]
104+
def map_subscript(self, expr: p.Subscript) -> ResultT:
105+
return self.rec(expr.aggregate)[self.rec(expr.index)] # type: ignore[index]
106106

107-
def map_lookup(self, expr: p.Lookup) -> Any:
107+
def map_lookup(self, expr: p.Lookup) -> ResultT:
108108
return getattr(self.rec(expr.aggregate), expr.name)
109109

110-
def map_sum(self, expr: p.Sum) -> Any:
111-
return sum(self.rec(child) for child in expr.children)
110+
def map_sum(self, expr: p.Sum) -> ResultT:
111+
return sum(self.rec(child) for child in expr.children) # type: ignore[return-value, misc]
112112

113-
def map_product(self, expr: p.Product) -> Any:
113+
def map_product(self, expr: p.Product) -> ResultT:
114114
from pytools import product
115115
return product(self.rec(child) for child in expr.children)
116116

117-
def map_quotient(self, expr: p.Quotient) -> Any:
118-
return self.rec(expr.numerator) / self.rec(expr.denominator)
117+
def map_quotient(self, expr: p.Quotient) -> ResultT:
118+
return self.rec(expr.numerator) / self.rec(expr.denominator) # type: ignore[operator]
119119

120-
def map_floor_div(self, expr: p.FloorDiv) -> Any:
121-
return self.rec(expr.numerator) // self.rec(expr.denominator)
120+
def map_floor_div(self, expr: p.FloorDiv) -> ResultT:
121+
return self.rec(expr.numerator) // self.rec(expr.denominator) # type: ignore[operator]
122122

123-
def map_remainder(self, expr: p.Remainder) -> Any:
124-
return self.rec(expr.numerator) % self.rec(expr.denominator)
123+
def map_remainder(self, expr: p.Remainder) -> ResultT:
124+
return self.rec(expr.numerator) % self.rec(expr.denominator) # type: ignore[operator]
125125

126-
def map_power(self, expr: p.Power) -> Any:
127-
return self.rec(expr.base) ** self.rec(expr.exponent)
126+
def map_power(self, expr: p.Power) -> ResultT:
127+
return self.rec(expr.base) ** self.rec(expr.exponent) # type: ignore[operator]
128128

129-
def map_left_shift(self, expr: p.LeftShift) -> Any:
130-
return self.rec(expr.shiftee) << self.rec(expr.shift)
129+
def map_left_shift(self, expr: p.LeftShift) -> ResultT:
130+
return self.rec(expr.shiftee) << self.rec(expr.shift) # type: ignore[operator]
131131

132-
def map_right_shift(self, expr: p.RightShift) -> Any:
133-
return self.rec(expr.shiftee) >> self.rec(expr.shift)
132+
def map_right_shift(self, expr: p.RightShift) -> ResultT:
133+
return self.rec(expr.shiftee) >> self.rec(expr.shift) # type: ignore[operator]
134134

135-
def map_bitwise_not(self, expr: p.BitwiseNot) -> Any:
135+
def map_bitwise_not(self, expr: p.BitwiseNot) -> ResultT:
136136
# ??? Why, pylint, why ???
137137
# pylint: disable=invalid-unary-operand-type
138-
return ~self.rec(expr.child)
138+
return ~self.rec(expr.child) # type: ignore[operator]
139139

140-
def map_bitwise_or(self, expr: p.BitwiseOr) -> Any:
140+
def map_bitwise_or(self, expr: p.BitwiseOr) -> ResultT:
141141
return reduce(op.or_, (self.rec(ch) for ch in expr.children))
142142

143-
def map_bitwise_xor(self, expr: p.BitwiseXor) -> Any:
143+
def map_bitwise_xor(self, expr: p.BitwiseXor) -> ResultT:
144144
return reduce(op.xor, (self.rec(ch) for ch in expr.children))
145145

146-
def map_bitwise_and(self, expr: p.BitwiseAnd) -> Any:
146+
def map_bitwise_and(self, expr: p.BitwiseAnd) -> ResultT:
147147
return reduce(op.and_, (self.rec(ch) for ch in expr.children))
148148

149-
def map_logical_not(self, expr: p.LogicalNot) -> Any:
149+
def map_logical_not(self, expr: p.LogicalNot) -> bool: # type: ignore[override]
150150
return not self.rec(expr.child)
151151

152-
def map_logical_or(self, expr: p.LogicalOr) -> Any:
152+
def map_logical_or(self, expr: p.LogicalOr) -> bool: # type: ignore[override]
153153
return any(self.rec(ch) for ch in expr.children)
154154

155-
def map_logical_and(self, expr: p.LogicalAnd) -> Any:
155+
def map_logical_and(self, expr: p.LogicalAnd) -> bool: # type: ignore[override]
156156
return all(self.rec(ch) for ch in expr.children)
157157

158-
def map_list(self, expr: list[ExpressionT]) -> Any:
159-
return [self.rec(child) for child in expr]
158+
def map_list(self, expr: list[ExpressionT]) -> ResultT:
159+
return [self.rec(child) for child in expr] # type: ignore[return-value]
160160

161-
def map_numpy_array(self, expr: np.ndarray) -> Any:
161+
def map_numpy_array(self, expr: np.ndarray) -> ResultT:
162162
import numpy
163163
result = numpy.empty(expr.shape, dtype=object)
164164
for i in numpy.ndindex(expr.shape):
165165
result[i] = self.rec(expr[i])
166-
return result
166+
return result # type: ignore[return-value]
167167

168-
def map_multivector(self, expr: MultiVector) -> Any:
168+
def map_multivector(self, expr: MultiVector) -> ResultT:
169169
return expr.map(lambda ch: self.rec(ch))
170170

171-
def map_common_subexpression_uncached(self, expr: p.CommonSubexpression) -> Any:
171+
def map_common_subexpression_uncached(self, expr: p.CommonSubexpression) -> ResultT:
172172
return self.rec(expr.child)
173173

174-
def map_if(self, expr: p.If) -> Any:
174+
def map_if(self, expr: p.If) -> ResultT:
175175
if self.rec(expr.condition):
176176
return self.rec(expr.then)
177177
else:
178178
return self.rec(expr.else_)
179179

180-
def map_comparison(self, expr: p.Comparison) -> Any:
180+
def map_comparison(self, expr: p.Comparison) -> ResultT:
181181
import operator
182182
return getattr(operator, expr.operator_to_name[expr.operator])(
183183
self.rec(expr.left), self.rec(expr.right))
184184

185-
def map_min(self, expr: p.Min) -> Any:
186-
return min(self.rec(child) for child in expr.children)
185+
def map_min(self, expr: p.Min) -> ResultT:
186+
return min(self.rec(child) for child in expr.children) # type: ignore[type-var]
187187

188-
def map_max(self, expr: p.Max) -> Any:
189-
return max(self.rec(child) for child in expr.children)
188+
def map_max(self, expr: p.Max) -> ResultT:
189+
return max(self.rec(child) for child in expr.children) # type: ignore[type-var]
190190

191-
def map_tuple(self, expr: tuple[ExpressionT, ...]) -> Any:
192-
return tuple([self.rec(child) for child in expr])
191+
def map_tuple(self, expr: tuple[ExpressionT, ...]) -> ResultT:
192+
return tuple([self.rec(child) for child in expr]) # type: ignore[return-value]
193193

194-
def map_nan(self, expr: p.NaN) -> Any:
194+
def map_nan(self, expr: p.NaN) -> ResultT:
195195
if expr.data_type is None:
196196
from math import nan
197-
return nan
197+
return nan # type:ignore[return-value]
198198
else:
199199
return expr.data_type(float("nan"))
200200

201201

202-
class CachedEvaluationMapper(CachedMapper, EvaluationMapper):
202+
class CachedEvaluationMapper(CachedMapper[ResultT, []], EvaluationMapper[ResultT]):
203203
def __init__(self, context=None):
204204
CachedMapper.__init__(self)
205205
EvaluationMapper.__init__(self, context=context)
206206

207207

208-
class FloatEvaluationMapper(EvaluationMapper):
209-
def map_constant(self, expr):
208+
class FloatEvaluationMapper(EvaluationMapper[float]):
209+
def map_constant(self, expr) -> float:
210210
return float(expr)
211211

212-
def map_rational(self, expr):
212+
def map_rational(self, expr) -> float:
213213
return self.rec(expr.numerator) / self.rec(expr.denominator)
214214

215215

216-
class CachedFloatEvaluationMapper(CachedEvaluationMapper):
217-
def map_constant(self, expr):
216+
class CachedFloatEvaluationMapper(CachedEvaluationMapper[float]):
217+
def map_constant(self, expr) -> float:
218218
return float(expr)
219219

220-
def map_rational(self, expr):
220+
def map_rational(self, expr) -> float:
221221
return self.rec(expr.numerator) / self.rec(expr.denominator)
222222

223223

224-
def evaluate(expression, context=None, mapper_cls=CachedEvaluationMapper) -> Any:
224+
def evaluate(
225+
expression: ExpressionT,
226+
context: Mapping[str, ResultT] | None = None,
227+
mapper_cls: type[EvaluationMapper[ResultT]] = CachedEvaluationMapper,
228+
) -> ResultT:
225229
"""
226230
:arg mapper_cls: A :class:`type` of the evaluation mapper
227231
whose instance performs the evaluation.
@@ -231,7 +235,11 @@ def evaluate(expression, context=None, mapper_cls=CachedEvaluationMapper) -> Any
231235
return mapper_cls(context)(expression)
232236

233237

234-
def evaluate_kw(expression, mapper_cls=CachedEvaluationMapper, **context) -> Any:
238+
def evaluate_kw(
239+
expression: ExpressionT,
240+
mapper_cls: type[EvaluationMapper[ResultT]] = CachedEvaluationMapper,
241+
**context: ResultT,
242+
) -> ResultT:
235243
"""
236244
:arg mapper_cls: A :class:`type` of the evaluation mapper
237245
whose instance performs the evaluation.
@@ -240,7 +248,7 @@ def evaluate_kw(expression, mapper_cls=CachedEvaluationMapper, **context) -> Any
240248

241249

242250
def evaluate_to_float(expression, context=None,
243-
mapper_cls=CachedFloatEvaluationMapper) -> Any:
251+
mapper_cls=CachedFloatEvaluationMapper) -> float:
244252
"""
245253
:arg mapper_cls: A :class:`type` of the evaluation mapper
246254
whose instance performs the evaluation.

0 commit comments

Comments
 (0)