Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 45 additions & 29 deletions likelihood_runtime/OPTIMIZATION_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ follow-up (see the bottom of this doc).
| Local GPU (RTX 2060) fp64 + mp | ✅ run, numbers below |
| HPC A100 sma (4 cells) | ✅ run 2026-05-21 — interferometer/delaunay + datacube/delaunay × sma × fp64 + mp |
| HPC A100 alma (4 cells) | ✅ run 2026-05-22 — unblocked by PyAutoArray#329 (apply_sparse_operator now accepts TransformerNUFFT) |
| HPC A100 alma_high | ⏸ blocked — see [nufft_simulator_chunking](https://github.com/PyAutoLabs/PyAutoPrompt/blob/main/autoarray/nufft_simulator_chunking.md) |
| HPC A100 alma_high (4 cells) | ✅ run 2026-05-22 — unblocked by PyAutoArray#330 (TransformerNUFFT chunk_size knob caps the nufftax gather buffer) |
| Imaging cells fresh CPU/GPU | ⚠ blocked by upstream `Grid2DIrregular.mask` bug — table rows show the pre-existing v2026.5.8.2 / v2026.5.14.2 data |

## Headline numbers (full pipeline, single JIT per call)
Expand Down Expand Up @@ -334,23 +334,37 @@ the cached dirty image, χ² is `inversion.fast_chi_squared`. Setup-time
- **mp is a wash** (45.3 vs 44.8 ms — essentially identical), same pattern
as sma. fp64 is the right default.

### alma_high (5M / 10M visibilities) — simulator blocked
### alma_high (5M visibilities, high-res `pixel_scale=0.025`)

| Config | full pipeline | Notes |
|-------------------|---------------|-------|
| hpc_a100_fp64 | — | dataset can't be simulated; see follow-up |
| hpc_a100_mp | — | same |

Simulator can't generate the dataset: `nufftax.spread.interp_2d_impl`
allocates `2 × N_vis × nspread² × dtype = ~15.7 GB` in a single gather
buffer even at 5M visibilities, exceeding A100 headroom. Tracked in
[`PyAutoPrompt/autoarray/nufft_simulator_chunking.md`](https://github.com/PyAutoLabs/PyAutoPrompt/blob/main/autoarray/nufft_simulator_chunking.md).
Re-running the alma_high SLURM submits unblocks once that chunking
lands.

vmap is intentionally skipped on this cell — opt in with `DELAUNAY_VMAP=1`
per the script's design. Delaunay mesh construction doesn't batch
cleanly along the parameter axis.
| Config | full pipeline | vmap (batch=3) | log_evidence |
|-------------------|---------------|------------------------------|--------------|
| hpc_a100_fp64 | **98 ms** | _vmap intentionally skipped_ | −60 243 535.86 |
| hpc_a100_mp | 101 ms | _vmap intentionally skipped_ | −60 243 535.86 |

Unblocked by [PyAutoArray#330](https://github.com/PyAutoLabs/PyAutoArray/pull/330)
(TransformerNUFFT `chunk_size` knob caps the nufftax gather buffer at
`2 × chunk_size × nspread² × dtype_size`; per-instrument default for
alma_high is `chunk_size=1_000_000`). Simulator runs cleanly on A100 in
~80 s for the full 5M-vis dataset.

**Per-call scaling validated.** Across the three instrument presets
(same model, same Hilbert pixel budget, sparse-operator path):

| Instrument | n_vis | pixel_scale | mask radius (px) | per-call (fp64) |
|------------|-------|-------------|------------------|-----------------|
| sma | 190 | 0.1 | 35 | 33 ms |
| alma | 1 M | 0.05 | 70 | 45 ms |
| alma_high | 5 M | 0.025 | 140 | 98 ms |

Per-call cost scales **with mask radius (in pixels)**, not with
visibility count. Going alma → alma_high doubles the mask diameter
(4× more mask pixels → 4× more FFT work), and the per-call time
~doubles (45 → 98 ms). Going sma → alma → alma_high, visibility count
scales 5263× (190 → 1M → 5M) but per-call time only scales 3× (33 → 98
ms). This is the clearest empirical confirmation yet of the W-Tilde
sparse-formalism prediction: per-likelihood cost is dominated by the
mask-extent FFT (`O(N_mask · log N_mask)`), and visibility count enters
only the one-shot, setup-time NUFFT precision-matrix precompute.

**Key findings (sma)**

Expand Down Expand Up @@ -443,19 +457,21 @@ Per-call cube-JIT timing is still pending — opt in with `CUBE_FULL_JIT=1`
on a follow-up SLURM run. Expected based on the interferometer alma row
(45 ms × 34 channels ≈ 1.5 s/cube, give or take XLA fusion savings).

### alma_high (5M / 10M visibilities) — simulator blocked
### alma_high (5M visibilities × 34 channels, `pixel_scale=0.025`)

| Config | full pipeline | Notes |
|-------------------|---------------|-------|
| hpc_a100_fp64 | — | dataset can't be simulated; see follow-up |
| hpc_a100_mp | — | same |

Simulator can't generate the dataset: `nufftax.spread.interp_2d_impl`
allocates `2 × N_vis × nspread² × dtype = ~15.7 GB` in a single gather
buffer even at 5M visibilities, exceeding A100 headroom. Tracked in
[`PyAutoPrompt/autoarray/nufft_simulator_chunking.md`](https://github.com/PyAutoLabs/PyAutoPrompt/blob/main/autoarray/nufft_simulator_chunking.md).
Re-running the alma_high SLURM submits unblocks once that chunking
(equivalent to the planned `TransformerNUFFT.chunk_size` knob) lands.
| Config | full pipeline (cube) | log_evidence (cube) | log_evidence/channel | Notes |
|-------------------|----------------------------|---------------------|----------------------|-------|
| hpc_a100_fp64 | **eager baseline only** | −2 048 222 823.68 | −60 241 847.76 | runtime variant; cube-JIT skipped (opt in via `CUBE_FULL_JIT=1`) |
| hpc_a100_mp | eager baseline only | −2 048 222 823.67 | −60 241 847.75 | same |

Unblocked by [PyAutoArray#330](https://github.com/PyAutoLabs/PyAutoArray/pull/330).
All 34 channels finished `apply_sparse_operator` on A100 within the
~31-minute SLURM wall budget at the chunked `chunk_size=1_000_000`
setting (longer than alma's 21-min run due to the 4× larger mask FFT
per channel at `pixel_scale=0.025`). Per-channel eager log_evidence
matches `interferometer/delaunay/alma_high` within ~0.005% (small drift
down to fixed-seed-driven model parameter differences between the two
scripts; well within the math-equivalence threshold).

**Headline finding (local data)** — **the 34-channel cube drops from 197 s to 18 s on
GPU** (10.9× faster), making per-cube fits genuinely interactive on RTX 2060.
Expand Down
20 changes: 16 additions & 4 deletions likelihood_runtime/datacube/delaunay.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,19 +221,31 @@ def jit_profile(func, label, *args, n_repeats=10):
radius=mask_radius,
)

transformer_chunk_size = INSTRUMENTS[instrument].get("transformer_chunk_size", None)


def _build_transformer(uv_wavelengths, real_space_mask):
"""Inject per-instrument chunk_size into TransformerNUFFT without needing a
transformer_kwargs API on Interferometer.from_fits. Required for alma_high
(5M visibilities) to cap the nufftax gather buffer (PyAutoArray#330)."""
return al.TransformerNUFFT(
uv_wavelengths=uv_wavelengths,
real_space_mask=real_space_mask,
chunk_size=transformer_chunk_size,
)


with timer.section("dataset_list_load"):
# apply_sparse_operator: precompute the visibility-space sparse precision
# operator so per-fit curvature assembly uses the FFT-based sparse path
# instead of a dense DFT for every source pixel. Unblocked by
# PyAutoArray#316 (the Pmax > 1 extent-indexing fix); on Delaunay this was
# previously guarded with NotImplementedError.
# instead of a dense DFT for every source pixel.
dataset_list = [
al.Interferometer.from_fits(
data_path=dataset_path / "data.fits",
noise_map_path=dataset_path / "noise_map.fits",
uv_wavelengths_path=dataset_path / "uv_wavelengths.fits",
real_space_mask=real_space_mask,
transformer_class=al.TransformerNUFFT,
transformer_class=_build_transformer,
).apply_sparse_operator(use_jax=True, show_progress=False)
for _ in range(n_channels)
]
Expand Down
15 changes: 14 additions & 1 deletion likelihood_runtime/interferometer/delaunay.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,33 @@ def jit_profile(func, label, *args, n_repeats=10):
)

mask_radius = INSTRUMENTS[instrument]["mask_radius"]
transformer_chunk_size = INSTRUMENTS[instrument].get("transformer_chunk_size", None)

real_space_mask = al.Mask2D.circular(
shape_native=real_space_shape,
pixel_scales=pixel_scale,
radius=mask_radius,
)


def _build_transformer(uv_wavelengths, real_space_mask):
"""Inject per-instrument chunk_size into TransformerNUFFT without needing a
transformer_kwargs API on Interferometer.from_fits. Required for alma_high
(5M visibilities) to cap the nufftax gather buffer (PyAutoArray#330)."""
return al.TransformerNUFFT(
uv_wavelengths=uv_wavelengths,
real_space_mask=real_space_mask,
chunk_size=transformer_chunk_size,
)


with timer.section("dataset_load"):
dataset = al.Interferometer.from_fits(
data_path=dataset_path / "data.fits",
noise_map_path=dataset_path / "noise_map.fits",
uv_wavelengths_path=dataset_path / "uv_wavelengths.fits",
real_space_mask=real_space_mask,
transformer_class=al.TransformerNUFFT,
transformer_class=_build_transformer,
)

with timer.section("apply_sparse_operator"):
Expand Down
28 changes: 21 additions & 7 deletions simulators/interferometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"noise_sigma": 1000.0,
"seed": 1,
"transformer": "dft", # 190 vis × 256² grid; DFT is cheap and exact
"transformer_chunk_size": None, # one-shot; sma is tiny
},
"alma": {
"pixel_scale": 0.05,
Expand All @@ -65,16 +66,18 @@
"noise_sigma": 100.0,
"seed": 1,
"transformer": "nufft", # 1M vis × 800² grid → DFT memory blowup; use nufftax
"transformer_chunk_size": None, # 1M vis × nspread²=196 ≈ 3 GB gather buffer; fits A100 one-shot
},
"alma_high": {
"pixel_scale": 0.125,
"pixel_scale": 0.025,
"real_space_shape": (800, 800),
"mask_radius": 3.5,
"n_visibilities": 5_000_000,
"uv_scale": 2.0e6,
"noise_sigma": 100.0,
"seed": 1,
"transformer": "nufft", # 5M vis × 800² grid; nufftax index buffer ~7.8 GB at eps=1e-6 (10M would need 15.7 GB, exceeds practical A100 headroom)
"transformer": "nufft", # 5M vis × 800² grid; needs chunking via PyAutoArray#330
"transformer_chunk_size": 1_000_000, # caps gather buffer ~3 GB / chunk
},
}

Expand Down Expand Up @@ -121,10 +124,20 @@ def simulate(instrument: str = "sma", output_root: Path | None = None) -> Path:
noise_sigma = config["noise_sigma"]
seed = config["seed"]
transformer_choice = config.get("transformer", "dft").lower()
transformer_class = {
"dft": al.TransformerDFT,
"nufft": al.TransformerNUFFT,
}[transformer_choice]
transformer_chunk_size = config.get("transformer_chunk_size", None)
if transformer_choice == "nufft":
# Lambda inject so chunk_size flows into TransformerNUFFT.__init__
# without needing a transformer_kwargs API in Interferometer.from_fits.
def transformer_class(uv_wavelengths, real_space_mask):
return al.TransformerNUFFT(
uv_wavelengths=uv_wavelengths,
real_space_mask=real_space_mask,
chunk_size=transformer_chunk_size,
)
elif transformer_choice == "dft":
transformer_class = al.TransformerDFT
else:
raise ValueError(f"Unknown transformer '{transformer_choice}'")

root = output_root if output_root is not None else _REPO_ROOT
dataset_path = root / "dataset" / "interferometer" / instrument
Expand Down Expand Up @@ -331,7 +344,8 @@ def _image_fn(grid_array):
"n_visibilities": n_visibilities,
"uv_scale": uv_scale,
"noise_sigma": noise_sigma,
"transformer": transformer_class.__name__,
"transformer": _TRANSFORMER_CLASS[transformer_choice],
"transformer_chunk_size": transformer_chunk_size,
},
"phases": phases,
"key_timings": {
Expand Down
Loading