Skip to content

Commit dee37a8

Browse files
committed
Enable TC ruff rules, fix
1 parent d6e82bb commit dee37a8

21 files changed

+119
-67
lines changed

pymbolic/algorithm.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@
4343
from pytools import MovedFunctionDeprecationWrapper, memoize
4444

4545

46-
if TYPE_CHECKING or getattr(sys, "_BUILDING_SPHINX_DOCS", None):
46+
if TYPE_CHECKING:
47+
import numpy as np
48+
49+
50+
if getattr(sys, "_BUILDING_SPHINX_DOCS", None):
4751
import numpy as np
4852

4953

pymbolic/geometric_algebra/__init__.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def __init__(
567567
f"are supported for 'data': shape {data.shape}")
568568

569569
dimensions, = data.shape
570-
data_dict = {(i,): cast(CoeffT, xi) for i, xi in enumerate(data)}
570+
data_dict = {(i,): cast("CoeffT", xi) for i, xi in enumerate(data)}
571571

572572
if space is None:
573573
space = get_euclidean_space(dimensions)
@@ -579,7 +579,7 @@ def __init__(
579579
elif isinstance(data, Mapping):
580580
data_dict = data
581581
else:
582-
data_dict = {0: cast(CoeffT, data)}
582+
data_dict = {0: cast("CoeffT", data)}
583583

584584
if space is None:
585585
raise ValueError("No 'space' provided")
@@ -595,16 +595,16 @@ def __init__(
595595
assert isinstance(basis_indices, tuple)
596596

597597
bits, sign = space.bits_and_sign(basis_indices)
598-
new_coeff = cast(CoeffT,
599-
new_data.setdefault(bits, cast(CoeffT, 0)) # type: ignore[operator]
598+
new_coeff = cast("CoeffT",
599+
new_data.setdefault(bits, cast("CoeffT", 0)) # type: ignore[operator]
600600
+ sign*coeff)
601601

602602
if is_zero(new_coeff):
603603
del new_data[bits]
604604
else:
605605
new_data[bits] = new_coeff
606606
else:
607-
new_data = cast(dict[int, CoeffT], data_dict)
607+
new_data = cast("dict[int, CoeffT]", data_dict)
608608

609609
# }}}
610610

@@ -691,8 +691,8 @@ def __add__(self, other) -> MultiVector:
691691
from pymbolic.primitives import is_zero
692692
new_data = {}
693693
for bits in all_bits:
694-
new_coeff = (self.data.get(bits, cast(CoeffT, 0))
695-
+ other.data.get(bits, cast(CoeffT, 0)))
694+
new_coeff = (self.data.get(bits, cast("CoeffT", 0))
695+
+ other.data.get(bits, cast("CoeffT", 0)))
696696

697697
if not is_zero(new_coeff):
698698
new_data[bits] = new_coeff
@@ -741,7 +741,7 @@ def _generic_product(self,
741741
coeff = (weight
742742
* canonical_reordering_sign(sbits, obits)
743743
* scoeff * ocoeff)
744-
new_coeff = new_data.setdefault(new_bits, cast(CoeffT, 0)) + coeff
744+
new_coeff = new_data.setdefault(new_bits, cast("CoeffT", 0)) + coeff
745745
if is_zero(new_coeff):
746746
del new_data[new_bits]
747747
else:
@@ -1134,7 +1134,7 @@ def componentwise(f: Callable[[CoeffT], CoeffT], expr: T) -> T:
11341134
"""
11351135

11361136
if isinstance(expr, MultiVector):
1137-
return cast(T, expr.map(f))
1137+
return cast("T", expr.map(f))
11381138

11391139
from pytools.obj_array import obj_array_vectorize
11401140
return obj_array_vectorize(f, expr)

pymbolic/geometric_algebra/mapper.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525

2626
# This is experimental, undocumented, and could go away any second.
2727
# Consider yourself warned.
28-
from collections.abc import Set
29-
from typing import ClassVar
28+
from typing import TYPE_CHECKING, ClassVar
3029

3130
import pymbolic.geometric_algebra.primitives as prim
3231
from pymbolic.geometric_algebra import MultiVector
@@ -49,7 +48,12 @@
4948
PREC_NONE,
5049
StringifyMapper as StringifyMapperBase,
5150
)
52-
from pymbolic.primitives import ExpressionNode
51+
52+
53+
if TYPE_CHECKING:
54+
from collections.abc import Set
55+
56+
from pymbolic.primitives import ExpressionNode
5357

5458

5559
class IdentityMapper(IdentityMapperBase[P]):

pymbolic/geometric_algebra/primitives.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,15 @@
2626
# This is experimental, undocumented, and could go away any second.
2727
# Consider yourself warned.
2828

29-
from collections.abc import Hashable
30-
from typing import ClassVar
29+
from typing import TYPE_CHECKING, ClassVar
3130

3231
from pymbolic.primitives import ExpressionNode, Variable, expr_dataclass
33-
from pymbolic.typing import Expression
32+
33+
34+
if TYPE_CHECKING:
35+
from collections.abc import Hashable
36+
37+
from pymbolic.typing import Expression
3438

3539

3640
class MultiVectorVariable(Variable):

pymbolic/interop/ast.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,14 @@
2727
"""
2828

2929
import ast
30-
from typing import Any, ClassVar
30+
from typing import TYPE_CHECKING, Any, ClassVar
3131

3232
import pymbolic.primitives as p
3333
from pymbolic.mapper import CachedMapper
34-
from pymbolic.typing import Expression
34+
35+
36+
if TYPE_CHECKING:
37+
from pymbolic.typing import Expression
3538

3639

3740
__doc__ = r'''

pymbolic/interop/matchpy/mapper.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from __future__ import annotations
22

3-
from collections.abc import Callable
4-
from typing import Any
3+
from typing import TYPE_CHECKING, Any
54

6-
from pymbolic.interop.matchpy import PymbolicOp
5+
6+
if TYPE_CHECKING:
7+
from collections.abc import Callable
8+
9+
from pymbolic.interop.matchpy import PymbolicOp
710

811

912
class Mapper:

pymbolic/interop/matchpy/tofrom.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from __future__ import annotations
22

3-
from collections.abc import Callable
43
from dataclasses import dataclass
5-
from typing import Any
4+
from typing import TYPE_CHECKING, Any
65

76
import multiset
87
import numpy as np
@@ -12,7 +11,12 @@
1211
import pymbolic.primitives as p
1312
from pymbolic.interop.matchpy.mapper import Mapper as BaseMatchPyMapper
1413
from pymbolic.mapper import Mapper as BasePymMapper
15-
from pymbolic.typing import Scalar as PbScalar
14+
15+
16+
if TYPE_CHECKING:
17+
from collections.abc import Callable
18+
19+
from pymbolic.typing import Scalar as PbScalar
1620

1721

1822
# {{{ to matchpy

pymbolic/mapper/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def __call__(self,
429429
method_name = getattr(expr, "mapper_method", None)
430430
if method_name is not None:
431431
method = cast(
432-
Callable[Concatenate[Expression, P], ResultT] | None,
432+
"Callable[Concatenate[Expression, P], ResultT] | None",
433433
getattr(self, method_name, None)
434434
)
435435
if method is not None:
@@ -973,7 +973,7 @@ def map_multivector(self,
973973
*args: P.args, **kwargs: P.kwargs
974974
) -> Expression:
975975
# True fact: MultiVectors aren't expressions
976-
return expr.map(lambda ch: cast(ArithmeticExpression,
976+
return expr.map(lambda ch: cast("ArithmeticExpression",
977977
self.rec(ch, *args, **kwargs))) # type: ignore[return-value]
978978

979979
def map_common_subexpression(self,
@@ -1012,7 +1012,7 @@ def map_derivative(self,
10121012
def map_slice(self,
10131013
expr: p.Slice,
10141014
*args: P.args, **kwargs: P.kwargs) -> Expression:
1015-
children: p.SliceChildrenT = cast(p.SliceChildrenT, tuple([
1015+
children: p.SliceChildrenT = cast("p.SliceChildrenT", tuple([
10161016
None if child is None else self.rec(child, *args, **kwargs)
10171017
for child in expr.children
10181018
]))

pymbolic/mapper/coefficient.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def map_product(self, expr: p.Product) -> CoeffsT:
6868
for i, child_coeffs in enumerate(children_coeffs):
6969
if i != idx_of_child_with_vars:
7070
assert len(child_coeffs) == 1
71-
other_coeffs *= cast(ArithmeticExpression, child_coeffs[1])
71+
other_coeffs *= cast("ArithmeticExpression", child_coeffs[1])
7272

7373
if idx_of_child_with_vars is None:
7474
return {1: other_coeffs}

pymbolic/mapper/collector.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,18 @@
2626
THE SOFTWARE.
2727
"""
2828

29-
from collections.abc import Sequence, Set
30-
from typing import cast
29+
from typing import TYPE_CHECKING, cast
3130

3231
import pymbolic
3332
import pymbolic.primitives as p
3433
from pymbolic.mapper import IdentityMapper
35-
from pymbolic.mapper.dependency import DependenciesT
36-
from pymbolic.typing import ArithmeticExpression, Expression
34+
35+
36+
if TYPE_CHECKING:
37+
from collections.abc import Sequence, Set
38+
39+
from pymbolic.mapper.dependency import DependenciesT
40+
from pymbolic.typing import ArithmeticExpression, Expression
3741

3842

3943
class TermCollector(IdentityMapper[[]]):
@@ -110,7 +114,7 @@ def exponent(term: Expression) -> ArithmeticExpression:
110114

111115
base_exp_set = frozenset(
112116
(base, exp) for base, exp in cleaned_base2exp.items())
113-
return base_exp_set, cast(ArithmeticExpression,
117+
return base_exp_set, cast("ArithmeticExpression",
114118
self.rec(pymbolic.flattened_product(coefficients)))
115119

116120
def map_sum(self, expr: p.Sum) -> Expression:

pymbolic/mapper/constant_folder.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
THE SOFTWARE.
2828
"""
2929

30-
from collections.abc import Callable
30+
31+
from typing import TYPE_CHECKING
3132

3233
from pymbolic.mapper import (
3334
CSECachingMapperMixin,
@@ -38,6 +39,10 @@
3839
from pymbolic.typing import ArithmeticExpression, Expression
3940

4041

42+
if TYPE_CHECKING:
43+
from collections.abc import Callable
44+
45+
4146
class ConstantFoldingMapperBase(Mapper[Expression, []]):
4247
def is_constant(self, expr):
4348
from pymbolic.mapper.dependency import DependencyMapper

pymbolic/mapper/distributor.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,17 @@
2727
THE SOFTWARE.
2828
"""
2929

30-
from typing import cast
30+
from typing import TYPE_CHECKING, cast
3131

3232
import pymbolic
3333
import pymbolic.primitives as p
3434
from pymbolic.mapper import IdentityMapper
3535
from pymbolic.mapper.collector import TermCollector
3636
from pymbolic.mapper.constant_folder import CommutativeConstantFoldingMapper
37-
from pymbolic.typing import ArithmeticExpression, Expression
37+
38+
39+
if TYPE_CHECKING:
40+
from pymbolic.typing import ArithmeticExpression, Expression
3841

3942

4043
class DistributeMapper(IdentityMapper[[]]):
@@ -118,7 +121,7 @@ def map_power(self, expr: p.Power) -> Expression:
118121
newbase = self.rec(expr.base)
119122
if isinstance(newbase, p.Product):
120123
return self.rec(pymbolic.flattened_product([
121-
cast(ArithmeticExpression, child)**expr.exponent
124+
cast("ArithmeticExpression", child)**expr.exponent
122125
for child in newbase.children
123126
]))
124127

pymbolic/mapper/evaluator.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,20 @@
3434
"""
3535

3636
import operator as op
37-
from collections.abc import Mapping
3837
from functools import reduce
3938
from typing import TYPE_CHECKING, cast
4039

41-
import pymbolic.primitives as p
4240
from pymbolic.mapper import CachedMapper, CSECachingMapperMixin, Mapper, ResultT
43-
from pymbolic.typing import Expression
4441

4542

4643
if TYPE_CHECKING:
44+
from collections.abc import Mapping
45+
4746
import numpy as np
4847

48+
import pymbolic.primitives as p
4949
from pymbolic.geometric_algebra import MultiVector
50+
from pymbolic.typing import Expression
5051

5152

5253
class UnknownVariableError(Exception):
@@ -82,7 +83,7 @@ def __init__(self, context: Mapping[str, ResultT] | None = None) -> None:
8283
self.context = context
8384

8485
def map_constant(self, expr: object) -> ResultT:
85-
return cast(ResultT, expr)
86+
return cast("ResultT", expr)
8687

8788
def map_variable(self, expr: p.Variable) -> ResultT:
8889
try:

pymbolic/mapper/flattener.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,18 @@
3131
THE SOFTWARE.
3232
"""
3333

34-
from typing import cast
34+
from typing import TYPE_CHECKING, cast
3535

3636
import pymbolic.primitives as p
3737
from pymbolic.mapper import IdentityMapper
38-
from pymbolic.typing import (
39-
ArithmeticExpression,
40-
ArithmeticOrExpressionT,
41-
Expression,
42-
)
38+
39+
40+
if TYPE_CHECKING:
41+
from pymbolic.typing import (
42+
ArithmeticExpression,
43+
ArithmeticOrExpressionT,
44+
Expression,
45+
)
4346

4447

4548
class FlattenMapper(IdentityMapper[[]]):
@@ -68,13 +71,13 @@ def is_expr_integer_valued(self, expr: Expression) -> bool:
6871
def map_sum(self, expr: p.Sum) -> Expression:
6972
from pymbolic.primitives import flattened_sum
7073
return flattened_sum([
71-
cast(ArithmeticExpression, self.rec(ch))
74+
cast("ArithmeticExpression", self.rec(ch))
7275
for ch in expr.children])
7376

7477
def map_product(self, expr: p.Product) -> Expression:
7578
from pymbolic.primitives import flattened_product
7679
return flattened_product([
77-
cast(ArithmeticExpression, self.rec(ch))
80+
cast("ArithmeticExpression", self.rec(ch))
7881
for ch in expr.children])
7982

8083
def map_quotient(self, expr: p.Quotient) -> Expression:
@@ -123,4 +126,4 @@ def map_power(self, expr: p.Power) -> Expression:
123126

124127

125128
def flatten(expr: ArithmeticOrExpressionT) -> ArithmeticOrExpressionT:
126-
return cast(ArithmeticOrExpressionT, FlattenMapper()(expr))
129+
return cast("ArithmeticOrExpressionT", FlattenMapper()(expr))

0 commit comments

Comments
 (0)