Skip to content

Commit ef7fe4e

Browse files
authored
[MRG] speedup and update tests and wheels and prepare for release 0.9.6 (#759)
* fatser tests * add tests * add new build wheels * add python 3.13 * keep version macos * python 3.12 on macosx * release * ste release version number * remove macos-13 * move solve_batch and solve_gromov_btac and dist in main ot
1 parent 17699d7 commit ef7fe4e

File tree

10 files changed

+132
-60
lines changed

10 files changed

+132
-60
lines changed

.github/workflows/build_tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
strategy:
4848
max-parallel: 4
4949
matrix:
50-
python-version: ["3.9", "3.10", "3.11", "3.12"]
50+
python-version: ["3.10", "3.11", "3.12", "3.13"]
5151

5252
steps:
5353
- uses: actions/checkout@v4
@@ -78,7 +78,7 @@ jobs:
7878
- name: Set up Python
7979
uses: actions/setup-python@v5
8080
with:
81-
python-version: "3.12"
81+
python-version: "3.13"
8282
- name: Install dependencies
8383
run: |
8484
python -m pip install --upgrade pip setuptools
@@ -98,7 +98,7 @@ jobs:
9898
strategy:
9999
max-parallel: 4
100100
matrix:
101-
os: [macos-latest, macos-13]
101+
os: [macos-latest]
102102
python-version: ["3.12"]
103103

104104
steps:

.github/workflows/build_wheels.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
3131
- name: Install cibuildwheel
3232
run: |
33-
python -m pip install cibuildwheel==2.23.3
33+
python -m pip install cibuildwheel==3.1.4
3434
3535
- name: Build wheels
3636
env:
@@ -65,7 +65,7 @@ jobs:
6565
6666
- name: Install cibuildwheel
6767
run: |
68-
python -m pip install cibuildwheel==2.16.4
68+
python -m pip install cibuildwheel==3.1.4
6969
7070
- name: Set up QEMU
7171
if: runner.os == 'Linux'

.github/workflows/build_wheels_weekly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
3030
- name: Install cibuildwheel
3131
run: |
32-
python -m pip install cibuildwheel==2.23.3
32+
python -m pip install cibuildwheel==3.1.4
3333
3434
- name: Set up QEMU
3535
if: runner.os == 'Linux'

RELEASES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Releases
22

3-
## 0.9.6dev
3+
## 0.9.6
44

55
#### New features
66
- Implement CG solvers for partial FGW (PR #687)
@@ -28,6 +28,7 @@
2828
- Removed release information from quickstart guide (PR #744)
2929
- Implement batch parallel solvers in ot.batch (PR #745)
3030
- Update REAMDE with new API and reorganize examples (PR #754)
31+
- Speedup and update tests and wheels (PR #759)
3132

3233
#### Closed issues
3334
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)

examples/backends/plot_ot_batch.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
ot.dist(samples_source[i], samples_target[i])
5555
) # List of cost matrices n_samples x n_samples
5656
# Batched approach
57-
M_batch = ot.batch.dist_batch(
57+
M_batch = ot.dist_batch(
5858
samples_source, samples_target
5959
) # Array of cost matrices n_problems x n_samples x n_samples
6060

@@ -88,7 +88,7 @@
8888
results_values_list.append(res.value_linear)
8989

9090
# Batched approach
91-
results_batch = ot.batch.solve_batch(
91+
results_batch = ot.solve_batch(
9292
M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy"
9393
)
9494
results_values_batch = results_batch.value_linear
@@ -131,8 +131,8 @@ def benchmark_naive(samples_source, samples_target):
131131

132132
def benchmark_batch(samples_source, samples_target):
133133
start = perf_counter()
134-
M_batch = ot.batch.dist_batch(samples_source, samples_target)
135-
res_batch = ot.batch.solve_batch(
134+
M_batch = ot.dist_batch(samples_source, samples_target)
135+
res_batch = ot.solve_batch(
136136
M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy"
137137
)
138138
end = perf_counter()
@@ -176,8 +176,7 @@ def benchmark_batch(samples_source, samples_target):
176176
# If your data is on a GPU, :func:`ot.batch.solve_gromov_batch`
177177
# is significantly faster AND provides better objective values.
178178

179-
from ot import solve_gromov
180-
from ot.batch import solve_gromov_batch
179+
from ot import solve_gromov, solve_gromov_batch
181180

182181

183182
def benchmark_naive_gw(samples_source, samples_target):
@@ -195,8 +194,8 @@ def benchmark_naive_gw(samples_source, samples_target):
195194

196195
def benchmark_batch_gw(samples_source, samples_target):
197196
start = perf_counter()
198-
C1_batch = ot.batch.dist_batch(samples_source, samples_source)
199-
C2_batch = ot.batch.dist_batch(samples_target, samples_target)
197+
C1_batch = ot.dist_batch(samples_source, samples_source)
198+
C2_batch = ot.dist_batch(samples_target, samples_target)
200199
res_batch = solve_gromov_batch(
201200
C1_batch, C2_batch, reg=1, max_iter=100, max_iter_inner=50, tol=tol
202201
)

ot/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@
7272
from .solvers import solve, solve_gromov, solve_sample
7373
from .lowrank import lowrank_sinkhorn
7474

75-
from .batch import solve_batch, solve_gromov_batch
75+
from .batch import solve_batch, solve_sample_batch, solve_gromov_batch, dist_batch
7676

7777
# utils functions
7878
from .utils import dist, unif, tic, toc, toq
7979

80-
__version__ = "0.9.6dev0"
80+
__version__ = "0.9.6"
8181

8282
__all__ = [
8383
"emd",
@@ -139,4 +139,6 @@
139139
"lowrank_gromov_wasserstein_samples",
140140
"solve_batch",
141141
"solve_gromov_batch",
142+
"solve_sample_batch",
143+
"dist_batch",
142144
]

test/gromov/test_partial.py

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,8 @@ def test_partial_fgw2_gradients():
465465
@pytest.skip_backend("tf", reason="test very slow with tf backend")
466466
def test_entropic_partial_gromov_wasserstein(nx):
467467
rng = np.random.RandomState(42)
468-
n_samples = 20 # nb samples
469-
n_noise = 10 # nb of samples (noise)
468+
n_samples = 10 # nb samples
469+
n_noise = 5 # nb of samples (noise)
470470

471471
p = ot.unif(n_samples + n_noise)
472472
psub = ot.unif(n_samples - 5 + n_noise)
@@ -516,6 +516,7 @@ def test_entropic_partial_gromov_wasserstein(nx):
516516
log=True,
517517
symmetric=list_sym[i],
518518
verbose=True,
519+
numItermax=10,
519520
)
520521

521522
resb, logb = ot.gromov.entropic_partial_gromov_wasserstein(
@@ -530,6 +531,7 @@ def test_entropic_partial_gromov_wasserstein(nx):
530531
log=True,
531532
symmetric=False,
532533
verbose=True,
534+
numItermax=10,
533535
)
534536

535537
resb_ = nx.to_numpy(resb)
@@ -552,6 +554,7 @@ def test_entropic_partial_gromov_wasserstein(nx):
552554
log=False,
553555
symmetric=list_sym[i],
554556
verbose=True,
557+
numItermax=10,
555558
)
556559

557560
resb = ot.gromov.entropic_partial_gromov_wasserstein(
@@ -564,6 +567,7 @@ def test_entropic_partial_gromov_wasserstein(nx):
564567
log=False,
565568
symmetric=False,
566569
verbose=True,
570+
numItermax=10,
567571
)
568572

569573
resb_ = nx.to_numpy(resb)
@@ -573,11 +577,25 @@ def test_entropic_partial_gromov_wasserstein(nx):
573577
# tests with different number of samples across spaces
574578
m = 0.5
575579
res, log = ot.gromov.entropic_partial_gromov_wasserstein(
576-
C1, C1sub, p=p, q=psub, reg=1e4, m=m, log=True
580+
C1,
581+
C1sub,
582+
p=p,
583+
q=psub,
584+
reg=1e4,
585+
m=m,
586+
log=True,
587+
numItermax=10,
577588
)
578589

579590
resb, logb = ot.gromov.entropic_partial_gromov_wasserstein(
580-
C1b, C1subb, p=pb, q=psubb, reg=1e4, m=m, log=True
591+
C1b,
592+
C1subb,
593+
p=pb,
594+
q=psubb,
595+
reg=1e4,
596+
m=m,
597+
log=True,
598+
numItermax=10,
581599
)
582600

583601
resb_ = nx.to_numpy(resb)
@@ -589,10 +607,26 @@ def test_entropic_partial_gromov_wasserstein(nx):
589607
# tests for pGW2
590608
for loss_fun in ["square_loss", "kl_loss"]:
591609
w0, log0 = ot.gromov.entropic_partial_gromov_wasserstein2(
592-
C1, C2, p=None, q=q, reg=1e4, m=m, loss_fun=loss_fun, log=True
610+
C1,
611+
C2,
612+
p=None,
613+
q=q,
614+
reg=1e4,
615+
m=m,
616+
loss_fun=loss_fun,
617+
log=True,
618+
numItermax=10,
593619
)
594620
w0_val = ot.gromov.entropic_partial_gromov_wasserstein2(
595-
C1b, C2b, p=pb, q=None, reg=1e4, m=m, loss_fun=loss_fun, log=False
621+
C1b,
622+
C2b,
623+
p=pb,
624+
q=None,
625+
reg=1e4,
626+
m=m,
627+
loss_fun=loss_fun,
628+
log=False,
629+
numItermax=10,
596630
)
597631
np.testing.assert_allclose(w0, w0_val, rtol=1e-8)
598632

@@ -666,6 +700,7 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
666700
log=True,
667701
symmetric=list_sym[i],
668702
verbose=True,
703+
numItermax=10,
669704
)
670705

671706
resb, logb = ot.gromov.entropic_partial_fused_gromov_wasserstein(
@@ -681,6 +716,7 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
681716
log=True,
682717
symmetric=False,
683718
verbose=True,
719+
numItermax=10,
684720
)
685721

686722
resb_ = nx.to_numpy(resb)
@@ -704,6 +740,7 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
704740
log=False,
705741
symmetric=list_sym[i],
706742
verbose=True,
743+
numItermax=10,
707744
)
708745

709746
resb = ot.gromov.entropic_partial_fused_gromov_wasserstein(
@@ -717,6 +754,7 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
717754
log=False,
718755
symmetric=False,
719756
verbose=True,
757+
numItermax=10,
720758
)
721759

722760
resb_ = nx.to_numpy(resb)
@@ -726,11 +764,27 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
726764
# tests with different number of samples across spaces
727765
m = 0.5
728766
res, log = ot.gromov.entropic_partial_fused_gromov_wasserstein(
729-
M11sub, C1, C1sub, p=p, q=psub, reg=1e4, m=m, log=True
767+
M11sub,
768+
C1,
769+
C1sub,
770+
p=p,
771+
q=psub,
772+
reg=1e4,
773+
m=m,
774+
log=True,
775+
numItermax=10,
730776
)
731777

732778
resb, logb = ot.gromov.entropic_partial_fused_gromov_wasserstein(
733-
M11subb, C1b, C1subb, p=pb, q=psubb, reg=1e4, m=m, log=True
779+
M11subb,
780+
C1b,
781+
C1subb,
782+
p=pb,
783+
q=psubb,
784+
reg=1e4,
785+
m=m,
786+
log=True,
787+
numItermax=10,
734788
)
735789

736790
resb_ = nx.to_numpy(resb)
@@ -742,9 +796,27 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
742796
# tests for pGW2
743797
for loss_fun in ["square_loss", "kl_loss"]:
744798
w0, log0 = ot.gromov.entropic_partial_fused_gromov_wasserstein2(
745-
M12, C1, C2, p=None, q=q, reg=1e4, m=m, loss_fun=loss_fun, log=True
799+
M12,
800+
C1,
801+
C2,
802+
p=None,
803+
q=q,
804+
reg=1e4,
805+
m=m,
806+
loss_fun=loss_fun,
807+
log=True,
808+
numItermax=10,
746809
)
747810
w0_val = ot.gromov.entropic_partial_fused_gromov_wasserstein2(
748-
M12b, C1b, C2b, p=pb, q=None, reg=1e4, m=m, loss_fun=loss_fun, log=False
811+
M12b,
812+
C1b,
813+
C2b,
814+
p=pb,
815+
q=None,
816+
reg=1e4,
817+
m=m,
818+
loss_fun=loss_fun,
819+
log=False,
820+
numItermax=10,
749821
)
750822
np.testing.assert_allclose(w0, w0_val, rtol=1e-8)

test/test_da.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -912,8 +912,8 @@ def test_emd_laplace_class(nx):
912912
def test_nearest_brenier_potential(nx):
913913
X = nx.ones((2, 2))
914914
for ssnb in [
915-
ot.da.NearestBrenierPotential(log=True),
916-
ot.da.NearestBrenierPotential(log=False),
915+
ot.da.NearestBrenierPotential(log=True, its=5),
916+
ot.da.NearestBrenierPotential(log=False, its=5),
917917
]:
918918
ssnb.fit(Xs=X, Xt=X)
919919
G_lu = ssnb.transform(Xs=X)

0 commit comments

Comments
 (0)