diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 1a4e790f..7acf4fab 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -80,7 +80,7 @@ def _empty_like(array): def zeros_like(self, ary): def _zeros_like(array): - return self._array_context.zeros(array.shape, array.dtype) + return self._array_context.np.zeros(array.shape, array.dtype) return self._array_context._rec_map_container( _zeros_like, ary, default_scalar=0) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 337a70ea..8c8e73de 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -535,7 +535,6 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray: pt_prg = pt.generate_loopy(transformed_dag, options=opts, - cl_device=self.queue.device, function_name=function_name, target=self.get_target() ).bind_to_context(self.context) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 90449f0a..79328c15 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -529,7 +529,8 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): return pytato_program, name_in_program_to_tags, name_in_program_to_axes -def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): +def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg, + fn_name=""): input_kwargs_for_loopy = {} for arg_id, arg in arg_id_to_arg.items(): @@ -550,16 +551,13 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): # got a frozen array => do nothing pass elif isinstance(arg, pt.Array): - # got an array expression => evaluate it - from warnings import warn - warn(f"Argument array '{arg_id}' to a compiled function is " - "unevaluated. Evaluating just-in-time, at " - "considerable expense. This is deprecated and will stop " - "working in 2023. To avoid this warning, force evaluation " - "of all arguments via freeze/thaw.", - DeprecationWarning, stacklevel=4) - - arg = actx.freeze(arg) + # got an array expression => abort + raise ValueError( + f"Argument '{arg_id}' to the '{fn_name}' compiled function is a" + " pytato array expression. Evaluating it just-in-time" + " potentially causes a significant overhead on each call to the" + " function and is therefore unsupported. " + ) else: raise NotImplementedError(type(arg)) @@ -567,15 +565,6 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): return input_kwargs_for_loopy - -def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): - from warnings import warn - warn("_args_to_cl_buffer has been renamed to" - " _args_to_device_buffers. This will be" - " an error in 2023.", DeprecationWarning, stacklevel=2) - return _args_to_device_buffers(actx, input_id_to_name_in_program, - arg_id_to_arg) - # }}} @@ -631,7 +620,7 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction): type of the callable. """ actx: PytatoPyOpenCLArrayContext - pytato_program: pt.target.BoundProgram + pytato_program: pt.target.loopy.BoundPyOpenCLExecutable input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] name_in_program_to_tags: Mapping[str, frozenset[Tag]] @@ -642,8 +631,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: from .utils import get_cl_axes_from_pt_axes from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array + fn_name = self.pytato_program.program.entrypoint + input_kwargs_for_loopy = _args_to_device_buffers( - self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) evt, out_dict = self.pytato_program(queue=self.actx.queue, allocator=self.actx.allocator, @@ -674,7 +665,7 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction): Name of the output array in the program. """ actx: PytatoPyOpenCLArrayContext - pytato_program: pt.target.BoundProgram + pytato_program: pt.target.loopy.BoundPyOpenCLExecutable input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_tags: frozenset[Tag] output_axes: tuple[pt.Axis, ...] @@ -684,8 +675,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: from .utils import get_cl_axes_from_pt_axes from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array + fn_name = self.pytato_program.program.entrypoint + input_kwargs_for_loopy = _args_to_device_buffers( - self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) evt, out_dict = self.pytato_program(queue=self.actx.queue, allocator=self.actx.allocator, @@ -723,7 +716,7 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): type of the callable. """ actx: PytatoJAXArrayContext - pytato_program: pt.target.BoundProgram + pytato_program: pt.target.python.BoundJAXPythonProgram input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] name_in_program_to_tags: Mapping[str, frozenset[Tag]] @@ -731,8 +724,10 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: + fn_name = self.pytato_program.entrypoint + input_kwargs_for_loopy = _args_to_device_buffers( - self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) out_dict = self.pytato_program(**input_kwargs_for_loopy) @@ -754,15 +749,17 @@ class CompiledJAXFunctionReturningArray(CompiledFunction): Name of the output array in the program. """ actx: PytatoJAXArrayContext - pytato_program: pt.target.BoundProgram + pytato_program: pt.target.python.BoundJAXPythonProgram input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_tags: frozenset[Tag] output_axes: tuple[pt.Axis, ...] output_name: str def __call__(self, arg_id_to_arg) -> ArrayContainer: + fn_name = self.pytato_program.entrypoint + input_kwargs_for_loopy = _args_to_device_buffers( - self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) _evt, out_dict = self.pytato_program(**input_kwargs_for_loopy) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index a9158fda..5b864e6c 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -93,7 +93,7 @@ def zeros(self, shape, dtype): def zeros_like(self, ary): def _zeros_like(array): - return self._array_context.zeros( + return self._array_context.np.zeros( array.shape, array.dtype).copy(axes=array.axes, tags=array.tags) return self._array_context._rec_map_container( diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 11ccbb1f..904b8ad9 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1143,7 +1143,6 @@ def test_actx_compile_kwargs(actx_factory): def test_actx_compile_with_tuple_output_keys(actx_factory): # arraycontext.git<=3c9aee68 would fail due to a bug in output # key stringification logic. - from arraycontext import from_numpy, to_numpy actx = actx_factory() rng = np.random.default_rng() @@ -1157,11 +1156,11 @@ def my_rhs(scale, vel): v_x = rng.uniform(size=10) v_y = rng.uniform(size=10) - vel = from_numpy(Velocity2D(v_x, v_y, actx), actx) + vel = actx.from_numpy(Velocity2D(v_x, v_y, actx)) scaled_speed = compiled_rhs(3.14, vel=vel) - result = to_numpy(scaled_speed, actx)[0, 0] + result = actx.to_numpy(scaled_speed)[0, 0] np.testing.assert_allclose(result.u, -3.14*v_y) np.testing.assert_allclose(result.v, 3.14*v_x) @@ -1287,6 +1286,8 @@ class ArrayContainerWithNumpy: u: np.ndarray v: DOFArray + __array_ufunc__ = None + def test_array_container_with_numpy(actx_factory): actx = actx_factory() @@ -1405,14 +1406,16 @@ def test_compile_anonymous_function(actx_factory): # See https://github.com/inducer/grudge/issues/287 actx = actx_factory() + + ones = actx.thaw(actx.freeze( + actx.np.zeros(shape=(10, 4), dtype=np.float64) + 1 + )) + f = actx.compile(lambda x: 2*x+40) - np.testing.assert_allclose( - actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))), - 42) + np.testing.assert_allclose(actx.to_numpy(f(ones)), 42) + f = actx.compile(partial(lambda x: 2*x+40)) - np.testing.assert_allclose( - actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))), - 42) + np.testing.assert_allclose(actx.to_numpy(f(ones)), 42) @pytest.mark.parametrize( diff --git a/test/testlib.py b/test/testlib.py index 3f085207..da33deae 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -160,7 +160,7 @@ def array_context(self): @with_container_arithmetic( bcasts_across_obj_array=False, - bcast_container_types=(DOFArray, np.ndarray), + container_types_bcast_across=(DOFArray, np.ndarray), matmul=True, rel_comparison=True, _cls_has_array_context_attr=True, @@ -173,6 +173,8 @@ class MyContainerDOFBcast: momentum: np.ndarray enthalpy: DOFArray | np.ndarray + __array_ufunc__ = None + @property def array_context(self): if isinstance(self.mass, np.ndarray): @@ -209,6 +211,8 @@ class Velocity2D: v: ArrayContainer array_context: ArrayContext + __array_ufunc__ = None + @with_array_context.register(Velocity2D) # https://github.com/python/mypy/issues/13040