@@ -465,8 +465,8 @@ def test_partial_fgw2_gradients():
465
465
@pytest .skip_backend ("tf" , reason = "test very slow with tf backend" )
466
466
def test_entropic_partial_gromov_wasserstein (nx ):
467
467
rng = np .random .RandomState (42 )
468
- n_samples = 20 # nb samples
469
- n_noise = 10 # nb of samples (noise)
468
+ n_samples = 10 # nb samples
469
+ n_noise = 5 # nb of samples (noise)
470
470
471
471
p = ot .unif (n_samples + n_noise )
472
472
psub = ot .unif (n_samples - 5 + n_noise )
@@ -516,6 +516,7 @@ def test_entropic_partial_gromov_wasserstein(nx):
516
516
log = True ,
517
517
symmetric = list_sym [i ],
518
518
verbose = True ,
519
+ numItermax = 10 ,
519
520
)
520
521
521
522
resb , logb = ot .gromov .entropic_partial_gromov_wasserstein (
@@ -530,6 +531,7 @@ def test_entropic_partial_gromov_wasserstein(nx):
530
531
log = True ,
531
532
symmetric = False ,
532
533
verbose = True ,
534
+ numItermax = 10 ,
533
535
)
534
536
535
537
resb_ = nx .to_numpy (resb )
@@ -552,6 +554,7 @@ def test_entropic_partial_gromov_wasserstein(nx):
552
554
log = False ,
553
555
symmetric = list_sym [i ],
554
556
verbose = True ,
557
+ numItermax = 10 ,
555
558
)
556
559
557
560
resb = ot .gromov .entropic_partial_gromov_wasserstein (
@@ -564,6 +567,7 @@ def test_entropic_partial_gromov_wasserstein(nx):
564
567
log = False ,
565
568
symmetric = False ,
566
569
verbose = True ,
570
+ numItermax = 10 ,
567
571
)
568
572
569
573
resb_ = nx .to_numpy (resb )
@@ -573,11 +577,25 @@ def test_entropic_partial_gromov_wasserstein(nx):
573
577
# tests with different number of samples across spaces
574
578
m = 0.5
575
579
res , log = ot .gromov .entropic_partial_gromov_wasserstein (
576
- C1 , C1sub , p = p , q = psub , reg = 1e4 , m = m , log = True
580
+ C1 ,
581
+ C1sub ,
582
+ p = p ,
583
+ q = psub ,
584
+ reg = 1e4 ,
585
+ m = m ,
586
+ log = True ,
587
+ numItermax = 10 ,
577
588
)
578
589
579
590
resb , logb = ot .gromov .entropic_partial_gromov_wasserstein (
580
- C1b , C1subb , p = pb , q = psubb , reg = 1e4 , m = m , log = True
591
+ C1b ,
592
+ C1subb ,
593
+ p = pb ,
594
+ q = psubb ,
595
+ reg = 1e4 ,
596
+ m = m ,
597
+ log = True ,
598
+ numItermax = 10 ,
581
599
)
582
600
583
601
resb_ = nx .to_numpy (resb )
@@ -589,10 +607,26 @@ def test_entropic_partial_gromov_wasserstein(nx):
589
607
# tests for pGW2
590
608
for loss_fun in ["square_loss" , "kl_loss" ]:
591
609
w0 , log0 = ot .gromov .entropic_partial_gromov_wasserstein2 (
592
- C1 , C2 , p = None , q = q , reg = 1e4 , m = m , loss_fun = loss_fun , log = True
610
+ C1 ,
611
+ C2 ,
612
+ p = None ,
613
+ q = q ,
614
+ reg = 1e4 ,
615
+ m = m ,
616
+ loss_fun = loss_fun ,
617
+ log = True ,
618
+ numItermax = 10 ,
593
619
)
594
620
w0_val = ot .gromov .entropic_partial_gromov_wasserstein2 (
595
- C1b , C2b , p = pb , q = None , reg = 1e4 , m = m , loss_fun = loss_fun , log = False
621
+ C1b ,
622
+ C2b ,
623
+ p = pb ,
624
+ q = None ,
625
+ reg = 1e4 ,
626
+ m = m ,
627
+ loss_fun = loss_fun ,
628
+ log = False ,
629
+ numItermax = 10 ,
596
630
)
597
631
np .testing .assert_allclose (w0 , w0_val , rtol = 1e-8 )
598
632
@@ -666,6 +700,7 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
666
700
log = True ,
667
701
symmetric = list_sym [i ],
668
702
verbose = True ,
703
+ numItermax = 10 ,
669
704
)
670
705
671
706
resb , logb = ot .gromov .entropic_partial_fused_gromov_wasserstein (
@@ -681,6 +716,7 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
681
716
log = True ,
682
717
symmetric = False ,
683
718
verbose = True ,
719
+ numItermax = 10 ,
684
720
)
685
721
686
722
resb_ = nx .to_numpy (resb )
@@ -704,6 +740,7 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
704
740
log = False ,
705
741
symmetric = list_sym [i ],
706
742
verbose = True ,
743
+ numItermax = 10 ,
707
744
)
708
745
709
746
resb = ot .gromov .entropic_partial_fused_gromov_wasserstein (
@@ -717,6 +754,7 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
717
754
log = False ,
718
755
symmetric = False ,
719
756
verbose = True ,
757
+ numItermax = 10 ,
720
758
)
721
759
722
760
resb_ = nx .to_numpy (resb )
@@ -726,11 +764,27 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
726
764
# tests with different number of samples across spaces
727
765
m = 0.5
728
766
res , log = ot .gromov .entropic_partial_fused_gromov_wasserstein (
729
- M11sub , C1 , C1sub , p = p , q = psub , reg = 1e4 , m = m , log = True
767
+ M11sub ,
768
+ C1 ,
769
+ C1sub ,
770
+ p = p ,
771
+ q = psub ,
772
+ reg = 1e4 ,
773
+ m = m ,
774
+ log = True ,
775
+ numItermax = 10 ,
730
776
)
731
777
732
778
resb , logb = ot .gromov .entropic_partial_fused_gromov_wasserstein (
733
- M11subb , C1b , C1subb , p = pb , q = psubb , reg = 1e4 , m = m , log = True
779
+ M11subb ,
780
+ C1b ,
781
+ C1subb ,
782
+ p = pb ,
783
+ q = psubb ,
784
+ reg = 1e4 ,
785
+ m = m ,
786
+ log = True ,
787
+ numItermax = 10 ,
734
788
)
735
789
736
790
resb_ = nx .to_numpy (resb )
@@ -742,9 +796,27 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
742
796
# tests for pGW2
743
797
for loss_fun in ["square_loss" , "kl_loss" ]:
744
798
w0 , log0 = ot .gromov .entropic_partial_fused_gromov_wasserstein2 (
745
- M12 , C1 , C2 , p = None , q = q , reg = 1e4 , m = m , loss_fun = loss_fun , log = True
799
+ M12 ,
800
+ C1 ,
801
+ C2 ,
802
+ p = None ,
803
+ q = q ,
804
+ reg = 1e4 ,
805
+ m = m ,
806
+ loss_fun = loss_fun ,
807
+ log = True ,
808
+ numItermax = 10 ,
746
809
)
747
810
w0_val = ot .gromov .entropic_partial_fused_gromov_wasserstein2 (
748
- M12b , C1b , C2b , p = pb , q = None , reg = 1e4 , m = m , loss_fun = loss_fun , log = False
811
+ M12b ,
812
+ C1b ,
813
+ C2b ,
814
+ p = pb ,
815
+ q = None ,
816
+ reg = 1e4 ,
817
+ m = m ,
818
+ loss_fun = loss_fun ,
819
+ log = False ,
820
+ numItermax = 10 ,
749
821
)
750
822
np .testing .assert_allclose (w0 , w0_val , rtol = 1e-8 )
0 commit comments