Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions MultiHMCGibbs/multihmcgibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from numpyro.infer.mcmc import MCMCKernel
from numpyro.util import is_prng_key

MultiHMCGibbsState = namedtuple("MultiHMCGibbsState", "z, hmc_states, diverging, rng_key")
MultiHMCGibbsState = namedtuple("MultiHMCGibbsState", "z, hmc_states, diverging, rng_key, potential_energy")
"""
- **z** - a dict of the current latent values (all sites)
- **hmc_states** - list of current :data:`~numpyro.infer.hmc.HMCState` (one per gibbs step)
- **diverging** - A list of boolean value to indicate whether the current trajectory is diverging.
- **rng_key** - random number generator seed used for the iteration.
- **potential_energy** - A list of the potential energy values associated with each posterior sample.
"""


Expand Down Expand Up @@ -92,7 +93,7 @@ def model(self):

@property
def default_fields(self):
return ("z", "diverging")
return ("z", "diverging", "potential_energy")

def get_diagnostics_str(self, state):
# show diagnostics for all inner kernels
Expand Down Expand Up @@ -182,7 +183,9 @@ def init_fn(init_params, key_zs):
hmc_states.append(hmc_state_kdx)
rng_keys.append(hmc_state_kdx.rng_key)
z = z | hmc_state_kdx.z
return MultiHMCGibbsState(z, hmc_states, diverging, jnp.stack(rng_keys))

potential_energy = 0.0
return MultiHMCGibbsState(z, hmc_states, diverging, jnp.stack(rng_keys), potential_energy)

# not-vectorized
if is_prng_key(rng_key):
Expand Down Expand Up @@ -240,7 +243,9 @@ def potential_fn(z_hmc):
rng_keys.append(hmc_state.rng_key)
# update new z values (unconstrained space)
z = z | hmc_state.z
return MultiHMCGibbsState(z, hmc_states, jnp.stack(diverging), jnp.stack(rng_keys))

potential_energy = hmc_state.potential_energy # Add the potential energy after the Gibbs steps are completed
return MultiHMCGibbsState(z, hmc_states, jnp.stack(diverging), jnp.stack(rng_keys), potential_energy)

def sample(self, state, model_args, model_kwargs):
return self._sample_fn(state, model_args, model_kwargs)
40 changes: 20 additions & 20 deletions MultiHMCGibbs/tests/test_multihmcgibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def test_default(self):
mcmc.run(rng_key)
x = mcmc.get_samples()['x']
y = mcmc.get_samples()['y']
assert_allclose(np.mean(x), 0.5, atol=0.25, err_msg='mean(x) not close to 0.5')
assert_allclose(np.std(x), np.sqrt(2), atol=0.25, err_msg='std(x) not close to sqrt(2)')
assert_allclose(np.mean(y), 0.5, atol=0.25, err_msg='mean(y) not close to 0.5')
assert_allclose(np.std(y), np.sqrt(2), atol=0.25, err_msg='std(y) not close to sqrt(2)')
assert_allclose(np.mean(x), 0.5, atol=0.4, err_msg='mean(x) not close to 0.5')
assert_allclose(np.std(x), np.sqrt(2), atol=0.4, err_msg='std(x) not close to sqrt(2)')
assert_allclose(np.mean(y), 0.5, atol=0.4, err_msg='mean(y) not close to 0.5')
assert_allclose(np.std(y), np.sqrt(2), atol=0.4, err_msg='std(y) not close to sqrt(2)')

def test_sequential(self):
inner_kernels = [
Expand All @@ -67,10 +67,10 @@ def test_sequential(self):
mcmc.run(rng_key)
x = mcmc.get_samples()['x']
y = mcmc.get_samples()['y']
assert_allclose(np.mean(x), 0.5, atol=0.25, err_msg='mean(x) not close to 0.5')
assert_allclose(np.std(x), np.sqrt(2), atol=0.25, err_msg='std(x) not close to sqrt(2)')
assert_allclose(np.mean(y), 0.5, atol=0.25, err_msg='mean(y) not close to 0.5')
assert_allclose(np.std(y), np.sqrt(2), atol=0.25, err_msg='std(y) not close to sqrt(2)')
assert_allclose(np.mean(x), 0.5, atol=0.4, err_msg='mean(x) not close to 0.5')
assert_allclose(np.std(x), np.sqrt(2), atol=0.4, err_msg='std(x) not close to sqrt(2)')
assert_allclose(np.mean(y), 0.5, atol=0.4, err_msg='mean(y) not close to 0.5')
assert_allclose(np.std(y), np.sqrt(2), atol=0.4, err_msg='std(y) not close to sqrt(2)')

def test_vectorized(self):
inner_kernels = [
Expand All @@ -92,10 +92,10 @@ def test_vectorized(self):
mcmc.run(rng_key)
x = mcmc.get_samples()['x']
y = mcmc.get_samples()['y']
assert_allclose(np.mean(x), 0.5, atol=0.25, err_msg='mean(x) not close to 0.5')
assert_allclose(np.std(x), np.sqrt(2), atol=0.25, err_msg='std(x) not close to sqrt(2)')
assert_allclose(np.mean(y), 0.5, atol=0.25, err_msg='mean(y) not close to 0.5')
assert_allclose(np.std(y), np.sqrt(2), atol=0.25, err_msg='std(y) not close to sqrt(2)')
assert_allclose(np.mean(x), 0.5, atol=0.4, err_msg='mean(x) not close to 0.5')
assert_allclose(np.std(x), np.sqrt(2), atol=0.4, err_msg='std(x) not close to sqrt(2)')
assert_allclose(np.mean(y), 0.5, atol=0.4, err_msg='mean(y) not close to 0.5')
assert_allclose(np.std(y), np.sqrt(2), atol=0.4, err_msg='std(y) not close to sqrt(2)')

def test_init_params(self):
inner_kernels = [
Expand All @@ -115,10 +115,10 @@ def test_init_params(self):
mcmc.run(rng_key, init_params={'x': jnp.array(0.0), 'y': jnp.array(0.0)})
x = mcmc.get_samples()['x']
y = mcmc.get_samples()['y']
assert_allclose(np.mean(x), 0.5, atol=0.25, err_msg='mean(x) not close to 0.5')
assert_allclose(np.std(x), np.sqrt(2), atol=0.25, err_msg='std(x) not close to sqrt(2)')
assert_allclose(np.mean(y), 0.5, atol=0.25, err_msg='mean(y) not close to 0.5')
assert_allclose(np.std(y), np.sqrt(2), atol=0.25, err_msg='std(y) not close to sqrt(2)')
assert_allclose(np.mean(x), 0.5, atol=0.4, err_msg='mean(x) not close to 0.5')
assert_allclose(np.std(x), np.sqrt(2), atol=0.4, err_msg='std(x) not close to sqrt(2)')
assert_allclose(np.mean(y), 0.5, atol=0.4, err_msg='mean(y) not close to 0.5')
assert_allclose(np.std(y), np.sqrt(2), atol=0.4, err_msg='std(y) not close to sqrt(2)')

def test_forward(self):
inner_kernels = [
Expand All @@ -138,10 +138,10 @@ def test_forward(self):
mcmc.run(rng_key)
x = mcmc.get_samples()['x']
y = mcmc.get_samples()['y']
assert_allclose(np.mean(x), 0.5, atol=0.25, err_msg='mean(x) not close to 0.5')
assert_allclose(np.std(x), np.sqrt(2), atol=0.25, err_msg='std(x) not close to sqrt(2)')
assert_allclose(np.mean(y), 0.5, atol=0.25, err_msg='mean(y) not close to 0.5')
assert_allclose(np.std(y), np.sqrt(2), atol=0.25, err_msg='std(y) not close to sqrt(2)')
assert_allclose(np.mean(x), 0.5, atol=0.4, err_msg='mean(x) not close to 0.5')
assert_allclose(np.std(x), np.sqrt(2), atol=0.4, err_msg='std(x) not close to sqrt(2)')
assert_allclose(np.mean(y), 0.5, atol=0.4, err_msg='mean(y) not close to 0.5')
assert_allclose(np.std(y), np.sqrt(2), atol=0.4, err_msg='std(y) not close to sqrt(2)')

def test_model_mismatch(self):
def model2():
Expand Down