diff --git a/spd/dataset_attributions/scripts/run_slurm.py b/spd/dataset_attributions/scripts/run_slurm.py index 80d7949c4..31f30c03a 100644 --- a/spd/dataset_attributions/scripts/run_slurm.py +++ b/spd/dataset_attributions/scripts/run_slurm.py @@ -50,8 +50,8 @@ def submit_attributions( time: Job time limit. job_suffix: Optional suffix for SLURM job names (e.g., "1h" -> "spd-attr-1h"). """ - run_id = f"attr-{secrets.token_hex(4)}" - snapshot_branch, commit_hash = create_git_snapshot(run_id) + launch_id = f"attr-{secrets.token_hex(4)}" + snapshot_branch, commit_hash = create_git_snapshot(snapshot_id=launch_id) logger.info(f"Created git snapshot: {snapshot_branch} ({commit_hash[:8]})") suffix = f"-{job_suffix}" if job_suffix else "" diff --git a/spd/experiments/ih/ih_decomposition.py b/spd/experiments/ih/ih_decomposition.py index 1b0b268fc..dfdc33f88 100644 --- a/spd/experiments/ih/ih_decomposition.py +++ b/spd/experiments/ih/ih_decomposition.py @@ -1,87 +1,43 @@ -import json +"""Induction head decomposition script.""" + from pathlib import Path import fire -import wandb -from spd.configs import Config, IHTaskConfig +from spd.configs import IHTaskConfig from spd.experiments.ih.model import InductionModelTargetRunInfo, InductionTransformer from spd.log import logger -from spd.run_spd import optimize +from spd.run_spd import run_experiment from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import ( - save_pre_run_info, - set_seed, -) -from spd.utils.run_utils import ExecutionStamp -from spd.utils.wandb_utils import init_wandb +from spd.utils.general_utils import set_seed +from spd.utils.run_utils import parse_config, parse_sweep_params def main( config_path: Path | str | None = None, config_json: str | None = None, evals_id: str | None = None, - sweep_id: str | None = None, + launch_id: str | None = None, sweep_params_json: str | None = None, + run_id: str | None = None, ) -> None: - assert (config_path is not None) != (config_json is not None), ( - "Need exactly one of config_path and config_json" - ) - if config_path is not None: - config = Config.from_file(config_path) - else: - assert config_json is not None - config = Config(**json.loads(config_json.removeprefix("json:"))) - - sweep_params = ( - None if sweep_params_json is None else json.loads(sweep_params_json.removeprefix("json:")) - ) + config = parse_config(config_path, config_json) device = get_device() logger.info(f"Using device: {device}") - execution_stamp = ExecutionStamp.create(run_type="spd", create_snapshot=False) - out_dir = execution_stamp.out_dir - logger.info(f"Run ID: {execution_stamp.run_id}") - logger.info(f"Output directory: {out_dir}") - - if config.wandb_project: - tags = ["ih"] - if evals_id: - tags.append(evals_id) - if sweep_id: - tags.append(sweep_id) - init_wandb( - config=config, - project=config.wandb_project, - run_id=execution_stamp.run_id, - name=config.wandb_run_name, - tags=tags, - ) + set_seed(config.seed) task_config = config.task_config assert isinstance(task_config, IHTaskConfig) - set_seed(config.seed) - logger.info(config) - assert config.pretrained_model_path, "pretrained_model_path must be set" target_run_info = InductionModelTargetRunInfo.from_path(config.pretrained_model_path) target_model = InductionTransformer.from_run_info(target_run_info) target_model = target_model.to(device) target_model.eval() - save_pre_run_info( - save_to_wandb=config.wandb_project is not None, - out_dir=out_dir, - spd_config=config, - sweep_params=sweep_params, - target_model=target_model, - train_config=target_run_info.config, - task_name=config.task_config.task_name, - ) - prefix_window = task_config.prefix_window or target_model.config.seq_len - 3 dataset = InductionDataset( @@ -93,19 +49,20 @@ def main( train_loader = DatasetGeneratedDataLoader(dataset, batch_size=config.batch_size, shuffle=False) eval_loader = DatasetGeneratedDataLoader(dataset, batch_size=config.batch_size, shuffle=False) - optimize( + run_experiment( target_model=target_model, config=config, device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, - out_dir=out_dir, + experiment_tag="ih", + run_id=run_id, + launch_id=launch_id, + evals_id=evals_id, + sweep_params=parse_sweep_params(sweep_params_json), + target_model_train_config=target_run_info.config, ) - if config.wandb_project: - wandb.finish() - if __name__ == "__main__": fire.Fire(main) diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index 82433d985..e06d01481 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -1,16 +1,14 @@ """Language Model decomposition script.""" -import json from pathlib import Path import fire -import wandb -from spd.configs import Config, LMTaskConfig +from spd.configs import LMTaskConfig from spd.data import DatasetConfig, create_data_loader from spd.log import logger from spd.pretrain.run_info import PretrainRunInfo -from spd.run_spd import optimize +from spd.run_spd import run_experiment from spd.utils.distributed_utils import ( DistributedState, ensure_cached_and_call, @@ -19,9 +17,8 @@ is_main_process, with_distributed_cleanup, ) -from spd.utils.general_utils import resolve_class, save_pre_run_info, set_seed -from spd.utils.run_utils import setup_decomposition_run -from spd.utils.wandb_utils import init_wandb +from spd.utils.general_utils import resolve_class, set_seed +from spd.utils.run_utils import parse_config, parse_sweep_params @with_distributed_cleanup @@ -29,44 +26,18 @@ def main( config_path: Path | str | None = None, config_json: str | None = None, evals_id: str | None = None, - sweep_id: str | None = None, + launch_id: str | None = None, sweep_params_json: str | None = None, + run_id: str | None = None, ) -> None: - assert (config_path is not None) != (config_json is not None), ( - "Need exactly one of config_path and config_json" - ) - if config_path is not None: - config = Config.from_file(config_path) - else: - assert config_json is not None - config = Config(**json.loads(config_json.removeprefix("json:"))) + config = parse_config(config_path, config_json) dist_state = init_distributed() logger.info(f"Distributed state: {dist_state}") - sweep_params = ( - None if sweep_params_json is None else json.loads(sweep_params_json.removeprefix("json:")) - ) - # Use the same seed across all ranks for deterministic data loading set_seed(config.seed) - if is_main_process(): - out_dir, run_id, tags = setup_decomposition_run( - experiment_tag="lm", evals_id=evals_id, sweep_id=sweep_id - ) - if config.wandb_project: - init_wandb( - config=config, - project=config.wandb_project, - run_id=run_id, - name=config.wandb_run_name, - tags=tags, - ) - logger.info(config) - else: - out_dir = None - device = get_device() assert isinstance(config.task_config, LMTaskConfig), "task_config not LMTaskConfig" @@ -96,18 +67,6 @@ def main( ) target_model.eval() - if is_main_process(): - assert out_dir is not None - save_pre_run_info( - save_to_wandb=config.wandb_project is not None, - out_dir=out_dir, - spd_config=config, - sweep_params=sweep_params, - target_model=None, - train_config=None, - task_name=None, - ) - # --- Load Data --- # if is_main_process(): logger.info("Loading dataset...") @@ -171,24 +130,19 @@ def main( dist_state=dist_state, ) - if is_main_process(): - logger.info("Starting optimization...") - - optimize( + run_experiment( target_model=target_model, config=config, device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, - out_dir=out_dir, + experiment_tag="lm", + run_id=run_id, + launch_id=launch_id, + evals_id=evals_id, + sweep_params=parse_sweep_params(sweep_params_json), ) - if is_main_process(): - logger.info("Optimization finished.") - if config.wandb_project: - wandb.finish() - if __name__ == "__main__": fire.Fire(main) diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index e5e5aed49..f47233a2e 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -1,61 +1,36 @@ """Residual MLP decomposition script.""" -import json from pathlib import Path import fire -import wandb -from spd.configs import Config, ResidMLPTaskConfig +from spd.configs import ResidMLPTaskConfig from spd.experiments.resid_mlp.models import ( ResidMLP, ResidMLPTargetRunInfo, ) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.log import logger -from spd.run_spd import optimize +from spd.run_spd import run_experiment +from spd.settings import SPD_OUT_DIR from spd.utils.data_utils import DatasetGeneratedDataLoader from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import save_pre_run_info, set_seed -from spd.utils.run_utils import save_file, setup_decomposition_run -from spd.utils.wandb_utils import init_wandb +from spd.utils.general_utils import set_seed +from spd.utils.run_utils import generate_run_id, parse_config, parse_sweep_params, save_file def main( config_path: Path | str | None = None, config_json: str | None = None, evals_id: str | None = None, - sweep_id: str | None = None, + launch_id: str | None = None, sweep_params_json: str | None = None, + run_id: str | None = None, ) -> None: - assert (config_path is not None) != (config_json is not None), ( - "Need exactly one of config_path and config_json" - ) - if config_path is not None: - config = Config.from_file(config_path) - else: - assert config_json is not None - config = Config(**json.loads(config_json.removeprefix("json:"))) - - sweep_params = ( - None if sweep_params_json is None else json.loads(sweep_params_json.removeprefix("json:")) - ) + config = parse_config(config_path, config_json) set_seed(config.seed) - out_dir, run_id, tags = setup_decomposition_run( - experiment_tag="resid_mlp", evals_id=evals_id, sweep_id=sweep_id - ) - if config.wandb_project: - init_wandb( - config=config, - project=config.wandb_project, - run_id=run_id, - name=config.wandb_run_name, - tags=tags, - ) - logger.info(config) - device = get_device() logger.info(f"Using device: {device}") assert isinstance(config.task_config, ResidMLPTaskConfig) @@ -66,18 +41,11 @@ def main( target_model = target_model.to(device) target_model.eval() - save_pre_run_info( - save_to_wandb=config.wandb_project is not None, - out_dir=out_dir, - spd_config=config, - sweep_params=sweep_params, - target_model=target_model, - train_config=target_run_info.config, - task_name=config.task_config.task_name, - ) + # Domain-specific: save label coefficients to out_dir + run_id = run_id or generate_run_id("spd") + out_dir = SPD_OUT_DIR / "spd" / run_id + out_dir.mkdir(parents=True, exist_ok=True) save_file(target_run_info.label_coeffs.detach().cpu().tolist(), out_dir / "label_coeffs.json") - if config.wandb_project: - wandb.save(str(out_dir / "label_coeffs.json"), base_path=out_dir, policy="now") synced_inputs = target_run_info.config.synced_inputs dataset = ResidMLPDataset( @@ -98,21 +66,21 @@ def main( dataset, batch_size=config.eval_batch_size, shuffle=False ) - # TODO: Below not needed when TMS supports config.n_eval_steps assert config.n_eval_steps is not None, "n_eval_steps must be set" - optimize( + run_experiment( target_model=target_model, config=config, device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, - out_dir=out_dir, + experiment_tag="resid_mlp", + run_id=run_id, + launch_id=launch_id, + evals_id=evals_id, + sweep_params=parse_sweep_params(sweep_params_json), + target_model_train_config=target_run_info.config, ) - if config.wandb_project: - wandb.finish() - if __name__ == "__main__": fire.Fire(main) diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 553329c9c..85249d8d7 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -4,61 +4,35 @@ the losses of the "correct" solution during training. """ -import json from pathlib import Path import fire -import wandb -from spd.configs import Config, TMSTaskConfig +from spd.configs import TMSTaskConfig from spd.experiments.tms.models import TMSModel, TMSTargetRunInfo from spd.log import logger -from spd.run_spd import optimize +from spd.run_spd import run_experiment from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import save_pre_run_info, set_seed -from spd.utils.run_utils import setup_decomposition_run -from spd.utils.wandb_utils import init_wandb +from spd.utils.general_utils import set_seed +from spd.utils.run_utils import parse_config, parse_sweep_params def main( config_path: Path | str | None = None, config_json: str | None = None, evals_id: str | None = None, - sweep_id: str | None = None, + launch_id: str | None = None, sweep_params_json: str | None = None, + run_id: str | None = None, ) -> None: - assert (config_path is not None) != (config_json is not None), ( - "Need exactly one of config_path and config_json" - ) - if config_path is not None: - config = Config.from_file(config_path) - else: - assert config_json is not None - config = Config(**json.loads(config_json.removeprefix("json:"))) - - sweep_params = ( - None if sweep_params_json is None else json.loads(sweep_params_json.removeprefix("json:")) - ) + config = parse_config(config_path, config_json) device = get_device() logger.info(f"Using device: {device}") set_seed(config.seed) - out_dir, run_id, tags = setup_decomposition_run( - experiment_tag="tms", evals_id=evals_id, sweep_id=sweep_id - ) - if config.wandb_project: - init_wandb( - config=config, - project=config.wandb_project, - run_id=run_id, - name=config.wandb_run_name, - tags=tags, - ) - logger.info(config) - task_config = config.task_config assert isinstance(task_config, TMSTaskConfig) @@ -68,16 +42,6 @@ def main( target_model = target_model.to(device) target_model.eval() - save_pre_run_info( - save_to_wandb=config.wandb_project is not None, - out_dir=out_dir, - spd_config=config, - sweep_params=sweep_params, - target_model=target_model, - train_config=target_model.config, - task_name=config.task_config.task_name, - ) - synced_inputs = target_run_info.config.synced_inputs dataset = SparseFeatureDataset( n_features=target_model.config.n_features, @@ -96,20 +60,21 @@ def main( if target_model.config.tied_weights: tied_weights = [("linear1", "linear2")] - optimize( + run_experiment( target_model=target_model, config=config, device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, - out_dir=out_dir, + experiment_tag="tms", + run_id=run_id, + launch_id=launch_id, + evals_id=evals_id, + sweep_params=parse_sweep_params(sweep_params_json), + target_model_train_config=target_model.config, tied_weights=tied_weights, ) - if config.wandb_project: - wandb.finish() - if __name__ == "__main__": fire.Fire(main) diff --git a/spd/harvest/scripts/run_slurm.py b/spd/harvest/scripts/run_slurm.py index 6350aba89..801cf7cb4 100644 --- a/spd/harvest/scripts/run_slurm.py +++ b/spd/harvest/scripts/run_slurm.py @@ -56,8 +56,8 @@ def harvest( time: Job time limit for worker jobs. job_suffix: Optional suffix for SLURM job names (e.g., "v2" -> "spd-harvest-v2"). """ - run_id = f"harvest-{secrets.token_hex(4)}" - snapshot_branch, commit_hash = create_git_snapshot(run_id) + launch_id = f"harvest-{secrets.token_hex(4)}" + snapshot_branch, commit_hash = create_git_snapshot(snapshot_id=launch_id) logger.info(f"Created git snapshot: {snapshot_branch} ({commit_hash[:8]})") suffix = f"-{job_suffix}" if job_suffix else "" diff --git a/spd/run_spd.py b/spd/run_spd.py index 6b1075bfe..623315689 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -1,10 +1,11 @@ """Run SPD on a model.""" import gc +import os from collections import defaultdict from collections.abc import Iterator from pathlib import Path -from typing import cast +from typing import Any, cast import torch import torch.nn as nn @@ -18,6 +19,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm +from spd.base_config import BaseConfig from spd.configs import ( Config, LossMetricConfigType, @@ -33,6 +35,7 @@ from spd.losses import compute_total_loss from spd.metrics import faithfulness_loss from spd.models.component_model import ComponentModel, OutputWithCache +from spd.settings import SPD_OUT_DIR from spd.utils.component_utils import calc_ci_l_zero from spd.utils.distributed_utils import ( avg_metrics_across_ranks, @@ -45,11 +48,12 @@ dict_safe_update_, extract_batch_data, get_scheduled_value, + save_pre_run_info, ) from spd.utils.logging_utils import get_grad_norms_dict, local_log from spd.utils.module_utils import expand_module_patterns -from spd.utils.run_utils import save_file -from spd.utils.wandb_utils import try_wandb +from spd.utils.run_utils import generate_run_id, save_file +from spd.utils.wandb_utils import init_wandb, try_wandb def run_faithfulness_warmup( @@ -376,3 +380,68 @@ def create_pgd_data_iter() -> ( if is_main_process(): logger.info("Finished training loop.") + + +def run_experiment( + target_model: nn.Module, + config: Config, + device: str, + train_loader: DataLoader[Int[Tensor, "..."]] + | DataLoader[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], + eval_loader: DataLoader[Int[Tensor, "..."]] + | DataLoader[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], + experiment_tag: str, + run_id: str | None = None, + launch_id: str | None = None, + evals_id: str | None = None, + sweep_params: dict[str, Any] | None = None, + target_model_train_config: BaseConfig | None = None, + tied_weights: list[tuple[str, str]] | None = None, +) -> None: + """Run a full SPD experiment: setup, optimize, cleanup. + + All ranks call this function. Only the main process does wandb/logging setup. + """ + if is_main_process(): + run_id = run_id or generate_run_id("spd") + out_dir = SPD_OUT_DIR / "spd" / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Run ID: {run_id}") + logger.info(f"Output directory: {out_dir}") + + tags = [str(i) for i in [experiment_tag, evals_id, launch_id] if i is not None] + slurm_array_job_id = os.getenv("SLURM_ARRAY_JOB_ID") + if slurm_array_job_id is not None: + tags.append(f"slurm-array-job-id_{slurm_array_job_id}") + + if config.wandb_project: + init_wandb(config, config.wandb_project, run_id, config.wandb_run_name, tags) + + logger.info(config) + + save_pre_run_info( + save_to_wandb=config.wandb_project is not None, + out_dir=out_dir, + spd_config=config, + sweep_params=sweep_params, + target_model=target_model if target_model_train_config is not None else None, + train_config=target_model_train_config, + task_name=getattr(config.task_config, "task_name", None), + ) + else: + out_dir = None + + optimize( + target_model=target_model, + config=config, + device=device, + train_loader=train_loader, + eval_loader=eval_loader, + n_eval_steps=config.n_eval_steps, + out_dir=out_dir, + tied_weights=tied_weights, + ) + + if is_main_process() and config.wandb_project: + wandb.finish() diff --git a/spd/scripts/run.py b/spd/scripts/run.py index 922b0a4d4..bcd85f978 100644 --- a/spd/scripts/run.py +++ b/spd/scripts/run.py @@ -25,9 +25,14 @@ create_slurm_array_script, ) from spd.utils.git_utils import create_git_snapshot -from spd.utils.run_utils import apply_nested_updates, generate_grid_combinations +from spd.utils.run_utils import apply_nested_updates, generate_grid_combinations, generate_run_id from spd.utils.slurm import submit_slurm_job -from spd.utils.wandb_utils import ReportCfg, create_view_and_report, generate_wandb_run_name +from spd.utils.wandb_utils import ( + ReportCfg, + create_view_and_report, + generate_wandb_run_name, + get_wandb_run_url, +) def launch_slurm_run( @@ -57,8 +62,8 @@ def launch_slurm_run( project: W&B project name """ - run_id = _generate_run_id() - logger.info(f"Run ID: {run_id}") + launch_id = _generate_launch_id() + logger.info(f"Launch ID: {launch_id}") experiments_list = _get_experiments(experiments) logger.info(f"Experiments: {', '.join(experiments_list)}") @@ -76,24 +81,25 @@ def launch_slurm_run( sweep_params=sweep_params, ) - snapshot_branch, commit_hash = create_git_snapshot(run_id=run_id) + snapshot_branch, commit_hash = create_git_snapshot(snapshot_id=launch_id) logger.info(f"Created git snapshot branch: {snapshot_branch} ({commit_hash[:8]})") - _wandb_setup( - create_report=create_report, - report_title=report_title, - project=project, - run_id=run_id, - experiments_list=experiments_list, - snapshot_branch=snapshot_branch, - commit_hash=commit_hash, - ) + if len(training_jobs) > 1: + _create_wandb_views_and_report( + create_report=create_report, + report_title=report_title, + project=project, + launch_id=launch_id, + experiments_list=experiments_list, + snapshot_branch=snapshot_branch, + commit_hash=commit_hash, + ) slurm_job_name = f"spd-{job_suffix or get_max_expected_runtime(experiments_list)}" array_script_content = create_slurm_array_script( slurm_job_name=slurm_job_name, - run_id=run_id, + launch_id=launch_id, training_jobs=training_jobs, sweep_params=sweep_params, snapshot_branch=snapshot_branch, @@ -105,26 +111,33 @@ def launch_slurm_run( # Submit script (handles file writing, submission, renaming, and log file creation) result = submit_slurm_job( array_script_content, - f"run_array_{run_id}", + f"launch_array_{launch_id}", is_array=True, n_array_tasks=len(training_jobs), ) logger.section("Job submitted successfully!") - logger.values( - { - "Array Job ID": result.job_id, - "Total training jobs": len(training_jobs), - "Max concurrent tasks": n_agents, - "View logs in": result.log_pattern, - "Script": str(result.script_path), - } - ) - - -def _generate_run_id() -> str: - """Generate a unique run ID based on timestamp.""" - return f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + summary: dict[str, str | int | None] = { + "Array Job ID": result.job_id, + "Total training jobs": len(training_jobs), + "Max concurrent tasks": n_agents, + "View logs in": result.log_pattern, + "Script": str(result.script_path), + } + if len(training_jobs) <= 10: + urls = [get_wandb_run_url(project, job.run_id) for job in training_jobs] + summary["WandB run URLs"] = ( + urls[0] if len(urls) == 1 else "\n" + "\n".join(f" - {u}" for u in urls) + ) + logger.values(summary) + + +def _generate_launch_id() -> str: + """Generate a unique launch ID based on timestamp. + + Prefixed with 'launch-' to prevent Python Fire from parsing the numeric timestamp as an int. + """ + return f"launch-{datetime.now().strftime('%Y%m%d_%H%M%S')}" def _create_training_jobs( @@ -161,6 +174,7 @@ def _create_training_jobs( experiment=experiment, script_path=exp_config.decomp_script, config=config_with_overrides, + run_id=generate_run_id("spd"), ) ) task_breakdown[experiment] = "1 job" @@ -185,6 +199,7 @@ def _create_training_jobs( experiment=experiment, script_path=exp_config.decomp_script, config=config_with_overrides, + run_id=generate_run_id("spd"), ) ) @@ -310,11 +325,11 @@ def _resolve_sweep_params_path(sweep_params_file: str) -> Path: return REPO_ROOT / sweep_params_file -def _wandb_setup( +def _create_wandb_views_and_report( create_report: bool, report_title: str | None, project: str, - run_id: str, + launch_id: str, experiments_list: list[str], snapshot_branch: str, commit_hash: str, @@ -339,7 +354,7 @@ def _wandb_setup( create_view_and_report( project=project, - run_id=run_id, + launch_id=launch_id, experiments=experiments_list, report_cfg=report_cfg, ) diff --git a/spd/utils/compute_utils.py b/spd/utils/compute_utils.py index 01be254d9..7c4f0d55f 100644 --- a/spd/utils/compute_utils.py +++ b/spd/utils/compute_utils.py @@ -28,6 +28,7 @@ class TrainingJob: experiment: str script_path: Path config: Config + run_id: str # Pre-generated unique run identifier (e.g. "s-a1b2c3d4") def _choose_master_port(run_id_local: str, idx: int) -> int: @@ -43,7 +44,7 @@ def _choose_master_port(run_id_local: str, idx: int) -> int: def _build_script_args( - run_id: str, + launch_id: str, job: TrainingJob, sweep_params: dict[str, Any] | None, ) -> str: @@ -51,8 +52,9 @@ def _build_script_args( json_tagged_config = f"json:{json.dumps(job.config.model_dump(mode='json'))}" args = ( f"--config_json {shlex.quote(json_tagged_config)} " - f"--sweep_id {run_id} " - f"--evals_id {job.experiment}" + f"--launch_id {launch_id} " + f"--evals_id {job.experiment} " + f"--run_id {job.run_id}" ) if sweep_params is not None: json_tagged_sweep_params = f"json:{json.dumps(sweep_params)}" @@ -61,7 +63,7 @@ def _build_script_args( def get_command( - run_id: str, + launch_id: str, job: TrainingJob, job_idx: int, n_gpus: int | None, @@ -71,7 +73,7 @@ def get_command( """Build the command to run a training job. Args: - run_id: Unique identifier for the run. + launch_id: Launch identifier for this group of jobs. job: The training job to run. job_idx: Index of the job in the run. n_gpus: Number of GPUs. None or 1 means single GPU/CPU. 2-8 means single-node DDP. @@ -79,8 +81,8 @@ def get_command( sweep_params: Optional sweep parameters to pass to the job. snapshot_branch: Git branch to checkout (used for multi-node workspace setup). """ - port = _choose_master_port(run_id, job_idx) - script_args = _build_script_args(run_id, job, sweep_params) + port = _choose_master_port(launch_id, job_idx) + script_args = _build_script_args(launch_id, job, sweep_params) match n_gpus: case None | 1: @@ -120,7 +122,7 @@ def get_command( def create_slurm_array_script( slurm_job_name: str, - run_id: str, + launch_id: str, training_jobs: list[TrainingJob], sweep_params: dict[str, Any] | None, snapshot_branch: str, @@ -135,7 +137,7 @@ def create_slurm_array_script( Args: slurm_job_name: Name for the SLURM job array - run_id: Unique identifier for the run. + launch_id: Launch identifier for this group of jobs. training_jobs: List of training jobs to execute. sweep_params: Optional sweep parameters to pass to the jobs. snapshot_branch: Git branch to checkout. @@ -148,7 +150,7 @@ def create_slurm_array_script( commands: list[str] = [] for i, training_job in enumerate(training_jobs): cmd = get_command( - run_id, + launch_id, training_job, i, n_gpus, diff --git a/spd/utils/git_utils.py b/spd/utils/git_utils.py index df438930a..e5d533069 100644 --- a/spd/utils/git_utils.py +++ b/spd/utils/git_utils.py @@ -41,13 +41,17 @@ def repo_current_commit_hash() -> str: return commit_hash -def create_git_snapshot(run_id: str) -> tuple[str, str]: +def create_git_snapshot(snapshot_id: str) -> tuple[str, str]: """Create a git snapshot branch with current changes. Creates a timestamped branch containing all current changes (staged and unstaged). Uses a temporary worktree to avoid affecting the current working directory. Will push the snapshot branch to origin if possible, but will continue without error if push permissions are lacking. + Args: + snapshot_id: Identifier used in the branch name and commit message (e.g. a launch_id + or run_id). + Returns: (branch_name, commit_hash) where commit_hash is the HEAD of the snapshot branch (this will be the new snapshot commit if changes existed, otherwise the base commit). @@ -56,11 +60,11 @@ def create_git_snapshot(run_id: str) -> tuple[str, str]: subprocess.CalledProcessError: If git commands fail (except for push) """ # prefix branch name - snapshot_branch: str = f"snapshot/{run_id}" + snapshot_branch: str = f"snapshot/{snapshot_id}" # Create temporary worktree path with tempfile.TemporaryDirectory() as temp_dir: - worktree_path = Path(temp_dir) / f"spd-snapshot-{run_id}" + worktree_path = Path(temp_dir) / f"spd-snapshot-{snapshot_id}" try: # Create worktree with new branch @@ -97,7 +101,7 @@ def create_git_snapshot(run_id: str) -> tuple[str, str]: # Commit changes if any exist if diff_result.returncode != 0: # Non-zero means there are changes subprocess.run( - ["git", "commit", "-m", f"run id {run_id}", "--no-verify"], + ["git", "commit", "-m", f"snapshot {snapshot_id}", "--no-verify"], cwd=worktree_path, check=True, capture_output=True, diff --git a/spd/utils/run_utils.py b/spd/utils/run_utils.py index 080d2545a..813c30e50 100644 --- a/spd/utils/run_utils.py +++ b/spd/utils/run_utils.py @@ -13,6 +13,7 @@ import torch import yaml +from spd.configs import Config from spd.log import logger from spd.settings import SPD_OUT_DIR from spd.utils.git_utils import ( @@ -324,22 +325,39 @@ def generate_grid_combinations(parameters: dict[str, Any]) -> list[dict[str, Any } +def generate_run_id(run_type: RunType) -> str: + """Generate a unique run identifier. + + Format: `{type_abbr}-{random_hex}` + """ + type_abbr = RUN_TYPE_ABBREVIATIONS[run_type] + return f"{type_abbr}-{secrets.token_hex(4)}" + + +def parse_config(config_path: Path | str | None, config_json: str | None) -> Config: + """Parse a Config from either a file path or a JSON string. Exactly one must be provided.""" + assert (config_path is not None) != (config_json is not None), ( + "Need exactly one of config_path and config_json" + ) + if config_path is not None: + return Config.from_file(config_path) + assert config_json is not None + return Config(**json.loads(config_json.removeprefix("json:"))) + + +def parse_sweep_params(sweep_params_json: str | None) -> dict[str, Any] | None: + """Parse sweep parameters from a JSON string, or return None if not provided.""" + if sweep_params_json is None: + return None + return json.loads(sweep_params_json.removeprefix("json:")) + + class ExecutionStamp(NamedTuple): run_id: str snapshot_branch: str commit_hash: str run_type: RunType - @staticmethod - def _generate_run_id(run_type: RunType) -> str: - """Generate a unique run identifier, - - Format: `{type_abbr}-{random_hex}` - """ - type_abbr: str = RUN_TYPE_ABBREVIATIONS[run_type] - random_hex: str = secrets.token_hex(4) - return f"{type_abbr}-{random_hex}" - @classmethod def create( cls, @@ -347,12 +365,12 @@ def create( create_snapshot: bool, ) -> "ExecutionStamp": """Create an execution stamp, possibly including a git snapshot branch.""" - run_id = ExecutionStamp._generate_run_id(run_type) + run_id = generate_run_id(run_type) snapshot_branch: str commit_hash: str if create_snapshot: - snapshot_branch, commit_hash = create_git_snapshot(run_id=run_id) + snapshot_branch, commit_hash = create_git_snapshot(snapshot_id=run_id) logger.info(f"Created git snapshot branch: {snapshot_branch} ({commit_hash[:8]})") else: snapshot_branch = repo_current_branch() @@ -380,38 +398,6 @@ def out_dir(self) -> Path: return run_dir -def setup_decomposition_run( - experiment_tag: str, - evals_id: str | None = None, - sweep_id: str | None = None, -) -> tuple[Path, str, list[str]]: - """Set up run infrastructure for a decomposition experiment. - - Creates execution stamp, logs run info, and builds W&B tags. - Should only be called on main process for distributed training. - - Args: - experiment_tag: Tag for the experiment type (e.g., "lm", "tms", "resid_mlp") - evals_id: Optional evaluation identifier to add as W&B tag - sweep_id: Optional sweep identifier to add as W&B tag - - Returns: - Tuple of (output directory, run_id, tags for W&B). - """ - execution_stamp = ExecutionStamp.create(run_type="spd", create_snapshot=False) - out_dir = execution_stamp.out_dir - logger.info(f"Run ID: {execution_stamp.run_id}") - logger.info(f"Output directory: {out_dir}") - - tags = [i for i in [experiment_tag, evals_id, sweep_id] if i is not None] - slurm_array_job_id = os.getenv("SLURM_ARRAY_JOB_ID", None) - if slurm_array_job_id is not None: - logger.info(f"Running on slurm array job id: {slurm_array_job_id}") - tags.append(f"slurm-array-job-id_{slurm_array_job_id}") - - return out_dir, execution_stamp.run_id, tags - - _NO_ARG_PARSSED_SENTINEL = object() diff --git a/spd/utils/wandb_utils.py b/spd/utils/wandb_utils.py index 04cb6727c..62ed7de9c 100644 --- a/spd/utils/wandb_utils.py +++ b/spd/utils/wandb_utils.py @@ -69,6 +69,23 @@ } +def get_wandb_entity() -> str: + """Get the WandB entity from env var or the authenticated user's default entity.""" + load_dotenv(override=True) + entity = os.getenv("WANDB_ENTITY") + if entity is None: + entity = wandb.Api().default_entity + assert entity is not None, ( + "Could not determine WandB entity. Set WANDB_ENTITY in .env or log in with `wandb login`." + ) + return entity + + +def get_wandb_run_url(project: str, run_id: str) -> str: + """Get the direct WandB URL for a run.""" + return f"https://wandb.ai/{get_wandb_entity()}/{project}/runs/{run_id}" + + def _parse_metric_config_key(key: str) -> tuple[str, str, str] | None: """Parse a metric config key into (list_field, classname, param). @@ -290,12 +307,10 @@ def init_wandb( name: The name of the wandb run. tags: Optional list of tags to add to the run. """ - load_dotenv(override=True) - wandb.init( id=run_id, project=project, - entity=os.getenv("WANDB_ENTITY"), + entity=get_wandb_entity(), name=name, tags=tags, ) @@ -328,7 +343,7 @@ def ensure_project_exists(project: str) -> None: logger.info(f"Project '{project}' created successfully") -def create_workspace_view(run_id: str, experiment_name: str, project: str) -> str: +def create_workspace_view(launch_id: str, experiment_name: str, project: str) -> str: """Create a wandb workspace view for an experiment.""" # Use experiment-specific template if available template_url: str = WORKSPACE_TEMPLATES.get(experiment_name, WORKSPACE_TEMPLATES["default"]) @@ -338,12 +353,11 @@ def create_workspace_view(run_id: str, experiment_name: str, project: str) -> st workspace.project = project # Update the workspace name - workspace.name = f"{experiment_name} - {run_id}" + workspace.name = f"{experiment_name} - {launch_id}" - # Filter for runs that have BOTH the run_id AND experiment name tags - # Create filter using the same pattern as in run_grid_search.py + # Filter for runs that have BOTH the launch_id AND experiment name tags workspace.runset_settings.filters = [ - ws.Tags("tags").isin([run_id]), + ws.Tags("tags").isin([launch_id]), ws.Tags("tags").isin([experiment_name]), ] @@ -355,7 +369,7 @@ def create_workspace_view(run_id: str, experiment_name: str, project: str) -> st def create_wandb_report( report_title: str, - run_id: str, + launch_id: str, branch_name: str, commit_hash: str | None, experiments: list[str], @@ -363,7 +377,7 @@ def create_wandb_report( project: str, report_total_width: int = 24, ) -> str: - """Create a W&B report for the run.""" + """Create a W&B report for the launch.""" report = wr.Report( project=project, title=report_title, @@ -379,8 +393,10 @@ def create_wandb_report( for experiment in experiments: task_name: str = EXPERIMENT_REGISTRY[experiment].task_name - # Use run_id and experiment name tags for filtering - combined_filter = f'(Tags("tags") in ["{run_id}"]) and (Tags("tags") in ["{experiment}"])' + # Use launch_id and experiment name tags for filtering + combined_filter = ( + f'(Tags("tags") in ["{launch_id}"]) and (Tags("tags") in ["{experiment}"])' + ) # Create runset for this specific experiment runset = wr.Runset( @@ -545,7 +561,7 @@ class ReportCfg: def create_view_and_report( project: str, - run_id: str, + launch_id: str, experiments: list[str], report_cfg: ReportCfg | None, ) -> None: @@ -553,7 +569,7 @@ def create_view_and_report( Args: project: W&B project name - run_id: Unique run identifier + launch_id: Launch identifier for this group of jobs experiments: List of experiment names to create views for report_cfg: How to set up a wandb view, and optionally a report for the run, if at all. """ @@ -564,15 +580,15 @@ def create_view_and_report( logger.section("Creating workspace views...") workspace_urls: dict[str, str] = {} for experiment in experiments: - workspace_url = create_workspace_view(run_id, experiment, project) + workspace_url = create_workspace_view(launch_id, experiment, project) workspace_urls[experiment] = workspace_url # Create report if requested report_url: str | None = None if report_cfg is not None and len(experiments) > 1: report_url = create_wandb_report( - report_title=report_cfg.report_title or f"SPD Run Report - {run_id}", - run_id=run_id, + report_title=report_cfg.report_title or f"SPD Launch Report - {launch_id}", + launch_id=launch_id, branch_name=report_cfg.branch, commit_hash=report_cfg.commit_hash, experiments=experiments, diff --git a/tests/scripts_run/test_main.py b/tests/scripts_run/test_main.py index a6b1ea69e..912eeafa4 100644 --- a/tests/scripts_run/test_main.py +++ b/tests/scripts_run/test_main.py @@ -28,10 +28,10 @@ def test_invalid_experiment_name(self): @patch("spd.scripts.run.submit_slurm_job") @patch("spd.scripts.run.create_slurm_array_script") @patch("spd.scripts.run.create_git_snapshot") - @patch("spd.scripts.run._wandb_setup") + @patch("spd.scripts.run._create_wandb_views_and_report") def test_sweep_creates_slurm_array( self, - mock_wandb_setup, + mock_create_wandb_views_and_report, mock_create_git_snapshot, mock_create_slurm_array_script, mock_submit_slurm_job,