Skip to content

Wasserstein Circle distance doesn't seem correct? #738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ckp95 opened this issue May 20, 2025 · 4 comments
Open

Wasserstein Circle distance doesn't seem correct? #738

ckp95 opened this issue May 20, 2025 · 4 comments

Comments

@ckp95
Copy link

ckp95 commented May 20, 2025

Describe the bug

Forgive me if I've misunderstood what the wasserstein_circle function is supposed to do. But I would have thought that if we have

d1 = wasserstein_circle(arr1, arr2)

And then we add the same amount to both arrays (i.e. rotating both samples the same angle around the circle), then we should get the same answer.

d2 = wasserstein_circle(arr1 + delta, arr2 + delta)
assert d1 == d2

But the first example I tried fails.

To Reproduce

import numpy as np
import ot

sample1 = np.array([0.1, 0.11, 0.4, 0.6])
sample2 = np.array([0.21, 0.15, 0.7, 0.95])

d1 = ot.wasserstein_circle(sample1, sample2)

delta = 0.02

d2 = ot.wasserstein_circle(sample1 + delta, sample2 + delta)

assert d1 == d2 # fails

Expected behavior

wasserstein_circle should be rotationally symmetric, i.e. it should obey the property

d1 = wasserstein_circle(arr1, arr2)
d2 = wasserstein_circle((arr1 + delta) % 1, (arr2 + delta) % 1)
assert d1 == d2

For all real delta (up to floating point inaccuracies), because this amounts to just turning your head to the side.

Or am I just misunderstanding how the input is supposed to be represented here?

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): NixOS 24.11
  • Python version: 3.12.7
  • How was POT installed (source, pip, conda): Nix

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-6.6.67-x86_64-with-glibc2.40
Python 3.12.7 (main, Oct  1 2024, 02:05:46) [GCC 13.3.0]
NumPy 1.26.4
SciPy 1.14.1
POT 0.9.4

Additional context

@clbonet
Copy link
Contributor

clbonet commented May 20, 2025

You are right, wasserstein_circle should be rotationally symmetric. It seems that the problem comes from the closed-form we are trying to approximate in the case p=1, which is the default case of wasserstein_circle, and does not seem to work always well. If you try with the function binary_search_circle, it should be well rotationally symmetric.

I just proposed in #PR736 to change the default behaviour of wasserstein_circle to always use the binary_search_circle, which seems to work better in general.

@ckp95
Copy link
Author

ckp95 commented May 20, 2025

Okay, good to know I'm not going crazy 😅

Have you considered using hypothesis to property-test this library? It works particularly well for mathematical code with lots of clean invariants / symmetries. This property-test immediately finds a minimal counterexample:

import hypothesis.strategies as st
from hypothesis import given
import numpy as np
import ot


@given(
    arr1=st.lists(
        st.floats(min_value=0, max_value=1),
        min_size=1, max_size=10,
    ),
    arr2=st.lists(
        st.floats(min_value=0, max_value=1),
        min_size=1, max_size=10,
    ),
    delta=st.floats(min_value=-1, max_value=1)
)
def test_wasserstein_circle_is_rotationally_symmetric(arr1, arr2, delta):
    arr1 = np.array(arr1)
    arr2 = np.array(arr2)

    [d1] = ot.wasserstein_circle(arr1, arr2)
    [d2] = ot.wasserstein_circle((arr1 + delta) % 1, (arr2 + delta) % 1)

    assert np.isclose(d1, d2, atol=1e-3)
E       Falsifying example: test_wasserstein_circle_is_rotationally_symmetric(                                               
E           arr1=[0.0],                                                                                                      
E           arr2=[0.125],                                                                                                    
E           delta=0.5,                                                                                                       
E       )  

The binary_search_circle function passes this test 🙂

@ckp95
Copy link
Author

ckp95 commented May 20, 2025

The binary_search_circle function passes this test 🙂

Actually not quite, this gives a warning:

E               RuntimeWarning: invalid value encountered in divide
E               Falsifying example: test_binary_search_circle_is_rotationally_symmetric(
E                   arr1=[0.5],
E                   arr2=[0.25],
E                   delta=0.5,
E               )

Standalone example:

import numpy as np
import ot
import warnings
warnings.filterwarnings("error")

arr1 = np.array([0.5])
arr2 = np.array([0.25])
delta = 0.5

[d1] = ot.binary_search_circle(arr1, arr2)
[d2] = ot.binary_search_circle((arr1 + delta) % 1, (arr2 + delta) % 1)

assert np.isclose(d1, d2, atol=1e-3)
---------------------------------------------------------------------------
RuntimeWarning                            Traceback (most recent call last)
Cell In[16], line 9
      6 delta = 0.5
      8 [d1] = ot.binary_search_circle(arr1, arr2)
----> 9 [d2] = ot.binary_search_circle((arr1 + delta) % 1, (arr2 + delta) % 1)
     11 assert np.isclose(d1, d2, atol=1e-3)

File /nix/store/nlrk6vd18px7gg88yjaxgn0l8dalpbcg-python3-3.12.7-env/lib/python3.12/site-packages/ot/lp/solver_1d.py:737, in binary_search_circle(u_values, v_values, u_weights, v_weights, p, Lm, Lp, tm, tp, eps, require_sort, log)
    734     Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
    736     mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001)
--> 737     tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0]
    738     done[nx.prod(mask, axis=-1) > 0] = 1
    739 elif nx.any(1 - done):

RuntimeWarning: invalid value encountered in divide

Not sure if this is a real bug but thought it was worth recording.

@clbonet
Copy link
Contributor

clbonet commented May 20, 2025

Sorry, it's my bad ^^. I shouldn't have use this implementation for the default behaviour...

I didn't know the hypothesis library. It seems very useful! I am reassured that binary_search_circle seems to pass the test...

Concerning the binary_search_circle, thank you for noticing this little bug. I think this warning arises as dCptm - dCmtp==0, but it is not a problem as the corresponding values are not updated thanks to the condition mask_end>0. However, it would be nice to avoid to raise a warning in this case. I will see if I can fix this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants