Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,27 @@
- Filter: Only plots datapoints where CI > threshold

Usage:
python -m spd.scripts.plot_component_activations s-7884efcc
python -m spd.scripts.plot_component_activations s-7884efcc --ci-threshold 0.0
python -m spd.scripts.plot_component_activations.plot_component_activations \
wandb:goodfire/spd/runs/<run_id>
"""

import argparse
from collections import defaultdict
from pathlib import Path

import fire
import matplotlib.pyplot as plt
import numpy as np

from spd.harvest.repo import HarvestRepo
from spd.harvest.schemas import ComponentData
from spd.log import logger
from spd.spd_types import ModelPath
from spd.utils.wandb_utils import parse_wandb_run_path

SCRIPT_DIR = Path(__file__).parent

def extract_activations(

def _extract_activations(
components: list[ComponentData],
ci_threshold: float,
) -> tuple[dict[str, dict[str, list[float]]], dict[str, dict[str, list[float]]]]:
Expand All @@ -47,7 +52,7 @@ def extract_activations(
return dict(all_activations), dict(filtered_activations)


def normalize_per_component(
def _normalize_per_component(
all_activations: dict[str, list[float]],
filtered_activations: dict[str, list[float]],
) -> dict[str, np.ndarray]:
Expand All @@ -67,14 +72,14 @@ def normalize_per_component(
return normalized


def order_by_median(normalized: dict[str, np.ndarray]) -> list[str]:
def _order_by_median(normalized: dict[str, np.ndarray]) -> list[str]:
"""Order component keys by median of their normalized activations (descending)."""
medians = [(key, np.median(acts)) for key, acts in normalized.items()]
medians.sort(key=lambda x: x[1], reverse=True)
return [key for key, _ in medians]


def order_by_frequency(
def _order_by_frequency(
normalized: dict[str, np.ndarray], firing_counts: dict[str, int]
) -> list[str]:
"""Order component keys by pre-calculated firing counts (descending)."""
Expand All @@ -83,14 +88,14 @@ def order_by_frequency(
return [key for key, _ in freqs]


def create_layer_scatter_plot(
def _create_layer_scatter_plot(
normalized_by_key: dict[str, np.ndarray],
ordered_keys: list[str],
layer_name: str,
run_id: str,
output_path: Path,
x_label: str = "Component Rank (by median activation)",
y_label: str = "Normalized Component Activation",
x_label: str,
y_label: str,
) -> None:
"""Create scatter plot for a single layer."""
x_vals = []
Expand Down Expand Up @@ -123,86 +128,87 @@ def create_layer_scatter_plot(
plt.close(fig)


def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("run_id", help="WandB run ID (e.g., 's-7884efcc')")
parser.add_argument(
"--ci-threshold",
type=float,
default=0.1,
help="Minimum CI value to include (default: 0.1)",
)
args = parser.parse_args()
def plot_component_activations(
wandb_path: ModelPath,
ci_threshold: float = 0.1,
) -> None:
_entity, _project, run_id = parse_wandb_run_path(str(wandb_path))

base_output_dir = Path(__file__).parent / "out" / args.run_id / "component-act-scatter"
output_dir_median = base_output_dir / "order-by-median"
output_dir_freq = base_output_dir / "order-by-freq"
output_dir_median.mkdir(parents=True, exist_ok=True)
output_dir_freq.mkdir(parents=True, exist_ok=True)
out_dir = SCRIPT_DIR / "out" / run_id
out_dir_median = out_dir / "order-by-median"
out_dir_freq = out_dir / "order-by-freq"
out_dir_median.mkdir(parents=True, exist_ok=True)
out_dir_freq.mkdir(parents=True, exist_ok=True)

repo = HarvestRepo.open(args.run_id)
assert repo is not None, f"No harvest data for {args.run_id}"
repo = HarvestRepo.open(run_id)
assert repo is not None, f"No harvest data for {run_id}"

print(f"Loading components for run {args.run_id}...")
logger.info(f"Loading components for run {run_id}...")
components = repo.get_all_components()
print(f"Loaded {len(components)} components")
logger.info(f"Loaded {len(components)} components")

print("Loading firing counts...")
logger.info("Loading firing counts...")
token_stats = repo.get_token_stats()
assert token_stats is not None, f"No token stats found for run {args.run_id}"
assert token_stats is not None, f"No token stats found for run {run_id}"
firing_counts = {
key: int(count)
for key, count in zip(token_stats.component_keys, token_stats.firing_counts, strict=True)
}

print("Extracting activations...")
all_by_layer, filtered_by_layer = extract_activations(components, args.ci_threshold)
logger.info("Extracting activations...")
all_by_layer, filtered_by_layer = _extract_activations(components, ci_threshold)

n_layers = len(filtered_by_layer)
n_total = sum(sum(len(v) for v in layer.values()) for layer in filtered_by_layer.values())
print(f"Found {n_total} datapoints across {n_layers} layers with CI > {args.ci_threshold}")
logger.info(f"Found {n_total} datapoints across {n_layers} layers with CI > {ci_threshold}")

if n_total == 0:
print("No datapoints found above threshold. Try lowering --ci-threshold.")
return
assert n_total > 0, "No datapoints found above threshold. Try lowering ci_threshold."

print(f"Creating per-layer plots (ordered by median) in {output_dir_median}/...")
logger.info(f"Creating per-layer plots (ordered by median) in {out_dir_median}/...")
for layer_name in sorted(all_by_layer.keys()):
all_acts = all_by_layer[layer_name]
filtered_acts = filtered_by_layer.get(layer_name, {})
normalized = normalize_per_component(all_acts, filtered_acts)
normalized = _normalize_per_component(all_acts, filtered_acts)
if not normalized:
continue
ordered_keys = order_by_median(normalized)
ordered_keys = _order_by_median(normalized)
safe_name = layer_name.replace(".", "_")
output_path = output_dir_median / f"{safe_name}.png"
create_layer_scatter_plot(normalized, ordered_keys, layer_name, args.run_id, output_path)
print(f" {output_path}")
output_path = out_dir_median / f"{safe_name}.png"
_create_layer_scatter_plot(
normalized,
ordered_keys,
layer_name,
run_id,
output_path,
x_label="Component Rank (by median activation)",
y_label="Normalized Component Activation",
)
logger.info(f" Saved {output_path}")

print(f"Creating per-layer plots (ordered by frequency) in {output_dir_freq}/...")
logger.info(f"Creating per-layer plots (ordered by frequency) in {out_dir_freq}/...")
for layer_name in sorted(all_by_layer.keys()):
all_acts = all_by_layer[layer_name]
filtered_acts = filtered_by_layer.get(layer_name, {})
normalized = normalize_per_component(all_acts, filtered_acts)
normalized = _normalize_per_component(all_acts, filtered_acts)
if not normalized:
continue
abs_from_midpoint = {key: np.abs(acts - 0.5) for key, acts in normalized.items()}
ordered_keys = order_by_frequency(abs_from_midpoint, firing_counts)
ordered_keys = _order_by_frequency(abs_from_midpoint, firing_counts)
safe_name = layer_name.replace(".", "_")
output_path = output_dir_freq / f"{safe_name}.png"
create_layer_scatter_plot(
output_path = out_dir_freq / f"{safe_name}.png"
_create_layer_scatter_plot(
abs_from_midpoint,
ordered_keys,
layer_name,
args.run_id,
run_id,
output_path,
x_label="Component Rank (by firing frequency)",
y_label="|Normalized Component Activation - 0.5|",
)
print(f" {output_path}")
logger.info(f" Saved {output_path}")

print("Done!")
logger.info(f"All plots saved to {out_dir}")


if __name__ == "__main__":
main()
fire.Fire(plot_component_activations)