Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions spd/app/backend/routers/pretrain_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,35 +128,20 @@ def _build_summary(model_type: str, target_model_config: dict[str, Any] | None)

n_layer = target_model_config.get("n_layer")
n_embd = target_model_config.get("n_embd")
n_intermediate = target_model_config.get("n_intermediate")
n_head = target_model_config.get("n_head")
n_kv = target_model_config.get("n_key_value_heads")
vocab = target_model_config.get("vocab_size")
ctx = target_model_config.get("n_ctx")

if n_layer is not None:
parts.append(f"{n_layer}L")
dims = []
if n_embd is not None:
dims.append(f"d={n_embd}")
if n_intermediate is not None:
dims.append(f"ff={n_intermediate}")
if dims:
parts.append(" ".join(dims))
parts.append(f"d={n_embd}")
heads = []
if n_head is not None:
heads.append(f"{n_head}h")
if n_kv is not None and n_kv != n_head:
heads.append(f"{n_kv}kv")
if heads:
parts.append("/".join(heads))
meta = []
if vocab is not None:
meta.append(f"vocab={vocab}")
if ctx is not None:
meta.append(f"ctx={ctx}")
if meta:
parts.append(" ".join(meta))

return " · ".join(parts)

Expand Down
105 changes: 104 additions & 1 deletion spd/app/backend/routers/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from spd.app.backend.state import RunState
from spd.app.backend.utils import log_errors
from spd.autointerp.repo import InterpRepo
from spd.autointerp.schemas import AUTOINTERP_DATA_DIR
from spd.configs import LMTaskConfig
from spd.dataset_attributions.repo import AttributionRepo
from spd.dataset_attributions.repo import AttributionRepo, get_attributions_dir
from spd.harvest.repo import HarvestRepo
from spd.log import logger
from spd.models.component_model import ComponentModel, SPDRunInfo
Expand Down Expand Up @@ -44,6 +45,19 @@ class LoadedRun(BaseModel):
dataset_search_enabled: bool


class DiscoveredRun(BaseModel):
run_id: str
n_labels: int
has_harvest: bool
has_detection: bool
has_fuzzing: bool
has_intruder: bool
has_dataset_attributions: bool
model_type: str | None
arch_summary: str | None
created_at: str | None


router = APIRouter(prefix="/api", tags=["runs"])

DEVICE = get_device()
Expand Down Expand Up @@ -180,3 +194,92 @@ def health_check() -> dict[str, str]:
def whoami() -> dict[str, str]:
"""Return the current backend user."""
return {"user": getpass.getuser()}


def _try_get_arch_info(run_id: str) -> tuple[str | None, str | None]:
"""Get model type and arch summary for a run. Downloads config from wandb if needed."""
from spd.app.backend.routers.pretrain_info import (
_get_pretrain_info,
_load_spd_config_lightweight,
)

try:
spd_config = _load_spd_config_lightweight(f"goodfire/spd/{run_id}")
info = _get_pretrain_info(spd_config)
return info.model_type, info.summary
except Exception:
logger.debug(f"[discover] Failed to get arch info for {run_id}", exc_info=True)
return None, None


def _fetch_wandb_created_dates(run_ids: list[str]) -> dict[str, str]:
"""Fetch created_at timestamps from wandb for a batch of runs."""
import wandb

api = wandb.Api()
dates: dict[str, str] = {}
for run_id in run_ids:
try:
r = api.run(f"goodfire/spd/{run_id}")
dates[run_id] = r.created_at
except Exception:
logger.debug(f"[discover] Failed to fetch wandb date for {run_id}")
return dates


@router.get("/runs/discover")
@log_errors
def discover_runs() -> list[DiscoveredRun]:
"""Scan SPD_OUT_DIR for all runs that have autointerp labels.

Returns runs sorted by wandb creation date (newest first). Includes arch
info from local cache.
"""
if not AUTOINTERP_DATA_DIR.exists():
return []

runs: list[DiscoveredRun] = []
for run_dir in AUTOINTERP_DATA_DIR.iterdir():
if not run_dir.is_dir():
continue
run_id = run_dir.name

interp = InterpRepo.open(run_id)
if interp is None:
continue

n_labels = interp.get_interpretation_count()
if n_labels == 0:
continue

score_types = interp.get_available_score_types()

harvest = HarvestRepo.open_most_recent(run_id)
has_intruder = bool(harvest.get_scores("intruder")) if harvest is not None else False

has_ds_attrs = get_attributions_dir(run_id).exists() and any(
d.is_dir() and d.name.startswith("da-") for d in get_attributions_dir(run_id).iterdir()
)

model_type, arch_summary = _try_get_arch_info(run_id)

runs.append(
DiscoveredRun(
run_id=run_id,
n_labels=n_labels,
has_harvest=harvest is not None,
has_detection="detection" in score_types,
has_fuzzing="fuzzing" in score_types,
has_intruder=has_intruder,
has_dataset_attributions=has_ds_attrs,
model_type=model_type,
arch_summary=arch_summary,
created_at=None,
)
)

dates = _fetch_wandb_created_dates([r.run_id for r in runs])
for r in runs:
r.created_at = dates.get(r.run_id)
runs.sort(key=lambda r: r.created_at or "", reverse=True)
return runs
Loading
Loading