Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions spd/dataset_attributions/scripts/run_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def submit_attributions(
time: Job time limit.
job_suffix: Optional suffix for SLURM job names (e.g., "1h" -> "spd-attr-1h").
"""
run_id = f"attr-{secrets.token_hex(4)}"
snapshot_branch, commit_hash = create_git_snapshot(run_id)
launch_id = f"attr-{secrets.token_hex(4)}"
snapshot_branch, commit_hash = create_git_snapshot(snapshot_id=launch_id)
logger.info(f"Created git snapshot: {snapshot_branch} ({commit_hash[:8]})")

suffix = f"-{job_suffix}" if job_suffix else ""
Expand Down
77 changes: 17 additions & 60 deletions spd/experiments/ih/ih_decomposition.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,43 @@
import json
"""Induction head decomposition script."""

from pathlib import Path

import fire
import wandb

from spd.configs import Config, IHTaskConfig
from spd.configs import IHTaskConfig
from spd.experiments.ih.model import InductionModelTargetRunInfo, InductionTransformer
from spd.log import logger
from spd.run_spd import optimize
from spd.run_spd import run_experiment
from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset
from spd.utils.distributed_utils import get_device
from spd.utils.general_utils import (
save_pre_run_info,
set_seed,
)
from spd.utils.run_utils import ExecutionStamp
from spd.utils.wandb_utils import init_wandb
from spd.utils.general_utils import set_seed
from spd.utils.run_utils import parse_config, parse_sweep_params


def main(
config_path: Path | str | None = None,
config_json: str | None = None,
evals_id: str | None = None,
sweep_id: str | None = None,
launch_id: str | None = None,
sweep_params_json: str | None = None,
run_id: str | None = None,
) -> None:
assert (config_path is not None) != (config_json is not None), (
"Need exactly one of config_path and config_json"
)
if config_path is not None:
config = Config.from_file(config_path)
else:
assert config_json is not None
config = Config(**json.loads(config_json.removeprefix("json:")))

sweep_params = (
None if sweep_params_json is None else json.loads(sweep_params_json.removeprefix("json:"))
)
config = parse_config(config_path, config_json)

device = get_device()
logger.info(f"Using device: {device}")

execution_stamp = ExecutionStamp.create(run_type="spd", create_snapshot=False)
out_dir = execution_stamp.out_dir
logger.info(f"Run ID: {execution_stamp.run_id}")
logger.info(f"Output directory: {out_dir}")

if config.wandb_project:
tags = ["ih"]
if evals_id:
tags.append(evals_id)
if sweep_id:
tags.append(sweep_id)
init_wandb(
config=config,
project=config.wandb_project,
run_id=execution_stamp.run_id,
name=config.wandb_run_name,
tags=tags,
)
set_seed(config.seed)

task_config = config.task_config
assert isinstance(task_config, IHTaskConfig)

set_seed(config.seed)
logger.info(config)

assert config.pretrained_model_path, "pretrained_model_path must be set"
target_run_info = InductionModelTargetRunInfo.from_path(config.pretrained_model_path)
target_model = InductionTransformer.from_run_info(target_run_info)
target_model = target_model.to(device)
target_model.eval()

save_pre_run_info(
save_to_wandb=config.wandb_project is not None,
out_dir=out_dir,
spd_config=config,
sweep_params=sweep_params,
target_model=target_model,
train_config=target_run_info.config,
task_name=config.task_config.task_name,
)

prefix_window = task_config.prefix_window or target_model.config.seq_len - 3

dataset = InductionDataset(
Expand All @@ -93,19 +49,20 @@ def main(
train_loader = DatasetGeneratedDataLoader(dataset, batch_size=config.batch_size, shuffle=False)
eval_loader = DatasetGeneratedDataLoader(dataset, batch_size=config.batch_size, shuffle=False)

optimize(
run_experiment(
target_model=target_model,
config=config,
device=device,
train_loader=train_loader,
eval_loader=eval_loader,
n_eval_steps=config.n_eval_steps,
out_dir=out_dir,
experiment_tag="ih",
run_id=run_id,
launch_id=launch_id,
evals_id=evals_id,
sweep_params=parse_sweep_params(sweep_params_json),
target_model_train_config=target_run_info.config,
)

if config.wandb_project:
wandb.finish()


if __name__ == "__main__":
fire.Fire(main)
72 changes: 13 additions & 59 deletions spd/experiments/lm/lm_decomposition.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
"""Language Model decomposition script."""

import json
from pathlib import Path

import fire
import wandb

from spd.configs import Config, LMTaskConfig
from spd.configs import LMTaskConfig
from spd.data import DatasetConfig, create_data_loader
from spd.log import logger
from spd.pretrain.run_info import PretrainRunInfo
from spd.run_spd import optimize
from spd.run_spd import run_experiment
from spd.utils.distributed_utils import (
DistributedState,
ensure_cached_and_call,
Expand All @@ -19,54 +17,27 @@
is_main_process,
with_distributed_cleanup,
)
from spd.utils.general_utils import resolve_class, save_pre_run_info, set_seed
from spd.utils.run_utils import setup_decomposition_run
from spd.utils.wandb_utils import init_wandb
from spd.utils.general_utils import resolve_class, set_seed
from spd.utils.run_utils import parse_config, parse_sweep_params


@with_distributed_cleanup
def main(
config_path: Path | str | None = None,
config_json: str | None = None,
evals_id: str | None = None,
sweep_id: str | None = None,
launch_id: str | None = None,
sweep_params_json: str | None = None,
run_id: str | None = None,
) -> None:
assert (config_path is not None) != (config_json is not None), (
"Need exactly one of config_path and config_json"
)
if config_path is not None:
config = Config.from_file(config_path)
else:
assert config_json is not None
config = Config(**json.loads(config_json.removeprefix("json:")))
config = parse_config(config_path, config_json)

dist_state = init_distributed()
logger.info(f"Distributed state: {dist_state}")

sweep_params = (
None if sweep_params_json is None else json.loads(sweep_params_json.removeprefix("json:"))
)

# Use the same seed across all ranks for deterministic data loading
set_seed(config.seed)

if is_main_process():
out_dir, run_id, tags = setup_decomposition_run(
experiment_tag="lm", evals_id=evals_id, sweep_id=sweep_id
)
if config.wandb_project:
init_wandb(
config=config,
project=config.wandb_project,
run_id=run_id,
name=config.wandb_run_name,
tags=tags,
)
logger.info(config)
else:
out_dir = None

device = get_device()
assert isinstance(config.task_config, LMTaskConfig), "task_config not LMTaskConfig"

Expand Down Expand Up @@ -96,18 +67,6 @@ def main(
)
target_model.eval()

if is_main_process():
assert out_dir is not None
save_pre_run_info(
save_to_wandb=config.wandb_project is not None,
out_dir=out_dir,
spd_config=config,
sweep_params=sweep_params,
target_model=None,
train_config=None,
task_name=None,
)

# --- Load Data --- #
if is_main_process():
logger.info("Loading dataset...")
Expand Down Expand Up @@ -171,24 +130,19 @@ def main(
dist_state=dist_state,
)

if is_main_process():
logger.info("Starting optimization...")

optimize(
run_experiment(
target_model=target_model,
config=config,
device=device,
train_loader=train_loader,
eval_loader=eval_loader,
n_eval_steps=config.n_eval_steps,
out_dir=out_dir,
experiment_tag="lm",
run_id=run_id,
launch_id=launch_id,
evals_id=evals_id,
sweep_params=parse_sweep_params(sweep_params_json),
)

if is_main_process():
logger.info("Optimization finished.")
if config.wandb_project:
wandb.finish()


if __name__ == "__main__":
fire.Fire(main)
70 changes: 19 additions & 51 deletions spd/experiments/resid_mlp/resid_mlp_decomposition.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,36 @@
"""Residual MLP decomposition script."""

import json
from pathlib import Path

import fire
import wandb

from spd.configs import Config, ResidMLPTaskConfig
from spd.configs import ResidMLPTaskConfig
from spd.experiments.resid_mlp.models import (
ResidMLP,
ResidMLPTargetRunInfo,
)
from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset
from spd.log import logger
from spd.run_spd import optimize
from spd.run_spd import run_experiment
from spd.settings import SPD_OUT_DIR
from spd.utils.data_utils import DatasetGeneratedDataLoader
from spd.utils.distributed_utils import get_device
from spd.utils.general_utils import save_pre_run_info, set_seed
from spd.utils.run_utils import save_file, setup_decomposition_run
from spd.utils.wandb_utils import init_wandb
from spd.utils.general_utils import set_seed
from spd.utils.run_utils import generate_run_id, parse_config, parse_sweep_params, save_file


def main(
config_path: Path | str | None = None,
config_json: str | None = None,
evals_id: str | None = None,
sweep_id: str | None = None,
launch_id: str | None = None,
sweep_params_json: str | None = None,
run_id: str | None = None,
) -> None:
assert (config_path is not None) != (config_json is not None), (
"Need exactly one of config_path and config_json"
)
if config_path is not None:
config = Config.from_file(config_path)
else:
assert config_json is not None
config = Config(**json.loads(config_json.removeprefix("json:")))

sweep_params = (
None if sweep_params_json is None else json.loads(sweep_params_json.removeprefix("json:"))
)
config = parse_config(config_path, config_json)

set_seed(config.seed)

out_dir, run_id, tags = setup_decomposition_run(
experiment_tag="resid_mlp", evals_id=evals_id, sweep_id=sweep_id
)
if config.wandb_project:
init_wandb(
config=config,
project=config.wandb_project,
run_id=run_id,
name=config.wandb_run_name,
tags=tags,
)
logger.info(config)

device = get_device()
logger.info(f"Using device: {device}")
assert isinstance(config.task_config, ResidMLPTaskConfig)
Expand All @@ -66,18 +41,11 @@ def main(
target_model = target_model.to(device)
target_model.eval()

save_pre_run_info(
save_to_wandb=config.wandb_project is not None,
out_dir=out_dir,
spd_config=config,
sweep_params=sweep_params,
target_model=target_model,
train_config=target_run_info.config,
task_name=config.task_config.task_name,
)
# Domain-specific: save label coefficients to out_dir
run_id = run_id or generate_run_id("spd")
out_dir = SPD_OUT_DIR / "spd" / run_id
out_dir.mkdir(parents=True, exist_ok=True)
save_file(target_run_info.label_coeffs.detach().cpu().tolist(), out_dir / "label_coeffs.json")
if config.wandb_project:
wandb.save(str(out_dir / "label_coeffs.json"), base_path=out_dir, policy="now")

synced_inputs = target_run_info.config.synced_inputs
dataset = ResidMLPDataset(
Expand All @@ -98,21 +66,21 @@ def main(
dataset, batch_size=config.eval_batch_size, shuffle=False
)

# TODO: Below not needed when TMS supports config.n_eval_steps
assert config.n_eval_steps is not None, "n_eval_steps must be set"
optimize(
run_experiment(
target_model=target_model,
config=config,
device=device,
train_loader=train_loader,
eval_loader=eval_loader,
n_eval_steps=config.n_eval_steps,
out_dir=out_dir,
experiment_tag="resid_mlp",
run_id=run_id,
launch_id=launch_id,
evals_id=evals_id,
sweep_params=parse_sweep_params(sweep_params_json),
target_model_train_config=target_run_info.config,
)

if config.wandb_project:
wandb.finish()


if __name__ == "__main__":
fire.Fire(main)
Loading