Skip to content

Remove use_jax_for_visualization; add visualization warmup#1297

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/unify-jax-visualization
May 27, 2026
Merged

Remove use_jax_for_visualization; add visualization warmup#1297
Jammy2211 merged 1 commit into
mainfrom
feature/unify-jax-visualization

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Removes the use_jax_for_visualization flag from Analysis.__init__ and adds a visualization warmup call in Fitness.__init__. Profiling revealed the flag triggered 234 individual JAX JIT compilations (~20s) on the first quick update because the profile methods dispatch to JAX individually via decorators. The warmup pays this cost before sampling starts so every quick update during the search is fast (~2s).

Closes #1296

API Changes

Analysis.__init__ no longer accepts use_jax_for_visualization. Visualization now automatically follows use_jax — if the search uses JAX, visualization does too. The _jitted_fit_from lazy JIT cache on Analysis is removed; the warmup in Fitness.__init__ is a better approach (pre-compiles before sampling, not on first quick update). See full details below.

Test Plan

  • pytest test_autofit/ — 1412 passed, 1 skipped
  • Zero remaining references to use_jax_for_visualization in library code
  • Workspace smoke tests (autolens_workspace, autogalaxy_workspace) — pending after workspace PRs land
  • Manual: run start_here.py with live_visual_update=True and verify quick updates are fast after warmup
Full API Changes (for automation & release notes)

Removed

  • Analysis.__init__(..., use_jax_for_visualization: bool = False) — the parameter no longer exists. Passing it via **kwargs is silently absorbed but has no effect.
  • Analysis._use_jax_for_visualization attribute — no longer set.
  • Analysis._jitted_fit_from lazy cache — the jax.jit(self.fit_from) wrapping in fit_for_visualization is removed entirely.

Added

  • Fitness._warmup_visualization() — called automatically in Fitness.__init__ when iterations_per_quick_update is not None and _xp is JAX. Logs "Warming up visualization..." and calls analysis.fit_for_visualization(instance) + fit.model_data once with prior-median instance to warm per-function JIT caches.

Changed Behaviour

  • Analysis.fit_for_visualization(instance) — no longer dispatches between plain and JIT paths. Always calls self.fit_from(instance). When use_jax=True, the profile methods inside fit_from dispatch to JAX via decorators (same as before with use_jax_for_visualization=False).
  • Analysis.supports_jax_visualization property — now returns self._use_jax instead of the removed self._use_jax_for_visualization. Effect: any analysis with use_jax=True now reports supporting JAX visualization.

Migration

  • Before: al.AnalysisImaging(dataset=dataset, use_jax=True, use_jax_for_visualization=True)
  • After: al.AnalysisImaging(dataset=dataset, use_jax=True) — visualization follows use_jax automatically.
  • Passing use_jax_for_visualization=True is absorbed by **kwargs and has no effect — it won't crash, but it does nothing.

🤖 Generated with Claude Code

The separate use_jax_for_visualization flag added complexity without
benefit — profiling showed it triggered 234 individual JAX JIT
compilations on the first quick update (20s+). Removing the flag and
adding a warmup call in Fitness.__init__ pays this cost before sampling
starts, so every quick update during the search is fast (~2s).

Changes:
- Analysis.__init__: remove use_jax_for_visualization parameter
- fit_for_visualization: simplify to plain fit_from delegation
- supports_jax_visualization: now returns self._use_jax
- Fitness.__init__: add _warmup_visualization when JAX + quick updates
  are active — calls fit_for_visualization once with prior-median
  instance to warm per-function JIT caches before sampling

Closes #1296

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Jammy2211 Jammy2211 added the pending-release PR queued for the next release build label May 27, 2026
@Jammy2211 Jammy2211 merged commit aa01bd8 into main May 27, 2026
7 checks passed
@Jammy2211 Jammy2211 deleted the feature/unify-jax-visualization branch May 27, 2026 12:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

perf: unify JAX visualization with likelihood JIT path

1 participant