Skip to content

upgrade 'unevaluated array as argument' warning to error #305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arraycontext/impl/jax/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@

def zeros_like(self, ary):
def _zeros_like(array):
return self._array_context.zeros(array.shape, array.dtype)
return self._array_context.np.zeros(array.shape, array.dtype)

return self._array_context._rec_map_container(
_zeros_like, ary, default_scalar=0)
Expand Down Expand Up @@ -113,7 +113,7 @@
"""
if order in "AK":
from warnings import warn
warn(f"ravel with order='{order}' not supported by JAX,"

Check warning on line 116 in arraycontext/impl/jax/fake_numpy.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

ravel with order='A' not supported by JAX, using order=C.

Check warning on line 116 in arraycontext/impl/jax/fake_numpy.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

ravel with order='A' not supported by JAX, using order=C.

Check warning on line 116 in arraycontext/impl/jax/fake_numpy.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

ravel with order='A' not supported by JAX, using order=C.

Check warning on line 116 in arraycontext/impl/jax/fake_numpy.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

ravel with order='A' not supported by JAX, using order=C.

Check warning on line 116 in arraycontext/impl/jax/fake_numpy.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

ravel with order='A' not supported by JAX, using order=C.

Check warning on line 116 in arraycontext/impl/jax/fake_numpy.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

ravel with order='A' not supported by JAX, using order=C.

Check warning on line 116 in arraycontext/impl/jax/fake_numpy.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

ravel with order='A' not supported by JAX, using order=C.

Check warning on line 116 in arraycontext/impl/jax/fake_numpy.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

ravel with order='A' not supported by JAX, using order=C.
" using order=C.", stacklevel=1)
order = "C"

Expand Down
1 change: 0 additions & 1 deletion arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,6 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:

pt_prg = pt.generate_loopy(transformed_dag,
options=opts,
cl_device=self.queue.device,
function_name=function_name,
target=self.get_target()
).bind_to_context(self.context)
Expand Down
53 changes: 25 additions & 28 deletions arraycontext/impl/pytato/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,8 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
return pytato_program, name_in_program_to_tags, name_in_program_to_axes


def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg,
fn_name="<unknown>"):
input_kwargs_for_loopy = {}

for arg_id, arg in arg_id_to_arg.items():
Expand All @@ -550,32 +551,20 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
# got a frozen array => do nothing
pass
elif isinstance(arg, pt.Array):
# got an array expression => evaluate it
from warnings import warn
warn(f"Argument array '{arg_id}' to a compiled function is "
"unevaluated. Evaluating just-in-time, at "
"considerable expense. This is deprecated and will stop "
"working in 2023. To avoid this warning, force evaluation "
"of all arguments via freeze/thaw.",
DeprecationWarning, stacklevel=4)

arg = actx.freeze(arg)
# got an array expression => abort
raise ValueError(
f"Argument '{arg_id}' to the '{fn_name}' compiled function is a"
" pytato array expression. Evaluating it just-in-time"
" potentially causes a significant overhead on each call to the"
" function and is therefore unsupported. "
)
else:
raise NotImplementedError(type(arg))

input_kwargs_for_loopy[input_id_to_name_in_program[arg_id]] = arg

return input_kwargs_for_loopy


def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
from warnings import warn
warn("_args_to_cl_buffer has been renamed to"
" _args_to_device_buffers. This will be"
" an error in 2023.", DeprecationWarning, stacklevel=2)
return _args_to_device_buffers(actx, input_id_to_name_in_program,
arg_id_to_arg)

# }}}


Expand Down Expand Up @@ -631,7 +620,7 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
type of the callable.
"""
actx: PytatoPyOpenCLArrayContext
pytato_program: pt.target.BoundProgram
pytato_program: pt.target.loopy.BoundPyOpenCLExecutable
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
Expand All @@ -642,8 +631,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
from .utils import get_cl_axes_from_pt_axes
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array

fn_name = self.pytato_program.program.entrypoint
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the fn_names below appear to be inconsistent. Is there a specific reason for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is hopefully fixed in b1c3a73. Part of the challenge here is that CompiledPyOpenCLFunctionReturningArray appears to be unused.


input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)

evt, out_dict = self.pytato_program(queue=self.actx.queue,
allocator=self.actx.allocator,
Expand Down Expand Up @@ -674,7 +665,7 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction):
Name of the output array in the program.
"""
actx: PytatoPyOpenCLArrayContext
pytato_program: pt.target.BoundProgram
pytato_program: pt.target.loopy.BoundPyOpenCLExecutable
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_tags: frozenset[Tag]
output_axes: tuple[pt.Axis, ...]
Expand All @@ -684,8 +675,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
from .utils import get_cl_axes_from_pt_axes
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array

fn_name = self.pytato_program.program.entrypoint

input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)

evt, out_dict = self.pytato_program(queue=self.actx.queue,
allocator=self.actx.allocator,
Expand Down Expand Up @@ -723,16 +716,18 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction):
type of the callable.
"""
actx: PytatoJAXArrayContext
pytato_program: pt.target.BoundProgram
pytato_program: pt.target.python.BoundJAXPythonProgram
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]]
output_template: ArrayContainer

def __call__(self, arg_id_to_arg) -> ArrayContainer:
fn_name = self.pytato_program.entrypoint

input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)

out_dict = self.pytato_program(**input_kwargs_for_loopy)

Expand All @@ -754,15 +749,17 @@ class CompiledJAXFunctionReturningArray(CompiledFunction):
Name of the output array in the program.
"""
actx: PytatoJAXArrayContext
pytato_program: pt.target.BoundProgram
pytato_program: pt.target.python.BoundJAXPythonProgram
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_tags: frozenset[Tag]
output_axes: tuple[pt.Axis, ...]
output_name: str

def __call__(self, arg_id_to_arg) -> ArrayContainer:
fn_name = self.pytato_program.entrypoint

input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)

_evt, out_dict = self.pytato_program(**input_kwargs_for_loopy)

Expand Down
2 changes: 1 addition & 1 deletion arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def zeros(self, shape, dtype):

def zeros_like(self, ary):
def _zeros_like(array):
return self._array_context.zeros(
return self._array_context.np.zeros(
array.shape, array.dtype).copy(axes=array.axes, tags=array.tags)

return self._array_context._rec_map_container(
Expand Down
21 changes: 12 additions & 9 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,6 @@ def test_actx_compile_kwargs(actx_factory):
def test_actx_compile_with_tuple_output_keys(actx_factory):
# arraycontext.git<=3c9aee68 would fail due to a bug in output
# key stringification logic.
from arraycontext import from_numpy, to_numpy
actx = actx_factory()
rng = np.random.default_rng()

Expand All @@ -1157,11 +1156,11 @@ def my_rhs(scale, vel):
v_x = rng.uniform(size=10)
v_y = rng.uniform(size=10)

vel = from_numpy(Velocity2D(v_x, v_y, actx), actx)
vel = actx.from_numpy(Velocity2D(v_x, v_y, actx))

scaled_speed = compiled_rhs(3.14, vel=vel)

result = to_numpy(scaled_speed, actx)[0, 0]
result = actx.to_numpy(scaled_speed)[0, 0]
np.testing.assert_allclose(result.u, -3.14*v_y)
np.testing.assert_allclose(result.v, 3.14*v_x)

Expand Down Expand Up @@ -1287,6 +1286,8 @@ class ArrayContainerWithNumpy:
u: np.ndarray
v: DOFArray

__array_ufunc__ = None


def test_array_container_with_numpy(actx_factory):
actx = actx_factory()
Expand Down Expand Up @@ -1405,14 +1406,16 @@ def test_compile_anonymous_function(actx_factory):

# See https://github.com/inducer/grudge/issues/287
actx = actx_factory()

ones = actx.thaw(actx.freeze(
actx.np.zeros(shape=(10, 4), dtype=np.float64) + 1
))

f = actx.compile(lambda x: 2*x+40)
np.testing.assert_allclose(
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
42)
np.testing.assert_allclose(actx.to_numpy(f(ones)), 42)

f = actx.compile(partial(lambda x: 2*x+40))
np.testing.assert_allclose(
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
42)
np.testing.assert_allclose(actx.to_numpy(f(ones)), 42)


@pytest.mark.parametrize(
Expand Down
6 changes: 5 additions & 1 deletion test/testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def array_context(self):

@with_container_arithmetic(
bcasts_across_obj_array=False,
bcast_container_types=(DOFArray, np.ndarray),
container_types_bcast_across=(DOFArray, np.ndarray),
matmul=True,
rel_comparison=True,
_cls_has_array_context_attr=True,
Expand All @@ -173,6 +173,8 @@ class MyContainerDOFBcast:
momentum: np.ndarray
enthalpy: DOFArray | np.ndarray

__array_ufunc__ = None

@property
def array_context(self):
if isinstance(self.mass, np.ndarray):
Expand Down Expand Up @@ -209,6 +211,8 @@ class Velocity2D:
v: ArrayContainer
array_context: ArrayContext

__array_ufunc__ = None


@with_array_context.register(Velocity2D)
# https://github.com/python/mypy/issues/13040
Expand Down
Loading