@@ -82,6 +82,19 @@ def yield_namespace_device_dtype_combinations(include_numpy_namespaces=True):
8282 ):
8383 yield array_namespace , device , dtype
8484 yield array_namespace , "mps" , "float32"
85+
86+ elif array_namespace == "array_api_strict" :
87+ try :
88+ import array_api_strict # noqa
89+
90+ yield array_namespace , array_api_strict .Device ("CPU_DEVICE" ), "float64"
91+ yield array_namespace , array_api_strict .Device ("device1" ), "float32"
92+ except ImportError :
93+ # Those combinations will typically be skipped by pytest if
94+ # array_api_strict is not installed but we still need to see them in
95+ # the test output.
96+ yield array_namespace , "CPU_DEVICE" , "float64"
97+ yield array_namespace , "device1" , "float32"
8598 else :
8699 yield array_namespace , None , None
87100
@@ -582,12 +595,14 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):
582595 if namespace .__name__ == "array_api_strict" and hasattr (
583596 namespace , "set_array_api_strict_flags"
584597 ):
585- namespace .set_array_api_strict_flags (api_version = "2023 .12" )
598+ namespace .set_array_api_strict_flags (api_version = "2024 .12" )
586599
587600 return namespace , is_array_api_compliant
588601
589602
590- def get_namespace_and_device (* array_list , remove_none = True , remove_types = (str ,)):
603+ def get_namespace_and_device (
604+ * array_list , remove_none = True , remove_types = (str ,), xp = None
605+ ):
591606 """Combination into one single function of `get_namespace` and `device`.
592607
593608 Parameters
@@ -598,6 +613,10 @@ def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,))
598613 Whether to ignore None objects passed in arrays.
599614 remove_types : tuple or list, default=(str,)
600615 Types to ignore in the arrays.
616+ xp : module, default=None
617+ Precomputed array namespace module. When passed, typically from a caller
618+ that has already performed inspection of its own inputs, skips array
619+ namespace inspection.
601620
602621 Returns
603622 -------
@@ -610,16 +629,20 @@ def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,))
610629 device : device
611630 `device` object (see the "Device Support" section of the array API spec).
612631 """
632+ skip_remove_kwargs = dict (remove_none = False , remove_types = [])
633+
613634 array_list = _remove_non_arrays (
614635 * array_list ,
615636 remove_none = remove_none ,
616637 remove_types = remove_types ,
617638 )
639+ arrays_device = device (* array_list , ** skip_remove_kwargs )
618640
619- skip_remove_kwargs = dict (remove_none = False , remove_types = [])
641+ if xp is None :
642+ xp , is_array_api = get_namespace (* array_list , ** skip_remove_kwargs )
643+ else :
644+ xp , is_array_api = xp , True
620645
621- xp , is_array_api = get_namespace (* array_list , ** skip_remove_kwargs )
622- arrays_device = device (* array_list , ** skip_remove_kwargs )
623646 if is_array_api :
624647 return xp , is_array_api , arrays_device
625648 else :
@@ -769,49 +792,66 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None):
769792 return sum_ / scale
770793
771794
795+ def _xlogy (x , y , xp = None ):
796+ # TODO: Remove this once https://github.com/scipy/scipy/issues/21736 is fixed
797+ xp , _ , device_ = get_namespace_and_device (x , y , xp = xp )
798+
799+ with numpy .errstate (divide = "ignore" , invalid = "ignore" ):
800+ temp = x * xp .log (y )
801+ return xp .where (x == 0.0 , xp .asarray (0.0 , dtype = temp .dtype , device = device_ ), temp )
802+
803+
772804def _nanmin (X , axis = None , xp = None ):
773805 # TODO: refactor once nan-aware reductions are standardized:
774806 # https://github.com/data-apis/array-api/issues/621
775- xp , _ = get_namespace (X , xp = xp )
807+ xp , _ , device_ = get_namespace_and_device (X , xp = xp )
776808 if _is_numpy_namespace (xp ):
777809 return xp .asarray (numpy .nanmin (X , axis = axis ))
778810
779811 else :
780812 mask = xp .isnan (X )
781- X = xp .min (xp .where (mask , xp .asarray (+ xp .inf , device = device (X )), X ), axis = axis )
813+ X = xp .min (
814+ xp .where (mask , xp .asarray (+ xp .inf , dtype = X .dtype , device = device_ ), X ),
815+ axis = axis ,
816+ )
782817 # Replace Infs from all NaN slices with NaN again
783818 mask = xp .all (mask , axis = axis )
784819 if xp .any (mask ):
785- X = xp .where (mask , xp .asarray (xp .nan ), X )
820+ X = xp .where (mask , xp .asarray (xp .nan , dtype = X . dtype , device = device_ ), X )
786821 return X
787822
788823
789824def _nanmax (X , axis = None , xp = None ):
790825 # TODO: refactor once nan-aware reductions are standardized:
791826 # https://github.com/data-apis/array-api/issues/621
792- xp , _ = get_namespace (X , xp = xp )
827+ xp , _ , device_ = get_namespace_and_device (X , xp = xp )
793828 if _is_numpy_namespace (xp ):
794829 return xp .asarray (numpy .nanmax (X , axis = axis ))
795830
796831 else :
797832 mask = xp .isnan (X )
798- X = xp .max (xp .where (mask , xp .asarray (- xp .inf , device = device (X )), X ), axis = axis )
833+ X = xp .max (
834+ xp .where (mask , xp .asarray (- xp .inf , dtype = X .dtype , device = device_ ), X ),
835+ axis = axis ,
836+ )
799837 # Replace Infs from all NaN slices with NaN again
800838 mask = xp .all (mask , axis = axis )
801839 if xp .any (mask ):
802- X = xp .where (mask , xp .asarray (xp .nan ), X )
840+ X = xp .where (mask , xp .asarray (xp .nan , dtype = X . dtype , device = device_ ), X )
803841 return X
804842
805843
806844def _nanmean (X , axis = None , xp = None ):
807845 # TODO: refactor once nan-aware reductions are standardized:
808846 # https://github.com/data-apis/array-api/issues/621
809- xp , _ = get_namespace (X , xp = xp )
847+ xp , _ , device_ = get_namespace_and_device (X , xp = xp )
810848 if _is_numpy_namespace (xp ):
811849 return xp .asarray (numpy .nanmean (X , axis = axis ))
812850 else :
813851 mask = xp .isnan (X )
814- total = xp .sum (xp .where (mask , xp .asarray (0.0 , device = device (X )), X ), axis = axis )
852+ total = xp .sum (
853+ xp .where (mask , xp .asarray (0.0 , dtype = X .dtype , device = device_ ), X ), axis = axis
854+ )
815855 count = xp .sum (xp .astype (xp .logical_not (mask ), X .dtype ), axis = axis )
816856 return total / count
817857
@@ -868,6 +908,8 @@ def _convert_to_numpy(array, xp):
868908 return array .cpu ().numpy ()
869909 elif xp_name in {"array_api_compat.cupy" , "cupy" }: # pragma: nocover
870910 return array .get ()
911+ elif xp_name in {"array_api_strict" }:
912+ return numpy .asarray (xp .asarray (array , device = xp .Device ("CPU_DEVICE" )))
871913
872914 return numpy .asarray (array )
873915
0 commit comments