diff --git a/autogalaxy/galaxy/galaxy_model_csv.py b/autogalaxy/galaxy/galaxy_model_csv.py index d1c2483c..fc2bd3a3 100644 --- a/autogalaxy/galaxy/galaxy_model_csv.py +++ b/autogalaxy/galaxy/galaxy_model_csv.py @@ -334,7 +334,16 @@ def galaxy_af_models_from_csv_tables(*tables: GalaxyModelTable) -> Dict[str, Any galaxy_models: Dict[str, Any] = {} for galaxy_name, rows in by_galaxy.items(): redshift = _resolve_redshift(galaxy_name, rows) - attrs = {row.attr_name: af.Model(row.profile_class, **row.params) for row in rows} + attrs: Dict[str, Any] = {} + for row in rows: + model = af.Model(row.profile_class) + for name, value in row.params.items(): + if isinstance(value, tuple): + for i, component in enumerate(value): + setattr(model, f"{name}_{i}", component) + else: + setattr(model, name, value) + attrs[row.attr_name] = model galaxy_models[galaxy_name] = af.Model(Galaxy, redshift=redshift, **attrs) return galaxy_models diff --git a/test_autogalaxy/galaxy/test_galaxy_model_csv.py b/test_autogalaxy/galaxy/test_galaxy_model_csv.py index 5005eb88..a446c08a 100644 --- a/test_autogalaxy/galaxy/test_galaxy_model_csv.py +++ b/test_autogalaxy/galaxy/test_galaxy_model_csv.py @@ -209,6 +209,30 @@ def test__af_models_round_trip(tmp_path): assert galaxy_model.mass.b0 == 3.0 +def test__af_models__tuple_param_supports_prior_override(tmp_path): + point_csv = tmp_path / "point.csv" + + ag.galaxy_models_to_csv( + profiles_by_galaxy={"source_0": {"point_0": ag.ps.Point(centre=(0.3, 0.5))}}, + file_path=point_csv, + family="point", + redshifts={"source_0": 1.0}, + ) + + galaxy_models = ag.galaxy_af_models_from_csv_tables( + ag.galaxy_models_from_csv(point_csv, family="point"), + ) + + point_attr = galaxy_models["source_0"].point_0 + point_attr.centre_0 = af.GaussianPrior(mean=0.3, sigma=3.0) + point_attr.centre_1 = af.GaussianPrior(mean=0.5, sigma=3.0) + + instance = galaxy_models["source_0"].instance_from_unit_vector([0.5, 0.5]) + assert isinstance(instance.point_0, ag.ps.Point) + assert isinstance(instance.point_0.centre, tuple) + assert len(instance.point_0.centre) == 2 + + def test__redshift_consistency_check__raises(tmp_path): mass_csv = tmp_path / "mass.csv" light_csv = tmp_path / "light.csv"