diff --git a/spd/app/backend/routers/pretrain_info.py b/spd/app/backend/routers/pretrain_info.py index 2872423c9..1cf0d2258 100644 --- a/spd/app/backend/routers/pretrain_info.py +++ b/spd/app/backend/routers/pretrain_info.py @@ -128,21 +128,13 @@ def _build_summary(model_type: str, target_model_config: dict[str, Any] | None) n_layer = target_model_config.get("n_layer") n_embd = target_model_config.get("n_embd") - n_intermediate = target_model_config.get("n_intermediate") n_head = target_model_config.get("n_head") n_kv = target_model_config.get("n_key_value_heads") - vocab = target_model_config.get("vocab_size") - ctx = target_model_config.get("n_ctx") if n_layer is not None: parts.append(f"{n_layer}L") - dims = [] if n_embd is not None: - dims.append(f"d={n_embd}") - if n_intermediate is not None: - dims.append(f"ff={n_intermediate}") - if dims: - parts.append(" ".join(dims)) + parts.append(f"d={n_embd}") heads = [] if n_head is not None: heads.append(f"{n_head}h") @@ -150,13 +142,6 @@ def _build_summary(model_type: str, target_model_config: dict[str, Any] | None) heads.append(f"{n_kv}kv") if heads: parts.append("/".join(heads)) - meta = [] - if vocab is not None: - meta.append(f"vocab={vocab}") - if ctx is not None: - meta.append(f"ctx={ctx}") - if meta: - parts.append(" ".join(meta)) return " ยท ".join(parts) diff --git a/spd/app/backend/routers/runs.py b/spd/app/backend/routers/runs.py index 0989cea54..27dac0865 100644 --- a/spd/app/backend/routers/runs.py +++ b/spd/app/backend/routers/runs.py @@ -13,8 +13,9 @@ from spd.app.backend.state import RunState from spd.app.backend.utils import log_errors from spd.autointerp.repo import InterpRepo +from spd.autointerp.schemas import AUTOINTERP_DATA_DIR from spd.configs import LMTaskConfig -from spd.dataset_attributions.repo import AttributionRepo +from spd.dataset_attributions.repo import AttributionRepo, get_attributions_dir from spd.harvest.repo import HarvestRepo from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo @@ -44,6 +45,19 @@ class LoadedRun(BaseModel): dataset_search_enabled: bool +class DiscoveredRun(BaseModel): + run_id: str + n_labels: int + has_harvest: bool + has_detection: bool + has_fuzzing: bool + has_intruder: bool + has_dataset_attributions: bool + model_type: str | None + arch_summary: str | None + created_at: str | None + + router = APIRouter(prefix="/api", tags=["runs"]) DEVICE = get_device() @@ -180,3 +194,92 @@ def health_check() -> dict[str, str]: def whoami() -> dict[str, str]: """Return the current backend user.""" return {"user": getpass.getuser()} + + +def _try_get_arch_info(run_id: str) -> tuple[str | None, str | None]: + """Get model type and arch summary for a run. Downloads config from wandb if needed.""" + from spd.app.backend.routers.pretrain_info import ( + _get_pretrain_info, + _load_spd_config_lightweight, + ) + + try: + spd_config = _load_spd_config_lightweight(f"goodfire/spd/{run_id}") + info = _get_pretrain_info(spd_config) + return info.model_type, info.summary + except Exception: + logger.debug(f"[discover] Failed to get arch info for {run_id}", exc_info=True) + return None, None + + +def _fetch_wandb_created_dates(run_ids: list[str]) -> dict[str, str]: + """Fetch created_at timestamps from wandb for a batch of runs.""" + import wandb + + api = wandb.Api() + dates: dict[str, str] = {} + for run_id in run_ids: + try: + r = api.run(f"goodfire/spd/{run_id}") + dates[run_id] = r.created_at + except Exception: + logger.debug(f"[discover] Failed to fetch wandb date for {run_id}") + return dates + + +@router.get("/runs/discover") +@log_errors +def discover_runs() -> list[DiscoveredRun]: + """Scan SPD_OUT_DIR for all runs that have autointerp labels. + + Returns runs sorted by wandb creation date (newest first). Includes arch + info from local cache. + """ + if not AUTOINTERP_DATA_DIR.exists(): + return [] + + runs: list[DiscoveredRun] = [] + for run_dir in AUTOINTERP_DATA_DIR.iterdir(): + if not run_dir.is_dir(): + continue + run_id = run_dir.name + + interp = InterpRepo.open(run_id) + if interp is None: + continue + + n_labels = interp.get_interpretation_count() + if n_labels == 0: + continue + + score_types = interp.get_available_score_types() + + harvest = HarvestRepo.open_most_recent(run_id) + has_intruder = bool(harvest.get_scores("intruder")) if harvest is not None else False + + has_ds_attrs = get_attributions_dir(run_id).exists() and any( + d.is_dir() and d.name.startswith("da-") for d in get_attributions_dir(run_id).iterdir() + ) + + model_type, arch_summary = _try_get_arch_info(run_id) + + runs.append( + DiscoveredRun( + run_id=run_id, + n_labels=n_labels, + has_harvest=harvest is not None, + has_detection="detection" in score_types, + has_fuzzing="fuzzing" in score_types, + has_intruder=has_intruder, + has_dataset_attributions=has_ds_attrs, + model_type=model_type, + arch_summary=arch_summary, + created_at=None, + ) + ) + + dates = _fetch_wandb_created_dates([r.run_id for r in runs]) + for r in runs: + r.created_at = dates.get(r.run_id) + runs.sort(key=lambda r: r.created_at or "", reverse=True) + return runs diff --git a/spd/app/frontend/src/components/RunSelector.svelte b/spd/app/frontend/src/components/RunSelector.svelte index aa4728bd8..e70d182f3 100644 --- a/spd/app/frontend/src/components/RunSelector.svelte +++ b/spd/app/frontend/src/components/RunSelector.svelte @@ -1,7 +1,7 @@
@@ -59,25 +130,56 @@ {/if} -
- {#each CANONICAL_RUNS as entry (entry.wandbRunId)} - {@const info = archInfo[entry.wandbRunId]} - - {/each} +
+ {#if mergedRuns === null} +
+ {#each Array(12) as _} +
+
+
+
+
+ {/each} +
+ {:else} + {#each mergedRuns as run (run.runId)} + {@const d = run.discovered} + {@const pills = d ? presentPills(d) : []} + {@const arch = d?.arch_summary ?? null} + + {/each} + {/if}
@@ -144,7 +246,7 @@ } .selector-content { - max-width: 720px; + max-width: 960px; width: 100%; transition: opacity var(--transition-slow); } @@ -158,84 +260,159 @@ font-size: var(--text-3xl); font-weight: 600; color: var(--text-primary); - margin: 0 0 var(--space-2) 0; + margin: 0 0 var(--space-4) 0; text-align: center; font-family: var(--font-sans); } - .runs-grid { - display: grid; - grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); - gap: var(--space-3); + .runs-list { + height: 480px; + overflow-y: auto; + border: 1px solid var(--border-default); + border-radius: var(--radius-md); margin-bottom: var(--space-6); } - .run-card { + /* Loading skeleton */ + .loading-skeleton { display: flex; flex-direction: column; - align-items: flex-start; - gap: var(--space-1); - padding: var(--space-3); - background: var(--bg-surface); - border: 1px solid var(--border-default); - border-radius: var(--radius-md); + } + + .skeleton-row { + display: flex; + align-items: center; + gap: var(--space-3); + padding: 10px var(--space-2); + } + + .skeleton-block { + height: 10px; + border-radius: 4px; + background: var(--border-default); + animation: pulse 1.5s ease-in-out infinite; + } + + .skeleton-block.id { + width: 80px; + flex-shrink: 0; + } + + .skeleton-block.model { + flex: 1; + } + + .skeleton-block.data { + width: 100px; + flex-shrink: 0; + } + + @keyframes pulse { + 0%, + 100% { + opacity: 0.2; + } + 50% { + opacity: 0.45; + } + } + + /* Run rows */ + .run-row { + display: flex; + align-items: center; + gap: var(--space-3); + padding: 7px var(--space-2); + width: 100%; + background: transparent; + border: none; + border-radius: var(--radius-sm); cursor: pointer; text-align: left; - transition: - border-color var(--transition-normal), - background var(--transition-normal); + transition: background var(--transition-normal); } - .run-card:hover:not(:disabled) { - border-color: var(--accent-primary); + .run-row:hover:not(:disabled) { background: var(--bg-elevated); } - .run-card:disabled { + .run-row:disabled { opacity: 0.5; cursor: not-allowed; } - .run-model { - font-size: var(--text-sm); - font-weight: 600; - color: var(--text-primary); - font-family: var(--font-sans); + .col-id { + display: flex; + flex-direction: column; + gap: 1px; + flex-shrink: 0; + width: 140px; } .run-id { font-size: var(--text-xs); font-family: var(--font-mono); color: var(--accent-primary); + font-weight: 500; } .run-notes { - font-size: var(--text-xs); + font-size: 10px; color: var(--text-muted); font-family: var(--font-sans); + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + line-height: 1.2; + } + + .col-model { + display: flex; + flex-direction: column; + gap: 1px; + flex: 1; + min-width: 0; } .run-arch { font-size: 10px; font-family: var(--font-mono); - color: var(--text-secondary, var(--text-muted)); - background: var(--bg-inset, var(--bg-base)); - padding: 1px 4px; - border-radius: 3px; + color: var(--text-muted); + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; line-height: 1.3; } - .run-arch.loading { - opacity: 0.5; - font-style: italic; - font-family: var(--font-sans); - background: none; + .col-data { + display: flex; + align-items: center; + gap: var(--space-2); + flex-shrink: 0; } - .run-cluster-mappings { - font-size: var(--text-xs); + .label-count { + font-size: 11px; color: var(--text-muted); + font-family: var(--font-mono); + white-space: nowrap; + } + + .pills { + display: flex; + gap: 3px; + } + + .pill { + font-size: 9px; font-family: var(--font-sans); + font-weight: 500; + padding: 1px 6px; + border-radius: 9999px; + line-height: 1.4; + white-space: nowrap; + background: color-mix(in srgb, var(--accent-primary) 15%, transparent); + color: var(--accent-primary); } .divider { diff --git a/spd/app/frontend/src/lib/api/discover.ts b/spd/app/frontend/src/lib/api/discover.ts new file mode 100644 index 000000000..5f6e40d6b --- /dev/null +++ b/spd/app/frontend/src/lib/api/discover.ts @@ -0,0 +1,22 @@ +/** + * API client for /api/runs/discover endpoint. + */ + +import { fetchJson } from "./index"; + +export type DiscoveredRun = { + run_id: string; + n_labels: number; + has_harvest: boolean; + has_detection: boolean; + has_fuzzing: boolean; + has_intruder: boolean; + has_dataset_attributions: boolean; + model_type: string | null; + arch_summary: string | null; + created_at: string | null; +}; + +export async function discoverRuns(): Promise { + return fetchJson("/api/runs/discover"); +} diff --git a/spd/app/frontend/src/lib/api/index.ts b/spd/app/frontend/src/lib/api/index.ts index 773663636..057e4e9ed 100644 --- a/spd/app/frontend/src/lib/api/index.ts +++ b/spd/app/frontend/src/lib/api/index.ts @@ -53,3 +53,4 @@ export * from "./dataset"; export * from "./clusters"; export * from "./dataSources"; export * from "./pretrainInfo"; +export * from "./discover";