diff --git a/.gitignore b/.gitignore index 4780cbd03..b5601daf4 100644 --- a/.gitignore +++ b/.gitignore @@ -177,4 +177,6 @@ cython_debug/ #.idea/ **/*.db -**/*.db* \ No newline at end of file +**/*.db* + +.claude/worktrees \ No newline at end of file diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index 4c3d07753..bf8ee501a 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -13,40 +13,41 @@ from spd.app.backend.dependencies import DepLoadedRun from spd.app.backend.utils import log_errors +from spd.dataset_attributions.storage import AttrMetric, DatasetAttributionStorage from spd.dataset_attributions.storage import DatasetAttributionEntry as StorageEntry -from spd.dataset_attributions.storage import DatasetAttributionStorage +ATTR_METRICS: list[AttrMetric] = ["attr", "attr_abs", "mean_squared_attr"] -class DatasetAttributionEntry(BaseModel): - """A single entry in attribution results.""" +class DatasetAttributionEntry(BaseModel): component_key: str layer: str component_idx: int value: float + token_str: str | None = None class DatasetAttributionMetadata(BaseModel): - """Metadata about dataset attributions availability.""" - available: bool n_batches_processed: int | None n_tokens_processed: int | None n_component_layer_keys: int | None - vocab_size: int | None - d_model: int | None ci_threshold: float | None class ComponentAttributions(BaseModel): - """All attribution data for a single component (sources and targets, positive and negative).""" - positive_sources: list[DatasetAttributionEntry] negative_sources: list[DatasetAttributionEntry] positive_targets: list[DatasetAttributionEntry] negative_targets: list[DatasetAttributionEntry] +class AllMetricAttributions(BaseModel): + attr: ComponentAttributions + attr_abs: ComponentAttributions + mean_squared_attr: ComponentAttributions + + router = APIRouter(prefix="/api/dataset_attributions", tags=["dataset_attributions"]) NOT_AVAILABLE_MSG = ( @@ -54,27 +55,17 @@ class ComponentAttributions(BaseModel): ) -def _to_concrete_key(canonical_layer: str, component_idx: int, loaded: DepLoadedRun) -> str: - """Translate canonical layer + idx to concrete storage key. - - "embed" maps to the concrete embedding path (e.g. "wte") in storage. - "output" is a pseudo-layer used as-is in storage. - """ - if canonical_layer == "output": - return f"output:{component_idx}" - concrete = loaded.topology.canon_to_target(canonical_layer) - return f"{concrete}:{component_idx}" +def _storage_key(canonical_layer: str, component_idx: int) -> str: + return f"{canonical_layer}:{component_idx}" def _require_storage(loaded: DepLoadedRun) -> DatasetAttributionStorage: - """Get storage or raise 404.""" if loaded.attributions is None: raise HTTPException(status_code=404, detail=NOT_AVAILABLE_MSG) return loaded.attributions.get_attributions() def _require_source(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a source or raise 404.""" if not storage.has_source(component_key): raise HTTPException( status_code=404, @@ -83,7 +74,6 @@ def _require_source(storage: DatasetAttributionStorage, component_key: str) -> N def _require_target(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a target or raise 404.""" if not storage.has_target(component_key): raise HTTPException( status_code=404, @@ -92,43 +82,85 @@ def _require_target(storage: DatasetAttributionStorage, component_key: str) -> N def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: - """Get the unembedding matrix from the loaded model.""" return loaded.topology.get_unembed_weight() def _to_api_entries( - loaded: DepLoadedRun, entries: list[StorageEntry] + entries: list[StorageEntry], loaded: DepLoadedRun ) -> list[DatasetAttributionEntry]: - """Convert storage entries to API response format with canonical keys.""" - - def _canonicalize_layer(layer: str) -> str: - if layer == "output": - return layer - return loaded.topology.target_to_canon(layer) - return [ DatasetAttributionEntry( - component_key=f"{_canonicalize_layer(e.layer)}:{e.component_idx}", - layer=_canonicalize_layer(e.layer), + component_key=e.component_key, + layer=e.layer, component_idx=e.component_idx, value=e.value, + token_str=loaded.tokenizer.decode([e.component_idx]) + if e.layer in ("embed", "output") + else None, ) for e in entries ] +def _get_component_attributions_for_metric( + storage: DatasetAttributionStorage, + loaded: DepLoadedRun, + component_key: str, + k: int, + metric: AttrMetric, + is_source: bool, + is_target: bool, + w_unembed: Float[Tensor, "d_model vocab"] | None, +) -> ComponentAttributions: + return ComponentAttributions( + positive_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "positive", metric), loaded + ) + if is_target + else [], + negative_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "negative", metric), loaded + ) + if is_target + else [], + positive_targets=_to_api_entries( + storage.get_top_targets( + component_key, + k, + "positive", + metric, + w_unembed=w_unembed, + include_outputs=w_unembed is not None, + ), + loaded, + ) + if is_source + else [], + negative_targets=_to_api_entries( + storage.get_top_targets( + component_key, + k, + "negative", + metric, + w_unembed=w_unembed, + include_outputs=w_unembed is not None, + ), + loaded, + ) + if is_source + else [], + ) + + @router.get("/metadata") @log_errors def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata: - """Get metadata about dataset attributions availability.""" if loaded.attributions is None: return DatasetAttributionMetadata( available=False, n_batches_processed=None, n_tokens_processed=None, n_component_layer_keys=None, - vocab_size=None, - d_model=None, ci_threshold=None, ) storage = loaded.attributions.get_attributions() @@ -137,8 +169,6 @@ def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata n_batches_processed=storage.n_batches_processed, n_tokens_processed=storage.n_tokens_processed, n_component_layer_keys=storage.n_components, - vocab_size=storage.vocab_size, - d_model=storage.d_model, ci_threshold=storage.ci_threshold, ) @@ -150,12 +180,11 @@ def get_component_attributions( component_idx: int, loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, -) -> ComponentAttributions: - """Get all attribution data for a component (sources and targets, positive and negative).""" +) -> AllMetricAttributions: + """Get all attribution data for a component across all 3 metrics.""" storage = _require_storage(loaded) - component_key = _to_concrete_key(layer, component_idx, loaded) + component_key = _storage_key(layer, component_idx) - # Component can be both a source and a target, so we need to check both is_source = storage.has_source(component_key) is_target = storage.has_target(component_key) @@ -167,41 +196,13 @@ def get_component_attributions( w_unembed = _get_w_unembed(loaded) if is_source else None - return ComponentAttributions( - positive_sources=_to_api_entries( - loaded, storage.get_top_sources(component_key, k, "positive") - ) - if is_target - else [], - negative_sources=_to_api_entries( - loaded, storage.get_top_sources(component_key, k, "negative") - ) - if is_target - else [], - positive_targets=_to_api_entries( - loaded, - storage.get_top_targets( - component_key, - k, - "positive", - w_unembed=w_unembed, - include_outputs=w_unembed is not None, - ), - ) - if is_source - else [], - negative_targets=_to_api_entries( - loaded, - storage.get_top_targets( - component_key, - k, - "negative", - w_unembed=w_unembed, - include_outputs=w_unembed is not None, - ), - ) - if is_source - else [], + return AllMetricAttributions( + **{ + metric: _get_component_attributions_for_metric( + storage, loaded, component_key, k, metric, is_source, is_target, w_unembed + ) + for metric in ATTR_METRICS + } ) @@ -213,16 +214,16 @@ def get_attribution_sources( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k source components that attribute TO this target over the dataset.""" storage = _require_storage(loaded) - target_key = _to_concrete_key(layer, component_idx, loaded) + target_key = _storage_key(layer, component_idx) _require_target(storage, target_key) w_unembed = _get_w_unembed(loaded) if layer == "output" else None return _to_api_entries( - loaded, storage.get_top_sources(target_key, k, sign, w_unembed=w_unembed) + storage.get_top_sources(target_key, k, sign, metric, w_unembed=w_unembed), loaded ) @@ -234,16 +235,16 @@ def get_attribution_targets( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k target components this source attributes TO over the dataset.""" storage = _require_storage(loaded) - source_key = _to_concrete_key(layer, component_idx, loaded) + source_key = _storage_key(layer, component_idx) _require_source(storage, source_key) w_unembed = _get_w_unembed(loaded) return _to_api_entries( - loaded, storage.get_top_targets(source_key, k, sign, w_unembed=w_unembed) + storage.get_top_targets(source_key, k, sign, metric, w_unembed=w_unembed), loaded ) @@ -255,14 +256,14 @@ def get_attribution_between( target_layer: str, target_idx: int, loaded: DepLoadedRun, + metric: AttrMetric = "attr", ) -> float: - """Get attribution strength from source component to target component.""" storage = _require_storage(loaded) - source_key = _to_concrete_key(source_layer, source_idx, loaded) - target_key = _to_concrete_key(target_layer, target_idx, loaded) + source_key = _storage_key(source_layer, source_idx) + target_key = _storage_key(target_layer, target_idx) _require_source(storage, source_key) _require_target(storage, target_key) w_unembed = _get_w_unembed(loaded) if target_layer == "output" else None - return storage.get_attribution(source_key, target_key, w_unembed=w_unembed) + return storage.get_attribution(source_key, target_key, metric, w_unembed=w_unembed) diff --git a/spd/app/frontend/src/components/RunSelector.svelte b/spd/app/frontend/src/components/RunSelector.svelte index aa4728bd8..f174ee635 100644 --- a/spd/app/frontend/src/components/RunSelector.svelte +++ b/spd/app/frontend/src/components/RunSelector.svelte @@ -87,7 +87,7 @@
diff --git a/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte b/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte index a0d663208..640135c76 100644 --- a/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte +++ b/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte @@ -130,6 +130,16 @@ const currentNodeKey = $derived(`${layer}:${seqIdx}:${cIdx}`); const N_EDGES_TO_DISPLAY = 20; + function resolveTokenStr(nodeKey: string): string | null { + const parts = nodeKey.split(":"); + if (parts.length !== 3) return null; + const [layer, seqStr, cIdx] = parts; + const seqIdx = parseInt(seqStr); + if (layer === "embed") return tokens[seqIdx] ?? null; + if (layer === "output") return outputProbs[`${seqIdx}:${cIdx}`]?.token ?? null; + return null; + } + function getTopEdgeAttributions( edges: EdgeData[], isPositive: boolean, @@ -144,6 +154,7 @@ key: getKey(e), value: e.val, normalizedMagnitude: Math.abs(e.val) / maxAbsVal, + tokenStr: resolveTokenStr(getKey(e)), })); } @@ -249,8 +260,6 @@ {outgoingNegative} pageSize={COMPONENT_CARD_CONSTANTS.PROMPT_ATTRIBUTIONS_PAGE_SIZE} onClick={handleEdgeNodeClick} - {tokens} - {outputProbs} /> {/if} diff --git a/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte b/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte index cd86c9af1..1d8799d63 100644 --- a/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte +++ b/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte @@ -9,6 +9,7 @@ import { COMPONENT_CARD_CONSTANTS } from "../../lib/componentCardConstants"; import type { EdgeAttribution } from "../../lib/promptAttributionsTypes"; import type { DatasetAttributions } from "../../lib/useComponentData.svelte"; + import type { AttrMetric, DatasetAttributionEntry } from "../../lib/api/datasetAttributions"; import EdgeAttributionGrid from "./EdgeAttributionGrid.svelte"; type Props = { @@ -17,6 +18,7 @@ }; let { attributions, onComponentClick }: Props = $props(); + let selectedMetric = $state("attr"); function handleClick(key: string) { if (onComponentClick) { @@ -25,37 +27,124 @@ } function toEdgeAttribution( - entries: { component_key: string; value: number }[], + entries: DatasetAttributionEntry[], maxAbsValue: number, ): EdgeAttribution[] { return entries.map((e) => ({ key: e.component_key, value: e.value, normalizedMagnitude: Math.abs(e.value) / (maxAbsValue || 1), + tokenStr: e.token_str, })); } - const maxSourceVal = $derived( - Math.max(attributions.positive_sources[0]?.value ?? 0, Math.abs(attributions.negative_sources[0]?.value ?? 0)), - ); - const maxTargetVal = $derived( - Math.max(attributions.positive_targets[0]?.value ?? 0, Math.abs(attributions.negative_targets[0]?.value ?? 0)), - ); + function maxAbs(...vals: number[]): number { + return Math.max(...vals.map(Math.abs)); + } + + // attr: signed + const attrMaxSource = $derived(maxAbs(attributions.attr.positive_sources[0]?.value ?? 0, attributions.attr.negative_sources[0]?.value ?? 0)); + const attrMaxTarget = $derived(maxAbs(attributions.attr.positive_targets[0]?.value ?? 0, attributions.attr.negative_targets[0]?.value ?? 0)); - const positiveSources = $derived(toEdgeAttribution(attributions.positive_sources, maxSourceVal)); - const negativeSources = $derived(toEdgeAttribution(attributions.negative_sources, maxSourceVal)); - const positiveTargets = $derived(toEdgeAttribution(attributions.positive_targets, maxTargetVal)); - const negativeTargets = $derived(toEdgeAttribution(attributions.negative_targets, maxTargetVal)); + // attr_abs: signed + const absMaxSource = $derived(maxAbs(attributions.attr_abs.positive_sources[0]?.value ?? 0, attributions.attr_abs.negative_sources[0]?.value ?? 0)); + const absMaxTarget = $derived(maxAbs(attributions.attr_abs.positive_targets[0]?.value ?? 0, attributions.attr_abs.negative_targets[0]?.value ?? 0)); + + // mean_squared_attr: unsigned (positive only) + const rmsMaxSource = $derived(attributions.mean_squared_attr.positive_sources[0]?.value ?? 0); + const rmsMaxTarget = $derived(attributions.mean_squared_attr.positive_targets[0]?.value ?? 0); - +
+
+ + + +
+ + {#if selectedMetric === "attr"} + + {:else if selectedMetric === "attr_abs"} + + {:else} + + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte b/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte index 844cb7c04..c90bfc33e 100644 --- a/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte +++ b/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte @@ -1,5 +1,5 @@