diff --git a/arraycontext/context.py b/arraycontext/context.py index 210a9b89..3bfcf8b4 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -557,7 +557,8 @@ def clone(self) -> Self: "setup-only" array context "leaks" into the application. """ - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + def compile(self, f: Callable[..., Any], + single_version_only: bool = False) -> Callable[..., Any]: """Compiles *f* for repeated use on this array context. *f* is expected to be a `pure function `__ performing an array computation. @@ -573,6 +574,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: it may be called only once (or a few times). :arg f: the function executing the computation. + :arg single_version_only: If *True*, raise an error if *f* is compiled + more than once (due to different input argument types). :return: a function with the same signature as *f*. """ return f diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index f7f7be8d..74ef2ca3 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -630,9 +630,11 @@ def call_loopy(self, program, **kwargs): return call_loopy(program, processed_kwargs, entrypoint) - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + def compile(self, f: Callable[..., Any], + single_version_only: bool = False) -> Callable[..., Any]: from .compile import LazilyPyOpenCLCompilingFunctionCaller - return LazilyPyOpenCLCompilingFunctionCaller(self, f) + return LazilyPyOpenCLCompilingFunctionCaller(self, + f, single_version_only) def transform_dag(self, dag: pytato.DictOfNamedArrays ) -> pytato.DictOfNamedArrays: @@ -844,9 +846,10 @@ def _thaw(ary): self._rec_map_container(_thaw, array, self._frozen_array_types), actx=self) - def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: + def compile(self, f: Callable[..., Any], + single_version_only: bool = False) -> Callable[..., Any]: from .compile import LazilyJAXCompilingFunctionCaller - return LazilyJAXCompilingFunctionCaller(self, f) + return LazilyJAXCompilingFunctionCaller(self, f, single_version_only) def tag(self, tags: ToTagSetConvertible, array): def _tag(ary): diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index e77c1091..35e20039 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -263,6 +263,7 @@ class BaseLazilyCompilingFunctionCaller: actx: _BasePytatoArrayContext f: Callable[..., Any] + single_version_only: bool program_cache: dict[Mapping[tuple[Hashable, ...], AbstractInputDescriptor], CompiledFunction] = field(default_factory=lambda: {}) @@ -322,15 +323,32 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: :attr:`~BaseLazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. The intermediary pytato DAG for *args* is memoized in *self*. """ - arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( - args, kwargs) - - try: - compiled_f = self.program_cache[arg_id_to_descr] - except KeyError: - pass + if not self.single_version_only: + arg_id_to_arg, arg_id_to_descr = \ + _get_arg_id_to_arg_and_arg_id_to_descr(args, kwargs) + + try: + compiled_f = self.program_cache[arg_id_to_descr] + except KeyError: + pass + else: + return compiled_f(arg_id_to_arg) else: - return compiled_f(arg_id_to_arg) + assert len(self.program_cache) <= 1 + + try: + arg_id_to_descr, compiled_f = next(iter(self.program_cache.items())) + except StopIteration: + arg_id_to_arg, arg_id_to_descr = \ + _get_arg_id_to_arg_and_arg_id_to_descr(args, kwargs) + else: + if __debug__: + current_arg_id_to_arg, current_arg_id_to_descr = \ + _get_arg_id_to_arg_and_arg_id_to_descr(args, kwargs) + assert arg_id_to_descr == current_arg_id_to_descr + assert self.arg_id_to_arg == current_arg_id_to_arg # pylint: disable=access-member-before-definition + + return compiled_f(self.arg_id_to_arg) # pylint: disable=access-member-before-definition dict_of_named_arrays = {} output_id_to_name_in_program = {} @@ -373,6 +391,9 @@ def _as_dict_of_named_arrays(keys, ary): output_id_to_name_in_program=output_id_to_name_in_program, output_template=output_template) + if self.single_version_only: + self.arg_id_to_arg: Mapping[tuple[Hashable, ...], Any] = arg_id_to_arg + self.program_cache[arg_id_to_descr] = compiled_func return compiled_func(arg_id_to_arg)