Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hpc/batch_gpu/error/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.gitignore
!.gitignore
2 changes: 2 additions & 0 deletions hpc/batch_gpu/output/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.gitignore
!.gitignore
49 changes: 49 additions & 0 deletions hpc/batch_gpu/submit_imaging_mge_a100_hst_fp64
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/bin/bash -l
#
# A100 first-class search profiling: searches/nautilus/imaging/mge × hst × fp64.
#
# Drives af.Nautilus end-to-end (visualization, samples I/O, search.summary)
# on the HST imaging MGE model from the autolens_profiling/searches package.
# Mirrors the resource budget of the sibling likelihood profiling submit
# (z_projects/profiling/hpc/batch_gpu/submit_imaging_mge_a100_hst_fp64) but
# allocates more wall time because a first-class fit runs the full Nautilus
# convergence loop, not a one-shot likelihood evaluation.

#SBATCH -J search_nautilus_imaging_mge_hst_fp64
#SBATCH --partition=gpu
#SBATCH --gres=gpu:1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=4
#SBATCH --mem=64gb
#SBATCH --time=2:00:00
#SBATCH -o output/output.%A.out
#SBATCH -e error/error.%A.err
#SBATCH --mail-type=END,FAIL
#SBATCH --mail-user=james.w.nightingale@durham.ac.uk

export AP_ROOT=/mnt/ral/jnightin/autolens_profiling
source /mnt/ral/jnightin/PyAutoNSS/PyAutoNSS/bin/activate

export JAX_PLATFORM_NAME=cuda
export JAX_PLATFORMS=cuda,cpu
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export JAX_ENABLE_X64=True
export NUMBA_CACHE_DIR=/tmp/numba_cache
export MPLCONFIGDIR=/tmp/matplotlib

nvidia-smi

echo "=========================================="
date
echo "Cell: searches/nautilus/imaging/mge"
echo "Instrument: hst"
echo "Precision: fp64"

cd $AP_ROOT
python3 searches/nautilus/imaging/mge.py \
--instrument hst \
--config-name hpc_a100_fp64 \
--output-dir $AP_ROOT/results/searches/nautilus/imaging/mge/hst

echo "Finished."
date
185 changes: 136 additions & 49 deletions searches/README.md
Original file line number Diff line number Diff line change
@@ -1,72 +1,159 @@
# searches
# `searches/` — first-class search profiling

Sampler / search profiling for the PyAutoLens HST MGE lens-modelling likelihood. Each subfolder drives a single sampler family directly against the real likelihood — bypassing `af.NonLinearSearch` — so the per-sampler convergence characteristics (wall time, likelihood evaluations, posterior ESS, evals/time to ML) can be compared on identical footing.
This section profiles **first-class PyAutoFit search objects** end-to-end:
`af.Nautilus` today, with the registry shape ready for `af.DynestyStatic`,
`af.BlackJAXNUTS`, `af.Emcee`, etc. Unlike `likelihood_runtime/` (which
profiles `analysis.log_likelihood_function` in isolation), every cell here
runs `search.fit(model=model, analysis=analysis)` — so visualization,
samples I/O, `samples_info.json`, latent variables, and every other piece
of PyAutoFit machinery is exercised and measured.

## Why bypass `af.NonLinearSearch`?
## Design

`af.NonLinearSearch` adds caching, multi-process forking, output formatting, and result hierarchies that are valuable for production fits but obscure the underlying sampler's cost. The scripts in this section call the sampler library directly and instrument every likelihood evaluation through a shared `MLTracker`. The result is a clean apples-to-apples comparison of:
| Dimension | Values |
|----------------|---------------------------------------------------------------------------|
| Sampler | `nautilus` (more to come via `_samplers.SAMPLER_BUILDERS`) |
| Dataset class | `imaging`, `interferometer`, `point_source`, `datacube` |
| Model type | `mge`, `pixelization`, `delaunay`, `image_plane`, `source_plane` |
| Instrument | per-dataset-class (HST/Euclid/JWST/AO; SMA/ALMA/ALMA-high/JVLA; simple) |
| Hardware | `local_cpu`, `local_gpu`, `hpc_a100` (external dispatch) |
| Precision | `fp64`, `mp` (mixed precision via `al.Settings(use_mixed_precision=...)`) |

- Wall time and likelihood-evaluation count to **Nautilus's default convergence** (`n_eff=10000`, `f_live=0.01`).
- Per-evaluation likelihood cost (NumPy baseline vs JAX-JIT'd path).
- Evals-to-ML and time-to-ML — the eval index and wall time at which the running max log L first came within 1 nat of the final maximum.

## Shared helpers

| File | Role |
|------|------|
| [`_setup.py`](./_setup.py) | Builds the HST imaging dataset, the MGE + Isothermal + ExternalShear lens model with an MGE source bulge, and the `AnalysisImaging` object. The dataset, mask, and model mirror the reference setup in [`likelihood/imaging/mge.py`](../likelihood/imaging/mge.py) so likelihood values are directly comparable across the two sections. |
| [`_metrics.py`](./_metrics.py) | `MLTracker` — records the log-likelihood and wall time of every evaluation, computes evals-to-ML and time-to-ML headline numbers. Also offers `MLTracker.from_log_l_history` for samplers that JIT their likelihood and only expose log-L per dead/live point post hoc. |

## Supported samplers

| Sampler | Folder | Status | Notes |
|---------|--------|--------|-------|
| Nautilus | [`nautilus/`](./nautilus/README.md) | ✓ profiled | Both NumPy (`simple.py`) and JAX-JIT (`jax.py`) variants. |
| Dynesty | _planned_ | not yet mirrored | Static nested sampling; reference scripts at `autolens_workspace_developer/searches_minimal/dynesty_simple.py`. |
| Emcee | _planned_ | not yet mirrored | Affine-invariant ensemble MCMC. |
| BlackJAX (NUTS, SMC) | _planned_ | not yet mirrored | Pure-JAX HMC family. Gradient pathology surfaced in upstream `sweep_findings.md`; HMC viability depends on first fixing NaN-gradient hot spots. |
| NumPyro (ESS) | _planned_ | not yet mirrored | Ensemble slice sampler under JAX. |
| PocoMC | _planned_ | not yet mirrored | Preconditioned Monte Carlo. |
| NSS (simple, jit, grad) | _planned_ | not yet mirrored | Nested slice sampler; `nss_jit.py` shows VRAM ceiling on consumer GPUs (see `sweep_findings.md`). |
| LBFGS | _planned_ | not yet mirrored | Not a sampler; serves as the maximum-likelihood reference point. |

Each row above corresponds to one or more scripts under `autolens_workspace_developer/searches_minimal/`; the mirror migration here under their own follow-up prompts.

## Versioned artifacts

Each script writes a JSON + PNG pair to:
Layout:

```
results/searches/<sampler>/<script>_summary_v<al.__version__>.{json,png}
searches/
README.md # this file
_setup.py # dataset/model/analysis dispatchers
_samplers.py # sampler registry + per-(ds, model) n_live
_metrics.py # viz wall-time interception + result reader
_runner.py # shared driver (every leaf calls run_search)
sweep.py # matrix driver, resume-by-default
aggregate.py # comparison.json + comparison.png per cell
nautilus/
imaging/{mge, pixelization, delaunay}.py
interferometer/{mge, pixelization, delaunay}.py
point_source/{image_plane, source_plane}.py
datacube/delaunay.py
```

The JSON carries the structured timings + sampler config + best-fit summary. The PNG is a bar chart of the headline timings (wall time, time per eval, time to ML; plus JIT compile time on JAX scripts).
## Key design choices

**First-class only.** No more wrapping `nautilus.Sampler` directly. The
old `simple.py` / `jax.py` scripts are deleted. Every cell goes through
`af.Nautilus.fit(model, analysis)`, so visualization, output writes,
sample I/O, and latent-variable computation are part of the profile.

**SLaM-matched `n_live`.** Per `autolens_workspace/scripts/guides/modeling/
slam_start_here.py`: MGE / point-source / parametric phases use
`n_live=200` (matches `source_lp[1]`); pixelization / Delaunay phases
use `n_live=150` (matches `source_pix[1]`).

**`number_of_cores=1` always.** This profile measures per-evaluation
end-to-end cost. Production scaling via `number_of_cores > 1` is a
separate axis a future sweep can introduce.

**JAX rows force `force_x1_cpu=True` and `use_jax_vmap=True`.** This is
mandatory: `nautilus.Sampler` forking under multiprocessing corrupts
JAX state. The trade-off is one batched evaluation per Nautilus step.

**Visualization wall-time is split out.** `_metrics.attach_viz_timer`
wraps every visualize-family hook on the analysis (`visualize`,
`visualize_combined`, `visualize_before_fit`,
`visualize_before_fit_combined`) plus the search's `plot_results`. The
JSON reports `total_wall_s`, `viz_wall_s` and the derived
`sampler_wall_s = total_wall_s - viz_wall_s` so you can ask both "how
long did the full first-class fit take?" and "how much was viz?".

**`force_pickle_overwrite=True` on every search.** Defeats the
`.completed`-file resume that would otherwise return cached results
the second time you run the same `path_prefix`. Combined with
unique-per-(sampler, ds, model, instrument, config) `path_prefix`, this
keeps repeated sweep runs honest.

## Datacube multi-channel fitting

`datacube/delaunay.py` fits `_DATACUBE_N_CHANNELS` (default 4) identical
interferometer channels via `af.FactorGraphModel`. Each channel becomes
its own `al.AnalysisInterferometer`, wrapped in an `af.AnalysisFactor`
paired with `model.copy()`, then combined under a single global model —
the same pattern documented in
`autolens_workspace/scripts/multi/modeling.py`. The N channels are
identical copies of the per-instrument dataset; the profile measures
cube-cost scaling, not band-wavelength variation.

To change the channel count, edit `_DATACUBE_N_CHANNELS` in `_setup.py`
(34 matches the existing ALMA cube fiducial; 4 keeps profiling
turnaround sane).

## What this *doesn't* profile (yet)

- **Pool scaling.** `number_of_cores > 1` sweeps are future work.
- **Adapt-image regeneration across phases.** Pixelization / Delaunay
cells use a truth-derived `lensed_source.fits` cached next to the
dataset. Production SLaM regenerates this between phases.
- **A100 dispatch.** The local sweep generates only CPU and laptop-GPU
rows. The `hpc_a100_fp64` / `hpc_a100_mp` config names exist in
`sweep.py` for parity with `likelihood_runtime/`; the actual dispatch
to RAL HPC happens externally (same mechanism as the likelihood
sweep).
- **Samplers other than Nautilus.** The registry is in place; adding
`dynesty`, `blackjax_nuts`, `emcee`, etc. is one function per sampler
in `_samplers.py`.

## Running

Single cell (CPU NumPy, fastest path):

Old versions are retained alongside new ones; Phase 4's dashboard surfaces the latest per axis.

## Running a script
```bash
python searches/nautilus/imaging/mge.py \
--instrument hst --config-name local_cpu_fp64
```

From the repo root (cwd matters because `_setup.build_dataset()` resolves `dataset/imaging/hst/` relative to the repo root via `Path(__file__).resolve().parent.parent`):
Single cell (laptop GPU, JAX-vmap):

```bash
cd autolens_profiling
python searches/nautilus/simple.py
python searches/nautilus/jax.py
JAX_PLATFORM_NAME=cuda JAX_PLATFORMS=cuda,cpu \
XLA_PYTHON_CLIENT_MEM_FRACTION=0.5 \
python searches/nautilus/imaging/mge.py \
--instrument hst --config-name local_gpu_fp64
```

Or as modules:
Full sweep (every cell × instrument × config) — warning, this is long:

```bash
python -m searches.nautilus.simple
python -m searches.nautilus.jax
python searches/sweep.py
```

Both invocation styles work — each script injects the repo root into `sys.path` before importing `searches._{setup,metrics}` for robustness.
Iteration sweep (one cell, one instrument, CPU only):

**Requirements:** `nautilus-sampler` for the Nautilus scripts (`pip install nautilus-sampler`). The JAX variant additionally needs a working JAX install.
```bash
python searches/sweep.py \
--only nautilus/imaging/mge \
--instrument hst \
--skip-gpu --skip-mp
```

**Codex / sandboxed runs:**
Aggregate post-sweep:

```bash
NUMBA_CACHE_DIR=/tmp/numba_cache MPLCONFIGDIR=/tmp/matplotlib python searches/nautilus/simple.py
python searches/aggregate.py
```

## Output layout

```
results/searches/
<sampler>/<dataset_class>/<model>/<instrument>/
<config_name>.json # per-config headline metrics
<config_name>.png # per-config bar chart
<config_name>.log # subprocess stdout/stderr (sweep only)
comparison.json # cross-config aggregation (aggregate.py)
comparison.png # cross-config bar chart (aggregate.py)
```

The PyAutoFit search itself writes its own output (`samples.csv`,
`samples_info.json`, `search.summary`, visualization, ...) to the
autoconf `output_path` under `path_prefix=searches/<sampler>/
<dataset_class>/<model>/<instrument>`. The metric JSON+PNG above live
separately under `results/searches/`.
Loading
Loading