Incorrect computation of cost_correction
matrix in ot.da.EMDTransport
#664
Labels
cost_correction
matrix in ot.da.EMDTransport
#664
Describe the bug
It seems that the
cost_correction
matrix is computed incorrectly. This is the current code that can be found here:The issues:
label_match
is False ifys[i] != yt[i]
.However, if
ys[i] != yt[i]
then(ys[:, None] - yt[None, :]) != 0
will be True, hencelabel_match
will be True - although the labels do not match (the naming is confusing in this case). Therefore, eitherlabel_mismatch
and the comment should be fixed ORlabel_match = (ys[:, None] - yt[None, :]) == 0
and flip the value incost_correction
, i.e.cost_correction = (1 - label_match) * ...
cost_correction = label_match * missing_labels * self.limit_max
will apply a cost correction only ifmissing_labels
is True. However, it must not correct ifmissing_labels
is True - hence, we need to flip it to... * (1 - missing_labels ) * ...
Therefore, I'd propose the following change
Happy to send the corresponding PR if you agree.
Screenshots
The following screenshots show the effect of flipping the
missing_labels
value. Here we map samples across multiple Gaussian distributions with 2 labels (p = 1 and p = 2). All labels are given. Without the fix, the transport plans are not computed correctly. With the fix, only samples from the same target class are linked.Environment (please complete the following information):
Linux-4.18.0-372.75.1.el8_6.x86_64-x86_64-with-glibc2.28
Python 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0]
NumPy 2.0.0
SciPy 1.14.0
POT 0.9.4 (pip installed)
The text was updated successfully, but these errors were encountered: