fix: skip FitException samples in compute_latent_samples#1233
Merged
Conversation
Model assertions (e.g. `gaussian_0.centre > gaussian_1.centre`) can leave assertion-violating parameter vectors in a non-linear search's posterior sample list — the sampler's own likelihood path rejects them as -inf, but they're still recorded as dead points. Post-fit, `compute_latent_samples` iterates every sample and calls `model.instance_from_vector`, which runs `check_assertions` and raises `FitException`. Before this change, one bad sample killed the whole batch and the entire fit call. Handle it per-sample: the non-JAX `batched_compute_latent` now catches `FitException` and substitutes a NaN row, and the batch loop drops those rows via a new row-mask before the existing per-latent column mask runs. JAX path unchanged — user JAX latent functions can't raise Python exceptions inside `jit`/`vmap` anyway. Added test: `test_compute_latent_samples_skips_fit_exception_samples` seeds a 2-sample aggregator where the second sample would raise `FitException`; asserts only the first sample survives in the output. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
gaussian_0.centre > gaussian_1.centre) can leave assertion-violating parameter vectors in a non-linear search's posteriorsample_list— the sampler's likelihood path rejects them as -inf, but they're still recorded as dead points. Post-fit,compute_latent_samplescallsmodel.instance_from_vectoron every sample, which runscheck_assertionsand raisesFitException. Before this change, one bad sample killed the whole batch and the entiresearch.fit()call.try/except FitException, substitute a NaN row, then drop those rows via a new row-mask before the existing per-latent column-mask. JAX path is untouched (user JAX latent functions can't raise Python exceptions insidejit/vmap).Why
Surfaced while stabilising
autofit_workspace_test/scripts/features/assertion.py: withPYAUTO_TEST_MODE=1(reduced iterations) the fit only takes ~6s but CI was flaky (~1-in-2 runs hit this crash). With this fix, the workspace override flips back toTEST_MODE=1and runs in ~6s instead of ~67–107s (PR autofit_workspace_test#TBD).Test plan
test_compute_latent_samples_skips_fit_exception_samples(2-sample aggregator, one raises FitException → only the surviving sample appears in output).pytest test_autofit/analysis/test_latent_variables.py→ 6 passedpytest test_autofit/non_linear/ test_autofit/analysis/→ 253 passedrun_smoke.pyinvocations withTEST_MODE=1, all 9/9 pass, assertion.py in 4–9s range.🤖 Generated with Claude Code