-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_profile_cli.py
More file actions
209 lines (181 loc) · 6.95 KB
/
_profile_cli.py
File metadata and controls
209 lines (181 loc) · 6.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""Shared CLI / JSON / auto-simulate helpers for the likelihood scripts.
Used by every script under ``likelihood/{imaging,interferometer,
datacube,point_source}/`` so the per-script boilerplate stays minimal
and the sweep-driver flags (``--config-name``, ``--output-dir``,
``--use-mixed-precision``) and dataset auto-simulate hook are defined
in one place.
Designed to be imported with relative path manipulation since the scripts
live under multiple sibling directories::
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from _profile_cli import (
parse_profile_cli, device_info_dict, resolve_output_paths,
auto_simulate_if_missing,
)
"""
from __future__ import annotations
import argparse
import os
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
@dataclass(frozen=True)
class ProfileCLI:
config_name: Optional[str]
output_dir: Optional[Path]
use_mixed_precision: bool
instrument: Optional[str]
vmap_probe: bool
def parse_profile_cli(default_config_name: Optional[str] = None) -> ProfileCLI:
"""Parse the sweep CLI flags accepted by every per-cell profile script.
Returns ``ProfileCLI(config_name, output_dir, use_mixed_precision,
instrument)``.
When ``--config-name`` is omitted, falls back to ``default_config_name``
(typically inferred from ``JAX_PLATFORM_NAME`` env var or left as ``None``
to preserve the existing single-config filename pattern).
``--instrument`` is optional; when omitted (None) per-cell scripts keep
their module-level hardcoded default (typically ``"sma"`` or ``"hst"``).
"""
parser = argparse.ArgumentParser(
description="Multi-config likelihood profiling driver flags.",
# Keep unknown args; per-script argparse is not exhaustive.
allow_abbrev=False,
)
parser.add_argument(
"--config-name",
default=None,
help=(
"Output-filename label for the multi-config sweep "
"(e.g. local_cpu_fp64, local_gpu_mp, hpc_a100_fp64). "
"When omitted, the script keeps its single-config filename pattern."
),
)
parser.add_argument(
"--output-dir",
default=None,
help=(
"Override results dir. Defaults to "
"<autolens_profiling>/results/likelihood/<class>/."
),
)
parser.add_argument(
"--use-mixed-precision",
action="store_true",
help=(
"Pass use_mixed_precision=True to al.Settings — "
"targeted fp32 paths in the JAX inversion."
),
)
parser.add_argument(
"--instrument",
default=None,
help=(
"Instrument preset to profile. When omitted, the per-cell "
"script's module-level default applies (typically 'sma' for "
"interferometer/datacube cells, 'hst' for imaging)."
),
)
parser.add_argument(
"--vmap-probe",
action="store_true",
help=(
"Probe mode: JIT-vmap the full pipeline at batch=2 and batch=4, "
"read compiled.memory_analysis(), write a vmap_probe.json with "
"the recommended A100 batch_size, and exit before the steady-"
"state timing loop. See vram/README.md for methodology."
),
)
args, _unknown = parser.parse_known_args()
config_name = args.config_name or default_config_name
output_dir = Path(args.output_dir).resolve() if args.output_dir else None
return ProfileCLI(
config_name=config_name,
output_dir=output_dir,
use_mixed_precision=bool(args.use_mixed_precision),
instrument=args.instrument,
vmap_probe=bool(args.vmap_probe),
)
def device_info_dict() -> dict:
"""Capture backend / device / nvidia-smi summary for the current JAX process.
Imports jax lazily so callers can collect this near the JSON write without
re-importing.
"""
import jax
info = {
"backend": jax.default_backend(),
"device": str(jax.devices()[0]),
}
if info["backend"] == "gpu":
try:
out = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=name,memory.used,memory.total",
"--format=csv,noheader",
],
stderr=subprocess.DEVNULL,
timeout=3,
).decode().strip()
info["nvidia_smi"] = out.replace("\n", "; ")
except Exception:
pass
return info
def resolve_output_paths(
cli: ProfileCLI,
default_dir: Path,
default_basename: str,
) -> tuple[Path, Path]:
"""Resolve (json_path, png_path) for the per-cell write.
- When ``cli.config_name`` is set: use ``<output_dir>/<config_name>.{json,png}``.
- Otherwise: use ``<output_dir>/<default_basename>.{json,png}`` to preserve
the existing single-config filename pattern.
- ``cli.output_dir`` overrides ``default_dir`` when set.
"""
results_dir = cli.output_dir if cli.output_dir is not None else default_dir
results_dir.mkdir(parents=True, exist_ok=True)
basename = cli.config_name or default_basename
return results_dir / f"{basename}.json", results_dir / f"{basename}.png"
def auto_simulate_if_missing(
dataset_path: Path,
*,
dataset_type: str,
instrument: str,
workspace_root: Path,
) -> None:
"""If the dataset is missing, invoke the matching simulator script.
``dataset_type`` maps to ``simulators/<dataset_type>.py`` (one of
``imaging``, ``interferometer``, ``point_source``). The simulator is
invoked via subprocess with ``--instrument <instrument>``, so both the
likelihood-fit dataset and a versioned simulator-profiling JSON+PNG
land at the right path in one shot.
The dataset gate uses ``al.util.dataset.should_simulate`` (which also
handles the ``PYAUTO_SMALL_DATASETS=1`` cleanup case). ``autolens`` is
imported lazily so this helper can sit in any module without forcing
the heavy import chain on every caller.
"""
import sys
import autolens as al # noqa: F401 — imported lazily to defer side effects
if not al.util.dataset.should_simulate(str(dataset_path)):
return
simulator_script = workspace_root / "simulators" / f"{dataset_type}.py"
if not simulator_script.exists():
raise FileNotFoundError(
f"Auto-simulate could not find simulator script at {simulator_script}. "
f"Expected one of imaging.py / interferometer.py / point_source.py "
f"under simulators/."
)
print(
f" [auto-simulate] {dataset_path} missing; invoking "
f"simulators/{dataset_type}.py --instrument {instrument}"
)
subprocess.run(
[
sys.executable,
str(simulator_script),
"--instrument", instrument,
"--output-root", str(workspace_root),
],
check=True,
)