Skip to content

Commit eeb161a

Browse files
Jammy2211Jammy2211claude
authored
feat(DatasetModel): add grid_rotation_angle for multi-band rotation (#312)
Mirrors the existing grid_offset pattern so each band in a multi-band fit can be rotated as well as shifted relative to a reference dataset. Adds Grid2D / Grid2DIrregular.subtracted_and_rotated_from helpers (shift then rotate CCW about the offset point) and wires FitDataset.grids through them so lp / pixelization / blurring grids all carry the transform into the fit. Refs PyAutoLens#511, prototype by @qiuhan06 on dev_Q. Co-authored-by: Jammy2211 <JNightingale2211@gmail.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 3ddf368 commit eeb161a

7 files changed

Lines changed: 273 additions & 48 deletions

File tree

autoarray/dataset/dataset_model.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,42 @@
1-
from typing import Tuple
2-
3-
4-
class DatasetModel:
5-
def __init__(
6-
self,
7-
background_sky_level: float = 0.0,
8-
grid_offset: Tuple[float, float] = (0.0, 0.0),
9-
):
10-
"""
11-
Attributes which allow for parts of a dataset to be treated as a model, meaning they can be fitted
12-
for in the `fit` module.
13-
14-
The following aspects of a dataset can be treated as a model:
15-
16-
- `background_sky_level`: The data may have a constant signal in the background which is estimated
17-
and subtracted from the data beforehand with a degree of uncertainty. By including it in the model it can be
18-
marginalized over. Units are dimensionless and derived from the data.
19-
20-
- `grid_offset`: Two datasets may be offset from one another, for example if they are taken with different
21-
pointing positions. This offset can be included in the model and marginalized over. Units are arc seconds.
22-
23-
Parameters
24-
----------
25-
background_sky_level
26-
Overall normalisation of the sky which is added or subtracted from the data. Units are dimensionless and
27-
derived from the data, which is expected to be electrons per second in Astronomy analyses.
28-
grid_offset
29-
Offset between two datasets, in arc seconds. This is used to align datasets which are taken at different
30-
pointing positions.
31-
"""
32-
self.background_sky_level = background_sky_level
33-
self.grid_offset = grid_offset
1+
from typing import Tuple
2+
3+
4+
class DatasetModel:
5+
def __init__(
6+
self,
7+
background_sky_level: float = 0.0,
8+
grid_offset: Tuple[float, float] = (0.0, 0.0),
9+
grid_rotation_angle: float = 0.0,
10+
):
11+
"""
12+
Attributes which allow for parts of a dataset to be treated as a model, meaning they can be fitted
13+
for in the `fit` module.
14+
15+
The following aspects of a dataset can be treated as a model:
16+
17+
- `background_sky_level`: The data may have a constant signal in the background which is estimated
18+
and subtracted from the data beforehand with a degree of uncertainty. By including it in the model it can be
19+
marginalized over. Units are dimensionless and derived from the data.
20+
21+
- `grid_offset`: Two datasets may be offset from one another, for example if they are taken with different
22+
pointing positions. This offset can be included in the model and marginalized over. Units are arc seconds.
23+
24+
- `grid_rotation_angle`: Two datasets may also be rotated relative to one another (e.g. a different
25+
telescope roll angle). This rotation can be included in the model and marginalized over. Units are degrees,
26+
with positive angles applying a counter-clockwise rotation about the offset point.
27+
28+
Parameters
29+
----------
30+
background_sky_level
31+
Overall normalisation of the sky which is added or subtracted from the data. Units are dimensionless and
32+
derived from the data, which is expected to be electrons per second in Astronomy analyses.
33+
grid_offset
34+
Offset between two datasets, in arc seconds. This is used to align datasets which are taken at different
35+
pointing positions.
36+
grid_rotation_angle
37+
Rotation between two datasets, in degrees. Applied counter-clockwise about the offset point after the
38+
offset is subtracted. Used to align datasets which share a pointing centre but differ in roll angle.
39+
"""
40+
self.background_sky_level = background_sky_level
41+
self.grid_offset = grid_offset
42+
self.grid_rotation_angle = grid_rotation_angle

autoarray/fit/fit_dataset.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -185,26 +185,23 @@ def mask(self) -> Mask2D:
185185
@property
186186
def grids(self) -> GridsInterface:
187187
"""
188-
The grids of (y,x) coordinates associated with the dataset, adjusted by any `grid_offset` specified in
189-
the `dataset_model`. Each grid (`lp`, `pixelization`, `blurring`) has the offset subtracted from it
190-
before being returned.
188+
The grids of (y,x) coordinates associated with the dataset, adjusted by any `grid_offset` and
189+
`grid_rotation_angle` specified in the `dataset_model`. Each grid (`lp`, `pixelization`, `blurring`)
190+
has the offset subtracted from it and is then rotated counter-clockwise by `grid_rotation_angle`
191+
about the offset point before being returned.
191192
"""
192193

193-
def subtracted_from(grid, offset):
194+
offset = self.dataset_model.grid_offset
195+
angle = self.dataset_model.grid_rotation_angle
196+
197+
def shift_and_rotate(grid):
194198
if grid is None:
195199
return None
200+
return grid.subtracted_and_rotated_from(offset=offset, angle=angle, xp=self._xp)
196201

197-
return grid.subtracted_from(offset=offset, xp=self._xp)
198-
199-
lp = subtracted_from(
200-
grid=self.dataset.grids.lp, offset=self.dataset_model.grid_offset
201-
)
202-
pixelization = subtracted_from(
203-
grid=self.dataset.grids.pixelization, offset=self.dataset_model.grid_offset
204-
)
205-
blurring = subtracted_from(
206-
grid=self.dataset.grids.blurring, offset=self.dataset_model.grid_offset
207-
)
202+
lp = shift_and_rotate(self.dataset.grids.lp)
203+
pixelization = shift_and_rotate(self.dataset.grids.pixelization)
204+
blurring = shift_and_rotate(self.dataset.grids.blurring)
208205

209206
return GridsInterface(
210207
lp=lp,

autoarray/structures/grids/irregular_2d.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,39 @@ def grid_of_closest_from(self, grid_pair: "Grid2DIrregular") -> "Grid2DIrregular
279279
closest_points = self.array[closest_idx]
280280

281281
return Grid2DIrregular(closest_points)
282+
283+
def subtracted_from(self, offset, xp=np) -> "Grid2DIrregular":
284+
"""
285+
Return a new Grid2DIrregular with ``offset`` subtracted from every (y, x) coordinate.
286+
"""
287+
offset_array = xp.array(offset)
288+
return Grid2DIrregular(self.array - offset_array)
289+
290+
def subtracted_and_rotated_from(
291+
self, offset, angle: float, xp=np
292+
) -> "Grid2DIrregular":
293+
"""
294+
Return a new Grid2DIrregular where the (y, x) coordinates of this grid have an offset
295+
subtracted and are then rotated counter-clockwise by ``angle`` (in degrees) about the
296+
offset point.
297+
298+
Order matches :meth:`Grid2D.subtracted_and_rotated_from`: shift, then rotate.
299+
300+
Parameters
301+
----------
302+
offset
303+
The (y, x) offset subtracted from every grid coordinate before rotation.
304+
angle
305+
The rotation angle in degrees. Positive values rotate counter-clockwise.
306+
"""
307+
offset_array = xp.array(offset)
308+
angle_rad = xp.deg2rad(angle)
309+
cos_a = xp.cos(angle_rad)
310+
sin_a = xp.sin(angle_rad)
311+
312+
shifted = self.array - offset_array
313+
sy = shifted[:, 0]
314+
sx = shifted[:, 1]
315+
ry = sx * sin_a + sy * cos_a
316+
rx = sx * cos_a - sy * sin_a
317+
return Grid2DIrregular(xp.stack((ry, rx), axis=-1))

autoarray/structures/grids/uniform_2d.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,56 @@ def subtracted_from(
736736
over_sampler=self.over_sampler,
737737
)
738738

739+
def subtracted_and_rotated_from(
740+
self, offset: Tuple[float, float], angle: float, xp=np
741+
) -> "Grid2D":
742+
"""
743+
Return a new Grid2D where the (y, x) coordinates of this grid have an offset subtracted
744+
and are then rotated counter-clockwise by ``angle`` (in degrees) about the offset point.
745+
746+
Order: shift first, then rotate. With ``offset = (oy, ox)`` and ``angle = theta`` (degrees):
747+
748+
(y', x') = (y - oy, x - ox)
749+
y'' = y' cos(theta) + x' sin(theta)
750+
x'' = x' cos(theta) - y' sin(theta)
751+
752+
Parameters
753+
----------
754+
offset
755+
The (y, x) offset subtracted from every grid coordinate before rotation.
756+
angle
757+
The rotation angle in degrees. Positive values rotate counter-clockwise.
758+
"""
759+
offset_array = xp.array(offset)
760+
angle_rad = xp.deg2rad(angle)
761+
cos_a = xp.cos(angle_rad)
762+
sin_a = xp.sin(angle_rad)
763+
764+
def _shift_and_rotate(grid_array):
765+
shifted = grid_array - offset_array
766+
sy = shifted[:, 0]
767+
sx = shifted[:, 1]
768+
ry = sx * sin_a + sy * cos_a
769+
rx = sx * cos_a - sy * sin_a
770+
return xp.stack((ry, rx), axis=-1)
771+
772+
grid_rotated = _shift_and_rotate(self.array)
773+
over_sampled_rotated = _shift_and_rotate(self.over_sampled.array)
774+
775+
mask = Mask2D(
776+
mask=self.mask,
777+
pixel_scales=self.pixel_scales,
778+
origin=(self.origin[0] - offset[0], self.origin[1] - offset[1]),
779+
)
780+
781+
return Grid2D(
782+
values=grid_rotated,
783+
mask=mask,
784+
over_sample_size=self.over_sample_size,
785+
over_sampled=Grid2DIrregular(over_sampled_rotated),
786+
over_sampler=self.over_sampler,
787+
)
788+
739789
@property
740790
def slim(self) -> "Grid2D":
741791
"""

test_autoarray/fit/test_fit_dataset.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,64 @@ def test__grids__with_dataset_model_grid_offset__lp_and_pixelization_grids_offse
112112
assert fit.dataset_model.grid_offset == (1.0, 2.0)
113113
assert fit.grids.lp[0] == pytest.approx((0.0, -3.0), 1.0e-4)
114114
assert fit.grids.pixelization[0] == pytest.approx((0.0, -3.0), 1.0e-4)
115+
116+
117+
def test__grids__with_dataset_model_grid_rotation_angle__lp_grid_rotated_correctly(
118+
imaging_7x7, mask_2d_7x7, model_image_7x7
119+
):
120+
masked_imaging_7x7 = imaging_7x7.apply_mask(mask=mask_2d_7x7)
121+
122+
# Rotation by 90 degrees CCW about the origin maps (y, x) -> (x, -y).
123+
fit = aa.m.MockFitImaging(
124+
dataset=masked_imaging_7x7,
125+
use_mask_in_fit=False,
126+
model_data=model_image_7x7,
127+
dataset_model=aa.DatasetModel(grid_rotation_angle=90.0),
128+
)
129+
130+
assert fit.dataset_model.grid_rotation_angle == 90.0
131+
# 90 deg CCW rotation in (y, x) order maps (y, x) -> (x, -y).
132+
original = masked_imaging_7x7.grids.lp[0]
133+
rotated = fit.grids.lp[0]
134+
assert rotated[0] == pytest.approx(original[1], 1.0e-4)
135+
assert rotated[1] == pytest.approx(-original[0], 1.0e-4)
136+
137+
138+
def test__grids__with_grid_offset_and_grid_rotation_angle__shift_then_rotate(
139+
imaging_7x7, mask_2d_7x7, model_image_7x7
140+
):
141+
masked_imaging_7x7 = imaging_7x7.apply_mask(mask=mask_2d_7x7)
142+
143+
fit = aa.m.MockFitImaging(
144+
dataset=masked_imaging_7x7,
145+
use_mask_in_fit=False,
146+
model_data=model_image_7x7,
147+
dataset_model=aa.DatasetModel(
148+
grid_offset=(1.0, 2.0), grid_rotation_angle=90.0
149+
),
150+
)
151+
152+
# First subtract the offset, then rotate 90deg CCW: (y, x) -> (x, -y).
153+
original = masked_imaging_7x7.grids.lp[0]
154+
shifted_y = original[0] - 1.0
155+
shifted_x = original[1] - 2.0
156+
expected = (shifted_x, -shifted_y)
157+
158+
assert fit.grids.lp[0] == pytest.approx(expected, 1.0e-4)
159+
160+
161+
def test__grids__with_grid_rotation_angle_zero__matches_subtracted_from(
162+
imaging_7x7, mask_2d_7x7, model_image_7x7
163+
):
164+
masked_imaging_7x7 = imaging_7x7.apply_mask(mask=mask_2d_7x7)
165+
166+
fit_rotated = aa.m.MockFitImaging(
167+
dataset=masked_imaging_7x7,
168+
use_mask_in_fit=False,
169+
model_data=model_image_7x7,
170+
dataset_model=aa.DatasetModel(grid_offset=(1.0, 2.0), grid_rotation_angle=0.0),
171+
)
172+
173+
# angle=0 is identity rotation, so the result must equal the offset-only path.
174+
assert fit_rotated.grids.lp[0] == pytest.approx((0.0, -3.0), 1.0e-4)
175+
assert fit_rotated.grids.pixelization[0] == pytest.approx((0.0, -3.0), 1.0e-4)

test_autoarray/structures/grids/test_irregular_2d.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import shutil
33
import numpy as np
4+
import pytest
45

56
import autoarray as aa
67

@@ -125,3 +126,37 @@ def test__grid_of_closest_from():
125126
assert (
126127
grid_of_closest == np.array([[0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [0.0, 0.0]])
127128
).all()
129+
130+
131+
def test__subtracted_from():
132+
grid = aa.Grid2DIrregular(np.array([[1.0, 2.0], [3.0, 4.0]]))
133+
134+
shifted = grid.subtracted_from(offset=(0.5, -0.5))
135+
136+
assert shifted.array == pytest.approx(np.array([[0.5, 2.5], [2.5, 4.5]]), 1.0e-4)
137+
138+
139+
def test__subtracted_and_rotated_from__zero_angle_is_pure_shift():
140+
grid = aa.Grid2DIrregular(np.array([[1.0, 2.0], [3.0, 4.0]]))
141+
142+
shifted = grid.subtracted_and_rotated_from(offset=(0.5, -0.5), angle=0.0)
143+
144+
assert shifted.array == pytest.approx(np.array([[0.5, 2.5], [2.5, 4.5]]), 1.0e-4)
145+
146+
147+
def test__subtracted_and_rotated_from__90_degrees_about_origin():
148+
grid = aa.Grid2DIrregular(np.array([[1.0, 0.0], [0.0, 1.0]]))
149+
150+
rotated = grid.subtracted_and_rotated_from(offset=(0.0, 0.0), angle=90.0)
151+
152+
# 90 deg CCW in (y, x) order maps (y, x) -> (x, -y).
153+
assert rotated.array == pytest.approx(np.array([[0.0, -1.0], [1.0, 0.0]]), 1.0e-4)
154+
155+
156+
def test__subtracted_and_rotated_from__shift_first_then_rotate():
157+
grid = aa.Grid2DIrregular(np.array([[2.0, 3.0]]))
158+
159+
rotated = grid.subtracted_and_rotated_from(offset=(1.0, 1.0), angle=90.0)
160+
161+
# Shifted -> (1.0, 2.0); 90 deg CCW -> (2.0, -1.0).
162+
assert rotated.array == pytest.approx(np.array([[2.0, -1.0]]), 1.0e-4)

test_autoarray/structures/grids/test_uniform_2d.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,3 +830,40 @@ def test__apply_over_sampling():
830830
grid = grid.apply_over_sampling(over_sample_size=2)
831831

832832
assert grid.over_sampled.shape[0] == 16
833+
834+
835+
def test__subtracted_and_rotated_from__zero_angle_is_pure_shift():
836+
grid = aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0, over_sample_size=1)
837+
838+
shifted = grid.subtracted_and_rotated_from(offset=(0.5, -0.5), angle=0.0)
839+
840+
expected = grid.array - np.array([0.5, -0.5])
841+
assert shifted.array == pytest.approx(expected, 1.0e-4)
842+
843+
844+
def test__subtracted_and_rotated_from__90_degrees_about_origin():
845+
grid = aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0, over_sample_size=1)
846+
847+
rotated = grid.subtracted_and_rotated_from(offset=(0.0, 0.0), angle=90.0)
848+
849+
# 90 deg CCW in (y, x) order maps (y, x) -> (x, -y).
850+
expected = np.stack((grid.array[:, 1], -grid.array[:, 0]), axis=-1)
851+
assert rotated.array == pytest.approx(expected, 1.0e-4)
852+
853+
854+
def test__subtracted_and_rotated_from__180_degrees_inverts_coordinates():
855+
grid = aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0, over_sample_size=1)
856+
857+
rotated = grid.subtracted_and_rotated_from(offset=(0.0, 0.0), angle=180.0)
858+
859+
assert rotated.array == pytest.approx(-grid.array, 1.0e-4)
860+
861+
862+
def test__subtracted_and_rotated_from__shift_first_then_rotate():
863+
grid = aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=1.0, over_sample_size=1)
864+
865+
shifted = grid.array - np.array([1.0, 2.0])
866+
expected = np.stack((shifted[:, 1], -shifted[:, 0]), axis=-1)
867+
868+
rotated = grid.subtracted_and_rotated_from(offset=(1.0, 2.0), angle=90.0)
869+
assert rotated.array == pytest.approx(expected, 1.0e-4)

0 commit comments

Comments
 (0)