Skip to content

Commit 787cdb2

Browse files
committed
add slides for kai's anc talk
1 parent 0b7d51e commit 787cdb2

File tree

3 files changed

+227
-0
lines changed

3 files changed

+227
-0
lines changed

anc.md

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
---
2+
marp: true
3+
paginate: true
4+
---
5+
6+
<style>
7+
section {
8+
font-size: 24px;
9+
}
10+
.footnote {
11+
font-size: 20px;
12+
}
13+
</style>
14+
15+
# **G**enerative **Ra**tio **M**atching Networks
16+
17+
Akash Srivastava$^{\ast,1,2}$, Kai Xu$^{\ast,3}$, Michael U. Gutmann$^{3}$, Charles Sutton$^{3,4,5}$
18+
19+
<br><br><br><br><br><br>
20+
21+
$\ast$ denotes equal contributions
22+
$^1$MIT-IBM Watson AI Lab $^2$IBM Research $^3$University of Edinburgh $^4$Google AI $^5$Alan Turing Institute
23+
24+
<br>
25+
26+
To appear in ICLR 2020; OpenReview: https://openreview.net/forum?id=SJg7spEYDS
27+
28+
---
29+
30+
## Introduction and motivations
31+
32+
Implicit deep generative models: $x = \mathrm{NN}(z; \theta)$ where $z \sim$ noise
33+
34+
Maximum mean discrepancy networks (MMD-nets)
35+
36+
- :x: can only work well with **low-dimensional** data
37+
- :white_check_mark: are very **stable** to train by avoiding the saddle-point optimization problem
38+
39+
Adversarial generative models (e.g. GANs, MMD-GANs)
40+
41+
- :white_check_mark: can generate **high-dimensional** data such as natural images
42+
- :x: are very **difficult** to train due to the saddle-point optimization problem
43+
44+
Q: Can we have two :white_check_mark::white_check_mark:?
45+
A: Yes. Generative ratio matching (GRAM) is a *stable* learning algorithm for *implicit* deep generative models that does **not** involve a saddle-point optimization problem and therefore is easy to train :tada:.
46+
47+
---
48+
49+
## Background: **maximum mean discrepancy**
50+
51+
The maximum mean discrepancy (MMD) between two distributions $p$ and $q$ is defined as
52+
$$
53+
\mathrm{MMD}_\mathcal{F}(p,q) = \sup_{f\in\mathcal{F}} \left(\mathbb{E}_p \lbrack f(x) \rbrack - \mathbb{E}_q \lbrack f(x) \rbrack \right)
54+
$$
55+
Gretton et al. (2012) shows that it is sufficient to choose $\mathcal{F}$ to be a unit ball in an reproducing kernel Hilbert space (RKHS) $\mathcal{R}$ with a characteristic kernel $k$ s.t.
56+
$$
57+
\mathrm{MMD}_\mathcal{F}(p, q) = 0 \iff p = q
58+
$$
59+
The empirical estimate of the (squared) MMD with $x_i \sim p$ and $y_j \sim q$ by Monte Carlo is
60+
$$
61+
\hat{\textmd{MMD}}^2_\mathcal{R}(p,q) =
62+
\frac{1}{N^2}\sum_{i=1}^N\sum_{i'=1}^N k(x_i,x_{i'})
63+
- \frac{2}{NM}\sum_{i=1}^N\sum_{j=1}^M k(x_i, y_j)
64+
+ \frac{1}{M^2}\sum_{j=1}^M\sum_{j'=1}^M k(y_j,y_{j'})
65+
$$
66+
MMD-nets trains neural generators by minimizing this empirical estimate.
67+
68+
---
69+
70+
## Background: **density ratio estimation via moment matching**
71+
72+
Density ratio estimation: find $\hat{r}(x) \approx r(x) = \frac{p(x)}{q(x)}$ with samples from $p$ and $q$
73+
74+
Finite moments under the fixed design setup gives $\hat{\mathbf{r}}_q = [\hat{r}(x_1^q), ..., \hat{r}(x_M^q)]$ for $x^q \sim q$
75+
76+
$$
77+
\min_r \left ( \int \phi(x)p(x) dx - \int \phi(x)r(x)q(x) dx \right )^2
78+
$$
79+
80+
Huang et al. (2007) shows that by changing $\phi(x)$ to $k(x; .)$, where $k$ is a characteristic kernel in RKHS, we can match infinite moments and the optimization below agrees with the true $r(x)$
81+
82+
$$
83+
\min_{r\in\mathcal{R}} \bigg \Vert \int k(x; .)p(x) dx - \int k(x; .)r(x)q(x) dx \bigg \Vert_{\mathcal{R}}^2
84+
$$
85+
86+
Analytical solution: $\hat{\mathbf{r}} = \mathbf{K}_{q,q}^{-1}\mathbf{K}_{q,p} \mathbf{1}$, where $[\mathbf{K}_{p,q}]_{i,j} = k(x_i^p, x_j^q)$ given samples $\{x_i^p\}$ and $\{x_j^q\}$.
87+
88+
---
89+
90+
## GRAM: an overview
91+
92+
Two targets in the training loop
93+
94+
1. Learning a projection function $f_\theta$ that maps the data space into a low-dimensional manifold which preserves the density ratio between data and model.
95+
- "Preserves": $\frac{p_x(x)}{q_x(x)} = \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))}$, measured by $D(\theta) = \int q_x(x) \left( \frac{p_x(x)}{q_x(x)} - \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 dx$
96+
- :question: $\frac{p_x(x)}{q_x(x)}$ is hard to estimate in the high-dimensional space ...
97+
2. Matching the model $G_\gamma$ to data in the low-dimensional manifold by minimizing MMD
98+
- :thumbsup: MMD works well in low dimensional space
99+
- $\mathrm{MMD} = 0$ :arrow_right: $\frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} = 1$ :arrow_right: $\frac{p_x(x)}{q_x(x)} = 1$
100+
101+
Both with empirical estimates based on samples from the data $\{x_i^p\}$ and the model $\{x_j^q\}$.
102+
103+
$f_\theta$ and $G_\gamma$ are simultaneously updated.
104+
105+
---
106+
107+
## GRAM: tractable ratio matching
108+
109+
:one: Learning the projection function $f_\theta(x)$ by minimizing the squared difference
110+
$$
111+
\begin{aligned}
112+
D(\theta)
113+
&= \int q_x(x) \left( \frac{p_x(x)}{q_x(x)} - \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 dx \\
114+
&= C - 2 \int p_x(x) \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} dx + \int q_x(x) \left( \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 dx \\
115+
&= C - 2 \int \bar{p}(f_\theta(x)) \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} df_\theta(x) + \int \bar{q}(f_\theta(x)) \left( \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 df_\theta(x) \\
116+
&= C' - \left( \int \bar{q}(f_\theta(x)) \left( \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 df_\theta(x) - 1 \right) = C' - \mathrm{PD}(\bar{q}, \bar{p})
117+
\end{aligned}
118+
$$
119+
... or by equivalently maximizing the Pearson divergence :smile:.
120+
121+
A reminder on LOTUS: $\int p(x) g(f(x)) dx = \int p(f(x)) g(f(x)) d f(x)$
122+
123+
<br>
124+
125+
<div class="footnote">
126+
[1]: A derivation of the reverse order for a special case of projection functions was also shown in (Sugiyama et al., 2011).
127+
</dvi>
128+
129+
---
130+
131+
## GRAM: Pearson divergence maximization
132+
133+
Monte Carlo approximation of PD
134+
$$
135+
\mathrm{PD}(\bar{q}, \bar{p}) = \int \bar{q}(f_\theta(x)) \left( \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 df_\theta(x) - 1 \approx \frac{1}{N} \sum_{i=1}^N \left( \frac{\bar{p}(f_\theta(x_i^q))}{\bar{q}(f_\theta(x_i^q))} \right)^2 - 1
136+
$$
137+
where $x^q_i \sim q_x$ or equivalently $f_\theta(x^q_i) \sim \bar{q}$.
138+
139+
Given samples $\{x_i^p\}$ and $\{x_j^q\}$, we use the density ratio estimator based on infinite moments matching (Huang et al., 2007, Sugiyama et al., 2012) under the fixed-design setup
140+
$$
141+
\hat{\mathbf{r}}_{q,_\theta} = \mathbf{K}^{-1}_{q,q} \mathbf{K}_{q,p}\mathbf{1} = [\hat{r}_\theta(x_1^q), ..., \hat{r}_\theta(x_M^q)]^\top
142+
$$
143+
where $[\mathbf{K}_{p,q}]_{i,j} = k(f_\theta(x_i^p), f_\theta(x_j^q))$ and $r_\theta(x) = \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))}$.
144+
145+
---
146+
147+
## GRAM: matching projected model to projected data
148+
149+
:two: Minimizing the empirical estimator of MMD in the low-dimensional manifold
150+
151+
$$
152+
\begin{aligned}
153+
\min_\gamma \Bigg[&\frac{1}{N^2}\sum_{i=1}^N\sum_{i'=1}^N k(f_\theta(x_i),f_\theta(x_{i'}))
154+
- \frac{2}{NM}\sum_{i=1}^N\sum_{j=1}^M k(f_\theta(x_i), f_\theta(G_\gamma(z_j)))\\
155+
&\quad + \frac{1}{M^2}\sum_{j=1}^M\sum_{j'=1}^M k(f_\theta(G_\gamma(z_j)),f_\theta(G_\gamma(z_{j'}))) \Bigg ]
156+
\end{aligned}
157+
$$
158+
159+
with respect to its parameters $\gamma$.
160+
161+
---
162+
163+
## GRAM: the complete algorithm
164+
165+
Loop until convergence
166+
167+
1. Sample a minibatch of data and generate samples from $G_\gamma$
168+
2. Project data and generated samples using $f_\theta$
169+
3. Compute the kernel Gram matrices using Gaussian kernels in the projected space
170+
4. Compute the objectives for $f_\theta$ and $G_\gamma$ using the same kernel Gram matrices
171+
5. Backprop two objectives to get the gradients for $\theta$ and $\gamma$
172+
6. Perform gradient update for $\theta$ and $\gamma$
173+
174+
<br>
175+
176+
:sunglasses: Fun fact: the objectives in our GRAM algorithm both heavily relies on the use of kernel Gram matrices.
177+
178+
---
179+
180+
## How do GRAM-nets compare to other deep generative models
181+
182+
| GAN | MMD-net | MMD-GAN | GRAM-net |
183+
| - | - | - | - |
184+
| ![Computation graph (GAN)](comp_graph-gan.png) | ![Computation graph (MMD-net)](comp_graph-mmdnet.png) | ![Computation graph (MMD-GAN)](comp_graph-mmdgan.png) | ![Computation graph (GRAM-net)](comp_graph-gramnet.png) |
185+
186+
---
187+
188+
## Illustration of GRAM training
189+
190+
| | |
191+
| - | - |
192+
| ![Ring: training process](ring_process.png) | ![width:350px](trace.png) |
193+
194+
Blue: data, Orange: samples
195+
Top: original, Bottom: projected
196+
197+
---
198+
199+
## Evaluations: the stability of models
200+
201+
![Ring: stability on GAN and GRAM-nets](stability_gan_gramnet.png)
202+
203+
x-axis = noise dimension and y-axis = generator layer size
204+
205+
---
206+
207+
## Evaluations: the stability of models (continued)
208+
209+
![Ring: stability on MMD-nets and MMD-GANs](stability_mmdnet_mmdgan.png)
210+
211+
x-axis = noise dimension and y-axis = generator layer size
212+
213+
---
214+
215+
## Quantitative results: sample quality
216+
217+
![FID table](table_image.png)
218+
219+
---
220+
221+
## Qualitative results: random samples
222+
223+
![Samples (CIFAR10)](cifar10.png) ![Samples (CelebA)](celeba.png)
224+
225+
---
226+
227+
## The end

anc.pdf

2.65 MB
Binary file not shown.

trace.png

23 KB
Loading

0 commit comments

Comments
 (0)