Skip to content

Commit bc360c0

Browse files
Jammy2211Jammy2211claude
authored
feat(AdaptImages): rotate cached mesh grid with DatasetModel transforms (#416)
The image-plane mesh grid cached in galaxy_name_image_plane_mesh_grid_dict now follows the same shift+rotate transform that FitDataset.grids applies to the data grid, so adaptive pixelization fits stay aligned in multi-band work. Also plumbs dataset_model + xp through adapt_images_via_instance_from and adds a grid_rotation_angle prior to config/priors/dataset_model.yaml. This is the actual fix for the source-reconstruction misalignment in @qiuhan06's dev_Q prototype; he had it correct but never re-tested it. Adds 3 chi^2=0 simulate-and-fit tests in test_simulate_and_fit_imaging.py covering offset-only, rotation-only, and combined cases. Refs PyAutoLens#511, PyAutoArray#312. Co-authored-by: Jammy2211 <JNightingale2211@gmail.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2ae5972 commit bc360c0

5 files changed

Lines changed: 162 additions & 6 deletions

File tree

autogalaxy/analysis/adapt_images/adapt_images.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,12 @@ def model_image(self) -> aa.Array2D:
157157
return adapt_model_image
158158

159159
def updated_via_instance_from(
160-
self, instance, mask=None, galaxies: Optional[List["Galaxy"]] = None
160+
self,
161+
instance,
162+
dataset_model: Optional["aa.DatasetModel"] = None,
163+
mask=None,
164+
galaxies: Optional[List["Galaxy"]] = None,
165+
xp=np,
161166
) -> "AdaptImages":
162167
"""
163168
Returns adapt-images which have been updated to map galaxy instances instead of galaxy names.
@@ -174,10 +179,19 @@ def updated_via_instance_from(
174179
galaxy instances are also created on-fly from the database. Database images do not have a mask, so it is
175180
also applied to the adapt images on-the-fly during database loading.
176181
182+
When a ``dataset_model`` is supplied with a non-trivial ``grid_offset`` or ``grid_rotation_angle``, the cached
183+
``galaxy_name_image_plane_mesh_grid_dict`` entries are transformed into the same frame as the dataset's
184+
image-plane grid (which ``FitDataset.grids`` rotates by the same amount). Without this transform the cached
185+
mesh and the data grid would sit in different frames, producing a misaligned source reconstruction.
186+
177187
Parameters
178188
----------
179189
instance
180190
The instance of the model-fit (e.g. in a non-linear search) which is used to update the adapt images.
191+
dataset_model
192+
The dataset model whose ``grid_offset`` and ``grid_rotation_angle`` are applied to cached mesh grids so
193+
they remain consistent with the rotated/shifted data grid produced by ``FitDataset.grids``. If ``None``,
194+
the cached mesh grids are passed through unchanged.
181195
mask
182196
A mask which can be applied to the adapt images, which is used when setting up the adaptive images
183197
via the aggregator and autofit database tools.
@@ -188,6 +202,8 @@ def updated_via_instance_from(
188202
galaxy instances into fresh objects. When ``None`` the path list is populated in ``path_instance_tuples_for_class``
189203
order, which matches ``Analysis.galaxies_via_instance_from`` for the common case (no
190204
``extra_galaxies`` / ``scaling_galaxies``).
205+
xp
206+
Array backend (``numpy`` or ``jax.numpy``) used when transforming cached mesh grids.
191207
192208
Returns
193209
-------
@@ -226,9 +242,14 @@ def updated_via_instance_from(
226242
galaxy_name = str(galaxy_name)
227243

228244
if galaxy_name in self.galaxy_name_image_plane_mesh_grid_dict:
229-
galaxy_image_plane_mesh_grid_dict[galaxy] = (
230-
self.galaxy_name_image_plane_mesh_grid_dict[galaxy_name]
231-
)
245+
cached_mesh = self.galaxy_name_image_plane_mesh_grid_dict[galaxy_name]
246+
if dataset_model is not None:
247+
cached_mesh = cached_mesh.subtracted_and_rotated_from(
248+
offset=dataset_model.grid_offset,
249+
angle=dataset_model.grid_rotation_angle,
250+
xp=xp,
251+
)
252+
galaxy_image_plane_mesh_grid_dict[galaxy] = cached_mesh
232253

233254
if galaxies is not None:
234255
galaxy_path_list = [path_by_id.get(id(g)) for g in galaxies]

autogalaxy/analysis/analysis/dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,16 @@ def save_results(self, paths: af.DirectoryPaths, result: ResultDataset):
172172
def adapt_images_via_instance_from(
173173
self,
174174
instance: af.ModelInstance,
175+
dataset_model: Optional[aa.DatasetModel] = None,
175176
galaxies=None,
177+
xp=np,
176178
) -> AdaptImages:
177179
try:
178180
return self.adapt_images.updated_via_instance_from(
179-
instance=instance, galaxies=galaxies
181+
instance=instance,
182+
dataset_model=dataset_model,
183+
galaxies=galaxies,
184+
xp=xp,
180185
)
181186
except AttributeError:
182187
pass

autogalaxy/config/priors/dataset_model.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,8 @@ DatasetModel:
66
type: Constant
77
value: 0.0
88
grid_offset_1:
9+
type: Constant
10+
value: 0.0
11+
grid_rotation_angle:
912
type: Constant
1013
value: 0.0

autogalaxy/imaging/model/analysis.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ def fit_from(self, instance: af.ModelInstance) -> FitImaging:
153153
dataset_model = self.dataset_model_via_instance_from(instance=instance)
154154

155155
adapt_images = self.adapt_images_via_instance_from(
156-
instance=instance, galaxies=galaxies
156+
instance=instance,
157+
dataset_model=dataset_model,
158+
galaxies=galaxies,
159+
xp=self._xp,
157160
)
158161

159162
return FitImaging(

test_autogalaxy/imaging/test_simulate_and_fit_imaging.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,130 @@ def test__perfect_fit__simulate_and_reload__chi_squared_zero():
8181
shutil.rmtree(file_path)
8282

8383

84+
def _perfect_fit_dataset(galaxies, grid):
85+
"""Helper: simulate noiseless imaging and zero out the noise map for chi^2 tests."""
86+
psf = ag.Convolver.from_gaussian(
87+
shape_native=(3, 3), pixel_scales=grid.pixel_scales[0], sigma=0.05, normalize=True
88+
)
89+
simulator = ag.SimulatorImaging(
90+
exposure_time=300.0, psf=psf, add_poisson_noise_to_data=False
91+
)
92+
dataset = simulator.via_galaxies_from(galaxies=galaxies, grid=grid)
93+
dataset.noise_map = ag.Array2D.ones(
94+
shape_native=dataset.data.shape_native, pixel_scales=grid.pixel_scales
95+
)
96+
return dataset
97+
98+
99+
def test__perfect_fit__sim_offset_centre__fit_with_dataset_model_grid_offset__chi_squared_zero():
100+
"""Sim a profile with offset centre; fit with origin profile + DatasetModel.grid_offset."""
101+
grid = ag.Grid2D.uniform(shape_native=(31, 31), pixel_scales=0.2, over_sample_size=1)
102+
centre = (0.3, 0.2)
103+
104+
sim_galaxy = ag.Galaxy(
105+
redshift=0.5,
106+
light=ag.lp.Sersic(centre=centre, intensity=0.5, effective_radius=0.5),
107+
)
108+
dataset = _perfect_fit_dataset([sim_galaxy], grid)
109+
mask = ag.Mask2D.circular(
110+
shape_native=dataset.data.shape_native, pixel_scales=0.2, radius=2.5
111+
)
112+
masked = dataset.apply_mask(mask=mask)
113+
114+
fit_galaxy = ag.Galaxy(
115+
redshift=0.5,
116+
light=ag.lp.Sersic(centre=(0.0, 0.0), intensity=0.5, effective_radius=0.5),
117+
)
118+
dataset_model = ag.DatasetModel(grid_offset=centre)
119+
fit = ag.FitImaging(
120+
dataset=masked, galaxies=[fit_galaxy], dataset_model=dataset_model
121+
)
122+
123+
assert fit.chi_squared == pytest.approx(0.0, abs=1e-4)
124+
125+
126+
def test__perfect_fit__sim_rotated_ellipse__fit_with_dataset_model_grid_rotation__chi_squared_zero():
127+
"""Sim a rotated ellipse; fit with axis-aligned profile + DatasetModel.grid_rotation_angle.
128+
129+
Convention: profile with ell-angle theta is equivalent to grid rotated by -theta,
130+
so fit with grid_rotation_angle=-theta to recover chi^2 = 0.
131+
"""
132+
grid = ag.Grid2D.uniform(shape_native=(31, 31), pixel_scales=0.2, over_sample_size=1)
133+
theta = 15.0
134+
135+
sim_galaxy = ag.Galaxy(
136+
redshift=0.5,
137+
light=ag.lp.Sersic(
138+
centre=(0.0, 0.0),
139+
ell_comps=ag.convert.ell_comps_from(axis_ratio=0.6, angle=theta),
140+
intensity=0.5,
141+
effective_radius=0.5,
142+
),
143+
)
144+
dataset = _perfect_fit_dataset([sim_galaxy], grid)
145+
mask = ag.Mask2D.circular(
146+
shape_native=dataset.data.shape_native, pixel_scales=0.2, radius=2.5
147+
)
148+
masked = dataset.apply_mask(mask=mask)
149+
150+
fit_galaxy = ag.Galaxy(
151+
redshift=0.5,
152+
light=ag.lp.Sersic(
153+
centre=(0.0, 0.0),
154+
ell_comps=ag.convert.ell_comps_from(axis_ratio=0.6, angle=0.0),
155+
intensity=0.5,
156+
effective_radius=0.5,
157+
),
158+
)
159+
dataset_model = ag.DatasetModel(grid_rotation_angle=-theta)
160+
fit = ag.FitImaging(
161+
dataset=masked, galaxies=[fit_galaxy], dataset_model=dataset_model
162+
)
163+
164+
assert fit.chi_squared == pytest.approx(0.0, abs=1e-4)
165+
166+
167+
def test__perfect_fit__sim_offset_and_rotated__fit_with_dataset_model_offset_and_rotation__chi_squared_zero():
168+
"""Combined: sim with offset centre AND rotated ellipse, fit with identity profile +
169+
DatasetModel(grid_offset, grid_rotation_angle)."""
170+
grid = ag.Grid2D.uniform(shape_native=(31, 31), pixel_scales=0.2, over_sample_size=1)
171+
centre = (0.3, 0.2)
172+
theta = 12.0
173+
174+
sim_galaxy = ag.Galaxy(
175+
redshift=0.5,
176+
light=ag.lp.Sersic(
177+
centre=centre,
178+
ell_comps=ag.convert.ell_comps_from(axis_ratio=0.6, angle=theta),
179+
intensity=0.5,
180+
effective_radius=0.5,
181+
),
182+
)
183+
dataset = _perfect_fit_dataset([sim_galaxy], grid)
184+
mask = ag.Mask2D.circular(
185+
shape_native=dataset.data.shape_native, pixel_scales=0.2, radius=2.5
186+
)
187+
masked = dataset.apply_mask(mask=mask)
188+
189+
fit_galaxy = ag.Galaxy(
190+
redshift=0.5,
191+
light=ag.lp.Sersic(
192+
centre=(0.0, 0.0),
193+
ell_comps=ag.convert.ell_comps_from(axis_ratio=0.6, angle=0.0),
194+
intensity=0.5,
195+
effective_radius=0.5,
196+
),
197+
)
198+
dataset_model = ag.DatasetModel(
199+
grid_offset=centre, grid_rotation_angle=-theta
200+
)
201+
fit = ag.FitImaging(
202+
dataset=masked, galaxies=[fit_galaxy], dataset_model=dataset_model
203+
)
204+
205+
assert fit.chi_squared == pytest.approx(0.0, abs=1e-4)
206+
207+
84208
def test__simulate_imaging_data_and_fit__standard_galaxies__known_figure_of_merit():
85209
grid = ag.Grid2D.uniform(shape_native=(31, 31), pixel_scales=0.2)
86210

0 commit comments

Comments
 (0)