@@ -529,7 +529,8 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
529
529
return pytato_program , name_in_program_to_tags , name_in_program_to_axes
530
530
531
531
532
- def _args_to_device_buffers (actx , input_id_to_name_in_program , arg_id_to_arg ):
532
+ def _args_to_device_buffers (actx , input_id_to_name_in_program , arg_id_to_arg ,
533
+ fn_name = "<unknown>" ):
533
534
input_kwargs_for_loopy = {}
534
535
535
536
for arg_id , arg in arg_id_to_arg .items ():
@@ -550,32 +551,20 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
550
551
# got a frozen array => do nothing
551
552
pass
552
553
elif isinstance (arg , pt .Array ):
553
- # got an array expression => evaluate it
554
- from warnings import warn
555
- warn (f"Argument array '{ arg_id } ' to a compiled function is "
556
- "unevaluated. Evaluating just-in-time, at "
557
- "considerable expense. This is deprecated and will stop "
558
- "working in 2023. To avoid this warning, force evaluation "
559
- "of all arguments via freeze/thaw." ,
560
- DeprecationWarning , stacklevel = 4 )
561
-
562
- arg = actx .freeze (arg )
554
+ # got an array expression => abort
555
+ raise ValueError (
556
+ f"Argument '{ arg_id } ' to the '{ fn_name } ' compiled function is a"
557
+ " pytato array expression. Evaluating it just-in-time"
558
+ " potentially causes a significant overhead on each call to the"
559
+ " function and is therefore unsupported. "
560
+ )
563
561
else :
564
562
raise NotImplementedError (type (arg ))
565
563
566
564
input_kwargs_for_loopy [input_id_to_name_in_program [arg_id ]] = arg
567
565
568
566
return input_kwargs_for_loopy
569
567
570
-
571
- def _args_to_cl_buffers (actx , input_id_to_name_in_program , arg_id_to_arg ):
572
- from warnings import warn
573
- warn ("_args_to_cl_buffer has been renamed to"
574
- " _args_to_device_buffers. This will be"
575
- " an error in 2023." , DeprecationWarning , stacklevel = 2 )
576
- return _args_to_device_buffers (actx , input_id_to_name_in_program ,
577
- arg_id_to_arg )
578
-
579
568
# }}}
580
569
581
570
@@ -631,7 +620,7 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
631
620
type of the callable.
632
621
"""
633
622
actx : PytatoPyOpenCLArrayContext
634
- pytato_program : pt .target .BoundProgram
623
+ pytato_program : pt .target .loopy . BoundPyOpenCLExecutable
635
624
input_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
636
625
output_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
637
626
name_in_program_to_tags : Mapping [str , frozenset [Tag ]]
@@ -642,8 +631,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
642
631
from .utils import get_cl_axes_from_pt_axes
643
632
from arraycontext .impl .pyopencl .taggable_cl_array import to_tagged_cl_array
644
633
634
+ fn_name = self .pytato_program .program .entrypoint
635
+
645
636
input_kwargs_for_loopy = _args_to_device_buffers (
646
- self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
637
+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg , fn_name )
647
638
648
639
evt , out_dict = self .pytato_program (queue = self .actx .queue ,
649
640
allocator = self .actx .allocator ,
@@ -674,7 +665,7 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction):
674
665
Name of the output array in the program.
675
666
"""
676
667
actx : PytatoPyOpenCLArrayContext
677
- pytato_program : pt .target .BoundProgram
668
+ pytato_program : pt .target .loopy . BoundPyOpenCLExecutable
678
669
input_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
679
670
output_tags : frozenset [Tag ]
680
671
output_axes : tuple [pt .Axis , ...]
@@ -684,8 +675,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
684
675
from .utils import get_cl_axes_from_pt_axes
685
676
from arraycontext .impl .pyopencl .taggable_cl_array import to_tagged_cl_array
686
677
678
+ fn_name = self .pytato_program .program .entrypoint
679
+
687
680
input_kwargs_for_loopy = _args_to_device_buffers (
688
- self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
681
+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg , fn_name )
689
682
690
683
evt , out_dict = self .pytato_program (queue = self .actx .queue ,
691
684
allocator = self .actx .allocator ,
@@ -723,16 +716,18 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction):
723
716
type of the callable.
724
717
"""
725
718
actx : PytatoJAXArrayContext
726
- pytato_program : pt .target .BoundProgram
719
+ pytato_program : pt .target .python . BoundJAXPythonProgram
727
720
input_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
728
721
output_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
729
722
name_in_program_to_tags : Mapping [str , frozenset [Tag ]]
730
723
name_in_program_to_axes : Mapping [str , tuple [pt .Axis , ...]]
731
724
output_template : ArrayContainer
732
725
733
726
def __call__ (self , arg_id_to_arg ) -> ArrayContainer :
727
+ fn_name = self .pytato_program .entrypoint
728
+
734
729
input_kwargs_for_loopy = _args_to_device_buffers (
735
- self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
730
+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg , fn_name )
736
731
737
732
out_dict = self .pytato_program (** input_kwargs_for_loopy )
738
733
@@ -754,15 +749,17 @@ class CompiledJAXFunctionReturningArray(CompiledFunction):
754
749
Name of the output array in the program.
755
750
"""
756
751
actx : PytatoJAXArrayContext
757
- pytato_program : pt .target .BoundProgram
752
+ pytato_program : pt .target .python . BoundJAXPythonProgram
758
753
input_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
759
754
output_tags : frozenset [Tag ]
760
755
output_axes : tuple [pt .Axis , ...]
761
756
output_name : str
762
757
763
758
def __call__ (self , arg_id_to_arg ) -> ArrayContainer :
759
+ fn_name = self .pytato_program .entrypoint
760
+
764
761
input_kwargs_for_loopy = _args_to_device_buffers (
765
- self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
762
+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg , fn_name )
766
763
767
764
_evt , out_dict = self .pytato_program (** input_kwargs_for_loopy )
768
765
0 commit comments