Remove use_jax_for_visualization; add visualization warmup#1297
Merged
Conversation
The separate use_jax_for_visualization flag added complexity without benefit — profiling showed it triggered 234 individual JAX JIT compilations on the first quick update (20s+). Removing the flag and adding a warmup call in Fitness.__init__ pays this cost before sampling starts, so every quick update during the search is fast (~2s). Changes: - Analysis.__init__: remove use_jax_for_visualization parameter - fit_for_visualization: simplify to plain fit_from delegation - supports_jax_visualization: now returns self._use_jax - Fitness.__init__: add _warmup_visualization when JAX + quick updates are active — calls fit_for_visualization once with prior-median instance to warm per-function JIT caches before sampling Closes #1296 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This was referenced May 27, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Removes the
use_jax_for_visualizationflag fromAnalysis.__init__and adds a visualization warmup call inFitness.__init__. Profiling revealed the flag triggered 234 individual JAX JIT compilations (~20s) on the first quick update because the profile methods dispatch to JAX individually via decorators. The warmup pays this cost before sampling starts so every quick update during the search is fast (~2s).Closes #1296
API Changes
Analysis.__init__no longer acceptsuse_jax_for_visualization. Visualization now automatically followsuse_jax— if the search uses JAX, visualization does too. The_jitted_fit_fromlazy JIT cache on Analysis is removed; the warmup inFitness.__init__is a better approach (pre-compiles before sampling, not on first quick update). See full details below.Test Plan
pytest test_autofit/— 1412 passed, 1 skippeduse_jax_for_visualizationin library codestart_here.pywithlive_visual_update=Trueand verify quick updates are fast after warmupFull API Changes (for automation & release notes)
Removed
Analysis.__init__(..., use_jax_for_visualization: bool = False)— the parameter no longer exists. Passing it via**kwargsis silently absorbed but has no effect.Analysis._use_jax_for_visualizationattribute — no longer set.Analysis._jitted_fit_fromlazy cache — thejax.jit(self.fit_from)wrapping infit_for_visualizationis removed entirely.Added
Fitness._warmup_visualization()— called automatically inFitness.__init__wheniterations_per_quick_update is not Noneand_xpis JAX. Logs "Warming up visualization..." and callsanalysis.fit_for_visualization(instance)+fit.model_dataonce with prior-median instance to warm per-function JIT caches.Changed Behaviour
Analysis.fit_for_visualization(instance)— no longer dispatches between plain and JIT paths. Always callsself.fit_from(instance). Whenuse_jax=True, the profile methods insidefit_fromdispatch to JAX via decorators (same as before withuse_jax_for_visualization=False).Analysis.supports_jax_visualizationproperty — now returnsself._use_jaxinstead of the removedself._use_jax_for_visualization. Effect: any analysis withuse_jax=Truenow reports supporting JAX visualization.Migration
al.AnalysisImaging(dataset=dataset, use_jax=True, use_jax_for_visualization=True)al.AnalysisImaging(dataset=dataset, use_jax=True)— visualization followsuse_jaxautomatically.use_jax_for_visualization=Trueis absorbed by**kwargsand has no effect — it won't crash, but it does nothing.🤖 Generated with Claude Code