Skip to content

Commit b4f4f24

Browse files
committed
update example with changes from pymc#7342
1 parent 6b607f9 commit b4f4f24

File tree

2 files changed

+335
-304
lines changed

2 files changed

+335
-304
lines changed

examples/gaussian_processes/HSGP-Basic.ipynb

+311-289
Large diffs are not rendered by default.

examples/gaussian_processes/HSGP-Basic.myst.md

+24-15
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ jupytext:
55
format_name: myst
66
format_version: 0.13
77
kernelspec:
8-
display_name: pymc-examples
8+
display_name: pymc-dev
99
language: python
10-
name: pymc-examples
10+
name: pymc-dev
1111
---
1212

1313
(hsgp)=
@@ -58,6 +58,12 @@ import matplotlib.pyplot as plt
5858
import numpy as np
5959
import pymc as pm
6060
import pytensor.tensor as pt
61+
62+
# Sample on the CPU
63+
%env CUDA_VISIBLE_DEVICES=''
64+
# import jax
65+
# import numpyro
66+
# numpyro.set_host_device_count(6)
6167
```
6268

6369
```{code-cell} ipython3
@@ -325,12 +331,12 @@ In practice, you'll need to infer the lengthscale from the data, so the HSGP nee
325331
For example, if you're using the `Matern52` covariance and your data ranges from $x=-5$ to $x=95$, and the bulk of your lengthscale prior is between $\ell=1$ and $\ell=50$, then the smallest recommended values are $m=543$ and $c=3.7$, as you can see below:
326332
327333
```{code-cell} ipython3
328-
hsgp_params = pm.gp.hsgp_approx.approx_hsgp_hyperparams(
329-
x=np.linspace(-5, 95), lengthscale_range=[1, 50], cov_func="matern52"
334+
m, c = pm.gp.hsgp_approx.approx_hsgp_hyperparams(
335+
x_range=[-5, 95], lengthscale_range=[1, 50], cov_func="matern52"
330336
)
331337
332-
print("Recommended smallest number of basis vectors (m):", hsgp_params.m)
333-
print("Recommended smallest scaling factor (c):", np.round(hsgp_params.c, 1))
338+
print("Recommended smallest number of basis vectors (m):", m)
339+
print("Recommended smallest scaling factor (c):", np.round(c, 1))
334340
```
335341
336342
### The HSGP approximate Gram matrix
@@ -355,7 +361,8 @@ K = cov_func(X).eval()
355361
356362
## Calculate the HSGP approximate Gram matrix
357363
# Center or "scale" X so we can work with Xs (important)
358-
Xs = X - np.mean(X, axis=0)
364+
X_center = (np.max(X, axis=0) - np.min(X, axis=0)) / 2.0
365+
Xs = X - X_center
359366
360367
# Calculate L given Xs and c
361368
m, c = [20, 20], 2.0
@@ -376,7 +383,7 @@ def calculate_Kapprox(Xs, L, m):
376383
fig, axs = plt.subplots(2, 4, figsize=(14, 7), sharey=True)
377384
378385
axs[0, 0].imshow(K, cmap="inferno", vmin=0, vmax=1)
379-
axs[0, 0].set(xlabel="x1", ylabel="x2", title=f"True Gram matrix\nTrue $\ell$ = {chosen_ell}")
386+
axs[0, 0].set(xlabel="x1", ylabel="x2", title=f"True Gram matrix\nTrue $\\ell$ = {chosen_ell}")
380387
axs[1, 0].axis("off")
381388
im_kwargs = {
382389
"cmap": "inferno",
@@ -424,6 +431,8 @@ K_approx = calculate_Kapprox(Xs, L, m)
424431
axs[1, 3].imshow(K_approx, **im_kwargs)
425432
axs[1, 3].set_title(f"m = {m}, c = {c}")
426433
434+
for ax in axs.flatten():
435+
ax.grid(False)
427436
plt.tight_layout();
428437
```
429438
@@ -549,25 +558,25 @@ with pm.Model() as model:
549558
beta = pm.Normal("beta", mu=0.0, sigma=10.0, shape=2)
550559
551560
# Prior on the HSGP
552-
eta = pm.HalfNormal("eta", 0.5)
561+
eta = pm.Exponential("eta", scale=2.0)
553562
ell_params = pm.find_constrained_prior(
554563
pm.Lognormal, lower=0.5, upper=5.0, mass=0.9, init_guess={"mu": 1.0, "sigma": 1.0}
555564
)
556565
ell = pm.Lognormal("ell", **ell_params)
557566
cov_func = eta**2 * pm.gp.cov.Matern52(input_dim=2, ls=ell)
558567
559568
# m and c control the fidelity of the approximation
560-
m0, m1, c = 30, 30, 2.5
569+
m0, m1, c = 30, 30, 2.0
561570
gp = pm.gp.HSGP(m=[m0, m1], c=c, cov_func=cov_func)
562571
563-
phi, sqrt_psd = gp.prior_linearized(Xs=X_gp)
572+
phi, sqrt_psd = gp.prior_linearized(X=X_gp)
564573
565574
basis_coeffs = pm.Normal("basis_coeffs", size=gp.n_basis_vectors)
566575
f = pm.Deterministic("f", phi @ (basis_coeffs * sqrt_psd))
567576
568577
mu = pm.Deterministic("mu", beta[0] + beta[1] * X_fe + f)
569578
570-
sigma = pm.HalfNormal("sigma", 0.5)
579+
sigma = pm.Exponential("sigma", scale=2.0)
571580
pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y_tr, shape=X_gp.shape[0])
572581
573582
idata = pm.sample_prior_predictive()
@@ -585,7 +594,7 @@ Before sampling and looking at the results, there are a few things to pay attent
585594
586595
First, `prior_linearized` returns the eigenvector basis, `phi`, and the square root of the power spectrum at the eigenvalues, `sqrt_psd`. You have to construct the HSGP approximation from these. The following are the relevant lines of code, showing both the centered and non-centered parameterization.
587596
```python
588-
phi, sqrt_psd = gp.prior_linearized(Xs=Xs)
597+
phi, sqrt_psd = gp.prior_linearized(X=X)
589598
590599
## non-centered
591600
basis_coeffs= pm.Normal("basis_coeffs", size=gp.n_basis_vectors)
@@ -607,7 +616,7 @@ nu = pm.Gamma("nu", alpha=2, beta=0.1)
607616
basis_coeffs= pm.StudentT("basis_coeffs", nu=nu, size=gp.n_basis_vectors)
608617
f = pm.Deterministic("f", phi @ (beta * sqrt_psd))
609618
```
610-
where we use a $\text{Gamma}(\alpha=2, \beta=0.1)$ prior for $\nu$, which places around 50% probability that $\nu > 30$, the point where a Student-T roughly becomes indistinguishable from a Gaussian.
619+
where we use a $\text{Gamma}(\alpha=2, \beta=0.1)$ prior for $\nu$, which places around 50% probability that $\nu > 30$, the point where a Student-T roughly becomes indistinguishable from a Gaussian. See [this link](https://github.com/stan-dev/stan/wiki/prior-choice-recommendations#prior-for-degrees-of-freedom-in-students-t-distribution) for more information.
611620
612621
+++
613622
@@ -639,7 +648,7 @@ az.plot_trace(
639648
);
640649
```
641650
642-
Sampling went great, but, interestingly, we seem to have a bias in the model, for `eta`, `ell` and `sigma`. It's not the focus of this notebook, but it'd be interesting to dive into this in a real use-case.
651+
Sampling went great, but, interestingly, we seem to have a bias in the posterior for `sigma`. It's not the focus of this notebook, but it'd be interesting to dive into this in a real use-case.
643652
644653
+++
645654

0 commit comments

Comments
 (0)