diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 40aef6dba..f456e37f2 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -6,6 +6,7 @@ import time from typing import Optional, Dict +from autofit import exc from autofit.mapper.prior_model.abstract import AbstractPriorModel from autofit.non_linear.paths.abstract import AbstractPaths from autofit.non_linear.samples.summary import SamplesSummary @@ -173,8 +174,17 @@ def compute_latent_samples(self, samples: Samples, batch_size : Optional[int] = batched_compute_latent = jax.jit(jax.vmap(compute_latent_for_model)) logger.info(f"JAX: vmap and jit applied in {time.time() - start} seconds.") else: + n_latents = len(self.LATENT_KEYS) + nan_row = np.full(n_latents, np.nan) + + def _safe_compute(xx): + try: + return compute_latent_for_model(xx) + except exc.FitException: + return nan_row + def batched_compute_latent(x): - return np.array([compute_latent_for_model(xx) for xx in x]) + return np.array([_safe_compute(xx) for xx in x]) parameter_array = np.array(samples.parameter_lists) latent_samples = [] @@ -183,6 +193,7 @@ def batched_compute_latent(x): for i in range(0, len(parameter_array), batch_size): batch = parameter_array[i:i + batch_size] + batch_samples = samples.sample_list[i:i + batch_size] # batched JAX call on this chunk latent_values_batch = batched_compute_latent(batch) @@ -193,10 +204,19 @@ def batched_compute_latent(x): mask = jnp.all(jnp.isfinite(latent_values_batch), axis=0) latent_values_batch = latent_values_batch[:, mask] else: - mask = np.all(np.isfinite(latent_values_batch), axis=0) - latent_values_batch = latent_values_batch[:, mask] - - for sample, values in zip(samples.sample_list[i:i + batch_size], latent_values_batch): + # Drop samples whose latent computation failed (e.g. FitException from + # model assertions surfaced as a NaN row in _safe_compute). This leaves + # the per-latent column mask to continue handling degenerate latent + # dimensions that produce NaN for all remaining samples. + row_mask = np.all(np.isfinite(latent_values_batch), axis=1) + latent_values_batch = latent_values_batch[row_mask] + batch_samples = [s for s, keep in zip(batch_samples, row_mask) if keep] + + if len(latent_values_batch): + col_mask = np.all(np.isfinite(latent_values_batch), axis=0) + latent_values_batch = latent_values_batch[:, col_mask] + + for sample, values in zip(batch_samples, latent_values_batch): kwargs = {k: float(v) for k, v in zip(self.LATENT_KEYS, values)} diff --git a/test_autofit/analysis/test_latent_variables.py b/test_autofit/analysis/test_latent_variables.py index 8965cb840..c4e15f6a6 100644 --- a/test_autofit/analysis/test_latent_variables.py +++ b/test_autofit/analysis/test_latent_variables.py @@ -77,6 +77,45 @@ def test_compute_latent_samples(latent_samples): assert latent_samples.model.instance_from_vector([1.0]).fwhm == 1.0 +class AssertionAnalysis(af.Analysis): + + LATENT_KEYS = ["fwhm"] + + def log_likelihood_function(self, instance): + return 1.0 + + def compute_latent_variables(self, parameters, model): + if parameters[0] < 0: + raise af.exc.FitException("assertion violated") + instance = model.instance_from_vector(vector=parameters) + return (instance.fwhm,) + + +def test_compute_latent_samples_skips_fit_exception_samples(): + analysis = AssertionAnalysis() + latent_samples = analysis.compute_latent_samples( + SamplesPDF( + model=af.Model(af.ex.Gaussian), + sample_list=[ + af.Sample( + log_likelihood=1.0, + log_prior=0.0, + weight=1.0, + kwargs={"centre": 1.0, "normalization": 2.0, "sigma": 3.0}, + ), + af.Sample( + log_likelihood=-1.0, + log_prior=0.0, + weight=0.0, + kwargs={"centre": -1.0, "normalization": 2.0, "sigma": 3.0}, + ), + ], + ), + ) + assert len(latent_samples.sample_list) == 1 + assert latent_samples.sample_list[0].kwargs == {"fwhm": 7.0644601350928475} + + def test_info(latent_samples): info = result_info_from(latent_samples) assert (