Skip to content

fix: skip FitException samples in compute_latent_samples#1233

Merged
Jammy2211 merged 1 commit intomainfrom
feature/latent-fitexception-safe
Apr 26, 2026
Merged

fix: skip FitException samples in compute_latent_samples#1233
Jammy2211 merged 1 commit intomainfrom
feature/latent-fitexception-safe

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

  • Model assertions (e.g. gaussian_0.centre > gaussian_1.centre) can leave assertion-violating parameter vectors in a non-linear search's posterior sample_list — the sampler's likelihood path rejects them as -inf, but they're still recorded as dead points. Post-fit, compute_latent_samples calls model.instance_from_vector on every sample, which runs check_assertions and raises FitException. Before this change, one bad sample killed the whole batch and the entire search.fit() call.
  • Fix: wrap the non-JAX per-sample call with try/except FitException, substitute a NaN row, then drop those rows via a new row-mask before the existing per-latent column-mask. JAX path is untouched (user JAX latent functions can't raise Python exceptions inside jit/vmap).

Why

Surfaced while stabilising autofit_workspace_test/scripts/features/assertion.py: with PYAUTO_TEST_MODE=1 (reduced iterations) the fit only takes ~6s but CI was flaky (~1-in-2 runs hit this crash). With this fix, the workspace override flips back to TEST_MODE=1 and runs in ~6s instead of ~67–107s (PR autofit_workspace_test#TBD).

Test plan

  • New unit test test_compute_latent_samples_skips_fit_exception_samples (2-sample aggregator, one raises FitException → only the surviving sample appears in output).
  • pytest test_autofit/analysis/test_latent_variables.py → 6 passed
  • pytest test_autofit/non_linear/ test_autofit/analysis/ → 253 passed
  • Integration: 5 consecutive run_smoke.py invocations with TEST_MODE=1, all 9/9 pass, assertion.py in 4–9s range.

🤖 Generated with Claude Code

Model assertions (e.g. `gaussian_0.centre > gaussian_1.centre`) can leave
assertion-violating parameter vectors in a non-linear search's posterior
sample list — the sampler's own likelihood path rejects them as -inf, but
they're still recorded as dead points. Post-fit, `compute_latent_samples`
iterates every sample and calls `model.instance_from_vector`, which runs
`check_assertions` and raises `FitException`. Before this change, one bad
sample killed the whole batch and the entire fit call.

Handle it per-sample: the non-JAX `batched_compute_latent` now catches
`FitException` and substitutes a NaN row, and the batch loop drops those
rows via a new row-mask before the existing per-latent column mask runs.
JAX path unchanged — user JAX latent functions can't raise Python
exceptions inside `jit`/`vmap` anyway.

Added test: `test_compute_latent_samples_skips_fit_exception_samples`
seeds a 2-sample aggregator where the second sample would raise
`FitException`; asserts only the first sample survives in the output.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit 675b93c into main Apr 26, 2026
3 checks passed
@Jammy2211 Jammy2211 deleted the feature/latent-fitexception-safe branch April 26, 2026 10:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant