From 6f950bc2461c79c9a35420ac97de9f1e3971bf92 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 20 May 2026 18:00:23 +0100 Subject: [PATCH] fix(csv): preserve TuplePrior on af.Model built from tuple-param rows galaxy_af_models_from_csv_tables previously did af.Model(cls, **params), which for tuple-valued params (e.g. centre=(0.3, 0.5)) bypassed PyAutoFit's TuplePrior auto-create path and stored `centre` as a raw tuple attribute. Later `.centre_0 = GaussianPrior(...)` overrides then created ghost direct attributes alongside the raw tuple, so at sample time the constructor was called as `Point(centre=(0.3, 0.5), centre_0=..., centre_1=...)` and raised `TypeError: Point.__init__() got an unexpected keyword argument 'centre_0'`. Build af.Model(cls) first (which triggers the TuplePrior auto-create branch in PyAutoFit for tuple-defaulted ctor args), then setattr each tuple param component-wise. Scalar params unchanged. Later prior overrides on `.centre_0`/`.centre_1` now delegate into the auto-created TuplePrior instead of producing ghost direct attributes. Unblocks autolens_workspace/scripts/cluster/{start_here,modeling}.py. Co-Authored-By: Claude Opus 4.7 (1M context) --- autogalaxy/galaxy/galaxy_model_csv.py | 11 ++++++++- .../galaxy/test_galaxy_model_csv.py | 24 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/autogalaxy/galaxy/galaxy_model_csv.py b/autogalaxy/galaxy/galaxy_model_csv.py index d1c2483c..fc2bd3a3 100644 --- a/autogalaxy/galaxy/galaxy_model_csv.py +++ b/autogalaxy/galaxy/galaxy_model_csv.py @@ -334,7 +334,16 @@ def galaxy_af_models_from_csv_tables(*tables: GalaxyModelTable) -> Dict[str, Any galaxy_models: Dict[str, Any] = {} for galaxy_name, rows in by_galaxy.items(): redshift = _resolve_redshift(galaxy_name, rows) - attrs = {row.attr_name: af.Model(row.profile_class, **row.params) for row in rows} + attrs: Dict[str, Any] = {} + for row in rows: + model = af.Model(row.profile_class) + for name, value in row.params.items(): + if isinstance(value, tuple): + for i, component in enumerate(value): + setattr(model, f"{name}_{i}", component) + else: + setattr(model, name, value) + attrs[row.attr_name] = model galaxy_models[galaxy_name] = af.Model(Galaxy, redshift=redshift, **attrs) return galaxy_models diff --git a/test_autogalaxy/galaxy/test_galaxy_model_csv.py b/test_autogalaxy/galaxy/test_galaxy_model_csv.py index 5005eb88..a446c08a 100644 --- a/test_autogalaxy/galaxy/test_galaxy_model_csv.py +++ b/test_autogalaxy/galaxy/test_galaxy_model_csv.py @@ -209,6 +209,30 @@ def test__af_models_round_trip(tmp_path): assert galaxy_model.mass.b0 == 3.0 +def test__af_models__tuple_param_supports_prior_override(tmp_path): + point_csv = tmp_path / "point.csv" + + ag.galaxy_models_to_csv( + profiles_by_galaxy={"source_0": {"point_0": ag.ps.Point(centre=(0.3, 0.5))}}, + file_path=point_csv, + family="point", + redshifts={"source_0": 1.0}, + ) + + galaxy_models = ag.galaxy_af_models_from_csv_tables( + ag.galaxy_models_from_csv(point_csv, family="point"), + ) + + point_attr = galaxy_models["source_0"].point_0 + point_attr.centre_0 = af.GaussianPrior(mean=0.3, sigma=3.0) + point_attr.centre_1 = af.GaussianPrior(mean=0.5, sigma=3.0) + + instance = galaxy_models["source_0"].instance_from_unit_vector([0.5, 0.5]) + assert isinstance(instance.point_0, ag.ps.Point) + assert isinstance(instance.point_0.centre, tuple) + assert len(instance.point_0.centre) == 2 + + def test__redshift_consistency_check__raises(tmp_path): mass_csv = tmp_path / "mass.csv" light_csv = tmp_path / "light.csv"