Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions _profile_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ProfileCLI:
output_dir: Optional[Path]
use_mixed_precision: bool
instrument: Optional[str]
vmap_probe: bool


def parse_profile_cli(default_config_name: Optional[str] = None) -> ProfileCLI:
Expand Down Expand Up @@ -88,6 +89,16 @@ def parse_profile_cli(default_config_name: Optional[str] = None) -> ProfileCLI:
"interferometer/datacube cells, 'hst' for imaging)."
),
)
parser.add_argument(
"--vmap-probe",
action="store_true",
help=(
"Probe mode: JIT-vmap the full pipeline at batch=2 and batch=4, "
"read compiled.memory_analysis(), write a vmap_probe.json with "
"the recommended A100 batch_size, and exit before the steady-"
"state timing loop. See vram/README.md for methodology."
),
)

args, _unknown = parser.parse_known_args()
config_name = args.config_name or default_config_name
Expand All @@ -97,6 +108,7 @@ def parse_profile_cli(default_config_name: Optional[str] = None) -> ProfileCLI:
output_dir=output_dir,
use_mixed_precision=bool(args.use_mixed_precision),
instrument=args.instrument,
vmap_probe=bool(args.vmap_probe),
)


Expand Down
84 changes: 84 additions & 0 deletions instruments/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# `instruments` — per-instrument dataset presets

This subpackage owns the per-instrument configuration dicts that drive
both the simulators and the profiling cells. Two modules:

- `instruments.imaging` — `INSTRUMENTS` for imaging (hst, jwst, ao, euclid).
- `instruments.interferometer` — `INSTRUMENTS` for interferometer
(sma, alma, alma_high, jvla).

## Why a separate package?

The INSTRUMENTS dicts used to live inside `simulators/imaging.py` and
`simulators/interferometer.py`. As the repo grew, multiple consumers ended
up reading them:

- `simulators/*.py` — drives dataset simulation.
- `likelihood_runtime/{imaging,interferometer,datacube}/*.py` — reads
`pixel_scale`, `mask_radius`, `real_space_shape`, `transformer_chunk_size`
for setting up the profiling fit.
- `likelihood_breakdown/{imaging,interferometer,datacube}/*.py` — same.
- `vram/config.py` — uses the instrument keys to index the
`VMAP_BATCH` lookup table.

Splitting the dicts into a dedicated home means:

- Each consumer imports from one canonical location.
- Adding a new instrument is one row in one file (plus a probe + a
`VMAP_BATCH` entry).
- Helpers like `mask_radius_pixels(instrument)` can centralise math that
was previously inlined across multiple files.

## Schema

### Imaging fields

| Field | Type | Meaning |
|-------|------|---------|
| `pixel_scale` | float | arcsec / pixel |
| `mask_radius` | float | arcsec (circular mask) |
| `psf_shape` | tuple[int, int] | PSF kernel shape (n_y, n_x) |
| `psf_sigma` | float | Gaussian PSF width (arcsec) |
| `seed` | int | RNG seed for noise generation |

### Interferometer fields

| Field | Type | Meaning |
|-------|------|---------|
| `pixel_scale` | float | arcsec / pixel |
| `real_space_shape` | tuple[int, int] | (n_y, n_x) real-space image grid |
| `mask_radius` | float | arcsec (circular mask) |
| `n_visibilities` | int | number of (u, v) baselines |
| `uv_scale` | float | RNG sampling scale for (u, v) |
| `noise_sigma` | float | noise per visibility |
| `seed` | int | RNG seed |
| `transformer` | "dft" or "nufft" | transformer class |
| `transformer_chunk_size` | int or None | NUFFT gather-buffer cap |

## Helpers

- `imaging.mask_radius_pixels(instrument) -> int` — mask radius / pixel_scale, rounded.
- `imaging.shape_native(instrument) -> tuple[int, int]` — data grid shape derived from mask.
- `interferometer.mask_radius_pixels(instrument) -> int` — same math, on interferometer.
- `interferometer.transformer_chunk_size_for(instrument) -> int | None` — convenience accessor.

## Backward compatibility

The legacy import paths still work:

```python
from simulators.imaging import INSTRUMENTS # still valid
from simulators.interferometer import INSTRUMENTS # still valid
```

These re-export from `instruments.{imaging,interferometer}` so existing
consumers don't have to migrate. New code should prefer
`from instruments.imaging import INSTRUMENTS`.

## Adding a new instrument

1. Add a row to the appropriate `INSTRUMENTS` dict.
2. Simulate the dataset by running `python simulators/<imaging|interferometer>.py --instrument <name>`.
3. Run a `vram/` probe job (see `vram/README.md`) on the A100.
4. Add the resulting `VMAP_BATCH` entry to `vram/config.py`.
5. Re-run the regular profile sweep to confirm vmap holds at steady state.
20 changes: 20 additions & 0 deletions instruments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Per-instrument dataset presets — single source of truth across the repo.

This package decouples instrument configuration (pixel_scale, mask_radius,
PSF shape, visibility count, transformer config, ...) from the simulator
and likelihood-fit code that consume it. Multiple consumers — simulators,
``likelihood_runtime/``, ``likelihood_breakdown/``, ``vram/`` — read the
same dicts, so they live in their own module.

Public API::

from instruments.imaging import INSTRUMENTS, mask_radius_pixels
from instruments.interferometer import INSTRUMENTS, transformer_chunk_size_for

The legacy ``from simulators.{imaging,interferometer} import INSTRUMENTS``
imports continue to work via re-exports in those modules.
"""

from instruments import imaging, interferometer # noqa: F401

__all__ = ["imaging", "interferometer"]
77 changes: 77 additions & 0 deletions instruments/imaging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Per-instrument imaging dataset presets.

The single source of truth for imaging dataset geometry + simulator
configuration across `autolens_profiling`. Consumed by:

- ``simulators/imaging.py`` (re-exports ``INSTRUMENTS`` and uses every field
to drive the simulator).
- ``likelihood_runtime/imaging/{delaunay,mge,pixelization}.py`` (read
``pixel_scale`` and ``mask_radius`` to set up the dataset for profiling).
- ``likelihood_breakdown/imaging/*.py`` (same as above).
- ``vram/config.py`` (uses instrument keys to index the vmap batch_size
table — only the keys are referenced there, not the field values).

Each preset's fields:

- ``pixel_scale`` — arcsec per pixel.
- ``mask_radius`` — circular mask radius in arcsec.
- ``psf_shape`` — (n_y, n_x) shape of the simulated PSF kernel.
- ``psf_sigma`` — Gaussian PSF width in arcsec.
- ``seed`` — RNG seed for noise generation in the simulator.

To add a new instrument: append a row, then probe the per-(cell, instrument)
vmap batch size via ``vram/`` and add the matching rows in
``vram/config.py:VMAP_BATCH``.
"""

from __future__ import annotations


INSTRUMENTS: dict[str, dict] = {
"euclid": {
"pixel_scale": 0.1,
"mask_radius": 3.5,
"psf_shape": (21, 21),
"psf_sigma": 0.1,
"seed": 1,
},
"hst": {
"pixel_scale": 0.05,
"mask_radius": 3.5,
"psf_shape": (21, 21),
"psf_sigma": 0.05,
"seed": 1,
},
"jwst": {
"pixel_scale": 0.03,
"mask_radius": 3.5,
"psf_shape": (21, 21),
"psf_sigma": 0.03,
"seed": 1,
},
"ao": {
"pixel_scale": 0.01,
"mask_radius": 3.5,
"psf_shape": (21, 21),
"psf_sigma": 0.01,
"seed": 1,
},
}


def mask_radius_pixels(instrument: str) -> int:
"""Mask radius in pixels = ``mask_radius_arcsec / pixel_scale``."""
cfg = INSTRUMENTS[instrument]
return int(round(cfg["mask_radius"] / cfg["pixel_scale"]))


def shape_native(instrument: str) -> tuple[int, int]:
"""Native data grid shape derived from mask radius + pixel scale.

The simulator uses ``shape_pixels = ceil(2 * mask_radius / pixel_scale)``
(with a tight bounding box around the unmasked circle). This helper
replicates that math so consumers can size their grids consistently.
"""
cfg = INSTRUMENTS[instrument]
n = int(-(-2 * cfg["mask_radius"] // cfg["pixel_scale"])) # ceil-div
return (n, n)
97 changes: 97 additions & 0 deletions instruments/interferometer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Per-instrument interferometer dataset presets.

Single source of truth for interferometer dataset geometry + simulator
configuration. Consumed by:

- ``simulators/interferometer.py`` (re-exports ``INSTRUMENTS`` and uses every
field to drive the simulator + the lensed-source NUFFT transformer).
- ``likelihood_runtime/interferometer/{delaunay,mge,pixelization}.py`` (read
``pixel_scale``, ``real_space_shape``, ``mask_radius``, ``transformer_chunk_size``).
- ``likelihood_runtime/datacube/delaunay.py`` (same as above; per-channel).
- ``likelihood_breakdown/interferometer/*.py`` (same).
- ``vram/config.py`` (uses instrument keys to index the vmap batch_size table).

Each preset's fields:

- ``pixel_scale`` — arcsec per pixel.
- ``real_space_shape`` — (n_y, n_x) of the real-space image grid.
- ``mask_radius`` — circular mask radius in arcsec.
- ``n_visibilities`` — number of (u, v) baselines in the dataset.
- ``uv_scale`` — RNG sampling scale for (u, v) coordinates.
- ``noise_sigma`` — noise per visibility (in data units).
- ``seed`` — RNG seed for noise + uv generation.
- ``transformer`` — ``"dft"`` or ``"nufft"`` (selects the
transformer in both simulator and runtime).
- ``transformer_chunk_size`` — ``None`` for one-shot NUFFT, or a positive
integer to cap the nufftax gather buffer (PyAutoArray#330). Required at
alma_high / jvla scale.
"""

from __future__ import annotations

from typing import Optional


INSTRUMENTS: dict[str, dict] = {
"sma": {
"pixel_scale": 0.1,
"real_space_shape": (256, 256),
"mask_radius": 3.5,
"n_visibilities": 190,
"uv_scale": 3.0e5,
"noise_sigma": 1000.0,
"seed": 1,
"transformer": "dft", # 190 vis × 256² grid; DFT is cheap and exact
"transformer_chunk_size": None, # sma is tiny; one-shot
},
"alma": {
"pixel_scale": 0.05,
"real_space_shape": (800, 800),
"mask_radius": 3.5,
"n_visibilities": 1_000_000,
"uv_scale": 2.0e6,
"noise_sigma": 100.0,
"seed": 1,
"transformer": "nufft", # 1M vis × 800² grid → DFT memory blowup; use nufftax
"transformer_chunk_size": None, # 1M vis × nspread²=196 ≈ 3 GB; fits A100 one-shot
},
"alma_high": {
"pixel_scale": 0.025,
"real_space_shape": (800, 800),
"mask_radius": 3.5,
"n_visibilities": 5_000_000,
"uv_scale": 2.0e6,
"noise_sigma": 100.0,
"seed": 1,
"transformer": "nufft", # 5M vis × 800² grid; needs chunking via PyAutoArray#330
"transformer_chunk_size": 1_000_000, # caps gather buffer ~3 GB / chunk
},
"jvla": {
"pixel_scale": 0.01,
"real_space_shape": (800, 800),
"mask_radius": 3.5,
"n_visibilities": 25_000_000,
"uv_scale": 2.0e6,
"noise_sigma": 100.0,
"seed": 1,
"transformer": "nufft", # 25M vis stretch test; mask_radius=3.5/0.01 = 350-px radius (700-px mask diameter)
"transformer_chunk_size": 1_000_000, # 25 chunks × ~3 GB gather buffer each
},
}


TRANSFORMER_CLASS_NAME: dict[str, str] = {
"dft": "TransformerDFT",
"nufft": "TransformerNUFFT",
}


def mask_radius_pixels(instrument: str) -> int:
"""Mask radius in pixels = ``mask_radius_arcsec / pixel_scale``."""
cfg = INSTRUMENTS[instrument]
return int(round(cfg["mask_radius"] / cfg["pixel_scale"]))


def transformer_chunk_size_for(instrument: str) -> Optional[int]:
"""Per-instrument NUFFT chunk_size (None for one-shot)."""
return INSTRUMENTS[instrument].get("transformer_chunk_size")
Loading
Loading