Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
e17f510
clustering (#43) squash
mivanit Oct 10, 2025
26c2957
timing cluster_ss.py
mivanit Oct 10, 2025
55be0d8
Revert "timing cluster_ss.py"
mivanit Oct 10, 2025
00db8dd
[clustering] `cluster_ss.py` speedup ci (#199)
mivanit Oct 10, 2025
13db0be
add num_nonsingleton_groups stat from PR170
mivanit Oct 10, 2025
8e9ddd0
Merge branch 'main' into clustering/main
mivanit Oct 13, 2025
e79ccb8
Merge branch 'main' into clustering/main
mivanit Oct 13, 2025
0c74b5d
switch BaseModel to BaseConfig, get rid of old save/read logic (#209)
mivanit Oct 13, 2025
3f07f42
switch to new run
mivanit Oct 13, 2025
37175cb
Merge branch 'main' into clustering/main
mivanit Oct 14, 2025
ed667c8
Merge branch 'main' into clustering/main
mivanit Oct 14, 2025
0cdd754
Merge branch 'main' into clustering/main
mivanit Oct 16, 2025
c06cffe
wip sigmoid issues
mivanit Oct 16, 2025
fddb323
that worked...?
mivanit Oct 16, 2025
8a56e12
dont assert positive coacts?
mivanit Oct 16, 2025
1f18690
Merge branch 'main' into clustering/main
mivanit Oct 16, 2025
d18470b
get rid of long-running merge pair sampler on GPU test
mivanit Oct 16, 2025
35b7423
Merge branch 'main' into clustering/main
mivanit Oct 20, 2025
8dfea5e
[clustering] Refactor to two-stage process (#203)
danbraunai-goodfire Oct 20, 2025
d80ba3f
[clustering] distance computation (#213)
mivanit Oct 21, 2025
06ab8cc
Merge branch 'main' into clustering/main
mivanit Oct 21, 2025
a24361a
Merge branch 'main' into clustering/main
mivanit Oct 23, 2025
6c70327
Merge branch 'main' into clustering/main
mivanit Oct 24, 2025
26c6520
[clustering] config refactor (#227)
mivanit Oct 24, 2025
1f0725c
deps
mivanit Oct 24, 2025
d8cc0e7
Merge branch 'main' into clustering/main
mivanit Oct 27, 2025
315e953
Merge branch 'main' into clustering/main
mivanit Oct 29, 2025
06deb8c
uv lock
mivanit Oct 29, 2025
ccc960a
Merge branch 'main' into clustering/main
mivanit Nov 11, 2025
a67cbf0
fix merge?
mivanit Nov 11, 2025
e05f412
Add get_cluster_mapping.py and temp backward compatibility fixes
danbraunai-goodfire Dec 16, 2025
b029b3a
Merge branch 'main' into clustering/main
danbraunai-goodfire Dec 22, 2025
91d6c7c
Add more clustering configs
danbraunai-goodfire Dec 22, 2025
2e56354
spd-cluster -> spd-clustering for consistency
danbraunai-goodfire Dec 22, 2025
e0f9b5f
Remove muutils
danbraunai-goodfire Dec 22, 2025
565f478
Delete various clustering tests
danbraunai-goodfire Dec 22, 2025
1230557
Misc clean
danbraunai-goodfire Dec 23, 2025
ae029bf
some refactoring of slurm utils
danbraunai-goodfire Dec 23, 2025
83daa0a
Refactor slurm utils
danbraunai-goodfire Dec 23, 2025
41f0669
Address AI comments on refactor
danbraunai-goodfire Dec 23, 2025
179114f
Remove run_locally out of slurm.py
danbraunai-goodfire Dec 23, 2025
9ddecd7
Don't assume .env exists
danbraunai-goodfire Dec 23, 2025
4018b00
misc cleanups
danbraunai-goodfire Dec 23, 2025
7b5cb6d
Fix test mocks
danbraunai-goodfire Dec 23, 2025
ce216db
Stop harvesting if datasaet is depleted
danbraunai-goodfire Dec 23, 2025
c2aea4e
Add clustering CLAUDE.md
danbraunai-goodfire Dec 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
spd/scripts/sweep_params.yaml
docs/coverage/**
notebooks/**

**/out/
neuronpedia_outputs/
Expand Down
37 changes: 27 additions & 10 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -243,19 +243,36 @@
}
},
{
"name": "lm streamlit",
"name": "run_clustering example",
"type": "debugpy",
"request": "launch",
"module": "streamlit",
"program": "${workspaceFolder}/spd/clustering/scripts/run_clustering.py",
"args": [
"run",
"${workspaceFolder}/spd/experiments/lm/streamlit_v1/app.py",
"--server.port",
"2000",
"--",
"--model_path",
"wandb:goodfire/spd/runs/ioprgffh"
]
"--config",
"${workspaceFolder}/spd/clustering/configs/crc/example.yaml",
],
"python": "${command:python.interpreterPath}",
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYDEVD_DISABLE_FILE_VALIDATION": "1"
}
},
{
"name": "clustering pipeline",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/spd/clustering/scripts/run_pipeline.py",
"args": [
"--config",
"${workspaceFolder}/spd/clustering/configs/pipeline_config.yaml",
],
"python": "${command:python.interpreterPath}",
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYDEVD_DISABLE_FILE_VALIDATION": "1"
}
}
]
}
11 changes: 10 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ coverage:
uv run python -m coverage report -m > $(COVERAGE_DIR)/coverage.txt
uv run python -m coverage html --directory=$(COVERAGE_DIR)/html/


.PHONY: clean
clean:
@echo "Cleaning Python cache and build artifacts..."
find . -type d -name "__pycache__" -exec rm -rf {} +
find . -type d -name "*.egg-info" -exec rm -rf {} +
rm -rf build/ dist/ .ruff_cache/ .pytest_cache/ .coverage


.PHONY: app
app:
@uv run python spd/app/run_app.py
Expand All @@ -86,4 +95,4 @@ install-app:

.PHONY: check-app
check-app:
(cd spd/app/frontend && npm run format && npm run check && npm run lint)
(cd spd/app/frontend && npm run format && npm run check && npm run lint)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
# see: https://github.com/huggingface/datasets/issues/6980 https://github.com/huggingface/datasets/pull/6991 (fixed in https://github.com/huggingface/datasets/releases/tag/2.21.0 )
"datasets>=2.21.0",
"simple_stories_train @ git+https://github.com/goodfire-ai/simple_stories_train.git@dev",
"scipy>=1.14.1",
"fastapi",
"uvicorn",
"openrouter>=0.1.1",
Expand All @@ -48,6 +49,7 @@ dev = [
[project.scripts]
spd-run = "spd.scripts.run_cli:cli"
spd-local = "spd.scripts.run_local:cli"
spd-clustering = "spd.clustering.scripts.run_pipeline:cli"
spd-harvest = "spd.harvest.scripts.run_slurm_cli:cli"
spd-autointerp = "spd.autointerp.scripts.run_slurm_cli:cli"

Expand Down
105 changes: 28 additions & 77 deletions spd/autointerp/scripts/run_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,9 @@
spd-autointerp <wandb_path> --budget_usd 100
"""

import subprocess
from datetime import datetime
from pathlib import Path

from spd.autointerp.interpret import OpenRouterModelName
from spd.log import logger
from spd.settings import REPO_ROOT


def _generate_job_id() -> str:
return datetime.now().strftime("%Y%m%d_%H%M%S")


def _submit_slurm_job(script_content: str, script_path: Path) -> str:
"""Write script and submit to SLURM, returning job ID."""
with open(script_path, "w") as f:
f.write(script_content)
script_path.chmod(0o755)

result = subprocess.run(
["sbatch", str(script_path)], capture_output=True, text=True, check=False
)
if result.returncode != 0:
raise RuntimeError(f"Failed to submit SLURM job: {result.stderr}")

job_id = result.stdout.strip().split()[-1]
return job_id
from spd.utils.slurm import SlurmConfig, generate_script, submit_slurm_job


def launch_interpret_job(
Expand All @@ -52,14 +28,7 @@ def launch_interpret_job(
time: Job time limit.
max_examples_per_component: Maximum number of activation examples per component.
"""
job_id = _generate_job_id()
slurm_logs_dir = Path.home() / "slurm_logs"
slurm_logs_dir.mkdir(exist_ok=True)

sbatch_scripts_dir = Path.home() / "sbatch_scripts"
sbatch_scripts_dir.mkdir(exist_ok=True)

job_name = f"interpret-{job_id}"
job_name = "interpret"

cmd_parts = [
"python -m spd.autointerp.scripts.run_interpret",
Expand All @@ -69,56 +38,38 @@ def launch_interpret_job(
]
interpret_cmd = " \\\n ".join(cmd_parts)

script_content = f"""\
#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --partition={partition}
#SBATCH --nodes=1
#SBATCH --gres=gpu:0
#SBATCH --cpus-per-task=4
#SBATCH --time={time}
#SBATCH --output={slurm_logs_dir}/slurm-%j.out

set -euo pipefail

echo "=== Interpret ==="
echo "WANDB_PATH: {wandb_path}"
echo "MODEL: {model.value}"
echo "SLURM_JOB_ID: $SLURM_JOB_ID"
echo "================="

cd {REPO_ROOT}
source .venv/bin/activate

# OPENROUTER_API_KEY should be in .env or environment
if [ -f .env ]; then
set -a
source .env
set +a
fi

{interpret_cmd}

echo "Interpret complete!"
"""

script_path = sbatch_scripts_dir / f"interpret_{job_id}.sh"
slurm_job_id = _submit_slurm_job(script_content, script_path)

# Rename to include SLURM job ID
final_script_path = sbatch_scripts_dir / f"interpret_{slurm_job_id}.sh"
script_path.rename(final_script_path)
# Build full command with echoes
full_command = "\n".join(
[
'echo "=== Interpret ==="',
f'echo "WANDB_PATH: {wandb_path}"',
f'echo "MODEL: {model.value}"',
'echo "SLURM_JOB_ID: $SLURM_JOB_ID"',
'echo "================="',
"",
interpret_cmd,
"",
'echo "Interpret complete!"',
]
)

# Create empty log file for tailing
(slurm_logs_dir / f"slurm-{slurm_job_id}.out").touch()
config = SlurmConfig(
job_name=job_name,
partition=partition,
n_gpus=0, # CPU-only job
time=time,
snapshot_branch=None, # Autointerp doesn't use git snapshots
)
script_content = generate_script(config, full_command)
result = submit_slurm_job(script_content, "interpret")

logger.section("Interpret job submitted!")
logger.values(
{
"Job ID": slurm_job_id,
"Job ID": result.job_id,
"WandB path": wandb_path,
"Model": model.value,
"Log": f"~/slurm_logs/slurm-{slurm_job_id}.out",
"Script": str(final_script_path),
"Log": result.log_pattern,
"Script": str(result.script_path),
}
)
9 changes: 8 additions & 1 deletion spd/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class BaseConfig(BaseModel):

model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", frozen=True)

# TODO: add a "config_type" field, which is set to the class name, so that when loading a config we can check whether the config type matches the expected class

@classmethod
def from_file(cls, path: Path | str) -> Self:
"""Load config from path to a JSON or YAML file."""
Expand All @@ -29,7 +31,12 @@ def from_file(cls, path: Path | str) -> Self:
case _:
raise ValueError(f"Only (.json, .yaml, .yml) files are supported, got {path}")

return cls.model_validate(data)
try:
cfg = cls.model_validate(data)
except Exception as e:
e.add_note(f"Error validating config {cls=} from path `{path.as_posix()}`\n{data = }")
raise e
return cfg

def to_file(self, path: Path | str) -> None:
"""Save config to file (format inferred from extension)."""
Expand Down
118 changes: 118 additions & 0 deletions spd/clustering/CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Clustering Module

Hierarchical clustering of SPD components based on coactivation patterns. Runs ensemble clustering experiments to discover stable groups of components that behave similarly.

## Usage

**`spd-clustering` / `run_pipeline.py`**: Runs multiple clustering runs (ensemble) with different seeds, then runs `calc_distances` to compute pairwise distances between results. Use this for ensemble experiments.

**`run_clustering.py`**: Runs a single clustering run. Useful for testing or when you only need one clustering result.

```bash
# Run clustering pipeline via SLURM (ensemble of runs + distance calculation)
spd-clustering --config spd/clustering/configs/pipeline_config.yaml

# Run locally instead of SLURM
spd-clustering --config spd/clustering/configs/pipeline_config.yaml --local

# Single clustering run (usually called by pipeline)
python -m spd.clustering.scripts.run_clustering --config <clustering_run_config.json>
```

## Data Storage

```
/mnt/polished-lake/spd/clustering/
├── cluster/<run_id>/ # Single clustering run outputs
│ ├── clustering_run_config.json
│ └── history.zip # MergeHistory (group assignments per iteration)
└── ensemble/<pipeline_run_id>/ # Pipeline/ensemble outputs
├── pipeline_config.yaml
├── ensemble_meta.json # Component labels, iteration stats
├── ensemble_merge_array.npz # Normalized merge array
├── distances_<method>.npz # Distance matrices
└── plots/
└── distances_<method>.png # Distance distribution visualization
```

## Architecture

### Pipeline (`scripts/run_pipeline.py`)

Entry point via `spd-clustering`. Submits clustering runs as SLURM job array, then calculates distances between results. Key steps:
1. Creates `ExecutionStamp` for pipeline
2. Generates commands for each clustering run (with different dataset seeds)
3. Submits clustering array job to SLURM
4. Submits distance calculation jobs (depend on clustering completion)

### Single Run (`scripts/run_clustering.py`)

Performs one clustering run:
1. Load decomposed model from WandB
2. Compute component activations on dataset batch
3. Run merge iteration (greedy MDL-based clustering)
4. Save `MergeHistory` with group assignments per iteration

### Merge Algorithm (`merge.py`)

Greedy hierarchical clustering using MDL (Minimum Description Length) cost:
- Computes coactivation matrix from component activations
- Iteratively merges pairs with lowest cost (via `compute_merge_costs`)
- Supports stochastic merge pair selection (`merge_pair_sampling_method`)
- Tracks full merge history for analysis

### Distance Calculation (`scripts/calc_distances.py`)

Computes pairwise distances between clustering runs in an ensemble:
- Normalizes component labels across runs (handles dead components)
- Supports multiple distance methods: `perm_invariant_hamming`, `matching_dist`
- Runs in parallel using multiprocessing

## Key Types

### Configs

```python
ClusteringPipelineConfig # Pipeline settings (n_runs, distances_methods, SLURM config)
ClusteringRunConfig # Single run settings (model_path, batch_size, merge_config)
MergeConfig # Merge algorithm params (alpha, iters, activation_threshold)
```

### Data Structures

```python
MergeHistory # Full merge history: group assignments at each iteration
MergeHistoryEnsemble # Collection of histories for distance analysis
GroupMerge # Current group assignments (component -> group mapping)
```

### Type Aliases (`consts.py`)

```python
ActivationsTensor # Float[Tensor, "samples n_components"]
ClusterCoactivationShaped # Float[Tensor, "k_groups k_groups"]
MergesArray # Int[np.ndarray, "n_ens n_iters n_components"]
DistancesArray # Float[np.ndarray, "n_iters n_ens n_ens"]
```

## Math Submodule (`math/`)

- `merge_matrix.py` - `GroupMerge` class for tracking group assignments
- `merge_distances.py` - Distance computation between clustering results
- `perm_invariant_hamming.py` - Permutation-invariant Hamming distance
- `matching_dist.py` - Optimal matching distance via Hungarian algorithm
- `merge_pair_samplers.py` - Strategies for selecting which pair to merge

## Config Files

Configs live in `spd/clustering/configs/`:
- Pipeline configs: `*.yaml` files with `ClusteringPipelineConfig`
- Run configs: `crc/*.json` files with `ClusteringRunConfig`

Example pipeline config:
```yaml
clustering_run_config_path: "spd/clustering/configs/crc/ss_llama_simple_mlp-1L.json"
n_runs: 10
distances_methods: ["perm_invariant_hamming"]
wandb_project: "spd"
```
Empty file added spd/clustering/__init__.py
Empty file.
Loading