-
Notifications
You must be signed in to change notification settings - Fork 540
Description
Describe the bug
When calling ot.gromov.quantized_fused_gromov_wasserstein_partitioned(..., build_OT=True), the returned full transport plan T appears to be indexed in a partition-concatenated node order (nodes grouped by partition id), rather than the original node order of the input spaces (C1, p) and (C2, q).
As a result, T does not satisfy the marginal constraints with respect to the input marginals:
T.sum(axis=1)does not matchpT.sum(axis=0)does not matchq
Instead, the marginals match permuted versions p_perm and q_perm obtained by concatenating node indices partition-by-partition (in increasing partition id order). If T is reindexed back to the original node ordering using the inverse permutations, the expected marginal constraints are restored.
The API/documentation describes T as a coupling “between the two spaces” when build_OT=True. I interpreted this as being indexed in the original node order; if the partition-concatenated ordering is intended, it would be helpful to document that explicitly and/or return the required permutations (or provide a helper) to map T back to the original indexing.
To Reproduce
Steps to reproduce the behavior:
- Save the code sample below as a Python script (e.g.
mre_qgw_partition_order.py). It uses two small 2D point clouds with non-uniform marginals to make the issue visible. - Run it with Python (e.g.
python mre_qgw_partition_order.py).
The script prints four diagnostics:
max abs error (row-sum of T vs. p)andmax abs error (col-sum of T vs. q)are non-zero.max abs error (row-sum of T vs. p_perm)andmax abs error (col-sum of T vs. q_perm)are ~0.- After reindexing (
T_fix),row-sumandcol-summatchpandqagain (~0).
Code sample
import numpy as np
import ot
rng = np.random.default_rng(0)
# generate small 2D point clouds
n1, n2 = 8, 9 # number of points in each point cloud
npart1, npart2 = 3, 4 # number partitions
X1 = rng.normal(size=(n1, 2))
X2 = rng.normal(size=(n2, 2))
C1 = ot.dist(X1, X1, metric="sqeuclidean")
C2 = ot.dist(X2, X2, metric="sqeuclidean")
# make the marginals random (and not uniform) so that it's easy to see the effect of permutation
p = rng.random(n1); p /= p.sum()
q = rng.random(n2); q /= q.sum()
part1 = ot.gromov.get_graph_partition(C1, npart1, part_method="random", random_state=0)
rep1 = ot.gromov.get_graph_representants(C1, part1, rep_method="random", random_state=0)
CR1, list_R1, list_p1 = ot.gromov.format_partitioned_graph(C1, p, part1, rep1)
part2 = ot.gromov.get_graph_partition(C2, npart2, part_method="random", random_state=1)
rep2 = ot.gromov.get_graph_representants(C2, part2, rep_method="random", random_state=1)
CR2, list_R2, list_p2 = ot.gromov.format_partitioned_graph(C2, q, part2, rep2)
_, _, T = ot.gromov.quantized_fused_gromov_wasserstein_partitioned(
CR1, CR2, list_R1, list_R2, list_p1, list_p2, alpha=1.0, build_OT=True, log=False)
# Check the marginal constraints of returned transport plan T
print("max abs error (row-sum of T vs. p):", np.max(np.abs(T.sum(axis=1) - p)))
print("max abs error (col-sum of T vs. q):", np.max(np.abs(T.sum(axis=0) - q)))
# Hypothesis: POT returns T in "partition-concatenated" order (nodes grouped by part labels),
# so its row/col sums match permuted marginals p_perm/q_perm rather than p/q.
part_ids1 = np.unique(part1)
perm1 = np.concatenate([np.flatnonzero(part1 == pid) for pid in part_ids1]) # concatenate indices of each partition in order
p_perm = p[perm1]
part_ids2 = np.unique(part2)
perm2 = np.concatenate([np.flatnonzero(part2 == pid) for pid in part_ids2])
q_perm = q[perm2]
print("max abs error (row-sum of T vs. p_perm):", np.max(np.abs(T.sum(axis=1) - p_perm)))
print("max abs error (col-sum of T vs. q_perm):", np.max(np.abs(T.sum(axis=0) - q_perm)))
# Potential fix: reindex T to original ordering
inv1 = np.argsort(perm1)
inv2 = np.argsort(perm2)
T_fix = T[inv1][:, inv2]
print("max abs error (row-sum of T_fix vs. p):", np.max(np.abs(T_fix.sum(axis=1) - p)))
print("max abs error (col-sum of T_fix vs. q):", np.max(np.abs(T_fix.sum(axis=0) - q)))Expected behavior
When build_OT=True, I expected the returned full transport plan T to be indexed in the original node order of the input spaces (C1, p) and (C2, q), so that:
T.sum(axis=1)matchespT.sum(axis=0)matchesq
(Alternatively, if T is intentionally returned in a different ordering, I expected this ordering (and a mapping back to original indices) to be clearly documented / returned.)
Environment:
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)Environment A
macOS-15.6.1-arm64-arm-64bit
Python 3.12.10
NumPy 2.0.0
SciPy 1.15.2
POT 0.9.5
Environment B
Linux-6.8.0-90-generic-x86_64-with-glibc2.35
Python 3.12.2 | packaged by conda-forge
NumPy 2.1.3
SciPy 1.15.3
POT 0.9.6.post1
Additional context
After reindexing T back to the original node order (i.e., undoing the partition-concatenated ordering), the resulting coupling and objective value agree with the outputs I obtain from the reference implementation by the original authors (QuantizedGromovWasserstein): https://github.com/trneedham/QuantizedGromovWasserstein
You can find two files in the attachment:
qGW_POT_fix.py(original code with the relevant part extracted)
qGW_POT.py(added a potential fix on reindexing issue)