Skip to content

Commit 0ea8d6d

Browse files
authored
Merge pull request #465 from PyAutoLabs/feature/cluster-simulator
feat: add optional redshift to PointDataset
2 parents 6f720c6 + 74fa707 commit 0ea8d6d

2 files changed

Lines changed: 127 additions & 4 deletions

File tree

autolens/point/dataset.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_BASE_HEADERS = ["name", "y", "x", "positions_noise"]
3232
_FLUX_HEADERS = ["flux", "flux_noise"]
3333
_TIME_DELAY_HEADERS = ["time_delay", "time_delay_noise"]
34+
_REDSHIFT_HEADERS = ["redshift"]
3435

3536

3637
class PointDataset:
@@ -45,6 +46,7 @@ def __init__(
4546
time_delays_noise_map: Optional[
4647
Union[float, aa.ArrayIrregular, List[float]]
4748
] = None,
49+
redshift: Optional[float] = None,
4850
):
4951
"""
5052
A collection of the data component that can be used for point-source model-fitting, for example fitting the
@@ -73,6 +75,9 @@ def __init__(
7375
The time delays of each observed point-source of light in days.
7476
time_delays_noise_map
7577
The noise-value of every observed time delay, which is typically measured from the time delay analysis.
78+
redshift
79+
The redshift of the source. Optional; when provided it is carried through CSV round-trips alongside
80+
the positions so cluster-scale workflows can encode per-source redshifts in a single spreadsheet.
7681
"""
7782

7883
self.name = name
@@ -111,6 +116,8 @@ def convert_to_array_irregular(values):
111116
self.time_delays = convert_to_array_irregular(time_delays)
112117
self.time_delays_noise_map = convert_to_array_irregular(time_delays_noise_map)
113118

119+
self.redshift = float(redshift) if redshift is not None else None
120+
114121
@property
115122
def info(self) -> str:
116123
"""
@@ -125,6 +132,7 @@ def info(self) -> str:
125132
info += f"fluxes_noise_map : {self.fluxes_noise_map}\n"
126133
info += f"time_delays : {self.time_delays}\n"
127134
info += f"time_delays_noise_map : {self.time_delays_noise_map}\n"
135+
info += f"redshift : {self.redshift}\n"
128136
return info
129137

130138
def extent_from(self, buffer: float = 0.1):
@@ -202,22 +210,28 @@ def output_to_csv(datasets: List[PointDataset], file_path: str):
202210
image.
203211
204212
The base columns (``name, y, x, positions_noise``) are always written. The
205-
optional ``flux``/``flux_noise`` and ``time_delay``/``time_delay_noise`` columns
206-
are included when *any* dataset in ``datasets`` carries those values; datasets
207-
that do not carry them leave those cells blank.
213+
optional ``flux``/``flux_noise``, ``time_delay``/``time_delay_noise`` and
214+
``redshift`` columns are included when *any* dataset in ``datasets`` carries
215+
those values; datasets that do not carry them leave those cells blank.
216+
217+
When written, every row in a given ``name`` group repeats the same ``redshift``
218+
value — the source redshift is a per-source property, not per-image.
208219
209220
This is the hand-editable / spreadsheet form preferred for strong-lens cluster
210221
workflows with tens or hundreds of multiply-imaged sources. For exact
211222
round-trip serialisation use ``output_to_json`` / ``from_json``.
212223
"""
213224
include_flux = any(d.fluxes is not None for d in datasets)
214225
include_time_delay = any(d.time_delays is not None for d in datasets)
226+
include_redshift = any(d.redshift is not None for d in datasets)
215227

216228
headers = list(_BASE_HEADERS)
217229
if include_flux:
218230
headers += _FLUX_HEADERS
219231
if include_time_delay:
220232
headers += _TIME_DELAY_HEADERS
233+
if include_redshift:
234+
headers += _REDSHIFT_HEADERS
221235

222236
rows = []
223237
for dataset in datasets:
@@ -247,6 +261,10 @@ def output_to_csv(datasets: List[PointDataset], file_path: str):
247261
row["time_delay_noise"] = (
248262
"" if time_delays_noise is None else time_delays_noise[i]
249263
)
264+
if include_redshift:
265+
row["redshift"] = (
266+
"" if dataset.redshift is None else dataset.redshift
267+
)
250268
rows.append(row)
251269

252270
csvable.output_to_csv(rows, file_path, headers=headers)
@@ -270,17 +288,47 @@ def _float_column(
270288
return [float(v) for v in raw]
271289

272290

291+
def _group_redshift(
292+
group_rows: List[dict], group_name: str
293+
) -> Optional[float]:
294+
raw = [row.get("redshift", "") for row in group_rows]
295+
populated = [v for v in raw if v not in ("", None)]
296+
297+
if not populated:
298+
return None
299+
300+
if len(populated) != len(raw):
301+
raise ValueError(
302+
f"CSV group {group_name!r} has partially populated column "
303+
f"'redshift'; every row in the group must have a value or all be blank."
304+
)
305+
306+
values = [float(v) for v in populated]
307+
if any(v != values[0] for v in values):
308+
raise ValueError(
309+
f"CSV group {group_name!r} has inconsistent 'redshift' values "
310+
f"{values!r}; a source redshift must be identical across all of its "
311+
f"image rows."
312+
)
313+
314+
return values[0]
315+
316+
273317
def list_from_csv(file_path: str) -> List[PointDataset]:
274318
"""
275319
Load a list of ``PointDataset`` objects from a CSV written by
276320
:func:`output_to_csv` (or :meth:`PointDataset.to_csv`).
277321
278322
Rows are grouped by their ``name`` column — one ``PointDataset`` per distinct
279-
name, preserving the order of first appearance. Optional columns
323+
name, preserving the order of first appearance. Optional per-image columns
280324
(``flux``/``flux_noise``, ``time_delay``/``time_delay_noise``) are carried through
281325
per-group: if every row in a group populates the column the values are loaded,
282326
if every row leaves it blank the corresponding attribute is set to ``None``, and
283327
any partial-population is rejected with a ``ValueError``.
328+
329+
The optional ``redshift`` column is per-source (not per-image): every row within
330+
a group must share the same value. A group with mixed or differing redshifts is
331+
rejected with a ``ValueError``.
284332
"""
285333
rows = csvable.list_from_csv(file_path)
286334

@@ -304,6 +352,7 @@ def list_from_csv(file_path: str) -> List[PointDataset]:
304352
has_flux_noise_column = "flux_noise" in headers
305353
has_time_delay_column = "time_delay" in headers
306354
has_time_delay_noise_column = "time_delay_noise" in headers
355+
has_redshift_column = "redshift" in headers
307356

308357
datasets: List[PointDataset] = []
309358
for name, group_rows in groups.items():
@@ -332,6 +381,11 @@ def list_from_csv(file_path: str) -> List[PointDataset]:
332381
if has_time_delay_noise_column
333382
else None
334383
)
384+
redshift = (
385+
_group_redshift(group_rows, name)
386+
if has_redshift_column
387+
else None
388+
)
335389

336390
datasets.append(
337391
PointDataset(
@@ -342,6 +396,7 @@ def list_from_csv(file_path: str) -> List[PointDataset]:
342396
fluxes_noise_map=fluxes_noise_map,
343397
time_delays=time_delays,
344398
time_delays_noise_map=time_delays_noise_map,
399+
redshift=redshift,
345400
)
346401
)
347402

test_autolens/point/test_dataset.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def _assert_dataset_equal(actual: al.PointDataset, expected: al.PointDataset):
2929
_assert_array_close(actual.fluxes_noise_map, expected.fluxes_noise_map)
3030
_assert_array_close(actual.time_delays, expected.time_delays)
3131
_assert_array_close(actual.time_delays_noise_map, expected.time_delays_noise_map)
32+
if expected.redshift is None:
33+
assert actual.redshift is None
34+
else:
35+
assert actual.redshift == pytest.approx(expected.redshift)
3236

3337

3438
def test__csv_round_trip__positions_only(tmp_path):
@@ -133,6 +137,70 @@ def test__csv_list_round_trip__heterogeneous_optional_columns(tmp_path):
133137
assert loaded[1].fluxes_noise_map is None
134138

135139

140+
def test__csv_round_trip__redshift(tmp_path):
141+
dataset = al.PointDataset(
142+
name="source_0",
143+
positions=[(0.5, 1.0), (-0.25, 2.0), (1.5, -1.0)],
144+
positions_noise_map=[0.05, 0.05, 0.1],
145+
redshift=2.5,
146+
)
147+
148+
file_path = os.path.join(tmp_path, "point_dataset.csv")
149+
dataset.to_csv(file_path)
150+
151+
loaded = al.PointDataset.from_csv(file_path)
152+
153+
_assert_dataset_equal(loaded, dataset)
154+
assert loaded.redshift == pytest.approx(2.5)
155+
156+
157+
def test__csv_list_round_trip__mixed_redshift_presence(tmp_path):
158+
with_redshift = al.PointDataset(
159+
name="source_0",
160+
positions=[(0.0, 0.0), (1.0, 1.0)],
161+
positions_noise_map=[0.05, 0.05],
162+
redshift=1.8,
163+
)
164+
without_redshift = al.PointDataset(
165+
name="source_1",
166+
positions=[(2.0, 0.5), (-1.0, 0.5)],
167+
positions_noise_map=[0.1, 0.1],
168+
)
169+
170+
file_path = os.path.join(tmp_path, "point_datasets.csv")
171+
al.output_to_csv([with_redshift, without_redshift], file_path)
172+
173+
loaded = al.list_from_csv(file_path)
174+
175+
assert [d.name for d in loaded] == ["source_0", "source_1"]
176+
_assert_dataset_equal(loaded[0], with_redshift)
177+
_assert_dataset_equal(loaded[1], without_redshift)
178+
assert loaded[0].redshift == pytest.approx(1.8)
179+
assert loaded[1].redshift is None
180+
181+
182+
def test__list_from_csv__inconsistent_redshift_raises(tmp_path):
183+
file_path = os.path.join(tmp_path, "point_datasets.csv")
184+
with open(file_path, "w") as f:
185+
f.write("name,y,x,positions_noise,redshift\n")
186+
f.write("source_0,0.0,0.0,0.05,1.5\n")
187+
f.write("source_0,1.0,1.0,0.05,2.0\n")
188+
189+
with pytest.raises(ValueError, match="inconsistent 'redshift'"):
190+
al.list_from_csv(file_path)
191+
192+
193+
def test__list_from_csv__partial_redshift_raises(tmp_path):
194+
file_path = os.path.join(tmp_path, "point_datasets.csv")
195+
with open(file_path, "w") as f:
196+
f.write("name,y,x,positions_noise,redshift\n")
197+
f.write("source_0,0.0,0.0,0.05,1.5\n")
198+
f.write("source_0,1.0,1.0,0.05,\n")
199+
200+
with pytest.raises(ValueError, match="partially populated column 'redshift'"):
201+
al.list_from_csv(file_path)
202+
203+
136204
def test__from_csv__multiple_groups_requires_name(tmp_path):
137205
datasets = [
138206
al.PointDataset(

0 commit comments

Comments
 (0)