Skip to content

Commit 08c5348

Browse files
committed
raise NotImplementedError when expected_sliced is used with tf or jax
1 parent 7c9761f commit 08c5348

File tree

2 files changed

+37
-21
lines changed

2 files changed

+37
-21
lines changed

ot/sliced.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,9 @@ def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False, beta=0.0
853853
.. note::
854854
The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases.
855855
856+
.. warning::
857+
The function runs on backend but tensorflow and jax are not supported due to array assignment.
858+
856859
Parameters
857860
----------
858861
X : torch.Tensor
@@ -888,8 +891,15 @@ def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False, beta=0.0
888891
assert (
889892
X.shape == Y.shape
890893
), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape"
894+
891895
nx = get_backend(X, Y)
896+
if str(nx) in ["tf", "jax"]:
897+
raise NotImplementedError(
898+
f"expected_sliced is not implemented for the {str(nx)} backend due"
899+
"to array assignment."
900+
)
892901
n = X.shape[0]
902+
893903
log_dict = {}
894904
if log:
895905
perm, log_dict = sliced_permutations(

test/test_sliced.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import ot
1313
from ot.sliced import get_random_projections
1414
from ot.backend import tf, torch
15+
from contextlib import nullcontext
1516

1617

1718
def test_get_random_projections():
@@ -790,8 +791,6 @@ def test_min_pivot_sliced(nx):
790791
ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas)
791792

792793

793-
@pytest.skip_backend("tf") # skips because of array assignment
794-
@pytest.skip_backend("jax")
795794
def test_expected_sliced(nx):
796795
n = 10
797796
n_proj = 10
@@ -805,28 +804,35 @@ def test_expected_sliced(nx):
805804
thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T
806805
thetas_b = nx.from_numpy(thetas)
807806

808-
expected_plan, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas)
809-
expected_plan_b, expected_cost_b, _ = ot.sliced.expected_sliced(
810-
x_b, y_b, thetas=thetas_b, log=True
807+
context = (
808+
nullcontext()
809+
if str(nx) not in ["tf", "jax"]
810+
else pytest.raises(NotImplementedError)
811811
)
812812

813-
np.testing.assert_almost_equal(expected_plan, nx.to_numpy(expected_plan_b))
814-
np.testing.assert_almost_equal(expected_cost, nx.to_numpy(expected_cost_b))
813+
with context:
814+
expected_plan, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas)
815+
expected_plan_b, expected_cost_b, _ = ot.sliced.expected_sliced(
816+
x_b, y_b, thetas=thetas_b, log=True
817+
)
815818

816-
# result should be a coarse upper-bound of W2
817-
w2 = ot.emd2(ot.unif(n), ot.unif(n), ot.dist(x, y))
818-
assert expected_cost >= w2
819-
assert expected_cost <= 3 * w2
819+
np.testing.assert_almost_equal(expected_plan, nx.to_numpy(expected_plan_b))
820+
np.testing.assert_almost_equal(expected_cost, nx.to_numpy(expected_cost_b))
820821

821-
# test without provided thetas
822-
ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True)
822+
# result should be a coarse upper-bound of W2
823+
w2 = ot.emd2(ot.unif(n), ot.unif(n), ot.dist(x, y))
824+
assert expected_cost >= w2
825+
assert expected_cost <= 3 * w2
823826

824-
# test with invalid shapes
825-
with pytest.raises(AssertionError):
826-
ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas)
827+
# test without provided thetas
828+
ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True)
829+
830+
# test with invalid shapes
831+
with pytest.raises(AssertionError):
832+
ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas)
827833

828-
# with a small temperature (i.e. large beta),
829-
# the cost should be close to min_pivot
830-
_, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas, beta=100.0)
831-
_, min_cost = ot.sliced.min_pivot_sliced(x, y, thetas=thetas)
832-
np.testing.assert_almost_equal(expected_cost, min_cost, decimal=3)
834+
# with a small temperature (i.e. large beta), the cost should be close
835+
# to min_pivot
836+
_, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas, beta=100.0)
837+
_, min_cost = ot.sliced.min_pivot_sliced(x, y, thetas=thetas)
838+
np.testing.assert_almost_equal(expected_cost, min_cost, decimal=3)

0 commit comments

Comments
 (0)