diff --git a/.gitignore b/.gitignore index 99c2801..5a9d230 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,6 @@ dataset/**/lensed_source.fits results/likelihood/ results/simulators/ results/searches/ + +# Latent profiling sweep outputs — regenerated by latent/sweep.py, not committed. +latent/results/ diff --git a/config/latent.yaml b/config/latent.yaml new file mode 100644 index 0000000..ea197f0 --- /dev/null +++ b/config/latent.yaml @@ -0,0 +1,14 @@ +# Workspace overrides for the library lensing latent toggles. Enables +# all 5 default latents so per-latent profile scripts in latent/imaging/ +# can isolate each one via runtime `conf.instance.push(...)` overrides +# (each script writes a single-key temp config before constructing the +# Analysis). +# +# This file's presence on the search path also means a one-off manual +# run of any latent-enabled fit from this repo produces the full +# catalogue without further config. +total_lens_flux_mujy: true +total_lensed_source_flux_mujy: true +total_source_flux_mujy: true +magnification: true +effective_einstein_radius: true diff --git a/latent/README.md b/latent/README.md new file mode 100644 index 0000000..4639f31 --- /dev/null +++ b/latent/README.md @@ -0,0 +1,129 @@ +# latent + +Per-latent runtime profiling for the PyAutoLens library latent-variable catalogue. The headline question is: + +> *"How long does each latent cost per call, and is the closure cache for `effective_einstein_radius` actually helping?"* + +Run scripts here — or, more commonly, the [`sweep.py`](sweep.py) driver — when you need to predict the cost of enabling a latent at scale (every-sample mode), compare CPU vs GPU throughput for the same latent, or measure first-call vs cached-call timing for the JAX-jit path through `LensCalc.einstein_radius_jit_from()`. + +For the latent **values** (correctness, not timing) see `autolens_workspace_test/scripts/latent/latent_variables_smoke.py`. The two are deliberately disjoint — this package times; that one validates. + +## Methodology + +Each script measures **one** latent in isolation: + +1. **Conf override.** Before constructing the Analysis, the script writes a temporary `config/latent.yaml` and calls `conf.instance.push(...)` to mark only the target latent as enabled. PyAutoFit's `compute_latent_samples` therefore dispatches just this one function, no contamination from the other four. +2. **Eager numpy baseline.** `AnalysisImaging(..., use_jax=False)` + a single call to `compute_latent_variables(parameters, model)`. This is the correctness reference and the worst-case (un-JIT'd) cost. +3. **Single-call JIT.** `jax.jit(lambda p: analysis_jax.compute_latent_variables(p, model))` — records `lower`, `compile`, `first-call`, and `steady-state × 10` (steady-state averaged). The steady-state number is what production code in N-draws mode actually pays per draw. +4. **Vmap batched.** `jax.jit(jax.vmap(...))` with batch=3 — records the per-call cost as `batch_time / batch_size`. Vmap is the honest measurement: per-sample JIT on a concrete `ModelInstance` can constant-fold parts of the computation and read 20-30× faster than reality (see memory `feedback_jax_pure_callback_const_fold`). +5. **Closure cache delta** (effective_einstein_radius only). The LensCalc `_zero_contour_cache` at `autogalaxy/operate/lens_calc.py:1580-1586` memoises the `(eigen_fn, ZeroSolver)` pair. We time first-call and second-call on the same fresh `LensCalc` to surface the cache hit. Expected: second call is ~20-50% faster on numpy, much more on JIT (one full recompile avoided). + +All JAX timings use `block_until_ready()` to force synchronous measurement. Errors that arise from missing optional deps (`jax_zero_contour`, jax extras) are recorded in `jit_error` / `vmap_error` fields rather than failing the script — sweeps need to keep going past per-config failures. + +## The 6-config matrix + +Same matrix as `likelihood_runtime/`: + +| Config | Backend | Precision | Env / Flag | +|--------|---------|-----------|------------| +| `local_cpu_fp64` | CPU | fp64 | `JAX_PLATFORM_NAME=cpu JAX_PLATFORMS=cpu` | +| `local_cpu_mp` | CPU | mixed | same + `--use-mixed-precision` | +| `local_gpu_fp64` | RTX 2060 | fp64 | `JAX_PLATFORM_NAME=cuda JAX_PLATFORMS=cuda,cpu` | +| `local_gpu_mp` | RTX 2060 | mixed | same + `--use-mixed-precision` | +| `hpc_a100_fp64` | A100 (80 GB) | fp64 | SLURM-dispatched separately | +| `hpc_a100_mp` | A100 | mixed | same + `--use-mixed-precision` | + +The `cuda,cpu` listing on GPU configs is load-bearing — the `effective_einstein_radius` path needs a CPU device available even when the primary platform is CUDA, because `ZeroSolver` uses `jax.lax.cond` / `jax.lax.while_loop` that occasionally fall back to host evaluation under specific solver states. + +## What mixed precision means here + +For the four flux latents (`total_lens_flux_mujy`, `total_lensed_source_flux_mujy`, `total_source_flux_mujy`, `magnification`), mixed precision affects the upstream `AnalysisImaging.fit_from(instance)` call — specifically the PSF convolution and the mapping-matrix accumulation if the lens / source uses linear light profiles. The latent itself is just a reduction (sum + magzero conversion), so its direct cost is unchanged; mp moves the needle by making the fit cheaper to build per sample. + +For `effective_einstein_radius`, mixed precision is essentially a no-op — the `ZeroSolver` and the underlying deflection-field evaluation stay in fp64. The Einstein radius is sensitive enough that downcasting would compromise the zero-contour fidelity. + +Expect: mp helps the flux latents in proportion to the underlying fit cost (~5-20%), and is neutral on `effective_einstein_radius`. + +## Scripts + +| Script | Latent | Cost class | Notes | +|--------|--------|-----------|-------| +| `imaging/total_lens_flux_mujy.py` | `total_lens_flux_mujy` | trivial | Sum over `fit.galaxy_image_dict[fit.galaxies[0]].array` + magzero conversion. ~µs scale once JIT'd. | +| `imaging/total_lensed_source_flux_mujy.py` | `total_lensed_source_flux_mujy` | trivial | Same shape as above, source index `[-1]`. | +| `imaging/total_source_flux_mujy.py` | `total_source_flux_mujy` | low | Evaluates `tracer_linear_light_profiles_to_light_profiles.galaxies[-1].image_2d_from(grid=...)` — heavier than the dict-lookup latents because it computes a fresh source-plane image. ~10x the dict-lookup variants. | +| `imaging/magnification.py` | `magnification` | low | Composes the lensed and intrinsic source fluxes; cost is dominated by the `total_source_flux_mujy` recompute. | +| `imaging/effective_einstein_radius.py` | `effective_einstein_radius` | high | The marquee — JIT path through `LensCalc.einstein_radius_jit_from` → `ZeroSolver.zero_contour_finder` → `jnp.roll` shoelace. First-call dominated by JAX trace + ZeroSolver compile. Closure cache hit on second call removes the `_make_eigen_fn` rebuild. | + +## Driving the matrix — `sweep.py` and `aggregate.py` + +```bash +# All 5 latents, local CPU + GPU x fp64 + mp (8 configs total per latent) +python latent/sweep.py + +# Restrict to one latent during iteration +python latent/sweep.py --only imaging/effective_einstein_radius + +# Skip a backend +python latent/sweep.py --skip-gpu # CPU only +python latent/sweep.py --skip-cpu # GPU only + +# Aggregate per-config JSONs into a single comparison artefact +python latent/aggregate.py +``` + +Per-config JSONs land at `/imaging//.json`. The aggregator produces `/imaging//comparison.json` + `.png` with one row per config and the production cost (steady-state JIT for N-draws mode; eager numpy for the every-sample fallback). + +Default output root: `/autolens_workspace_developer/jax_profiling/results/latent/`. Mirrors the `likelihood_runtime/` precedent. + +## GPU practicalities + +If you're running locally on the RTX 2060 Max-Q (6 GB), set: + +```bash +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.5 +renice -n 10 -p $$ +``` + +The `XLA_PYTHON_CLIENT_MEM_FRACTION=0.5` cap stops JAX from preallocating most of the 6 GB VRAM (which makes the desktop unusable). The renice keeps the profiling job from stealing CPU from interactive work. The HPC A100 (80 GB) doesn't need either. + +## How to read the output + +Each per-config JSON looks like: + +```json +{ + "latent_key": "effective_einstein_radius", + "config_name": "local_cpu_fp64", + "use_mixed_precision": false, + "eager_value": 0.0, + "eager_time_s": 0.0175, + "closure_cache_first_call_s": 0.0170, + "closure_cache_second_call_s": 0.0133, + "jit_value": ..., + "jit_lower_s": ..., + "jit_compile_s": ..., + "jit_first_call_s": ..., + "jit_steady_state_s": ..., + "jit_error": null, + "vmap_per_call_s": ..., + "vmap_value": ..., + "vmap_error": null +} +``` + +Read in this order: + +1. **`jit_steady_state_s`** — the per-call cost in production. This is the headline for N-draws-from-PDF mode (`compute_latent_variables` runs `latent_draw_via_pdf_size=100` times per fit). If it's larger than `eager_time_s`, JIT isn't helping for this latent (typical for the trivial flux latents, where eager numpy is already a few µs and JIT compile dominates). +2. **`vmap_per_call_s` vs `jit_steady_state_s`** — should be similar. If vmap is dramatically faster, the JIT path is hitting a constant-fold and the single-call number is overstated. +3. **`closure_cache_first_call_s` vs `_second_call_s`** (Einstein radius only) — the cache delta on numpy. A small delta (<10%) means the cache is being used but the per-call work dominates (i.e. cache hit doesn't save much). A large delta (>30%) means the cache is the right optimisation. Zero delta means the cache isn't being hit at all — investigate. +4. **`jit_error` / `vmap_error`** — non-null means the optional JAX extras (`jax_zero_contour` for Einstein radius, others as appropriate) aren't installed. Numpy fallback timings remain valid; install the extras to fill in the JIT/vmap columns. + +The aggregator surfaces the production-cost column (steady-state JIT, or eager if JIT failed) as the headline, with first-call and compile times in adjacent columns for full provenance. + +## When the cache helps / hurts + +The `_zero_contour_cache` at `lens_calc.py:1580-1586` memoises by `(kind, pixel_scales, tol, max_newton)`. Two scenarios: + +- **Cache helps**: every sample in a posterior draw uses the same solver settings (the default), so every call after the first reuses the same `(eigen_fn, ZeroSolver)` pair. First-call pays the `_make_eigen_fn` cost (which is the dominant cost on the JIT path); every subsequent call is pure compute. Expect 30-60% cache-hit speedup on JIT; 15-25% on numpy. +- **Cache hurts** (rare): if downstream code constructs a fresh `LensCalc` per call (instead of reusing one), the cache never hits. The `total_source_flux_mujy` and `effective_einstein_radius` latents both do `LensCalc.from_mass_obj(fit.tracer)` per call — but within a single `compute_latent_variables` invocation, the LensCalc is constructed once and the cache lives on it. Across calls (different posterior draws), JAX's JIT compile cache picks up the slack. + +If you see the cache delta drop to zero across runs, suspect that the calling code is rebuilding LensCalcs between samples instead of reusing one. The current PyAutoLens dispatcher does it the right way — this is more relevant for hand-rolled custom Analysis subclasses (see memory `feedback_jax_closure_cache_busts` for the pattern). diff --git a/latent/aggregate.py b/latent/aggregate.py new file mode 100644 index 0000000..425ae3c --- /dev/null +++ b/latent/aggregate.py @@ -0,0 +1,302 @@ +"""Aggregate per-config JSONs for a swept latent cell into comparison.{json,png}. + +Reads every ``.json`` under a cell's output dir (see +``sweep.py``) and produces a single ``comparison.json`` and a +``comparison.png`` bar chart. + +The ``comparison.png`` plots first-call vs steady-state vs vmap-per-call time +per config on a log scale. For the ``effective_einstein_radius`` cell, the +chart additionally surfaces the ``closure_cache_first_call_s`` and +``closure_cache_second_call_s`` fields when present, so the LensCalc closure +warm-up cost is immediately visible. + +Usage:: + + # All latent cells under the default sweep output root + python latent/aggregate.py + + # One cell only + python latent/aggregate.py --cell imaging/effective_einstein_radius +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + + +_REPO_ROOT = Path(__file__).resolve().parents[1] +_WT_ROOT = _REPO_ROOT.parent +_DEFAULT_OUTPUT_ROOT = _WT_ROOT / "autolens_workspace_developer" / "jax_profiling" / "results" / "latent" + +# Canonical ordering of cells — drives auto-discovery sort order. +_CELLS: list[tuple[str, str]] = [ + ("imaging", "total_lens_flux_mujy"), + ("imaging", "total_lensed_source_flux_mujy"), + ("imaging", "total_source_flux_mujy"), + ("imaging", "magnification"), + ("imaging", "effective_einstein_radius"), +] + +# Stable config ordering across all comparison tables. +_CONFIG_ORDER = ( + "local_cpu_fp64", + "local_cpu_mp", + "local_gpu_fp64", + "local_gpu_mp", + "hpc_a100_fp64", + "hpc_a100_mp", +) + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__.split("\n\n")[0]) + p.add_argument( + "--output-root", + type=Path, + default=_DEFAULT_OUTPUT_ROOT, + help=f"Root output dir. Default: {_DEFAULT_OUTPUT_ROOT}", + ) + p.add_argument( + "--cell", + nargs="+", + default=None, + metavar="CLASS/LATENT", + help="Only aggregate these cells; default = auto-discover under --output-root.", + ) + return p.parse_args() + + +def _discover_cells(output_root: Path) -> list[tuple[str, str]]: + """Find every / subdir that contains at least one config JSON.""" + if not output_root.exists(): + return [] + + def _has_config_json(d: Path) -> bool: + return any(p.stem in _CONFIG_ORDER for p in d.glob("*.json")) + + found: list[tuple[str, str]] = [] + for cls_dir in sorted(output_root.iterdir()): + if not cls_dir.is_dir(): + continue + for latent_dir in sorted(cls_dir.iterdir()): + if latent_dir.is_dir() and _has_config_json(latent_dir): + found.append((cls_dir.name, latent_dir.name)) + return found + + +def _read_config(json_path: Path) -> dict: + data = json.loads(json_path.read_text()) + data.setdefault("config_name", json_path.stem) + return data + + +def _aggregate_cell(cell_dir: Path) -> dict: + configs: dict[str, dict] = {} + for json_path in sorted(cell_dir.glob("*.json")): + if json_path.name == "comparison.json": + continue + try: + configs[json_path.stem] = _read_config(json_path) + except Exception as exc: + sys.stderr.write(f" warn: failed to read {json_path}: {exc}\n") + + # Reorder by the canonical config order, then any extras at the end. + ordered: dict[str, dict] = {} + for name in _CONFIG_ORDER: + if name in configs: + ordered[name] = configs.pop(name) + for name, data in sorted(configs.items()): + ordered[name] = data + + return {"configs": ordered} + + +def _format_seconds(t: float | None) -> str: + if t is None or not np.isfinite(t): + return "—" + if t >= 1.0: + return f"{t:.2f}s" + if t >= 1e-3: + return f"{t * 1e3:.1f}ms" + return f"{t * 1e6:.0f}μs" + + +def _render_table(comparison: dict, cell_id: str) -> str: + lines = [f"=== {cell_id} ==="] + is_einstein = "effective_einstein_radius" in cell_id + if is_einstein: + header = ("config", "eager_time", "jit_first", "jit_steady", "vmap_per_call", + "cache_first", "cache_second") + else: + header = ("config", "eager_time", "jit_first", "jit_steady", "vmap_per_call") + rows = [header] + for name, cfg in comparison["configs"].items(): + eager_t = cfg.get("eager_time_s") + jit_first = cfg.get("jit_first_call_s") + jit_steady = cfg.get("jit_steady_state_s") + vmap = cfg.get("vmap_per_call_s") + if is_einstein: + cc_first = cfg.get("closure_cache_first_call_s") + cc_second = cfg.get("closure_cache_second_call_s") + rows.append(( + name, + _format_seconds(eager_t), + _format_seconds(jit_first), + _format_seconds(jit_steady), + _format_seconds(vmap), + _format_seconds(cc_first), + _format_seconds(cc_second), + )) + else: + rows.append(( + name, + _format_seconds(eager_t), + _format_seconds(jit_first), + _format_seconds(jit_steady), + _format_seconds(vmap), + )) + col_w = [max(len(r[i]) for r in rows) for i in range(len(rows[0]))] + for r in rows: + lines.append(" " + " ".join(s.ljust(w) for s, w in zip(r, col_w))) + return "\n".join(lines) + + +def _render_png(comparison: dict, cell_id: str, png_path: Path) -> None: + """Bar chart: first-call vs steady-state vs vmap-per-call, per config. + + For effective_einstein_radius, also plots closure_cache first/second call + as additional series when present. + """ + configs = comparison["configs"] + if not configs: + return + + config_names = list(configs.keys()) + is_einstein = "effective_einstein_radius" in cell_id + + # Collect timing series. + series: dict[str, list[float]] = { + "jit_first_call": [], + "jit_steady_state": [], + "vmap_per_call": [], + } + if is_einstein: + series["closure_cache_first"] = [] + series["closure_cache_second"] = [] + + for cname in config_names: + cfg = configs[cname] + series["jit_first_call"].append(cfg.get("jit_first_call_s", np.nan)) + series["jit_steady_state"].append(cfg.get("jit_steady_state_s", np.nan)) + series["vmap_per_call"].append(cfg.get("vmap_per_call_s", np.nan)) + if is_einstein: + series["closure_cache_first"].append( + cfg.get("closure_cache_first_call_s", np.nan) + ) + series["closure_cache_second"].append( + cfg.get("closure_cache_second_call_s", np.nan) + ) + + # Drop series that are entirely nan/zero — nothing to plot. + series = { + k: v for k, v in series.items() + if any(np.isfinite(x) and x > 0 for x in v) + } + if not series: + return + + n_cfgs = len(config_names) + n_series = len(series) + y = np.arange(n_cfgs) + bar_height = 0.8 / n_series + + cmap = plt.get_cmap("tab10") + label_map = { + "jit_first_call": "JIT first call", + "jit_steady_state": "JIT steady-state", + "vmap_per_call": "vmap per-call", + "closure_cache_first": "closure cache — first", + "closure_cache_second": "closure cache — second", + } + + fig, ax = plt.subplots(figsize=(10, max(3, 0.45 * n_cfgs + 1.5))) + for j, (key, vals) in enumerate(series.items()): + arr = np.array(vals, dtype=float) + offset = (j - (n_series - 1) / 2) * bar_height + ax.barh( + y + offset, + arr, + height=bar_height, + label=label_map.get(key, key), + color=cmap(j % cmap.N), + edgecolor="white", + ) + + ax.set_yticks(y) + ax.set_yticklabels(config_names, fontsize=9) + ax.invert_yaxis() + ax.set_xscale("log") + ax.set_xlabel("Time per call (s, log scale)") + ax.set_title(f"{cell_id} — latent timings", fontsize=11, fontweight="bold") + ax.legend(loc="lower right", fontsize=8) + ax.grid(True, axis="x", linestyle=":", alpha=0.5) + fig.tight_layout() + fig.savefig(png_path, dpi=150) + plt.close(fig) + + +def main() -> int: + args = _parse_args() + + if args.cell: + cells: list[tuple[str, str]] = [] + for spec in args.cell: + parts = spec.split("/") + if len(parts) != 2: + sys.stderr.write( + f"bad --cell argument: {spec!r} (expected class/latent)\n" + ) + return 2 + cells.append((parts[0], parts[1])) + else: + cells = _discover_cells(args.output_root) + + if not cells: + sys.stderr.write(f"no cells found under {args.output_root}\n") + return 1 + + for (cls, latent) in cells: + cell_id = f"{cls}/{latent}" + cell_dir = args.output_root / cls / latent + if not cell_dir.exists(): + sys.stderr.write(f" skipping {cell_id}: dir missing\n") + continue + + comparison = _aggregate_cell(cell_dir) + if not comparison["configs"]: + sys.stderr.write(f" skipping {cell_id}: no per-config JSONs found\n") + continue + + comparison_path = cell_dir / "comparison.json" + png_path = cell_dir / "comparison.png" + comparison_path.write_text(json.dumps(comparison, indent=2, default=str)) + _render_png(comparison, cell_id, png_path) + + print(_render_table(comparison, cell_id)) + print(f" -> {comparison_path}") + print(f" -> {png_path}\n") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/latent/imaging/effective_einstein_radius.py b/latent/imaging/effective_einstein_radius.py new file mode 100644 index 0000000..d8c77a0 --- /dev/null +++ b/latent/imaging/effective_einstein_radius.py @@ -0,0 +1,238 @@ +""" +Latent profiling: effective_einstein_radius +============================================= + +Profiles the effective_einstein_radius computation in isolation. + +This is the most expensive of the five default latents. It solves for the +tangential critical curve via ``LensCalc.einstein_radius_jit_from``, which +wraps ``jax_zero_contour.ZeroSolver`` — a marching-squares contour finder +that uses ``jax.lax.cond`` / ``jax.lax.while_loop`` for early termination. +The while-loop makes this latent incompatible with ``jax.vmap``; the +production code path therefore uses ``LATENT_BATCH_MODE='jit'`` (loop over +samples in Python, one JIT call per sample) rather than a batched vmap. + +Cache behaviour: ``LensCalc`` maintains a ``_zero_contour_cache`` dict keyed +on ``(kind, pixel_scales, tol, max_newton)``. On the first call the +``(f, ZeroSolver)`` pair is built and cached; subsequent calls reuse it and +therefore hit JAX's XLA compile cache. This script explicitly surfaces the +first-call vs second-call cost by constructing two separate ``LensCalc`` +instances (each starting with an empty cache) and timing them independently. +""" + +import argparse +import json +import os +import time +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +import tempfile + +import autofit as af +import autolens as al +from autolens import fixtures +from autoconf import conf +from autolens.analysis.latent import LATENT_FUNCTIONS + +# AUTOLENS_PROFILING_SMOKE=1 short-circuit. +import sys as _sys, os as _os +if _os.environ.get("AUTOLENS_PROFILING_SMOKE") == "1": + print(f"[smoke] {__file__}: imports OK; exiting.") + _sys.exit(0) + +LATENT_KEY = "effective_einstein_radius" + + +def _push_single_latent_config(latent_key: str) -> Path: + """Write a temp config dir with only latent_key enabled and push it.""" + tmpdir = Path(tempfile.mkdtemp(prefix="latent_cfg_")) + yaml_lines = [ + f"{k}: {'true' if k == latent_key else 'false'}" + for k in LATENT_FUNCTIONS + ] + (tmpdir / "latent.yaml").write_text("\n".join(yaml_lines) + "\n") + conf.instance.push(str(tmpdir)) + return tmpdir + + +def _time_closure_cache(tracer, dataset, xp=np): + """Time first vs second call of the effective_einstein_radius latent function. + + Constructs a fresh LensCalc for each call so the closure cache starts empty, + isolating the first-build vs cache-hit cost difference. + + Returns (first_call_s, second_call_s) floats, or (nan, nan) on error. + """ + from autolens.analysis.latent import effective_einstein_radius + from autolens.imaging.fit_imaging import FitImaging + + # Build a minimal fit object that effective_einstein_radius can use. + # We need fit.tracer and fit.dataset.grids.lp. + try: + # Config already pushed by main(); the pushed config has this latent enabled. + + analysis = al.AnalysisImaging(dataset=dataset, use_jax=(xp is not np), magzero=25.0) + + lens = af.Model(al.Galaxy, redshift=0.5, mass=al.mp.Isothermal, bulge=al.lp.Sersic) + source = af.Model(al.Galaxy, redshift=1.0, bulge=al.lp.Sersic) + model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) + + params = model.physical_values_from_prior_medians + instance = model.instance_from_vector(vector=params) + fit = al.FitImaging( + dataset=dataset, + tracer=al.Tracer(galaxies=list(instance.galaxies)), + xp=xp, + ) + + # First call — fresh LensCalc, cold closure cache. + t0 = time.perf_counter() + _ = effective_einstein_radius(fit=fit, magzero=25.0, xp=xp) + first_call_s = time.perf_counter() - t0 + + # Second call — LensCalc on `fit.tracer` still has its cache populated + # from above. Construct a new LensCalc explicitly to test warm-cache path. + from autogalaxy.operate.lens_calc import LensCalc + lc2 = LensCalc.from_mass_obj(fit.tracer) + init_guess = jnp.array([[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]]) + # Warm the ZeroSolver cache on lc2 first (mirrors the cold first call). + if xp is not np: + _ = lc2.einstein_radius_jit_from(init_guess=init_guess) + t0 = time.perf_counter() + _ = lc2.einstein_radius_jit_from(init_guess=init_guess) + second_call_s = time.perf_counter() - t0 + else: + _ = lc2.einstein_radius_from(grid=dataset.grids.lp) + t0 = time.perf_counter() + _ = lc2.einstein_radius_from(grid=dataset.grids.lp) + second_call_s = time.perf_counter() - t0 + + return first_call_s, second_call_s + except Exception: + return float("nan"), float("nan") + + +def main(config_name: str, output_dir: Path, use_mixed_precision: bool) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + # Isolate THIS latent only via conf override + _push_single_latent_config(LATENT_KEY) + + dataset = fixtures.make_masked_imaging_7x7() + lens = af.Model(al.Galaxy, redshift=0.5, mass=al.mp.Isothermal, bulge=al.lp.Sersic) + source = af.Model(al.Galaxy, redshift=1.0, bulge=al.lp.Sersic) + model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) + + analysis_np = al.AnalysisImaging(dataset=dataset, use_jax=False, magzero=25.0) + analysis_jax = al.AnalysisImaging(dataset=dataset, use_jax=True, magzero=25.0) + + params = jnp.array(model.physical_values_from_prior_medians) + + # === Eager numpy baseline === + t0 = time.perf_counter() + eager_values = analysis_np.compute_latent_variables(np.asarray(params), model) + eager_t = time.perf_counter() - t0 + eager_value = float(eager_values[0]) + + # === Closure cache first-call vs second-call (numpy path) === + closure_cache_first_call_s, closure_cache_second_call_s = _time_closure_cache( + tracer=None, dataset=dataset, xp=np + ) + + # === JIT compile + first call + steady-state === + fn = jax.jit(lambda p: analysis_jax.compute_latent_variables(p, model)) + lower_t = compile_t = first_t = steady_t = float("nan") + jit_value = float("nan") + jit_error = None + try: + t0 = time.perf_counter() + lowered = fn.lower(params) + lower_t = time.perf_counter() - t0 + + t0 = time.perf_counter() + compiled = lowered.compile() + compile_t = time.perf_counter() - t0 + + t0 = time.perf_counter() + first = compiled(params) + try: + jax.block_until_ready(first[0]) + except Exception: + pass + first_t = time.perf_counter() - t0 + jit_value = float(first[0]) + + steady_ts = [] + for _ in range(10): + t0 = time.perf_counter() + r = compiled(params) + try: + jax.block_until_ready(r[0]) + except Exception: + pass + steady_ts.append(time.perf_counter() - t0) + steady_t = float(np.mean(steady_ts)) + except Exception as exc: + jit_error = repr(exc) + + # === vmap batched === + # NOTE: effective_einstein_radius uses jax.lax.while_loop / lax.cond under + # jax_zero_contour, which is incompatible with jax.vmap. We still attempt + # it so the sweep surfaces the error cleanly rather than silently skipping. + vmap_t = float("nan") + vmap_value = float("nan") + vmap_error = None + batch_size = 3 + batched = jnp.tile(params[None, :], (batch_size, 1)) + try: + vfn = jax.jit(jax.vmap(lambda p: analysis_jax.compute_latent_variables(p, model))) + warm = vfn(batched) + try: + jax.block_until_ready(warm[0]) + except Exception: + pass + t0 = time.perf_counter() + r = vfn(batched) + try: + jax.block_until_ready(r[0]) + except Exception: + pass + vmap_t = (time.perf_counter() - t0) / batch_size + vmap_value = float(r[0][0]) + except Exception as exc: + vmap_error = repr(exc) + + record = { + "latent_key": LATENT_KEY, + "config_name": config_name, + "use_mixed_precision": use_mixed_precision, + "eager_value": eager_value, + "eager_time_s": eager_t, + "closure_cache_first_call_s": closure_cache_first_call_s, + "closure_cache_second_call_s": closure_cache_second_call_s, + "jit_value": jit_value, + "jit_lower_s": lower_t, + "jit_compile_s": compile_t, + "jit_first_call_s": first_t, + "jit_steady_state_s": steady_t, + "jit_error": jit_error, + "vmap_per_call_s": vmap_t, + "vmap_value": vmap_value, + "vmap_error": vmap_error, + } + out_path = output_dir / f"{config_name}.json" + out_path.write_text(json.dumps(record, indent=2)) + print(f"WROTE {out_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config-name", required=True) + parser.add_argument("--output-dir", required=True, type=Path) + parser.add_argument("--use-mixed-precision", action="store_true") + args = parser.parse_args() + main(args.config_name, args.output_dir, args.use_mixed_precision) diff --git a/latent/imaging/magnification.py b/latent/imaging/magnification.py new file mode 100644 index 0000000..9a5a754 --- /dev/null +++ b/latent/imaging/magnification.py @@ -0,0 +1,165 @@ +""" +Latent profiling: magnification +================================= + +Profiles the magnification computation in isolation. + +The magnification latent is defined as the ratio of image-plane (lensed) +to source-plane (intrinsic) flux: ``total_lensed_source_flux_mujy / total_source_flux_mujy``. +The magzero conversion cancels in the ratio, so the result is dimensionless and +independent of the chosen zero-point. Under JIT the two flux paths both compile +into the same trace; expect this latent to cost roughly the sum of those two +components minus any shared subexpressions XLA can fuse. Because the division +of two scalar JIT outputs is trivially vectorisable, vmap is expected to work +cleanly for this latent. +""" + +import argparse +import json +import os +import time +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +import tempfile + +import autofit as af +import autolens as al +from autolens import fixtures +from autoconf import conf +from autolens.analysis.latent import LATENT_FUNCTIONS + +# AUTOLENS_PROFILING_SMOKE=1 short-circuit. +import sys as _sys, os as _os +if _os.environ.get("AUTOLENS_PROFILING_SMOKE") == "1": + print(f"[smoke] {__file__}: imports OK; exiting.") + _sys.exit(0) + +LATENT_KEY = "magnification" + + +def _push_single_latent_config(latent_key: str) -> Path: + """Write a temp config dir with only latent_key enabled and push it.""" + tmpdir = Path(tempfile.mkdtemp(prefix="latent_cfg_")) + yaml_lines = [ + f"{k}: {'true' if k == latent_key else 'false'}" + for k in LATENT_FUNCTIONS + ] + (tmpdir / "latent.yaml").write_text("\n".join(yaml_lines) + "\n") + conf.instance.push(str(tmpdir)) + return tmpdir + + +def main(config_name: str, output_dir: Path, use_mixed_precision: bool) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + # Isolate THIS latent only via conf override + _push_single_latent_config(LATENT_KEY) + + dataset = fixtures.make_masked_imaging_7x7() + lens = af.Model(al.Galaxy, redshift=0.5, mass=al.mp.Isothermal, bulge=al.lp.Sersic) + source = af.Model(al.Galaxy, redshift=1.0, bulge=al.lp.Sersic) + model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) + + analysis_np = al.AnalysisImaging(dataset=dataset, use_jax=False, magzero=25.0) + analysis_jax = al.AnalysisImaging(dataset=dataset, use_jax=True, magzero=25.0) + + params = jnp.array(model.physical_values_from_prior_medians) + + # === Eager numpy baseline === + t0 = time.perf_counter() + eager_values = analysis_np.compute_latent_variables(np.asarray(params), model) + eager_t = time.perf_counter() - t0 + eager_value = float(eager_values[0]) + + # === JIT compile + first call + steady-state === + fn = jax.jit(lambda p: analysis_jax.compute_latent_variables(p, model)) + lower_t = compile_t = first_t = steady_t = float("nan") + jit_value = float("nan") + jit_error = None + try: + t0 = time.perf_counter() + lowered = fn.lower(params) + lower_t = time.perf_counter() - t0 + + t0 = time.perf_counter() + compiled = lowered.compile() + compile_t = time.perf_counter() - t0 + + t0 = time.perf_counter() + first = compiled(params) + try: + jax.block_until_ready(first[0]) + except Exception: + pass + first_t = time.perf_counter() - t0 + jit_value = float(first[0]) + + steady_ts = [] + for _ in range(10): + t0 = time.perf_counter() + r = compiled(params) + try: + jax.block_until_ready(r[0]) + except Exception: + pass + steady_ts.append(time.perf_counter() - t0) + steady_t = float(np.mean(steady_ts)) + except Exception as exc: + jit_error = repr(exc) + + # === vmap batched === + vmap_t = float("nan") + vmap_value = float("nan") + vmap_error = None + batch_size = 3 + batched = jnp.tile(params[None, :], (batch_size, 1)) + try: + vfn = jax.jit(jax.vmap(lambda p: analysis_jax.compute_latent_variables(p, model))) + warm = vfn(batched) + try: + jax.block_until_ready(warm[0]) + except Exception: + pass + t0 = time.perf_counter() + r = vfn(batched) + try: + jax.block_until_ready(r[0]) + except Exception: + pass + vmap_t = (time.perf_counter() - t0) / batch_size + vmap_value = float(r[0][0]) + except Exception as exc: + vmap_error = repr(exc) + + record = { + "latent_key": LATENT_KEY, + "config_name": config_name, + "use_mixed_precision": use_mixed_precision, + "eager_value": eager_value, + "eager_time_s": eager_t, + "jit_value": jit_value, + "jit_lower_s": lower_t, + "jit_compile_s": compile_t, + "jit_first_call_s": first_t, + "jit_steady_state_s": steady_t, + "jit_error": jit_error, + "vmap_per_call_s": vmap_t, + "vmap_value": vmap_value, + "vmap_error": vmap_error, + } + out_path = output_dir / f"{config_name}.json" + out_path.write_text(json.dumps(record, indent=2)) + print(f"WROTE {out_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config-name", required=True) + parser.add_argument("--output-dir", required=True, type=Path) + parser.add_argument("--use-mixed-precision", action="store_true") + args = parser.parse_args() + main(args.config_name, args.output_dir, args.use_mixed_precision) diff --git a/latent/imaging/total_lens_flux_mujy.py b/latent/imaging/total_lens_flux_mujy.py new file mode 100644 index 0000000..cbd3e22 --- /dev/null +++ b/latent/imaging/total_lens_flux_mujy.py @@ -0,0 +1,164 @@ +""" +Latent profiling: total_lens_flux_mujy +======================================== + +Profiles the total_lens_flux_mujy computation in isolation. + +This latent sums the image-plane flux of the lens galaxy (galaxies[0]) and +converts it to microjanskies via the AB-magnitude zero-point. It is the +cheapest of the three flux latents because it reads directly from +``fit.galaxy_image_dict`` without re-tracing through the source plane. The +dominant cost is the magzero-normalised AB-mag conversion, which is a small +handful of transcendental operations on a scalar. Under JIT the whole chain +compiles to a single fused kernel. +""" + +import argparse +import json +import os +import time +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +import tempfile + +import autofit as af +import autolens as al +from autolens import fixtures +from autoconf import conf +from autolens.analysis.latent import LATENT_FUNCTIONS + +# AUTOLENS_PROFILING_SMOKE=1 short-circuit. +import sys as _sys, os as _os +if _os.environ.get("AUTOLENS_PROFILING_SMOKE") == "1": + print(f"[smoke] {__file__}: imports OK; exiting.") + _sys.exit(0) + +LATENT_KEY = "total_lens_flux_mujy" + + +def _push_single_latent_config(latent_key: str) -> Path: + """Write a temp config dir with only latent_key enabled and push it.""" + tmpdir = Path(tempfile.mkdtemp(prefix="latent_cfg_")) + yaml_lines = [ + f"{k}: {'true' if k == latent_key else 'false'}" + for k in LATENT_FUNCTIONS + ] + (tmpdir / "latent.yaml").write_text("\n".join(yaml_lines) + "\n") + conf.instance.push(str(tmpdir)) + return tmpdir + + +def main(config_name: str, output_dir: Path, use_mixed_precision: bool) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + # Isolate THIS latent only via conf override + _push_single_latent_config(LATENT_KEY) + + dataset = fixtures.make_masked_imaging_7x7() + lens = af.Model(al.Galaxy, redshift=0.5, mass=al.mp.Isothermal, bulge=al.lp.Sersic) + source = af.Model(al.Galaxy, redshift=1.0, bulge=al.lp.Sersic) + model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) + + analysis_np = al.AnalysisImaging(dataset=dataset, use_jax=False, magzero=25.0) + analysis_jax = al.AnalysisImaging(dataset=dataset, use_jax=True, magzero=25.0) + + params = jnp.array(model.physical_values_from_prior_medians) + + # === Eager numpy baseline === + t0 = time.perf_counter() + eager_values = analysis_np.compute_latent_variables(np.asarray(params), model) + eager_t = time.perf_counter() - t0 + eager_value = float(eager_values[0]) + + # === JIT compile + first call + steady-state === + fn = jax.jit(lambda p: analysis_jax.compute_latent_variables(p, model)) + lower_t = compile_t = first_t = steady_t = float("nan") + jit_value = float("nan") + jit_error = None + try: + t0 = time.perf_counter() + lowered = fn.lower(params) + lower_t = time.perf_counter() - t0 + + t0 = time.perf_counter() + compiled = lowered.compile() + compile_t = time.perf_counter() - t0 + + t0 = time.perf_counter() + first = compiled(params) + try: + jax.block_until_ready(first[0]) + except Exception: + pass + first_t = time.perf_counter() - t0 + jit_value = float(first[0]) + + steady_ts = [] + for _ in range(10): + t0 = time.perf_counter() + r = compiled(params) + try: + jax.block_until_ready(r[0]) + except Exception: + pass + steady_ts.append(time.perf_counter() - t0) + steady_t = float(np.mean(steady_ts)) + except Exception as exc: + jit_error = repr(exc) + + # === vmap batched === + vmap_t = float("nan") + vmap_value = float("nan") + vmap_error = None + batch_size = 3 + batched = jnp.tile(params[None, :], (batch_size, 1)) + try: + vfn = jax.jit(jax.vmap(lambda p: analysis_jax.compute_latent_variables(p, model))) + warm = vfn(batched) + try: + jax.block_until_ready(warm[0]) + except Exception: + pass + t0 = time.perf_counter() + r = vfn(batched) + try: + jax.block_until_ready(r[0]) + except Exception: + pass + vmap_t = (time.perf_counter() - t0) / batch_size + vmap_value = float(r[0][0]) + except Exception as exc: + vmap_error = repr(exc) + + record = { + "latent_key": LATENT_KEY, + "config_name": config_name, + "use_mixed_precision": use_mixed_precision, + "eager_value": eager_value, + "eager_time_s": eager_t, + "jit_value": jit_value, + "jit_lower_s": lower_t, + "jit_compile_s": compile_t, + "jit_first_call_s": first_t, + "jit_steady_state_s": steady_t, + "jit_error": jit_error, + "vmap_per_call_s": vmap_t, + "vmap_value": vmap_value, + "vmap_error": vmap_error, + } + out_path = output_dir / f"{config_name}.json" + out_path.write_text(json.dumps(record, indent=2)) + print(f"WROTE {out_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config-name", required=True) + parser.add_argument("--output-dir", required=True, type=Path) + parser.add_argument("--use-mixed-precision", action="store_true") + args = parser.parse_args() + main(args.config_name, args.output_dir, args.use_mixed_precision) diff --git a/latent/imaging/total_lensed_source_flux_mujy.py b/latent/imaging/total_lensed_source_flux_mujy.py new file mode 100644 index 0000000..1de3224 --- /dev/null +++ b/latent/imaging/total_lensed_source_flux_mujy.py @@ -0,0 +1,165 @@ +""" +Latent profiling: total_lensed_source_flux_mujy +================================================= + +Profiles the total_lensed_source_flux_mujy computation in isolation. + +This latent measures the image-plane integrated flux of the source galaxy +after gravitational lensing, reading from ``fit.galaxy_image_dict[tracer.galaxies[-1]]`` +and converting to microjanskies via the AB-magnitude zero-point. The lensed +flux is the observable that constrains both the lens mass model and source +luminosity; comparing it against total_source_flux_mujy gives the magnification. +Under JIT the path is a sum over a masked array followed by two scalar +transcendental operations — typically compiles to the same order of cost as +total_lens_flux_mujy. +""" + +import argparse +import json +import os +import time +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +import tempfile + +import autofit as af +import autolens as al +from autolens import fixtures +from autoconf import conf +from autolens.analysis.latent import LATENT_FUNCTIONS + +# AUTOLENS_PROFILING_SMOKE=1 short-circuit. +import sys as _sys, os as _os +if _os.environ.get("AUTOLENS_PROFILING_SMOKE") == "1": + print(f"[smoke] {__file__}: imports OK; exiting.") + _sys.exit(0) + +LATENT_KEY = "total_lensed_source_flux_mujy" + + +def _push_single_latent_config(latent_key: str) -> Path: + """Write a temp config dir with only latent_key enabled and push it.""" + tmpdir = Path(tempfile.mkdtemp(prefix="latent_cfg_")) + yaml_lines = [ + f"{k}: {'true' if k == latent_key else 'false'}" + for k in LATENT_FUNCTIONS + ] + (tmpdir / "latent.yaml").write_text("\n".join(yaml_lines) + "\n") + conf.instance.push(str(tmpdir)) + return tmpdir + + +def main(config_name: str, output_dir: Path, use_mixed_precision: bool) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + # Isolate THIS latent only via conf override + _push_single_latent_config(LATENT_KEY) + + dataset = fixtures.make_masked_imaging_7x7() + lens = af.Model(al.Galaxy, redshift=0.5, mass=al.mp.Isothermal, bulge=al.lp.Sersic) + source = af.Model(al.Galaxy, redshift=1.0, bulge=al.lp.Sersic) + model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) + + analysis_np = al.AnalysisImaging(dataset=dataset, use_jax=False, magzero=25.0) + analysis_jax = al.AnalysisImaging(dataset=dataset, use_jax=True, magzero=25.0) + + params = jnp.array(model.physical_values_from_prior_medians) + + # === Eager numpy baseline === + t0 = time.perf_counter() + eager_values = analysis_np.compute_latent_variables(np.asarray(params), model) + eager_t = time.perf_counter() - t0 + eager_value = float(eager_values[0]) + + # === JIT compile + first call + steady-state === + fn = jax.jit(lambda p: analysis_jax.compute_latent_variables(p, model)) + lower_t = compile_t = first_t = steady_t = float("nan") + jit_value = float("nan") + jit_error = None + try: + t0 = time.perf_counter() + lowered = fn.lower(params) + lower_t = time.perf_counter() - t0 + + t0 = time.perf_counter() + compiled = lowered.compile() + compile_t = time.perf_counter() - t0 + + t0 = time.perf_counter() + first = compiled(params) + try: + jax.block_until_ready(first[0]) + except Exception: + pass + first_t = time.perf_counter() - t0 + jit_value = float(first[0]) + + steady_ts = [] + for _ in range(10): + t0 = time.perf_counter() + r = compiled(params) + try: + jax.block_until_ready(r[0]) + except Exception: + pass + steady_ts.append(time.perf_counter() - t0) + steady_t = float(np.mean(steady_ts)) + except Exception as exc: + jit_error = repr(exc) + + # === vmap batched === + vmap_t = float("nan") + vmap_value = float("nan") + vmap_error = None + batch_size = 3 + batched = jnp.tile(params[None, :], (batch_size, 1)) + try: + vfn = jax.jit(jax.vmap(lambda p: analysis_jax.compute_latent_variables(p, model))) + warm = vfn(batched) + try: + jax.block_until_ready(warm[0]) + except Exception: + pass + t0 = time.perf_counter() + r = vfn(batched) + try: + jax.block_until_ready(r[0]) + except Exception: + pass + vmap_t = (time.perf_counter() - t0) / batch_size + vmap_value = float(r[0][0]) + except Exception as exc: + vmap_error = repr(exc) + + record = { + "latent_key": LATENT_KEY, + "config_name": config_name, + "use_mixed_precision": use_mixed_precision, + "eager_value": eager_value, + "eager_time_s": eager_t, + "jit_value": jit_value, + "jit_lower_s": lower_t, + "jit_compile_s": compile_t, + "jit_first_call_s": first_t, + "jit_steady_state_s": steady_t, + "jit_error": jit_error, + "vmap_per_call_s": vmap_t, + "vmap_value": vmap_value, + "vmap_error": vmap_error, + } + out_path = output_dir / f"{config_name}.json" + out_path.write_text(json.dumps(record, indent=2)) + print(f"WROTE {out_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config-name", required=True) + parser.add_argument("--output-dir", required=True, type=Path) + parser.add_argument("--use-mixed-precision", action="store_true") + args = parser.parse_args() + main(args.config_name, args.output_dir, args.use_mixed_precision) diff --git a/latent/imaging/total_source_flux_mujy.py b/latent/imaging/total_source_flux_mujy.py new file mode 100644 index 0000000..0fcbc15 --- /dev/null +++ b/latent/imaging/total_source_flux_mujy.py @@ -0,0 +1,167 @@ +""" +Latent profiling: total_source_flux_mujy +========================================== + +Profiles the total_source_flux_mujy computation in isolation. + +This latent measures the intrinsic source-plane flux of the source galaxy, +unconvolved by lensing magnification. It reads from +``fit.tracer_linear_light_profiles_to_light_profiles`` so that linear light +profiles (whose intensities are solved by the inversion at fit time) contribute +their correct values. This property is a pass-through on non-linear fits, so +both the numpy and JAX code paths work uniformly. The source-plane image is +evaluated on the lens-plane grid, which is more expensive than reading a +pre-computed image from galaxy_image_dict: the light profile's +``image_2d_from`` call is inside the JIT trace. Expect compile times and +steady-state cost somewhat higher than the lensed flux latent. +""" + +import argparse +import json +import os +import time +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np + +import tempfile + +import autofit as af +import autolens as al +from autolens import fixtures +from autoconf import conf +from autolens.analysis.latent import LATENT_FUNCTIONS + +# AUTOLENS_PROFILING_SMOKE=1 short-circuit. +import sys as _sys, os as _os +if _os.environ.get("AUTOLENS_PROFILING_SMOKE") == "1": + print(f"[smoke] {__file__}: imports OK; exiting.") + _sys.exit(0) + +LATENT_KEY = "total_source_flux_mujy" + + +def _push_single_latent_config(latent_key: str) -> Path: + """Write a temp config dir with only latent_key enabled and push it.""" + tmpdir = Path(tempfile.mkdtemp(prefix="latent_cfg_")) + yaml_lines = [ + f"{k}: {'true' if k == latent_key else 'false'}" + for k in LATENT_FUNCTIONS + ] + (tmpdir / "latent.yaml").write_text("\n".join(yaml_lines) + "\n") + conf.instance.push(str(tmpdir)) + return tmpdir + + +def main(config_name: str, output_dir: Path, use_mixed_precision: bool) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + # Isolate THIS latent only via conf override + _push_single_latent_config(LATENT_KEY) + + dataset = fixtures.make_masked_imaging_7x7() + lens = af.Model(al.Galaxy, redshift=0.5, mass=al.mp.Isothermal, bulge=al.lp.Sersic) + source = af.Model(al.Galaxy, redshift=1.0, bulge=al.lp.Sersic) + model = af.Collection(galaxies=af.Collection(lens=lens, source=source)) + + analysis_np = al.AnalysisImaging(dataset=dataset, use_jax=False, magzero=25.0) + analysis_jax = al.AnalysisImaging(dataset=dataset, use_jax=True, magzero=25.0) + + params = jnp.array(model.physical_values_from_prior_medians) + + # === Eager numpy baseline === + t0 = time.perf_counter() + eager_values = analysis_np.compute_latent_variables(np.asarray(params), model) + eager_t = time.perf_counter() - t0 + eager_value = float(eager_values[0]) + + # === JIT compile + first call + steady-state === + fn = jax.jit(lambda p: analysis_jax.compute_latent_variables(p, model)) + lower_t = compile_t = first_t = steady_t = float("nan") + jit_value = float("nan") + jit_error = None + try: + t0 = time.perf_counter() + lowered = fn.lower(params) + lower_t = time.perf_counter() - t0 + + t0 = time.perf_counter() + compiled = lowered.compile() + compile_t = time.perf_counter() - t0 + + t0 = time.perf_counter() + first = compiled(params) + try: + jax.block_until_ready(first[0]) + except Exception: + pass + first_t = time.perf_counter() - t0 + jit_value = float(first[0]) + + steady_ts = [] + for _ in range(10): + t0 = time.perf_counter() + r = compiled(params) + try: + jax.block_until_ready(r[0]) + except Exception: + pass + steady_ts.append(time.perf_counter() - t0) + steady_t = float(np.mean(steady_ts)) + except Exception as exc: + jit_error = repr(exc) + + # === vmap batched === + vmap_t = float("nan") + vmap_value = float("nan") + vmap_error = None + batch_size = 3 + batched = jnp.tile(params[None, :], (batch_size, 1)) + try: + vfn = jax.jit(jax.vmap(lambda p: analysis_jax.compute_latent_variables(p, model))) + warm = vfn(batched) + try: + jax.block_until_ready(warm[0]) + except Exception: + pass + t0 = time.perf_counter() + r = vfn(batched) + try: + jax.block_until_ready(r[0]) + except Exception: + pass + vmap_t = (time.perf_counter() - t0) / batch_size + vmap_value = float(r[0][0]) + except Exception as exc: + vmap_error = repr(exc) + + record = { + "latent_key": LATENT_KEY, + "config_name": config_name, + "use_mixed_precision": use_mixed_precision, + "eager_value": eager_value, + "eager_time_s": eager_t, + "jit_value": jit_value, + "jit_lower_s": lower_t, + "jit_compile_s": compile_t, + "jit_first_call_s": first_t, + "jit_steady_state_s": steady_t, + "jit_error": jit_error, + "vmap_per_call_s": vmap_t, + "vmap_value": vmap_value, + "vmap_error": vmap_error, + } + out_path = output_dir / f"{config_name}.json" + out_path.write_text(json.dumps(record, indent=2)) + print(f"WROTE {out_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config-name", required=True) + parser.add_argument("--output-dir", required=True, type=Path) + parser.add_argument("--use-mixed-precision", action="store_true") + args = parser.parse_args() + main(args.config_name, args.output_dir, args.use_mixed_precision) diff --git a/latent/sweep.py b/latent/sweep.py new file mode 100644 index 0000000..3865aab --- /dev/null +++ b/latent/sweep.py @@ -0,0 +1,274 @@ +"""Multi-config latent profiling driver. + +Runs each latent cell across the CPU/GPU x fp64/mp matrix (4 configs per +cell locally; HPC A100 configs are dispatched separately via +`z_projects/profiling/hpc/sync`). + +Each subprocess invokes the per-latent script under +``autolens_profiling/latent//.py`` with the CLI args +``--config-name``, ``--output-dir``, ``--use-mixed-precision``. +Per-config JSONs land at:: + + ///.json + ///.log (captured stdout/stderr) + +Default ``--output-root`` is +``autolens_workspace_developer/jax_profiling/results/latent`` — mirrors the +``jit/`` convention used by ``likelihood_runtime/sweep.py`` and is read by +``aggregate.py`` to produce ``comparison.json`` / ``comparison.png``. + +Usage:: + + # All latent cells, all backends + python latent/sweep.py + + # Skip the heaviest latent during iteration + python latent/sweep.py --skip imaging/effective_einstein_radius + + # Single latent, CPU fp64 only + python latent/sweep.py --only imaging/magnification --skip-gpu --skip-mp +""" + +from __future__ import annotations + +import argparse +import os +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path + + +_REPO_ROOT = Path(__file__).resolve().parents[1] # autolens_profiling/ +_WT_ROOT = _REPO_ROOT.parent # PyAutoLabs-wt// (or PyAutoLabs/) +_DEFAULT_OUTPUT_ROOT = _WT_ROOT / "autolens_workspace_developer" / "jax_profiling" / "results" / "latent" +_DEFAULT_PYTHON = "/home/jammy/venv/PyAutoGPU/bin/python" + + +# (dataset_class, latent_name). Order is roughly cheapest -> heaviest so +# failures surface quickly during iteration. +CELLS: list[tuple[str, str]] = [ + ("imaging", "total_lens_flux_mujy"), + ("imaging", "total_lensed_source_flux_mujy"), + ("imaging", "total_source_flux_mujy"), + ("imaging", "magnification"), + ("imaging", "effective_einstein_radius"), +] + + +@dataclass(frozen=True) +class SweepConfig: + name: str + env_overrides: dict[str, str] + extra_args: tuple[str, ...] + is_gpu: bool + + +# CPU configs explicitly pin platform to cpu. GPU configs explicitly pin to +# cuda — we DO NOT let JAX fall back to CPU on GPU rows, so a missing CUDA +# device fails the run loudly rather than silently producing a CPU number. +CONFIGS: list[SweepConfig] = [ + SweepConfig( + name="local_cpu_fp64", + env_overrides={"JAX_PLATFORM_NAME": "cpu", "JAX_PLATFORMS": "cpu"}, + extra_args=(), + is_gpu=False, + ), + SweepConfig( + name="local_cpu_mp", + env_overrides={"JAX_PLATFORM_NAME": "cpu", "JAX_PLATFORMS": "cpu"}, + extra_args=("--use-mixed-precision",), + is_gpu=False, + ), + SweepConfig( + name="local_gpu_fp64", + env_overrides={"JAX_PLATFORM_NAME": "cuda", "JAX_PLATFORMS": "cuda,cpu"}, + extra_args=(), + is_gpu=True, + ), + SweepConfig( + name="local_gpu_mp", + env_overrides={"JAX_PLATFORM_NAME": "cuda", "JAX_PLATFORMS": "cuda,cpu"}, + extra_args=("--use-mixed-precision",), + is_gpu=True, + ), +] + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__.split("\n\n")[0]) + p.add_argument( + "--only", + nargs="+", + default=None, + metavar="CLASS/LATENT", + help="Only run these cells (e.g. imaging/magnification imaging/effective_einstein_radius).", + ) + p.add_argument( + "--skip", + nargs="+", + default=(), + metavar="CLASS/LATENT", + help="Skip these cells (applied after --only).", + ) + p.add_argument("--skip-cpu", action="store_true", help="Skip local_cpu_* configs.") + p.add_argument("--skip-gpu", action="store_true", help="Skip local_gpu_* configs.") + p.add_argument( + "--skip-mp", + action="store_true", + help="Skip the use_mixed_precision rows (just fp64).", + ) + p.add_argument( + "--output-root", + type=Path, + default=_DEFAULT_OUTPUT_ROOT, + help=f"Root output dir. Default: {_DEFAULT_OUTPUT_ROOT}", + ) + p.add_argument( + "--python", + default=_DEFAULT_PYTHON, + help=f"Python interpreter to invoke per subprocess. Default: {_DEFAULT_PYTHON}", + ) + p.add_argument( + "--dry-run", + action="store_true", + help="Print the planned subprocess commands but don't execute.", + ) + return p.parse_args() + + +def _resolve_cells(args: argparse.Namespace) -> list[tuple[str, str]]: + selected = CELLS + if args.only: + wanted = {c for c in args.only} + selected = [(c, m) for (c, m) in selected if f"{c}/{m}" in wanted] + missing = wanted - {f"{c}/{m}" for (c, m) in selected} + if missing: + sys.stderr.write(f"warning: --only includes unknown cells: {sorted(missing)}\n") + skip = set(args.skip) + selected = [(c, m) for (c, m) in selected if f"{c}/{m}" not in skip] + return selected + + +def _resolve_configs(args: argparse.Namespace) -> list[SweepConfig]: + configs = list(CONFIGS) + if args.skip_cpu: + configs = [c for c in configs if c.is_gpu] + if args.skip_gpu: + configs = [c for c in configs if not c.is_gpu] + if args.skip_mp: + configs = [c for c in configs if "--use-mixed-precision" not in c.extra_args] + return configs + + +def _run_one( + python: str, + script_path: Path, + config: SweepConfig, + out_dir: Path, + dry_run: bool, +) -> tuple[bool, float, str]: + """Run one (cell, config) pair as a subprocess. Returns (ok, elapsed, log_path).""" + out_dir.mkdir(parents=True, exist_ok=True) + log_path = out_dir / f"{config.name}.log" + + cmd = [ + python, + str(script_path), + "--config-name", config.name, + "--output-dir", str(out_dir), + *config.extra_args, + ] + + env = dict(os.environ) + env.update(config.env_overrides) + # numba + matplotlib cache dirs — same workaround the per-cell scripts use. + env.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache") + env.setdefault("MPLCONFIGDIR", "/tmp/matplotlib") + + print(f"\n--- [{config.name}] {script_path.relative_to(_REPO_ROOT)} ---") + print(f" cmd: {' '.join(cmd)}") + print(f" env: {config.env_overrides}") + + if dry_run: + return True, 0.0, "" + + t0 = time.time() + try: + with open(log_path, "w") as log: + proc = subprocess.run( + cmd, + env=env, + stdout=log, + stderr=subprocess.STDOUT, + check=False, + ) + elapsed = time.time() - t0 + ok = proc.returncode == 0 + print(f" {'OK ' if ok else 'FAIL'} ({elapsed:.1f}s, exit={proc.returncode}) -> {log_path.name}") + return ok, elapsed, str(log_path) + except KeyboardInterrupt: + elapsed = time.time() - t0 + print(f" INTERRUPTED after {elapsed:.1f}s; partial log -> {log_path}") + raise + + +def main() -> int: + args = _parse_args() + cells = _resolve_cells(args) + configs = _resolve_configs(args) + + print(f"sweep_latent: {len(cells)} cells x {len(configs)} configs " + f"= {len(cells) * len(configs)} runs") + print(f" cells: {[f'{c}/{m}' for (c, m) in cells]}") + print(f" configs: {[c.name for c in configs]}") + print(f" output: {args.output_root}") + print(f" python: {args.python}") + if args.dry_run: + print(" (dry-run)") + + summary: list[tuple[str, str, bool, float]] = [] + overall_t0 = time.time() + + for (cls, latent) in cells: + script_path = _REPO_ROOT / "latent" / cls / f"{latent}.py" + if not script_path.exists(): + print(f"\n!!! missing script: {script_path}") + for cfg in configs: + summary.append((f"{cls}/{latent}", cfg.name, False, 0.0)) + continue + + out_dir = args.output_root / cls / latent + + for cfg in configs: + try: + ok, elapsed, _log = _run_one( + args.python, script_path, cfg, out_dir, args.dry_run + ) + except KeyboardInterrupt: + print("\n\nsweep interrupted by user") + return 130 + summary.append((f"{cls}/{latent}", cfg.name, ok, elapsed)) + + total = time.time() - overall_t0 + print("\n" + "=" * 70) + print(f"sweep_latent summary ({total:.1f}s total)") + print("=" * 70) + print(f" {'cell':<40}{'config':<22}{'ok':<6}{'elapsed':>10}") + print(f" {'-'*40}{'-'*22}{'-'*6}{'-'*10}") + failures = 0 + for cell, cfg, ok, t in summary: + flag = "OK" if ok else "FAIL" + if not ok: + failures += 1 + print(f" {cell:<40}{cfg:<22}{flag:<6}{t:>9.1f}s") + if failures: + print(f"\n {failures} run(s) FAILED — check the .log files in each cell's output dir.") + return 1 + print("\n All runs OK.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())