Skip to content

Commit bf6a947

Browse files
committed
Configure, pass mypy, mark typed
1 parent ae19d5f commit bf6a947

15 files changed

+342
-155
lines changed

.github/workflows/ci.yml

+17
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ jobs:
4444
4545
run_pylint pymbolic test/test_*.py
4646
47+
mypy:
48+
name: Mypy
49+
runs-on: ubuntu-latest
50+
steps:
51+
- uses: actions/checkout@v4
52+
-
53+
uses: actions/setup-python@v5
54+
with:
55+
python-version: '3.x'
56+
- name: "Main Script"
57+
run: |
58+
curl -L -O https://tiker.net/ci-support-v0
59+
. ./ci-support-v0
60+
build_py_project_in_venv
61+
python -m pip install mypy numpy
62+
./run-mypy.sh
63+
4764
pytest:
4865
name: Pytest on Py${{ matrix.python-version }}
4966
runs-on: ubuntu-latest

.gitlab-ci.yml

+12
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,18 @@ Pylint:
4040
except:
4141
- tags
4242

43+
Mypy:
44+
script: |
45+
curl -L -O https://tiker.net/ci-support-v0
46+
. ./ci-support-v0
47+
build_py_project_in_venv
48+
python -m pip install mypy
49+
./run-mypy.sh
50+
tags:
51+
- python3
52+
except:
53+
- tags
54+
4355
Documentation:
4456
script:
4557
- EXTRA_INSTALL="numpy sympy"

doc/conf.py

+2
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@
2222
"numpy": ("https://numpy.org/doc/stable/", None),
2323
"python": ("https://docs.python.org/3", None),
2424
"sympy": ("https://docs.sympy.org/dev/", None),
25+
"typing_extensions":
26+
("https://typing-extensions.readthedocs.io/en/latest/", None),
2527
}

pymbolic/imperative/statement.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from sys import intern
2626
from pytools import RecordWithoutPickling
27+
from pymbolic.typing import not_none
2728

2829

2930
# {{{ statemetn classes
@@ -88,7 +89,7 @@ def get_dependency_mapper(self, include_calls="descend_args"):
8889
# {{{ statement with condition
8990

9091
class ConditionalStatement(Statement):
91-
__doc__ = Statement.__doc__ + """
92+
__doc__ = not_none(Statement.__doc__) + """
9293
.. attribute:: condition
9394
9495
The instruction condition as a :mod:`pymbolic` expression (`True` if the

pymbolic/interop/ast.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,9 @@ def map_constant(self, expr: ScalarT) -> ast.expr:
288288
def map_call(self, expr: p.Call) -> ast.expr:
289289
return ast.Call(
290290
func=self.rec(expr.function),
291-
args=[self.rec(param) for param in expr.parameters])
291+
args=[self.rec(param) for param in expr.parameters],
292+
keywords=[],
293+
)
292294

293295
def map_call_with_kwargs(self, expr) -> ast.expr:
294296
return ast.Call(
@@ -354,7 +356,7 @@ def map_bitwise_and(self, expr) -> ast.expr:
354356
ast.BitAnd())
355357

356358
def map_logical_not(self, expr) -> ast.expr:
357-
return ast.UnaryOp(self.rec(expr.child), ast.Not())
359+
return ast.UnaryOp(ast.Not(), self.rec(expr.child))
358360

359361
def map_logical_or(self, expr) -> ast.expr:
360362
return ast.BoolOp(ast.Or(), [self.rec(child)
@@ -376,6 +378,7 @@ def map_if(self, expr: p.If) -> ast.expr:
376378
orelse=self.rec(expr.else_))
377379

378380
def map_nan(self, expr: p.NaN) -> ast.expr:
381+
assert expr.data_type is not None
379382
if isinstance(expr.data_type(float("nan")), float):
380383
return ast.Call(
381384
ast.Name(id="float"),

pymbolic/interop/matchpy/__init__.py

+39-37
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from __future__ import annotations
2+
3+
14
"""
25
Interoperability with :mod:`matchpy.functions` for pattern-matching and
36
term-rewriting.
@@ -38,22 +41,21 @@
3841
"""
3942

4043

41-
import numpy as np
4244
import abc
4345
import pymbolic.primitives as p
4446

45-
from typing import (Union, ClassVar, Optional, Iterator, Mapping,
46-
Generic, TypeVar, Tuple, Iterable, Callable)
47+
from typing_extensions import TypeAlias
48+
from typing import (ClassVar, Iterator, Mapping,
49+
Generic, TypeVar, Iterable, Callable)
4750
from dataclasses import dataclass, fields, field
4851

4952
from matchpy import (Operation, Arity, Expression, Atom as BaseAtom,
5053
Wildcard as BaseWildcard, ReplacementRule)
51-
from numbers import Number
5254
from functools import partial
55+
from pymbolic.typing import ScalarT
5356

5457

55-
ScalarT = Union[Number, int, np.bool_, bool]
56-
ExprT = Expression
58+
ExprT: TypeAlias = Expression
5759
ConstantT = TypeVar("ConstantT")
5860
ToMatchpyT = Callable[[p.Expression], ExprT]
5961
FromMatchpyT = Callable[[ExprT], p.Expression]
@@ -70,7 +72,7 @@
7072
@op_dataclass
7173
class _Constant(BaseAtom, Generic[ConstantT]):
7274
value: ConstantT
73-
variable_name: Optional[str] = None
75+
variable_name: str | None = None
7476

7577
@property
7678
def head(self):
@@ -89,7 +91,7 @@ def __lt__(self, other):
8991

9092
@op_dataclass
9193
class Scalar(_Constant[ScalarT]):
92-
_mapper_method: [str] = "map_scalar"
94+
_mapper_method: str = "map_scalar"
9395

9496

9597
@op_dataclass
@@ -104,8 +106,8 @@ class ComparisonOp(_Constant[str]):
104106

105107
@op_dataclass
106108
class TupleOp(Operation):
107-
_operands: Tuple[ExprT, ...]
108-
variable_name: Optional[str] = non_operand_field(default=None)
109+
_operands: tuple[ExprT, ...]
110+
variable_name: str | None = non_operand_field(default=None)
109111

110112
arity: ClassVar[Arity] = Arity.variadic
111113
name: ClassVar[str] = "tuple"
@@ -129,7 +131,7 @@ def variable_name(self):
129131
pass
130132

131133
@property
132-
def operands(self) -> Tuple[Expression]:
134+
def operands(self) -> tuple[Expression, ...]:
133135
return tuple(getattr(self, field.name)
134136
for field in fields(self)
135137
if not field.metadata.get("not_an_operand", False))
@@ -150,15 +152,15 @@ def name(self) -> str:
150152
class Variable(PymbolicOp):
151153
id: Id
152154
arity: ClassVar[Arity] = Arity.unary
153-
variable_name: Optional[str] = non_operand_field(default=None)
155+
variable_name: str | None = non_operand_field(default=None)
154156
_mapper_method: ClassVar[str] = "map_variable"
155157

156158

157159
@op_dataclass
158160
class Call(PymbolicOp):
159161
function: ExprT
160162
args: TupleOp
161-
variable_name: Optional[str] = non_operand_field(default=None)
163+
variable_name: str | None = non_operand_field(default=None)
162164

163165
arity: ClassVar[Arity] = Arity.binary
164166
_mapper_method: ClassVar[str] = "map_call"
@@ -168,7 +170,7 @@ class Call(PymbolicOp):
168170
class Subscript(PymbolicOp):
169171
aggregate: ExprT
170172
indices: TupleOp
171-
variable_name: Optional[str] = non_operand_field(default=None)
173+
variable_name: str | None = non_operand_field(default=None)
172174

173175
arity: ClassVar[Arity] = Arity.binary
174176
_mapper_method: ClassVar[str] = "map_subscript"
@@ -182,7 +184,7 @@ class _BinaryOp(PymbolicOp):
182184
x2: ExprT
183185

184186
arity: ClassVar[Arity] = Arity.binary
185-
variable_name: Optional[str] = non_operand_field(default=None)
187+
variable_name: str | None = non_operand_field(default=None)
186188

187189

188190
@op_dataclass
@@ -224,8 +226,8 @@ class RightShift(_BinaryOp):
224226

225227
@variadic_op_dataclass
226228
class _VariadicCommAssocOp(PymbolicOp):
227-
children: Tuple[ExprT, ...]
228-
variable_name: Optional[str] = non_operand_field(default=None)
229+
children: tuple[ExprT, ...]
230+
variable_name: str | None = non_operand_field(default=None)
229231

230232
commutative: ClassVar[bool] = True
231233
associative: ClassVar[bool] = True
@@ -236,7 +238,7 @@ def __init__(self, *children: ExprT, variable_name=None):
236238
object.__setattr__(self, "variable_name", variable_name)
237239

238240
@property
239-
def operands(self) -> Tuple[ExprT, ...]:
241+
def operands(self) -> tuple[ExprT, ...]:
240242
return self.children
241243

242244

@@ -283,7 +285,7 @@ class BitwiseXor(_VariadicCommAssocOp):
283285
class _UnaryOp(PymbolicOp):
284286
x: ExprT
285287
arity: ClassVar[Arity] = Arity.unary
286-
variable_name: Optional[str] = non_operand_field(default=None)
288+
variable_name: str | None = non_operand_field(default=None)
287289

288290

289291
@op_dataclass
@@ -303,7 +305,7 @@ class Comparison(PymbolicOp):
303305
left: ExprT
304306
operator: ComparisonOp
305307
right: ExprT
306-
variable_name: Optional[str] = non_operand_field(default=None)
308+
variable_name: str | None = non_operand_field(default=None)
307309

308310
arity: ClassVar[Arity] = Arity.ternary
309311
_mapper_method: ClassVar[str] = "map_comparison"
@@ -314,7 +316,7 @@ class If(PymbolicOp):
314316
condition: ExprT
315317
then: ExprT
316318
else_: ExprT
317-
variable_name: Optional[str] = non_operand_field(default=None)
319+
variable_name: str | None = non_operand_field(default=None)
318320

319321
arity: ClassVar[Arity] = Arity.ternary
320322
_mapper_method: ClassVar[str] = "map_if"
@@ -325,16 +327,16 @@ class Wildcard(BaseWildcard):
325327
# {{{ FIXME: This should go into matchpy itself.
326328

327329
@classmethod
328-
def dot(cls, name=None) -> "Wildcard":
330+
def dot(cls, name=None) -> Wildcard:
329331
return cls(min_count=1, fixed_size=True, variable_name=name)
330332

331333
@classmethod
332-
def star(cls, name=None) -> "Wildcard":
334+
def star(cls, name=None) -> Wildcard:
333335
# FIXME: This should go into matchpy itself.
334336
return cls(min_count=0, fixed_size=False, variable_name=name)
335337

336338
@classmethod
337-
def plus(cls, name=None) -> "Wildcard":
339+
def plus(cls, name=None) -> Wildcard:
338340
# FIXME: This should go into matchpy itself.
339341
return cls(min_count=1, fixed_size=False, variable_name=name)
340342

@@ -343,7 +345,7 @@ def plus(cls, name=None) -> "Wildcard":
343345
# }}}
344346

345347

346-
def _get_operand_at_path(expr: PymbolicOp, path: Tuple[int, ...]) -> PymbolicOp:
348+
def _get_operand_at_path(expr: PymbolicOp, path: tuple[int, ...]) -> PymbolicOp:
347349
result = expr
348350

349351
while path:
@@ -355,9 +357,9 @@ def _get_operand_at_path(expr: PymbolicOp, path: Tuple[int, ...]) -> PymbolicOp:
355357

356358
def match(subject: p.Expression,
357359
pattern: p.Expression,
358-
to_matchpy_expr: Optional[ToMatchpyT] = None,
359-
from_matchpy_expr: Optional[FromMatchpyT] = None
360-
) -> Iterator[Mapping[str, Union[p.Expression, ScalarT]]]:
360+
to_matchpy_expr: ToMatchpyT | None = None,
361+
from_matchpy_expr: FromMatchpyT | None = None
362+
) -> Iterator[Mapping[str, p.Expression | ScalarT]]:
361363
from matchpy import match, Pattern
362364
from .tofrom import (ToMatchpyExpressionMapper,
363365
FromMatchpyExpressionMapper)
@@ -378,10 +380,10 @@ def match(subject: p.Expression,
378380

379381
def match_anywhere(subject: p.Expression,
380382
pattern: p.Expression,
381-
to_matchpy_expr: Optional[ToMatchpyT] = None,
382-
from_matchpy_expr: Optional[FromMatchpyT] = None
383-
) -> Iterator[Tuple[Mapping[str, Union[p.Expression, ScalarT]],
384-
Union[p.Expression, ScalarT]]
383+
to_matchpy_expr: ToMatchpyT | None = None,
384+
from_matchpy_expr: FromMatchpyT | None = None
385+
) -> Iterator[tuple[Mapping[str, p.Expression | ScalarT],
386+
p.Expression | ScalarT]
385387
]:
386388
from matchpy import match_anywhere, Pattern
387389
from .tofrom import (ToMatchpyExpressionMapper,
@@ -404,8 +406,8 @@ def match_anywhere(subject: p.Expression,
404406

405407
def make_replacement_rule(pattern: p.Expression,
406408
replacement: Callable[..., p.Expression],
407-
to_matchpy_expr: Optional[ToMatchpyT] = None,
408-
from_matchpy_expr: Optional[FromMatchpyT] = None
409+
to_matchpy_expr: ToMatchpyT | None = None,
410+
from_matchpy_expr: FromMatchpyT | None = None
409411
) -> ReplacementRule:
410412
"""
411413
Returns a :class:`matchpy.functions.ReplacementRule` from the objects
@@ -429,9 +431,9 @@ def make_replacement_rule(pattern: p.Expression,
429431

430432
def replace_all(expression: p.Expression,
431433
rules: Iterable[ReplacementRule],
432-
to_matchpy_expr: Optional[ToMatchpyT] = None,
433-
from_matchpy_expr: Optional[FromMatchpyT] = None
434-
) -> Union[p.Expression, Tuple[p.Expression, ...]]:
434+
to_matchpy_expr: ToMatchpyT | None = None,
435+
from_matchpy_expr: FromMatchpyT | None = None
436+
) -> p.Expression | tuple[p.Expression, ...]:
435437
import collections.abc as abc
436438
from .tofrom import (ToMatchpyExpressionMapper,
437439
FromMatchpyExpressionMapper)

pymbolic/interop/maxima.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@
3636

3737
import re
3838
from sys import intern
39-
from typing import ClassVar, List, Tuple
39+
from typing import ClassVar
4040

4141
import numpy as np
4242

4343
import pytools
4444

4545
from pymbolic.mapper.stringifier import StringifyMapper
46-
from pymbolic.parser import Parser as ParserBase, FinalizedTuple
46+
from pymbolic.parser import Parser as ParserBase, FinalizedTuple, LexTable
4747

4848

4949
IN_PROMPT_RE = re.compile(br"\(%i([0-9]+)\) ")
@@ -95,7 +95,7 @@ class MaximaParser(ParserBase):
9595
imag_unit = intern("imag_unit")
9696
euler_number = intern("euler_number")
9797

98-
lex_table: ClassVar[List[Tuple[str, str]]] = [
98+
lex_table: ClassVar[LexTable] = [
9999
(power_sym, pytools.lex.RE(r"\^")),
100100
(imag_unit, pytools.lex.RE(r"%i")),
101101
(euler_number, pytools.lex.RE(r"%e")),

pymbolic/mapper/constant_folder.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ class ConstantFoldingMapper(
9595
IdentityMapper.map_common_subexpression
9696

9797

98-
class CommutativeConstantFoldingMapper(
98+
# Yes, map_product incompatible: missing *args, **kwargs
99+
class CommutativeConstantFoldingMapper( # type: ignore[misc]
99100
CSECachingMapperMixin,
100101
CommutativeConstantFoldingMapperBase,
101102
IdentityMapper):

0 commit comments

Comments
 (0)