Skip to content

Commit 48bcb80

Browse files
Jammy2211claude
authored andcommitted
fix(jax): keep parameterization cache off ModelInstance + auto-register pytrees
Two coupled fixes restoring the JAX `jit(fit_from)` path that broke when commit 4564ae9 made `AbstractPriorModel.parameterization` a `functools.cached_property`. `cached_property` writes to `self.__dict__["parameterization"]`. After any `model.info` access, `Collection._instance_for_arguments` (which iterates `__dict__` and skips only underscore-prefixed keys) propagates the cached string onto every `ModelInstance`. The string then surfaces as a non-array JAX pytree leaf (autogalaxy_workspace_test + autolens_workspace_test `jax_likelihood_functions/*` — 38 scripts) and makes `for x in instance:` yield strings instead of profiles (autofit_workspace `overview/overview_1_the_basics.py`). Fix 1: store the cache under the underscore-prefixed key `_parameterization_cache` so both `Collection._instance_for_arguments` and `ModelInstance.dict` filter it out. Preserves the 2.7s → 0.05s perf win from 4564ae9. Fix 2: auto-call `enable_pytrees() + register_model(self.model)` from `Fitness.__init__` whenever `analysis._use_jax=True`. Both helpers are idempotent, so workspaces that still call them explicitly keep working. New JAX-enabled workspaces don't need the boilerplate. Verified locally: - 1413/1413 PyAutoFit unit tests pass + new `test_parameterization_cache_does_not_leak_into_instance` regression - `autofit_workspace/scripts/overview/overview_1_the_basics.py` runs to completion (cluster C4 reproducer) - `autolens_workspace_test/scripts/jax_likelihood_functions/imaging/rectangular.py` prints "PASS: jit(fit_from) round-trip matches NumPy scalar" (cluster C1 reproducer) Follow-up: a structural defense across the four `__dict__`-iterators in `autofit/mapper/` plus `autoarray/abstract_ndarray.py` will ship as a separate PR — a `_cached_property_names(cls)` classmethod applied as an extra filter at every leak site so the next future `@cached_property` on a model class cannot reintroduce this bug. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent f24e915 commit 48bcb80

3 files changed

Lines changed: 59 additions & 4 deletions

File tree

autofit/mapper/prior_model/abstract.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import copy
2-
import functools
32
import inspect
43
import json
54
import logging
@@ -1860,12 +1859,26 @@ def order_no(self) -> str:
18601859
]
18611860
return ":".join(values)
18621861

1863-
@functools.cached_property
1862+
@property
18641863
def parameterization(self) -> str:
18651864
"""
18661865
Describes the path to each of the PriorModels, its class
1867-
and its number of free parameters
1868-
"""
1866+
and its number of free parameters.
1867+
1868+
Cached on first access in ``self.__dict__`` under the
1869+
``_`` -prefixed key ``_parameterization_cache`` so that
1870+
``Collection._instance_for_arguments`` and
1871+
``ModelInstance.dict`` (which iterate ``__dict__`` and filter
1872+
underscore-prefixed keys) do not propagate the cached string
1873+
onto the constructed ``ModelInstance``. A plain
1874+
``functools.cached_property`` writes to ``__dict__[name]``
1875+
without a leading underscore, which would leak the string as
1876+
a non-array JAX pytree leaf and break ``jax.jit(fit_from)``.
1877+
"""
1878+
cached = self.__dict__.get("_parameterization_cache")
1879+
if cached is not None:
1880+
return cached
1881+
18691882
from .prior_model import Model
18701883

18711884
formatter = TextFormatter(line_length=info_whitespace())
@@ -1900,6 +1913,7 @@ def parameterization(self) -> str:
19001913
for group in find_groups(paths, limit=0):
19011914
formatter.add(*group)
19021915

1916+
self.__dict__["_parameterization_cache"] = formatter.text
19031917
return formatter.text
19041918

19051919
@property

autofit/non_linear/fitness.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ def __init__(
122122
self.use_jax_vmap = use_jax_vmap
123123
self.use_jax_jit = use_jax_jit
124124

125+
if getattr(self.analysis, "_use_jax", False):
126+
from autofit.jax.pytrees import enable_pytrees, register_model
127+
128+
enable_pytrees()
129+
register_model(self.model)
130+
125131
self._call = self.call
126132

127133
if self.use_jax_vmap:

test_autofit/mapper/test_parameterization.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,41 @@ def test_tuple_instance_model_info(self, mapper):
141141
assert len(info.split("\n")) == len(mapper.info.split("\n"))
142142

143143

144+
def test_parameterization_cache_does_not_leak_into_instance():
145+
"""Regression: ``parameterization`` is cached in
146+
``self.__dict__["_parameterization_cache"]`` so that
147+
``Collection._instance_for_arguments`` and ``ModelInstance.dict``
148+
(which skip underscore-prefixed keys) do not propagate the cached
149+
string onto the constructed instance. A plain
150+
``functools.cached_property`` would write to ``__dict__["parameterization"]``
151+
without an underscore, leaking the string into ``ModelInstance.dict``
152+
and downstream JAX pytree flattening — see commit 4564ae9a1."""
153+
154+
model = af.Collection(gaussian=af.Model(af.ex.Gaussian))
155+
156+
# Touch model.info → exercises the same propagation path that every
157+
# workspace script hits at construction time.
158+
_ = model.info
159+
_ = model.parameterization # second access uses the cache
160+
161+
# The cache must live behind an underscore key on the model.
162+
assert "_parameterization_cache" in model.__dict__
163+
assert "parameterization" not in model.__dict__
164+
165+
instance = model.instance_from_prior_medians()
166+
167+
# Neither the cached key nor the public name may appear on the
168+
# constructed instance.
169+
assert "parameterization" not in instance.__dict__
170+
assert "_parameterization_cache" not in instance.__dict__
171+
assert "parameterization" not in instance.dict
172+
assert "_parameterization_cache" not in instance.dict
173+
174+
# The instance must yield only model components when iterated.
175+
for child in instance:
176+
assert not isinstance(child, str)
177+
178+
144179
def test_integer_attributes():
145180
model = af.Model(af.ex.Gaussian)
146181

0 commit comments

Comments
 (0)