From a64fb1cbf882d4904b9d43a6b2154adf819cc417 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 24 Apr 2026 20:16:05 +0100 Subject: [PATCH] fix: skip FitException samples in compute_latent_samples MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Model assertions (e.g. `gaussian_0.centre > gaussian_1.centre`) can leave assertion-violating parameter vectors in a non-linear search's posterior sample list — the sampler's own likelihood path rejects them as -inf, but they're still recorded as dead points. Post-fit, `compute_latent_samples` iterates every sample and calls `model.instance_from_vector`, which runs `check_assertions` and raises `FitException`. Before this change, one bad sample killed the whole batch and the entire fit call. Handle it per-sample: the non-JAX `batched_compute_latent` now catches `FitException` and substitutes a NaN row, and the batch loop drops those rows via a new row-mask before the existing per-latent column mask runs. JAX path unchanged — user JAX latent functions can't raise Python exceptions inside `jit`/`vmap` anyway. Added test: `test_compute_latent_samples_skips_fit_exception_samples` seeds a 2-sample aggregator where the second sample would raise `FitException`; asserts only the first sample survives in the output. Co-Authored-By: Claude Opus 4.7 (1M context) --- autofit/non_linear/analysis/analysis.py | 30 +++++++++++--- .../analysis/test_latent_variables.py | 39 +++++++++++++++++++ 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 40aef6dba..f456e37f2 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -6,6 +6,7 @@ import time from typing import Optional, Dict +from autofit import exc from autofit.mapper.prior_model.abstract import AbstractPriorModel from autofit.non_linear.paths.abstract import AbstractPaths from autofit.non_linear.samples.summary import SamplesSummary @@ -173,8 +174,17 @@ def compute_latent_samples(self, samples: Samples, batch_size : Optional[int] = batched_compute_latent = jax.jit(jax.vmap(compute_latent_for_model)) logger.info(f"JAX: vmap and jit applied in {time.time() - start} seconds.") else: + n_latents = len(self.LATENT_KEYS) + nan_row = np.full(n_latents, np.nan) + + def _safe_compute(xx): + try: + return compute_latent_for_model(xx) + except exc.FitException: + return nan_row + def batched_compute_latent(x): - return np.array([compute_latent_for_model(xx) for xx in x]) + return np.array([_safe_compute(xx) for xx in x]) parameter_array = np.array(samples.parameter_lists) latent_samples = [] @@ -183,6 +193,7 @@ def batched_compute_latent(x): for i in range(0, len(parameter_array), batch_size): batch = parameter_array[i:i + batch_size] + batch_samples = samples.sample_list[i:i + batch_size] # batched JAX call on this chunk latent_values_batch = batched_compute_latent(batch) @@ -193,10 +204,19 @@ def batched_compute_latent(x): mask = jnp.all(jnp.isfinite(latent_values_batch), axis=0) latent_values_batch = latent_values_batch[:, mask] else: - mask = np.all(np.isfinite(latent_values_batch), axis=0) - latent_values_batch = latent_values_batch[:, mask] - - for sample, values in zip(samples.sample_list[i:i + batch_size], latent_values_batch): + # Drop samples whose latent computation failed (e.g. FitException from + # model assertions surfaced as a NaN row in _safe_compute). This leaves + # the per-latent column mask to continue handling degenerate latent + # dimensions that produce NaN for all remaining samples. + row_mask = np.all(np.isfinite(latent_values_batch), axis=1) + latent_values_batch = latent_values_batch[row_mask] + batch_samples = [s for s, keep in zip(batch_samples, row_mask) if keep] + + if len(latent_values_batch): + col_mask = np.all(np.isfinite(latent_values_batch), axis=0) + latent_values_batch = latent_values_batch[:, col_mask] + + for sample, values in zip(batch_samples, latent_values_batch): kwargs = {k: float(v) for k, v in zip(self.LATENT_KEYS, values)} diff --git a/test_autofit/analysis/test_latent_variables.py b/test_autofit/analysis/test_latent_variables.py index 8965cb840..c4e15f6a6 100644 --- a/test_autofit/analysis/test_latent_variables.py +++ b/test_autofit/analysis/test_latent_variables.py @@ -77,6 +77,45 @@ def test_compute_latent_samples(latent_samples): assert latent_samples.model.instance_from_vector([1.0]).fwhm == 1.0 +class AssertionAnalysis(af.Analysis): + + LATENT_KEYS = ["fwhm"] + + def log_likelihood_function(self, instance): + return 1.0 + + def compute_latent_variables(self, parameters, model): + if parameters[0] < 0: + raise af.exc.FitException("assertion violated") + instance = model.instance_from_vector(vector=parameters) + return (instance.fwhm,) + + +def test_compute_latent_samples_skips_fit_exception_samples(): + analysis = AssertionAnalysis() + latent_samples = analysis.compute_latent_samples( + SamplesPDF( + model=af.Model(af.ex.Gaussian), + sample_list=[ + af.Sample( + log_likelihood=1.0, + log_prior=0.0, + weight=1.0, + kwargs={"centre": 1.0, "normalization": 2.0, "sigma": 3.0}, + ), + af.Sample( + log_likelihood=-1.0, + log_prior=0.0, + weight=0.0, + kwargs={"centre": -1.0, "normalization": 2.0, "sigma": 3.0}, + ), + ], + ), + ) + assert len(latent_samples.sample_list) == 1 + assert latent_samples.sample_list[0].kwargs == {"fwhm": 7.0644601350928475} + + def test_info(latent_samples): info = result_info_from(latent_samples) assert (