Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
6d76160
add worktrees to ignore
ocg-goodfire Feb 23, 2026
3bcaddd
Rewrite dataset attribution storage: dict-of-dicts, canonical names, …
ocg-goodfire Feb 23, 2026
9d9a4a3
Fix alive_targets iteration: use torch.where for indices, not bool to…
ocg-goodfire Feb 23, 2026
3c0ba4b
Fix KeyError for embed source: CI dict doesn't include embedding layer
ocg-goodfire Feb 23, 2026
d736f42
Fix scatter_add OOB: use embedding num_embeddings instead of tokenize…
ocg-goodfire Feb 23, 2026
8271905
Split run.py into run_worker.py and run_merge.py
ocg-goodfire Feb 23, 2026
7650495
Correct attr_abs via backprop through |target|, reorganise method sig…
ocg-goodfire Feb 23, 2026
ccf713f
Add merge_mem config (default 200G) to prevent merge OOM
ocg-goodfire Feb 23, 2026
8139fb1
Add 3-metric selection to dataset attributions in app
ocg-goodfire Feb 23, 2026
1bf9877
Allow bare s-prefixed run IDs everywhere (e.g. "s-17805b61")
ocg-goodfire Feb 23, 2026
627df2b
Fix AttributionRepo.open skipping valid subruns due to old-format dirs
ocg-goodfire Feb 23, 2026
73a7ba0
Fix 3s lag on attribution metric toggle: O(V) linear scan per pill
ocg-goodfire Feb 23, 2026
5798178
Ship token strings from backend instead of resolving vocab IDs in fro…
ocg-goodfire Feb 23, 2026
538742c
Hide negative attribution column for non-signed metrics
ocg-goodfire Feb 23, 2026
3b8bf4e
Narrow frontend types: SignedAttributions vs UnsignedAttributions
ocg-goodfire Feb 23, 2026
bcbb212
Update dataset_attributions CLAUDE.md for new storage format and 3 me…
ocg-goodfire Feb 23, 2026
fb3c5a0
Remove stray worktree files from PR
ocg-goodfire Feb 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,6 @@ cython_debug/
#.idea/

**/*.db
**/*.db*
**/*.db*

.claude/worktrees
175 changes: 88 additions & 87 deletions spd/app/backend/routers/dataset_attributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,68 +13,59 @@

from spd.app.backend.dependencies import DepLoadedRun
from spd.app.backend.utils import log_errors
from spd.dataset_attributions.storage import AttrMetric, DatasetAttributionStorage
from spd.dataset_attributions.storage import DatasetAttributionEntry as StorageEntry
from spd.dataset_attributions.storage import DatasetAttributionStorage

ATTR_METRICS: list[AttrMetric] = ["attr", "attr_abs", "mean_squared_attr"]

class DatasetAttributionEntry(BaseModel):
"""A single entry in attribution results."""

class DatasetAttributionEntry(BaseModel):
component_key: str
layer: str
component_idx: int
value: float
token_str: str | None = None


class DatasetAttributionMetadata(BaseModel):
"""Metadata about dataset attributions availability."""

available: bool
n_batches_processed: int | None
n_tokens_processed: int | None
n_component_layer_keys: int | None
vocab_size: int | None
d_model: int | None
ci_threshold: float | None


class ComponentAttributions(BaseModel):
"""All attribution data for a single component (sources and targets, positive and negative)."""

positive_sources: list[DatasetAttributionEntry]
negative_sources: list[DatasetAttributionEntry]
positive_targets: list[DatasetAttributionEntry]
negative_targets: list[DatasetAttributionEntry]


class AllMetricAttributions(BaseModel):
attr: ComponentAttributions
attr_abs: ComponentAttributions
mean_squared_attr: ComponentAttributions


router = APIRouter(prefix="/api/dataset_attributions", tags=["dataset_attributions"])

NOT_AVAILABLE_MSG = (
"Dataset attributions not available. Run: spd-attributions <wandb_path> --n_batches N"
)


def _to_concrete_key(canonical_layer: str, component_idx: int, loaded: DepLoadedRun) -> str:
"""Translate canonical layer + idx to concrete storage key.

"embed" maps to the concrete embedding path (e.g. "wte") in storage.
"output" is a pseudo-layer used as-is in storage.
"""
if canonical_layer == "output":
return f"output:{component_idx}"
concrete = loaded.topology.canon_to_target(canonical_layer)
return f"{concrete}:{component_idx}"
def _storage_key(canonical_layer: str, component_idx: int) -> str:
return f"{canonical_layer}:{component_idx}"


def _require_storage(loaded: DepLoadedRun) -> DatasetAttributionStorage:
"""Get storage or raise 404."""
if loaded.attributions is None:
raise HTTPException(status_code=404, detail=NOT_AVAILABLE_MSG)
return loaded.attributions.get_attributions()


def _require_source(storage: DatasetAttributionStorage, component_key: str) -> None:
"""Validate component exists as a source or raise 404."""
if not storage.has_source(component_key):
raise HTTPException(
status_code=404,
Expand All @@ -83,7 +74,6 @@ def _require_source(storage: DatasetAttributionStorage, component_key: str) -> N


def _require_target(storage: DatasetAttributionStorage, component_key: str) -> None:
"""Validate component exists as a target or raise 404."""
if not storage.has_target(component_key):
raise HTTPException(
status_code=404,
Expand All @@ -92,43 +82,85 @@ def _require_target(storage: DatasetAttributionStorage, component_key: str) -> N


def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]:
"""Get the unembedding matrix from the loaded model."""
return loaded.topology.get_unembed_weight()


def _to_api_entries(
loaded: DepLoadedRun, entries: list[StorageEntry]
entries: list[StorageEntry], loaded: DepLoadedRun
) -> list[DatasetAttributionEntry]:
"""Convert storage entries to API response format with canonical keys."""

def _canonicalize_layer(layer: str) -> str:
if layer == "output":
return layer
return loaded.topology.target_to_canon(layer)

return [
DatasetAttributionEntry(
component_key=f"{_canonicalize_layer(e.layer)}:{e.component_idx}",
layer=_canonicalize_layer(e.layer),
component_key=e.component_key,
layer=e.layer,
component_idx=e.component_idx,
value=e.value,
token_str=loaded.tokenizer.decode([e.component_idx])
if e.layer in ("embed", "output")
else None,
)
for e in entries
]


def _get_component_attributions_for_metric(
storage: DatasetAttributionStorage,
loaded: DepLoadedRun,
component_key: str,
k: int,
metric: AttrMetric,
is_source: bool,
is_target: bool,
w_unembed: Float[Tensor, "d_model vocab"] | None,
) -> ComponentAttributions:
return ComponentAttributions(
positive_sources=_to_api_entries(
storage.get_top_sources(component_key, k, "positive", metric), loaded
)
if is_target
else [],
negative_sources=_to_api_entries(
storage.get_top_sources(component_key, k, "negative", metric), loaded
)
if is_target
else [],
positive_targets=_to_api_entries(
storage.get_top_targets(
component_key,
k,
"positive",
metric,
w_unembed=w_unembed,
include_outputs=w_unembed is not None,
),
loaded,
)
if is_source
else [],
negative_targets=_to_api_entries(
storage.get_top_targets(
component_key,
k,
"negative",
metric,
w_unembed=w_unembed,
include_outputs=w_unembed is not None,
),
loaded,
)
if is_source
else [],
)


@router.get("/metadata")
@log_errors
def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata:
"""Get metadata about dataset attributions availability."""
if loaded.attributions is None:
return DatasetAttributionMetadata(
available=False,
n_batches_processed=None,
n_tokens_processed=None,
n_component_layer_keys=None,
vocab_size=None,
d_model=None,
ci_threshold=None,
)
storage = loaded.attributions.get_attributions()
Expand All @@ -137,8 +169,6 @@ def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata
n_batches_processed=storage.n_batches_processed,
n_tokens_processed=storage.n_tokens_processed,
n_component_layer_keys=storage.n_components,
vocab_size=storage.vocab_size,
d_model=storage.d_model,
ci_threshold=storage.ci_threshold,
)

Expand All @@ -150,12 +180,11 @@ def get_component_attributions(
component_idx: int,
loaded: DepLoadedRun,
k: Annotated[int, Query(ge=1)] = 10,
) -> ComponentAttributions:
"""Get all attribution data for a component (sources and targets, positive and negative)."""
) -> AllMetricAttributions:
"""Get all attribution data for a component across all 3 metrics."""
storage = _require_storage(loaded)
component_key = _to_concrete_key(layer, component_idx, loaded)
component_key = _storage_key(layer, component_idx)

# Component can be both a source and a target, so we need to check both
is_source = storage.has_source(component_key)
is_target = storage.has_target(component_key)

Expand All @@ -167,41 +196,13 @@ def get_component_attributions(

w_unembed = _get_w_unembed(loaded) if is_source else None

return ComponentAttributions(
positive_sources=_to_api_entries(
loaded, storage.get_top_sources(component_key, k, "positive")
)
if is_target
else [],
negative_sources=_to_api_entries(
loaded, storage.get_top_sources(component_key, k, "negative")
)
if is_target
else [],
positive_targets=_to_api_entries(
loaded,
storage.get_top_targets(
component_key,
k,
"positive",
w_unembed=w_unembed,
include_outputs=w_unembed is not None,
),
)
if is_source
else [],
negative_targets=_to_api_entries(
loaded,
storage.get_top_targets(
component_key,
k,
"negative",
w_unembed=w_unembed,
include_outputs=w_unembed is not None,
),
)
if is_source
else [],
return AllMetricAttributions(
**{
metric: _get_component_attributions_for_metric(
storage, loaded, component_key, k, metric, is_source, is_target, w_unembed
)
for metric in ATTR_METRICS
}
)


Expand All @@ -213,16 +214,16 @@ def get_attribution_sources(
loaded: DepLoadedRun,
k: Annotated[int, Query(ge=1)] = 10,
sign: Literal["positive", "negative"] = "positive",
metric: AttrMetric = "attr",
) -> list[DatasetAttributionEntry]:
"""Get top-k source components that attribute TO this target over the dataset."""
storage = _require_storage(loaded)
target_key = _to_concrete_key(layer, component_idx, loaded)
target_key = _storage_key(layer, component_idx)
_require_target(storage, target_key)

w_unembed = _get_w_unembed(loaded) if layer == "output" else None

return _to_api_entries(
loaded, storage.get_top_sources(target_key, k, sign, w_unembed=w_unembed)
storage.get_top_sources(target_key, k, sign, metric, w_unembed=w_unembed), loaded
)


Expand All @@ -234,16 +235,16 @@ def get_attribution_targets(
loaded: DepLoadedRun,
k: Annotated[int, Query(ge=1)] = 10,
sign: Literal["positive", "negative"] = "positive",
metric: AttrMetric = "attr",
) -> list[DatasetAttributionEntry]:
"""Get top-k target components this source attributes TO over the dataset."""
storage = _require_storage(loaded)
source_key = _to_concrete_key(layer, component_idx, loaded)
source_key = _storage_key(layer, component_idx)
_require_source(storage, source_key)

w_unembed = _get_w_unembed(loaded)

return _to_api_entries(
loaded, storage.get_top_targets(source_key, k, sign, w_unembed=w_unembed)
storage.get_top_targets(source_key, k, sign, metric, w_unembed=w_unembed), loaded
)


Expand All @@ -255,14 +256,14 @@ def get_attribution_between(
target_layer: str,
target_idx: int,
loaded: DepLoadedRun,
metric: AttrMetric = "attr",
) -> float:
"""Get attribution strength from source component to target component."""
storage = _require_storage(loaded)
source_key = _to_concrete_key(source_layer, source_idx, loaded)
target_key = _to_concrete_key(target_layer, target_idx, loaded)
source_key = _storage_key(source_layer, source_idx)
target_key = _storage_key(target_layer, target_idx)
_require_source(storage, source_key)
_require_target(storage, target_key)

w_unembed = _get_w_unembed(loaded) if target_layer == "output" else None

return storage.get_attribution(source_key, target_key, w_unembed=w_unembed)
return storage.get_attribution(source_key, target_key, metric, w_unembed=w_unembed)
2 changes: 1 addition & 1 deletion spd/app/frontend/src/components/RunSelector.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
<form class="custom-form" onsubmit={handleCustomSubmit}>
<input
type="text"
placeholder="e.g. goodfire/spd/runs/33n6xjjt"
placeholder="e.g. s-17805b61 or goodfire/spd/runs/33n6xjjt"
bind:value={customPath}
disabled={isLoading}
/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@
const currentNodeKey = $derived(`${layer}:${seqIdx}:${cIdx}`);
const N_EDGES_TO_DISPLAY = 20;

function resolveTokenStr(nodeKey: string): string | null {
const parts = nodeKey.split(":");
if (parts.length !== 3) return null;
const [layer, seqStr, cIdx] = parts;
const seqIdx = parseInt(seqStr);
if (layer === "embed") return tokens[seqIdx] ?? null;
if (layer === "output") return outputProbs[`${seqIdx}:${cIdx}`]?.token ?? null;
return null;
}

function getTopEdgeAttributions(
edges: EdgeData[],
isPositive: boolean,
Expand All @@ -144,6 +154,7 @@
key: getKey(e),
value: e.val,
normalizedMagnitude: Math.abs(e.val) / maxAbsVal,
tokenStr: resolveTokenStr(getKey(e)),
}));
}

Expand Down Expand Up @@ -249,8 +260,6 @@
{outgoingNegative}
pageSize={COMPONENT_CARD_CONSTANTS.PROMPT_ATTRIBUTIONS_PAGE_SIZE}
onClick={handleEdgeNodeClick}
{tokens}
{outputProbs}
/>
{/if}

Expand Down
Loading