Skip to content

Commit 1371f83

Browse files
authored
Add join sweep method (#173)
Allows adding extra agents to a sweep.
1 parent 3e94705 commit 1371f83

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Join an existing Weights and Biases sweep, as a new agent."""
2+
import argparse
3+
4+
from sparse_autoencoder.train.sweep import sweep
5+
6+
7+
def parse_arguments() -> argparse.Namespace:
8+
"""Parse command line arguments.
9+
10+
Returns:
11+
argparse.Namespace: Parsed command line arguments.
12+
"""
13+
parser = argparse.ArgumentParser(description="Join an existing W&B sweep.")
14+
parser.add_argument(
15+
"--id", type=str, default=None, help="Sweep ID for the existing sweep.", required=True
16+
)
17+
return parser.parse_args()
18+
19+
20+
if __name__ == "__main__":
21+
args = parse_arguments()
22+
23+
sweep(sweep_id=args.id)

sparse_autoencoder/train/sweep.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,28 @@ def train() -> None:
347347
sys.exit(1)
348348

349349

350-
def sweep(sweep_config: SweepConfig) -> None:
351-
"""Main function to run the training pipeline with wandb hyperparameter sweep."""
352-
sweep_id = wandb.sweep(sweep_config.to_dict(), project="sparse-autoencoder")
350+
def sweep(sweep_config: SweepConfig | None = None, sweep_id: str | None = None) -> None:
351+
"""Run the training pipeline with wandb hyperparameter sweep.
352+
353+
Warning:
354+
Either sweep_config or sweep_id must be specified, but not both.
355+
356+
Args:
357+
sweep_config: The sweep configuration.
358+
sweep_id: The sweep id for an existing sweep.
359+
360+
Raises:
361+
ValueError: If neither sweep_config nor sweep_id is specified.
362+
"""
363+
if sweep_id is not None:
364+
wandb.agent(sweep_id, train, project="sparse-autoencoder")
365+
366+
elif sweep_config is not None:
367+
sweep_id = wandb.sweep(sweep_config.to_dict(), project="sparse-autoencoder")
368+
wandb.agent(sweep_id, train)
369+
370+
else:
371+
error_message = "Either sweep_config or sweep_id must be specified."
372+
raise ValueError(error_message)
353373

354-
wandb.agent(sweep_id, train)
355374
wandb.finish()

0 commit comments

Comments
 (0)