Skip to content

Commit 2a8f5c7

Browse files
Jammy2211claude
authored andcommitted
Merge: fix(jax) — parameterization cache + Fitness pytree auto-register (#1300)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2 parents f24e915 + 48bcb80 commit 2a8f5c7

3 files changed

Lines changed: 59 additions & 4 deletions

File tree

autofit/mapper/prior_model/abstract.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import copy
2-
import functools
32
import inspect
43
import json
54
import logging
@@ -1860,12 +1859,26 @@ def order_no(self) -> str:
18601859
]
18611860
return ":".join(values)
18621861

1863-
@functools.cached_property
1862+
@property
18641863
def parameterization(self) -> str:
18651864
"""
18661865
Describes the path to each of the PriorModels, its class
1867-
and its number of free parameters
1868-
"""
1866+
and its number of free parameters.
1867+
1868+
Cached on first access in ``self.__dict__`` under the
1869+
``_`` -prefixed key ``_parameterization_cache`` so that
1870+
``Collection._instance_for_arguments`` and
1871+
``ModelInstance.dict`` (which iterate ``__dict__`` and filter
1872+
underscore-prefixed keys) do not propagate the cached string
1873+
onto the constructed ``ModelInstance``. A plain
1874+
``functools.cached_property`` writes to ``__dict__[name]``
1875+
without a leading underscore, which would leak the string as
1876+
a non-array JAX pytree leaf and break ``jax.jit(fit_from)``.
1877+
"""
1878+
cached = self.__dict__.get("_parameterization_cache")
1879+
if cached is not None:
1880+
return cached
1881+
18691882
from .prior_model import Model
18701883

18711884
formatter = TextFormatter(line_length=info_whitespace())
@@ -1900,6 +1913,7 @@ def parameterization(self) -> str:
19001913
for group in find_groups(paths, limit=0):
19011914
formatter.add(*group)
19021915

1916+
self.__dict__["_parameterization_cache"] = formatter.text
19031917
return formatter.text
19041918

19051919
@property

autofit/non_linear/fitness.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ def __init__(
122122
self.use_jax_vmap = use_jax_vmap
123123
self.use_jax_jit = use_jax_jit
124124

125+
if getattr(self.analysis, "_use_jax", False):
126+
from autofit.jax.pytrees import enable_pytrees, register_model
127+
128+
enable_pytrees()
129+
register_model(self.model)
130+
125131
self._call = self.call
126132

127133
if self.use_jax_vmap:

test_autofit/mapper/test_parameterization.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,41 @@ def test_tuple_instance_model_info(self, mapper):
141141
assert len(info.split("\n")) == len(mapper.info.split("\n"))
142142

143143

144+
def test_parameterization_cache_does_not_leak_into_instance():
145+
"""Regression: ``parameterization`` is cached in
146+
``self.__dict__["_parameterization_cache"]`` so that
147+
``Collection._instance_for_arguments`` and ``ModelInstance.dict``
148+
(which skip underscore-prefixed keys) do not propagate the cached
149+
string onto the constructed instance. A plain
150+
``functools.cached_property`` would write to ``__dict__["parameterization"]``
151+
without an underscore, leaking the string into ``ModelInstance.dict``
152+
and downstream JAX pytree flattening — see commit 4564ae9a1."""
153+
154+
model = af.Collection(gaussian=af.Model(af.ex.Gaussian))
155+
156+
# Touch model.info → exercises the same propagation path that every
157+
# workspace script hits at construction time.
158+
_ = model.info
159+
_ = model.parameterization # second access uses the cache
160+
161+
# The cache must live behind an underscore key on the model.
162+
assert "_parameterization_cache" in model.__dict__
163+
assert "parameterization" not in model.__dict__
164+
165+
instance = model.instance_from_prior_medians()
166+
167+
# Neither the cached key nor the public name may appear on the
168+
# constructed instance.
169+
assert "parameterization" not in instance.__dict__
170+
assert "_parameterization_cache" not in instance.__dict__
171+
assert "parameterization" not in instance.dict
172+
assert "_parameterization_cache" not in instance.dict
173+
174+
# The instance must yield only model components when iterated.
175+
for child in instance:
176+
assert not isinstance(child, str)
177+
178+
144179
def test_integer_attributes():
145180
model = af.Model(af.ex.Gaussian)
146181

0 commit comments

Comments
 (0)