|
| 1 | +--- |
| 2 | +marp: true |
| 3 | +paginate: true |
| 4 | +--- |
| 5 | + |
| 6 | +<style> |
| 7 | +section { |
| 8 | + font-size: 24px; |
| 9 | +} |
| 10 | +.footnote { |
| 11 | + font-size: 20px; |
| 12 | +} |
| 13 | +</style> |
| 14 | + |
| 15 | +# **G**enerative **Ra**tio **M**atching Networks |
| 16 | + |
| 17 | +Akash Srivastava$^{\ast,1,2}$, Kai Xu$^{\ast,3}$, Michael U. Gutmann$^{3}$, Charles Sutton$^{3,4,5}$ |
| 18 | + |
| 19 | +<br><br><br><br><br><br> |
| 20 | + |
| 21 | +$\ast$ denotes equal contributions |
| 22 | +$^1$MIT-IBM Watson AI Lab $^2$IBM Research $^3$University of Edinburgh $^4$Google AI $^5$Alan Turing Institute |
| 23 | + |
| 24 | +<br> |
| 25 | + |
| 26 | +To appear in ICLR 2020; OpenReview: https://openreview.net/forum?id=SJg7spEYDS |
| 27 | + |
| 28 | +--- |
| 29 | + |
| 30 | +## Introduction and motivations |
| 31 | + |
| 32 | +Implicit deep generative models: $x = \mathrm{NN}(z; \theta)$ where $z \sim$ noise |
| 33 | + |
| 34 | +Maximum mean discrepancy networks (MMD-nets) |
| 35 | + |
| 36 | +- :x: can only work well with **low-dimensional** data |
| 37 | +- :white_check_mark: are very **stable** to train by avoiding the saddle-point optimization problem |
| 38 | + |
| 39 | +Adversarial generative models (e.g. GANs, MMD-GANs) |
| 40 | + |
| 41 | +- :white_check_mark: can generate **high-dimensional** data such as natural images |
| 42 | +- :x: are very **difficult** to train due to the saddle-point optimization problem |
| 43 | + |
| 44 | +Q: Can we have two :white_check_mark::white_check_mark:? |
| 45 | +A: Yes. Generative ratio matching (GRAM) is a *stable* learning algorithm for *implicit* deep generative models that does **not** involve a saddle-point optimization problem and therefore is easy to train :tada:. |
| 46 | + |
| 47 | +--- |
| 48 | + |
| 49 | +## Background: **maximum mean discrepancy** |
| 50 | + |
| 51 | +The maximum mean discrepancy (MMD) between two distributions $p$ and $q$ is defined as |
| 52 | +$$ |
| 53 | +\mathrm{MMD}_\mathcal{F}(p,q) = \sup_{f\in\mathcal{F}} \left(\mathbb{E}_p \lbrack f(x) \rbrack - \mathbb{E}_q \lbrack f(x) \rbrack \right) |
| 54 | +$$ |
| 55 | +Gretton et al. (2012) shows that it is sufficient to choose $\mathcal{F}$ to be a unit ball in an reproducing kernel Hilbert space (RKHS) $\mathcal{R}$ with a characteristic kernel $k$ s.t. |
| 56 | +$$ |
| 57 | +\mathrm{MMD}_\mathcal{F}(p, q) = 0 \iff p = q |
| 58 | +$$ |
| 59 | +The empirical estimate of the (squared) MMD with $x_i \sim p$ and $y_j \sim q$ by Monte Carlo is |
| 60 | +$$ |
| 61 | +\hat{\textmd{MMD}}^2_\mathcal{R}(p,q) = |
| 62 | +\frac{1}{N^2}\sum_{i=1}^N\sum_{i'=1}^N k(x_i,x_{i'}) |
| 63 | +- \frac{2}{NM}\sum_{i=1}^N\sum_{j=1}^M k(x_i, y_j) |
| 64 | + + \frac{1}{M^2}\sum_{j=1}^M\sum_{j'=1}^M k(y_j,y_{j'}) |
| 65 | +$$ |
| 66 | +MMD-nets trains neural generators by minimizing this empirical estimate. |
| 67 | + |
| 68 | +--- |
| 69 | + |
| 70 | +## Background: **density ratio estimation via moment matching** |
| 71 | + |
| 72 | +Density ratio estimation: find $\hat{r}(x) \approx r(x) = \frac{p(x)}{q(x)}$ with samples from $p$ and $q$ |
| 73 | + |
| 74 | +Finite moments under the fixed design setup gives $\hat{\mathbf{r}}_q = [\hat{r}(x_1^q), ..., \hat{r}(x_M^q)]$ for $x^q \sim q$ |
| 75 | + |
| 76 | +$$ |
| 77 | +\min_r \left ( \int \phi(x)p(x) dx - \int \phi(x)r(x)q(x) dx \right )^2 |
| 78 | +$$ |
| 79 | + |
| 80 | +Huang et al. (2007) shows that by changing $\phi(x)$ to $k(x; .)$, where $k$ is a characteristic kernel in RKHS, we can match infinite moments and the optimization below agrees with the true $r(x)$ |
| 81 | + |
| 82 | +$$ |
| 83 | +\min_{r\in\mathcal{R}} \bigg \Vert \int k(x; .)p(x) dx - \int k(x; .)r(x)q(x) dx \bigg \Vert_{\mathcal{R}}^2 |
| 84 | +$$ |
| 85 | + |
| 86 | +Analytical solution: $\hat{\mathbf{r}} = \mathbf{K}_{q,q}^{-1}\mathbf{K}_{q,p} \mathbf{1}$, where $[\mathbf{K}_{p,q}]_{i,j} = k(x_i^p, x_j^q)$ given samples $\{x_i^p\}$ and $\{x_j^q\}$. |
| 87 | + |
| 88 | +--- |
| 89 | + |
| 90 | +## GRAM: an overview |
| 91 | + |
| 92 | +Two targets in the training loop |
| 93 | + |
| 94 | +1. Learning a projection function $f_\theta$ that maps the data space into a low-dimensional manifold which preserves the density ratio between data and model. |
| 95 | + - "Preserves": $\frac{p_x(x)}{q_x(x)} = \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))}$, measured by $D(\theta) = \int q_x(x) \left( \frac{p_x(x)}{q_x(x)} - \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 dx$ |
| 96 | + - :question: $\frac{p_x(x)}{q_x(x)}$ is hard to estimate in the high-dimensional space ... |
| 97 | +2. Matching the model $G_\gamma$ to data in the low-dimensional manifold by minimizing MMD |
| 98 | + - :thumbsup: MMD works well in low dimensional space |
| 99 | + - $\mathrm{MMD} = 0$ :arrow_right: $\frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} = 1$ :arrow_right: $\frac{p_x(x)}{q_x(x)} = 1$ |
| 100 | + |
| 101 | +Both with empirical estimates based on samples from the data $\{x_i^p\}$ and the model $\{x_j^q\}$. |
| 102 | + |
| 103 | +$f_\theta$ and $G_\gamma$ are simultaneously updated. |
| 104 | + |
| 105 | +--- |
| 106 | + |
| 107 | +## GRAM: tractable ratio matching |
| 108 | + |
| 109 | +:one: Learning the projection function $f_\theta(x)$ by minimizing the squared difference |
| 110 | +$$ |
| 111 | +\begin{aligned} |
| 112 | +D(\theta) |
| 113 | +&= \int q_x(x) \left( \frac{p_x(x)}{q_x(x)} - \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 dx \\ |
| 114 | +&= C - 2 \int p_x(x) \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} dx + \int q_x(x) \left( \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 dx \\ |
| 115 | +&= C - 2 \int \bar{p}(f_\theta(x)) \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} df_\theta(x) + \int \bar{q}(f_\theta(x)) \left( \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 df_\theta(x) \\ |
| 116 | +&= C' - \left( \int \bar{q}(f_\theta(x)) \left( \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 df_\theta(x) - 1 \right) = C' - \mathrm{PD}(\bar{q}, \bar{p}) |
| 117 | +\end{aligned} |
| 118 | +$$ |
| 119 | +... or by equivalently maximizing the Pearson divergence :smile:. |
| 120 | + |
| 121 | +A reminder on LOTUS: $\int p(x) g(f(x)) dx = \int p(f(x)) g(f(x)) d f(x)$ |
| 122 | + |
| 123 | +<br> |
| 124 | + |
| 125 | +<div class="footnote"> |
| 126 | +[1]: A derivation of the reverse order for a special case of projection functions was also shown in (Sugiyama et al., 2011). |
| 127 | +</dvi> |
| 128 | + |
| 129 | +--- |
| 130 | + |
| 131 | +## GRAM: Pearson divergence maximization |
| 132 | + |
| 133 | +Monte Carlo approximation of PD |
| 134 | +$$ |
| 135 | +\mathrm{PD}(\bar{q}, \bar{p}) = \int \bar{q}(f_\theta(x)) \left( \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))} \right)^2 df_\theta(x) - 1 \approx \frac{1}{N} \sum_{i=1}^N \left( \frac{\bar{p}(f_\theta(x_i^q))}{\bar{q}(f_\theta(x_i^q))} \right)^2 - 1 |
| 136 | +$$ |
| 137 | +where $x^q_i \sim q_x$ or equivalently $f_\theta(x^q_i) \sim \bar{q}$. |
| 138 | + |
| 139 | +Given samples $\{x_i^p\}$ and $\{x_j^q\}$, we use the density ratio estimator based on infinite moments matching (Huang et al., 2007, Sugiyama et al., 2012) under the fixed-design setup |
| 140 | +$$ |
| 141 | +\hat{\mathbf{r}}_{q,_\theta} = \mathbf{K}^{-1}_{q,q} \mathbf{K}_{q,p}\mathbf{1} = [\hat{r}_\theta(x_1^q), ..., \hat{r}_\theta(x_M^q)]^\top |
| 142 | +$$ |
| 143 | +where $[\mathbf{K}_{p,q}]_{i,j} = k(f_\theta(x_i^p), f_\theta(x_j^q))$ and $r_\theta(x) = \frac{\bar{p}(f_\theta(x))}{\bar{q}(f_\theta(x))}$. |
| 144 | + |
| 145 | +--- |
| 146 | + |
| 147 | +## GRAM: matching projected model to projected data |
| 148 | + |
| 149 | +:two: Minimizing the empirical estimator of MMD in the low-dimensional manifold |
| 150 | + |
| 151 | +$$ |
| 152 | +\begin{aligned} |
| 153 | +\min_\gamma \Bigg[&\frac{1}{N^2}\sum_{i=1}^N\sum_{i'=1}^N k(f_\theta(x_i),f_\theta(x_{i'})) |
| 154 | +- \frac{2}{NM}\sum_{i=1}^N\sum_{j=1}^M k(f_\theta(x_i), f_\theta(G_\gamma(z_j)))\\ |
| 155 | +&\quad + \frac{1}{M^2}\sum_{j=1}^M\sum_{j'=1}^M k(f_\theta(G_\gamma(z_j)),f_\theta(G_\gamma(z_{j'}))) \Bigg ] |
| 156 | +\end{aligned} |
| 157 | +$$ |
| 158 | + |
| 159 | +with respect to its parameters $\gamma$. |
| 160 | + |
| 161 | +--- |
| 162 | + |
| 163 | +## GRAM: the complete algorithm |
| 164 | + |
| 165 | +Loop until convergence |
| 166 | + |
| 167 | +1. Sample a minibatch of data and generate samples from $G_\gamma$ |
| 168 | +2. Project data and generated samples using $f_\theta$ |
| 169 | +3. Compute the kernel Gram matrices using Gaussian kernels in the projected space |
| 170 | +4. Compute the objectives for $f_\theta$ and $G_\gamma$ using the same kernel Gram matrices |
| 171 | +5. Backprop two objectives to get the gradients for $\theta$ and $\gamma$ |
| 172 | +6. Perform gradient update for $\theta$ and $\gamma$ |
| 173 | + |
| 174 | +<br> |
| 175 | + |
| 176 | +:sunglasses: Fun fact: the objectives in our GRAM algorithm both heavily relies on the use of kernel Gram matrices. |
| 177 | + |
| 178 | +--- |
| 179 | + |
| 180 | +## How do GRAM-nets compare to other deep generative models |
| 181 | + |
| 182 | +| GAN | MMD-net | MMD-GAN | GRAM-net | |
| 183 | +| - | - | - | - | |
| 184 | +|  |  |  |  | |
| 185 | + |
| 186 | +--- |
| 187 | + |
| 188 | +## Illustration of GRAM training |
| 189 | + |
| 190 | +| | | |
| 191 | +| - | - | |
| 192 | +|  |  | |
| 193 | + |
| 194 | +Blue: data, Orange: samples |
| 195 | +Top: original, Bottom: projected |
| 196 | + |
| 197 | +--- |
| 198 | + |
| 199 | +## Evaluations: the stability of models |
| 200 | + |
| 201 | + |
| 202 | + |
| 203 | +x-axis = noise dimension and y-axis = generator layer size |
| 204 | + |
| 205 | +--- |
| 206 | + |
| 207 | +## Evaluations: the stability of models (continued) |
| 208 | + |
| 209 | + |
| 210 | + |
| 211 | +x-axis = noise dimension and y-axis = generator layer size |
| 212 | + |
| 213 | +--- |
| 214 | + |
| 215 | +## Quantitative results: sample quality |
| 216 | + |
| 217 | + |
| 218 | + |
| 219 | +--- |
| 220 | + |
| 221 | +## Qualitative results: random samples |
| 222 | + |
| 223 | +  |
| 224 | + |
| 225 | +--- |
| 226 | + |
| 227 | +## The end |
0 commit comments