diff --git a/_profile_cli.py b/_profile_cli.py index 1dd1f60..94f32cb 100644 --- a/_profile_cli.py +++ b/_profile_cli.py @@ -34,6 +34,7 @@ class ProfileCLI: output_dir: Optional[Path] use_mixed_precision: bool instrument: Optional[str] + vmap_probe: bool def parse_profile_cli(default_config_name: Optional[str] = None) -> ProfileCLI: @@ -88,6 +89,16 @@ def parse_profile_cli(default_config_name: Optional[str] = None) -> ProfileCLI: "interferometer/datacube cells, 'hst' for imaging)." ), ) + parser.add_argument( + "--vmap-probe", + action="store_true", + help=( + "Probe mode: JIT-vmap the full pipeline at batch=2 and batch=4, " + "read compiled.memory_analysis(), write a vmap_probe.json with " + "the recommended A100 batch_size, and exit before the steady-" + "state timing loop. See vram/README.md for methodology." + ), + ) args, _unknown = parser.parse_known_args() config_name = args.config_name or default_config_name @@ -97,6 +108,7 @@ def parse_profile_cli(default_config_name: Optional[str] = None) -> ProfileCLI: output_dir=output_dir, use_mixed_precision=bool(args.use_mixed_precision), instrument=args.instrument, + vmap_probe=bool(args.vmap_probe), ) diff --git a/instruments/README.md b/instruments/README.md new file mode 100644 index 0000000..1075260 --- /dev/null +++ b/instruments/README.md @@ -0,0 +1,84 @@ +# `instruments` — per-instrument dataset presets + +This subpackage owns the per-instrument configuration dicts that drive +both the simulators and the profiling cells. Two modules: + +- `instruments.imaging` — `INSTRUMENTS` for imaging (hst, jwst, ao, euclid). +- `instruments.interferometer` — `INSTRUMENTS` for interferometer + (sma, alma, alma_high, jvla). + +## Why a separate package? + +The INSTRUMENTS dicts used to live inside `simulators/imaging.py` and +`simulators/interferometer.py`. As the repo grew, multiple consumers ended +up reading them: + +- `simulators/*.py` — drives dataset simulation. +- `likelihood_runtime/{imaging,interferometer,datacube}/*.py` — reads + `pixel_scale`, `mask_radius`, `real_space_shape`, `transformer_chunk_size` + for setting up the profiling fit. +- `likelihood_breakdown/{imaging,interferometer,datacube}/*.py` — same. +- `vram/config.py` — uses the instrument keys to index the + `VMAP_BATCH` lookup table. + +Splitting the dicts into a dedicated home means: + +- Each consumer imports from one canonical location. +- Adding a new instrument is one row in one file (plus a probe + a + `VMAP_BATCH` entry). +- Helpers like `mask_radius_pixels(instrument)` can centralise math that + was previously inlined across multiple files. + +## Schema + +### Imaging fields + +| Field | Type | Meaning | +|-------|------|---------| +| `pixel_scale` | float | arcsec / pixel | +| `mask_radius` | float | arcsec (circular mask) | +| `psf_shape` | tuple[int, int] | PSF kernel shape (n_y, n_x) | +| `psf_sigma` | float | Gaussian PSF width (arcsec) | +| `seed` | int | RNG seed for noise generation | + +### Interferometer fields + +| Field | Type | Meaning | +|-------|------|---------| +| `pixel_scale` | float | arcsec / pixel | +| `real_space_shape` | tuple[int, int] | (n_y, n_x) real-space image grid | +| `mask_radius` | float | arcsec (circular mask) | +| `n_visibilities` | int | number of (u, v) baselines | +| `uv_scale` | float | RNG sampling scale for (u, v) | +| `noise_sigma` | float | noise per visibility | +| `seed` | int | RNG seed | +| `transformer` | "dft" or "nufft" | transformer class | +| `transformer_chunk_size` | int or None | NUFFT gather-buffer cap | + +## Helpers + +- `imaging.mask_radius_pixels(instrument) -> int` — mask radius / pixel_scale, rounded. +- `imaging.shape_native(instrument) -> tuple[int, int]` — data grid shape derived from mask. +- `interferometer.mask_radius_pixels(instrument) -> int` — same math, on interferometer. +- `interferometer.transformer_chunk_size_for(instrument) -> int | None` — convenience accessor. + +## Backward compatibility + +The legacy import paths still work: + +```python +from simulators.imaging import INSTRUMENTS # still valid +from simulators.interferometer import INSTRUMENTS # still valid +``` + +These re-export from `instruments.{imaging,interferometer}` so existing +consumers don't have to migrate. New code should prefer +`from instruments.imaging import INSTRUMENTS`. + +## Adding a new instrument + +1. Add a row to the appropriate `INSTRUMENTS` dict. +2. Simulate the dataset by running `python simulators/.py --instrument `. +3. Run a `vram/` probe job (see `vram/README.md`) on the A100. +4. Add the resulting `VMAP_BATCH` entry to `vram/config.py`. +5. Re-run the regular profile sweep to confirm vmap holds at steady state. diff --git a/instruments/__init__.py b/instruments/__init__.py new file mode 100644 index 0000000..34fe52d --- /dev/null +++ b/instruments/__init__.py @@ -0,0 +1,20 @@ +"""Per-instrument dataset presets — single source of truth across the repo. + +This package decouples instrument configuration (pixel_scale, mask_radius, +PSF shape, visibility count, transformer config, ...) from the simulator +and likelihood-fit code that consume it. Multiple consumers — simulators, +``likelihood_runtime/``, ``likelihood_breakdown/``, ``vram/`` — read the +same dicts, so they live in their own module. + +Public API:: + + from instruments.imaging import INSTRUMENTS, mask_radius_pixels + from instruments.interferometer import INSTRUMENTS, transformer_chunk_size_for + +The legacy ``from simulators.{imaging,interferometer} import INSTRUMENTS`` +imports continue to work via re-exports in those modules. +""" + +from instruments import imaging, interferometer # noqa: F401 + +__all__ = ["imaging", "interferometer"] diff --git a/instruments/imaging.py b/instruments/imaging.py new file mode 100644 index 0000000..ab5dd71 --- /dev/null +++ b/instruments/imaging.py @@ -0,0 +1,77 @@ +"""Per-instrument imaging dataset presets. + +The single source of truth for imaging dataset geometry + simulator +configuration across `autolens_profiling`. Consumed by: + +- ``simulators/imaging.py`` (re-exports ``INSTRUMENTS`` and uses every field + to drive the simulator). +- ``likelihood_runtime/imaging/{delaunay,mge,pixelization}.py`` (read + ``pixel_scale`` and ``mask_radius`` to set up the dataset for profiling). +- ``likelihood_breakdown/imaging/*.py`` (same as above). +- ``vram/config.py`` (uses instrument keys to index the vmap batch_size + table — only the keys are referenced there, not the field values). + +Each preset's fields: + +- ``pixel_scale`` — arcsec per pixel. +- ``mask_radius`` — circular mask radius in arcsec. +- ``psf_shape`` — (n_y, n_x) shape of the simulated PSF kernel. +- ``psf_sigma`` — Gaussian PSF width in arcsec. +- ``seed`` — RNG seed for noise generation in the simulator. + +To add a new instrument: append a row, then probe the per-(cell, instrument) +vmap batch size via ``vram/`` and add the matching rows in +``vram/config.py:VMAP_BATCH``. +""" + +from __future__ import annotations + + +INSTRUMENTS: dict[str, dict] = { + "euclid": { + "pixel_scale": 0.1, + "mask_radius": 3.5, + "psf_shape": (21, 21), + "psf_sigma": 0.1, + "seed": 1, + }, + "hst": { + "pixel_scale": 0.05, + "mask_radius": 3.5, + "psf_shape": (21, 21), + "psf_sigma": 0.05, + "seed": 1, + }, + "jwst": { + "pixel_scale": 0.03, + "mask_radius": 3.5, + "psf_shape": (21, 21), + "psf_sigma": 0.03, + "seed": 1, + }, + "ao": { + "pixel_scale": 0.01, + "mask_radius": 3.5, + "psf_shape": (21, 21), + "psf_sigma": 0.01, + "seed": 1, + }, +} + + +def mask_radius_pixels(instrument: str) -> int: + """Mask radius in pixels = ``mask_radius_arcsec / pixel_scale``.""" + cfg = INSTRUMENTS[instrument] + return int(round(cfg["mask_radius"] / cfg["pixel_scale"])) + + +def shape_native(instrument: str) -> tuple[int, int]: + """Native data grid shape derived from mask radius + pixel scale. + + The simulator uses ``shape_pixels = ceil(2 * mask_radius / pixel_scale)`` + (with a tight bounding box around the unmasked circle). This helper + replicates that math so consumers can size their grids consistently. + """ + cfg = INSTRUMENTS[instrument] + n = int(-(-2 * cfg["mask_radius"] // cfg["pixel_scale"])) # ceil-div + return (n, n) diff --git a/instruments/interferometer.py b/instruments/interferometer.py new file mode 100644 index 0000000..eda02cc --- /dev/null +++ b/instruments/interferometer.py @@ -0,0 +1,97 @@ +"""Per-instrument interferometer dataset presets. + +Single source of truth for interferometer dataset geometry + simulator +configuration. Consumed by: + +- ``simulators/interferometer.py`` (re-exports ``INSTRUMENTS`` and uses every + field to drive the simulator + the lensed-source NUFFT transformer). +- ``likelihood_runtime/interferometer/{delaunay,mge,pixelization}.py`` (read + ``pixel_scale``, ``real_space_shape``, ``mask_radius``, ``transformer_chunk_size``). +- ``likelihood_runtime/datacube/delaunay.py`` (same as above; per-channel). +- ``likelihood_breakdown/interferometer/*.py`` (same). +- ``vram/config.py`` (uses instrument keys to index the vmap batch_size table). + +Each preset's fields: + +- ``pixel_scale`` — arcsec per pixel. +- ``real_space_shape`` — (n_y, n_x) of the real-space image grid. +- ``mask_radius`` — circular mask radius in arcsec. +- ``n_visibilities`` — number of (u, v) baselines in the dataset. +- ``uv_scale`` — RNG sampling scale for (u, v) coordinates. +- ``noise_sigma`` — noise per visibility (in data units). +- ``seed`` — RNG seed for noise + uv generation. +- ``transformer`` — ``"dft"`` or ``"nufft"`` (selects the + transformer in both simulator and runtime). +- ``transformer_chunk_size`` — ``None`` for one-shot NUFFT, or a positive + integer to cap the nufftax gather buffer (PyAutoArray#330). Required at + alma_high / jvla scale. +""" + +from __future__ import annotations + +from typing import Optional + + +INSTRUMENTS: dict[str, dict] = { + "sma": { + "pixel_scale": 0.1, + "real_space_shape": (256, 256), + "mask_radius": 3.5, + "n_visibilities": 190, + "uv_scale": 3.0e5, + "noise_sigma": 1000.0, + "seed": 1, + "transformer": "dft", # 190 vis × 256² grid; DFT is cheap and exact + "transformer_chunk_size": None, # sma is tiny; one-shot + }, + "alma": { + "pixel_scale": 0.05, + "real_space_shape": (800, 800), + "mask_radius": 3.5, + "n_visibilities": 1_000_000, + "uv_scale": 2.0e6, + "noise_sigma": 100.0, + "seed": 1, + "transformer": "nufft", # 1M vis × 800² grid → DFT memory blowup; use nufftax + "transformer_chunk_size": None, # 1M vis × nspread²=196 ≈ 3 GB; fits A100 one-shot + }, + "alma_high": { + "pixel_scale": 0.025, + "real_space_shape": (800, 800), + "mask_radius": 3.5, + "n_visibilities": 5_000_000, + "uv_scale": 2.0e6, + "noise_sigma": 100.0, + "seed": 1, + "transformer": "nufft", # 5M vis × 800² grid; needs chunking via PyAutoArray#330 + "transformer_chunk_size": 1_000_000, # caps gather buffer ~3 GB / chunk + }, + "jvla": { + "pixel_scale": 0.01, + "real_space_shape": (800, 800), + "mask_radius": 3.5, + "n_visibilities": 25_000_000, + "uv_scale": 2.0e6, + "noise_sigma": 100.0, + "seed": 1, + "transformer": "nufft", # 25M vis stretch test; mask_radius=3.5/0.01 = 350-px radius (700-px mask diameter) + "transformer_chunk_size": 1_000_000, # 25 chunks × ~3 GB gather buffer each + }, +} + + +TRANSFORMER_CLASS_NAME: dict[str, str] = { + "dft": "TransformerDFT", + "nufft": "TransformerNUFFT", +} + + +def mask_radius_pixels(instrument: str) -> int: + """Mask radius in pixels = ``mask_radius_arcsec / pixel_scale``.""" + cfg = INSTRUMENTS[instrument] + return int(round(cfg["mask_radius"] / cfg["pixel_scale"])) + + +def transformer_chunk_size_for(instrument: str) -> Optional[int]: + """Per-instrument NUFFT chunk_size (None for one-shot).""" + return INSTRUMENTS[instrument].get("transformer_chunk_size") diff --git a/likelihood_runtime/OPTIMIZATION_NOTES.md b/likelihood_runtime/OPTIMIZATION_NOTES.md index ce39aba..5b7605d 100644 --- a/likelihood_runtime/OPTIMIZATION_NOTES.md +++ b/likelihood_runtime/OPTIMIZATION_NOTES.md @@ -139,30 +139,22 @@ PyAutoLens v2026.5.14.2. The pre-existing v2026.5.8.2 sweep data in **mp verdict** — modest on CPU (~14 % win), neutral elsewhere. **Useful only at CPU scale**; skip on GPU. -### Per-instrument A100 runtime sweep (2026-05-24) - -Full-pipeline single-JIT cost per likelihood call across the 4 imaging -instrument presets. Same model, same rectangular mesh, same regularization -— only the dataset's pixel_scale (and hence mask shape) changes. - -| Instrument | pixel_scale | mask shape (px) | fp64 | mp | -|------------|-------------|------------------|------|-----| -| hst | 0.05 | 140 × 140 | 53 ms | 51 ms | -| jwst | 0.03 | 234 × 234 | 53 ms | 51 ms | -| ao | 0.01 | 700 × 700 | 53 ms | 51 ms | -| euclid | 0.1 | 70 × 70 | 54 ms | 51 ms | - -**Imaging/pixelization is essentially instrument-INDEPENDENT.** The -35×35 rectangular source mesh (1225 nodes) dominates the per-call FFT -budget; the data-side mask shape barely matters. Going euclid → ao -the data grid scales 100× in pixel count, but per-call time changes -<2%. This is exactly the inverse of the interferometer/pixelization -result, where mask-FFT extent drove a 6× per-call spread. - -**mp is uniformly a small win** (~4-5%) across all 4 imaging -instruments on this cell, with no scaling story — fixed-size FFTs -amortize the mixed-precision overhead the same way regardless of -instrument. +### Per-instrument A100 runtime sweep (2026-05-25, corrected) + +*Previous "instrument-independent" finding was an artefact of a bug +that always ran HST regardless of --instrument. Fixed in this PR.* + +| Instrument | mask (px) | single_jit fp64 | vmap fp64 (batch) | speedup | +|------------|-----------|-----------------|-------------------|---------| +| euclid | 70 | 44 ms | **14 ms** (b=64) | **3.1×** | +| hst | 140 | 53 ms | 33 ms (b=16) | 1.6× | +| jwst | 234 | 74 ms | 69 ms (b=8) | 1.1× | +| ao | 700 | 326 ms | 330 ms (b=1) | 1.0× | + +**Imaging/pixelization IS instrument-dependent** — AO at 700-px mask is +7.4× slower per single-JIT call than euclid at 70-px. vmap helps at +euclid/hst (batch 64/16 → 1.6-3× speedup) but is a wash at jwst/ao +(batch 8/1 → VRAM-limited). --- @@ -202,19 +194,20 @@ CPU rows have no fresh measurement available. **mp verdict** — barely measurable (~5 % on GPU). Skip; it's not worth the correctness-budget pressure. -### Per-instrument A100 runtime sweep (2026-05-24) +### Per-instrument A100 runtime sweep (2026-05-25, corrected) -| Instrument | pixel_scale | mask shape (px) | fp64 | mp | -|------------|-------------|------------------|------|-----| -| hst | 0.05 | 140 × 140 | 86 ms | 90 ms | -| jwst | 0.03 | 234 × 234 | 80 ms | 78 ms | -| ao | 0.01 | 700 × 700 | 85 ms | 80 ms | -| euclid | 0.1 | 70 × 70 | 77 ms | 78 ms | +| Instrument | mask (px) | single_jit fp64 | vmap fp64 (batch) | speedup | +|------------|-----------|-----------------|-------------------|---------| +| euclid | 70 | 65 ms | **46 ms** (b=64) | **1.4×** | +| hst | 140 | 83 ms | 107 ms (b=16) | 0.8× | +| jwst | 234 | 133 ms | 199 ms (b=8) | 0.7× | +| ao | 700 | 558 ms | 1308 ms (b=1) | 0.4× | -**Delaunay also instrument-quasi-independent on A100** (77-86 ms -spread, ~12% variation). Hilbert-mesh + triangulation cost dominates; -the source-plane mesh has ~1000 nodes regardless of instrument. -**mp is a wash** on this cell across all 4 instruments. +**Delaunay IS instrument-dependent** — AO is 8.6× slower per single-JIT +call than euclid. vmap helps for euclid (1.4× at batch=64) but actively +HURTS for hst+ instruments (the per-replica VRAM cost forces small batches +where vmap overhead dominates). At batch=1 (AO), vmap is 2.3× SLOWER +than single-JIT — skip it entirely for AO-class datasets. --- @@ -223,23 +216,22 @@ the source-plane mesh has ~1000 nodes regardless of instrument. *MGE-decomposed source (Gaussian basis, ~25 Gaussians). Isothermal + ExternalShear lens. Lowest per-call cost in the imaging suite.* -### Per-instrument A100 runtime sweep (2026-05-24) +### Per-instrument A100 runtime sweep (2026-05-25, corrected) -| Instrument | pixel_scale | mask shape (px) | fp64 | mp | -|------------|-------------|------------------|------|-----| -| hst | 0.05 | 140 × 140 | 5.8 ms | 6.4 ms | -| jwst | 0.03 | 234 × 234 | 6.0 ms | 6.0 ms | -| ao | 0.01 | 700 × 700 | 6.0 ms | 5.8 ms | -| euclid | 0.1 | 70 × 70 | 5.9 ms | 5.8 ms | +| Instrument | mask (px) | single_jit fp64 | vmap fp64 (batch) | speedup | +|------------|-----------|-----------------|-------------------|---------| +| euclid | 70 | 5.7 ms | **0.2 ms** (b=64) | **29×** | +| hst | 140 | 6.1 ms | **0.4 ms** (b=64) | **16×** | +| jwst | 234 | 6.7 ms | **1.0 ms** (b=64) | **7×** | +| ao | 700 | — (blocked) | — (correctness bug) | — | -**Fastest cell in the entire imaging-side sweep at ~6 ms / call.** -Analytical light + parametric mass + Gaussian basis convolution — no -mesh construction, no sparse-operator setup, no large FFT. Per-call -cost is essentially constant across all 4 instruments (the data grid -shape barely registers). +**Fastest cell in the sweep AND massive vmap wins.** MGE's tiny per-replica +memory (~6-42 MB) fits batch=64 for all instruments → vmap amortises +the fixed per-call overhead down to sub-millisecond. Production samplers +should ALWAYS use vmap for MGE. -**mp is a wash** at this scale — the kernel is too small for -mixed-precision matmul gains to surface. +**AO is blocked** — vmap at batch=64 produces wildly inconsistent +log_evidence across replicas (correctness bug, separate investigation). --- diff --git a/likelihood_runtime/datacube/delaunay.py b/likelihood_runtime/datacube/delaunay.py index 5b15499..aa3f0b0 100644 --- a/likelihood_runtime/datacube/delaunay.py +++ b/likelihood_runtime/datacube/delaunay.py @@ -128,6 +128,7 @@ auto_simulate_if_missing, ) from simulators.interferometer import INSTRUMENTS # noqa: E402 +from vram import vmap_batch_for, write_probe_json, ProbeResult # noqa: E402 _cli = parse_profile_cli() instrument = _cli.instrument or "sma" # default; override via --instrument (cube is N copies of the per-instrument dataset) @@ -394,6 +395,30 @@ def _build_transformer(uv_wavelengths, real_space_mask): cube_log_evidence_ref = float(sum(log_evidence_per_channel)) print(f" cube reference log_evidence (sum) = {cube_log_evidence_ref:.6f}") +# =================================================================== +# PART B.5 — vmap-probe mode (early exit, intentionally skipped) +# =================================================================== +# +# The datacube cell does not vmap over parameters — the natural batching axis +# is "channels" (datasets), not "parameters". If --vmap-probe is set, write a +# vmap_probe.json noting the intentional skip and exit early. + +if _cli.vmap_probe: + probe_path = ( + (_cli.output_dir or (_workspace_root / "results" / "likelihood" / "datacube")) + / "vmap_probe.json" + ) + probe_path.parent.mkdir(parents=True, exist_ok=True) + import json + probe_path.write_text(json.dumps({ + "dataset": "datacube", + "model": "delaunay", + "instrument": instrument, + "recommended_batch_size": None, + "note": "datacube vmap intentionally skipped — natural batching axis is channels, not parameters", + }, indent=2)) + print(f" vmap_probe: cell intentionally skipped — wrote {probe_path}") + sys.exit(0) # =================================================================== # PART C — Full-pipeline cube JIT (sum of per-channel log_likelihoods) diff --git a/likelihood_runtime/imaging/delaunay.py b/likelihood_runtime/imaging/delaunay.py index 42f74ff..a623f07 100644 --- a/likelihood_runtime/imaging/delaunay.py +++ b/likelihood_runtime/imaging/delaunay.py @@ -80,9 +80,15 @@ auto_simulate_if_missing, ) from simulators.imaging import INSTRUMENTS # noqa: E402 +from vram import ( # noqa: E402 + probe_vmap_memory, + recommend_batch_size, + vmap_batch_for, + write_probe_json, +) _cli = parse_profile_cli() -instrument = "hst" # <-- change this to profile a different instrument +instrument = _cli.instrument or "hst" # default; override via --instrument # --------------------------------------------------------------------------- @@ -389,35 +395,56 @@ def full_pipeline_from_params(params_tree): print(f" full log_likelihood = {full_result}") +# =================================================================== +# PART C.5 — vmap-probe mode (early exit) +# =================================================================== +# +# When ``--vmap-probe`` is set the script JIT-vmaps the pipeline at the +# configured batch sizes, reads ``compiled.memory_analysis()``, writes a +# ``vmap_probe.json`` with the recommended A100 batch_size, and exits +# before the full vmap timing loop. See ``vram/README.md`` for methodology. + +if _cli.vmap_probe: + probe = probe_vmap_memory( + full_pipeline_from_params, + params_tree, + batch_sizes=(1,), + dataset="imaging", + model="delaunay", + instrument=instrument, + ) + recommended = recommend_batch_size(probe) + probe_path = ( + (_cli.output_dir or (_workspace_root / "results" / "likelihood" / "imaging")) + / "vmap_probe.json" + ) + write_probe_json(probe, recommended, probe_path) + print(f"\n vmap_probe samples: {probe.samples}") + print(f" per_replica: {probe.per_replica_mb:.1f} MB / replica") + print(f" recommended batch: {recommended}") + print(f" written to: {probe_path}") + sys.exit(0) + # =================================================================== # PART D — vmap + correctness # =================================================================== print("\n--- vmap batched evaluation ---") -# WARNING: The vmap compilation for the Delaunay pipeline takes 20+ minutes on CPU. -# The XLA graph for a batched Delaunay inversion (including scipy triangulation, -# border relocation, interpolation, mapping matrix construction, and PSF convolution) -# is extremely large. The single-call JIT above compiles in ~2s and runs in ~1.8s, -# but vmap recompiles the entire graph for batch_size independent evaluations. -# -# This is likely a candidate for optimisation — either via custom_vjp to avoid -# retracing the full pipeline, or by restructuring the Delaunay steps to reduce -# the XLA graph size. For now, skip vmap by default and run it only when explicitly -# requested via DELAUNAY_VMAP=1 environment variable. +batch_size = vmap_batch_for("imaging", "delaunay", instrument) or 3 -import os -run_vmap = os.environ.get("DELAUNAY_VMAP", "0") == "1" +# Skip vmap if vmap_batch_for explicitly returns None for this cell. +_vmap_skipped = vmap_batch_for("imaging", "delaunay", instrument) is None -if not run_vmap: - print(" SKIPPED: vmap compilation takes 20+ minutes for Delaunay pipeline.") - print(" Set DELAUNAY_VMAP=1 to run this section.") - vmap_batch_time = None +if _vmap_skipped: + print(" SKIPPED: vmap_batch_for() returned None for this (cell, instrument).") vmap_per_call = None vmap_speedup = None + vmap_batch_time = None + result_vmap = None + vmapped_full = None + parameters = None else: - - batch_size = 3 parameters = jax.tree_util.tree_map( lambda leaf: jnp.broadcast_to(leaf, (batch_size, *leaf.shape)), params_tree, @@ -516,16 +543,13 @@ def full_pipeline_from_params(params_tree): "edge_zeroed_pixels": int(edge_pixels_total), }, "full_pipeline_single_jit": full_pipeline_per_call, - "vmap": "SKIPPED — compilation takes 20+ minutes (set DELAUNAY_VMAP=1)", -} - -if vmap_per_call is not None: - likelihood_summary["vmap"] = { + "vmap": "SKIPPED — vmap_batch_for() returned None for this (cell, instrument)" if _vmap_skipped else { "batch_size": batch_size, "batch_time": vmap_batch_time, "per_call": vmap_per_call, - "speedup_vs_single_jit": round(vmap_speedup, 1), - } + "speedup_vs_single_jit": round(vmap_speedup, 1) if vmap_speedup is not None else None, + }, +} dict_path, chart_path = resolve_output_paths( _cli, @@ -568,7 +592,7 @@ def full_pipeline_from_params(params_tree): rtol=1e-3, err_msg=f"imaging/delaunay[{instrument}]: regression — full log_evidence drifted", ) -if run_vmap: +if not _vmap_skipped: np.testing.assert_allclose( np.array(result_vmap), EXPECTED_LOG_EVIDENCE_HST, diff --git a/likelihood_runtime/imaging/mge.py b/likelihood_runtime/imaging/mge.py index d4ad6a1..4aa186a 100644 --- a/likelihood_runtime/imaging/mge.py +++ b/likelihood_runtime/imaging/mge.py @@ -87,9 +87,15 @@ auto_simulate_if_missing, ) from simulators.imaging import INSTRUMENTS # noqa: E402 +from vram import ( # noqa: E402 + probe_vmap_memory, + recommend_batch_size, + vmap_batch_for, + write_probe_json, +) _cli = parse_profile_cli() -instrument = "hst" # <-- change this to profile a different instrument +instrument = _cli.instrument or "hst" # default; override via --instrument # --------------------------------------------------------------------------- @@ -347,13 +353,45 @@ def full_pipeline_from_params(params_tree): print(f" full log_likelihood = {full_result}") +# =================================================================== +# PART C.5 — vmap-probe mode (early exit) +# =================================================================== +# +# When ``--vmap-probe`` is set the script JIT-vmaps the pipeline at two +# batch sizes, reads ``compiled.memory_analysis()`` for each, and writes a +# ``vmap_probe.json`` with the recommended A100 batch_size — then exits +# before the full vmap timing loop. See ``vram/README.md`` for methodology. + +if _cli.vmap_probe: + # mge has cheap XLA compile (~10s); use multi-point fit to catch + # any rematerialisation non-linearity at the (1, 4, 16) regime. + probe = probe_vmap_memory( + full_pipeline_from_params, + params_tree, + batch_sizes=(1, 4, 16), + dataset="imaging", + model="mge", + instrument=instrument, + ) + recommended = recommend_batch_size(probe) + probe_path = ( + (_cli.output_dir or (_workspace_root / "results" / "likelihood" / "imaging")) + / "vmap_probe.json" + ) + write_probe_json(probe, recommended, probe_path) + print(f"\n vmap_probe samples: {probe.samples}") + print(f" per_replica: {probe.per_replica_mb:.1f} MB / replica") + print(f" recommended batch: {recommended}") + print(f" written to: {probe_path}") + sys.exit(0) + # =================================================================== # PART D — vmap + correctness # =================================================================== print("\n--- vmap batched evaluation ---") -batch_size = 3 +batch_size = vmap_batch_for("imaging", "mge", instrument) or 3 # Build the batched pytree: every leaf gets a fresh leading batch axis. No # flat-vector reshaping required — JAX walks the pytree via the registration @@ -457,6 +495,10 @@ def full_pipeline_from_params(params_tree): "per_call": vmap_per_call, "speedup_vs_single_jit": round(vmap_speedup, 1), }, + "memory_mb": { + "output": memory_analysis.output_size_in_bytes / 1024**2, + "temp": memory_analysis.temp_size_in_bytes / 1024**2, + }, } dict_path, chart_path = resolve_output_paths( diff --git a/likelihood_runtime/imaging/pixelization.py b/likelihood_runtime/imaging/pixelization.py index 1ae7ae7..e034a75 100644 --- a/likelihood_runtime/imaging/pixelization.py +++ b/likelihood_runtime/imaging/pixelization.py @@ -73,9 +73,15 @@ auto_simulate_if_missing, ) from simulators.imaging import INSTRUMENTS # noqa: E402 +from vram import ( # noqa: E402 + probe_vmap_memory, + recommend_batch_size, + vmap_batch_for, + write_probe_json, +) _cli = parse_profile_cli() -instrument = "hst" # <-- change this to profile a different instrument +instrument = _cli.instrument or "hst" # default; override via --instrument # --------------------------------------------------------------------------- @@ -371,6 +377,36 @@ def full_pipeline_from_params(params_tree): print(f" full log_likelihood = {full_result}") +# =================================================================== +# PART C.5 — vmap-probe mode (early exit) +# =================================================================== +# +# When ``--vmap-probe`` is set the script JIT-vmaps the pipeline at the +# configured batch sizes, reads ``compiled.memory_analysis()``, writes a +# ``vmap_probe.json`` with the recommended A100 batch_size, and exits +# before the full vmap timing loop. See ``vram/README.md`` for methodology. + +if _cli.vmap_probe: + probe = probe_vmap_memory( + full_pipeline_from_params, + params_tree, + batch_sizes=(1, 4, 16), + dataset="imaging", + model="pixelization", + instrument=instrument, + ) + recommended = recommend_batch_size(probe) + probe_path = ( + (_cli.output_dir or (_workspace_root / "results" / "likelihood" / "imaging")) + / "vmap_probe.json" + ) + write_probe_json(probe, recommended, probe_path) + print(f"\n vmap_probe samples: {probe.samples}") + print(f" per_replica: {probe.per_replica_mb:.1f} MB / replica") + print(f" recommended batch: {recommended}") + print(f" written to: {probe_path}") + sys.exit(0) + # =================================================================== # PART D — vmap + correctness # =================================================================== @@ -381,7 +417,7 @@ def full_pipeline_from_params(params_tree): print("\n--- vmap batched evaluation ---") -batch_size = 3 +batch_size = vmap_batch_for("imaging", "pixelization", instrument) or 3 vmap_batch_time = None vmap_per_call = None vmap_speedup = None diff --git a/likelihood_runtime/interferometer/delaunay.py b/likelihood_runtime/interferometer/delaunay.py index a7da5ae..76d637c 100644 --- a/likelihood_runtime/interferometer/delaunay.py +++ b/likelihood_runtime/interferometer/delaunay.py @@ -83,7 +83,6 @@ pytree support landed in PyAutoFit#1222. """ -import os import numpy as np import jax import jax.numpy as jnp @@ -123,6 +122,12 @@ auto_simulate_if_missing, ) from simulators.interferometer import INSTRUMENTS # noqa: E402 +from vram import ( # noqa: E402 + probe_vmap_memory, + recommend_batch_size, + vmap_batch_for, + write_probe_json, +) _cli = parse_profile_cli() instrument = _cli.instrument or "sma" # default; override via --instrument @@ -431,18 +436,46 @@ def full_pipeline_from_params(params_tree): print(" Eager-vs-JIT correctness PASSED") # =================================================================== -# PART D — vmap (opt-in) + correctness +# PART C.5 — vmap-probe mode (early exit) # =================================================================== # -# Delaunay vmap compilation can take 20+ minutes on CPU due to the size of -# the triangulation + interpolation XLA graph. Skipped by default — set -# DELAUNAY_VMAP=1 to opt in. +# When ``--vmap-probe`` is set the script JIT-vmaps the pipeline at the +# configured batch sizes, reads ``compiled.memory_analysis()``, writes a +# ``vmap_probe.json`` with the recommended A100 batch_size, and exits +# before the full vmap timing loop. See ``vram/README.md`` for methodology. + +if _cli.vmap_probe: + probe = probe_vmap_memory( + full_pipeline_from_params, + params_tree, + batch_sizes=(1,), + dataset="interferometer", + model="delaunay", + instrument=instrument, + ) + recommended = recommend_batch_size(probe) + probe_path = ( + (_cli.output_dir or (_workspace_root / "results" / "likelihood" / "interferometer")) + / "vmap_probe.json" + ) + write_probe_json(probe, recommended, probe_path) + print(f"\n vmap_probe samples: {probe.samples}") + print(f" per_replica: {probe.per_replica_mb:.1f} MB / replica") + print(f" recommended batch: {recommended}") + print(f" written to: {probe_path}") + sys.exit(0) + +# =================================================================== +# PART D — vmap + correctness +# =================================================================== print("\n--- vmap batched evaluation ---") -run_vmap = os.environ.get("DELAUNAY_VMAP", "0") == "1" +batch_size = vmap_batch_for("interferometer", "delaunay", instrument) or 3 + +# Skip vmap if vmap_batch_for explicitly returns None for this cell. +_vmap_skipped = vmap_batch_for("interferometer", "delaunay", instrument) is None -batch_size = 3 vmap_batch_time = None vmap_per_call = None vmap_speedup = None @@ -451,8 +484,8 @@ def full_pipeline_from_params(params_tree): parameters = None _n_leaves = len(jax.tree_util.tree_leaves(params_tree)) -if not run_vmap: - print(" SKIPPED: opt-in via DELAUNAY_VMAP=1 (compilation can take 20+ minutes).") +if _vmap_skipped: + print(" SKIPPED: vmap_batch_for() returned None for this (cell, instrument).") elif _n_leaves == 0: print(f" SKIPPED: model has 0 free parameters (all fixed to truth); " f"vmap requires at least one array leaf.") @@ -552,8 +585,8 @@ def full_pipeline_from_params(params_tree): # --- Save results dictionary --- if vmap_per_call is None: - if not run_vmap: - vmap_payload = "SKIPPED — opt-in via DELAUNAY_VMAP=1" + if _vmap_skipped: + vmap_payload = "SKIPPED — vmap_batch_for() returned None for this (cell, instrument)" else: vmap_payload = "SKIPPED — model has 0 free parameters (all fixed to truth)" else: diff --git a/likelihood_runtime/interferometer/mge.py b/likelihood_runtime/interferometer/mge.py index cb931e0..3dd9b2d 100644 --- a/likelihood_runtime/interferometer/mge.py +++ b/likelihood_runtime/interferometer/mge.py @@ -86,6 +86,12 @@ auto_simulate_if_missing, ) from simulators.interferometer import INSTRUMENTS # noqa: E402 +from vram import ( # noqa: E402 + probe_vmap_memory, + recommend_batch_size, + vmap_batch_for, + write_probe_json, +) _cli = parse_profile_cli() import argparse as _argparse # noqa: E402 @@ -94,7 +100,7 @@ _local_args, _ = _local_parser.parse_known_args() USE_DFT = bool(_local_args.use_dft) -instrument = "sma" # <-- change this to profile a different instrument +instrument = _cli.instrument or "sma" # default; override via --instrument # --------------------------------------------------------------------------- @@ -321,13 +327,43 @@ def full_pipeline_from_params(params_tree): print(f" full log_likelihood = {full_result}") +# =================================================================== +# PART B.5 — vmap-probe mode (early exit) +# =================================================================== +# +# When ``--vmap-probe`` is set the script JIT-vmaps the pipeline at the +# configured batch sizes, reads ``compiled.memory_analysis()``, writes a +# ``vmap_probe.json`` with the recommended A100 batch_size, and exits +# before the full vmap timing loop. See ``vram/README.md`` for methodology. + +if _cli.vmap_probe: + probe = probe_vmap_memory( + full_pipeline_from_params, + params_tree, + batch_sizes=(1, 4, 16), + dataset="interferometer", + model="mge", + instrument=instrument, + ) + recommended = recommend_batch_size(probe) + probe_path = ( + (_cli.output_dir or (_workspace_root / "results" / "likelihood" / "interferometer")) + / "vmap_probe.json" + ) + write_probe_json(probe, recommended, probe_path) + print(f"\n vmap_probe samples: {probe.samples}") + print(f" per_replica: {probe.per_replica_mb:.1f} MB / replica") + print(f" recommended batch: {recommended}") + print(f" written to: {probe_path}") + sys.exit(0) + # =================================================================== # PART C — vmap + correctness # =================================================================== print("\n--- vmap batched evaluation ---") -batch_size = 3 +batch_size = vmap_batch_for("interferometer", "mge", instrument) or 3 parameters = jax.tree_util.tree_map( lambda leaf: jnp.broadcast_to(leaf, (batch_size, *leaf.shape)), diff --git a/likelihood_runtime/interferometer/pixelization.py b/likelihood_runtime/interferometer/pixelization.py index c1bb6e2..4ae7d59 100644 --- a/likelihood_runtime/interferometer/pixelization.py +++ b/likelihood_runtime/interferometer/pixelization.py @@ -83,9 +83,15 @@ auto_simulate_if_missing, ) from simulators.interferometer import INSTRUMENTS # noqa: E402 +from vram import ( # noqa: E402 + probe_vmap_memory, + recommend_batch_size, + vmap_batch_for, + write_probe_json, +) _cli = parse_profile_cli() -instrument = "sma" # <-- change this to profile a different instrument +instrument = _cli.instrument or "sma" # default; override via --instrument mesh_pixels_yx = 32 # 32x32 = 1024 source pixels — 1000-tier production fiducial mesh_shape = (mesh_pixels_yx, mesh_pixels_yx) @@ -337,6 +343,36 @@ def full_pipeline_from_params(params_tree): print(f" full log_likelihood = {full_result}") +# =================================================================== +# PART B.5 — vmap-probe mode (early exit) +# =================================================================== +# +# When ``--vmap-probe`` is set the script JIT-vmaps the pipeline at the +# configured batch sizes, reads ``compiled.memory_analysis()``, writes a +# ``vmap_probe.json`` with the recommended A100 batch_size, and exits +# before the full vmap timing loop. See ``vram/README.md`` for methodology. + +if _cli.vmap_probe: + probe = probe_vmap_memory( + full_pipeline_from_params, + params_tree, + batch_sizes=(1, 4, 16), + dataset="interferometer", + model="pixelization", + instrument=instrument, + ) + recommended = recommend_batch_size(probe) + probe_path = ( + (_cli.output_dir or (_workspace_root / "results" / "likelihood" / "interferometer")) + / "vmap_probe.json" + ) + write_probe_json(probe, recommended, probe_path) + print(f"\n vmap_probe samples: {probe.samples}") + print(f" per_replica: {probe.per_replica_mb:.1f} MB / replica") + print(f" recommended batch: {recommended}") + print(f" written to: {probe_path}") + sys.exit(0) + # =================================================================== # PART C — vmap + correctness # =================================================================== @@ -347,7 +383,7 @@ def full_pipeline_from_params(params_tree): print("\n--- vmap batched evaluation ---") -batch_size = 3 +batch_size = vmap_batch_for("interferometer", "pixelization", instrument) or 3 vmap_batch_time = None vmap_per_call = None vmap_speedup = None diff --git a/likelihood_runtime/point_source/image_plane.py b/likelihood_runtime/point_source/image_plane.py index 7955387..a0ed1ce 100644 --- a/likelihood_runtime/point_source/image_plane.py +++ b/likelihood_runtime/point_source/image_plane.py @@ -69,6 +69,12 @@ auto_simulate_if_missing, ) from simulators.point_source import INSTRUMENTS # noqa: E402 +from vram import ( # noqa: E402 + probe_vmap_memory, + recommend_batch_size, + vmap_batch_for, + write_probe_json, +) _cli = parse_profile_cli() matplotlib.use("Agg") @@ -277,6 +283,35 @@ def full_pipeline_from_params(params_tree): full_pipeline_per_call = timer.records[-1][1] / 10 print(f" full log_likelihood = {full_result}") +# =================================================================== +# PART B.5 — vmap-probe mode (early exit) +# =================================================================== +# +# When ``--vmap-probe`` is set the script JIT-vmaps the pipeline at the +# configured batch sizes, reads ``compiled.memory_analysis()``, writes a +# ``vmap_probe.json`` with the recommended A100 batch_size, and exits +# before the full vmap timing loop. See ``vram/README.md`` for methodology. + +if _cli.vmap_probe: + probe = probe_vmap_memory( + full_pipeline_from_params, + params_tree, + batch_sizes=(1, 4, 16), + dataset="point_source", + model="image_plane", + instrument=instrument, + ) + recommended = recommend_batch_size(probe) + probe_path = ( + (_cli.output_dir or (_workspace_root / "results" / "likelihood" / "point_source")) + / "vmap_probe.json" + ) + write_probe_json(probe, recommended, probe_path) + print(f"\n vmap_probe samples: {probe.samples}") + print(f" per_replica: {probe.per_replica_mb:.1f} MB / replica") + print(f" recommended batch: {recommended}") + print(f" written to: {probe_path}") + sys.exit(0) # =================================================================== # PART C — vmap over the full pipeline @@ -284,7 +319,7 @@ def full_pipeline_from_params(params_tree): print("\n--- vmap batched evaluation ---") -batch_size = 3 +batch_size = vmap_batch_for("point_source", "image_plane", instrument) or 3 batched_params = jax.tree_util.tree_map( lambda leaf: jnp.broadcast_to(leaf, (batch_size, *leaf.shape)), diff --git a/likelihood_runtime/point_source/source_plane.py b/likelihood_runtime/point_source/source_plane.py index ead517e..0073548 100644 --- a/likelihood_runtime/point_source/source_plane.py +++ b/likelihood_runtime/point_source/source_plane.py @@ -51,6 +51,12 @@ auto_simulate_if_missing, ) from simulators.point_source import INSTRUMENTS # noqa: E402 +from vram import ( # noqa: E402 + probe_vmap_memory, + recommend_batch_size, + vmap_batch_for, + write_probe_json, +) _cli = parse_profile_cli() matplotlib.use("Agg") @@ -281,6 +287,35 @@ def full_pipeline_from_params(params_tree): " >>> See module docstring for the proposed library fix." ) +# =================================================================== +# PART B.5 — vmap-probe mode (early exit) +# =================================================================== +# +# When ``--vmap-probe`` is set the script JIT-vmaps the pipeline at the +# configured batch sizes, reads ``compiled.memory_analysis()``, writes a +# ``vmap_probe.json`` with the recommended A100 batch_size, and exits +# before the full vmap timing loop. See ``vram/README.md`` for methodology. + +if _cli.vmap_probe: + probe = probe_vmap_memory( + full_pipeline_from_params, + params_tree, + batch_sizes=(1, 4, 16), + dataset="point_source", + model="source_plane", + instrument=instrument, + ) + recommended = recommend_batch_size(probe) + probe_path = ( + (_cli.output_dir or (_workspace_root / "results" / "likelihood" / "point_source")) + / "vmap_probe.json" + ) + write_probe_json(probe, recommended, probe_path) + print(f"\n vmap_probe samples: {probe.samples}") + print(f" per_replica: {probe.per_replica_mb:.1f} MB / replica") + print(f" recommended batch: {recommended}") + print(f" written to: {probe_path}") + sys.exit(0) # =================================================================== # PART C — JIT-able prefix: tracer ray-trace of observed positions @@ -326,7 +361,7 @@ def ray_trace_to_source_plane(params_tree, positions_raw): print("\n--- vmap over ray-trace prefix ---") -batch_size = 3 +batch_size = vmap_batch_for("point_source", "source_plane", instrument) or 3 batched_params = jax.tree_util.tree_map( lambda leaf: jnp.broadcast_to(leaf, (batch_size, *leaf.shape)), diff --git a/simulators/imaging.py b/simulators/imaging.py index 3e1e863..0161956 100644 --- a/simulators/imaging.py +++ b/simulators/imaging.py @@ -27,39 +27,13 @@ python simulators/imaging.py --instrument euclid """ +import sys from pathlib import Path - -INSTRUMENTS = { - "euclid": { - "pixel_scale": 0.1, - "mask_radius": 3.5, - "psf_shape": (21, 21), - "psf_sigma": 0.1, - "seed": 1, - }, - "hst": { - "pixel_scale": 0.05, - "mask_radius": 3.5, - "psf_shape": (21, 21), - "psf_sigma": 0.05, - "seed": 1, - }, - "jwst": { - "pixel_scale": 0.03, - "mask_radius": 3.5, - "psf_shape": (21, 21), - "psf_sigma": 0.03, - "seed": 1, - }, - "ao": { - "pixel_scale": 0.01, - "mask_radius": 3.5, - "psf_shape": (21, 21), - "psf_sigma": 0.01, - "seed": 1, - }, -} +# Soft-transition re-export — INSTRUMENTS now lives in `instruments/imaging.py`. +# Existing `from simulators.imaging import INSTRUMENTS` consumers keep working. +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +from instruments.imaging import INSTRUMENTS # noqa: E402, F401 _REPO_ROOT = Path(__file__).resolve().parents[1] # autolens_profiling/ @@ -284,8 +258,10 @@ def _image_fn(grid_array): lensed_source = psf.convolved_image_from( image=lensed_source_unblurred, blurring_image=None ) - lensed_source.output_to_fits( - file_path=dataset_path / "lensed_source.fits", overwrite=True + al.output_to_fits( + values=lensed_source.native_for_fits, + file_path=dataset_path / "lensed_source.fits", + overwrite=True, ) with timer.section("output_json"): diff --git a/simulators/interferometer.py b/simulators/interferometer.py index 3206273..efc37f2 100644 --- a/simulators/interferometer.py +++ b/simulators/interferometer.py @@ -33,70 +33,17 @@ python simulators/interferometer.py --instrument alma_high # 10M vis """ +import sys from pathlib import Path - -# --------------------------------------------------------------------------- -# Instrument definitions — single source of truth -# --------------------------------------------------------------------------- -# -# Each preset bundles BOTH simulation-time and likelihood-fit-time fields, -# so the likelihood scripts that import this dict get pixel_scale / -# real_space_shape / mask_radius while the simulator code in this module -# also gets n_visibilities / uv_scale / noise_sigma / seed. - -INSTRUMENTS = { - "sma": { - "pixel_scale": 0.1, - "real_space_shape": (256, 256), - "mask_radius": 3.5, - "n_visibilities": 190, - "uv_scale": 3.0e5, - "noise_sigma": 1000.0, - "seed": 1, - "transformer": "dft", # 190 vis × 256² grid; DFT is cheap and exact - "transformer_chunk_size": None, # one-shot; sma is tiny - }, - "alma": { - "pixel_scale": 0.05, - "real_space_shape": (800, 800), - "mask_radius": 3.5, - "n_visibilities": 1_000_000, - "uv_scale": 2.0e6, - "noise_sigma": 100.0, - "seed": 1, - "transformer": "nufft", # 1M vis × 800² grid → DFT memory blowup; use nufftax - "transformer_chunk_size": None, # 1M vis × nspread²=196 ≈ 3 GB gather buffer; fits A100 one-shot - }, - "alma_high": { - "pixel_scale": 0.025, - "real_space_shape": (800, 800), - "mask_radius": 3.5, - "n_visibilities": 5_000_000, - "uv_scale": 2.0e6, - "noise_sigma": 100.0, - "seed": 1, - "transformer": "nufft", # 5M vis × 800² grid; needs chunking via PyAutoArray#330 - "transformer_chunk_size": 1_000_000, # caps gather buffer ~3 GB / chunk - }, - "jvla": { - "pixel_scale": 0.01, - "real_space_shape": (800, 800), - "mask_radius": 3.5, - "n_visibilities": 25_000_000, - "uv_scale": 2.0e6, - "noise_sigma": 100.0, - "seed": 1, - "transformer": "nufft", # 25M vis stretch test; mask_radius=3.5/0.01 = 350-px radius (700-px mask diameter) - "transformer_chunk_size": 1_000_000, # 25 chunks × ~3 GB gather buffer each - }, -} - - -_TRANSFORMER_CLASS = { - "dft": "TransformerDFT", - "nufft": "TransformerNUFFT", -} +# Soft-transition re-export — INSTRUMENTS now lives in +# `instruments/interferometer.py`. Existing +# `from simulators.interferometer import INSTRUMENTS` consumers keep working. +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) +from instruments.interferometer import ( # noqa: E402, F401 + INSTRUMENTS, + TRANSFORMER_CLASS_NAME as _TRANSFORMER_CLASS, +) _REPO_ROOT = Path(__file__).resolve().parents[1] # autolens_profiling/ diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_vram_probe.py b/test/test_vram_probe.py new file mode 100644 index 0000000..b0b4c57 --- /dev/null +++ b/test/test_vram_probe.py @@ -0,0 +1,139 @@ +"""Unit tests for ``vram.probe`` extrapolation math. + +These tests construct synthetic ``ProbeResult`` objects (no JAX dependency) and +verify the linear-fit + budget-extrapolation logic. Lets us iterate on the math +without paying for HPC probe-job cycles. + +Run:: + + cd autolens_profiling + python -m pytest test/test_vram_probe.py +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +# Make `vram` importable without installing the repo. +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from vram.probe import ProbeResult, ProbeSample, recommend_batch_size # noqa: E402 + + +def _peak_sample(batch: int, peak_mb: float) -> ProbeSample: + """Helper: construct a ProbeSample with only peak_bytes set.""" + return ProbeSample(batch_size=batch, peak_bytes=int(peak_mb * 1024**2)) + + +def _legacy_sample(batch: int, output_mb: float, temp_mb: float) -> ProbeSample: + """Helper: ProbeSample with no peak (legacy JAX fallback to output+temp).""" + return ProbeSample( + batch_size=batch, + peak_bytes=0, + output_bytes=int(output_mb * 1024**2), + temp_bytes=int(temp_mb * 1024**2), + ) + + +def test_per_replica_two_point_fit(): + """Linear coefficient from two samples.""" + probe = ProbeResult( + dataset="t", + model="t", + instrument="t", + samples=[_peak_sample(1, 100.0), _peak_sample(4, 400.0)], + ) + assert probe.per_replica_mb == 100.0 + assert probe.constant_overhead_mb == 0.0 + + +def test_per_replica_with_overhead(): + """Linear fit with non-zero intercept.""" + probe = ProbeResult( + dataset="t", + model="t", + instrument="t", + samples=[_peak_sample(1, 200.0), _peak_sample(4, 500.0)], + ) + # peak(B) = 200 + (B-1)*100 = 100 + 100*B → per_replica=100, overhead=100 + assert probe.per_replica_mb == 100.0 + assert probe.constant_overhead_mb == 100.0 + + +def test_per_replica_single_sample_assumes_zero_overhead(): + """Single sample falls back to peak / batch as the per-replica estimate.""" + probe = ProbeResult( + dataset="t", + model="t", + instrument="t", + samples=[_peak_sample(4, 800.0)], + ) + assert probe.per_replica_mb == 200.0 + assert probe.constant_overhead_mb == 0.0 + + +def test_recommend_capped_at_max_batch(): + """Tiny per-replica → recommendation is capped, not unbounded.""" + probe = ProbeResult( + dataset="t", + model="t", + instrument="t", + samples=[_peak_sample(1, 1.0), _peak_sample(4, 4.0)], + ) + assert recommend_batch_size(probe, max_batch=64) == 64 + + +def test_recommend_respects_budget(): + """Large per-replica → recommendation drops below the cap.""" + probe = ProbeResult( + dataset="t", + model="t", + instrument="t", + samples=[_peak_sample(1, 5_000.0), _peak_sample(4, 20_000.0)], + ) + # per_replica = 5 GB. budget = 65 GB, safety = 1.15 → 5.75 GB / replica → ~11.3 + rec = recommend_batch_size(probe, vram_budget_gb=65.0, safety_factor=1.15) + assert 10 <= rec <= 12, f"expected ~11, got {rec}" + + +def test_recommend_floor_at_1(): + """Per-replica > budget → still returns at least 1.""" + probe = ProbeResult( + dataset="t", + model="t", + instrument="t", + samples=[_peak_sample(1, 200_000.0)], # 200 GB per replica — absurd + ) + rec = recommend_batch_size(probe, vram_budget_gb=65.0, safety_factor=1.15) + assert rec == 1 + + +def test_legacy_fallback_uses_output_plus_temp(): + """If peak_memory_in_bytes is absent (older JAX), fall back to output+temp.""" + probe = ProbeResult( + dataset="t", + model="t", + instrument="t", + samples=[ + _legacy_sample(1, output_mb=10.0, temp_mb=90.0), # 100 MB total + _legacy_sample(4, output_mb=10.0, temp_mb=390.0), # 400 MB total + ], + ) + assert probe.per_replica_mb == 100.0 + assert probe.constant_overhead_mb == 0.0 + + +def test_safety_factor_tightens_recommendation(): + """Larger safety_factor → smaller recommended batch.""" + probe = ProbeResult( + dataset="t", + model="t", + instrument="t", + samples=[_peak_sample(1, 1_000.0), _peak_sample(4, 4_000.0)], + ) + no_safety = recommend_batch_size(probe, vram_budget_gb=65.0, safety_factor=1.0) + with_safety = recommend_batch_size( + probe, vram_budget_gb=65.0, safety_factor=1.5 + ) + assert with_safety < no_safety diff --git a/vram/README.md b/vram/README.md new file mode 100644 index 0000000..e89397d --- /dev/null +++ b/vram/README.md @@ -0,0 +1,95 @@ +# `vram` — vmap batch_size investigation + per-cell config + +This subpackage owns the VRAM-budget logic for the likelihood profiling +sweep on A100 80 GB. Two responsibilities: + +1. **Probe** (`vram.probe`) — measure how the compiled vmap program's + memory footprint scales with batch size, extrapolate the largest batch + that fits the device's VRAM budget. +2. **Config** (`vram.config`) — the curated per-(dataset, model, + instrument) batch_size table that runtime cell scripts look up via + `vmap_batch_for(...)`. + +## Why a separate subpackage? + +The runtime cell scripts (`likelihood_runtime/{imaging,interferometer, +datacube}/*.py`) used to hard-code `batch_size = 3` everywhere. Production +sampling uses much larger batches (≥ 10), and the right batch size varies +with model/instrument (a 700-px mask AO dataset needs a smaller batch +than a 70-px Euclid one). Splitting the logic out: + +- Keeps each runtime cell terse — one import + one call to + `vmap_batch_for(...)`. +- Centralises the probe / extrapolation math so it can be unit-tested. +- Makes the per-(cell, instrument) batch_size table reviewable as data, + not as scattered constants. + +## Public API + +From `vram`: + +| Name | Purpose | +|------|---------| +| `vmap_batch_for(dataset, model, instrument)` | Return per-cell batch_size (or `None` if vmap is intentionally skipped or the cell hasn't been probed). | +| `probe_vmap_memory(func, args, batch_sizes=(2, 4))` | JIT-vmap `func` at each batch, read `compiled.memory_analysis()`, return a `ProbeResult`. | +| `recommend_batch_size(probe, vram_budget_gb=70.0, max_batch=64)` | Linear extrapolation → max batch fitting in budget. | +| `write_probe_json(probe, recommended, path)` | Serialise probe + recommendation to JSON. | + +## How `probe_vmap_memory` works + +1. For each batch size `B` in `batch_sizes` (default `(1, 4)`): + 1. Broadcast each leaf of `args_pytree` along a new leading axis of size `B`. + 2. `jax.jit(jax.vmap(func))(parameters).lower().compile()`. + 3. Read `compiled.memory_analysis().peak_memory_in_bytes` (preferred, on + modern JAX) or fall back to `output_size + temp_size` (older JAX). +2. Fit a linear model: `peak_mb ≈ overhead + B * per_replica`. +3. `recommend_batch_size` returns + `floor((budget - safety_factor * overhead) / (safety_factor * per_replica))`, + capped at `max_batch` (default 64). + +**Why `peak_memory_in_bytes` and not `output + temp`?** Those are sequential +phases — peak is the actual maximum simultaneous allocation including XLA +rematerialisation. Summing output+temp double-counts buffer reuse and over- +reports memory. peak is what XLA actually allocates on device. + +## Methodology — A100 80 GB budget + +- Hard ceiling: **80 GB** on the RAL A100s. +- Soft budget (default): **65 GB**. Leaves ~15 GB for JAX runtime overhead, + CUDA driver allocations, allocator fragmentation, and per-call activation + slack that static analysis doesn't fully account for. +- **Safety factor: 1.15×** on `per_replica_mb`. The XLA static estimate + typically under-counts the real runtime peak by ~10-15% on complex graphs. +- Cap: **64**. XLA compile time scales superlinearly with batch on some + cells (notably `delaunay`, due to scipy-Delaunay-via-`pure_callback` + inflating the XLA graph). 64 is roughly where diminishing returns kick in + for production sampling. + +All defaults are configurable via kwargs on `recommend_batch_size` / +`probe_vmap_memory` if a different device or workload needs different limits. + +## Batch_sizes selection per cell + +XLA recompiles for each new batch_size (no cache reuse). Compile cost varies: + +- **mge / pixelization** — ~10 s/compile. Use a multi-point fit: + `batch_sizes=(1, 4, 16)` catches the ~8/16/32 rematerialisation phase + transitions JAX may exhibit. +- **delaunay** — 10-30 min/compile on big graphs (scipy.Delaunay via + `pure_callback` bloats the XLA graph). Use single-point: + `batch_sizes=(1,)` and accept linear extrapolation. If the chosen batch + OOMs at run time, manually re-probe at a smaller batch. + +## Adding a new instrument + +1. Add an entry to the appropriate `INSTRUMENTS` dict in + `simulators/{imaging,interferometer}.py`. +2. Create a probe SLURM script at + `z_projects/profiling/hpc/batch_gpu/probe_vmap__` + (clone any existing one). +3. Run the probe; pull the resulting JSON to + `output/runtime///vmap_probe.json`. +4. Read `recommended_batch_size` and add the row to `VMAP_BATCH` in + `vram/config.py`. +5. Re-run the regular profile to confirm the chosen batch holds at + steady state (vmap completes, doesn't OOM). diff --git a/vram/__init__.py b/vram/__init__.py new file mode 100644 index 0000000..2724d07 --- /dev/null +++ b/vram/__init__.py @@ -0,0 +1,32 @@ +"""VRAM / vmap-batch utilities for the likelihood profiling sweep. + +Two responsibilities live here: + +1. **Probing** — measure the per-replica VRAM cost of a vmapped likelihood + function on a given device, so we can pick the largest batch_size that + fits the device's memory budget. See ``vram.probe``. + +2. **Configuration** — the curated table of per-(dataset, model, instrument) + batch sizes derived from the probe results. Runtime cell scripts import + ``vram.vmap_batch_for(...)`` to look up the production batch size for + their cell. See ``vram.config``. + +See ``vram/README.md`` for methodology and how to extend. +""" + +from vram.config import VMAP_BATCH, vmap_batch_for +from vram.probe import ( + ProbeResult, + probe_vmap_memory, + recommend_batch_size, + write_probe_json, +) + +__all__ = [ + "VMAP_BATCH", + "vmap_batch_for", + "ProbeResult", + "probe_vmap_memory", + "recommend_batch_size", + "write_probe_json", +] diff --git a/vram/config.py b/vram/config.py new file mode 100644 index 0000000..d8a81c9 --- /dev/null +++ b/vram/config.py @@ -0,0 +1,122 @@ +"""Per-(dataset, model, instrument) vmap batch_size table for A100 80 GB. + +Populated empirically from ``vram.probe`` runs on the RAL A100 cluster +(2026-05-24). Each value is the recommended batch_size for +``jax.jit(jax.vmap(likelihood))`` on an NVIDIA A100 80 GB PCIe, with: + +- 65 GB effective VRAM budget (15 GB headroom for JAX runtime, driver + allocations, and fragmentation). +- 1.15× safety factor on the per-replica static memory estimate. +- Cap at 64 for XLA compile-time tractability. + +A value of ``None`` means vmap is **intentionally skipped or blocked** +for that cell: + +- Datacube cells: natural batching axis is "channels", not "parameters". +- Interferometer mge at ALMA+ scale: inherently blocked. MGE's mapping + matrix is fully dense (every Gaussian maps to every pixel), so the + per-call NUFFT cost is O(N_vis × N_src) and can't use the + sparse-operator shortcut. ``transform_mapping_matrix`` chunking would + cap per-chunk memory but per-call time remains prohibitive at 1M+ vis. +- Interferometer pixelization at ALMA+ scale: blocked on + ``transform_mapping_matrix`` not being chunked. Pixelization's mapping + IS sparse (localized mesh) and could eventually use the + sparse-operator path (like delaunay); just not implemented yet. + +Keys are ``(dataset, model, instrument)`` tuples — flat for easy lookup. + +To add a new instrument: + 1. Run the probe SLURM script for it (see ``vram/README.md``). + 2. Read the ``recommended_batch_size`` from the probe JSON. + 3. Add the row below. + 4. Re-run the regular profile to confirm the chosen batch holds at + steady state. +""" + +from __future__ import annotations + +from typing import Optional + +VMAP_BATCH: dict[tuple[str, str, str], Optional[int]] = { + # ========================================================================= + # Imaging cells — 4 instruments, 3 cells. + # Per-replica cost dominated by mapping_matrix (n_mask × n_source × 8 bytes). + # AO (700-px mask) is the most constrained; euclid (70-px) the cheapest. + # ========================================================================= + # + # delaunay (1500-node Hilbert mesh) + # NOTE: probe-recommended sizes halved for hst/jwst/ao after cuFFT + # scratch-allocator failures at the probe-predicted batch. The static + # memory_analysis() doesn't account for cuFFT batched-plan scratch. + ("imaging", "delaunay", "euclid"): 64, # 270 MB / replica — probe OK + ("imaging", "delaunay", "hst"): 16, # 922 MB / replica — probe said 62, cuFFT failed + ("imaging", "delaunay", "jwst"): 8, # 2,415 MB / replica — probe said 23, cuFFT failed + ("imaging", "delaunay", "ao"): 1, # 17,485 MB / replica — probe said 3, OOM at 3 + # + # pixelization (35×35 = 1225-node rectangular mesh) + ("imaging", "pixelization", "euclid"): 64, # 273 MB / replica — probe OK + ("imaging", "pixelization", "hst"): 16, # 931 MB / replica — probe said 62, cuFFT failed + ("imaging", "pixelization", "jwst"): 8, # 2,428 MB / replica — probe said 23, cuFFT failed + ("imaging", "pixelization", "ao"): 1, # 17,537 MB / replica — probe said 3 + # + # mge (~25 analytical Gaussians — small, constant per-replica cost) + ("imaging", "mge", "euclid"): 64, # 6 MB / replica + ("imaging", "mge", "hst"): 64, # 16 MB / replica + ("imaging", "mge", "jwst"): 64, # 42 MB / replica + ("imaging", "mge", "ao"): None, # 296 MB / replica — vmap CORRECTNESS bug: 3 distinct log_ev in batch=64. Separate investigation. + # + # ========================================================================= + # Interferometer cells — 4 instruments, 3 cells (mge/pix blocked at ALMA+). + # Delaunay uses the W-Tilde sparse-operator path (per-call cost is + # mask-FFT-dominated, NOT visibility-count-dominated). mge/pixelization + # use the full NUFFT mapping matrix (blocked at 1M+ vis — see note above). + # ========================================================================= + # + # delaunay (1000-node Hilbert mesh, sparse-operator path) + ("interferometer", "delaunay", "sma"): 64, # 92 MB / replica — probe OK + ("interferometer", "delaunay", "alma"): 64, # 322 MB / replica — probe OK + ("interferometer", "delaunay", "alma_high"): 16, # 1,243 MB / replica — probe said 46, OOM at runtime + ("interferometer", "delaunay", "jvla"): 3, # 7,689 MB / replica — probe said 7, OOM at runtime + # + # mge — sma only; alma+ INHERENTLY blocked (dense model, O(N_vis × N_src)). + ("interferometer", "mge", "sma"): 64, # 160 MB / replica + ("interferometer", "mge", "alma"): None, # blocked: dense mapping → 62 GB gather buffer at 1M vis + ("interferometer", "mge", "alma_high"): None, + ("interferometer", "mge", "jvla"): None, + # + # pixelization — sma only; alma+ blocked on unchunked transform_mapping_matrix. + # Pixelization mapping IS sparse — could eventually use sparse-operator path. + ("interferometer", "pixelization", "sma"): 64, # 93 MB / replica + ("interferometer", "pixelization", "alma"): None, # blocked: needs sparse-operator or chunked NUFFT mapping + ("interferometer", "pixelization", "alma_high"): None, + ("interferometer", "pixelization", "jvla"): None, + # + # ========================================================================= + # Datacube — intentionally skipped (parameter-axis vmap not meaningful; + # cube batching is over channels, handled by the per-channel loop). + # ========================================================================= + ("datacube", "delaunay", "sma"): None, + ("datacube", "delaunay", "alma"): None, + ("datacube", "delaunay", "alma_high"): None, + # + # ========================================================================= + # Point source — tiny per-replica cost, cap at 64. + # ========================================================================= + ("point_source", "image_plane", "simple"): 64, # 3 MB / replica + ("point_source", "source_plane", "simple"): 64, # <1 MB / replica +} + + +def vmap_batch_for( + dataset: str, model: str, instrument: str +) -> Optional[int]: + """Return the per-(dataset, model, instrument) vmap batch_size for A100. + + Returns ``None`` when vmap is intentionally skipped (cube cells), + blocked (interferometer mge/pix at ALMA+), or when the cell hasn't + been probed yet. Callers should default to a small fallback (typically + 3) when ``None`` is returned for an un-probed cell, and skip vmap + entirely when ``None`` is returned for a known-blocked or known-skipped + cell. + """ + return VMAP_BATCH.get((dataset, model, instrument)) diff --git a/vram/probe.py b/vram/probe.py new file mode 100644 index 0000000..dc74f16 --- /dev/null +++ b/vram/probe.py @@ -0,0 +1,214 @@ +"""vmap-batch memory probe utilities. + +Given a JAX-traceable likelihood function and a parameter pytree, measure +the compiled program's VRAM footprint at two batch sizes and extrapolate +the maximum batch_size that fits a target device memory budget. + +The probe avoids running the steady-state timing loop — it only does the +compile + first-call sequence required to read ``compiled.memory_analysis()``. +Two batch sizes (default 2 and 4) are needed so we can decompose the +total program memory into a per-replica linear coefficient + a constant +overhead term: + + memory_at_batch(B) ≈ constant_overhead + B * per_replica_memory + +Then the largest batch fitting in ``vram_budget`` is:: + + floor((vram_budget - constant_overhead) / per_replica_memory) + +In practice JAX rematerialisation can make the relationship sub-linear at +large batch, so we add a safety floor (``vram_budget`` defaults to ~70 GB +on an 80 GB A100) and cap at ``max_batch`` (default 64) to keep compile +times tractable. +""" + +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Sequence + + +@dataclass(frozen=True) +class ProbeSample: + """One probe sample: the compiled vmap program's memory cost at a given batch. + + ``peak_bytes`` is the field XLA reports as ``peak_memory_in_bytes`` on the + compiled program — the maximum simultaneous device allocation across the + full computation, including rematerialised activations. This is what we + extrapolate against for sizing the production batch. + + ``output_bytes`` + ``temp_bytes`` are retained for backward compat with older + JAX versions that don't expose ``peak_memory_in_bytes`` (we fall back to + the sum, which over-counts but is conservative). + """ + + batch_size: int + peak_bytes: int + output_bytes: int = 0 + temp_bytes: int = 0 + + @property + def peak_mb(self) -> float: + return self.peak_bytes / 1024**2 + + @property + def total_mb(self) -> float: + """Conservative upper bound (output+temp sum). Use ``peak_mb`` if present.""" + return (self.output_bytes + self.temp_bytes) / 1024**2 + + @property + def effective_mb(self) -> float: + """The memory number to extrapolate against — prefer peak, fall back to sum.""" + if self.peak_bytes > 0: + return self.peak_mb + return self.total_mb + + +@dataclass(frozen=True) +class ProbeResult: + """Outcome of a vmap memory probe across multiple batch sizes.""" + + dataset: str + model: str + instrument: str + samples: list[ProbeSample] = field(default_factory=list) + + @property + def per_replica_mb(self) -> float: + """Linear coefficient of (peak) program memory vs batch_size. + + Computed from the two extreme samples. If only one sample exists, + we fall back to ``effective_mb / batch_size`` (assumes zero constant + overhead — conservative for the upper bound). + """ + if len(self.samples) < 2: + s = self.samples[0] + return s.effective_mb / s.batch_size + s_lo, s_hi = self.samples[0], self.samples[-1] + return (s_hi.effective_mb - s_lo.effective_mb) / ( + s_hi.batch_size - s_lo.batch_size + ) + + @property + def constant_overhead_mb(self) -> float: + """Intercept of the program-memory-vs-batch linear fit.""" + if len(self.samples) < 2: + return 0.0 + s_lo = self.samples[0] + return s_lo.effective_mb - s_lo.batch_size * self.per_replica_mb + + +def probe_vmap_memory( + func, + args_pytree, + batch_sizes: Sequence[int] = (1, 4), + *, + dataset: str = "", + model: str = "", + instrument: str = "", +) -> ProbeResult: + """JIT-vmap ``func`` at each batch size, read ``compiled.memory_analysis()``. + + ``func`` must accept a pytree of parameters whose leaves are JAX arrays. + ``args_pytree`` is the single-replica pytree (NOT pre-batched); we + broadcast it to each batch size internally. + + Reads ``peak_memory_in_bytes`` when available (post-JAX-0.4.X) — this is + the maximum simultaneous device allocation including rematerialised + activations, and is the correct number to extrapolate against. Falls back + to ``output_size_in_bytes + temp_size_in_bytes`` for older JAX versions + (an over-estimate, but conservative). + + For cells with expensive compile (delaunay: 10-30 min/batch), call with + ``batch_sizes=(1,)`` and accept single-point extrapolation. For cheap-compile + cells (mge / pixelization: ~10 s/batch), use ``(1, 4, 16)`` for a multi-point + fit that catches XLA rematerialisation non-linearity. + """ + import jax + import jax.numpy as jnp + + samples: list[ProbeSample] = [] + for B in batch_sizes: + parameters = jax.tree_util.tree_map( + lambda leaf: jnp.broadcast_to(leaf, (B, *leaf.shape)), + args_pytree, + ) + vmapped = jax.jit(jax.vmap(func)) + lowered = vmapped.lower(parameters) + compiled = lowered.compile() + mem = compiled.memory_analysis() + peak_bytes = int(getattr(mem, "peak_memory_in_bytes", 0)) + samples.append( + ProbeSample( + batch_size=int(B), + peak_bytes=peak_bytes, + output_bytes=int(mem.output_size_in_bytes), + temp_bytes=int(mem.temp_size_in_bytes), + ) + ) + return ProbeResult( + dataset=dataset, model=model, instrument=instrument, samples=samples + ) + + +def recommend_batch_size( + probe: ProbeResult, + *, + vram_budget_gb: float = 65.0, + safety_factor: float = 1.15, + max_batch: int = 64, +) -> int: + """Recommend the largest batch_size that fits ``vram_budget_gb`` on device. + + Computed by linear extrapolation of (peak) memory vs batch_size from the + probe samples; capped at ``max_batch`` to keep XLA compile time tractable. + + Defaults targeted at A100 80 GB: + - ``vram_budget_gb=65`` — leaves ~15 GB headroom for JAX runtime, CUDA driver + allocations, allocator fragmentation, and per-call activation slack that + static analysis doesn't account for. + - ``safety_factor=1.15`` — multiplier on the static peak estimate, per the + industry rule-of-thumb that XLA static analysis under-counts the real + runtime peak by ~10-15% on complex graphs. + - ``max_batch=64`` — compile time scales superlinearly with batch on some + cells (notably delaunay), and 64 is roughly where diminishing returns kick + in for production samplers. + """ + budget_mb = vram_budget_gb * 1024 + per_replica = probe.per_replica_mb * safety_factor + overhead = probe.constant_overhead_mb * safety_factor + if per_replica <= 0: + return max_batch + raw = int((budget_mb - overhead) / per_replica) + return max(1, min(raw, max_batch)) + + +def write_probe_json( + probe: ProbeResult, + recommended_batch_size: int, + output_path: Path, + *, + vram_budget_gb: float = 65.0, + safety_factor: float = 1.15, + max_batch_cap: int = 64, + extra: dict | None = None, +) -> None: + """Serialise probe samples + recommendation to JSON.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "dataset": probe.dataset, + "model": probe.model, + "instrument": probe.instrument, + "samples": [asdict(s) for s in probe.samples], + "per_replica_mb": round(probe.per_replica_mb, 3), + "constant_overhead_mb": round(probe.constant_overhead_mb, 3), + "vram_budget_gb": vram_budget_gb, + "safety_factor": safety_factor, + "max_batch_cap": max_batch_cap, + "recommended_batch_size": recommended_batch_size, + } + if extra is not None: + payload.update(extra) + output_path.write_text(json.dumps(payload, indent=2))