Skip to content

Commit bb65db4

Browse files
committed
fix(runs): address lint and runner review issues
1 parent 57f3bec commit bb65db4

10 files changed

Lines changed: 142 additions & 54 deletions

File tree

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ COPY docker-entrypoint.sh /usr/local/bin/entrypoint.sh
105105
RUN chmod +x /usr/local/bin/entrypoint.sh
106106

107107
# ============================================================================
108-
# Bake in model checkpoints from pre-built base image on Docker Hub
108+
# Bake in model checkpoints from pre-built Harbor image
109109
# ============================================================================
110110
# Checkpoints (~10 GB) rarely change, so this layer is placed before pixi
111111
# installs to stay cached even when dependencies update.

run_experiments

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ find_sampleworks_root_upwards() {
3939
resolve_repo_root() {
4040
local source_override="${SAMPLEWORKS_SOURCE_DIR:-}"
4141
if [[ -n "$source_override" ]]; then
42+
if ! is_sampleworks_root "$source_override"; then
43+
cat >&2 <<EOF
44+
SAMPLEWORKS_SOURCE_DIR does not point to a Sampleworks checkout:
45+
$source_override
46+
EOF
47+
return 2
48+
fi
4249
printf '%s\n' "$source_override"
4350
return 0
4451
fi
@@ -191,8 +198,17 @@ for checkpoint_var_and_file in \
191198
fi
192199
done
193200

201+
needs_runtime_paths=1
202+
for arg in "$@"; do
203+
case "$arg" in
204+
--dry-run|--show|--list|-h|--help)
205+
needs_runtime_paths=0
206+
;;
207+
esac
208+
done
209+
194210
source_proteins_csv="${PROTEINS_CSV:-$DATA_DIR/proteins.csv}"
195-
if [[ -f "$source_proteins_csv" ]]; then
211+
if [[ "$needs_runtime_paths" -eq 1 && -f "$source_proteins_csv" ]]; then
196212
# The shared proteins.csv currently contains absolute /data/inputs paths,
197213
# while ACTL mounts the dataset at /mnt/diffuse-shared. Rewrite a per-run
198214
# manifest instead of requiring non-root scientists to create /data symlinks.
@@ -227,15 +243,6 @@ if [[ -n "$explicit_jobs" ]]; then
227243
display_target="$display_target --jobs $explicit_jobs"
228244
fi
229245

230-
needs_runtime_paths=1
231-
for arg in "$@"; do
232-
case "$arg" in
233-
--dry-run|--show|--list|-h|--help)
234-
needs_runtime_paths=0
235-
;;
236-
esac
237-
done
238-
239246
if [[ "$needs_runtime_paths" -eq 1 ]]; then
240247
if [[ ! -f "${PROTEINS_CSV:-$source_proteins_csv}" ]]; then
241248
cat >&2 <<EOF

run_grid_search.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,15 @@ def detect_gpus() -> list[str]:
7272
"""Return CUDA GPU identifiers visible to this grid-search process.
7373
7474
``CUDA_VISIBLE_DEVICES`` wins when set because CUDA remaps those entries to
75-
local process ordinals. Otherwise, ``nvidia-smi`` is used as a best-effort
76-
discovery mechanism and ``["0"]`` is returned as a CPU/test fallback.
75+
local process ordinals. Explicit CUDA "no device" sentinel values return an
76+
empty list. Otherwise, ``nvidia-smi`` is used as a best-effort discovery
77+
mechanism and ``["0"]`` is returned as a CPU/test fallback.
7778
"""
78-
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
79-
if cuda_visible:
79+
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
80+
cuda_visible_key = cuda_visible.lower()
81+
if cuda_visible_key in {"none", "void", "nodevfiles"}:
82+
return []
83+
if cuda_visible and cuda_visible_key != "all":
8084
gpus = [g.strip() for g in cuda_visible.split(",") if g.strip()]
8185
try:
8286
result = subprocess.run(
@@ -85,9 +89,7 @@ def detect_gpus() -> list[str]:
8589
text=True,
8690
)
8791
if result.returncode == 0:
88-
visible = [
89-
g.strip() for g in result.stdout.strip().split("\n") if g.strip()
90-
]
92+
visible = [g.strip() for g in result.stdout.strip().split("\n") if g.strip()]
9193
if all(g.isdigit() for g in gpus + visible):
9294
missing = sorted(set(gpus).difference(visible), key=int)
9395
if missing:
@@ -361,6 +363,10 @@ def main(args: argparse.Namespace):
361363
log.info(f"Detected {len(gpus)} GPUs: {gpus}")
362364
if args.max_parallel != "auto":
363365
gpus = gpus[: int(args.max_parallel)]
366+
if not gpus:
367+
raise ValueError(
368+
"No CUDA GPUs are visible; unset CUDA_VISIBLE_DEVICES=none or use a GPU pod"
369+
)
364370

365371
log_args(args, gpus)
366372

src/sampleworks/runs/cli.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@ def _build_parser() -> argparse.ArgumentParser:
8484
default="",
8585
help="Preset name from experiments/ or path to a .toml file. Default: full_8gpu.",
8686
)
87-
parser.add_argument("--list", action="store_true", help="List experiments/*.toml presets and exit")
87+
parser.add_argument(
88+
"--list",
89+
action="store_true",
90+
help="List experiments/*.toml presets and exit",
91+
)
8892
parser.add_argument("--show", action="store_true", help="Print the resolved preset and exit")
8993
parser.add_argument(
9094
"--dry-run",

src/sampleworks/runs/loader.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Load presets from TOML and apply runtime overrides.
22
33
Resolution order for every string value (defaults block and ``args``):
4-
1. ``${VAR}`` references are resolved against the process environment,
4+
1. ``--set <dotted-path>=<value>`` CLI overrides are applied to the raw TOML
5+
dict by :func:`load_preset`, so overridden values participate in
6+
interpolation.
7+
2. ``${VAR}`` references are resolved against the process environment,
58
with the preset's ``[defaults]`` block filling in any unset keys.
6-
2. ``--set <dotted-path>=<value>`` CLI overrides are applied last.
79
"""
810

911
from __future__ import annotations
@@ -19,6 +21,7 @@
1921

2022

2123
_EXPERIMENTS_DIR_NAME = "experiments"
24+
_MAX_EXPAND_ITERATIONS = 32
2225
_VAR_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
2326
_TOP_LEVEL_KEYS = frozenset({"description", "defaults", "shared_args", "jobs"})
2427

@@ -325,10 +328,15 @@ def _find_in_list(items: list[Any], key: str, *, where: str) -> int:
325328
Raises
326329
------
327330
KeyError
328-
If no element with the given name exists.
331+
If no element with the given name or index exists.
329332
"""
330333
if key.isdigit() or (key.startswith("-") and key[1:].isdigit()):
331-
return int(key)
334+
index = int(key)
335+
try:
336+
items[index]
337+
except IndexError:
338+
raise KeyError(f"No list element at index {index} at {where!r}") from None
339+
return index
332340
for i, item in enumerate(items):
333341
if isinstance(item, dict) and item.get("name") == key:
334342
return i
@@ -442,6 +450,8 @@ def _expand(text: str, env: dict[str, str]) -> str:
442450
------
443451
KeyError
444452
If a referenced variable is not in ``env``.
453+
ValueError
454+
If recursive variable interpolation does not converge.
445455
"""
446456

447457
def repl(match: re.Match[str]) -> str:
@@ -451,12 +461,16 @@ def repl(match: re.Match[str]) -> str:
451461
raise KeyError(f"Undefined variable ${{{var}}} in preset (no env var, no default)")
452462
return env[var]
453463

454-
prev = None
455464
current = text
456-
while prev != current:
457-
prev = current
458-
current = _VAR_PATTERN.sub(repl, current)
459-
return current
465+
for _ in range(_MAX_EXPAND_ITERATIONS):
466+
expanded = _VAR_PATTERN.sub(repl, current)
467+
if expanded == current:
468+
return expanded
469+
current = expanded
470+
raise ValueError(
471+
f"Variable expansion did not converge for {text!r}; check for circular "
472+
"${VAR} references in [defaults], environment variables, or --set overrides."
473+
)
460474

461475

462476
def _build_preset(*, name: str, raw: dict[str, Any]) -> Preset:

src/sampleworks/runs/runner.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
DEFAULT_GRID_SEARCH_SCRIPT = "/app/run_grid_search.py"
1919
WORKSPACE_GRID_SEARCH_SCRIPT = "/home/dev/workspace/run_grid_search.py"
20+
PROCESS_SHUTDOWN_TIMEOUT_SECONDS = 10
21+
TEE_THREAD_JOIN_TIMEOUT_SECONDS = 5
2022

2123

2224
@dataclass(frozen=True)
@@ -172,9 +174,7 @@ def _validate_gpu_assignments(invocations: list[JobInvocation]) -> None:
172174
return
173175

174176
available_set = set(available)
175-
unavailable = {
176-
gpu: names for gpu, names in requested.items() if gpu not in available_set
177-
}
177+
unavailable = {gpu: names for gpu, names in requested.items() if gpu not in available_set}
178178
if unavailable:
179179
details = ", ".join(
180180
f"GPU {gpu} requested by {', '.join(names)}"
@@ -413,6 +413,7 @@ def run(preset: Preset, *, results_dir: Path, dry_run: bool = False) -> int:
413413
int
414414
``0`` if all jobs exited 0 (or ``dry_run`` was set), ``1`` otherwise.
415415
"""
416+
results_dir = results_dir.resolve()
416417
results_dir.mkdir(parents=True, exist_ok=True)
417418
invocations = build_invocations(preset, results_dir=results_dir)
418419
_validate_gpu_assignments(invocations)
@@ -444,23 +445,33 @@ def _terminate_all(jobs: list[_RunningJob]) -> None:
444445
Parameters
445446
----------
446447
jobs : list of _RunningJob
447-
Jobs whose subprocesses should be SIGTERM'd, waited on, and whose tee
448-
threads should be joined.
448+
Jobs whose subprocesses should be SIGTERM'd, escalated to SIGKILL if
449+
needed, and whose tee threads should be joined with bounded waits.
449450
"""
450451
for j in jobs:
451452
if j.proc.poll() is None:
452453
j.proc.terminate()
453454
for j in jobs:
454-
j.proc.wait()
455-
j.tee_thread.join()
455+
try:
456+
j.proc.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT_SECONDS)
457+
except subprocess.TimeoutExpired:
458+
j.proc.kill()
459+
try:
460+
j.proc.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT_SECONDS)
461+
except subprocess.TimeoutExpired:
462+
print(
463+
f"[{_ts()}] {j.inv.job.name} did not exit after SIGKILL",
464+
file=sys.stderr,
465+
)
466+
j.tee_thread.join(timeout=TEE_THREAD_JOIN_TIMEOUT_SECONDS)
456467

457468

458469
def _prepare_pixi_env(pixi_env: str) -> None:
459470
"""Prepare a pixi environment before parallel job launch.
460471
461-
``pixi run`` is deliberately called once per env even when the interpreter
462-
directory already exists, because pixi may still need to materialize PyPI
463-
packages into that environment after image startup.
472+
Preparation is skipped when a baked interpreter is already available, when
473+
prebuilt environments are required, or when ``SAMPLEWORKS_SKIP_ENV_PREPARE``
474+
is truthy. Otherwise, ``pixi run`` is called once for the environment.
464475
465476
Parameters
466477
----------
@@ -576,24 +587,39 @@ def _spawn(inv: JobInvocation) -> _RunningJob:
576587
inv.log_path.parent.mkdir(parents=True, exist_ok=True)
577588
inv.output_dir.mkdir(parents=True, exist_ok=True)
578589
log_file = open(inv.log_path, "wb")
590+
proc: subprocess.Popen[bytes] | None = None
591+
thread: threading.Thread | None = None
579592
try:
580593
proc = subprocess.Popen(
581594
inv.argv,
582595
env=inv.env,
596+
cwd=str(_pixi_project_dir()),
583597
stdout=subprocess.PIPE,
584598
stderr=subprocess.STDOUT,
585599
bufsize=0,
586600
)
601+
if proc.stdout is None:
602+
raise RuntimeError(f"Job {inv.job.name!r} started without a stdout pipe")
603+
thread = threading.Thread(
604+
target=_tee,
605+
args=(inv.job.name, proc.stdout, log_file),
606+
daemon=True,
607+
)
608+
thread.start()
587609
except BaseException:
588610
log_file.close()
611+
if proc is not None and proc.poll() is None:
612+
proc.kill()
613+
try:
614+
proc.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT_SECONDS)
615+
except subprocess.TimeoutExpired:
616+
print(
617+
f"[{_ts()}] {inv.job.name} did not exit after failed spawn cleanup",
618+
file=sys.stderr,
619+
)
589620
raise
590-
assert proc.stdout is not None
591-
thread = threading.Thread(
592-
target=_tee,
593-
args=(inv.job.name, proc.stdout, log_file),
594-
daemon=True,
595-
)
596-
thread.start()
621+
if proc is None or thread is None:
622+
raise RuntimeError(f"Job {inv.job.name!r} failed to initialize")
597623
print(f"[{_ts()}] launched {inv.job.name} (pid {proc.pid})", file=sys.stderr)
598624
return _RunningJob(inv=inv, proc=proc, tee_thread=thread)
599625

src/sampleworks/runs/schema.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Dataclasses for the preset schema.
22
33
A preset describes one or more parallel ``run_grid_search.py`` jobs. Each job
4-
is launched in its configured model environment with ``CUDA_VISIBLE_DEVICES``
5-
set to the job's GPU assignment.
4+
runs in its configured model environment, either through ``pixi run`` or a
5+
baked environment Python, with ``CUDA_VISIBLE_DEVICES`` set to the job's GPU
6+
assignment.
67
"""
78

89
from __future__ import annotations

tests/runs/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22

33
from __future__ import annotations
44

5+
import os
6+
57
import pytest
68

79

810
@pytest.fixture(autouse=True)
911
def force_pixi_argv(monkeypatch: pytest.MonkeyPatch) -> None:
1012
"""Keep argv assertions deterministic on machines with /app/.pixi present."""
13+
monkeypatch.delenv("SAMPLEWORKS_GRID_SEARCH_SCRIPT", raising=False)
14+
monkeypatch.delenv("SAMPLEWORKS_PIXI_PROJECT_DIR", raising=False)
15+
for var in list(os.environ):
16+
if var.startswith("SAMPLEWORKS_") and var.endswith("_PYTHON"):
17+
monkeypatch.delenv(var, raising=False)
1118
monkeypatch.setenv("SAMPLEWORKS_FORCE_PIXI", "1")

tests/runs/test_cli.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@ def test_dry_run_does_not_invoke_subprocess(
3939
) -> None:
4040
"""``--dry-run`` prints commands and CUDA assignment instead of executing."""
4141
monkeypatch.setenv("HOME", str(tmp_path))
42-
exit_code = cli.main([
43-
"--preset",
44-
"rf3_partial",
45-
"--dry-run",
46-
"--results-dir",
47-
str(tmp_path),
48-
])
42+
exit_code = cli.main(
43+
[
44+
"--preset",
45+
"rf3_partial",
46+
"--dry-run",
47+
"--results-dir",
48+
str(tmp_path),
49+
]
50+
)
4951
assert exit_code == 0
5052
out = capsys.readouterr().out
5153
assert "pixi run -e rf3 python /app/run_grid_search.py" in out

tests/runs/test_loader.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,27 @@ def test_set_with_unknown_top_level_key_raises(monkeypatch: pytest.MonkeyPatch)
183183
loader.load_preset("rf3_partial", overrides=["job.rf3.gpus=0"])
184184

185185

186+
def test_set_with_out_of_range_job_index_raises(monkeypatch: pytest.MonkeyPatch) -> None:
187+
"""Out-of-range list indices in overrides fail with a clear ``KeyError``."""
188+
monkeypatch.setenv("HOME", "/home/test")
189+
with pytest.raises(KeyError, match="index 99"):
190+
loader.load_preset("rf3_partial", overrides=["jobs.99.gpus=0"])
191+
192+
193+
def test_cyclic_variable_expansion_raises(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
194+
"""Cyclic ``${VAR}`` references fail fast instead of looping forever."""
195+
bad = tmp_path / "cycle.toml"
196+
bad.write_text(
197+
"[shared_args]\n"
198+
'proteins = "${A}"\n'
199+
'[[jobs]]\nname = "j"\nenv = "rf3"\ngpus = "0"\noutput_subdir = "j"\nargs = {}\n'
200+
)
201+
monkeypatch.setenv("A", "${B}")
202+
monkeypatch.setenv("B", "${A}")
203+
with pytest.raises(ValueError, match="did not converge"):
204+
loader.load_preset(str(bad))
205+
206+
186207
def test_bad_env_rejected(tmp_path: Path) -> None:
187208
"""Preset jobs reject unsupported pixi environment names."""
188209
bad = tmp_path / "bad.toml"

0 commit comments

Comments
 (0)