12
12
import ot
13
13
from ot .sliced import get_random_projections
14
14
from ot .backend import tf , torch
15
+ from contextlib import nullcontext
15
16
16
17
17
18
def test_get_random_projections ():
@@ -790,8 +791,6 @@ def test_min_pivot_sliced(nx):
790
791
ot .sliced .min_pivot_sliced (x [1 :, :], y , thetas = thetas )
791
792
792
793
793
- @pytest .skip_backend ("tf" ) # skips because of array assignment
794
- @pytest .skip_backend ("jax" )
795
794
def test_expected_sliced (nx ):
796
795
n = 10
797
796
n_proj = 10
@@ -805,28 +804,35 @@ def test_expected_sliced(nx):
805
804
thetas = ot .sliced .get_random_projections (d , n_proj , seed = 0 ).T
806
805
thetas_b = nx .from_numpy (thetas )
807
806
808
- expected_plan , expected_cost = ot .sliced .expected_sliced (x , y , thetas = thetas )
809
- expected_plan_b , expected_cost_b , _ = ot .sliced .expected_sliced (
810
- x_b , y_b , thetas = thetas_b , log = True
807
+ context = (
808
+ nullcontext ()
809
+ if str (nx ) not in ["tf" , "jax" ]
810
+ else pytest .raises (NotImplementedError )
811
811
)
812
812
813
- np .testing .assert_almost_equal (expected_plan , nx .to_numpy (expected_plan_b ))
814
- np .testing .assert_almost_equal (expected_cost , nx .to_numpy (expected_cost_b ))
813
+ with context :
814
+ expected_plan , expected_cost = ot .sliced .expected_sliced (x , y , thetas = thetas )
815
+ expected_plan_b , expected_cost_b , _ = ot .sliced .expected_sliced (
816
+ x_b , y_b , thetas = thetas_b , log = True
817
+ )
815
818
816
- # result should be a coarse upper-bound of W2
817
- w2 = ot .emd2 (ot .unif (n ), ot .unif (n ), ot .dist (x , y ))
818
- assert expected_cost >= w2
819
- assert expected_cost <= 3 * w2
819
+ np .testing .assert_almost_equal (expected_plan , nx .to_numpy (expected_plan_b ))
820
+ np .testing .assert_almost_equal (expected_cost , nx .to_numpy (expected_cost_b ))
820
821
821
- # test without provided thetas
822
- ot .sliced .expected_sliced (x , y , n_proj = n_proj , log = True )
822
+ # result should be a coarse upper-bound of W2
823
+ w2 = ot .emd2 (ot .unif (n ), ot .unif (n ), ot .dist (x , y ))
824
+ assert expected_cost >= w2
825
+ assert expected_cost <= 3 * w2
823
826
824
- # test with invalid shapes
825
- with pytest .raises (AssertionError ):
826
- ot .sliced .min_pivot_sliced (x [1 :, :], y , thetas = thetas )
827
+ # test without provided thetas
828
+ ot .sliced .expected_sliced (x , y , n_proj = n_proj , log = True )
829
+
830
+ # test with invalid shapes
831
+ with pytest .raises (AssertionError ):
832
+ ot .sliced .min_pivot_sliced (x [1 :, :], y , thetas = thetas )
827
833
828
- # with a small temperature (i.e. large beta),
829
- # the cost should be close to min_pivot
830
- _ , expected_cost = ot .sliced .expected_sliced (x , y , thetas = thetas , beta = 100.0 )
831
- _ , min_cost = ot .sliced .min_pivot_sliced (x , y , thetas = thetas )
832
- np .testing .assert_almost_equal (expected_cost , min_cost , decimal = 3 )
834
+ # with a small temperature (i.e. large beta), the cost should be close
835
+ # to min_pivot
836
+ _ , expected_cost = ot .sliced .expected_sliced (x , y , thetas = thetas , beta = 100.0 )
837
+ _ , min_cost = ot .sliced .min_pivot_sliced (x , y , thetas = thetas )
838
+ np .testing .assert_almost_equal (expected_cost , min_cost , decimal = 3 )
0 commit comments