diff --git a/autolens/point/dataset.py b/autolens/point/dataset.py index c14ef46bc..55d2b722e 100644 --- a/autolens/point/dataset.py +++ b/autolens/point/dataset.py @@ -31,6 +31,7 @@ _BASE_HEADERS = ["name", "y", "x", "positions_noise"] _FLUX_HEADERS = ["flux", "flux_noise"] _TIME_DELAY_HEADERS = ["time_delay", "time_delay_noise"] +_REDSHIFT_HEADERS = ["redshift"] class PointDataset: @@ -45,6 +46,7 @@ def __init__( time_delays_noise_map: Optional[ Union[float, aa.ArrayIrregular, List[float]] ] = None, + redshift: Optional[float] = None, ): """ A collection of the data component that can be used for point-source model-fitting, for example fitting the @@ -73,6 +75,9 @@ def __init__( The time delays of each observed point-source of light in days. time_delays_noise_map The noise-value of every observed time delay, which is typically measured from the time delay analysis. + redshift + The redshift of the source. Optional; when provided it is carried through CSV round-trips alongside + the positions so cluster-scale workflows can encode per-source redshifts in a single spreadsheet. """ self.name = name @@ -111,6 +116,8 @@ def convert_to_array_irregular(values): self.time_delays = convert_to_array_irregular(time_delays) self.time_delays_noise_map = convert_to_array_irregular(time_delays_noise_map) + self.redshift = float(redshift) if redshift is not None else None + @property def info(self) -> str: """ @@ -125,6 +132,7 @@ def info(self) -> str: info += f"fluxes_noise_map : {self.fluxes_noise_map}\n" info += f"time_delays : {self.time_delays}\n" info += f"time_delays_noise_map : {self.time_delays_noise_map}\n" + info += f"redshift : {self.redshift}\n" return info def extent_from(self, buffer: float = 0.1): @@ -202,9 +210,12 @@ def output_to_csv(datasets: List[PointDataset], file_path: str): image. The base columns (``name, y, x, positions_noise``) are always written. The - optional ``flux``/``flux_noise`` and ``time_delay``/``time_delay_noise`` columns - are included when *any* dataset in ``datasets`` carries those values; datasets - that do not carry them leave those cells blank. + optional ``flux``/``flux_noise``, ``time_delay``/``time_delay_noise`` and + ``redshift`` columns are included when *any* dataset in ``datasets`` carries + those values; datasets that do not carry them leave those cells blank. + + When written, every row in a given ``name`` group repeats the same ``redshift`` + value — the source redshift is a per-source property, not per-image. This is the hand-editable / spreadsheet form preferred for strong-lens cluster workflows with tens or hundreds of multiply-imaged sources. For exact @@ -212,12 +223,15 @@ def output_to_csv(datasets: List[PointDataset], file_path: str): """ include_flux = any(d.fluxes is not None for d in datasets) include_time_delay = any(d.time_delays is not None for d in datasets) + include_redshift = any(d.redshift is not None for d in datasets) headers = list(_BASE_HEADERS) if include_flux: headers += _FLUX_HEADERS if include_time_delay: headers += _TIME_DELAY_HEADERS + if include_redshift: + headers += _REDSHIFT_HEADERS rows = [] for dataset in datasets: @@ -247,6 +261,10 @@ def output_to_csv(datasets: List[PointDataset], file_path: str): row["time_delay_noise"] = ( "" if time_delays_noise is None else time_delays_noise[i] ) + if include_redshift: + row["redshift"] = ( + "" if dataset.redshift is None else dataset.redshift + ) rows.append(row) csvable.output_to_csv(rows, file_path, headers=headers) @@ -270,17 +288,47 @@ def _float_column( return [float(v) for v in raw] +def _group_redshift( + group_rows: List[dict], group_name: str +) -> Optional[float]: + raw = [row.get("redshift", "") for row in group_rows] + populated = [v for v in raw if v not in ("", None)] + + if not populated: + return None + + if len(populated) != len(raw): + raise ValueError( + f"CSV group {group_name!r} has partially populated column " + f"'redshift'; every row in the group must have a value or all be blank." + ) + + values = [float(v) for v in populated] + if any(v != values[0] for v in values): + raise ValueError( + f"CSV group {group_name!r} has inconsistent 'redshift' values " + f"{values!r}; a source redshift must be identical across all of its " + f"image rows." + ) + + return values[0] + + def list_from_csv(file_path: str) -> List[PointDataset]: """ Load a list of ``PointDataset`` objects from a CSV written by :func:`output_to_csv` (or :meth:`PointDataset.to_csv`). Rows are grouped by their ``name`` column — one ``PointDataset`` per distinct - name, preserving the order of first appearance. Optional columns + name, preserving the order of first appearance. Optional per-image columns (``flux``/``flux_noise``, ``time_delay``/``time_delay_noise``) are carried through per-group: if every row in a group populates the column the values are loaded, if every row leaves it blank the corresponding attribute is set to ``None``, and any partial-population is rejected with a ``ValueError``. + + The optional ``redshift`` column is per-source (not per-image): every row within + a group must share the same value. A group with mixed or differing redshifts is + rejected with a ``ValueError``. """ rows = csvable.list_from_csv(file_path) @@ -304,6 +352,7 @@ def list_from_csv(file_path: str) -> List[PointDataset]: has_flux_noise_column = "flux_noise" in headers has_time_delay_column = "time_delay" in headers has_time_delay_noise_column = "time_delay_noise" in headers + has_redshift_column = "redshift" in headers datasets: List[PointDataset] = [] for name, group_rows in groups.items(): @@ -332,6 +381,11 @@ def list_from_csv(file_path: str) -> List[PointDataset]: if has_time_delay_noise_column else None ) + redshift = ( + _group_redshift(group_rows, name) + if has_redshift_column + else None + ) datasets.append( PointDataset( @@ -342,6 +396,7 @@ def list_from_csv(file_path: str) -> List[PointDataset]: fluxes_noise_map=fluxes_noise_map, time_delays=time_delays, time_delays_noise_map=time_delays_noise_map, + redshift=redshift, ) ) diff --git a/test_autolens/point/test_dataset.py b/test_autolens/point/test_dataset.py index e4da0b0e2..4470ed8e0 100644 --- a/test_autolens/point/test_dataset.py +++ b/test_autolens/point/test_dataset.py @@ -29,6 +29,10 @@ def _assert_dataset_equal(actual: al.PointDataset, expected: al.PointDataset): _assert_array_close(actual.fluxes_noise_map, expected.fluxes_noise_map) _assert_array_close(actual.time_delays, expected.time_delays) _assert_array_close(actual.time_delays_noise_map, expected.time_delays_noise_map) + if expected.redshift is None: + assert actual.redshift is None + else: + assert actual.redshift == pytest.approx(expected.redshift) def test__csv_round_trip__positions_only(tmp_path): @@ -133,6 +137,70 @@ def test__csv_list_round_trip__heterogeneous_optional_columns(tmp_path): assert loaded[1].fluxes_noise_map is None +def test__csv_round_trip__redshift(tmp_path): + dataset = al.PointDataset( + name="source_0", + positions=[(0.5, 1.0), (-0.25, 2.0), (1.5, -1.0)], + positions_noise_map=[0.05, 0.05, 0.1], + redshift=2.5, + ) + + file_path = os.path.join(tmp_path, "point_dataset.csv") + dataset.to_csv(file_path) + + loaded = al.PointDataset.from_csv(file_path) + + _assert_dataset_equal(loaded, dataset) + assert loaded.redshift == pytest.approx(2.5) + + +def test__csv_list_round_trip__mixed_redshift_presence(tmp_path): + with_redshift = al.PointDataset( + name="source_0", + positions=[(0.0, 0.0), (1.0, 1.0)], + positions_noise_map=[0.05, 0.05], + redshift=1.8, + ) + without_redshift = al.PointDataset( + name="source_1", + positions=[(2.0, 0.5), (-1.0, 0.5)], + positions_noise_map=[0.1, 0.1], + ) + + file_path = os.path.join(tmp_path, "point_datasets.csv") + al.output_to_csv([with_redshift, without_redshift], file_path) + + loaded = al.list_from_csv(file_path) + + assert [d.name for d in loaded] == ["source_0", "source_1"] + _assert_dataset_equal(loaded[0], with_redshift) + _assert_dataset_equal(loaded[1], without_redshift) + assert loaded[0].redshift == pytest.approx(1.8) + assert loaded[1].redshift is None + + +def test__list_from_csv__inconsistent_redshift_raises(tmp_path): + file_path = os.path.join(tmp_path, "point_datasets.csv") + with open(file_path, "w") as f: + f.write("name,y,x,positions_noise,redshift\n") + f.write("source_0,0.0,0.0,0.05,1.5\n") + f.write("source_0,1.0,1.0,0.05,2.0\n") + + with pytest.raises(ValueError, match="inconsistent 'redshift'"): + al.list_from_csv(file_path) + + +def test__list_from_csv__partial_redshift_raises(tmp_path): + file_path = os.path.join(tmp_path, "point_datasets.csv") + with open(file_path, "w") as f: + f.write("name,y,x,positions_noise,redshift\n") + f.write("source_0,0.0,0.0,0.05,1.5\n") + f.write("source_0,1.0,1.0,0.05,\n") + + with pytest.raises(ValueError, match="partially populated column 'redshift'"): + al.list_from_csv(file_path) + + def test__from_csv__multiple_groups_requires_name(tmp_path): datasets = [ al.PointDataset(