36
36
import operator as op
37
37
from collections .abc import Mapping
38
38
from functools import reduce
39
- from typing import TYPE_CHECKING , Any
39
+ from typing import TYPE_CHECKING , cast
40
40
41
41
import pymbolic .primitives as p
42
- from pymbolic .mapper import CachedMapper , CSECachingMapperMixin , Mapper
42
+ from pymbolic .mapper import CachedMapper , CSECachingMapperMixin , Mapper , ResultT
43
43
from pymbolic .typing import ExpressionT
44
44
45
45
@@ -53,7 +53,7 @@ class UnknownVariableError(Exception):
53
53
pass
54
54
55
55
56
- class EvaluationMapper (Mapper [Any , []], CSECachingMapperMixin ):
56
+ class EvaluationMapper (Mapper [ResultT , []], CSECachingMapperMixin ):
57
57
"""Example usage:
58
58
59
59
.. doctest::
@@ -70,9 +70,9 @@ class EvaluationMapper(Mapper[Any, []], CSECachingMapperMixin):
70
70
110
71
71
"""
72
72
73
- context : Mapping [str , Any ]
73
+ context : Mapping [str , ResultT ]
74
74
75
- def __init__ (self , context : Mapping [str , Any ] | None = None ) -> None :
75
+ def __init__ (self , context : Mapping [str , ResultT ] | None = None ) -> None :
76
76
"""
77
77
:arg context: a mapping from variable names to values
78
78
"""
@@ -81,147 +81,151 @@ def __init__(self, context: Mapping[str, Any] | None = None) -> None:
81
81
82
82
self .context = context
83
83
84
- def map_constant (self , expr : object ) -> Any :
85
- return expr
84
+ def map_constant (self , expr : object ) -> ResultT :
85
+ return cast ( ResultT , expr )
86
86
87
- def map_variable (self , expr : p .Variable ) -> None :
87
+ def map_variable (self , expr : p .Variable ) -> ResultT :
88
88
try :
89
89
return self .context [expr .name ]
90
90
except KeyError :
91
91
raise UnknownVariableError (expr .name ) from None
92
92
93
- def map_call (self , expr : p .Call ) -> Any :
94
- return self .rec (expr .function )(* [self .rec (par ) for par in expr .parameters ])
93
+ def map_call (self , expr : p .Call ) -> ResultT :
94
+ return self .rec (expr .function )(* [self .rec (par ) for par in expr .parameters ]) # type: ignore[operator]
95
95
96
- def map_call_with_kwargs (self , expr : p .CallWithKwargs ) -> Any :
96
+ def map_call_with_kwargs (self , expr : p .CallWithKwargs ) -> ResultT :
97
97
args = [self .rec (par ) for par in expr .parameters ]
98
98
kwargs = {
99
99
k : self .rec (v )
100
100
for k , v in expr .kw_parameters .items ()}
101
101
102
- return self .rec (expr .function )(* args , ** kwargs )
102
+ return self .rec (expr .function )(* args , ** kwargs ) # type: ignore[operator]
103
103
104
- def map_subscript (self , expr : p .Subscript ) -> Any :
105
- return self .rec (expr .aggregate )[self .rec (expr .index )]
104
+ def map_subscript (self , expr : p .Subscript ) -> ResultT :
105
+ return self .rec (expr .aggregate )[self .rec (expr .index )] # type: ignore[index]
106
106
107
- def map_lookup (self , expr : p .Lookup ) -> Any :
107
+ def map_lookup (self , expr : p .Lookup ) -> ResultT :
108
108
return getattr (self .rec (expr .aggregate ), expr .name )
109
109
110
- def map_sum (self , expr : p .Sum ) -> Any :
111
- return sum (self .rec (child ) for child in expr .children )
110
+ def map_sum (self , expr : p .Sum ) -> ResultT :
111
+ return sum (self .rec (child ) for child in expr .children ) # type: ignore[return-value, misc]
112
112
113
- def map_product (self , expr : p .Product ) -> Any :
113
+ def map_product (self , expr : p .Product ) -> ResultT :
114
114
from pytools import product
115
115
return product (self .rec (child ) for child in expr .children )
116
116
117
- def map_quotient (self , expr : p .Quotient ) -> Any :
118
- return self .rec (expr .numerator ) / self .rec (expr .denominator )
117
+ def map_quotient (self , expr : p .Quotient ) -> ResultT :
118
+ return self .rec (expr .numerator ) / self .rec (expr .denominator ) # type: ignore[operator]
119
119
120
- def map_floor_div (self , expr : p .FloorDiv ) -> Any :
121
- return self .rec (expr .numerator ) // self .rec (expr .denominator )
120
+ def map_floor_div (self , expr : p .FloorDiv ) -> ResultT :
121
+ return self .rec (expr .numerator ) // self .rec (expr .denominator ) # type: ignore[operator]
122
122
123
- def map_remainder (self , expr : p .Remainder ) -> Any :
124
- return self .rec (expr .numerator ) % self .rec (expr .denominator )
123
+ def map_remainder (self , expr : p .Remainder ) -> ResultT :
124
+ return self .rec (expr .numerator ) % self .rec (expr .denominator ) # type: ignore[operator]
125
125
126
- def map_power (self , expr : p .Power ) -> Any :
127
- return self .rec (expr .base ) ** self .rec (expr .exponent )
126
+ def map_power (self , expr : p .Power ) -> ResultT :
127
+ return self .rec (expr .base ) ** self .rec (expr .exponent ) # type: ignore[operator]
128
128
129
- def map_left_shift (self , expr : p .LeftShift ) -> Any :
130
- return self .rec (expr .shiftee ) << self .rec (expr .shift )
129
+ def map_left_shift (self , expr : p .LeftShift ) -> ResultT :
130
+ return self .rec (expr .shiftee ) << self .rec (expr .shift ) # type: ignore[operator]
131
131
132
- def map_right_shift (self , expr : p .RightShift ) -> Any :
133
- return self .rec (expr .shiftee ) >> self .rec (expr .shift )
132
+ def map_right_shift (self , expr : p .RightShift ) -> ResultT :
133
+ return self .rec (expr .shiftee ) >> self .rec (expr .shift ) # type: ignore[operator]
134
134
135
- def map_bitwise_not (self , expr : p .BitwiseNot ) -> Any :
135
+ def map_bitwise_not (self , expr : p .BitwiseNot ) -> ResultT :
136
136
# ??? Why, pylint, why ???
137
137
# pylint: disable=invalid-unary-operand-type
138
- return ~ self .rec (expr .child )
138
+ return ~ self .rec (expr .child ) # type: ignore[operator]
139
139
140
- def map_bitwise_or (self , expr : p .BitwiseOr ) -> Any :
140
+ def map_bitwise_or (self , expr : p .BitwiseOr ) -> ResultT :
141
141
return reduce (op .or_ , (self .rec (ch ) for ch in expr .children ))
142
142
143
- def map_bitwise_xor (self , expr : p .BitwiseXor ) -> Any :
143
+ def map_bitwise_xor (self , expr : p .BitwiseXor ) -> ResultT :
144
144
return reduce (op .xor , (self .rec (ch ) for ch in expr .children ))
145
145
146
- def map_bitwise_and (self , expr : p .BitwiseAnd ) -> Any :
146
+ def map_bitwise_and (self , expr : p .BitwiseAnd ) -> ResultT :
147
147
return reduce (op .and_ , (self .rec (ch ) for ch in expr .children ))
148
148
149
- def map_logical_not (self , expr : p .LogicalNot ) -> Any :
149
+ def map_logical_not (self , expr : p .LogicalNot ) -> bool : # type: ignore[override]
150
150
return not self .rec (expr .child )
151
151
152
- def map_logical_or (self , expr : p .LogicalOr ) -> Any :
152
+ def map_logical_or (self , expr : p .LogicalOr ) -> bool : # type: ignore[override]
153
153
return any (self .rec (ch ) for ch in expr .children )
154
154
155
- def map_logical_and (self , expr : p .LogicalAnd ) -> Any :
155
+ def map_logical_and (self , expr : p .LogicalAnd ) -> bool : # type: ignore[override]
156
156
return all (self .rec (ch ) for ch in expr .children )
157
157
158
- def map_list (self , expr : list [ExpressionT ]) -> Any :
159
- return [self .rec (child ) for child in expr ]
158
+ def map_list (self , expr : list [ExpressionT ]) -> ResultT :
159
+ return [self .rec (child ) for child in expr ] # type: ignore[return-value]
160
160
161
- def map_numpy_array (self , expr : np .ndarray ) -> Any :
161
+ def map_numpy_array (self , expr : np .ndarray ) -> ResultT :
162
162
import numpy
163
163
result = numpy .empty (expr .shape , dtype = object )
164
164
for i in numpy .ndindex (expr .shape ):
165
165
result [i ] = self .rec (expr [i ])
166
- return result
166
+ return result # type: ignore[return-value]
167
167
168
- def map_multivector (self , expr : MultiVector ) -> Any :
168
+ def map_multivector (self , expr : MultiVector ) -> ResultT :
169
169
return expr .map (lambda ch : self .rec (ch ))
170
170
171
- def map_common_subexpression_uncached (self , expr : p .CommonSubexpression ) -> Any :
171
+ def map_common_subexpression_uncached (self , expr : p .CommonSubexpression ) -> ResultT :
172
172
return self .rec (expr .child )
173
173
174
- def map_if (self , expr : p .If ) -> Any :
174
+ def map_if (self , expr : p .If ) -> ResultT :
175
175
if self .rec (expr .condition ):
176
176
return self .rec (expr .then )
177
177
else :
178
178
return self .rec (expr .else_ )
179
179
180
- def map_comparison (self , expr : p .Comparison ) -> Any :
180
+ def map_comparison (self , expr : p .Comparison ) -> ResultT :
181
181
import operator
182
182
return getattr (operator , expr .operator_to_name [expr .operator ])(
183
183
self .rec (expr .left ), self .rec (expr .right ))
184
184
185
- def map_min (self , expr : p .Min ) -> Any :
186
- return min (self .rec (child ) for child in expr .children )
185
+ def map_min (self , expr : p .Min ) -> ResultT :
186
+ return min (self .rec (child ) for child in expr .children ) # type: ignore[type-var]
187
187
188
- def map_max (self , expr : p .Max ) -> Any :
189
- return max (self .rec (child ) for child in expr .children )
188
+ def map_max (self , expr : p .Max ) -> ResultT :
189
+ return max (self .rec (child ) for child in expr .children ) # type: ignore[type-var]
190
190
191
- def map_tuple (self , expr : tuple [ExpressionT , ...]) -> Any :
192
- return tuple ([self .rec (child ) for child in expr ])
191
+ def map_tuple (self , expr : tuple [ExpressionT , ...]) -> ResultT :
192
+ return tuple ([self .rec (child ) for child in expr ]) # type: ignore[return-value]
193
193
194
- def map_nan (self , expr : p .NaN ) -> Any :
194
+ def map_nan (self , expr : p .NaN ) -> ResultT :
195
195
if expr .data_type is None :
196
196
from math import nan
197
- return nan
197
+ return nan # type:ignore[return-value]
198
198
else :
199
199
return expr .data_type (float ("nan" ))
200
200
201
201
202
- class CachedEvaluationMapper (CachedMapper , EvaluationMapper ):
202
+ class CachedEvaluationMapper (CachedMapper [ ResultT , []], EvaluationMapper [ ResultT ] ):
203
203
def __init__ (self , context = None ):
204
204
CachedMapper .__init__ (self )
205
205
EvaluationMapper .__init__ (self , context = context )
206
206
207
207
208
- class FloatEvaluationMapper (EvaluationMapper ):
209
- def map_constant (self , expr ):
208
+ class FloatEvaluationMapper (EvaluationMapper [ float ] ):
209
+ def map_constant (self , expr ) -> float :
210
210
return float (expr )
211
211
212
- def map_rational (self , expr ):
212
+ def map_rational (self , expr ) -> float :
213
213
return self .rec (expr .numerator ) / self .rec (expr .denominator )
214
214
215
215
216
- class CachedFloatEvaluationMapper (CachedEvaluationMapper ):
217
- def map_constant (self , expr ):
216
+ class CachedFloatEvaluationMapper (CachedEvaluationMapper [ float ] ):
217
+ def map_constant (self , expr ) -> float :
218
218
return float (expr )
219
219
220
- def map_rational (self , expr ):
220
+ def map_rational (self , expr ) -> float :
221
221
return self .rec (expr .numerator ) / self .rec (expr .denominator )
222
222
223
223
224
- def evaluate (expression , context = None , mapper_cls = CachedEvaluationMapper ) -> Any :
224
+ def evaluate (
225
+ expression : ExpressionT ,
226
+ context : Mapping [str , ResultT ] | None = None ,
227
+ mapper_cls : type [EvaluationMapper [ResultT ]] = CachedEvaluationMapper ,
228
+ ) -> ResultT :
225
229
"""
226
230
:arg mapper_cls: A :class:`type` of the evaluation mapper
227
231
whose instance performs the evaluation.
@@ -231,7 +235,11 @@ def evaluate(expression, context=None, mapper_cls=CachedEvaluationMapper) -> Any
231
235
return mapper_cls (context )(expression )
232
236
233
237
234
- def evaluate_kw (expression , mapper_cls = CachedEvaluationMapper , ** context ) -> Any :
238
+ def evaluate_kw (
239
+ expression : ExpressionT ,
240
+ mapper_cls : type [EvaluationMapper [ResultT ]] = CachedEvaluationMapper ,
241
+ ** context : ResultT ,
242
+ ) -> ResultT :
235
243
"""
236
244
:arg mapper_cls: A :class:`type` of the evaluation mapper
237
245
whose instance performs the evaluation.
@@ -240,7 +248,7 @@ def evaluate_kw(expression, mapper_cls=CachedEvaluationMapper, **context) -> Any
240
248
241
249
242
250
def evaluate_to_float (expression , context = None ,
243
- mapper_cls = CachedFloatEvaluationMapper ) -> Any :
251
+ mapper_cls = CachedFloatEvaluationMapper ) -> float :
244
252
"""
245
253
:arg mapper_cls: A :class:`type` of the evaluation mapper
246
254
whose instance performs the evaluation.
0 commit comments