Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pymbolic typing #98

Merged
merged 9 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.x']
python-version: ['3.10', '3.x']
steps:
- uses: actions/checkout@v4
-
Expand Down
15 changes: 0 additions & 15 deletions MANIFEST.in

This file was deleted.

2 changes: 1 addition & 1 deletion contrib/pov-nodes/nodal_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
unit_to_barycentric(mp.warp_and_blend_nodes(3, n, node_tuples))
).T
]
id_to_node = dict(list(zip(node_tuples, nodes)))
id_to_node = dict(list(zip(node_tuples, nodes, strict=True)))


def get_ball_radius(nid):
Expand Down
4 changes: 2 additions & 2 deletions contrib/pov-nodes/pov.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, name, args=(), opts=(), **kwargs):

args = list(args)
for i in range(len(args)):
if isinstance(args[i], (tuple, list, np.ndarray)):
if isinstance(args[i], tuple | list | np.ndarray):
args[i] = Vector(args[i])

self.args = args
Expand All @@ -95,7 +95,7 @@ def write(self, file):
for key, val in list(self.kwargs.items()):
if val is None:
file.writeln(key)
elif isinstance(val, (tuple, list, np.ndarray)):
elif isinstance(val, tuple | list | np.ndarray):
val = Vector(*val)
file.writeln(f"{key} {val}")
else:
Expand Down
1 change: 1 addition & 0 deletions examples/plot-basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
for (i, j), basis_func in zip(
gnitstam(p, dims),
simplex_onb(dims, p),
strict=True,
):

all_nodes.append([*plot_nodes, stretch_factor * i, stretch_factor * j])
Expand Down
2 changes: 1 addition & 1 deletion modepy/matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"""


from typing import Callable, Sequence
from collections.abc import Callable, Sequence
from warnings import warn

import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion modepy/modal_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def create_decay_baseline(mode_number_vector: np.ndarray, n: int) -> np.ndarray:
modal_coefficients[zeros] = 1 # irrelevant, just keeps log from NaNing

# NOTE: mypy seems to be confused by the __itruediv__ argument types
modal_coefficients /= la.norm(modal_coefficients) # type: ignore[misc]
modal_coefficients /= la.norm(modal_coefficients) # type: ignore[arg-type,misc]

return modal_coefficients

Expand Down
30 changes: 15 additions & 15 deletions modepy/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,10 @@

import math
from abc import ABC, abstractmethod
from collections.abc import Callable, Hashable, Iterable, Sequence
from functools import partial, singledispatch
from typing import (
TYPE_CHECKING,
Callable,
Hashable,
Iterable,
Sequence,
Tuple,
TypeVar,
)

Expand Down Expand Up @@ -134,7 +130,9 @@ def _cse(expr, prefix):
def _where(op_a, comp, op_b, then, else_):
from pymbolic.primitives import Comparison, Expression, If
if isinstance(op_a, Expression) or isinstance(op_b, Expression):
return If(Comparison(op_a, comp, op_b), then, else_)
return If(
Comparison(op_a, Comparison.name_to_operator[comp], op_b),
then, else_)

import operator
comp_op = getattr(operator, comp)
Expand Down Expand Up @@ -203,7 +201,7 @@ def jacobi(alpha: float, beta: float, n: int, x: RealValueT) -> RealValueT:

bnew = -(alpha*alpha-beta*beta)/(h1*(h1+2.))
pl.append(_cse(
(-aold*pl[i-1] + np.multiply(x-bnew, pl[i]))/anew,
(-aold*pl[i-1] + (x-bnew) * pl[i])/anew,
prefix=f"jac_p{i+1}"))
aold = anew

Expand Down Expand Up @@ -577,7 +575,7 @@ def __call__(self, x):
# Likely we're evaluating symbolically.
result = 1

for d, function in zip(self.dims_per_function, self.functions):
for d, function in zip(self.dims_per_function, self.functions, strict=True):
result *= function(x[n:n + d])
n += d

Expand Down Expand Up @@ -654,7 +652,7 @@ def __call__(self, x):
n = 0
for ider, derivative in enumerate(self.derivatives):
f = 0
for iaxis, function in zip(self.dims_per_function, derivative):
for iaxis, function in zip(self.dims_per_function, derivative, strict=True):
components = function(x[f:f + iaxis])

if isinstance(components, tuple):
Expand Down Expand Up @@ -723,7 +721,7 @@ def symbolicize_function(
# {{{ basis interface

BasisFunctionType = Callable[[np.ndarray], np.ndarray]
BasisGradientType = Callable[[np.ndarray], Tuple[np.ndarray, ...]]
BasisGradientType = Callable[[np.ndarray], tuple[np.ndarray, ...]]


class BasisNotOrthonormal(Exception):
Expand Down Expand Up @@ -973,7 +971,8 @@ def part_flat_tuple(iterable: Iterable[tuple[bool, Hashable]]
part_flat_tuple((flatten, umid[mid_index_i])
for flatten, umid, mid_index_i in zip(
is_all_singletons_with_int,
underlying_mode_id_lists, mode_index_tuple))
underlying_mode_id_lists, mode_index_tuple,
strict=True))
for mode_index_tuple in self._mode_index_tuples)

@property
Expand All @@ -998,7 +997,7 @@ def gradients(self) -> tuple[BasisGradientType, ...]:
tuple(
tuple(func[is_deriv][ibasis][mid_i]
for ibasis, (is_deriv, mid_i) in enumerate(
zip(deriv_indicator_vec, mid)))
zip(deriv_indicator_vec, mid, strict=True)))
for deriv_indicator_vec in wandering_element(self._nbases)),
dims_per_function=self._dims_per_basis)
for mid in self._mode_index_tuples)
Expand All @@ -1015,7 +1014,7 @@ def _orthonormal_basis_for_tp(
raise ValueError("spatial dimensions of shape and space must match")

bases = [orthonormal_basis_for_space(b, s)
for b, s in zip(space.bases, shape.bases)]
for b, s in zip(space.bases, shape.bases, strict=True)]

return TensorProductBasis(
bases,
Expand All @@ -1030,7 +1029,8 @@ def _basis_for_tp(space: TensorProductSpace, shape: TensorProductShape):
if space.spatial_dim != shape.dim:
raise ValueError("spatial dimensions of shape and space must match")

bases = [basis_for_space(b, s) for b, s in zip(space.bases, shape.bases)]
bases = [basis_for_space(b, s)
for b, s in zip(space.bases, shape.bases, strict=True)]
return TensorProductBasis(
bases,
dims_per_basis=tuple(b.spatial_dim for b in space.bases))
Expand All @@ -1043,7 +1043,7 @@ def _monomial_basis_for_tp(space: TensorProductSpace, shape: TensorProductShape)

bases = [
monomial_basis_for_space(b, s)
for b, s in zip(space.bases, shape.bases)]
for b, s in zip(space.bases, shape.bases, strict=True)]

return TensorProductBasis(
bases,
Expand Down
6 changes: 3 additions & 3 deletions modepy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@

# }}}

from collections.abc import Sequence
from functools import partial, singledispatch
from typing import Sequence

import numpy as np
import numpy.linalg as la
Expand Down Expand Up @@ -555,7 +555,7 @@ def _equispaced_nodes_for_tp(

return tensor_product_nodes([
equispaced_nodes_for_space(b, s)
for b, s in zip(space.bases, shape.bases)
for b, s in zip(space.bases, shape.bases, strict=True)
])


Expand All @@ -571,7 +571,7 @@ def _edge_clustered_nodes_for_tp(

return tensor_product_nodes([
edge_clustered_nodes_for_space(b, s)
for b, s in zip(space.bases, shape.bases)
for b, s in zip(space.bases, shape.bases, strict=True)
])


Expand Down
6 changes: 3 additions & 3 deletions modepy/quadrature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@
THE SOFTWARE.
"""

from collections.abc import Callable, Iterable, Sequence
from functools import singledispatch
from numbers import Number
from typing import Callable, Iterable, Sequence

import numpy as np

Expand All @@ -79,7 +79,7 @@ def __gt__(self, other: object) -> bool:
return bool(isinstance(other, Number))

def __ge__(self, other: object) -> bool:
return bool(isinstance(other, (Number, _Inf)))
return bool(isinstance(other, Number | _Inf))


inf = _Inf()
Expand Down Expand Up @@ -298,7 +298,7 @@ def _quadrature_for_tp(
else:
quad = TensorProductQuadrature([
quadrature_for_space(sp, shp)
for sp, shp in zip(space.bases, shape.bases)
for sp, shp in zip(space.bases, shape.bases, strict=True)
])

assert all(quad.exact_to >= getattr(s, "order", 0) for s in space.bases)
Expand Down
3 changes: 2 additions & 1 deletion modepy/quadrature/grundmann_moeller.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def __init__(self, order: int, dimension: int) -> None:

dim_factor = 2**n
for p, w in points_to_weights.items():
real_p = reduce(add, (a/b * v for (a, b), v in zip(p, vertices)))
real_p = reduce(add, (a/b * v
for (a, b), v in zip(p, vertices, strict=True)))
nodes.append(real_p)
weights.append(dim_factor * w)

Expand Down
5 changes: 3 additions & 2 deletions modepy/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,10 @@

import contextlib
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from functools import partial, singledispatch
from typing import Any, Callable, Sequence
from typing import Any

import numpy as np

Expand Down Expand Up @@ -291,7 +292,7 @@ def face_normal(face: Face, normalize: bool = True) -> np.ndarray:
from operator import xor as outerprod

from pymbolic.geometric_algebra import MultiVector
surface_ps = reduce(outerprod, [
surface_ps: MultiVector = reduce(outerprod, [
MultiVector(face_vertices[:, i+1] - face_vertices[:, 0])
for i in range(face.dim)])

Expand Down
6 changes: 3 additions & 3 deletions modepy/test/test_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"""


from typing import Tuple, cast
from typing import cast

import numpy as np
import numpy.linalg as la
Expand Down Expand Up @@ -110,8 +110,8 @@ def test_tensor_product_diag_mass_matrix(shape: mp.Shape) -> None:
# Note that gll_diag_mass_mat is not a good approximation of gll_ref_mass_mat
# in the matrix norm sense!

for mid, func in zip(basis.mode_ids, basis.functions):
if max(cast(Tuple[int, ...], mid)) < order - 1:
for mid, func in zip(basis.mode_ids, basis.functions, strict=True):
if max(cast(tuple[int, ...], mid)) < order - 1:
err = np.abs(
gll_ref_mass_mat @ func(gll_quad.nodes)
- gll_diag_mass_mat @ func(gll_quad.nodes))
Expand Down
8 changes: 4 additions & 4 deletions modepy/test/test_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_basis_grad(dim, shape_cls, order, basis_getter):

from pytools import wandering_element
from pytools.convergence import EOCRecorder
for bf, gradbf in zip(basis.functions, basis.gradients):
for bf, gradbf in zip(basis.functions, basis.gradients, strict=True):
eoc_rec = EOCRecorder()
for h in [1e-2, 1e-3]:
r = mp.random_nodes_for_shape(shape, nnodes=1000, rng=rng)
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_symbolic_basis(shape, order, basis_getter):
rng = np.random.Generator(np.random.PCG64(17))
r = mp.random_nodes_for_shape(shape, 10000, rng=rng)

for func, sym_func in zip(basis.functions, sym_basis):
for func, sym_func in zip(basis.functions, sym_basis, strict=True):
strmap = MyStringifyMapper()
s = strmap(sym_func)
for name, val in strmap.cse_name_list:
Expand Down Expand Up @@ -276,7 +276,7 @@ def test_symbolic_basis(shape, order, basis_getter):

sym_grad_basis = [mp.symbolicize_function(f, shape.dim) for f in basis.gradients]

for grad, sym_grad in zip(basis.gradients, sym_grad_basis):
for grad, sym_grad in zip(basis.gradients, sym_grad_basis, strict=True):
strmap = MyStringifyMapper()
s = strmap(sym_grad)
for name, val in strmap.cse_name_list:
Expand All @@ -290,7 +290,7 @@ def test_symbolic_basis(shape, order, basis_getter):
sym_val = (sym_val,)
ref_val = (ref_val,)

for sv_i, rv_i in zip(sym_val, ref_val):
for sv_i, rv_i in zip(sym_val, ref_val, strict=True):
ref_norm = la.norm(rv_i, np.inf)
err = la.norm(sv_i-rv_i, np.inf)
if ref_norm:
Expand Down
3 changes: 2 additions & 1 deletion modepy/test/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,8 @@ def test_tensor_product_vdm_dim_by_dim(dim):
x_r = reshape_array_for_tensor_product_space(space, x)
vdm_dimbydim_x_r = x_r

for i, (subspace, subshape) in enumerate(zip(space.bases, shape.bases)):
for i, (subspace, subshape) in enumerate(
zip(space.bases, shape.bases, strict=True)):
subnodes = mp.edge_clustered_nodes_for_space(subspace, subshape)
subbasis = mp.basis_for_space(subspace, subshape)
subvdm = mp.vandermonde(subbasis.functions, subnodes)
Expand Down
22 changes: 4 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[build-system]
build-backend = "setuptools.build_meta"
build-backend = "hatchling.build"
requires = [
"setuptools>=63",
"hatchling",
]

[project]
Expand All @@ -13,7 +13,7 @@ license = { text = "MIT" }
authors = [
{ name = "Andreas Kloeckner", email = "[email protected]" },
]
requires-python = ">=3.8"
requires-python = ">=3.10"
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
Expand All @@ -32,7 +32,7 @@ classifiers = [
]
dependencies = [
"numpy",
"pymbolic",
"pymbolic>=2024.1",
"pytools",
]

Expand All @@ -53,20 +53,6 @@ Documentation = "https://documen.tician.de/modepy"
Homepage = "https://mathema.tician.de/software/modepy"
Repository = "https://github.com/inducer/modepy"

[tool.setuptools.packages.find]
include = [
"modepy*",
]

[tool.setuptools.package-dir]
# https://github.com/Infleqtion/client-superstaq/pull/715
"" = "."

[tool.setuptools.package-data]
modepy = [
"py.typed",
]

[tool.ruff]
preview = true

Expand Down
Loading