Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Chain Burn-in Calculation in Thermodynamic Integration #1361

Merged
merged 4 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions pypesto/sample/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
logger = logging.getLogger(__name__)


def geweke_test(result: Result, zscore: float = 2.0) -> int:
def geweke_test(
result: Result, zscore: float = 2.0, chain_number: int = 0
) -> int:
"""
Calculate the burn-in of MCMC chains.

Expand All @@ -21,6 +23,9 @@ def geweke_test(result: Result, zscore: float = 2.0) -> int:
The pyPESTO result object with filled sample result.
zscore:
The Geweke test threshold.
chain_number:
The chain number to be used for the Geweke test (in a parallel tempering setting).
Usually we are only interested in the first chain.

Returns
-------
Expand All @@ -29,16 +34,17 @@ def geweke_test(result: Result, zscore: float = 2.0) -> int:
do not differ significantly regarding Geweke test -> Burn-In
"""
# Get parameter samples as numpy arrays
chain = np.asarray(result.sample_result.trace_x[0])
chain = np.asarray(result.sample_result.trace_x[chain_number])

# Calculate burn in index
burn_in = burn_in_by_sequential_geweke(chain=chain, zscore=zscore)

# Log
logger.info(f"Geweke burn-in index: {burn_in}")
if chain_number == 0:
# Log
logger.info(f"Geweke burn-in index: {burn_in}")

# Fill in burn-in value into result
result.sample_result.burn_in = burn_in
# Fill in burn-in value into result
result.sample_result.burn_in = burn_in

return burn_in

Expand Down
28 changes: 22 additions & 6 deletions pypesto/sample/parallel_tempering.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..problem import Problem
from ..result import McmcPtResult, Result
from ..util import tqdm
from .diagnostics import geweke_test
from .sampler import InternalSampler, Sampler

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -197,12 +198,27 @@ def compute_log_evidence(
f"Carefully check the results. Consider using beta_init='{BETA_DECAY}' for better results."
)

burn_in = result.sample_result.burn_in
trace_loglike = (
result.sample_result.trace_neglogprior[::-1, burn_in:]
- result.sample_result.trace_neglogpost[::-1, burn_in:]
)
mean_loglike_per_beta = np.mean(trace_loglike, axis=1)
# compute burn in for all chains and then estimate mean of log likelihood for each beta
mean_loglike_per_beta = []
temps = []
for i_chain in reversed(range(len(self.betas))):
burn_in_i = geweke_test(result, chain_number=i_chain)
arrjon marked this conversation as resolved.
Show resolved Hide resolved

if (
burn_in_i < result.sample_result.trace_x.shape[1]
or i_chain == len(self.betas) - 1
arrjon marked this conversation as resolved.
Show resolved Hide resolved
):
# save temperature-chain as it is converged
# last chain converges always, only samples from prior
temps.append(i_chain)
trace_loglike_i = (
result.sample_result.trace_neglogprior[i_chain, burn_in_i:]
- result.sample_result.trace_neglogpost[
i_chain, burn_in_i:
]
)
mean_loglike_per_beta.append(np.mean(trace_loglike_i))

if method == "trapezoid":
log_evidence = trapezoid(
y=mean_loglike_per_beta, x=self.betas[::-1]
Expand Down
Loading