Skip to content

Commit 4aeaed4

Browse files
upgrade 'unevaluated array as argument' warning to error (#305)
* upgrade unevaluated array as argument warning logger.warning does not deduplicate, in contrast to warnings.warn * remove deprecated function * also print kernel name * fix func names * use warnings.warn instead * make it an error * fix test * fix a few deprecation-related warnings * moar warnings * avoid buggy ruff version astral-sh/ruff#16874 * make fn_name optional * Revert "avoid buggy ruff version" This reverts commit 392c62c. * fix types
1 parent 21a89f1 commit 4aeaed4

File tree

6 files changed

+44
-41
lines changed

6 files changed

+44
-41
lines changed

arraycontext/impl/jax/fake_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _empty_like(array):
8080

8181
def zeros_like(self, ary):
8282
def _zeros_like(array):
83-
return self._array_context.zeros(array.shape, array.dtype)
83+
return self._array_context.np.zeros(array.shape, array.dtype)
8484

8585
return self._array_context._rec_map_container(
8686
_zeros_like, ary, default_scalar=0)

arraycontext/impl/pytato/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,6 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
535535

536536
pt_prg = pt.generate_loopy(transformed_dag,
537537
options=opts,
538-
cl_device=self.queue.device,
539538
function_name=function_name,
540539
target=self.get_target()
541540
).bind_to_context(self.context)

arraycontext/impl/pytato/compile.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,8 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
529529
return pytato_program, name_in_program_to_tags, name_in_program_to_axes
530530

531531

532-
def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
532+
def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg,
533+
fn_name="<unknown>"):
533534
input_kwargs_for_loopy = {}
534535

535536
for arg_id, arg in arg_id_to_arg.items():
@@ -550,32 +551,20 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
550551
# got a frozen array => do nothing
551552
pass
552553
elif isinstance(arg, pt.Array):
553-
# got an array expression => evaluate it
554-
from warnings import warn
555-
warn(f"Argument array '{arg_id}' to a compiled function is "
556-
"unevaluated. Evaluating just-in-time, at "
557-
"considerable expense. This is deprecated and will stop "
558-
"working in 2023. To avoid this warning, force evaluation "
559-
"of all arguments via freeze/thaw.",
560-
DeprecationWarning, stacklevel=4)
561-
562-
arg = actx.freeze(arg)
554+
# got an array expression => abort
555+
raise ValueError(
556+
f"Argument '{arg_id}' to the '{fn_name}' compiled function is a"
557+
" pytato array expression. Evaluating it just-in-time"
558+
" potentially causes a significant overhead on each call to the"
559+
" function and is therefore unsupported. "
560+
)
563561
else:
564562
raise NotImplementedError(type(arg))
565563

566564
input_kwargs_for_loopy[input_id_to_name_in_program[arg_id]] = arg
567565

568566
return input_kwargs_for_loopy
569567

570-
571-
def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
572-
from warnings import warn
573-
warn("_args_to_cl_buffer has been renamed to"
574-
" _args_to_device_buffers. This will be"
575-
" an error in 2023.", DeprecationWarning, stacklevel=2)
576-
return _args_to_device_buffers(actx, input_id_to_name_in_program,
577-
arg_id_to_arg)
578-
579568
# }}}
580569

581570

@@ -631,7 +620,7 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
631620
type of the callable.
632621
"""
633622
actx: PytatoPyOpenCLArrayContext
634-
pytato_program: pt.target.BoundProgram
623+
pytato_program: pt.target.loopy.BoundPyOpenCLExecutable
635624
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
636625
output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
637626
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
@@ -642,8 +631,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
642631
from .utils import get_cl_axes_from_pt_axes
643632
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
644633

634+
fn_name = self.pytato_program.program.entrypoint
635+
645636
input_kwargs_for_loopy = _args_to_device_buffers(
646-
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
637+
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
647638

648639
evt, out_dict = self.pytato_program(queue=self.actx.queue,
649640
allocator=self.actx.allocator,
@@ -674,7 +665,7 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction):
674665
Name of the output array in the program.
675666
"""
676667
actx: PytatoPyOpenCLArrayContext
677-
pytato_program: pt.target.BoundProgram
668+
pytato_program: pt.target.loopy.BoundPyOpenCLExecutable
678669
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
679670
output_tags: frozenset[Tag]
680671
output_axes: tuple[pt.Axis, ...]
@@ -684,8 +675,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
684675
from .utils import get_cl_axes_from_pt_axes
685676
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
686677

678+
fn_name = self.pytato_program.program.entrypoint
679+
687680
input_kwargs_for_loopy = _args_to_device_buffers(
688-
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
681+
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
689682

690683
evt, out_dict = self.pytato_program(queue=self.actx.queue,
691684
allocator=self.actx.allocator,
@@ -723,16 +716,18 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction):
723716
type of the callable.
724717
"""
725718
actx: PytatoJAXArrayContext
726-
pytato_program: pt.target.BoundProgram
719+
pytato_program: pt.target.python.BoundJAXPythonProgram
727720
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
728721
output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
729722
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
730723
name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]]
731724
output_template: ArrayContainer
732725

733726
def __call__(self, arg_id_to_arg) -> ArrayContainer:
727+
fn_name = self.pytato_program.entrypoint
728+
734729
input_kwargs_for_loopy = _args_to_device_buffers(
735-
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
730+
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
736731

737732
out_dict = self.pytato_program(**input_kwargs_for_loopy)
738733

@@ -754,15 +749,17 @@ class CompiledJAXFunctionReturningArray(CompiledFunction):
754749
Name of the output array in the program.
755750
"""
756751
actx: PytatoJAXArrayContext
757-
pytato_program: pt.target.BoundProgram
752+
pytato_program: pt.target.python.BoundJAXPythonProgram
758753
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
759754
output_tags: frozenset[Tag]
760755
output_axes: tuple[pt.Axis, ...]
761756
output_name: str
762757

763758
def __call__(self, arg_id_to_arg) -> ArrayContainer:
759+
fn_name = self.pytato_program.entrypoint
760+
764761
input_kwargs_for_loopy = _args_to_device_buffers(
765-
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
762+
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
766763

767764
_evt, out_dict = self.pytato_program(**input_kwargs_for_loopy)
768765

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def zeros(self, shape, dtype):
9393

9494
def zeros_like(self, ary):
9595
def _zeros_like(array):
96-
return self._array_context.zeros(
96+
return self._array_context.np.zeros(
9797
array.shape, array.dtype).copy(axes=array.axes, tags=array.tags)
9898

9999
return self._array_context._rec_map_container(

test/test_arraycontext.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,6 @@ def test_actx_compile_kwargs(actx_factory):
11431143
def test_actx_compile_with_tuple_output_keys(actx_factory):
11441144
# arraycontext.git<=3c9aee68 would fail due to a bug in output
11451145
# key stringification logic.
1146-
from arraycontext import from_numpy, to_numpy
11471146
actx = actx_factory()
11481147
rng = np.random.default_rng()
11491148

@@ -1157,11 +1156,11 @@ def my_rhs(scale, vel):
11571156
v_x = rng.uniform(size=10)
11581157
v_y = rng.uniform(size=10)
11591158

1160-
vel = from_numpy(Velocity2D(v_x, v_y, actx), actx)
1159+
vel = actx.from_numpy(Velocity2D(v_x, v_y, actx))
11611160

11621161
scaled_speed = compiled_rhs(3.14, vel=vel)
11631162

1164-
result = to_numpy(scaled_speed, actx)[0, 0]
1163+
result = actx.to_numpy(scaled_speed)[0, 0]
11651164
np.testing.assert_allclose(result.u, -3.14*v_y)
11661165
np.testing.assert_allclose(result.v, 3.14*v_x)
11671166

@@ -1287,6 +1286,8 @@ class ArrayContainerWithNumpy:
12871286
u: np.ndarray
12881287
v: DOFArray
12891288

1289+
__array_ufunc__ = None
1290+
12901291

12911292
def test_array_container_with_numpy(actx_factory):
12921293
actx = actx_factory()
@@ -1405,14 +1406,16 @@ def test_compile_anonymous_function(actx_factory):
14051406

14061407
# See https://github.com/inducer/grudge/issues/287
14071408
actx = actx_factory()
1409+
1410+
ones = actx.thaw(actx.freeze(
1411+
actx.np.zeros(shape=(10, 4), dtype=np.float64) + 1
1412+
))
1413+
14081414
f = actx.compile(lambda x: 2*x+40)
1409-
np.testing.assert_allclose(
1410-
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
1411-
42)
1415+
np.testing.assert_allclose(actx.to_numpy(f(ones)), 42)
1416+
14121417
f = actx.compile(partial(lambda x: 2*x+40))
1413-
np.testing.assert_allclose(
1414-
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
1415-
42)
1418+
np.testing.assert_allclose(actx.to_numpy(f(ones)), 42)
14161419

14171420

14181421
@pytest.mark.parametrize(

test/testlib.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def array_context(self):
160160

161161
@with_container_arithmetic(
162162
bcasts_across_obj_array=False,
163-
bcast_container_types=(DOFArray, np.ndarray),
163+
container_types_bcast_across=(DOFArray, np.ndarray),
164164
matmul=True,
165165
rel_comparison=True,
166166
_cls_has_array_context_attr=True,
@@ -173,6 +173,8 @@ class MyContainerDOFBcast:
173173
momentum: np.ndarray
174174
enthalpy: DOFArray | np.ndarray
175175

176+
__array_ufunc__ = None
177+
176178
@property
177179
def array_context(self):
178180
if isinstance(self.mass, np.ndarray):
@@ -209,6 +211,8 @@ class Velocity2D:
209211
v: ArrayContainer
210212
array_context: ArrayContext
211213

214+
__array_ufunc__ = None
215+
212216

213217
@with_array_context.register(Velocity2D)
214218
# https://github.com/python/mypy/issues/13040

0 commit comments

Comments
 (0)