Skip to content

Commit

Permalink
Add tests and use numpyro backend by default (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Aug 1, 2023
1 parent 353c6d3 commit 98404c4
Show file tree
Hide file tree
Showing 13 changed files with 653 additions and 110 deletions.
5 changes: 3 additions & 2 deletions coix/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def dais(targets, momentum, leapfrog, refreshment, *, num_targets=None):
if _use_fori_loop(targets, num_targets):

def body_fun(i, q):
assert callable(targets)
p = extend(compose(momentum, targets(i), suffix=False), refreshment)
return propose(p, compose(refreshment, compose(leapfrog, q)))

Expand All @@ -155,7 +156,7 @@ def body_fun(i, q):

targets = [compose(momentum, p, suffix=False) for p in targets]
q = targets[0]
loss_fns = [None] * (len(targets) - 2) + [iwae_loss]
loss_fns = (None,) * (len(targets) - 2) + (iwae_loss,)
for p, loss_fn in zip(targets[1:], loss_fns):
q = compose(refreshment, compose(leapfrog, q))
q = propose(extend(p, refreshment), q, loss_fn=loss_fn)
Expand Down Expand Up @@ -413,7 +414,7 @@ def body_fun(i, q):
return propose(targets(num_targets - 1), q, loss_fn=iwae_loss)

q = propose(targets[0], proposals[0])
loss_fns = [None] * (len(proposals) - 2) + [iwae_loss]
loss_fns = (None,) * (len(proposals) - 2) + (iwae_loss,)
for p, fwd, loss_fn in zip(targets[1:], proposals[1:], loss_fns):
q = propose(p, compose(fwd, resample(q)), loss_fn=loss_fn)
return q
116 changes: 116 additions & 0 deletions coix/algo_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2023 The coix Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for algo.py."""

import functools

import coix
import jax
from jax import random
import jax.numpy as jnp
import numpy as np
import numpyro.distributions as dist
import optax

coix.set_backend("coix.oryx")

np.random.seed(0)
num_data, dim = 4, 2
data = np.random.randn(num_data, dim).astype(np.float32)
loc_p = np.random.randn(dim).astype(np.float32)
precision_p = np.random.rand(dim).astype(np.float32)
scale_p = np.sqrt(1 / precision_p)
precision_x = np.random.rand(dim).astype(np.float32)
scale_x = np.sqrt(1 / precision_x)
precision_q = precision_p + num_data * precision_x
loc_q = (data.sum(0) * precision_x + loc_p * precision_p) / precision_q
log_scale_q = -0.5 * np.log(precision_q)


def model(params, key):
del params
key_z, key_next = random.split(key)
z = coix.rv(dist.Normal(loc_p, scale_p), name="z")(key_z)
z = jnp.broadcast_to(z, (num_data, dim))
x = coix.rv(dist.Normal(z, scale_x), obs=data, name="x")
return key_next, z, x


def guide(params, key, *args):
del args
key, _ = random.split(key) # split here to test tie_in
scale_q = jnp.exp(params["log_scale_q"])
z = coix.rv(dist.Normal(params["loc_q"], scale_q), name="z")(key)
return z


def check_ess(make_program):
params = {"loc_q": loc_q, "log_scale_q": log_scale_q}
p = jax.vmap(functools.partial(model, params))
q = jax.vmap(functools.partial(guide, params))
program = make_program(p, q)

keys = random.split(random.PRNGKey(0), 5)
ess = coix.traced_evaluate(program)(keys)[2]["ess"]
np.testing.assert_allclose(ess, 5.0)


def run_inference(make_program, num_steps=1000):
"""Performs inference given an algorithm `make_program`."""

def loss_fn(params, key):
p = jax.vmap(functools.partial(model, params))
q = jax.vmap(functools.partial(guide, params))
program = make_program(p, q)

keys = random.split(key, 5)
metrics = coix.traced_evaluate(program)(keys)[2]
return metrics["loss"], metrics

init_params = {
"loc_q": jnp.zeros_like(loc_q),
"log_scale_q": jnp.zeros_like(log_scale_q),
}
params, _ = coix.util.train(
loss_fn, init_params, optax.adam(0.01), num_steps=num_steps
)

np.testing.assert_allclose(params["loc_q"], loc_q, atol=0.2)
np.testing.assert_allclose(params["log_scale_q"], log_scale_q, atol=0.2)


def test_apgs():
check_ess(lambda p, q: coix.algo.apgs(p, [q]))
run_inference(lambda p, q: coix.algo.apgs(p, [q]))


def test_rws():
check_ess(coix.algo.rws)
run_inference(coix.algo.rws)


def test_svi_elbo():
check_ess(coix.algo.svi)
run_inference(coix.algo.svi)


def test_svi_iwae():
check_ess(coix.algo.svi_iwae)
run_inference(coix.algo.svi_iwae)


def test_svi_stl():
check_ess(coix.algo.svi_stl)
run_inference(coix.algo.svi_stl)
69 changes: 43 additions & 26 deletions coix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,14 @@ def wrapped(*args, **kwargs):
log_probs = list(p_log_probs.values()) + list(q_log_probs.values())
batch_ndims = util.get_batch_ndims(log_probs)

assert "log_weight" in q_metrics
in_log_weight = q_metrics["log_weight"]
in_log_weight = jnp.sum(
in_log_weight,
axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)),
)
if "log_weight" in q_metrics:
in_log_weight = q_metrics["log_weight"]
in_log_weight = jnp.sum(
in_log_weight,
axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)),
)
else:
in_log_weight = util.get_log_weight(q_trace, batch_ndims)
p_log_weight = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in p_log_probs.items()
Expand All @@ -154,7 +156,7 @@ def wrapped(*args, **kwargs):
# Note: We include superfluous variables, whose `name in p_trace`.
q_log_weight = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in q_log_probs.items()
for lp in q_log_probs.values()
)
incremental_log_weight = p_log_weight - q_log_weight
log_weight = in_log_weight + incremental_log_weight
Expand Down Expand Up @@ -207,12 +209,20 @@ def _maybe_get_along_first_axis(x, idx, n, squeeze=False):
x = np.array(x)
# Special treatment for cascades.
if hasattr(x, "value"):
x.value = _maybe_get_along_first_axis(
util.get_site_value(x), idx, n, squeeze=squeeze
setattr(
x,
"value",
_maybe_get_along_first_axis(
util.get_site_value(x), idx, n, squeeze=squeeze
),
)
if hasattr(x, "log_density"):
x.log_density = _maybe_get_along_first_axis(
util.get_site_log_prob(x), idx, n, squeeze=squeeze
setattr(
x,
"log_density",
_maybe_get_along_first_axis(
util.get_site_log_prob(x), idx, n, squeeze=squeeze
),
)
if (
isinstance(x, (np.ndarray, jnp.ndarray))
Expand Down Expand Up @@ -247,7 +257,7 @@ def fn(*args, **kwargs):
if util.can_extract_key(args):
key_r, key_q = _split_key(args[0])
# We just need a single key for resampling.
key_r = key_r.reshape((-1, 2)).sum(0)
key_r = key_r.reshape((-1, 2))[0]
args = (key_q,) + args[1:]
else:
key_r = core.prng_key()
Expand Down Expand Up @@ -310,12 +320,17 @@ def _add_missing_metrics(metrics, trace):
batch_ndims = min(util.get_batch_ndims(list(log_probs.values())), 1)
log_weight = util.get_log_weight(trace, batch_ndims)
full_metrics["log_weight"] = log_weight
if batch_ndims: # leftmost dimension is particle dimension
ess = 1 / (jax.nn.softmax(log_weight, axis=0) ** 2).sum(0)
full_metrics["ess"] = ess.mean()
n = log_weight.shape[0]
log_z = jax.scipy.special.logsumexp(log_weight, 0) - jnp.log(n)
full_metrics["log_Z"] = log_z.mean()
else:
batch_ndims = metrics["log_weight"].ndim
log_weight = metrics["log_weight"]
# leftmost dimension is particle dimension
if batch_ndims and "ess" not in metrics:
assert "log_Z" not in metrics
ess = 1 / (jax.nn.softmax(log_weight, axis=0) ** 2).sum(0)
full_metrics["ess"] = ess.mean()
n = log_weight.shape[0]
log_z = jax.scipy.special.logsumexp(log_weight, 0) - jnp.log(n)
full_metrics["log_Z"] = log_z.mean()
if "loss" not in metrics:
full_metrics["loss"] = jnp.array(0.0)
if "log_density" not in metrics:
Expand All @@ -339,17 +354,18 @@ def fori_loop(lower, upper, body_fun, init_program):
"""

def fn(*args, **kwargs):
if util.can_extract_key(args):
key = args[0]
def trace_arg_key(fn, key):
return core.traced_evaluate(fn)(key, *args[1:], **kwargs)

def trace_fn(fn, key):
return core.traced_evaluate(fn)(key, *args[1:], **kwargs)
def trace_with_seed(fn, key):
return core.traced_evaluate(fn, seed=key)(*args, **kwargs)

if util.can_extract_key(args):
key = args[0]
trace_fn = trace_arg_key
else:
key = core.prng_key()

def trace_fn(fn, key):
return core.traced_evaluate(fn, seed=key)(*args, **kwargs)
trace_fn = trace_with_seed

key_body, key_init = _split_key(key)

Expand Down Expand Up @@ -420,7 +436,7 @@ def wrapped(*args, **kwargs):

p_log_weight = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in p_log_probs.items()
for lp in p_log_probs.values()
)

marginal_trace = {
Expand All @@ -431,6 +447,7 @@ def wrapped(*args, **kwargs):
new_memory = {
name: util.get_site_value(site) for name, site in marginal_trace.items()
}
assert not isinstance(p_log_weight, int)
num_particles = p_log_weight.shape[0]
batch_dim = p_log_weight.ndim
flat_memory = {
Expand Down
Loading

0 comments on commit 98404c4

Please sign in to comment.