Skip to content

Commit 1fe70c0

Browse files
Circle CICircle CI
authored andcommitted
CircleCI update of dev docs (3297).
1 parent 0ec72bf commit 1fe70c0

File tree

294 files changed

+104304
-100863
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

294 files changed

+104304
-100863
lines changed
Binary file not shown.
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# -*- coding: utf-8 -*-
2+
r"""
3+
=====================================
4+
Gaussian Mixture Model OT Barycenters
5+
=====================================
6+
7+
This example illustrates the computation of a barycenter between Gaussian
8+
Mixtures in the sense of GMM-OT [69]. This computation is done using the
9+
fixed-point method for OT barycenters with generic costs [77], for which POT
10+
provides a general solver, and a specific GMM solver. Note that this is a
11+
'free-support' method, implying that the number of components of the barycenter
12+
GMM and their weights are fixed.
13+
14+
The idea behind GMM-OT barycenters is to see the GMMs as discrete measures over
15+
the space of Gaussian distributions :math:`\mathcal{N}` (or equivalently the
16+
Bures-Wasserstein manifold), and to compute barycenters with respect to the
17+
2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a
18+
gaussian mixture is a finite combination of Diracs on specific gaussians, and
19+
two mixtures are compared with the 2-Wasserstein distance on this space, where
20+
ground cost the squared Bures distance between gaussians.
21+
22+
[69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space
23+
of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
24+
25+
[77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
26+
Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016
27+
(2024)
28+
29+
"""
30+
31+
# Author: Eloi Tanguy <[email protected]>
32+
#
33+
# License: MIT License
34+
35+
# sphinx_gallery_thumbnail_number = 1
36+
37+
# %%
38+
# Generate data
39+
import numpy as np
40+
import matplotlib.pyplot as plt
41+
from matplotlib.patches import Ellipse
42+
import ot
43+
from ot.gmm import gmm_barycenter_fixed_point
44+
45+
46+
K = 3 # number of GMMs
47+
d = 2 # dimension
48+
n = 6 # number of components of the desired barycenter
49+
50+
51+
def get_random_gmm(K, d, seed=0, min_cov_eig=1, cov_scale=1e-2):
52+
rng = np.random.RandomState(seed=seed)
53+
means = rng.randn(K, d)
54+
P = rng.randn(K, d, d) * cov_scale
55+
# C[k] = P[k] @ P[k]^T + min_cov_eig * I
56+
covariances = np.einsum("kab,kcb->kac", P, P)
57+
covariances += min_cov_eig * np.array([np.eye(d) for _ in range(K)])
58+
weights = rng.random(K)
59+
weights /= np.sum(weights)
60+
return means, covariances, weights
61+
62+
63+
m_list = [5, 6, 7] # number of components in each GMM
64+
offsets = [np.array([-3, 0]), np.array([2, 0]), np.array([0, 4])]
65+
means_list = [] # list of means for each GMM
66+
covs_list = [] # list of covariances for each GMM
67+
w_list = [] # list of weights for each GMM
68+
69+
# generate GMMs
70+
for k in range(K):
71+
means, covs, b = get_random_gmm(
72+
m_list[k], d, seed=k, min_cov_eig=0.25, cov_scale=0.5
73+
)
74+
means = means / 2 + offsets[k][None, :]
75+
means_list.append(means)
76+
covs_list.append(covs)
77+
w_list.append(b)
78+
79+
# %%
80+
# Compute the barycenter using the fixed-point method
81+
init_means, init_covs, _ = get_random_gmm(n, d, seed=0)
82+
weights = ot.unif(K) # barycenter coefficients
83+
means_bar, covs_bar, log = gmm_barycenter_fixed_point(
84+
means_list,
85+
covs_list,
86+
w_list,
87+
init_means,
88+
init_covs,
89+
weights,
90+
iterations=3,
91+
log=True,
92+
)
93+
94+
95+
# %%
96+
# Define plotting functions
97+
98+
99+
# draw a covariance ellipse
100+
def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5, ax=None):
101+
def eigsorted(cov):
102+
vals, vecs = np.linalg.eigh(cov)
103+
order = vals.argsort()[::-1].copy()
104+
return vals[order], vecs[:, order]
105+
106+
vals, vecs = eigsorted(C)
107+
theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
108+
w, h = 2 * nstd * np.sqrt(vals)
109+
ell = Ellipse(
110+
xy=(mu[0], mu[1]),
111+
width=w,
112+
height=h,
113+
alpha=alpha,
114+
angle=theta,
115+
facecolor=color,
116+
edgecolor=color,
117+
label=label,
118+
fill=True,
119+
)
120+
if ax is None:
121+
ax = plt.gca()
122+
ax.add_artist(ell)
123+
124+
125+
# draw a gmm as a set of ellipses with weights shown in alpha value
126+
def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None):
127+
for k in range(ms.shape[0]):
128+
draw_cov(
129+
ms[k], Cs[k], color, label if k == 0 else None, nstd, alpha * ws[k], ax=ax
130+
)
131+
132+
133+
# %%
134+
# Plot the results
135+
c_list = ["#7ED321", "#4A90E2", "#9013FE", "#F5A623"]
136+
c_bar = "#D0021B"
137+
fig, ax = plt.subplots(figsize=(6, 6))
138+
axis = [-4, 4, -2, 6]
139+
ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16)
140+
for k in range(K):
141+
draw_gmm(means_list[k], covs_list[k], w_list[k], color=c_list[k], ax=ax)
142+
draw_gmm(means_bar, covs_bar, ot.unif(n), color=c_bar, ax=ax)
143+
ax.axis(axis)
144+
ax.axis("off")
145+
146+
# %%

master/_downloads/0dbd57c6090215001a0a712021c577e5/plot_GMMOT_plan.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
},
1616
"outputs": [],
1717
"source": [
18-
"# Author: Eloi Tanguy <eloi.tanguy@u-paris>\n# Remi Flamary <[email protected]>\n# Julie Delon <[email protected]>\n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1\n\nimport numpy as np\nfrom ot.plot import plot1D_mat, rescale_for_imshow_plot\nfrom ot.gmm import gmm_ot_plan_density, gmm_pdf, gmm_ot_apply_map\nimport matplotlib.pyplot as plt"
18+
"# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>\n# Remi Flamary <[email protected]>\n# Julie Delon <[email protected]>\n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1\n\nimport numpy as np\nfrom ot.plot import plot1D_mat, rescale_for_imshow_plot\nfrom ot.gmm import gmm_ot_plan_density, gmm_pdf, gmm_ot_apply_map\nimport matplotlib.pyplot as plt"
1919
]
2020
},
2121
{
Binary file not shown.

master/_downloads/12cf635d7b9aa9f87c0e3bdc36aaa712/plot_SSNB.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
2017.
4242
"""
4343

44-
# Author: Eloi Tanguy <eloi.tanguy@u-paris.fr>
44+
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
4545
# License: MIT License
4646

4747
# sphinx_gallery_thumbnail_number = 3
Binary file not shown.

master/_downloads/15645a78701cc4e31af4898794deb04d/plot_generalized_free_support_barycenter.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
},
1616
"outputs": [],
1717
"source": [
18-
"# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu>\n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport matplotlib.pylab as pl\nimport ot\nimport matplotlib.animation as animation"
18+
"# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>\n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport matplotlib.pylab as pl\nimport ot\nimport matplotlib.animation as animation"
1919
]
2020
},
2121
{
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)