diff --git a/.gitignore b/.gitignore index 67655d902..b84f21934 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ spd/scripts/sweep_params.yaml docs/coverage/** +notebooks/** **/out/ neuronpedia_outputs/ diff --git a/.vscode/launch.json b/.vscode/launch.json index 32225ae4d..da7153838 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -243,19 +243,36 @@ } }, { - "name": "lm streamlit", + "name": "run_clustering example", "type": "debugpy", "request": "launch", - "module": "streamlit", + "program": "${workspaceFolder}/spd/clustering/scripts/run_clustering.py", "args": [ - "run", - "${workspaceFolder}/spd/experiments/lm/streamlit_v1/app.py", - "--server.port", - "2000", - "--", - "--model_path", - "wandb:goodfire/spd/runs/ioprgffh" - ] + "--config", + "${workspaceFolder}/spd/clustering/configs/crc/example.yaml", + ], + "python": "${command:python.interpreterPath}", + "console": "integratedTerminal", + "justMyCode": true, + "env": { + "PYDEVD_DISABLE_FILE_VALIDATION": "1" + } + }, + { + "name": "clustering pipeline", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/spd/clustering/scripts/run_pipeline.py", + "args": [ + "--config", + "${workspaceFolder}/spd/clustering/configs/pipeline_config.yaml", + ], + "python": "${command:python.interpreterPath}", + "console": "integratedTerminal", + "justMyCode": true, + "env": { + "PYDEVD_DISABLE_FILE_VALIDATION": "1" + } } ] } \ No newline at end of file diff --git a/Makefile b/Makefile index e3a148703..76f779159 100644 --- a/Makefile +++ b/Makefile @@ -76,6 +76,15 @@ coverage: uv run python -m coverage report -m > $(COVERAGE_DIR)/coverage.txt uv run python -m coverage html --directory=$(COVERAGE_DIR)/html/ + +.PHONY: clean +clean: + @echo "Cleaning Python cache and build artifacts..." + find . -type d -name "__pycache__" -exec rm -rf {} + + find . -type d -name "*.egg-info" -exec rm -rf {} + + rm -rf build/ dist/ .ruff_cache/ .pytest_cache/ .coverage + + .PHONY: app app: @uv run python spd/app/run_app.py @@ -86,4 +95,4 @@ install-app: .PHONY: check-app check-app: - (cd spd/app/frontend && npm run format && npm run check && npm run lint) \ No newline at end of file + (cd spd/app/frontend && npm run format && npm run check && npm run lint) diff --git a/pyproject.toml b/pyproject.toml index 1503c8455..c70987204 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ # see: https://github.com/huggingface/datasets/issues/6980 https://github.com/huggingface/datasets/pull/6991 (fixed in https://github.com/huggingface/datasets/releases/tag/2.21.0 ) "datasets>=2.21.0", "simple_stories_train @ git+https://github.com/goodfire-ai/simple_stories_train.git@dev", + "scipy>=1.14.1", "fastapi", "uvicorn", "openrouter>=0.1.1", @@ -48,6 +49,7 @@ dev = [ [project.scripts] spd-run = "spd.scripts.run_cli:cli" spd-local = "spd.scripts.run_local:cli" +spd-clustering = "spd.clustering.scripts.run_pipeline:cli" spd-harvest = "spd.harvest.scripts.run_slurm_cli:cli" spd-autointerp = "spd.autointerp.scripts.run_slurm_cli:cli" diff --git a/spd/autointerp/scripts/run_slurm.py b/spd/autointerp/scripts/run_slurm.py index da93a95ad..1a61f70d1 100644 --- a/spd/autointerp/scripts/run_slurm.py +++ b/spd/autointerp/scripts/run_slurm.py @@ -7,33 +7,9 @@ spd-autointerp --budget_usd 100 """ -import subprocess -from datetime import datetime -from pathlib import Path - from spd.autointerp.interpret import OpenRouterModelName from spd.log import logger -from spd.settings import REPO_ROOT - - -def _generate_job_id() -> str: - return datetime.now().strftime("%Y%m%d_%H%M%S") - - -def _submit_slurm_job(script_content: str, script_path: Path) -> str: - """Write script and submit to SLURM, returning job ID.""" - with open(script_path, "w") as f: - f.write(script_content) - script_path.chmod(0o755) - - result = subprocess.run( - ["sbatch", str(script_path)], capture_output=True, text=True, check=False - ) - if result.returncode != 0: - raise RuntimeError(f"Failed to submit SLURM job: {result.stderr}") - - job_id = result.stdout.strip().split()[-1] - return job_id +from spd.utils.slurm import SlurmConfig, generate_script, submit_slurm_job def launch_interpret_job( @@ -52,14 +28,7 @@ def launch_interpret_job( time: Job time limit. max_examples_per_component: Maximum number of activation examples per component. """ - job_id = _generate_job_id() - slurm_logs_dir = Path.home() / "slurm_logs" - slurm_logs_dir.mkdir(exist_ok=True) - - sbatch_scripts_dir = Path.home() / "sbatch_scripts" - sbatch_scripts_dir.mkdir(exist_ok=True) - - job_name = f"interpret-{job_id}" + job_name = "interpret" cmd_parts = [ "python -m spd.autointerp.scripts.run_interpret", @@ -69,56 +38,38 @@ def launch_interpret_job( ] interpret_cmd = " \\\n ".join(cmd_parts) - script_content = f"""\ -#!/bin/bash -#SBATCH --job-name={job_name} -#SBATCH --partition={partition} -#SBATCH --nodes=1 -#SBATCH --gres=gpu:0 -#SBATCH --cpus-per-task=4 -#SBATCH --time={time} -#SBATCH --output={slurm_logs_dir}/slurm-%j.out - -set -euo pipefail - -echo "=== Interpret ===" -echo "WANDB_PATH: {wandb_path}" -echo "MODEL: {model.value}" -echo "SLURM_JOB_ID: $SLURM_JOB_ID" -echo "=================" - -cd {REPO_ROOT} -source .venv/bin/activate - -# OPENROUTER_API_KEY should be in .env or environment -if [ -f .env ]; then - set -a - source .env - set +a -fi - -{interpret_cmd} - -echo "Interpret complete!" -""" - - script_path = sbatch_scripts_dir / f"interpret_{job_id}.sh" - slurm_job_id = _submit_slurm_job(script_content, script_path) - - # Rename to include SLURM job ID - final_script_path = sbatch_scripts_dir / f"interpret_{slurm_job_id}.sh" - script_path.rename(final_script_path) + # Build full command with echoes + full_command = "\n".join( + [ + 'echo "=== Interpret ==="', + f'echo "WANDB_PATH: {wandb_path}"', + f'echo "MODEL: {model.value}"', + 'echo "SLURM_JOB_ID: $SLURM_JOB_ID"', + 'echo "================="', + "", + interpret_cmd, + "", + 'echo "Interpret complete!"', + ] + ) - # Create empty log file for tailing - (slurm_logs_dir / f"slurm-{slurm_job_id}.out").touch() + config = SlurmConfig( + job_name=job_name, + partition=partition, + n_gpus=0, # CPU-only job + time=time, + snapshot_branch=None, # Autointerp doesn't use git snapshots + ) + script_content = generate_script(config, full_command) + result = submit_slurm_job(script_content, "interpret") logger.section("Interpret job submitted!") logger.values( { - "Job ID": slurm_job_id, + "Job ID": result.job_id, "WandB path": wandb_path, "Model": model.value, - "Log": f"~/slurm_logs/slurm-{slurm_job_id}.out", - "Script": str(final_script_path), + "Log": result.log_pattern, + "Script": str(result.script_path), } ) diff --git a/spd/base_config.py b/spd/base_config.py index c9b488e19..4c37da906 100644 --- a/spd/base_config.py +++ b/spd/base_config.py @@ -15,6 +15,8 @@ class BaseConfig(BaseModel): model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", frozen=True) + # TODO: add a "config_type" field, which is set to the class name, so that when loading a config we can check whether the config type matches the expected class + @classmethod def from_file(cls, path: Path | str) -> Self: """Load config from path to a JSON or YAML file.""" @@ -29,7 +31,12 @@ def from_file(cls, path: Path | str) -> Self: case _: raise ValueError(f"Only (.json, .yaml, .yml) files are supported, got {path}") - return cls.model_validate(data) + try: + cfg = cls.model_validate(data) + except Exception as e: + e.add_note(f"Error validating config {cls=} from path `{path.as_posix()}`\n{data = }") + raise e + return cfg def to_file(self, path: Path | str) -> None: """Save config to file (format inferred from extension).""" diff --git a/spd/clustering/CLAUDE.md b/spd/clustering/CLAUDE.md new file mode 100644 index 000000000..c05ad9321 --- /dev/null +++ b/spd/clustering/CLAUDE.md @@ -0,0 +1,118 @@ +# Clustering Module + +Hierarchical clustering of SPD components based on coactivation patterns. Runs ensemble clustering experiments to discover stable groups of components that behave similarly. + +## Usage + +**`spd-clustering` / `run_pipeline.py`**: Runs multiple clustering runs (ensemble) with different seeds, then runs `calc_distances` to compute pairwise distances between results. Use this for ensemble experiments. + +**`run_clustering.py`**: Runs a single clustering run. Useful for testing or when you only need one clustering result. + +```bash +# Run clustering pipeline via SLURM (ensemble of runs + distance calculation) +spd-clustering --config spd/clustering/configs/pipeline_config.yaml + +# Run locally instead of SLURM +spd-clustering --config spd/clustering/configs/pipeline_config.yaml --local + +# Single clustering run (usually called by pipeline) +python -m spd.clustering.scripts.run_clustering --config +``` + +## Data Storage + +``` +/mnt/polished-lake/spd/clustering/ +├── cluster// # Single clustering run outputs +│ ├── clustering_run_config.json +│ └── history.zip # MergeHistory (group assignments per iteration) +└── ensemble// # Pipeline/ensemble outputs + ├── pipeline_config.yaml + ├── ensemble_meta.json # Component labels, iteration stats + ├── ensemble_merge_array.npz # Normalized merge array + ├── distances_.npz # Distance matrices + └── plots/ + └── distances_.png # Distance distribution visualization +``` + +## Architecture + +### Pipeline (`scripts/run_pipeline.py`) + +Entry point via `spd-clustering`. Submits clustering runs as SLURM job array, then calculates distances between results. Key steps: +1. Creates `ExecutionStamp` for pipeline +2. Generates commands for each clustering run (with different dataset seeds) +3. Submits clustering array job to SLURM +4. Submits distance calculation jobs (depend on clustering completion) + +### Single Run (`scripts/run_clustering.py`) + +Performs one clustering run: +1. Load decomposed model from WandB +2. Compute component activations on dataset batch +3. Run merge iteration (greedy MDL-based clustering) +4. Save `MergeHistory` with group assignments per iteration + +### Merge Algorithm (`merge.py`) + +Greedy hierarchical clustering using MDL (Minimum Description Length) cost: +- Computes coactivation matrix from component activations +- Iteratively merges pairs with lowest cost (via `compute_merge_costs`) +- Supports stochastic merge pair selection (`merge_pair_sampling_method`) +- Tracks full merge history for analysis + +### Distance Calculation (`scripts/calc_distances.py`) + +Computes pairwise distances between clustering runs in an ensemble: +- Normalizes component labels across runs (handles dead components) +- Supports multiple distance methods: `perm_invariant_hamming`, `matching_dist` +- Runs in parallel using multiprocessing + +## Key Types + +### Configs + +```python +ClusteringPipelineConfig # Pipeline settings (n_runs, distances_methods, SLURM config) +ClusteringRunConfig # Single run settings (model_path, batch_size, merge_config) +MergeConfig # Merge algorithm params (alpha, iters, activation_threshold) +``` + +### Data Structures + +```python +MergeHistory # Full merge history: group assignments at each iteration +MergeHistoryEnsemble # Collection of histories for distance analysis +GroupMerge # Current group assignments (component -> group mapping) +``` + +### Type Aliases (`consts.py`) + +```python +ActivationsTensor # Float[Tensor, "samples n_components"] +ClusterCoactivationShaped # Float[Tensor, "k_groups k_groups"] +MergesArray # Int[np.ndarray, "n_ens n_iters n_components"] +DistancesArray # Float[np.ndarray, "n_iters n_ens n_ens"] +``` + +## Math Submodule (`math/`) + +- `merge_matrix.py` - `GroupMerge` class for tracking group assignments +- `merge_distances.py` - Distance computation between clustering results +- `perm_invariant_hamming.py` - Permutation-invariant Hamming distance +- `matching_dist.py` - Optimal matching distance via Hungarian algorithm +- `merge_pair_samplers.py` - Strategies for selecting which pair to merge + +## Config Files + +Configs live in `spd/clustering/configs/`: +- Pipeline configs: `*.yaml` files with `ClusteringPipelineConfig` +- Run configs: `crc/*.json` files with `ClusteringRunConfig` + +Example pipeline config: +```yaml +clustering_run_config_path: "spd/clustering/configs/crc/ss_llama_simple_mlp-1L.json" +n_runs: 10 +distances_methods: ["perm_invariant_hamming"] +wandb_project: "spd" +``` diff --git a/spd/clustering/__init__.py b/spd/clustering/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py new file mode 100644 index 000000000..cd6a2b742 --- /dev/null +++ b/spd/clustering/activations.py @@ -0,0 +1,267 @@ +from dataclasses import dataclass +from functools import cached_property +from typing import Literal, NamedTuple + +import torch +from jaxtyping import Bool, Float, Float16, Int +from torch import Tensor + +from spd.clustering.consts import ( + ActivationsTensor, + BoolActivationsTensor, + ClusterCoactivationShaped, + ComponentLabels, +) +from spd.clustering.util import ModuleFilterFunc +from spd.models.component_model import ComponentModel, OutputWithCache + + +def component_activations( + model: ComponentModel, + device: torch.device | str, + batch: Int[Tensor, "batch_size n_ctx"], +) -> dict[str, ActivationsTensor]: + """Get the component activations over a **single** batch.""" + causal_importances: dict[str, ActivationsTensor] + with torch.no_grad(): + model_output: OutputWithCache = model( + batch.to(device), + cache_type="input", + ) + + # TODO: !!!IMPORTANT!!! unclear what the right thing from CIOutputs is + causal_importances = model.calc_causal_importances( + pre_weight_acts=model_output.cache, + sampling="continuous", + detach_inputs=False, + ).upper_leaky + + return causal_importances + + +def compute_coactivatons( + activations: ActivationsTensor | BoolActivationsTensor, +) -> ClusterCoactivationShaped: + """Compute the coactivations matrix from the activations.""" + # TODO: this works for both boolean and continuous activations, + # but we could do better by just using OR for boolean activations + # and maybe even some bitshift hacks. but for now, we convert to float16 + activations_f16: Float16[Tensor, "samples C"] = activations.to(torch.float16) + return activations_f16.T @ activations_f16 + + +class FilteredActivations(NamedTuple): + activations: ActivationsTensor + "activations after filtering dead components" + + labels: ComponentLabels + "list of length c with labels for each preserved component" + + dead_components_labels: ComponentLabels | None + "list of labels for dead components, or None if no filtering was applied" + + @property + def n_alive(self) -> int: + """Number of alive components after filtering.""" + n_alive: int = len(self.labels) + assert n_alive == self.activations.shape[1], ( + f"{n_alive = } != {self.activations.shape[1] = }" + ) + return n_alive + + @property + def n_dead(self) -> int: + """Number of dead components after filtering.""" + return len(self.dead_components_labels) if self.dead_components_labels else 0 + + +def filter_dead_components( + activations: ActivationsTensor, + labels: ComponentLabels, + filter_dead_threshold: float = 0.01, +) -> FilteredActivations: + """Filter out dead components based on a threshold + + if `filter_dead_threshold` is 0, no filtering is applied. + activations and labels are returned as is, `dead_components_labels` is `None`. + + otherwise, components whose **maximum** activations across all samples is below the threshold + are considered dead and filtered out. The labels of these components are returned in `dead_components_labels`. + `dead_components_labels` will also be `None` if no components were below the threshold. + """ + dead_components_lst: ComponentLabels | None = None + if filter_dead_threshold > 0: + dead_components_lst = ComponentLabels(list()) + max_act: Float[Tensor, " c"] = activations.max(dim=0).values + dead_components: Bool[Tensor, " c"] = max_act < filter_dead_threshold + + if dead_components.any(): + activations = activations[:, ~dead_components] + alive_labels: list[tuple[str, bool]] = [ + (lbl, bool(keep.item())) + for lbl, keep in zip(labels, ~dead_components, strict=False) + ] + # re-assign labels only if we are filtering + labels = ComponentLabels([label for label, keep in alive_labels if keep]) + dead_components_lst = ComponentLabels( + [label for label, keep in alive_labels if not keep] + ) + + return FilteredActivations( + activations=activations, + labels=labels, + dead_components_labels=dead_components_lst if dead_components_lst else None, + ) + + +@dataclass(frozen=True) +class ProcessedActivations: + """Processed activations after filtering and concatenation""" + + activations_raw: dict[str, ActivationsTensor] + "activations after filtering, but prior to concatenation" + + activations: ActivationsTensor + "activations after filtering and concatenation" + + labels: ComponentLabels + "list of length c with labels for each preserved component, format `{module_name}:{component_index}`" + + dead_components_lst: ComponentLabels | None + "list of labels for dead components, or None if no filtering was applied" + + def validate(self) -> None: + """Validate the processed activations""" + # getting this property will also perform a variety of other checks + assert self.n_components_alive > 0 + + @property + def n_components_original(self) -> int: + """Total number of components before filtering. equal to the sum of all components in `activations_raw`, or to `n_components_alive + n_components_dead`""" + return sum(act.shape[1] for act in self.activations_raw.values()) + + @property + def n_components_alive(self) -> int: + """Number of alive components after filtering. equal to the length of `labels`""" + n_alive: int = len(self.labels) + assert n_alive + self.n_components_dead == self.n_components_original, ( + f"({n_alive = }) + ({self.n_components_dead = }) != ({self.n_components_original = })" + ) + assert n_alive == self.activations.shape[1], ( + f"{n_alive = } != {self.activations.shape[1] = }" + ) + + return n_alive + + @property + def n_components_dead(self) -> int: + """Number of dead components after filtering. equal to the length of `dead_components_lst` if it is not None, or 0 otherwise""" + return len(self.dead_components_lst) if self.dead_components_lst else 0 + + @cached_property + def label_index(self) -> dict[str, int | None]: + """Create a mapping from label to alive index (`None` if dead)""" + return { + **{label: i for i, label in enumerate(self.labels)}, + **( + {label: None for label in self.dead_components_lst} + if self.dead_components_lst + else {} + ), + } + + def get_label_index(self, label: str) -> int | None: + """Get the index of a label in the activations, or None if it is dead""" + return self.label_index[label] + + def get_label_index_alive(self, label: str) -> int: + """Get the index of a label in the activations, or raise if it is dead""" + idx: int | None = self.get_label_index(label) + if idx is None: + raise ValueError(f"Label '{label}' is dead and has no index in the activations.") + return idx + + @property + def module_keys(self) -> list[str]: + """Get the module keys from the activations_raw""" + return list(self.activations_raw.keys()) + + def get_module_indices(self, module_key: str) -> list[int | None]: + """given a module key, return a list len "num components in that moduel", with int index in alive components, or None if dead""" + num_components: int = self.activations_raw[module_key].shape[1] + return [self.label_index[f"{module_key}:{i}"] for i in range(num_components)] + + +def process_activations( + activations: dict[ + str, # module name to + Float[Tensor, "samples C"] # (sample x component gate activations) + | Float[Tensor, " n_sample n_ctx C"], # (sample x seq index x component gate activations) + ], + filter_dead_threshold: float = 0.01, + seq_mode: Literal["concat", "seq_mean", None] = None, + filter_modules: ModuleFilterFunc | None = None, +) -> ProcessedActivations: + """get back a dict of coactivations, slices, and concated activations + + Args: + activations: Dictionary of activations by module + filter_dead_threshold: Threshold for filtering dead components + seq_mode: How to handle sequence dimension + filter_modules: Function to filter modules + sort_components: Whether to sort components by similarity within each module + """ + + # reshape -- special cases for llms + # ============================================================ + activations_: dict[str, ActivationsTensor] + if seq_mode == "concat": + # Concatenate the sequence dimension into the sample dimension + activations_ = { + key: act.reshape(act.shape[0] * act.shape[1], act.shape[2]) + for key, act in activations.items() + } + elif seq_mode == "seq_mean": + # Take the mean over the sequence dimension + activations_ = { + key: act.mean(dim=1) if act.ndim == 3 else act for key, act in activations.items() + } + else: + # Use the activations as they are + activations_ = activations + + # put the labelled activations into one big matrix and filter them + # ============================================================ + + # filter activations for only the modules we want + if filter_modules is not None: + activations_ = {key: act for key, act in activations_.items() if filter_modules(key)} + + # compute the labels and total component count + total_c: int = 0 + labels: ComponentLabels = ComponentLabels(list()) + for key, act in activations_.items(): + c: int = act.shape[-1] + labels.extend([f"{key}:{i}" for i in range(c)]) + total_c += c + + # concat the activations + act_concat: ActivationsTensor = torch.cat([activations_[key] for key in activations_], dim=-1) + + # filter dead components + filtered_components: FilteredActivations = filter_dead_components( + activations=act_concat, + labels=labels, + filter_dead_threshold=filter_dead_threshold, + ) + + assert filtered_components.n_alive + filtered_components.n_dead == total_c, ( + f"({filtered_components.n_alive = }) + ({filtered_components.n_dead = }) != ({total_c = })" + ) + + return ProcessedActivations( + activations_raw=activations_, + activations=filtered_components.activations, + labels=filtered_components.labels, + dead_components_lst=filtered_components.dead_components_labels, + ) diff --git a/spd/clustering/clustering_run_config.py b/spd/clustering/clustering_run_config.py new file mode 100644 index 000000000..95d72f9bd --- /dev/null +++ b/spd/clustering/clustering_run_config.py @@ -0,0 +1,136 @@ +"""ClusteringRunConfig""" + +import base64 +import hashlib +import json +from pathlib import Path +from typing import Any + +from pydantic import Field, PositiveInt, field_validator, model_validator + +from spd.base_config import BaseConfig +from spd.clustering.merge_config import MergeConfig +from spd.registry import EXPERIMENT_REGISTRY +from spd.settings import SPD_CACHE_DIR + + +class LoggingIntervals(BaseConfig): + """Intervals in which to log each type of output.""" + + stat: PositiveInt = Field( + default=1, description="Logging statistics (e.g., k_groups, merge_pair_cost, mdl_loss)" + ) + tensor: PositiveInt = Field( + default=100, description="Logging tensors (e.g., wandb_log_tensor, fraction calculations)" + ) + plot: PositiveInt = Field( + default=100, description="Generating plots (e.g., plot_merge_iteration)" + ) + artifact: PositiveInt = Field( + default=100, description="Creating artifacts (e.g., merge_history)" + ) + + +class ClusteringRunConfig(BaseConfig): + """Configuration for a single clustering run. + + This config specifies the clustering algorithm parameters and data processing settings. + Deployment concerns (where to save, WandB settings, ensemble configuration) are handled + by ClusteringSubmitConfig. + """ + + # TODO: Handle both wandb strings and local file paths + model_path: str = Field( + description="WandB path to the decomposed model (format: wandb:entity/project/run_id)" + ) + + batch_size: PositiveInt = Field(..., description="Batch size for processing") + dataset_seed: int = Field(0, description="Seed for dataset generation/loading") + base_output_dir: Path = Field( + default=SPD_CACHE_DIR / "clustering", + description="Base directory to save clustering runs", + ) + ensemble_id: str | None = Field( + default=None, + description="Ensemble identifier for WandB grouping", + ) + merge_config: MergeConfig = Field(description="Merge algorithm configuration") + logging_intervals: LoggingIntervals = Field( + default_factory=LoggingIntervals, + description="Logging intervals", + ) + + wandb_project: str | None = Field( + default=None, + description="WandB project name (None to disable WandB logging)", + ) + wandb_entity: str = Field(default="goodfire", description="WandB entity (team/user) name") + dataset_streaming: bool = Field( + default=False, + description="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", + ) + + @model_validator(mode="before") + def process_experiment_key(cls, values: dict[str, Any]) -> dict[str, Any]: + experiment_key: str | None = values.get("experiment_key") + if experiment_key: + model_path_given: str | None = values.get("model_path") + model_path_from_experiment: str | None = EXPERIMENT_REGISTRY[ + experiment_key + ].canonical_run + assert model_path_from_experiment is not None, ( + f"Experiment '{experiment_key}' has no canonical_run defined in the EXPERIMENT_REGISTRY" + ) + if model_path_given and model_path_given != model_path_from_experiment: + raise ValueError( + f"Both experiment_key '{experiment_key}' and model_path '{model_path_given}' given in config data, but they disagree: {model_path_from_experiment=}" + ) + + values["model_path"] = model_path_from_experiment + del values["experiment_key"] + + return values + + @field_validator("model_path") + def validate_model_path(cls, v: str) -> str: + """Validate that model_path is a proper WandB path.""" + if not v.startswith("wandb:"): + raise ValueError(f"model_path must start with 'wandb:', got: {v}") + return v + + @property + def wandb_decomp_model(self) -> str: + """Extract the WandB run ID of the source decomposition.""" + parts = self.model_path.replace("wandb:", "").split("/") + if len(parts) >= 3: + return parts[-1] if parts[-1] != "runs" else parts[-2] + raise ValueError(f"Invalid wandb path format: {self.model_path}") + + def model_dump_with_properties(self) -> dict[str, Any]: + """Serialize config including computed properties for WandB logging.""" + base_dump: dict[str, Any] = self.model_dump(mode="json") + + # Add computed properties + base_dump.update( + { + "wandb_decomp_model": self.wandb_decomp_model, + } + ) + + return base_dump + + def stable_hash_b64(self) -> str: + """Generate a stable, deterministic base64-encoded hash of this config. + + Uses SHA256 hash of the JSON representation with sorted keys for determinism. + Returns URL-safe base64 encoding without padding. + + Returns: + URL-safe base64-encoded hash (without padding) + """ + config_dict: dict[str, Any] = self.model_dump(mode="json") + config_json: str = json.dumps(config_dict, indent=2, sort_keys=True) + hash_digest: bytes = hashlib.sha256(config_json.encode()).digest() + # Use base64 URL-safe encoding and strip padding for filesystem safety + hash_b64: str = base64.urlsafe_b64encode(hash_digest).decode().rstrip("=") + return hash_b64 diff --git a/spd/clustering/compute_costs.py b/spd/clustering/compute_costs.py new file mode 100644 index 000000000..f1b3425d1 --- /dev/null +++ b/spd/clustering/compute_costs.py @@ -0,0 +1,189 @@ +import math + +import torch +from jaxtyping import Bool, Float +from torch import Tensor + +from spd.clustering.consts import ClusterCoactivationShaped, MergePair +from spd.clustering.math.merge_matrix import GroupMerge + + +def compute_mdl_cost( + acts: Float[Tensor, " k_groups"], + merges: GroupMerge, + alpha: float = 1.0, +) -> float: + r"""Compute MDL costs for merge matrices + + $$ + MDL = \sum_{i \in \N_k} s_i ( \log(k) + \alpha r(P_i) ) + $$ + + where: + - $s_i$ activation of component $i$, $s_j$ activation of component $j$ + - $r(P_i)$ rank of component $i$, $r(P_j)$ rank of component $j$ + - $k$ is the total number of components + """ + + k_groups: int = acts.shape[0] + assert k_groups == merges.k_groups, "Merges must match activation vector shape" + + return ( + (acts * (math.log2(k_groups) + alpha * merges.components_per_group.to(device=acts.device))) + .sum() + .item() + ) + + +def compute_merge_costs( + coact: ClusterCoactivationShaped, + merges: GroupMerge, + alpha: float = 1.0, +) -> ClusterCoactivationShaped: + r"""Compute MDL costs for merge matrices + + $$ + F(P_i, P_j) + = \alpha |s_i| r(P_i) + \alpha |s_j| r(P_j) + - s_i s_j ( \alpha r(P_i) + \alpha r(P_j) + c ) + = \alpha ( + |s_i| r(P_i) + + |s_j| r(P_j) + - s_i s_j ( r(P_i) + r(P_j) + c/\alpha ) + ) + $$ + + new version from nathu 2025-08-11 16:48 + + $$ + (s_\Sigma - s_i - s_j) log((c-1)/c) + + s_{i,j} log(c-1) - s_i log(c) - s_j log(c) + + alpha ( s_{i,j} r(P_{i,j}) - s_i r(P_i) - s_j r(P_j) ) + $$ + where: + - $s_\Sigma$ average activation of all components + - $s_i$ activation of component $i$, $s_j$ activation of component $j$ + - $s_{i,j}$ activation of the merged component $i,j$ + - $r(P_i)$ rank of component $i$, $r(P_j)$ rank of component $j$ + - $r(P_{i,j})$ rank of the merged component $i,j$ + + """ + k_groups: int = coact.shape[0] + assert coact.shape[1] == k_groups, "Coactivation matrix must be square" + assert merges.k_groups == k_groups, "Merges must match coactivation matrix shape" + + device: torch.device = coact.device + ranks: Float[Tensor, " k_groups"] = merges.components_per_group.to(device=device).float() + s_diag: Float[Tensor, " k_groups"] = torch.diag(coact).to(device=device) + # term_si_rpj: Float[Tensor, "k_groups k_groups"] = s_diag.view(-1, 1) * ranks.view(1, -1) + # term_si_rpj: Float[Tensor, "k_groups k_groups"] = s_diag.view(-1, 1) * (ranks.view(1, -1) + 1/alpha) + term_si_rpi: Float[Tensor, " k_groups"] = s_diag * ranks + # dbg_auto(term_si_rpi) + rank_sum: ClusterCoactivationShaped = ranks.view(-1, 1) + ranks.view(1, -1) + # TODO: use dynamic rank computation + # return alpha * ( + # term_si_rpj # |s_i| r(P_j) + # + term_si_rpj.T # |s_j| r(P_i) + # - coact * ( # s_i s_j + # rank_sum # r(P_i) + r(P_j) + # + (rank_cost(merges.k_groups) / alpha) # c / alpha + # ) + # ) + + coact_OR: ClusterCoactivationShaped = s_diag.view(-1, 1) + s_diag.view(1, -1) - coact + + # reduce penalty for sending dictionary by 1 + # (s_\Sigma - s_i - s_j) log((c-1)/c) + # delta of cost for sending index, in expectation + # + s_{i,j} log(c-1) - s_i log(c) - s_j log(c) + # delta of cost for sending ranks, in expectation + # + alpha ( s_{i,j} r(P_{i,j}) - s_i r(P_i) - s_j r(P_j) + + s_other: ClusterCoactivationShaped = ( + s_diag.sum() - s_diag.view(-1, 1) - s_diag.view(1, -1) + ) * math.log2((k_groups - 1) / k_groups) + + bits_local: ClusterCoactivationShaped = ( + coact_OR * math.log2(k_groups - 1) + - s_diag.view(-1, 1) * math.log2(k_groups) + - s_diag.view(1, -1) * math.log2(k_groups) + ) + + penalty: ClusterCoactivationShaped = ( + coact_OR * rank_sum # s_{i,j} r(P_{i,j}) + - term_si_rpi.view(-1, 1) # s_i r(P_i) + - term_si_rpi.view(1, -1) # s_j r(P_j) + ) + + output: ClusterCoactivationShaped = s_other + bits_local + alpha * penalty + return output + + +def recompute_coacts_merge_pair( + coact: ClusterCoactivationShaped, + merges: GroupMerge, + merge_pair: MergePair, + activation_mask: Bool[Tensor, "samples k_groups"], +) -> tuple[ + GroupMerge, + Float[Tensor, "k_groups-1 k_groups-1"], + Bool[Tensor, "samples k_groups"], +]: + # check shape + k_groups: int = coact.shape[0] + assert coact.shape[1] == k_groups, "Coactivation matrix must be square" + + # activations of the new merged group + activation_mask_grp: Bool[Tensor, " samples"] = ( + activation_mask[:, merge_pair[0]] + activation_mask[:, merge_pair[1]] + ) + + # coactivations with the new merged group + coact_with_merge: Float[Tensor, " k_groups"] = ( + activation_mask_grp.float() @ activation_mask.float() + ) + new_group_idx: int = min(merge_pair) + remove_idx: int = max(merge_pair) + new_group_self_coact: float = activation_mask_grp.float().sum().item() + + # assemble the merge pair + merge_new: GroupMerge = merges.merge_groups( + merge_pair[0], + merge_pair[1], + ) + # TODO: we don't use this index for anything, and could reconstruct it from the merge pair if needed. get rid of it + # `merge_groups` will set `old_to_new_idx` to be an actual dict for `merge_new` + old_to_new_idx: dict[int | None, int | None] = merge_new.old_to_new_idx # pyright: ignore[reportAssignmentType] + assert old_to_new_idx[None] == new_group_idx, ( + "New group index should be the minimum of the merge pair" + ) + assert old_to_new_idx[new_group_idx] is None + assert old_to_new_idx[remove_idx] is None + # TODO: check that the rest are in order? probably not necessary + + # reindex coactivations + coact_temp: ClusterCoactivationShaped = coact.clone() + # add in the similarities with the new group + coact_temp[new_group_idx, :] = coact_with_merge + coact_temp[:, new_group_idx] = coact_with_merge + # delete the old group + mask: Bool[Tensor, " k_groups"] = torch.ones( + coact_temp.shape[0], dtype=torch.bool, device=coact_temp.device + ) + mask[remove_idx] = False + coact_new: Float[Tensor, "k_groups-1 k_groups-1"] = coact_temp[mask, :][:, mask] + # add in the self-coactivation of the new group + coact_new[new_group_idx, new_group_idx] = new_group_self_coact + + # reindex mask + activation_mask_new: Float[Tensor, "samples ..."] = activation_mask.clone() + # add in the new group + activation_mask_new[:, new_group_idx] = activation_mask_grp + # remove the old group + activation_mask_new = activation_mask_new[:, mask] + + return ( + merge_new, + coact_new, + activation_mask_new, + ) diff --git a/spd/clustering/configs/README.md b/spd/clustering/configs/README.md new file mode 100644 index 000000000..e1ac41f47 --- /dev/null +++ b/spd/clustering/configs/README.md @@ -0,0 +1 @@ +this folder `configs/` contains files with `ClusteringPipelineConfig`s. the folder `configs/crc/` contains files with `ClusteringRunConfig`s, which may be referenced by the `clustering_run_config_path` field in the pipeline configs. \ No newline at end of file diff --git a/spd/clustering/configs/crc/example.yaml b/spd/clustering/configs/crc/example.yaml new file mode 100644 index 000000000..9345307d2 --- /dev/null +++ b/spd/clustering/configs/crc/example.yaml @@ -0,0 +1,23 @@ +model_path: wandb:goodfire/spd/runs/zxbu57pt # WandB path to the decomposed model +batch_size: 8 # Batch size for processing -- number of samples for each run in the ensemble +dataset_seed: 0 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) +# output_dir: .data/clustering/clustering_runs # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) +# ensemble_id: 1234567890 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) + +merge_config: + activation_threshold: 0.01 # set to null to use scalar activations for cost calculation + alpha: 1.0 # rank penalty term + iters: 10 # iterations to run. setting this to exactly the number of components can be buggy when doing ensembles, so set it to a bit less? + merge_pair_sampling_method: "range" # Method for sampling merge pairs: 'range' or 'mcmc' + merge_pair_sampling_kwargs: + threshold: 0.05 # For range sampler: fraction of the range of costs to sample from + filter_dead_threshold: 0.001 # Threshold for filtering dead components + module_name_filter: null # Can be a string prefix like "model.layers.0." if you want to do only some modules + +wandb_project: spd-cluster +wandb_entity: goodfire +logging_intervals: + stat: 1 # for k_groups, merge_pair_cost, mdl_loss + tensor: 100 # for wandb_log_tensor and fraction_* calculations + plot: 100 # for calling the plotting callback + artifact: 100 # for calling the artifact callback \ No newline at end of file diff --git a/spd/clustering/configs/crc/resid_mlp1.json b/spd/clustering/configs/crc/resid_mlp1.json new file mode 100644 index 000000000..1e13ce23e --- /dev/null +++ b/spd/clustering/configs/crc/resid_mlp1.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.01, + "alpha": 1, + "iters": 5, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.05}, + "filter_dead_threshold": 0, + "module_name_filter": null + }, + "experiment_key": "resid_mlp1", + "batch_size": 128, + "wandb_project": "spd-cluster", + "logging_intervals": { + "stat": 1, + "tensor": 5, + "plot": 5, + "artifact": 5 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/crc/resid_mlp2.json b/spd/clustering/configs/crc/resid_mlp2.json new file mode 100644 index 000000000..edc4849e2 --- /dev/null +++ b/spd/clustering/configs/crc/resid_mlp2.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.01, + "alpha": 1, + "iters": 100, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.05}, + "filter_dead_threshold": 0.01, + "module_name_filter": null + }, + "experiment_key": "resid_mlp2", + "batch_size": 1024, + "wandb_project": "spd-cluster", + "logging_intervals": { + "stat": 1, + "tensor": 5, + "plot": 5, + "artifact": 50 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/crc/simplestories_dev.json b/spd/clustering/configs/crc/simplestories_dev.json new file mode 100644 index 000000000..e1647b6e4 --- /dev/null +++ b/spd/clustering/configs/crc/simplestories_dev.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.1, + "alpha": 1.0, + "iters": 100, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "model_path": "wandb:goodfire/spd/runs/lxs77xye", + "batch_size": 32, + "wandb_project": null, + "logging_intervals": { + "stat": 1, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/crc/ss_llama_simple_mlp-1L.json b/spd/clustering/configs/crc/ss_llama_simple_mlp-1L.json new file mode 100644 index 000000000..eba5723d2 --- /dev/null +++ b/spd/clustering/configs/crc/ss_llama_simple_mlp-1L.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.1, + "alpha": 0.1, + "iters": 1000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "model_path": "wandb:goodfire/spd/runs/5cr21lbs", + "batch_size": 2048, + "wandb_project": null, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + } + } \ No newline at end of file diff --git a/spd/clustering/configs/crc/ss_llama_simple_mlp.json b/spd/clustering/configs/crc/ss_llama_simple_mlp.json new file mode 100644 index 000000000..6cf534ec5 --- /dev/null +++ b/spd/clustering/configs/crc/ss_llama_simple_mlp.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.1, + "alpha": 1, + "iters": 1000, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "model_path": "wandb:goodfire/spd/runs/vjbol27n", + "batch_size": 512, + "wandb_project": null, + "logging_intervals": { + "stat": 10, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + } + } \ No newline at end of file diff --git a/spd/clustering/configs/crc/test-resid_mlp1.json b/spd/clustering/configs/crc/test-resid_mlp1.json new file mode 100644 index 000000000..4b3a26ff8 --- /dev/null +++ b/spd/clustering/configs/crc/test-resid_mlp1.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.5, + "alpha": 1, + "iters": 16, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.05}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "experiment_key": "resid_mlp1", + "batch_size": 128, + "wandb_project": null, + "logging_intervals": { + "stat": 1, + "tensor": 5, + "plot": 10, + "artifact": 10 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/crc/test-simplestories.json b/spd/clustering/configs/crc/test-simplestories.json new file mode 100644 index 000000000..911f71529 --- /dev/null +++ b/spd/clustering/configs/crc/test-simplestories.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.9, + "alpha": 1.0, + "iters": 5, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.05}, + "filter_dead_threshold": 0.9, + "module_name_filter": "model.layers.0" + }, + "model_path": "wandb:goodfire/spd/runs/lxs77xye", + "batch_size": 1, + "wandb_project": null, + "logging_intervals": { + "stat": 1, + "tensor": 2, + "plot": 3, + "artifact": 4 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml new file mode 100644 index 000000000..1868b5887 --- /dev/null +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -0,0 +1,9 @@ +n_runs: 2 +distances_methods: ["matching_dist"] +# base_output_dir: "tests/.temp/clustering" +slurm_job_name_prefix: null +slurm_partition: null +wandb_project: "spd-cluster" # wandb fails in CI +wandb_entity: "goodfire" +create_git_snapshot: false +clustering_run_config_path: "spd/clustering/configs/crc/simplestories_dev.json" \ No newline at end of file diff --git a/spd/clustering/configs/pipeline_config.yaml b/spd/clustering/configs/pipeline_config.yaml new file mode 100644 index 000000000..93554aebb --- /dev/null +++ b/spd/clustering/configs/pipeline_config.yaml @@ -0,0 +1,9 @@ +clustering_run_config_path: "spd/clustering/configs/crc/ss_llama_simple_mlp-1L.json" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +base_output_dir: "/mnt/polished-lake/spd/clustering" +slurm_job_name_prefix: "spd" +slurm_partition: "h200-reserved" +wandb_project: "spd" +wandb_entity: "goodfire" +create_git_snapshot: true \ No newline at end of file diff --git a/spd/clustering/consts.py b/spd/clustering/consts.py new file mode 100644 index 000000000..8a9647dc8 --- /dev/null +++ b/spd/clustering/consts.py @@ -0,0 +1,48 @@ +"""Constants and shared abstractions for clustering pipeline.""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Literal, NewType + +import numpy as np +from jaxtyping import Bool, Float, Int +from torch import Tensor + +# Merge arrays and distances (numpy-based for storage/analysis) +MergesAtIterArray = Int[np.ndarray, "n_ens n_components"] +MergesArray = Int[np.ndarray, "n_ens n_iters n_components"] +DistancesMethod = Literal["perm_invariant_hamming", "matching_dist", "matching_dist_vec"] +DistancesArray = Float[np.ndarray, "n_iters n_ens n_ens"] + +# Component and label types (NewType for stronger type safety) +ComponentLabel = NewType("ComponentLabel", str) # Format: "module_name:component_index" +ComponentLabels = NewType("ComponentLabels", list[str]) +BatchId = NewType("BatchId", str) + +# Path types +WandBPath = NewType("WandBPath", str) # Format: "wandb:entity/project/run_id" + +# Merge types +MergePair = NewType("MergePair", tuple[int, int]) + +# Tensor type aliases (torch-based for computation - TypeAlias for jaxtyping compatibility) +ActivationsTensor = Float[Tensor, "samples n_components"] +BoolActivationsTensor = Bool[Tensor, "samples n_components"] +ClusterCoactivationShaped = Float[Tensor, "k_groups k_groups"] +GroupIdxsTensor = Int[Tensor, " n_components"] +BatchTensor = Int[Tensor, "batch_size seq_len"] + + +class SaveableObject(ABC): + """Abstract base class for objects that can be saved to and loaded from disk.""" + + @abstractmethod + def save(self, path: Path) -> None: + """Save the object to disk at the given path.""" + ... + + @classmethod + @abstractmethod + def read(cls, path: Path) -> "SaveableObject": + """Load the object from disk at the given path.""" + ... diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py new file mode 100644 index 000000000..7513a64ee --- /dev/null +++ b/spd/clustering/dataset.py @@ -0,0 +1,135 @@ +"""Dataset loading utilities for clustering runs. + +Each clustering run loads its own dataset batch, seeded by the run index. +""" + +from typing import Any + +from spd.clustering.consts import BatchTensor +from spd.data import DatasetConfig, create_data_loader +from spd.experiments.lm.configs import LMTaskConfig +from spd.experiments.resid_mlp.configs import ResidMLPTaskConfig +from spd.experiments.resid_mlp.models import ResidMLP +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.spd_types import TaskName + + +def load_dataset( + model_path: str, + task_name: TaskName, + batch_size: int, + seed: int, + **kwargs: Any, +) -> BatchTensor: + """Load a single batch for clustering. + + Each run gets its own dataset batch, seeded by index in ensemble. + + Args: + model_path: Path to decomposed model + task_name: Task type + batch_size: Batch size + seed: Random seed for dataset + + Returns: + Single batch of data + """ + match task_name: + case "lm": + return _load_lm_batch( + model_path=model_path, + batch_size=batch_size, + seed=seed, + **kwargs, + ) + case "resid_mlp": + return _load_resid_mlp_batch( + model_path=model_path, + batch_size=batch_size, + seed=seed, + **kwargs, + ) + case _: + raise ValueError(f"Unsupported task: {task_name}") + + +def _load_lm_batch( + model_path: str, batch_size: int, seed: int, config_kwargs: dict[str, Any] | None = None +) -> BatchTensor: + """Load a batch for language model task.""" + spd_run = SPDRunInfo.from_path(model_path) + cfg = spd_run.config + + assert isinstance(cfg.task_config, LMTaskConfig), ( + f"Expected task_config to be of type LMTaskConfig, but got {type(cfg.task_config) = }" + ) + + try: + pretrained_model_name: str = cfg.pretrained_model_name # pyright: ignore[reportAssignmentType] + assert pretrained_model_name is not None + except Exception as e: + raise AttributeError("Could not find 'pretrained_model_name' in the SPD Run config") from e + + config_kwargs_: dict[str, Any] = { + **dict( + is_tokenized=False, + streaming=False, + ), + **(config_kwargs or {}), + } + + dataset_config = DatasetConfig( + name=cfg.task_config.dataset_name, + hf_tokenizer_path=cfg.tokenizer_name, + split=cfg.task_config.train_data_split, + n_ctx=cfg.task_config.max_seq_len, + seed=seed, # Use run-specific seed + column_name=cfg.task_config.column_name, + **config_kwargs_, + ) + + dataloader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=batch_size, + buffer_size=cfg.task_config.buffer_size, + global_seed=seed, # Use run-specific seed + ) + + # Get first batch + batch = next(iter(dataloader)) + return batch["input_ids"] + + +def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchTensor: + """Load a batch for ResidMLP task.""" + from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset + from spd.utils.data_utils import DatasetGeneratedDataLoader + + spd_run = SPDRunInfo.from_path(model_path) + cfg = spd_run.config + component_model = ComponentModel.from_pretrained(spd_run.checkpoint_path) + + assert isinstance(cfg.task_config, ResidMLPTaskConfig), ( + f"Expected task_config to be of type ResidMLPTaskConfig, but got {type(cfg.task_config) = }" + ) + assert isinstance(component_model.target_model, ResidMLP), ( + f"Expected target_model to be of type ResidMLP, but got {type(component_model.target_model) = }" + ) + + # Create dataset with run-specific seed + dataset = ResidMLPDataset( + n_features=component_model.target_model.config.n_features, + feature_probability=cfg.task_config.feature_probability, + device="cpu", + calc_labels=False, + label_type=None, + act_fn_name=None, + label_fn_seed=seed, # Use run-specific seed + label_coeffs=None, + data_generation_type=cfg.task_config.data_generation_type, + ) + + # Generate batch + dataloader = DatasetGeneratedDataLoader(dataset, batch_size=batch_size, shuffle=False) + batch, _ = next(iter(dataloader)) + return batch diff --git a/spd/clustering/ensemble_registry.py b/spd/clustering/ensemble_registry.py new file mode 100644 index 000000000..c54fe408b --- /dev/null +++ b/spd/clustering/ensemble_registry.py @@ -0,0 +1,87 @@ +"""Ensemble registry for tracking which clustering runs belong to which pipeline ensemble. + +Uses SQLite to maintain a mapping of (pipeline_run_id, idx, clustering_run_id). +""" + +import sqlite3 +from contextlib import contextmanager + +from spd.settings import SPD_CACHE_DIR + +# SQLite database path +_ENSEMBLE_REGISTRY_DB = SPD_CACHE_DIR / "clustering_ensemble_registry.db" + + +@contextmanager +def _get_connection(): + """Context manager for SQLite connection, ensures table exists.""" + _ENSEMBLE_REGISTRY_DB.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(_ENSEMBLE_REGISTRY_DB) + + try: + # Create table if not exists + conn.execute(""" + CREATE TABLE IF NOT EXISTS ensemble_runs ( + pipeline_run_id TEXT NOT NULL, + idx INTEGER NOT NULL, + clustering_run_id TEXT NOT NULL, + PRIMARY KEY (pipeline_run_id, idx) + ) + """) + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_pipeline_run_id + ON ensemble_runs (pipeline_run_id) + """) + conn.commit() + + yield conn + finally: + conn.close() + + +def register_clustering_run(pipeline_run_id: str, clustering_run_id: str) -> int: + """Register a clustering run as part of a pipeline ensemble. + + Args: + pipeline_run_id: The ensemble/pipeline run ID + idx: Index of this run in the ensemble. If -1, auto-assigns the next available index. + clustering_run_id: The individual clustering run ID + + Returns: + The index assigned to this run (either the provided idx or the auto-assigned one) + """ + with _get_connection() as conn: + # Use BEGIN IMMEDIATE for thread-safe auto-increment + conn.execute("BEGIN IMMEDIATE") + + # Auto-assign next available index, we rely on atomicity of the transaction here + cursor = conn.execute( + "SELECT COALESCE(MAX(idx), -1) + 1 FROM ensemble_runs WHERE pipeline_run_id = ?", + (pipeline_run_id,), + ) + assigned_idx: int = cursor.fetchone()[0] + + conn.execute( + "INSERT INTO ensemble_runs (pipeline_run_id, idx, clustering_run_id) VALUES (?, ?, ?)", + (pipeline_run_id, assigned_idx, clustering_run_id), + ) + conn.commit() + + return assigned_idx + + +def get_clustering_runs(pipeline_run_id: str) -> list[tuple[int, str]]: + """Get all clustering runs for a pipeline ensemble. + + Args: + pipeline_run_id: The ensemble/pipeline run ID + + Returns: + List of (idx, clustering_run_id) tuples, sorted by idx + """ + with _get_connection() as conn: + cursor = conn.execute( + "SELECT idx, clustering_run_id FROM ensemble_runs WHERE pipeline_run_id = ? ORDER BY idx", + (pipeline_run_id,), + ) + return cursor.fetchall() diff --git a/spd/clustering/math/__init__.py b/spd/clustering/math/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/clustering/math/matching_dist.py b/spd/clustering/math/matching_dist.py new file mode 100644 index 000000000..1991e9ba0 --- /dev/null +++ b/spd/clustering/math/matching_dist.py @@ -0,0 +1,47 @@ +import numpy as np +import torch +from jaxtyping import Bool, Float, Int +from torch import Tensor + +_DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def matching_dist( + X: Int[Tensor, "s n"], +) -> Float[Tensor, "s s"]: + s_ensemble, _n_components = X.shape + matches: Bool[Tensor, "s n n"] = X[:, :, None] == X[:, None, :] + + dists: Float[Tensor, "s s"] = torch.full((s_ensemble, s_ensemble), torch.nan) + + for i in range(s_ensemble): + for j in range(i + 1, s_ensemble): + dist_mat = matches[i].float() - matches[j].float() + dists[i, j] = torch.tril(dist_mat, diagonal=-1).abs().sum() + + return dists + + +def matching_dist_vec( + X: Int[Tensor, "s n"], +) -> Float[Tensor, "s s"]: + matches: Bool[Tensor, "s n n"] = X[:, :, None] == X[:, None, :] + diffs: Bool[Tensor, "s s n n"] = matches[:, None, :, :] ^ matches[None, :, :, :] + + dists_int: torch.Tensor = diffs.sum(dim=(-1, -2)) + dists: Float[Tensor, "s s"] = dists_int.to(torch.float32) + return dists + + +def matching_dist_np( + X: Int[np.ndarray, "s n"], + device: torch.device = _DEVICE, +) -> Float[np.ndarray, "s s"]: + return matching_dist(torch.tensor(X, device=device)).cpu().numpy() + + +def matching_dist_vec_np( + X: Int[np.ndarray, "s n"], + device: torch.device = _DEVICE, +) -> Float[np.ndarray, "s s"]: + return matching_dist_vec(torch.tensor(X, device=device)).cpu().numpy() diff --git a/spd/clustering/math/merge_distances.py b/spd/clustering/math/merge_distances.py new file mode 100644 index 000000000..ff4ebdc4b --- /dev/null +++ b/spd/clustering/math/merge_distances.py @@ -0,0 +1,57 @@ +from collections.abc import Callable, Iterable +from multiprocessing import Pool +from typing import TypeVar + +import numpy as np +from jaxtyping import Float, Int + +from spd.clustering.consts import ( + DistancesArray, + DistancesMethod, + MergesArray, + MergesAtIterArray, +) +from spd.clustering.math.matching_dist import matching_dist_np, matching_dist_vec_np +from spd.clustering.math.perm_invariant_hamming import perm_invariant_hamming_matrix + +_T = TypeVar("_T") +_R = TypeVar("_R") + + +def _run_parallel(func: Callable[[_T], _R], iterable: Iterable[_T]) -> list[_R]: + """Run a function in parallel over an iterable using multiprocessing.""" + items = list(iterable) + with Pool() as pool: + return pool.map(func, items) + + +DISTANCES_METHODS: dict[DistancesMethod, Callable[[MergesAtIterArray], DistancesArray]] = { + "perm_invariant_hamming": perm_invariant_hamming_matrix, + "matching_dist": matching_dist_np, +} + +# pyright: reportUnnecessaryComparison=false, reportUnreachable=false + + +def compute_distances( + normalized_merge_array: MergesArray, + method: DistancesMethod = "perm_invariant_hamming", +) -> DistancesArray: + n_iters: int = normalized_merge_array.shape[1] + merges_array_list: list[Int[np.ndarray, "n_ens n_components"]] + distances_list: list[Float[np.ndarray, "n_ens n_ens"]] + match method: + case "perm_invariant_hamming": + merges_array_list = [normalized_merge_array[:, i, :] for i in range(n_iters)] + distances_list = _run_parallel(perm_invariant_hamming_matrix, merges_array_list) + return np.stack(distances_list, axis=0) + case "matching_dist": + merges_array_list = [normalized_merge_array[:, i, :] for i in range(n_iters)] + distances_list = _run_parallel(matching_dist_np, merges_array_list) + return np.stack(distances_list, axis=0) + case "matching_dist_vec": + merges_array_list = [normalized_merge_array[:, i, :] for i in range(n_iters)] + distances_list = _run_parallel(matching_dist_vec_np, merges_array_list) + return np.stack(distances_list, axis=0) + case _: + raise ValueError(f"Unknown distance method: {method}") diff --git a/spd/clustering/math/merge_matrix.py b/spd/clustering/math/merge_matrix.py new file mode 100644 index 000000000..0b9e65086 --- /dev/null +++ b/spd/clustering/math/merge_matrix.py @@ -0,0 +1,288 @@ +from dataclasses import dataclass + +import torch +from jaxtyping import Bool, Int +from torch import Tensor + +from spd.clustering.consts import GroupIdxsTensor + + +def _array_summary(arr: Tensor) -> str: + """Get a brief summary string for a tensor.""" + return f"shape={tuple(arr.shape)}, dtype={arr.dtype}" + + +# pyright: reportUnnecessaryTypeIgnoreComment=false + + +@dataclass(kw_only=True, slots=True) +class GroupMerge: + """Canonical component-to-group assignment. + + `group_idxs` is a length-`n_components` integer tensor; entry `c` + gives the group index (0 to `k_groups-1`) that contains component `c`. + """ + + group_idxs: GroupIdxsTensor + k_groups: int + old_to_new_idx: dict[int | None, int | None] | None = None + + def summary(self) -> dict[str, int | str | None]: + return dict( + group_idxs=_array_summary(self.group_idxs), + k_groups=self.k_groups, + old_to_new_idx=f"len={len(self.old_to_new_idx)}" + if self.old_to_new_idx is not None + else None, + ) + + @property + def _n_components(self) -> int: + return int(self.group_idxs.shape[0]) + + @property + def components_per_group(self) -> Int[Tensor, " k_groups"]: + return torch.bincount(self.group_idxs, minlength=self.k_groups) + + def components_in_group_mask(self, group_idx: int) -> Bool[Tensor, " n_components"]: + """Returns a boolean mask for components in the specified group.""" + if group_idx < 0 or group_idx >= self.k_groups: + raise ValueError("group index out of range") + return self.group_idxs == group_idx + + def components_in_group(self, group_idx: int) -> list[int]: + """Returns a list of component indices in the specified group.""" + indices: Int[Tensor, " n_matches"] = ( + (self.group_idxs == group_idx).nonzero(as_tuple=False).squeeze(-1) + ) + return indices.tolist() + + def validate(self, *, require_nonempty: bool = True) -> None: + v_min: int = int(self.group_idxs.min().item()) + v_max: int = int(self.group_idxs.max().item()) + if v_min < 0 or v_max >= self.k_groups: + raise ValueError("group indices out of range") + + if require_nonempty: + has_empty_groups: bool = bool(self.components_per_group.eq(0).any().item()) + if has_empty_groups: + raise ValueError("one or more groups are empty") + + def to_matrix( + self, device: torch.device | None = None + ) -> Bool[Tensor, "k_groups n_components"]: + if device is None: + device = self.group_idxs.device + mat: Bool[Tensor, "k_groups n_components"] = torch.zeros( + (self.k_groups, self._n_components), dtype=torch.bool, device=device + ) + idxs: Int[Tensor, " n_components"] = torch.arange( + self._n_components, device=device, dtype=torch.int + ) + mat[self.group_idxs.to(dtype=torch.int), idxs] = True + return mat + + @classmethod + def from_matrix(cls, mat: Bool[Tensor, "k_groups n_components"]) -> "GroupMerge": + if mat.dtype is not torch.bool: + raise TypeError("mat must have dtype bool") + if not mat.sum(dim=0).eq(1).all(): + raise ValueError("each column must contain exactly one True") + group_idxs: GroupIdxsTensor = mat.argmax(dim=0).to(torch.int64) + inst: GroupMerge = cls(group_idxs=group_idxs, k_groups=int(mat.shape[0])) + inst.validate(require_nonempty=False) + return inst + + @classmethod + def random( + cls, + n_components: int, + k_groups: int, + *, + ensure_groups_nonempty: bool = False, + device: torch.device | str = "cpu", + ) -> "GroupMerge": + if ensure_groups_nonempty and n_components < k_groups: + raise ValueError("n_components must be >= k_groups when ensure_groups_nonempty is True") + + group_idxs: GroupIdxsTensor + + if ensure_groups_nonempty: + base: Int[Tensor, " k_groups"] = torch.arange(k_groups, device=device) + if n_components > k_groups: + extra: Int[Tensor, " n_extra"] = torch.randint( + 0, k_groups, (n_components - k_groups,), device=device + ) + group_idxs = torch.cat((base, extra)) + group_idxs = group_idxs[torch.randperm(n_components, device=device)] + else: + group_idxs = base + else: + group_idxs = torch.randint(0, k_groups, (n_components,), device=device) + inst: GroupMerge = cls(group_idxs=group_idxs, k_groups=k_groups) + inst.validate(require_nonempty=ensure_groups_nonempty) + return inst + + @classmethod + def identity(cls, n_components: int) -> "GroupMerge": + """Creates a GroupMerge where each component is its own group.""" + return cls( + group_idxs=torch.arange(n_components, dtype=torch.int64), + k_groups=n_components, + ) + + def merge_groups(self, group_a: int, group_b: int) -> "GroupMerge": + """Merges two groups into one, returning a new GroupMerge.""" + if group_a < 0 or group_b < 0 or group_a >= self.k_groups or group_b >= self.k_groups: + raise ValueError("group indices out of range") + if group_a == group_b: + raise ValueError("Cannot merge a group with itself") + + # make sure group_a is the smaller index + if group_a > group_b: + group_a, group_b = group_b, group_a + + # make a copy + new_idxs: GroupIdxsTensor = self.group_idxs.clone() + # wherever its currently b, change it to a + new_idxs[new_idxs == group_b] = group_a + # wherever i currently above b, change it to i-1 + new_idxs[new_idxs > group_b] -= 1 + # create a new GroupMerge instance + merged: GroupMerge = GroupMerge(group_idxs=new_idxs, k_groups=self.k_groups - 1) + + # create a mapping from old to new group indices + # `None` as a key is for the new group that contains both a and b + # values of a and b are mapped to `None` since they are merged + old_to_new_idx: dict[int | None, int | None] = dict() + for i in range(self.k_groups): + if i in {group_a, group_b}: + old_to_new_idx[i] = None + elif i <= group_b: + old_to_new_idx[i] = i + else: + old_to_new_idx[i] = i - 1 + old_to_new_idx[None] = group_a # the new group index for the merged group + + # HACK: store the mapping in the instance for later use + merged.old_to_new_idx = old_to_new_idx # type: ignore[assignment] + + # validate the new instance + # merged.validate(require_nonempty=True) + return merged + + def all_downstream_merged(self) -> "BatchedGroupMerge": + downstream: list[GroupMerge] = [] + idxs: list[tuple[int, int]] = [] + for i in range(self.k_groups): + for j in range(i + 1, self.k_groups): + downstream.append(self.merge_groups(i, j)) + idxs.append((i, j)) + + return BatchedGroupMerge.from_list(merge_matrices=downstream) + + +@dataclass(slots=True) +class BatchedGroupMerge: + """Batch of merge matrices. + + `group_idxs` has shape `(batch, n_components)`; each row holds the + group index for every component in that matrix. + """ + + group_idxs: Int[Tensor, "batch n_components"] + k_groups: Int[Tensor, " batch"] + + def summary(self) -> dict[str, int | str | None]: + return dict( + group_idxs=_array_summary(self.group_idxs), + k_groups=_array_summary(self.k_groups), + # TODO: re-add metadata (which pairs merged at each step) + # meta=f"len={len(self.meta)}" if self.meta is not None else None, + ) + + @classmethod + def init_empty(cls, batch_size: int, n_components: int) -> "BatchedGroupMerge": + """Initialize an empty BatchedGroupMerge with the given batch size and number of components.""" + return cls( + group_idxs=torch.full((batch_size, n_components), -1, dtype=torch.int16), + k_groups=torch.zeros(batch_size, dtype=torch.int16), + ) + + @property + def _batch_size(self) -> int: + return int(self.group_idxs.shape[0]) + + @property + def _n_components(self) -> int: + return int(self.group_idxs.shape[1]) + + @property + def k_groups_unique(self) -> int: + """Returns the number of groups across all matrices, throws exception if they differ.""" + k_groups_set: set[int] = set(self.k_groups.tolist()) + if len(k_groups_set) != 1: + raise ValueError("All matrices must have the same number of groups") + return k_groups_set.pop() + + def to_matrix( + self, device: torch.device | None = None + ) -> Bool[Tensor, "batch k_groups n_components"]: + if device is None: + device = self.group_idxs.device + k_groups_u: int = self.k_groups_unique + mat = torch.nn.functional.one_hot(self.group_idxs, num_classes=k_groups_u) + return mat.permute(0, 2, 1).to(device=device, dtype=torch.bool) + + @classmethod + def from_matrix(cls, mat: Bool[Tensor, "batch k_groups n_components"]) -> "BatchedGroupMerge": + if mat.dtype is not torch.bool: + raise TypeError("mat must have dtype bool") + if not mat.sum(dim=1).eq(1).all(): + raise ValueError("each column must have exactly one True per matrix") + group_idxs = mat.argmax(dim=1).to(torch.int64) + batch_size: int = int(mat.shape[0]) + inst = cls( + group_idxs=group_idxs, + k_groups=torch.full((batch_size,), int(mat.shape[1]), dtype=torch.int64), + ) + # inst.validate(require_nonempty=False) + return inst + + @classmethod + def from_list( + cls, + merge_matrices: list[GroupMerge], + ) -> "BatchedGroupMerge": + group_idxs: Int[Tensor, "batch n_components"] = torch.stack( + [mm.group_idxs for mm in merge_matrices], dim=0 + ) + k_groups: Int[Tensor, " batch"] = torch.tensor( + [mm.k_groups for mm in merge_matrices], dtype=torch.int64 + ) + inst: BatchedGroupMerge = cls(group_idxs=group_idxs, k_groups=k_groups) + # inst.validate(require_nonempty=False) + return inst + + def __getitem__(self, idx: int) -> GroupMerge: + if not (0 <= idx < self._batch_size): + raise IndexError("index out of range") + group_idxs: GroupIdxsTensor = self.group_idxs[idx] + k_groups: int = int(self.k_groups[idx].item()) + return GroupMerge(group_idxs=group_idxs, k_groups=k_groups) + + def __setitem__(self, idx: int, value: GroupMerge) -> None: + if not (0 <= idx < self._batch_size): + raise IndexError("index out of range") + if value._n_components != self._n_components: + raise ValueError("value must have the same number of components as the batch") + self.group_idxs[idx] = value.group_idxs + self.k_groups[idx] = value.k_groups + + def __iter__(self): + """Iterate over the GroupMerge instances in the batch.""" + for i in range(self._batch_size): + yield self[i] + + def __len__(self) -> int: + return self._batch_size diff --git a/spd/clustering/math/merge_pair_samplers.py b/spd/clustering/math/merge_pair_samplers.py new file mode 100644 index 000000000..24c050d36 --- /dev/null +++ b/spd/clustering/math/merge_pair_samplers.py @@ -0,0 +1,121 @@ +import random +from typing import Any, Literal, Protocol + +import torch +from jaxtyping import Bool, Float, Int +from torch import Tensor + +from spd.clustering.consts import ClusterCoactivationShaped, MergePair + +MergePairSamplerKey = Literal["range", "mcmc"] + + +class MergePairSamplerConfigurable(Protocol): + def __call__( + self, + costs: ClusterCoactivationShaped, + **kwargs: Any, + ) -> MergePair: ... + + +class MergePairSampler(Protocol): + def __call__( + self, + costs: ClusterCoactivationShaped, + ) -> MergePair: ... + + +def range_sampler( + costs: ClusterCoactivationShaped, + threshold: float = 0.05, + **kwargs: Any, +) -> MergePair: + """Sample a merge pair using threshold-based range selection. + + Considers all pairs with costs below a threshold defined as a fraction + of the range of non-diagonal costs, then randomly selects one. + + Args: + costs: Cost matrix for all possible merges + k_groups: Number of current groups + threshold: Fraction of cost range to consider (0=min only, 1=all pairs) + + Returns: + Tuple of (group_i, group_j) indices to merge + """ + assert not kwargs + k_groups: int = costs.shape[0] + assert costs.shape[1] == k_groups, "Cost matrix must be square" + + # Find the range of non-diagonal costs + non_diag_costs: Float[Tensor, " k_groups_squared_minus_k"] = costs[ + ~torch.eye(k_groups, dtype=torch.bool, device=costs.device) + ] + min_cost: float = float(non_diag_costs.min().item()) + max_cost: float = float(non_diag_costs.max().item()) + + # Calculate threshold cost + max_considered_cost: float = (max_cost - min_cost) * threshold + min_cost + + # Find all pairs below threshold + considered_idxs: Int[Tensor, "n_considered 2"] = torch.stack( + torch.where(costs <= max_considered_cost), dim=1 + ) + # Remove diagonal entries (i == j) + considered_idxs = considered_idxs[considered_idxs[:, 0] != considered_idxs[:, 1]] + + # Randomly select one of the considered pairs + selected_idx: int = random.randint(0, considered_idxs.shape[0] - 1) + pair_tuple: tuple[int, int] = tuple(considered_idxs[selected_idx].tolist()) # type: ignore[assignment] + return MergePair(pair_tuple) + + +def mcmc_sampler( + costs: ClusterCoactivationShaped, + temperature: float = 1.0, + **kwargs: Any, +) -> MergePair: + """Sample a merge pair using MCMC with probability proportional to exp(-cost/temperature). + + Args: + costs: Cost matrix for all possible merges + k_groups: Number of current groups + temperature: Temperature parameter for softmax (higher = more uniform sampling) + + Returns: + Tuple of (group_i, group_j) indices to merge + """ + assert not kwargs + k_groups: int = costs.shape[0] + assert costs.shape[1] == k_groups, "Cost matrix must be square" + + # Create mask for valid pairs (non-diagonal) + valid_mask: Bool[Tensor, "k_groups k_groups"] = ~torch.eye( + k_groups, dtype=torch.bool, device=costs.device + ) + + # Compute probabilities: exp(-cost/temperature) + # Use stable softmax computation to avoid overflow + costs_masked: ClusterCoactivationShaped = costs.clone() + costs_masked[~valid_mask] = float("inf") # Set diagonal to inf so exp gives 0 + + # Subtract min for numerical stability + min_cost: float = float(costs_masked[valid_mask].min()) + probs: ClusterCoactivationShaped = ( + torch.exp((min_cost - costs_masked) / temperature) * valid_mask + ) # Zero out diagonal + probs_flatten: Float[Tensor, " k_groups_squared"] = probs.flatten() + probs_flatten = probs_flatten / probs_flatten.sum() + + # Sample from multinomial distribution + idx: int = int(torch.multinomial(probs_flatten, 1).item()) + row: int = idx // k_groups + col: int = idx % k_groups + + return MergePair((row, col)) + + +MERGE_PAIR_SAMPLERS: dict[MergePairSamplerKey, MergePairSamplerConfigurable] = { + "range": range_sampler, + "mcmc": mcmc_sampler, +} diff --git a/spd/clustering/math/perm_invariant_hamming.py b/spd/clustering/math/perm_invariant_hamming.py new file mode 100644 index 000000000..e70d3c7c0 --- /dev/null +++ b/spd/clustering/math/perm_invariant_hamming.py @@ -0,0 +1,70 @@ +import warnings + +import numpy as np +from jaxtyping import Float, Int +from scipy.optimize import linear_sum_assignment + + +def perm_invariant_hamming_matrix( + X: Int[np.ndarray, "n_ens n_components"], +) -> Float[np.ndarray, "n_ens n_ens"]: + """Compute all pairwise permutation-invariant Hamming distances. + + The strictly lower-triangular entries are filled with distances; + the diagonal and upper triangle are left as `np.nan`. + + # Parameters: + - `X : Int[np.ndarray, "n_ens n_components"]` + Matrix where each of the `n_ens` rows is a label vector of length `n_components`. + + # Returns: + - `Float[np.ndarray, "n_ens n_ens"]` + Distance matrix `D` with `D[i, j]` defined only for `i > j`; + all other positions are `np.nan`. + + # Usage: + ```python + >>> X = np.array([[0, 0, 1], + ... [1, 1, 0], + ... [0, 1, 0]]) + >>> D = perm_invariant_hamming_matrix(X) + >>> D + array([[nan, nan, nan], + [ 0., nan, nan], + [ 2., 2., nan]]) + ``` + """ + n_ens: int + n_components: int + n_ens, n_components = X.shape + D: Float[np.ndarray, "n_ens n_ens"] = np.full((n_ens, n_ens), np.nan, dtype=float) + + # Pre-compute max label in each row once. + row_max: Int[np.ndarray, " n_ens"] = X.max(axis=1) + + for i in range(1, n_ens): + a: Int[np.ndarray, " n_components"] = X[i] + for j in range(i): + b: Int[np.ndarray, " n_components"] = X[j] + + k_lbls: int = int(max(row_max[i], row_max[j]) + 1) + + # Handle case where all labels are -1 (no valid clustering) + if k_lbls <= 0: + warnings.warn( + f"All labels are -1 at rows {i} and {j}. Setting distance to 0.", + UserWarning, + stacklevel=2, + ) + D[i, j] = 0.0 + continue + + C: Int[np.ndarray, "k_lbls k_lbls"] = np.zeros((k_lbls, k_lbls), dtype=int) + np.add.at(C, (a, b), 1) + + row_ind, col_ind = linear_sum_assignment(-C) + matches: int = int(C[row_ind, col_ind].sum()) + + D[i, j] = n_components - matches # int is fine; array is float because of NaN + + return D diff --git a/spd/clustering/math/semilog.py b/spd/clustering/math/semilog.py new file mode 100644 index 000000000..a17ba63b5 --- /dev/null +++ b/spd/clustering/math/semilog.py @@ -0,0 +1,13 @@ +import math + + +def semilog( + value: float, + epsilon: float = 1e-3, +) -> float: + if abs(value) < epsilon: + return value + else: + sign: int = 1 if value >= 0 else -1 + # log10 here is safe, since we know the value is not close to zero + return sign * epsilon * math.log1p(abs(value) / epsilon) diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py new file mode 100644 index 000000000..dba55c878 --- /dev/null +++ b/spd/clustering/merge.py @@ -0,0 +1,188 @@ +""" +Merge iteration with logging support. + +This wraps the pure merge_iteration_pure() function and adds WandB/plotting callbacks. +""" + +import warnings +from typing import Protocol + +import torch +from jaxtyping import Bool, Float +from torch import Tensor +from tqdm import tqdm + +from spd.clustering.compute_costs import ( + compute_mdl_cost, + compute_merge_costs, + recompute_coacts_merge_pair, +) +from spd.clustering.consts import ( + ActivationsTensor, + BoolActivationsTensor, + ClusterCoactivationShaped, + ComponentLabels, + MergePair, +) +from spd.clustering.math.merge_matrix import GroupMerge +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory + + +class LogCallback(Protocol): + def __call__( + self, + current_coact: ClusterCoactivationShaped, + component_labels: ComponentLabels, + current_merge: GroupMerge, + costs: ClusterCoactivationShaped, + merge_history: MergeHistory, + iter_idx: int, + k_groups: int, + merge_pair_cost: float, + mdl_loss: float, + mdl_loss_norm: float, + diag_acts: Float[Tensor, " k_groups"], + ) -> None: ... + + +def merge_iteration( + merge_config: MergeConfig, + activations: ActivationsTensor, + component_labels: ComponentLabels, + log_callback: LogCallback | None = None, +) -> MergeHistory: + """ + Merge iteration with optional logging/plotting callbacks. + + This wraps the pure computation with logging capabilities while maintaining + the same core algorithm logic. + """ + + # compute coactivations + # -------------------------------------------------- + activation_mask_orig: BoolActivationsTensor | ActivationsTensor | None = ( + activations > merge_config.activation_threshold + if merge_config.activation_threshold is not None + else activations + ) + coact: Float[Tensor, "c c"] = activation_mask_orig.float().T @ activation_mask_orig.float() + + # check shapes + c_components: int = coact.shape[0] + assert coact.shape[1] == c_components, "Coactivation matrix must be square" + + # determine number of iterations based on config and number of components + num_iters: int = merge_config.get_num_iters(c_components) + + # initialize vars + # -------------------------------------------------- + # start with an identity merge + current_merge: GroupMerge = GroupMerge.identity(n_components=c_components) + + # initialize variables for the merge process + k_groups: int = c_components + current_coact: ClusterCoactivationShaped = coact.clone() + current_act_mask: Bool[Tensor, "samples k_groups"] = activation_mask_orig.clone() + + # variables we keep track of + merge_history: MergeHistory = MergeHistory.from_config( + merge_config=merge_config, + labels=component_labels, + ) + + # merge iteration + # ================================================== + pbar: tqdm[int] = tqdm( + range(num_iters), + unit="iter", + total=num_iters, + ) + for iter_idx in pbar: + # compute costs, figure out what to merge + # -------------------------------------------------- + # HACK: this is messy + costs: ClusterCoactivationShaped = compute_merge_costs( + coact=current_coact / current_act_mask.shape[0], + merges=current_merge, + alpha=merge_config.alpha, + ) + + merge_pair: MergePair = merge_config.merge_pair_sample(costs) + + # merge the pair + # -------------------------------------------------- + # we do this *before* logging, so we can see how the sampled pair cost compares + # to the costs of all the other possible pairs + current_merge, current_coact, current_act_mask = recompute_coacts_merge_pair( + coact=current_coact, + merges=current_merge, + merge_pair=merge_pair, + activation_mask=current_act_mask, + ) + + # metrics and logging + # -------------------------------------------------- + # Store in history + merge_history.add_iteration( + idx=iter_idx, + selected_pair=merge_pair, + current_merge=current_merge, + ) + + # Compute metrics for logging + # the MDL loss computed here is the *cost of the current merge*, a single scalar value + # rather than the *delta in cost from merging a specific pair* (which is what `costs` matrix contains) + diag_acts: Float[Tensor, " k_groups"] = torch.diag(current_coact) + mdl_loss: float = compute_mdl_cost( + acts=diag_acts, + merges=current_merge, + alpha=merge_config.alpha, + ) + mdl_loss_norm: float = mdl_loss / current_act_mask.shape[0] + # this is the cost for the selected pair + merge_pair_cost: float = float(costs[merge_pair].item()) + + # Update progress bar + pbar.set_description(f"k={k_groups}, mdl={mdl_loss_norm:.4f}, pair={merge_pair_cost:.4f}") + + if log_callback is not None: + log_callback( + iter_idx=iter_idx, + current_coact=current_coact, + component_labels=component_labels, + current_merge=current_merge, + costs=costs, + merge_history=merge_history, + k_groups=k_groups, + merge_pair_cost=merge_pair_cost, + mdl_loss=mdl_loss, + mdl_loss_norm=mdl_loss_norm, + diag_acts=diag_acts, + ) + + # iterate and sanity checks + # -------------------------------------------------- + k_groups -= 1 + assert current_coact.shape[0] == k_groups, ( + "Coactivation matrix shape should match number of groups" + ) + assert current_coact.shape[1] == k_groups, ( + "Coactivation matrix shape should match number of groups" + ) + assert current_act_mask.shape[1] == k_groups, ( + "Activation mask shape should match number of groups" + ) + + # early stopping failsafe + # -------------------------------------------------- + if k_groups <= 3: + warnings.warn( + f"Stopping early at iteration {iter_idx} as only {k_groups} groups left", + stacklevel=2, + ) + break + + # finish up + # ================================================== + return merge_history diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py new file mode 100644 index 000000000..f471879b2 --- /dev/null +++ b/spd/clustering/merge_config.py @@ -0,0 +1,114 @@ +import functools +import hashlib +from typing import Any, Literal + +from pydantic import ( + Field, + PositiveInt, +) + +from spd.base_config import BaseConfig +from spd.clustering.consts import ClusterCoactivationShaped, MergePair +from spd.clustering.math.merge_pair_samplers import ( + MERGE_PAIR_SAMPLERS, + MergePairSampler, + MergePairSamplerKey, +) +from spd.clustering.util import ModuleFilterFunc, ModuleFilterSource +from spd.spd_types import Probability + +MergeConfigKey = Literal[ + "activation_threshold", + "alpha", + "iters", + "merge_pair_sampling_method", + "merge_pair_sampling_kwargs", + "filter_dead_threshold", +] + + +def _to_module_filter( + filter_modules: ModuleFilterSource, +) -> ModuleFilterFunc: + """Convert the filter_modules argument to a callable.""" + if filter_modules is None: + return lambda _: True + elif isinstance(filter_modules, str): + return lambda module_name: module_name.startswith(filter_modules) + elif isinstance(filter_modules, set): + return lambda module_name: module_name in filter_modules + elif callable(filter_modules): + return filter_modules + else: + raise TypeError(f"filter_modules must be str, set, or callable, got {type(filter_modules)}") # pyright: ignore[reportUnreachable] + + +class MergeConfig(BaseConfig): + activation_threshold: Probability | None = Field( + default=0.01, + description="Threshold for considering a component active in a group. If None, use raw scalar causal importances", + ) + alpha: float = Field( + default=1.0, + description="rank weight factor. Higher values mean a higher penalty on 'sending' the component weights", + ) + iters: PositiveInt | None = Field( + default=100, + description="max number of iterations to run the merge algorithm for. If `None`, set to number of components (after filtering) minus one.", + ) + merge_pair_sampling_method: MergePairSamplerKey = Field( + default="range", + description="Method for sampling merge pairs. Options: 'range', 'mcmc'.", + ) + merge_pair_sampling_kwargs: dict[str, Any] = Field( + default_factory=lambda: {"threshold": 0.05}, + description="Keyword arguments for the merge pair sampling method.", + ) + filter_dead_threshold: float = Field( + default=0.001, + description="Threshold for filtering out dead components. If a component's activation is below this threshold, it is considered dead and not included in the merge.", + ) + module_name_filter: ModuleFilterSource = Field( + default=None, + description="Filter for module names. Can be a string prefix, a set of names, or a callable that returns True for modules to include.", + ) + + @property + def merge_pair_sample_func(self) -> MergePairSampler: + return functools.partial( + MERGE_PAIR_SAMPLERS[self.merge_pair_sampling_method], + **self.merge_pair_sampling_kwargs, + ) + + def merge_pair_sample( + self, + costs: ClusterCoactivationShaped, + ) -> MergePair: + """do merge sampling based on the configured method and kwargs + + has signature `MergePairSampler = Callable[[ClusterCoactivationShaped], MergePair]` + """ + return self.merge_pair_sample_func(costs=costs) + + @property + def filter_modules(self) -> ModuleFilterFunc: + """Get the module filter function based on the provided source.""" + return _to_module_filter(self.module_name_filter) + + def get_num_iters(self, n_components: int) -> PositiveInt: + """Get the number of iterations to run the merge algorithm for. + + Args: + n_components: Number of components (after filtering) + + Returns: + Number of iterations to run + """ + if self.iters is None: + return n_components - 1 + else: + return self.iters + + @property + def stable_hash(self) -> str: + return hashlib.md5(self.model_dump_json().encode()).hexdigest()[:6] diff --git a/spd/clustering/merge_history.py b/spd/clustering/merge_history.py new file mode 100644 index 000000000..98dca398b --- /dev/null +++ b/spd/clustering/merge_history.py @@ -0,0 +1,457 @@ +import io +import json +import zipfile +from dataclasses import dataclass +from pathlib import Path +from typing import Any, override + +import numpy as np +import torch +from jaxtyping import Float, Int + +from spd.clustering.consts import ( + ComponentLabels, + DistancesArray, + DistancesMethod, + MergePair, + MergesArray, + SaveableObject, +) +from spd.clustering.math.merge_distances import compute_distances +from spd.clustering.math.merge_matrix import BatchedGroupMerge, GroupMerge +from spd.clustering.merge_config import MergeConfig + + +@dataclass(frozen=True) +class IterationInfo: + """Information about a single merge iteration.""" + + idx: int + selected_pair: list[int] + merges: GroupMerge + + +def _zip_save_arr(zf: zipfile.ZipFile, name: str, arr: np.ndarray) -> None: + """Save a numpy array to a zip file.""" + buf: io.BytesIO = io.BytesIO() + np.save(buf, arr) + zf.writestr(name, buf.getvalue()) + + +def _zip_save_arr_dict(zf: zipfile.ZipFile, data: dict[str, np.ndarray]) -> None: + """Save a dictionary of numpy arrays to a zip file, {key}.npy used as path""" + key: str + arr: np.ndarray + for key, arr in data.items(): + _zip_save_arr(zf, f"{key}.npy", arr) + + +@dataclass(kw_only=True) +class MergeHistory(SaveableObject): + """Track merge iteration history""" + + merges: BatchedGroupMerge + selected_pairs: Int[np.ndarray, " n_iters 2"] + labels: ComponentLabels + merge_config: MergeConfig + n_iters_current: int + + meta: dict[str, Any] | None = None + + @property + def c_components(self) -> int: + return len(self.labels) + + @classmethod + def from_config( + cls, + merge_config: MergeConfig, + labels: ComponentLabels, + ) -> "MergeHistory": + n_components: int = len(labels) + n_iters_target: int = merge_config.get_num_iters(n_components) + return MergeHistory( + labels=labels, + n_iters_current=0, + selected_pairs=np.full((n_iters_target, 2), -1, dtype=np.int16), + merges=BatchedGroupMerge.init_empty( + batch_size=n_iters_target, n_components=n_components + ), + merge_config=merge_config, + ) + + def summary(self) -> dict[str, str | int | None | dict[str, int | str | None]]: + return dict( + c_components=self.c_components, + n_iters_current=self.n_iters_current, + total_iters=len(self.merges.k_groups), + len_labels=len(self.labels), + # wandb_url=self.wandb_url, + merge_config=self.merge_config.model_dump(mode="json"), + merges_summary=self.merges.summary(), + ) + + @override + def __str__(self) -> str: + out: list[str] = [f" {key} = {value}" for key, value in self.summary().items()] + return "MergeHistory(\n" + "\n".join(out) + "\n)" + + @override + def __repr__(self) -> str: + return self.__str__() + + def add_iteration( + self, + idx: int, + selected_pair: MergePair, + current_merge: GroupMerge, + ) -> None: + """Add data for one iteration.""" + self.selected_pairs[idx] = np.array(selected_pair, dtype=np.int16) + self.merges[idx] = current_merge + + assert self.n_iters_current == idx + self.n_iters_current += 1 + + def __getitem__(self, idx: int) -> IterationInfo: + """Get data for a specific iteration.""" + if idx < 0 or idx >= self.n_iters_current: + raise IndexError( + f"Index {idx} out of range for history with {self.n_iters_current} iterations" + ) + + return IterationInfo( + idx=idx, + selected_pair=self.selected_pairs[idx].tolist(), + merges=self.merges[idx], + ) + + def __len__(self) -> int: + """Get the number of iterations in the history.""" + return self.n_iters_current + + def latest(self) -> IterationInfo: + """Get the latest values.""" + if self.n_iters_current == 0: + raise ValueError("No history available") + latest_idx: int = self.n_iters_current - 1 + return self[latest_idx] + + def get_unique_clusters(self, iteration: int) -> list[int]: + """Get unique cluster IDs at a given iteration. + + Args: + iteration: Iteration index (negative indexes from end) + + Returns: + List of unique cluster IDs + """ + if iteration < 0: + iteration = self.n_iters_current + iteration + assert 0 <= iteration < self.n_iters_current, ( + f"Invalid iteration: {iteration = }, {self.n_iters_current = }" + ) + merge: GroupMerge = self.merges[iteration] + return torch.unique(merge.group_idxs).tolist() + + def get_cluster_component_labels(self, iteration: int, cluster_id: int) -> ComponentLabels: + """Get component labels for a specific cluster at a given iteration. + + Args: + iteration: Iteration index (negative indexes from end) + cluster_id: Cluster ID to query + + Returns: + List of component labels in the cluster + """ + if iteration < 0: + iteration = self.n_iters_current + iteration + assert 0 <= iteration < self.n_iters_current, ( + f"Invalid iteration: {iteration = }, {self.n_iters_current = }" + ) + merge: GroupMerge = self.merges[iteration] + component_indices: list[int] = merge.components_in_group(cluster_id) + return ComponentLabels([self.labels[idx] for idx in component_indices]) + + def get_cluster_components_info(self, iteration: int, cluster_id: int) -> list[dict[str, Any]]: + """Get detailed component information for a cluster. + + Args: + iteration: Iteration index (negative indexes from end) + cluster_id: Cluster ID to query + + Returns: + List of dicts with keys: module, index, label + """ + component_labels: list[str] = self.get_cluster_component_labels(iteration, cluster_id) + result: list[dict[str, Any]] = [] + for label in component_labels: + module: str + idx_str: str + module, idx_str = label.rsplit(":", 1) + result.append({"module": module, "index": int(idx_str), "label": label}) + return result + + # Convenience properties for sweep analysis + @property + def total_iterations(self) -> int: + """Total number of iterations performed.""" + return self.n_iters_current + + @property + def final_k_groups(self) -> int: + """Final number of groups after merging.""" + if self.n_iters_current == 0: + return self.c_components + return int(self.merges.k_groups[self.n_iters_current - 1].item()) + + @property + def initial_k_groups(self) -> int: + """Initial number of groups before merging.""" + if self.n_iters_current == 0: + return self.c_components + return int(self.merges.k_groups[0].item()) + + @override + def save(self, path: Path) -> None: + zf: zipfile.ZipFile + with zipfile.ZipFile(path, "w") as zf: + # save arrays + _zip_save_arr_dict( + zf=zf, + data={ + "merge.group_idxs": self.merges.group_idxs.cpu().numpy(), + "merge.k_groups": self.merges.k_groups.cpu().numpy(), + "selected_pairs": self.selected_pairs, + }, + ) + # Save labels + zf.writestr("labels.txt", "\n".join(self.labels)) + # Save metadata + zf.writestr( + "metadata.json", + json.dumps( + dict( + merge_config=self.merge_config.model_dump(mode="json"), + c_components=self.c_components, + n_iters_current=self.n_iters_current, + labels=self.labels, + ) + ), + ) + + @override + @classmethod + def read(cls, path: Path) -> "MergeHistory": + zf: zipfile.ZipFile + with zipfile.ZipFile(path, "r") as zf: + group_idxs: np.ndarray = np.load(io.BytesIO(zf.read("merge.group_idxs.npy"))) + k_groups: np.ndarray = np.load(io.BytesIO(zf.read("merge.k_groups.npy"))) + selected_pairs: np.ndarray = np.load(io.BytesIO(zf.read("selected_pairs.npy"))) + merges: BatchedGroupMerge = BatchedGroupMerge( + group_idxs=torch.from_numpy(group_idxs), + k_groups=torch.from_numpy(k_groups), + ) + labels_raw: list[str] = zf.read("labels.txt").decode("utf-8").splitlines() + labels: ComponentLabels = ComponentLabels(labels_raw) + metadata: dict[str, Any] = json.loads(zf.read("metadata.json").decode("utf-8")) + merge_config: MergeConfig = MergeConfig.model_validate(metadata["merge_config"]) + + metadata["origin_path"] = path + + return cls( + merges=merges, + selected_pairs=selected_pairs, + labels=labels, + merge_config=merge_config, + n_iters_current=metadata["n_iters_current"], + meta=metadata, + ) + + +@dataclass +class MergeHistoryEnsemble: + data: list[MergeHistory] + + def __iter__(self): + return iter(self.data) + + def __getitem__(self, idx: int) -> MergeHistory: + return self.data[idx] + + def _validate_configs_match(self) -> None: + """Ensure all histories have the same merge config.""" + if not self.data: + return + first_config: MergeConfig = self.data[0].merge_config + for history in self.data[1:]: + if history.merge_config != first_config: + raise ValueError("All histories must have the same merge config") + + @property + def config(self) -> MergeConfig: + """Get the merge config used in the ensemble.""" + self._validate_configs_match() + return self.data[0].merge_config + + @property + def n_iters_min(self) -> int: + """Minimum number of iterations across all histories in the ensemble.""" + return min(len(history.merges.k_groups) for history in self.data) + + @property + def n_iters_max(self) -> int: + """Maximum number of iterations across all histories in the ensemble.""" + return max(len(history.merges.k_groups) for history in self.data) + + @property + def n_iters_range(self) -> tuple[int, int]: + """Range of iterations (min, max) across all histories in the ensemble.""" + iter_counts = [len(history.merges.k_groups) for history in self.data] + return (min(iter_counts), max(iter_counts)) + + @property + def n_ensemble(self) -> int: + """Number of ensemble members.""" + return len(self.data) + + @property + def c_components(self) -> int: + """Number of components in each history.""" + c_components: int = self.data[0].c_components + assert all(history.c_components == c_components for history in self.data), ( + "All histories must have the same number of components" + ) + return c_components + + @property + def shape(self) -> tuple[int, int, int]: + """Shape of the ensemble data.""" + return (self.n_ensemble, self.n_iters_min, self.c_components) + + @property + def merges_array(self) -> MergesArray: + n_ens: int = self.n_ensemble + n_iters: int = self.n_iters_min + c_components: int = self.c_components + + output: MergesArray = np.full( + (n_ens, n_iters, c_components), + fill_value=-1, + dtype=np.int16, + # if you have more than 32k components, change this to np.int32 + # if you have more than 2.1b components, rethink your life choices + ) + for i_ens, history in enumerate(self.data): + for i_iter, merge in enumerate(history.merges): + output[i_ens, i_iter] = merge.group_idxs + + return output + + def normalized(self) -> tuple[MergesArray, dict[str, Any]]: + """Normalize the component labels across all histories. + + if different histories see different batches, then they might have different dead + components, and are hence not directly comparable. So, we find the union of all + component labels across all histories, and then any component missing from a history + is put into it's own group in that history + """ + + unique_labels_set: set[str] = set() + for history in self.data: + unique_labels_set.update(history.labels) + + unique_labels_list: list[str] = sorted(unique_labels_set) + unique_labels: ComponentLabels = ComponentLabels(unique_labels_list) + c_components: int = len(unique_labels) + component_label_idxs: dict[str, int] = { + label: idx for idx, label in enumerate(unique_labels) + } + + try: + merges_array: MergesArray = np.full( + (self.n_ensemble, self.n_iters_min, c_components), + fill_value=-1, + dtype=np.int16, + ) + except Exception as e: + err_msg = ( + f"failed to create merge array, probably due to issues with getting shape.\n" + f"{self = }\n" + f"{self.data = }\n" + ) + raise RuntimeError(err_msg) from e + + overlap_stats: Float[np.ndarray, " n_ens"] = np.full( + self.n_ensemble, + fill_value=float("nan"), + dtype=np.float32, + ) + i_ens: int + history: MergeHistory + for i_ens, history in enumerate(self.data): + hist_c_labels: list[str] = history.labels + hist_n_components: int = len(hist_c_labels) + overlap_stats[i_ens] = hist_n_components / c_components + # map from old component indices to new component indices + i_comp_old: int + comp_label: str + for i_comp_old, comp_label in enumerate(hist_c_labels): + i_comp_new: int = component_label_idxs[comp_label] + merges_array[i_ens, :, i_comp_new] = history.merges.group_idxs[ + : self.n_iters_min, i_comp_old + ] + + # assert np.max(merges_array[i_ens]) == hist_n_components - 1, ( + # f"Max component index in history {i_ens} should be {hist_n_components - 1}, " + # f"but got {np.max(merges_array[i_ens])}" + # ) + + # put each missing label into its own group + hist_missing_labels: set[str] = unique_labels_set - set(hist_c_labels) + assert len(hist_missing_labels) == c_components - hist_n_components + idx_missing: int + missing_label: str + for idx_missing, missing_label in enumerate(hist_missing_labels): + i_comp_new_relabel: int = component_label_idxs[missing_label] + merges_array[i_ens, :, i_comp_new_relabel] = np.full( + self.n_iters_min, + fill_value=idx_missing + hist_n_components, + dtype=np.int16, + ) + + # TODO: double check this + # Convert any Path objects to strings for JSON serialization + history_metadatas: list[dict[str, Any] | None] = [] + for history in self.data: + if history.meta is not None: + meta_copy = history.meta.copy() + # Convert Path objects to strings + for key, value in meta_copy.items(): + if isinstance(value, Path): + meta_copy[key] = str(value) + history_metadatas.append(meta_copy) + else: + history_metadatas.append(None) + + return ( + # TODO: dataclass this + merges_array, + dict( + component_labels=unique_labels, + n_ensemble=self.n_ensemble, + n_iters_min=self.n_iters_min, + n_iters_max=self.n_iters_max, + n_iters_range=self.n_iters_range, + c_components=c_components, + config=self.config.model_dump(mode="json"), + history_metadatas=history_metadatas, + ), + ) + + def get_distances(self, method: DistancesMethod = "perm_invariant_hamming") -> DistancesArray: + merges_array: MergesArray = self.merges_array + return compute_distances( + normalized_merge_array=merges_array, + method=method, + ) diff --git a/spd/clustering/plotting/__init__.py b/spd/clustering/plotting/__init__.py new file mode 100644 index 000000000..b048d1d24 --- /dev/null +++ b/spd/clustering/plotting/__init__.py @@ -0,0 +1 @@ +"""Plotting utilities for clustering module.""" diff --git a/spd/clustering/plotting/activations.py b/spd/clustering/plotting/activations.py new file mode 100644 index 000000000..81a147f37 --- /dev/null +++ b/spd/clustering/plotting/activations.py @@ -0,0 +1,386 @@ +"""Plotting functions for activation visualizations.""" + +from collections.abc import Sequence +from pathlib import Path + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import torch +import wandb +import wandb.sdk.wandb_run +from jaxtyping import Float, Int +from torch import Tensor + +from spd.clustering.activations import ProcessedActivations, compute_coactivatons +from spd.clustering.consts import ActivationsTensor, ClusterCoactivationShaped, ComponentLabels + + +def plot_activations( + processed_activations: ProcessedActivations, + save_dir: Path | None, + n_samples_max: int, + figure_prefix: str = "activations", + figsize_raw: tuple[int, int] = (12, 4), + figsize_concat: tuple[int, int] = (12, 2), + figsize_coact: tuple[int, int] = (8, 6), + hist_scales: tuple[str, str] = ("lin", "log"), + hist_bins: int = 100, + do_sorted_samples: bool = False, + wandb_run: wandb.sdk.wandb_run.Run | None = None, +) -> None: + """Plot activation visualizations including raw, concatenated, sorted, and coactivations. + + Args: + activations: Dictionary of raw activations by module + act_concat: Concatenated activations tensor + coact: Coactivation matrix + labels: Component labels + save_dir: The directory to save the plots to (None to skip saving to disk) + figure_prefix: Prefix for PDF filenames + figsize_raw: Figure size for raw activations + figsize_concat: Figure size for concatenated activations + figsize_coact: Figure size for coactivations + hist_scales: Tuple of (x_scale, y_scale) where each is "lin" or "log" + hist_bins: Number of bins for histograms + """ + if save_dir is not None: + save_dir.mkdir(parents=True, exist_ok=True) + + act_dict: dict[str, ActivationsTensor] = processed_activations.activations_raw + act_concat: ActivationsTensor = processed_activations.activations + coact: ClusterCoactivationShaped = compute_coactivatons(act_concat) + labels: ComponentLabels = ComponentLabels(processed_activations.labels) + n_samples: int = act_concat.shape[0] + + # trim the activations if n_samples_max is specified + # clone here so we don't modify the original tensor + act_concat = act_concat[:n_samples_max].clone() + # we don't use the stuff in this dict again, so we can modify it in-place + for key in act_dict: + act_dict[key] = act_dict[key][:n_samples_max] + + # Update n_samples to reflect the truncated size + n_samples = act_concat.shape[0] + + # Raw activations + axs_act: Sequence[plt.Axes] + _fig1: plt.Figure + _fig1, axs_act = plt.subplots(len(act_dict), 1, figsize=figsize_raw) + if len(act_dict) == 1: + assert isinstance(axs_act, plt.Axes) + axs_act = [axs_act] + for i, (key, act) in enumerate(act_dict.items()): + act_raw_data: np.ndarray = act.T.cpu().numpy() + axs_act[i].matshow( + act_raw_data, aspect="auto", vmin=act_raw_data.min(), vmax=act_raw_data.max() + ) + axs_act[i].set_ylabel(f"components\n{key}") + axs_act[i].set_title(f"Raw Activations: {key} (shape: {act_raw_data.shape})") + + if save_dir is not None: + fig1_fname = save_dir / f"{figure_prefix}_raw.pdf" + _fig1.savefig(fig1_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/raw": wandb.Image(_fig1)}, step=0) + + # Close figure to free memory + plt.close(_fig1) + + # Concatenated activations + fig2: plt.Figure + ax2: plt.Axes + fig2, ax2 = plt.subplots(figsize=figsize_concat) + act_data: np.ndarray = act_concat.T.cpu().numpy() + im2 = ax2.matshow(act_data, aspect="auto", vmin=act_data.min(), vmax=act_data.max()) + ax2.set_title("Concatenated Activations") + + # Add component labeling on y-axis + add_component_labeling(ax2, labels, axis="y") + + plt.colorbar(im2) + + if save_dir is not None: + fig2_fname: Path = save_dir / f"{figure_prefix}_concatenated.pdf" + fig2.savefig(fig2_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/concatenated": wandb.Image(fig2)}, step=0) + + # Close figure to free memory + plt.close(fig2) + + # Concatenated activations, sorted samples + if do_sorted_samples: + # TODO: move sample sorting logic to its own function, see + # https://github.com/goodfire-ai/spd/pull/172/files#r2387275601 + fig3: plt.Figure + ax3: plt.Axes + fig3, ax3 = plt.subplots(figsize=figsize_concat) + + # Compute gram matrix (sample similarity) and sort samples using greedy ordering + gram_matrix: Float[Tensor, "samples samples"] = act_concat @ act_concat.T + + # Normalize gram matrix to get cosine similarity + norms: Float[Tensor, "samples 1"] = torch.norm(act_concat, dim=1, keepdim=True) + norms = torch.where(norms > 1e-8, norms, torch.ones_like(norms)) + similarity_matrix: Float[Tensor, "samples samples"] = gram_matrix / (norms @ norms.T) + + # Greedy ordering: start with sample most similar to all others + avg_similarity: Float[Tensor, " samples"] = similarity_matrix.mean(dim=1) + start_idx: int = int(torch.argmax(avg_similarity).item()) + + # Build ordering greedily + ordered_indices: list[int] = [start_idx] + remaining: set[int] = set(range(n_samples)) + remaining.remove(start_idx) + + # Greedily add the nearest unvisited sample + current_idx: int = start_idx + while remaining: + # Find the unvisited sample most similar to current + best_similarity: float = -1 + best_idx: int = -1 + for idx in remaining: + sim: float = similarity_matrix[current_idx, idx].item() + if sim > best_similarity: + best_similarity = sim + best_idx = idx + + ordered_indices.append(best_idx) + remaining.remove(best_idx) + current_idx = best_idx + + sorted_indices: Int[Tensor, " samples"] = torch.tensor( + ordered_indices, dtype=torch.long, device=act_concat.device + ) + act_concat_sorted: ActivationsTensor = act_concat[sorted_indices] + + # Handle log10 properly - add small epsilon to avoid log(0) + act_sorted_data: np.ndarray = act_concat_sorted.T.cpu().numpy() + act_sorted_log: np.ndarray = np.log10(act_sorted_data + 1e-10) + im3 = ax3.matshow( + act_sorted_log, aspect="auto", vmin=act_sorted_log.min(), vmax=act_sorted_log.max() + ) + ax3.set_title("Concatenated Activations $\\log_{10}$, Sorted Samples") + + # Add component labeling on y-axis + add_component_labeling(ax3, labels, axis="y") + + plt.colorbar(im3) + + if save_dir is not None: + fig3_fname: Path = save_dir / f"{figure_prefix}_concatenated_sorted.pdf" + fig3.savefig(fig3_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/concatenated_sorted": wandb.Image(fig3)}, step=0) + + # Close figure to free memory + plt.close(fig3) + + # Coactivations + fig4: plt.Figure + ax4: plt.Axes + fig4, ax4 = plt.subplots(figsize=figsize_coact) + coact_data: np.ndarray = coact.cpu().numpy() + im4 = ax4.matshow(coact_data, aspect="auto", vmin=coact_data.min(), vmax=coact_data.max()) + ax4.set_title("Coactivations") + + # Add component labeling on both axes + add_component_labeling(ax4, labels, axis="x") + add_component_labeling(ax4, labels, axis="y") + + plt.colorbar(im4) + + if save_dir is not None: + fig4_fname: Path = save_dir / f"{figure_prefix}_coactivations.pdf" + fig4.savefig(fig4_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/coactivations": wandb.Image(fig4)}, step=0) + + # Close figure to free memory + plt.close(fig4) + + # log coactivations + fig4_log: plt.Figure + ax4_log: plt.Axes + fig4_log, ax4_log = plt.subplots(figsize=figsize_coact) + # assert np.all(coact_data >= 0) # TODO: why are coacts negative? :/ + coact_log_data: np.ndarray = np.log10(coact_data + 1e-6 + coact_data.min()) + im4_log = ax4_log.matshow( + coact_log_data, aspect="auto", vmin=coact_log_data.min(), vmax=coact_log_data.max() + ) + ax4_log.set_title("Coactivations $\\log_{10}$") + # Add component labeling on both axes + add_component_labeling(ax4_log, labels, axis="x") + add_component_labeling(ax4_log, labels, axis="y") + plt.colorbar(im4_log) + if save_dir is not None: + fig4_log_fname: Path = save_dir / f"{figure_prefix}_coactivations_log.pdf" + fig4_log.savefig(fig4_log_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/coactivations_log": wandb.Image(fig4_log)}, step=0) + + # Close figure to free memory + plt.close(fig4_log) + + # Activation histograms + fig5: plt.Figure + ax5a: plt.Axes + ax5b: plt.Axes + ax5c: plt.Axes + fig5, (ax5a, ax5b, ax5c) = plt.subplots(1, 3, figsize=(15, 4)) + + x_scale: str + y_scale: str + x_scale, y_scale = hist_scales + + # Histogram 1: All activations + all_activations: Float[Tensor, " samples*n_components"] = act_concat.flatten() + all_vals: np.ndarray = all_activations.cpu().numpy() + hist_counts: np.ndarray + bin_edges: np.ndarray + hist_counts, bin_edges = np.histogram(all_vals, bins=hist_bins) + bin_centers: np.ndarray = (bin_edges[:-1] + bin_edges[1:]) / 2 + ax5a.plot(bin_centers, hist_counts, color="blue", linewidth=2) + ax5a.set_title("All Activations") + ax5a.set_xlabel("Activation Value") + ax5a.set_ylabel("Count") + if x_scale == "log": + ax5a.set_xscale("log") + if y_scale == "log": + ax5a.set_yscale("log") + ax5a.grid(True, alpha=0.3) + + # Histogram 2: Activations per component + n_components: int = act_concat.shape[1] + + # Common bin edges for all component histograms + all_min: float = float(all_vals.min()) + all_max: float = float(all_vals.max()) + common_bins: np.ndarray = np.linspace(all_min, all_max, hist_bins) + common_centers: np.ndarray = (common_bins[:-1] + common_bins[1:]) / 2 + + # Get unique label prefixes and assign colors + label_prefixes: list[str] = [label.split(":")[0] for label in labels] + unique_prefixes: list[str] = list(dict.fromkeys(label_prefixes)) # Preserve order + colors: Sequence[tuple[int, int, int]] = mpl.colormaps["tab10"]( + np.linspace(0, 1, len(unique_prefixes)) + ) # pyright: ignore[reportAssignmentType] + prefix_colors: dict[str, tuple[int, int, int]] = { + prefix: colors[i] for i, prefix in enumerate(unique_prefixes) + } + + for comp_idx in range(n_components): + component_activations: Float[Tensor, " n_samples"] = act_concat[:, comp_idx] + comp_vals: np.ndarray = component_activations.cpu().numpy() + hist_counts, _ = np.histogram(comp_vals, bins=common_bins, density=True) + + # Get color based on label prefix + prefix: str = label_prefixes[comp_idx] + color: tuple[int, int, int] = prefix_colors[prefix] + + ax5b.plot(common_centers, hist_counts, color=color, alpha=0.1, linewidth=1) + + ax5b.set_title(f"Per Component ({n_components} components)") + ax5b.set_xlabel("Activation Value") + ax5b.set_ylabel("Density") + if x_scale == "log": + ax5b.set_xscale("log") + if y_scale == "log": + ax5b.set_yscale("log") + ax5b.grid(True, alpha=0.3) + + # Histogram 3: Activations per sample + for sample_idx in range(n_samples): + sample_activations: Float[Tensor, " n_components"] = act_concat[sample_idx, :] + sample_vals: np.ndarray = sample_activations.cpu().numpy() + hist_counts, _ = np.histogram(sample_vals, bins=common_bins, density=True) + ax5c.plot(common_centers, hist_counts, color="blue", alpha=0.1, linewidth=1) + + ax5c.set_title(f"Per Sample ({n_samples} samples)") + ax5c.set_xlabel("Activation Value") + ax5c.set_ylabel("Density") + if x_scale == "log": + ax5c.set_xscale("log") + if y_scale == "log": + ax5c.set_yscale("log") + ax5c.grid(True, alpha=0.3) + + plt.tight_layout() + + if save_dir is not None: + fig5_fname: Path = save_dir / f"{figure_prefix}_histograms.pdf" + fig5.savefig(fig5_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/histograms": wandb.Image(fig5)}, step=0) + + # Close figure to free memory + plt.close(fig5) + + +def add_component_labeling( + ax: plt.Axes, component_labels: ComponentLabels, axis: str = "x" +) -> None: + """Add component labeling using major/minor ticks to show module boundaries. + + Args: + ax: Matplotlib axis to modify + component_labels: List of component labels in format "module:index" + axis: Which axis to label ('x' or 'y') + """ + if not component_labels: + return + + # Extract module information + module_changes: list[int] = [] + current_module: str = component_labels[0].split(":")[0] + module_labels: list[str] = [] + + for i, label in enumerate(component_labels): + module: str = label.split(":")[0] + if module != current_module: + module_changes.append(i) + module_labels.append(current_module) + current_module = module + module_labels.append(current_module) + + # Set up major and minor ticks + # Minor ticks: every 10 components + minor_ticks: list[int] = list(range(0, len(component_labels), 10)) + + # Major ticks: module boundaries (start of each module) + major_ticks: list[int] = [0] + module_changes + major_labels: list[str] = module_labels + + if axis == "x": + ax.set_xticks(minor_ticks, minor=True) + ax.set_xticks(major_ticks) + ax.set_xticklabels(major_labels) + ax.set_xlim(-0.5, len(component_labels) - 0.5) + # Style the ticks + ax.tick_params(axis="x", which="minor", length=2, width=0.5) + ax.tick_params(axis="x", which="major", length=6, width=1.5) + for x in major_ticks: + ax.axvline(x - 0.5, color="black", linestyle="--", linewidth=0.5, alpha=0.5) + else: + ax.set_yticks(minor_ticks, minor=True) + ax.set_yticks(major_ticks) + ax.set_yticklabels(major_labels) + ax.set_ylim(-0.5, len(component_labels) - 0.5) + # Style the ticks + ax.tick_params(axis="y", which="minor", length=2, width=0.5) + ax.tick_params(axis="y", which="major", length=6, width=1.5) + for y in major_ticks: + ax.axhline(y - 0.5, color="black", linestyle="--", linewidth=0.5, alpha=0.5) diff --git a/spd/clustering/plotting/merge.py b/spd/clustering/plotting/merge.py new file mode 100644 index 000000000..b213e1724 --- /dev/null +++ b/spd/clustering/plotting/merge.py @@ -0,0 +1,359 @@ +"""Plotting functions for merge visualizations.""" + +from typing import Any, Literal + +import matplotlib.pyplot as plt +import numpy as np +import torch +from jaxtyping import Bool, Float, Int +from torch import Tensor + +from spd.clustering.consts import ClusterCoactivationShaped, ComponentLabels, DistancesArray +from spd.clustering.math.merge_matrix import GroupMerge +from spd.clustering.merge_history import MergeHistory +from spd.clustering.util import format_scientific_latex + +DEFAULT_PLOT_CONFIG: dict[str, Any] = dict( + figsize=(16, 10), + tick_spacing=5, + save_pdf=False, + figure_prefix="merge_iteration", +) + + +def plot_merge_matrix( + merge_matrix: Bool[Tensor, "k_groups n_components"], + show: bool = True, + figsize: tuple[int, int] = (10, 3), + show_row_sums: bool | None = None, + ax: "plt.Axes | None" = None, + component_labels: ComponentLabels | None = None, +) -> None: + import matplotlib.pyplot as plt + + k_groups: int + k_groups, _ = merge_matrix.shape + group_sizes: Int[Tensor, " k_groups"] = merge_matrix.sum(dim=1) + + if show_row_sums is None: + show_row_sums = k_groups <= 20 + + ax_lbl: plt.Axes | None = None + if ax is not None: + show_row_sums = False # don't show row sums if we have an ax to plot on + ax_mat = ax + assert not show_row_sums + else: + if show_row_sums: + _fig, (ax_mat, ax_lbl) = plt.subplots( + 1, 2, figsize=figsize, gridspec_kw={"width_ratios": [10, 1]} + ) + else: + _fig, ax_mat = plt.subplots(figsize=figsize) + + ax_mat.matshow(merge_matrix.cpu(), aspect="auto", cmap="Blues", interpolation="nearest") + ax_mat.set_xlabel("Components") + ax_mat.set_ylabel("Groups") + ax_mat.set_title("Merge Matrix") + + # Add component labeling if component labels are provided + if component_labels is not None: + # Import the function here to avoid circular imports + from spd.clustering.plotting.activations import add_component_labeling + + add_component_labeling(ax_mat, component_labels, axis="x") + + if show_row_sums: + assert ax_lbl is not None + ax_lbl.set_xlim(0, 1) + ax_lbl.set_ylim(-0.5, k_groups - 0.5) + ax_lbl.invert_yaxis() + ax_lbl.set_title("Row Sums") + ax_lbl.axis("off") + for i, size in enumerate(group_sizes): + ax_lbl.text(0.5, i, str(size.item()), va="center", ha="center", fontsize=12) + + plt.tight_layout() + if show: + plt.show() + + +def plot_merge_iteration( + current_merge: GroupMerge, + current_coact: ClusterCoactivationShaped, + costs: ClusterCoactivationShaped, + # pair_cost: float, + iteration: int, + component_labels: ComponentLabels | None = None, + plot_config: dict[str, Any] | None = None, + nan_diag: bool = True, + show: bool = False, +) -> plt.Figure: + """Plot merge iteration results with merge tree, coactivations, and costs. + + Args: + current_merge: Current merge state + current_coact: Current coactivation matrix + costs: Current cost matrix + pair_cost: Cost of selected merge pair + iteration: Current iteration number + component_labels: Component labels for axis labeling + plot_config: Plot configuration settings + nan_diag: Whether to set diagonal to NaN for visualization + show: Whether to display the plot (default: False) + + Returns: + The matplotlib figure object + + Note: + Caller is responsible for closing the returned figure with plt.close(fig) + to prevent memory leaks. + """ + plot_config_: dict[str, Any] = { + **DEFAULT_PLOT_CONFIG, + **(plot_config or {}), + } + axs: list[plt.Axes] + fig, axs = plt.subplots( + 1, 3, figsize=plot_config_["figsize"], sharey=True, gridspec_kw={"width_ratios": [2, 1, 1]} + ) + + # Merge plot + plot_merge_matrix( + current_merge.to_matrix(), + ax=axs[0], + show=False, + component_labels=component_labels, + ) + + axs[0].set_title("Merge") + + # Coactivations plot + coact_min: float = current_coact.min().item() + coact_max: float = current_coact.max().item() + if nan_diag: + current_coact = current_coact.clone() + current_coact.fill_diagonal_(np.nan) + axs[1].matshow(current_coact.cpu().numpy(), aspect="equal") + coact_min_str: str = format_scientific_latex(coact_min) + coact_max_str: str = format_scientific_latex(coact_max) + axs[1].set_title(f"Coactivations\n[{coact_min_str}, {coact_max_str}]") + + # Setup ticks for coactivations + k_groups: int = current_coact.shape[0] + minor_ticks: list[int] = list(range(0, k_groups, plot_config_["tick_spacing"])) + axs[1].set_yticks(minor_ticks) + axs[1].set_xticks(minor_ticks) + axs[1].set_xticklabels([]) # Remove x-axis tick labels but keep ticks + + # Costs plot + costs_min: float = costs.min().item() + costs_max: float = costs.max().item() + if nan_diag: + costs = costs.clone() + costs.fill_diagonal_(np.nan) + axs[2].matshow(costs.cpu().numpy(), aspect="equal") + costs_min_str: str = format_scientific_latex(costs_min) + costs_max_str: str = format_scientific_latex(costs_max) + axs[2].set_title(f"Costs\n[{costs_min_str}, {costs_max_str}]") + + # Setup ticks for costs + axs[2].set_yticks(minor_ticks) + axs[2].set_xticks(minor_ticks) + axs[2].set_xticklabels([]) # Remove x-axis tick labels but keep ticks + + # fig.suptitle(f"Iteration {iteration} with cost {pair_cost:.4f}") + fig.suptitle(f"Iteration {iteration}") + plt.tight_layout() + + if plot_config_["save_pdf"]: + fig.savefig( + f"{plot_config_['figure_prefix']}_iter_{iteration:03d}.pdf", + bbox_inches="tight", + dpi=300, + ) + + if show: + plt.show() + + return fig + + +def plot_dists_distribution( + distances: DistancesArray, + mode: Literal["points", "dist"] = "points", + label: str | None = None, + ax: plt.Axes | None = None, + kwargs_fig: dict[str, Any] | None = None, + kwargs_plot: dict[str, Any] | None = None, + use_symlog: bool = True, + linthresh: float = 1.0, +) -> plt.Axes: + n_iters: int = distances.shape[0] + n_ens: int = distances.shape[1] + assert distances.shape[2] == n_ens, "Distances must be square" + + # Ensure ax and kwargs_fig are not both provided + if ax is not None and kwargs_fig is not None: + raise ValueError("Cannot provide both ax and kwargs_fig") + + dists_flat: Float[np.ndarray, " n_iters n_ens*n_ens"] = distances.reshape( + distances.shape[0], -1 + ) + + # Create figure if ax not provided + if ax is None: + _fig, ax_ = plt.subplots( # pyright: ignore[reportCallIssue] + 1, + 1, + **dict( + figsize=(8, 5), # pyright: ignore[reportArgumentType] + **(kwargs_fig or {}), + ), + ) + else: + ax_ = ax + + if mode == "points": + # Original points mode + n_samples: int = dists_flat.shape[1] + for i in range(n_iters): + ax_.plot( + np.full((n_samples), i), + dists_flat[i], + **dict( # pyright: ignore[reportArgumentType] + marker="o", + linestyle="", + color="blue", + alpha=min(1, 10 / (n_ens * n_ens)), + markersize=5, + markeredgewidth=0, + **(kwargs_plot or {}), + ), + ) + elif mode == "dist": + # Distribution statistics mode + # Generate a random color for this plot + color: Float[np.ndarray, " 3"] = np.random.rand(3) + + # Calculate statistics for each iteration + mins: list[float] = [] + maxs: list[float] = [] + means: list[float] = [] + medians: list[float] = [] + q1s: list[float] = [] + q3s: list[float] = [] + + for i in range(n_iters): + # Filter out NaN values (diagonal and upper triangle) + valid_dists: Float[np.ndarray, " n_valid"] = dists_flat[i][~np.isnan(dists_flat[i])] + if len(valid_dists) > 0: + mins.append(np.min(valid_dists)) + maxs.append(np.max(valid_dists)) + means.append(float(np.mean(valid_dists))) + medians.append(float(np.median(valid_dists))) + q1s.append(float(np.percentile(valid_dists, 25))) + q3s.append(float(np.percentile(valid_dists, 75))) + else: + # Handle case with no valid distances + mins.append(np.nan) + maxs.append(np.nan) + means.append(np.nan) + medians.append(np.nan) + q1s.append(np.nan) + q3s.append(np.nan) + + iterations: Int[np.ndarray, " n_iters"] = np.arange(n_iters) + + # Plot statistics + ax_.plot(iterations, mins, "-", color=color, alpha=0.5) + ax_.plot(iterations, maxs, "-", color=color, alpha=0.5) + ax_.plot(iterations, means, "-", color=color, linewidth=2, label=label) + ax_.plot(iterations, medians, "--", color=color, linewidth=2) + ax_.plot(iterations, q1s, ":", color=color, alpha=0.7) + ax_.plot(iterations, q3s, ":", color=color, alpha=0.7) + + # Shade between quartiles + ax_.fill_between(iterations, q1s, q3s, color=color, alpha=0.2) + + ax_.set_xlabel("Iteration #") + ax_.set_ylabel("distance") + ax_.set_title("Distribution of pairwise distances between group merges in an ensemble") + + if use_symlog: + from matplotlib.ticker import FuncFormatter + + ax_.set_yscale("symlog", linthresh=linthresh, linscale=0.2) + + # Custom formatter for y-axis ticks + def custom_format(y: float, _pos: int) -> str: + if abs(y) < linthresh: + # Show exact values in the linear range + return f"{y:.1f}" + elif abs(y) == 1: + return "1" + elif abs(y) == 10: + return "10" + else: + # Use scientific notation for larger values + exponent = int(np.log10(abs(y))) + return f"$10^{{{exponent}}}$" + + ax_.yaxis.set_major_formatter(FuncFormatter(custom_format)) + + # Add a visual indicator for the linear region (0 to linthresh) + ax_.axhspan(0, linthresh, alpha=0.05, color="gray", zorder=-10) + # Add subtle lines at linthresh boundaries + ax_.axhline(linthresh, color="gray", linestyle="--", linewidth=0.5, alpha=0.3) + if linthresh > 0: + ax_.axhline(0, color="gray", linestyle="-", linewidth=0.5, alpha=0.3) + + return ax_ + + +def plot_merge_history_cluster_sizes( + history: MergeHistory, + figsize: tuple[int, int] = (10, 5), + fmt: str = "png", + file_prefix: str | None = None, +) -> plt.Figure: + """Plot cluster sizes over iterations. + + Note: + Caller is responsible for closing the returned figure with plt.close(fig) + to prevent memory leaks. + """ + k_groups_t: Int[Tensor, " n_iters"] = history.merges.k_groups + valid_mask: Bool[Tensor, " n_iters"] = k_groups_t.ne(-1) + has_data: bool = bool(valid_mask.any().item()) + if not has_data: + raise ValueError("No populated iterations in history.k_groups") + + group_idxs_all: Int[Tensor, " n_iters n_components"] = history.merges.group_idxs[valid_mask] + k_groups_all: Int[Tensor, " n_iters"] = k_groups_t[valid_mask] + max_k: int = int(k_groups_all.max().item()) + + counts_list: list[Int[Tensor, " max_k"]] = [ + torch.bincount(row[row.ge(0)], minlength=max_k) # per-iteration cluster sizes + for row in group_idxs_all + ] + counts: Int[Tensor, " n_iters max_k"] = torch.stack(counts_list, dim=0) + + mask_pos: Bool[Tensor, " n_iters max_k"] = counts.gt(0) + it_idx_t, grp_idx_t = torch.nonzero(mask_pos, as_tuple=True) + xs_t: Float[Tensor, " n_points"] = it_idx_t.to(torch.float32) + sizes_t: Float[Tensor, " n_points"] = counts[it_idx_t, grp_idx_t].to(torch.float32) + + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + xs_t.cpu().numpy(), sizes_t.cpu().numpy(), "bo", markersize=3, alpha=0.15, markeredgewidth=0 + ) + ax.set_xlabel("Iteration") + ax.set_ylabel("Cluster size") + ax.set_yscale("log") + ax.set_title("Distribution of cluster sizes over time") + + if file_prefix is not None: + fig.savefig(f"{file_prefix}_cluster_sizes.{fmt}", bbox_inches="tight", dpi=300) + + return fig diff --git a/spd/clustering/scripts/calc_distances.py b/spd/clustering/scripts/calc_distances.py new file mode 100644 index 000000000..f12ac75b7 --- /dev/null +++ b/spd/clustering/scripts/calc_distances.py @@ -0,0 +1,142 @@ +"""Calculate distances between clustering runs in an ensemble. + +Output structure: + SPD_CACHE_DIR/ensemble/{pipeline_run_id}/ + ├── pipeline_config.yaml # Created by run_pipeline.py + ├── ensemble_meta.json # Ensemble metadata + ├── ensemble_merge_array.npz # Normalized merge array + ├── distances_.npz # Distance array for each method + └── plots/ + └── distances_.png # Distance distribution plot +""" + +import argparse +import json +import multiprocessing + +import numpy as np +import torch +from matplotlib import pyplot as plt +from matplotlib.axes import Axes + +from spd.clustering.consts import DistancesArray, DistancesMethod +from spd.clustering.ensemble_registry import get_clustering_runs +from spd.clustering.math.merge_distances import compute_distances +from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble +from spd.clustering.plotting.merge import plot_dists_distribution +from spd.clustering.scripts.run_clustering import ClusteringRunStorage +from spd.log import logger +from spd.settings import SPD_CACHE_DIR +from spd.utils.run_utils import ExecutionStamp + +# Set spawn method for CUDA compatibility with multiprocessing +# Must be done before any CUDA operations +if torch.cuda.is_available(): + try: # noqa: SIM105 + multiprocessing.set_start_method("spawn") + except RuntimeError: + # Already set, ignore + pass + + +def main(pipeline_run_id: str, distances_method: DistancesMethod) -> None: + """Calculate distances between clustering runs in an ensemble. + + Args: + pipeline_run_id: Pipeline run ID to query from registry + distances_method: Method for calculating distances + """ + logger.info(f"Calculating distances for pipeline run: {pipeline_run_id}") + + # Query registry for clustering runs + clustering_runs = get_clustering_runs(pipeline_run_id) + if not clustering_runs: + raise ValueError(f"No clustering runs found for pipeline {pipeline_run_id}") + + logger.info(f"Found {len(clustering_runs)} clustering runs") + + # Load histories from individual clustering run directories + histories: list[MergeHistory] = [] + for idx, clustering_run_id in clustering_runs: + history_path = ClusteringRunStorage( + ExecutionStamp( + run_id=clustering_run_id, + snapshot_branch="", + commit_hash="", + run_type="cluster", + ) + ).history_path + + # SPD_CACHE_DIR / "cluster" / clustering_run_id / "history.npz" + if not history_path.exists(): + raise FileNotFoundError( + f"History not found for run {clustering_run_id}: {history_path}" + ) + histories.append(MergeHistory.read(history_path)) + logger.info(f"Loaded history for run {idx}: {clustering_run_id}") + + # Compute normalized ensemble + ensemble: MergeHistoryEnsemble = MergeHistoryEnsemble(data=histories) + merge_array, merge_meta = ensemble.normalized() + + # Get pipeline output directory + pipeline_dir = SPD_CACHE_DIR / "ensemble" / pipeline_run_id + + # Save ensemble metadata and merge array + ensemble_meta_path = pipeline_dir / "ensemble_meta.json" + ensemble_meta_path.write_text(json.dumps(merge_meta, indent=2)) + logger.info(f"Saved ensemble metadata to {ensemble_meta_path}") + + ensemble_array_path = pipeline_dir / "ensemble_merge_array.npz" + np.savez_compressed(ensemble_array_path, merge_array=merge_array) + logger.info(f"Saved ensemble merge array to {ensemble_array_path}") + + # Compute distances + logger.info(f"Computing distances using method: {distances_method}") + distances: DistancesArray = compute_distances( + normalized_merge_array=merge_array, + method=distances_method, + ) + + distances_path = pipeline_dir / f"distances_{distances_method}.npz" + np.savez_compressed(distances_path, distances=distances) + logger.info(f"Distances computed and saved: shape={distances.shape}, path={distances_path}") + + # Create and save distances distribution plot + ax: Axes = plot_dists_distribution( + distances=distances, mode="points", label=f"{distances_method} distances" + ) + plt.title(f"Distance Distribution ({distances_method})") + + # Only add legend if there are labeled artists + handles, _labels = ax.get_legend_handles_labels() + if handles: + plt.legend() + + plots_dir = pipeline_dir / "plots" + plots_dir.mkdir(parents=True, exist_ok=True) + fig_path = plots_dir / f"distances_{distances_method}.png" + plt.savefig(fig_path) + plt.close() + logger.info(f"Saved distances distribution plot to {fig_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Calculate distances between clustering runs") + parser.add_argument( + "--pipeline-run-id", + type=str, + required=True, + help="Pipeline run ID to query from registry", + ) + parser.add_argument( + "--distances-method", + choices=DistancesMethod.__args__, + default="perm_invariant_hamming", + help="Method for calculating distances", + ) + args = parser.parse_args() + main( + pipeline_run_id=args.pipeline_run_id, + distances_method=args.distances_method, + ) diff --git a/spd/clustering/scripts/get_cluster_mapping.py b/spd/clustering/scripts/get_cluster_mapping.py new file mode 100644 index 000000000..c2212e5cd --- /dev/null +++ b/spd/clustering/scripts/get_cluster_mapping.py @@ -0,0 +1,163 @@ +"""Extract cluster mapping from an ensemble at a specific iteration. + +Usage: + python -m spd.clustering.scripts.get_cluster_mapping /path/to/ensemble --iteration 299 + python -m spd.clustering.scripts.get_cluster_mapping /path/to/ensemble --iteration 299 --run-idx 0 + python -m spd.clustering.scripts.get_cluster_mapping /path/to/ensemble --iteration 299 --notes "some notes" + +Output format: + { + "ensemble_id": "e-5f228e5f", + "notes": "", + "spd_run": "spd/goodfire/5cr21lbs", + "clusters": {"h.0.mlp.down_proj:1": 0, "h.0.mlp.down_proj:2": null, ...} + } + + Note: Singleton clusters (clusters with only one member) have null values. +""" + +import json +import sys +from pathlib import Path + +import fire +import numpy as np +import yaml + +from spd.settings import REPO_ROOT +from spd.utils.wandb_utils import parse_wandb_run_path + + +def get_cluster_mapping( + ensemble_dir: str | Path, + n_iterations: int, + run_idx: int = 0, +) -> dict[str, int | None]: + """Get mapping from component labels to cluster indices at a specific iteration. + + Args: + ensemble_dir: Path to ensemble directory containing ensemble_merge_array.npz + and ensemble_meta.json + n_iterations: Number of iterations to extract clusters from + run_idx: Run index within the ensemble (default 0) + + Returns: + Mapping from component label (e.g. "h.0.mlp.down_proj:42") to cluster index, + or None for singleton clusters (clusters with only one member). + """ + ensemble_dir = Path(ensemble_dir) + + merge_array_path = ensemble_dir / "ensemble_merge_array.npz" + meta_path = ensemble_dir / "ensemble_meta.json" + + assert merge_array_path.exists(), f"Merge array not found: {merge_array_path}" + assert meta_path.exists(), f"Metadata not found: {meta_path}" + + merge_data = np.load(merge_array_path) + merge_array = merge_data["merge_array"] # shape: (n_runs, n_iterations, n_components) + + with open(meta_path) as f: + meta = json.load(f) + + component_labels: list[str] = meta["component_labels"] + n_runs, n_iterations_stored, n_components = merge_array.shape + + assert 0 <= run_idx < n_runs, f"run_idx {run_idx} out of bounds [0, {n_runs})" + assert 0 <= n_iterations < n_iterations_stored, ( + f"n_iterations {n_iterations} out of bounds [0, {n_iterations_stored})" + ) + assert len(component_labels) == n_components, ( + f"Label count mismatch: {len(component_labels)} labels vs {n_components} components" + ) + + assignments = merge_array[run_idx, n_iterations, :] + + # Count members per cluster to identify singletons + cluster_ids, counts = np.unique(assignments, return_counts=True) + singleton_clusters = set(cluster_ids[counts == 1]) + + return { + label: None if cluster_id in singleton_clusters else int(cluster_id) + for label, cluster_id in zip(component_labels, assignments, strict=True) + } + + +def get_spd_run_path(ensemble_dir: Path) -> str: + """Extract the SPD run path from the ensemble's pipeline config. + + Follows pipeline_config.yaml -> clustering_run_config_path -> model_path, + then parses the wandb path. + + Returns: + Formatted path like "spd/goodfire/5cr21lbs" + """ + pipeline_config_path = ensemble_dir / "pipeline_config.yaml" + assert pipeline_config_path.exists(), f"Pipeline config not found: {pipeline_config_path}" + + with open(pipeline_config_path) as f: + pipeline_config = yaml.safe_load(f) + + clustering_run_config_path = REPO_ROOT / pipeline_config["clustering_run_config_path"] + assert clustering_run_config_path.exists(), ( + f"Clustering run config not found: {clustering_run_config_path}" + ) + + with open(clustering_run_config_path) as f: + clustering_run_config = json.load(f) + + model_path = clustering_run_config["model_path"] + entity, project, run_id = parse_wandb_run_path(model_path) + + return f"{entity}/{project}/{run_id}" + + +def main( + ensemble_dir: str, + n_iterations: int, + run_idx: int = 0, + notes: str = "", + output: str | None = None, +) -> None: + """Extract cluster mapping with metadata and output as JSON. + + Args: + ensemble_dir: Path to ensemble directory + n_iterations: Number of iterations to extract clusters from + run_idx: Run index within the ensemble (default 0) + notes: Optional notes to include in the output + output: Optional output file path. If not provided, writes to + {ensemble_dir}/cluster_mapping_{ensemble_id}.json + """ + ensemble_path = Path(ensemble_dir) + + clusters = get_cluster_mapping( + ensemble_dir=ensemble_dir, + n_iterations=n_iterations, + run_idx=run_idx, + ) + + ensemble_id = ensemble_path.name + spd_run = get_spd_run_path(ensemble_path) + + result = { + "ensemble_id": ensemble_id, + "notes": notes, + "spd_run": spd_run, + "n_iterations": n_iterations, + "run_idx": run_idx, + "clusters": clusters, + } + + json_str = json.dumps(result, indent=2) + + if output is None: + out_path = ensemble_path / f"cluster_mapping_{ensemble_id}.json" + else: + out_path = Path(output) + + out_path.write_text(json_str) + print(f"Wrote mapping ({len(clusters)} components) to {out_path}", file=sys.stderr) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py new file mode 100644 index 000000000..2dbafe1fa --- /dev/null +++ b/spd/clustering/scripts/run_clustering.py @@ -0,0 +1,443 @@ +"""Perform a single clustering run. + +This can be run as a standalone script, or called via `spd-clustering` +(i.e. clustering/scripts/run_pipeline.py). If called via spd-clustering, the ensemble-key is passed +in to identify the run within the pipeline ensemble. + +Output structure: + / # from execution stamp (run_type="cluster") + ├── clustering_run_config.json + └── history.npz +""" + +import argparse +import gc +import os +import tempfile +from collections.abc import Callable +from functools import partial +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import torch +import wandb +from jaxtyping import Float, Int +from matplotlib.figure import Figure +from torch import Tensor +from wandb.sdk.wandb_run import Run + +from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, +) +from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.consts import ( + ActivationsTensor, + BatchTensor, + ClusterCoactivationShaped, + ComponentLabels, +) +from spd.clustering.dataset import load_dataset +from spd.clustering.ensemble_registry import _ENSEMBLE_REGISTRY_DB, register_clustering_run +from spd.clustering.math.merge_matrix import GroupMerge +from spd.clustering.math.semilog import semilog +from spd.clustering.merge import merge_iteration +from spd.clustering.merge_history import MergeHistory +from spd.clustering.plotting.activations import plot_activations +from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration +from spd.clustering.storage import StorageBase +from spd.clustering.wandb_tensor_info import wandb_log_tensor +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.spd_types import TaskName +from spd.utils.distributed_utils import get_device +from spd.utils.general_utils import replace_pydantic_model +from spd.utils.run_utils import _NO_ARG_PARSSED_SENTINEL, ExecutionStamp, read_noneable_str + +os.environ["WANDB_QUIET"] = "true" + + +class ClusteringRunStorage(StorageBase): + """Storage paths for a single clustering run. + + All paths are relative to ExecutionStamp.out_dir. + """ + + # Relative path constants + _CONFIG = "clustering_run_config.json" + # we are saving a zip file with things in it besides npy files -- hence, `.zip` and not `.npz` + _HISTORY = "history.zip" + + def __init__(self, execution_stamp: ExecutionStamp) -> None: + super().__init__(execution_stamp) + self.config_path: Path = self.base_dir / self._CONFIG + self.history_path: Path = self.base_dir / self._HISTORY + + +LogCallback = Callable[ + [ + ClusterCoactivationShaped, + ComponentLabels, + GroupMerge, + ClusterCoactivationShaped, + MergeHistory, + int, + int, + float, + float, + float, + Float[Tensor, " k_groups"], + ], + None, +] + + +def _log_merge_history_plots(run: Run, history: MergeHistory) -> None: + """Log merge history plots to WandB.""" + fig_cs: Figure = plot_merge_history_cluster_sizes(history=history) + run.log( + {"plots/merge_history_cluster_sizes": wandb.Image(fig_cs)}, + step=history.n_iters_current, + ) + plt.close(fig_cs) + + +def _save_merge_history_artifact( + run: Run, + history_path: Path, + history: MergeHistory, +) -> None: + """Save merge history as WandB artifact.""" + artifact: wandb.Artifact = wandb.Artifact( + name="merge_history", + type="merge_history", + description="Merge history", + metadata={"n_iters_current": history.n_iters_current, "filename": str(history_path)}, + ) + artifact.add_file(str(history_path)) + run.log_artifact(artifact) + + +def _log_callback( + run: Run, + run_config: ClusteringRunConfig, + current_coact: ClusterCoactivationShaped, + component_labels: ComponentLabels, + current_merge: GroupMerge, + costs: ClusterCoactivationShaped, + merge_history: MergeHistory, + iter_idx: int, + k_groups: int, + merge_pair_cost: float, + mdl_loss: float, + mdl_loss_norm: float, + diag_acts: Float[Tensor, " k_groups"], +) -> None: + """Callback for logging during merge iteration.""" + if iter_idx % run_config.logging_intervals.stat == 0: + run.log( + { + "k_groups": int(k_groups), + "merge_pair_cost": merge_pair_cost, + "merge_pair_cost_semilog[1e-3]": semilog(merge_pair_cost, epsilon=1e-3), + "mdl_loss": float(mdl_loss), + "mdl_loss_norm": float(mdl_loss_norm), + }, + step=iter_idx, + ) + + if iter_idx % run_config.logging_intervals.tensor == 0: + group_sizes: Int[Tensor, " k_groups"] = current_merge.components_per_group + + tensor_data: dict[str, Tensor] = { + "coactivation": current_coact, + "costs": costs, + "group_sizes": group_sizes, + "group_activations": diag_acts, + "group_activations_over_sizes": ( + diag_acts / group_sizes.to(device=diag_acts.device).float() + ), + } + + fraction_singleton_groups: float = (group_sizes == 1).float().mean().item() + if fraction_singleton_groups > 0: + tensor_data["group_sizes.log1p"] = torch.log1p(group_sizes.float()) + + fraction_zero_coacts: float = (current_coact == 0).float().mean().item() + if fraction_zero_coacts > 0: + tensor_data["coactivation.log1p"] = torch.log1p(current_coact.float()) + + wandb_log_tensor(run, tensor_data, name="iters", step=iter_idx) + + run.log( + { + "fraction_singleton_groups": float(fraction_singleton_groups), + "num_nonsingleton_groups": int((group_sizes > 1).sum().item()), + "fraction_zero_coacts": float(fraction_zero_coacts), + }, + step=iter_idx, + ) + + if iter_idx > 0 and iter_idx % run_config.logging_intervals.artifact == 0: + with tempfile.NamedTemporaryFile() as tmp_file: + file: Path = Path(tmp_file.name) + merge_history.save(file) + artifact: wandb.Artifact = wandb.Artifact( + name=f"merge_hist_iter.iter_{iter_idx}", + type="merge_hist_iter", + description=f"Group indices at iteration {iter_idx}", + metadata={ + "iteration": iter_idx, + "config": merge_history.merge_config.model_dump(mode="json"), + }, + ) + artifact.add_file(str(file)) + run.log_artifact(artifact) + + if iter_idx % run_config.logging_intervals.plot == 0: + fig: Figure = plot_merge_iteration( + current_merge=current_merge, + current_coact=current_coact, + costs=costs, + iteration=iter_idx, + component_labels=component_labels, + show=False, + ) + run.log({"plots/merges": wandb.Image(fig)}, step=iter_idx) + plt.close(fig) + + +def main(run_config: ClusteringRunConfig) -> Path: + """A single clustering run. + + Args: + run_config: Runtime parameters for this clustering run + + Returns: + Path to saved merge history file + """ + # Create ExecutionStamp and storage + # don't create git snapshot -- if we are part of an ensemble, the snapshot should be created by the pipeline + execution_stamp = ExecutionStamp.create( + run_type="cluster", + create_snapshot=False, + ) + storage = ClusteringRunStorage(execution_stamp) + clustering_run_id = execution_stamp.run_id + logger.info(f"Clustering run ID: {clustering_run_id}") + + # Register with ensemble if this is part of a pipeline + assigned_idx: int | None + if run_config.ensemble_id: + assigned_idx = register_clustering_run( + pipeline_run_id=run_config.ensemble_id, + clustering_run_id=clustering_run_id, + ) + + logger.info( + f"Registered with pipeline {run_config.ensemble_id} at index {assigned_idx} in {_ENSEMBLE_REGISTRY_DB}" + ) + # IMPORTANT: set dataset seed based on assigned index + run_config = replace_pydantic_model( + run_config, + {"dataset_seed": run_config.dataset_seed + assigned_idx}, + ) + else: + assigned_idx = None + + # save config + run_config.to_file(storage.config_path) + logger.info(f"Config saved to {storage.config_path}") + + # start + logger.info("Starting clustering run") + logger.info(f"Output directory: {storage.base_dir}") + device = get_device() + + spd_run = SPDRunInfo.from_path(run_config.model_path) + task_name: TaskName = spd_run.config.task_config.task_name + + # 1. Load dataset + logger.info(f"Loading dataset (seed={run_config.dataset_seed})") + load_dataset_kwargs: dict[str, Any] = dict() + if run_config.dataset_streaming: + logger.info("Using streaming dataset loading") + load_dataset_kwargs["config_kwargs"] = dict(streaming=True) + assert task_name == "lm", ( + f"Streaming dataset loading only supported for 'lm' task, got '{task_name = }'. Remove dataset_streaming=True from config or use a different task." + ) + batch: BatchTensor = load_dataset( + model_path=run_config.model_path, + task_name=task_name, + batch_size=run_config.batch_size, + seed=run_config.dataset_seed, + **load_dataset_kwargs, + ) + batch = batch.to(device) + + # 2. Setup WandB for this run + wandb_run: Run | None = None + if run_config.wandb_project is not None: + wandb_run = wandb.init( + id=clustering_run_id, + entity=run_config.wandb_entity, + project=run_config.wandb_project, + group=run_config.ensemble_id, + config=run_config.model_dump(mode="json"), + tags=[ + "clustering", + f"task:{task_name}", + f"model:{run_config.wandb_decomp_model}", + f"ensemble_id:{run_config.ensemble_id}", + f"assigned_idx:{assigned_idx}", + ], + ) + # logger.info(f"WandB run: {wandb_run.url}") + + # 3. Load model + logger.info("Loading model") + model = ComponentModel.from_run_info(spd_run).to(device) + + # 4. Compute activations + logger.info("Computing activations") + activations_dict: ( + dict[str, Float[Tensor, "batch seq C"]] | dict[str, Float[Tensor, "batch C"]] + ) = component_activations( + model=model, + batch=batch, + device=device, + ) + + # 5. Process activations + logger.info("Processing activations") + processed_activations: ProcessedActivations = process_activations( + activations=activations_dict, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + seq_mode="concat" if task_name == "lm" else None, + filter_modules=run_config.merge_config.filter_modules, + ) + + # 6. Log activations (if WandB enabled) + if wandb_run is not None: + logger.info("Plotting activations") + plot_activations( + processed_activations=processed_activations, + save_dir=None, # Don't save to disk, only WandB + n_samples_max=256, + wandb_run=wandb_run, + ) + wandb_log_tensor( + wandb_run, + processed_activations.activations, + "activations", + 0, + single=True, + ) + + # Clean up memory + activations: ActivationsTensor = processed_activations.activations + component_labels: ComponentLabels = ComponentLabels(processed_activations.labels.copy()) + del processed_activations + del activations_dict + del model + del batch + gc.collect() + + # 7. Run merge iteration + logger.info("Starting merging") + log_callback: LogCallback | None = ( + partial(_log_callback, run=wandb_run, run_config=run_config) + if wandb_run is not None + else None + ) + + history: MergeHistory = merge_iteration( + merge_config=run_config.merge_config, + activations=activations, + component_labels=component_labels, + log_callback=log_callback, + ) + + # 8. Save merge history + + history.save(storage.history_path) + logger.info(f"History saved to {storage.history_path}") + + # 9. Log to WandB + if wandb_run is not None: + _log_merge_history_plots(wandb_run, history) + _save_merge_history_artifact(wandb_run, storage.history_path, history) + wandb_run.finish() + logger.info("WandB run finished") + + return storage.history_path + + +def cli() -> None: + """CLI for running a single clustering run.""" + parser = argparse.ArgumentParser(description="Run clustering on a single dataset") + parser.add_argument( + "--config", + type=Path, + required=True, + help="Path to ClusteringRunConfig file", + ) + parser.add_argument( + "--pipeline-run-id", + type=str, + default=None, + help="Pipeline run ID (ensemble identifier). If provided with --idx-in-ensemble, registers run.", + ) + parser.add_argument( + "--idx-in-ensemble", + type=int, + default=None, + help="Index of this run in the ensemble", + ) + parser.add_argument( + "--wandb-project", + type=read_noneable_str, + default=_NO_ARG_PARSSED_SENTINEL, + help="WandB project name (if not provided, WandB logging is disabled)", + ) + parser.add_argument( + "--wandb-entity", + type=str, + default=None, + help="WandB entity name (user or team)", + ) + parser.add_argument( + "--dataset-streaming", + action="store_true", + help="Whether to use streaming dataset loading (if supported by the dataset)", + ) + + args: argparse.Namespace = parser.parse_args() + + # Load base config + run_config = ClusteringRunConfig.from_file(args.config) + + # Override config values from CLI + overrides: dict[str, Any] = { + "dataset_streaming": args.dataset_streaming, + } + + # Handle ensemble-related overrides + if args.pipeline_run_id is not None: + overrides["ensemble_id"] = args.pipeline_run_id + + if args.wandb_project is not _NO_ARG_PARSSED_SENTINEL: + overrides["wandb_project"] = args.wandb_project + if args.wandb_entity is not None: + overrides["wandb_entity"] = args.wandb_entity + + run_config = replace_pydantic_model(run_config, overrides) + + # Run clustering + main(run_config) + + +if __name__ == "__main__": + cli() diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py new file mode 100644 index 000000000..50c21a5e1 --- /dev/null +++ b/spd/clustering/scripts/run_pipeline.py @@ -0,0 +1,483 @@ +"""Submit clustering runs to SLURM as separate jobs in a SLURM array. + +This script submits independent clustering runs as a SLURM job array, +where each run gets its own dataset (seeded), WandB run, and merge history output. + +Also submits a job to calculate distances between the clustering runs, which will run after +the clustering runs (the SLURM job depends on the previous array job). + +Output structure (only pipeline_config.json is saved to directly in this script. The files under + are saved by run_clustering.py which is called in SLURM jobs deployed by this script.): + / # from execution stamp + |── pipeline_config.json # Saved in this script + |── clustering_run_config.json # make copy of the file pointed to by pipeline config + ├── ensemble_meta.json # (Saved by calc_distances.py) Ensemble metadata + ├── ensemble_merge_array.npz # (Saved by calc_distances.py) Normalized merge array + ├── distances_.npz # (Saved by calc_distances.py) Distance array for each method + └── distances_.png # (Saved by calc_distances.py) Distance distribution plot +""" + +import argparse +import os +import shlex +from pathlib import Path +from typing import Any + +import wandb_workspaces.workspaces as ws +from pydantic import Field, PositiveInt, field_validator, model_validator + +from spd.base_config import BaseConfig +from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.consts import DistancesMethod +from spd.clustering.storage import StorageBase +from spd.log import logger +from spd.settings import SPD_CACHE_DIR +from spd.utils.general_utils import replace_pydantic_model +from spd.utils.run_utils import ( + _NO_ARG_PARSSED_SENTINEL, + ExecutionStamp, + read_noneable_str, + run_locally, +) +from spd.utils.slurm import ( + SlurmArrayConfig, + SlurmConfig, + generate_array_script, + generate_script, + submit_slurm_job, +) + +os.environ["WANDB_QUIET"] = "true" + + +class ClusteringPipelineStorage(StorageBase): + """Storage paths for clustering pipeline (ensemble). + + All paths are relative to ExecutionStamp.out_dir. + """ + + # Relative path constants + _PIPELINE_CONFIG = "pipeline_config.yaml" + _RUN_IDS = "run_ids.json" + _ENSEMBLE_META = "ensemble_meta.json" + _ENSEMBLE_MERGE_ARRAY = "ensemble_merge_array.npz" + + def __init__(self, execution_stamp: ExecutionStamp) -> None: + super().__init__(execution_stamp) + self.pipeline_config_path: Path = self.base_dir / self._PIPELINE_CONFIG + self.run_ids_path: Path = self.base_dir / self._RUN_IDS + self.ensemble_meta_path: Path = self.base_dir / self._ENSEMBLE_META + self.ensemble_merge_array_path: Path = self.base_dir / self._ENSEMBLE_MERGE_ARRAY + + def distances_path(self, method: DistancesMethod) -> Path: + return self.base_dir / f"distances_{method}.npz" + + +class ClusteringPipelineConfig(BaseConfig): + """Configuration for submitting an ensemble of clustering runs to SLURM.""" + + clustering_run_config_path: Path = Field( + description="Path to ClusteringRunConfig file.", + ) + n_runs: PositiveInt = Field(description="Number of clustering runs in the ensemble") + distances_methods: list[DistancesMethod] = Field( + description="List of method(s) to use for calculating distances" + ) + base_output_dir: Path = Field( + default=SPD_CACHE_DIR / "clustering_pipeline", + description="Base directory for outputs of clustering ensemble pipeline runs.", + ) + slurm_job_name_prefix: str | None = Field( + default=None, description="Prefix for SLURM job names" + ) + slurm_partition: str | None = Field(default=None, description="SLURM partition to use") + wandb_project: str | None = Field( + default=None, + description="Weights & Biases project name (set to None to disable WandB logging)", + ) + wandb_entity: str = Field(default="goodfire", description="WandB entity (team/user) name") + create_git_snapshot: bool = Field( + default=False, description="Create a git snapshot for the run" + ) + + @model_validator(mode="after") + def validate_crc(self) -> "ClusteringPipelineConfig": + """Validate that exactly one of clustering_run_config_path points to a valid `ClusteringRunConfig`.""" + assert self.clustering_run_config_path.exists(), ( + f"clustering_run_config_path does not exist: {self.clustering_run_config_path}" + ) + # Try to load ClusteringRunConfig + assert ClusteringRunConfig.from_file(self.clustering_run_config_path) + + return self + + @field_validator("distances_methods") + @classmethod + def validate_distances_methods(cls, v: list[DistancesMethod]) -> list[DistancesMethod]: + """Validate that distances_methods is non-empty and contains valid methods.""" + assert all(method in DistancesMethod.__args__ for method in v), ( + f"Invalid distances_methods: {v}" + ) + + return v + + +def create_clustering_workspace_view(ensemble_id: str, project: str, entity: str) -> str: + """Create WandB workspace view for clustering runs. + + TODO: Use a template workspace which actually shows some panels + TODO: since the run_id here is the same as the wandb id, can we take advantage of that? + + Args: + ensemble_id: Unique identifier for this ensemble + project: WandB project name + entity: WandB entity (team/user) name + + Returns: + URL to workspace view + """ + workspace = ws.Workspace(entity=entity, project=project) + workspace.name = f"Clustering - {ensemble_id}" + + workspace.runset_settings.filters = [ + ws.Tags("tags").isin([f"ensemble_id:{ensemble_id}"]), + ] + + try: + workspace.save_as_new_view() + return workspace.url + except Exception as e: + logger.warning( + f"Failed to create WandB workspace view: {workspace=}, {workspace.name=}, {ensemble_id=}, {project=}, {entity=}, {e}" + ) + raise e + + +def generate_clustering_commands( + pipeline_config: ClusteringPipelineConfig, + pipeline_run_id: str, + dataset_streaming: bool = False, +) -> list[str]: + """Generate commands for each clustering run. + + Args: + pipeline_config: Pipeline configuration + pipeline_run_id: Pipeline run ID (each run will create its own ExecutionStamp) + dataset_streaming: Whether to use dataset streaming + + Returns: + List of shell-safe command strings + """ + commands: list[str] = [] + + for idx in range(pipeline_config.n_runs): + cmd_parts = [ + "python", + "spd/clustering/scripts/run_clustering.py", + "--config", + pipeline_config.clustering_run_config_path.as_posix(), + "--pipeline-run-id", + pipeline_run_id, + "--idx-in-ensemble", + str(idx), + "--wandb-project", + str(pipeline_config.wandb_project), + "--wandb-entity", + pipeline_config.wandb_entity, + ] + if dataset_streaming: + cmd_parts.append("--dataset-streaming") + + commands.append(shlex.join(cmd_parts)) + + return commands + + +def generate_calc_distances_commands( + pipeline_run_id: str, distances_methods: list[DistancesMethod] +) -> list[str]: + """Generate commands for calculating distances. + + Args: + pipeline_run_id: Pipeline run ID (will query registry for clustering runs) + distances_methods: List of methods for calculating distances + + Returns: + List of shell-safe command strings, one per method + """ + commands: list[str] = [] + for method in distances_methods: + commands.append( + shlex.join( + [ + "python", + "spd/clustering/scripts/calc_distances.py", + "--pipeline-run-id", + pipeline_run_id, + "--distances-method", + method, + ] + ) + ) + return commands + + +def main( + pipeline_config: ClusteringPipelineConfig, + local: bool = False, + local_clustering_parallel: bool = False, + local_calc_distances_parallel: bool = False, + dataset_streaming: bool = False, + track_resources_calc_distances: bool = False, +) -> None: + """Submit clustering runs to SLURM. + + Args: + pipeline_config_path: Path to ClusteringPipelineConfig file + n_runs: Number of clustering runs in the ensemble. Will override value in the config file. + """ + # setup + # ========================================================================================== + + logger.set_format("console", "terse") + + if local_clustering_parallel or local_calc_distances_parallel or track_resources_calc_distances: + assert local, ( + "local_clustering_parallel, local_calc_distances_parallel, track_resources_calc_distances " + "can only be set when running locally\n" + f"{local_clustering_parallel=}, {local_calc_distances_parallel=}, {track_resources_calc_distances=}, {local=}" + ) + + # Create ExecutionStamp for pipeline + execution_stamp: ExecutionStamp = ExecutionStamp.create( + run_type="ensemble", + create_snapshot=pipeline_config.create_git_snapshot, + ) + pipeline_run_id: str = execution_stamp.run_id + logger.info(f"Pipeline run ID: {pipeline_run_id}") + + # Initialize storage + storage = ClusteringPipelineStorage(execution_stamp) + logger.info(f"Pipeline output directory: {storage.base_dir}") + + # Save pipeline config + pipeline_config.to_file(storage.pipeline_config_path) + logger.info(f"Pipeline config saved to {storage.pipeline_config_path}") + + # Create WandB workspace if requested + if pipeline_config.wandb_project is not None: + workspace_url = create_clustering_workspace_view( + ensemble_id=pipeline_run_id, + project=pipeline_config.wandb_project, + entity=pipeline_config.wandb_entity, + ) + logger.info(f"WandB workspace: {workspace_url}") + + # Generate commands for clustering runs + clustering_commands = generate_clustering_commands( + pipeline_config=pipeline_config, + pipeline_run_id=pipeline_run_id, + dataset_streaming=dataset_streaming, + ) + + # Generate commands for calculating distances + calc_distances_commands = generate_calc_distances_commands( + pipeline_run_id=pipeline_run_id, + distances_methods=pipeline_config.distances_methods, + ) + + # Submit to SLURM + if local: + # submit clustering array job + run_locally( + commands=clustering_commands, + parallel=local_clustering_parallel, + ) + + # submit calc_distances jobs in parallel + logger.info("Calculating distances...") + run_locally( + commands=calc_distances_commands, + parallel=local_calc_distances_parallel, + track_resources=track_resources_calc_distances, + ) + + logger.section("complete!") + + # Build distances plot paths dict + distances_plots = { + f"distances via {method}": str(storage.plots_dir / f"distances_{method}.png") + for method in pipeline_config.distances_methods + } + + logger.values( + { + "Total clustering runs": len(clustering_commands), + "Pipeline run ID": pipeline_run_id, + "Pipeline output dir": str(storage.base_dir), + **distances_plots, + } + ) + + else: + assert pipeline_config.slurm_job_name_prefix is not None, ( + "must specify slurm_job_name_prefix if not running locally" + ) + assert pipeline_config.slurm_partition is not None, ( + "must specify slurm_partition if not running locally" + ) + + # Submit clustering array job + clustering_config = SlurmArrayConfig( + job_name=f"{pipeline_config.slurm_job_name_prefix}_cluster", + partition=pipeline_config.slurm_partition, + n_gpus=1, # Always 1 GPU per run + snapshot_branch=execution_stamp.snapshot_branch, + max_concurrent_tasks=pipeline_config.n_runs, # Run all concurrently + ) + clustering_script = generate_array_script(clustering_config, clustering_commands) + clustering_result = submit_slurm_job( + clustering_script, + "clustering", + is_array=True, + n_array_tasks=len(clustering_commands), + ) + array_job_id = clustering_result.job_id + + # Submit calc_distances jobs (one per method) with dependency on array job + calc_distances_job_ids: list[str] = [] + calc_distances_logs: list[str] = [] + + for method, cmd in zip( + pipeline_config.distances_methods, calc_distances_commands, strict=True + ): + dist_config = SlurmConfig( + job_name=f"{pipeline_config.slurm_job_name_prefix}_dist_{method}", + partition=pipeline_config.slurm_partition, + n_gpus=1, + snapshot_branch=execution_stamp.snapshot_branch, + dependency_job_id=array_job_id, + ) + dist_script = generate_script(dist_config, cmd) + dist_result = submit_slurm_job(dist_script, f"calc_distances_{method}") + calc_distances_job_ids.append(dist_result.job_id) + calc_distances_logs.append(dist_result.log_pattern) + + logger.section("Jobs submitted successfully!") + + # Build distances plot paths dict + distances_plots = { + method: str(storage.plots_dir / f"distances_{method}.png") + for method in pipeline_config.distances_methods + } + + logger.values( + { + "Clustering Array Job ID": array_job_id, + "Calc Distances Job IDs": ", ".join(calc_distances_job_ids), + "Total clustering runs": len(clustering_commands), + "Pipeline run ID": pipeline_run_id, + "Pipeline output dir": str(storage.base_dir), + "Clustering logs": clustering_result.log_pattern, + "Calc Distances logs": ", ".join(calc_distances_logs), + } + ) + logger.info("Distances plots will be saved to:") + for method, path in distances_plots.items(): + logger.info(f" {method}: {path}") + + +def cli(): + """CLI for spd-clustering command.""" + parser = argparse.ArgumentParser( + prog="spd-clustering", + description="Submit clustering runs to SLURM. Arguments specified here will override the " + "corresponding value in the config file.", + ) + + parser.add_argument( + "--config", + required=True, + type=Path, + help="Path to pipeline config file", + ) + parser.add_argument( + "--n-runs", + type=int, + help="Number of clustering runs in the ensemble (overrides value in config file)", + ) + parser.add_argument( + "--wandb-project", + type=read_noneable_str, + default=_NO_ARG_PARSSED_SENTINEL, + help="WandB project name (if not provided, WandB logging is disabled)", + ) + parser.add_argument( + "--wandb-entity", + type=str, + default=None, + help="WandB entity name (user or team)", + ) + parser.add_argument( + "--distances-methods", + type=str, + default=None, + help="Comma-separated list of distance methods (e.g., 'perm_invariant_hamming,matching_dist')", + ) + parser.add_argument( + "--local", + action=argparse.BooleanOptionalAction, + default=False, + help="Run locally instead of submitting to SLURM (required if slurm_job_name_prefix and slurm_partition are None in config)", + ) + parser.add_argument( + "--local-clustering-parallel", + action="store_true", + help="If running locally, whether to run clustering runs in parallel", + ) + parser.add_argument( + "--local-calc-distances-parallel", + action="store_true", + help="If running locally, whether to run distance calculations in parallel", + ) + parser.add_argument( + "--track-resources-calc-distances", + action="store_true", + help="If running locally, whether to track resource usage during distance calculations", + ) + parser.add_argument( + "--dataset-streaming", + action="store_true", + help="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", + ) + + args = parser.parse_args() + + pipeline_config = ClusteringPipelineConfig.from_file(args.config) + overrides: dict[str, Any] = {} + + if args.n_runs is not None: + overrides["n_runs"] = args.n_runs + if args.wandb_project is not _NO_ARG_PARSSED_SENTINEL: + overrides["wandb_project"] = args.wandb_project + if args.wandb_entity is not None: + overrides["wandb_entity"] = args.wandb_entity + if args.distances_methods is not None: + # Parse comma-separated list of distance methods + methods = [method.strip() for method in args.distances_methods.split(",")] + overrides["distances_methods"] = methods + + pipeline_config = replace_pydantic_model(pipeline_config, overrides) + + main( + pipeline_config=pipeline_config, + local=args.local, + dataset_streaming=args.dataset_streaming, + local_clustering_parallel=args.local_clustering_parallel, + local_calc_distances_parallel=args.local_calc_distances_parallel, + track_resources_calc_distances=args.track_resources_calc_distances, + ) + + +if __name__ == "__main__": + cli() diff --git a/spd/clustering/storage.py b/spd/clustering/storage.py new file mode 100644 index 000000000..dc3d8765a --- /dev/null +++ b/spd/clustering/storage.py @@ -0,0 +1,19 @@ +"""Minimal storage base class for clustering - just path management.""" + +from pathlib import Path + +from spd.utils.run_utils import ExecutionStamp + + +class StorageBase: + """Base class for storage - provides ExecutionStamp and base directory. + + Subclasses define path constants (relative to base_dir) and set absolute paths in __init__. + Caller handles all actual saving and WandB uploading. + """ + + def __init__(self, execution_stamp: ExecutionStamp) -> None: + """Initialize storage with execution stamp.""" + self.execution_stamp: ExecutionStamp = execution_stamp + self.base_dir: Path = execution_stamp.out_dir + self.plots_dir: Path = self.base_dir / "plots" diff --git a/spd/clustering/util.py b/spd/clustering/util.py new file mode 100644 index 000000000..bd11e2fd4 --- /dev/null +++ b/spd/clustering/util.py @@ -0,0 +1,18 @@ +from collections.abc import Callable + + +def format_scientific_latex(value: float) -> str: + """Format a number in LaTeX scientific notation style.""" + if value == 0: + return r"$0$" + + import math + + exponent: int = int(math.floor(math.log10(abs(value)))) + mantissa: float = value / (10**exponent) + + return f"${mantissa:.2f} \\times 10^{{{exponent}}}$" + + +ModuleFilterSource = str | Callable[[str], bool] | set[str] | None +ModuleFilterFunc = Callable[[str], bool] diff --git a/spd/clustering/wandb_tensor_info.py b/spd/clustering/wandb_tensor_info.py new file mode 100644 index 000000000..0120c77e0 --- /dev/null +++ b/spd/clustering/wandb_tensor_info.py @@ -0,0 +1,203 @@ +"""Minimal WandB tensor logging utilities.""" + +import warnings +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import wandb +import wandb.sdk.wandb_run +from torch import Tensor + + +def _array_info(arr: Tensor | np.ndarray) -> dict[str, Any]: + """Get basic statistics about an array or tensor.""" + arr_np = arr.detach().cpu().numpy() if isinstance(arr, Tensor) else arr + + if arr_np.size == 0: + return { + "status": "empty", + "size": 0, + "shape": arr_np.shape, + "dtype": str(arr.dtype) if isinstance(arr, Tensor) else str(arr_np.dtype), + "has_nans": False, + "nan_percent": None, + "mean": None, + "median": None, + "std": None, + "min": None, + "max": None, + } + + has_nans = bool(np.isnan(arr_np).any()) + nan_count = int(np.isnan(arr_np).sum()) + nan_percent = 100.0 * nan_count / arr_np.size if arr_np.size > 0 else 0.0 + + # Compute stats ignoring NaNs + return { + "status": "ok", + "size": arr_np.size, + "shape": arr_np.shape, + "dtype": str(arr.dtype) if isinstance(arr, Tensor) else str(arr_np.dtype), + "has_nans": has_nans, + "nan_percent": nan_percent, + "mean": float(np.nanmean(arr_np)), + "median": float(np.nanmedian(arr_np)), + "std": float(np.nanstd(arr_np)), + "min": float(np.nanmin(arr_np)), + "max": float(np.nanmax(arr_np)), + } + + +def wandb_log_tensor( + run: wandb.sdk.wandb_run.Run, + data: Tensor | dict[str, Tensor], + name: str, + step: int, + single: bool = False, +) -> None: + """Log tensor(s) with stats to WandB as metrics and histograms. + + Args: + run: Current WandB run (None if WandB disabled) + data: Either a Tensor or dict[str, Tensor] + name: Name for logging + step: WandB step + single: True if this tensor is only logged once (component activations) + """ + try: + if isinstance(data, dict): + # Handle dict of tensors + for key, tensor in data.items(): + full_name: str = f"{name}.{key}" + _log_one(run, tensor, full_name, step, single=single) + else: + # Handle single tensor + _log_one(run, data, name, step, single=single) + except Exception as e: + warnings.warn(f"Failed to log tensor {name}: {e}") # noqa: B028 + raise e + + +def _create_histogram( + info: dict[str, Any], tensor: Tensor, name: str, logy: bool = True +) -> plt.Figure: + """Create matplotlib histogram with stats markers.""" + # sanity check + if info["status"] != "ok" or info["size"] == 0: + fig: plt.Figure + ax: plt.Axes + fig, ax = plt.subplots(figsize=(8, 6)) + ax.text(0.5, 0.5, f"{info['status']}", ha="center", va="center") + ax.set_title(f"{name} - {info['status']}") + return fig + + # make basic hist + values: np.ndarray = tensor.flatten().detach().cpu().numpy() + if info["has_nans"]: + values = values[~np.isnan(values)] + + fig, ax = plt.subplots(figsize=(10, 6)) + ax.hist(values, bins=50, alpha=0.7, edgecolor="black", linewidth=0.5) + + # Add stat lines + mean_val: float = info["mean"] or float("nan") + median_val: float = info["median"] or float("nan") + std_val: float = info["std"] or float("nan") + + if info["mean"] is not None: + ax.axvline( + mean_val, + color="red", + linestyle="-", + linewidth=2, + label="$\\mu$", + ) + ax.axvline( + median_val, + color="blue", + linestyle="-", + linewidth=2, + label="$\\tilde{x}$", + ) + if std_val: + ax.axvline( + mean_val + std_val, + color="orange", + linestyle="--", + linewidth=1.5, + alpha=0.8, + label="$\\mu+\\sigma$", + ) + ax.axvline( + mean_val - std_val, + color="orange", + linestyle="--", + linewidth=1.5, + alpha=0.8, + label="$\\mu-\\sigma$", + ) + + # Build informative title with tensor stats + shape_str: str = str(tuple(info["shape"])) if "shape" in info else "unknown" + dtype_str: str = str(info.get("dtype", "unknown")).replace("torch.", "") + + title_line1: str = f"{name}" + title_line2: str = f"shape={shape_str}, dtype={dtype_str}" + title_line3: str = ( + f"range=[{info['min']:.3g}, {info['max']:.3g}], " + f"$\\mu$={mean_val:.3g}, $\\tilde{{x}}$={median_val:.3g}, $\\sigma$={std_val:.3g}" + ) + + # Combine into multi-line title + full_title: str = f"{title_line1}\n{title_line2}\n{title_line3}" + ax.set_title(full_title, fontsize=10) + ax.set_xlabel("Value") + ax.set_ylabel("Count") + ax.legend() + ax.grid(True, alpha=0.3) + if logy: + ax.set_yscale("log") + + plt.tight_layout() + return fig + + +def _log_one( + run: wandb.sdk.wandb_run.Run, + tensor_: Tensor, + name: str, + step: int, + single: bool = False, + # use_log_counts: bool = True, +) -> None: + """Log a single tensor.""" + info: dict[str, Any] = _array_info(tensor_) + + if single: + # For single-use logging, log a single histogram as a figure + hist_fig: plt.Figure = _create_histogram(info=info, tensor=tensor_, name=name) + histogram_key: str = f"single_hists/{name}" + run.log({histogram_key: wandb.Image(hist_fig)}, step=step) + plt.close(hist_fig) # Close figure to free memory + else: + # Log numeric stats as metrics (viewable like loss) using dict comprehension + stats_to_log: dict[str, float | wandb.Histogram] = { + f"tensor_metrics/{name}/{key}": info[key] + for key in ["mean", "std", "median", "min", "max"] + if key in info and info[key] is not None + } + + # For regular logging, use wandb.Histogram directly + hist_key: str = f"tensor_histograms/{name}" + stats_to_log[hist_key] = wandb.Histogram(tensor_.flatten().cpu().numpy()) # pyright: ignore[reportArgumentType] + + # Add nan_percent if present + nan_percent: float | None = info["nan_percent"] + if nan_percent is None: + nan_percent = float("nan") + if nan_percent > 0: + stats_to_log[f"tensor_metrics/{name}/nan_percent"] = nan_percent + + if stats_to_log: + run.log(stats_to_log, step=step) diff --git a/spd/experiments/ih/configs.py b/spd/experiments/ih/configs.py index 4eabb5f5b..dd23b7110 100644 --- a/spd/experiments/ih/configs.py +++ b/spd/experiments/ih/configs.py @@ -33,7 +33,7 @@ class InductionHeadsTrainConfig(BaseConfig): class IHTaskConfig(BaseConfig): - task_name: Literal["induction_head"] + task_name: Literal["ih"] prefix_window: PositiveInt | None = Field( default=None, description="Number of tokens to use as a prefix window for the induction head. If none, uses the full sequence length.", diff --git a/spd/experiments/ih/ih_config.yaml b/spd/experiments/ih/ih_config.yaml index ecffb368f..c68aa0f48 100644 --- a/spd/experiments/ih/ih_config.yaml +++ b/spd/experiments/ih/ih_config.yaml @@ -11,7 +11,7 @@ lr: 1e-3 n_eval_steps: 200 task_config: - task_name: induction_head + task_name: ih n_mask_samples: 1 module_info: diff --git a/spd/harvest/harvest.py b/spd/harvest/harvest.py index f386299a1..c9387b592 100644 --- a/spd/harvest/harvest.py +++ b/spd/harvest/harvest.py @@ -169,8 +169,12 @@ def harvest( ) train_iter = iter(train_loader) - for _ in tqdm.tqdm(range(config.n_batches), desc="Harvesting"): - batch = extract_batch_data(next(train_iter)).to(device) + for batch_idx in tqdm.tqdm(range(config.n_batches), desc="Harvesting"): + try: + batch = extract_batch_data(next(train_iter)).to(device) + except StopIteration: + print(f"Dataset exhausted at batch {batch_idx}/{config.n_batches}. Finishing early.") + break with torch.no_grad(): out = model(batch, cache_type="input") @@ -240,7 +244,15 @@ def _harvest_worker( batches_processed = 0 last_log_time = time.time() for batch_idx in range(n_batches): - batch_data = extract_batch_data(next(train_iter)) + try: + batch_data = extract_batch_data(next(train_iter)) + except StopIteration: + print( + f"[Worker {rank}] Dataset exhausted at batch {batch_idx}/{n_batches}. " + f"Finishing early.", + flush=True, + ) + break if batch_idx % world_size != rank: continue diff --git a/spd/harvest/scripts/run_slurm.py b/spd/harvest/scripts/run_slurm.py index 8f07a896a..7f31e401d 100644 --- a/spd/harvest/scripts/run_slurm.py +++ b/spd/harvest/scripts/run_slurm.py @@ -7,32 +7,9 @@ spd-harvest --n_batches 8000 --n_gpus 8 """ -import subprocess -from datetime import datetime -from pathlib import Path - from spd.log import logger -from spd.settings import DEFAULT_PARTITION_NAME, REPO_ROOT - - -def _generate_job_id() -> str: - return datetime.now().strftime("%Y%m%d_%H%M%S") - - -def _submit_slurm_job(script_content: str, script_path: Path) -> str: - """Write script and submit to SLURM, returning job ID.""" - with open(script_path, "w") as f: - f.write(script_content) - script_path.chmod(0o755) - - result = subprocess.run( - ["sbatch", str(script_path)], capture_output=True, text=True, check=False - ) - if result.returncode != 0: - raise RuntimeError(f"Failed to submit SLURM job: {result.stderr}") - - job_id = result.stdout.strip().split()[-1] - return job_id +from spd.settings import DEFAULT_PARTITION_NAME +from spd.utils.slurm import SlurmConfig, generate_script, submit_slurm_job def harvest( @@ -61,15 +38,8 @@ def harvest( partition: SLURM partition name. time: Job time limit. """ - job_id = _generate_job_id() - slurm_logs_dir = Path.home() / "slurm_logs" - slurm_logs_dir.mkdir(exist_ok=True) - - sbatch_scripts_dir = Path.home() / "sbatch_scripts" - sbatch_scripts_dir.mkdir(exist_ok=True) - - gres = f"gpu:{n_gpus}" if n_gpus else "gpu:1" - job_name = f"harvest-{job_id}" + actual_n_gpus = n_gpus if n_gpus else 1 + job_name = "harvest" # Build the harvest command with all args cmd_parts = [ @@ -87,50 +57,40 @@ def harvest( harvest_cmd = " \\\n ".join(cmd_parts) - script_content = f"""\ -#!/bin/bash -#SBATCH --job-name={job_name} -#SBATCH --partition={partition} -#SBATCH --nodes=1 -#SBATCH --gres={gres} -#SBATCH --time={time} -#SBATCH --output={slurm_logs_dir}/slurm-%j.out - -set -euo pipefail - -echo "=== Harvest ===" -echo "WANDB_PATH: {wandb_path}" -echo "N_BATCHES: {n_batches}" -echo "N_GPUS: {n_gpus or 1}" -echo "SLURM_JOB_ID: $SLURM_JOB_ID" -echo "===============" - -cd {REPO_ROOT} -source .venv/bin/activate - -{harvest_cmd} - -echo "Harvest complete!" -""" - - script_path = sbatch_scripts_dir / f"harvest_{job_id}.sh" - slurm_job_id = _submit_slurm_job(script_content, script_path) - - # Rename to include SLURM job ID - final_script_path = sbatch_scripts_dir / f"harvest_{slurm_job_id}.sh" - script_path.rename(final_script_path) + # Build full command with echoes + full_command = "\n".join( + [ + 'echo "=== Harvest ==="', + f'echo "WANDB_PATH: {wandb_path}"', + f'echo "N_BATCHES: {n_batches}"', + f'echo "N_GPUS: {actual_n_gpus}"', + 'echo "SLURM_JOB_ID: $SLURM_JOB_ID"', + 'echo "==============="', + "", + harvest_cmd, + "", + 'echo "Harvest complete!"', + ] + ) - # Create empty log file for tailing - (slurm_logs_dir / f"slurm-{slurm_job_id}.out").touch() + config = SlurmConfig( + job_name=job_name, + partition=partition, + n_gpus=actual_n_gpus, + time=time, + snapshot_branch=None, # Harvest doesn't use git snapshots + ) + script_content = generate_script(config, full_command) + result = submit_slurm_job(script_content, "harvest") logger.section("Harvest job submitted!") logger.values( { - "Job ID": slurm_job_id, + "Job ID": result.job_id, "WandB path": wandb_path, "N batches": n_batches, - "N GPUs": n_gpus or 1, - "Log": f"~/slurm_logs/slurm-{slurm_job_id}.out", - "Script": str(final_script_path), + "N GPUs": actual_n_gpus, + "Log": result.log_pattern, + "Script": str(result.script_path), } ) diff --git a/spd/identity_insertion.py b/spd/identity_insertion.py index 5ed9cd3c6..6995693cc 100644 --- a/spd/identity_insertion.py +++ b/spd/identity_insertion.py @@ -58,7 +58,6 @@ def insert_identity_operations_( if unmatched: raise ValueError(f"Identity patterns did not match any modules: {sorted(unmatched)}") - # Add identity layers and hooks for module_path in identity_module_paths: module = target_model.get_submodule(module_path) @@ -72,5 +71,5 @@ def insert_identity_operations_( case _: raise ValueError(f"Module {module} not supported. type: {type(module)}") - module.pre_identity = Identity(d_in) # type: ignore + module.pre_identity = Identity(d_in) module.register_forward_pre_hook(pre_id_hook, with_kwargs=True) diff --git a/spd/scripts/compare_models/compare_models.py b/spd/scripts/compare_models/compare_models.py index 5c28e70ba..5df4ea239 100644 --- a/spd/scripts/compare_models/compare_models.py +++ b/spd/scripts/compare_models/compare_models.py @@ -97,7 +97,7 @@ def create_eval_data_loader(self) -> Iterator[Any]: "tms": self._create_tms_data_loader, "resid_mlp": self._create_resid_mlp_data_loader, "lm": self._create_lm_data_loader, - "induction_head": self._create_ih_data_loader, + "ih": self._create_ih_data_loader, } if task_name not in data_loader_fns: diff --git a/spd/scripts/run.py b/spd/scripts/run.py index 4b7acc67d..82a58c7a9 100644 --- a/spd/scripts/run.py +++ b/spd/scripts/run.py @@ -23,10 +23,10 @@ GPUS_PER_NODE, TrainingJob, create_slurm_array_script, - submit_slurm_array, ) from spd.utils.git_utils import create_git_snapshot from spd.utils.run_utils import apply_nested_updates, generate_grid_combinations, generate_run_name +from spd.utils.slurm import submit_slurm_job from spd.utils.wandb_utils import ReportCfg, create_view_and_report @@ -76,7 +76,7 @@ def launch_slurm_run( sweep_params=sweep_params, ) - snapshot_branch, commit_hash = create_git_snapshot(branch_name_prefix="run") + snapshot_branch, commit_hash = create_git_snapshot(run_id=run_id) logger.info(f"Created git snapshot branch: {snapshot_branch} ({commit_hash[:8]})") _wandb_setup( @@ -91,48 +91,33 @@ def launch_slurm_run( slurm_job_name = f"spd-{job_suffix or get_max_expected_runtime(experiments_list)}" - slurm_logs_dir = Path.home() / "slurm_logs" - slurm_logs_dir.mkdir(exist_ok=True) - array_script_content = create_slurm_array_script( slurm_job_name=slurm_job_name, run_id=run_id, training_jobs=training_jobs, sweep_params=sweep_params, - slurm_logs_dir=slurm_logs_dir, snapshot_branch=snapshot_branch, n_gpus=n_gpus, partition=partition, max_concurrent_tasks=n_agents, ) - # Save script to permanent location for debugging - sbatch_scripts_dir = Path.home() / "sbatch_scripts" - sbatch_scripts_dir.mkdir(exist_ok=True) - - array_script_path = sbatch_scripts_dir / f"run_array_{run_id}.sh" - with open(array_script_path, "w") as f: - f.write(array_script_content) - array_script_path.chmod(0o755) - array_job_id = submit_slurm_array(array_script_path) - - # Rename script to include job ID for easier correlation with logs - final_script_path = sbatch_scripts_dir / f"{array_job_id}.sh" - array_script_path.rename(final_script_path) - - # Quality of life: create empty log files for each job so you can tail - # them before waiting for the job to start - for i in range(len(training_jobs)): - (slurm_logs_dir / f"slurm-{array_job_id}_{i + 1}.out").touch() + # Submit script (handles file writing, submission, renaming, and log file creation) + result = submit_slurm_job( + array_script_content, + f"run_array_{run_id}", + is_array=True, + n_array_tasks=len(training_jobs), + ) logger.section("Job submitted successfully!") logger.values( { - "Array Job ID": array_job_id, + "Array Job ID": result.job_id, "Total training jobs": len(training_jobs), "Max concurrent tasks": n_agents, - "View logs in": f"~/slurm_logs/slurm-{array_job_id}_*.out", - "Script": str(final_script_path), + "View logs in": result.log_pattern, + "Script": str(result.script_path), } ) diff --git a/spd/settings.py b/spd/settings.py index f59206c65..47ef8c921 100644 --- a/spd/settings.py +++ b/spd/settings.py @@ -9,6 +9,10 @@ default_cache_dir = str(Path.home() / "spd_cache") SPD_CACHE_DIR = Path(os.environ.get("SPD_CACHE_DIR", default_cache_dir)) +# SLURM directories +SLURM_LOGS_DIR = Path.home() / "slurm_logs" +SBATCH_SCRIPTS_DIR = Path.home() / "sbatch_scripts" + # this is the gpu-enabled partition on the cluster # Not sure why we call it "default" instead of "gpu" or "compute" but keeping the convention here for consistency DEFAULT_PARTITION_NAME = "h200-reserved" diff --git a/spd/spd_types.py b/spd/spd_types.py index 1dc5d3032..012dc554a 100644 --- a/spd/spd_types.py +++ b/spd/spd_types.py @@ -1,7 +1,8 @@ from pathlib import Path from typing import Annotated, Literal -from pydantic import BeforeValidator, Field, PlainSerializer +from annotated_types import Ge, Le +from pydantic import BeforeValidator, PlainSerializer from spd.settings import REPO_ROOT @@ -45,7 +46,7 @@ def validate_path(v: str | Path) -> str | Path: ] -Probability = Annotated[float, Field(strict=True, ge=0, le=1)] +Probability = Annotated[float, Ge(0), Le(1)] TaskName = Literal["tms", "resid_mlp", "lm", "ih"] diff --git a/spd/utils/compute_utils.py b/spd/utils/compute_utils.py index ab669e345..f08e14521 100644 --- a/spd/utils/compute_utils.py +++ b/spd/utils/compute_utils.py @@ -2,14 +2,13 @@ import json import shlex -import subprocess from dataclasses import dataclass from hashlib import sha256 from pathlib import Path from typing import Any from spd.configs import Config -from spd.settings import REPO_ROOT +from spd.utils.slurm import SlurmArrayConfig, generate_array_script CUDA_FLAGS = { "NCCL_DEBUG": "WARN", @@ -124,7 +123,6 @@ def create_slurm_array_script( run_id: str, training_jobs: list[TrainingJob], sweep_params: dict[str, Any] | None, - slurm_logs_dir: Path, snapshot_branch: str, n_gpus: int | None, partition: str, @@ -132,6 +130,9 @@ def create_slurm_array_script( ) -> str: """Create a SLURM job array script with git snapshot for consistent code. + This is a thin wrapper around slurm.generate_array_script that handles + TrainingJob -> command string conversion and multi-node DDP setup. + Args: slurm_job_name: Name for the SLURM job array run_id: Unique identifier for the run. @@ -143,110 +144,31 @@ def create_slurm_array_script( partition: SLURM partition to use. max_concurrent_tasks: Maximum number of array tasks to run concurrently. If None, no limit. """ - n_jobs = len(training_jobs) - - # Create array range (SLURM arrays are 1-indexed) - if max_concurrent_tasks is not None: - array_range = f"1-{n_jobs}%{max_concurrent_tasks}" - else: - array_range = f"1-{n_jobs}" - - # Create case statement for commands (SLURM is 1-indexed, but we pass 0-indexed to get_command) - case_block_lines = [] + # Convert TrainingJobs to command strings + commands: list[str] = [] for i, training_job in enumerate(training_jobs): - command = get_command(run_id, training_job, i, n_gpus, sweep_params) - case_block_lines.append(f"{i + 1})") - if command.env_vars is not None: - for k, v in command.env_vars.items(): - case_block_lines.append(f" export {k}={v}") - case_block_lines.append(f" {command.command}") - case_block_lines.append(" ;;") - case_block = "\n".join(case_block_lines) - - # Compute SLURM resource allocation + cmd = get_command(run_id, training_job, i, n_gpus, sweep_params) + commands.append(cmd.command) + + # Compute SLURM resource allocation for multi-node DDP if n_gpus is None or n_gpus == 1: n_nodes = 1 - gpus_per_task = 1 + gpus_per_node = 1 elif n_gpus <= GPUS_PER_NODE: n_nodes = 1 - gpus_per_task = n_gpus + gpus_per_node = n_gpus else: n_nodes = n_gpus // GPUS_PER_NODE - gpus_per_task = GPUS_PER_NODE - - script_content = f"""\ -#!/bin/bash -#SBATCH --nodes={n_nodes} -#SBATCH --ntasks={n_nodes} -#SBATCH --gres=gpu:{gpus_per_task} - -#SBATCH --partition={partition} -#SBATCH --time=72:00:00 -#SBATCH --job-name={slurm_job_name} -#SBATCH --output={slurm_logs_dir}/slurm-%A_%a.out -#SBATCH --array={array_range} - -# Create job-specific working directory on shared filesystem (for multi-node access) -WORK_DIR="$HOME/slurm_workspaces/{slurm_job_name}-${{SLURM_ARRAY_JOB_ID}}_${{SLURM_ARRAY_TASK_ID}}" -mkdir -p "$WORK_DIR" - -# Clean up the workspace when the script exits -trap 'rm -rf "$WORK_DIR"' EXIT - -# Clone the repository to the job-specific directory -git clone {REPO_ROOT} "$WORK_DIR" - -# Change to the cloned repository directory -cd "$WORK_DIR" - -# Copy the .env file from the original repository for WandB authentication -cp {REPO_ROOT}/.env .env - -# Checkout the snapshot branch to ensure consistent code -git checkout "{snapshot_branch}" - -# Ensure that dependencies are using the snapshot branch. SLURM might inherit the -# parent environment, so we need to deactivate and unset the virtual environment. -echo "Deactivating virtual environment" -deactivate 2>/dev/null || true -unset VIRTUAL_ENV - -# echo "Syncing dependencies" -uv sync --no-dev --link-mode copy -q - - -echo "Activating virtual environment" -source .venv/bin/activate - -echo "Debug: SLURM_NODEID=$SLURM_NODEID" -echo "Debug: SLURM_PROCID=$SLURM_PROCID" -echo "Debug: SLURM_JOB_NODELIST=$SLURM_JOB_NODELIST" -echo "Debug: Master node=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" - -echo "Running..." -# Execute the appropriate command based on array task ID -case $SLURM_ARRAY_TASK_ID in -{case_block} -esac -""" - - return script_content - - -def submit_slurm_array(script_path: Path) -> str: - """Submit a SLURM job array and return the array job ID. - - Args: - script_path: Path to SLURM batch script - - Returns: - Array job ID from submitted job array - """ - result = subprocess.run( - ["sbatch", str(script_path)], capture_output=True, text=True, check=False + gpus_per_node = GPUS_PER_NODE + + config = SlurmArrayConfig( + job_name=slurm_job_name, + partition=partition, + n_gpus=gpus_per_node, + n_nodes=n_nodes, + snapshot_branch=snapshot_branch, + max_concurrent_tasks=max_concurrent_tasks, ) - if result.returncode != 0: - raise RuntimeError(f"Failed to submit SLURM job array: {result.stderr}") - # Extract job ID from sbatch output (format: "Submitted batch job 12345") - job_id = result.stdout.strip().split()[-1] - return job_id + + # CUDA_FLAGS are always set for training jobs + return generate_array_script(config, commands, env=CUDA_FLAGS) diff --git a/spd/utils/git_utils.py b/spd/utils/git_utils.py index d21bf240e..b9c0cf370 100644 --- a/spd/utils/git_utils.py +++ b/spd/utils/git_utils.py @@ -1,6 +1,5 @@ """Git utilities for creating code snapshots.""" -import datetime import subprocess import tempfile from pathlib import Path @@ -30,7 +29,32 @@ def repo_current_branch() -> str: return result.stdout.strip() -def create_git_snapshot(branch_name_prefix: str) -> tuple[str, str]: +def repo_is_clean(catch_except_as_false: bool = False) -> bool: + """Return True if the current git repository has no uncommitted or untracked changes. + + # TODO: this may error in CI environments: https://github.com/goodfire-ai/spd/actions/runs/18560369066/job/52907611203 + `fatal: detected dubious ownership in repository at '/__w/spd/spd'` + + for now, if `catch_except_as_false` is True, we catch any exceptions and return False. + + """ + try: + status: str = subprocess.check_output(["git", "status", "--porcelain"], text=True).strip() + return status == "" + except Exception as e: + if catch_except_as_false: + return False + else: + raise e + + +def repo_current_commit_hash() -> str: + """Return the current commit hash of the active HEAD.""" + commit_hash: str = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip() + return commit_hash + + +def create_git_snapshot(run_id: str) -> tuple[str, str]: """Create a git snapshot branch with current changes. Creates a timestamped branch containing all current changes (staged and unstaged). Uses a @@ -44,13 +68,12 @@ def create_git_snapshot(branch_name_prefix: str) -> tuple[str, str]: Raises: subprocess.CalledProcessError: If git commands fail (except for push) """ - # Generate timestamped branch name - timestamp_utc = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d-%H%M%S") - snapshot_branch = f"{branch_name_prefix}-{timestamp_utc}" + # prefix branch name + snapshot_branch: str = f"snapshot/{run_id}" # Create temporary worktree path with tempfile.TemporaryDirectory() as temp_dir: - worktree_path = Path(temp_dir) / f"spd-snapshot-{timestamp_utc}" + worktree_path = Path(temp_dir) / f"spd-snapshot-{run_id}" try: # Create worktree with new branch @@ -87,7 +110,7 @@ def create_git_snapshot(branch_name_prefix: str) -> tuple[str, str]: # Commit changes if any exist if diff_result.returncode != 0: # Non-zero means there are changes subprocess.run( - ["git", "commit", "-m", f"Sweep snapshot {timestamp_utc}", "--no-verify"], + ["git", "commit", "-m", f"run id {run_id}", "--no-verify"], cwd=worktree_path, check=True, capture_output=True, diff --git a/spd/utils/run_utils.py b/spd/utils/run_utils.py index ac207a922..3adc18de7 100644 --- a/spd/utils/run_utils.py +++ b/spd/utils/run_utils.py @@ -3,16 +3,26 @@ import copy import itertools import json +import os import secrets import string +import subprocess +import tempfile from pathlib import Path -from typing import Any +from typing import Any, Final, Literal, NamedTuple import torch import wandb import yaml +from spd.log import logger from spd.settings import DEFAULT_PROJECT_NAME, SPD_CACHE_DIR +from spd.utils.git_utils import ( + create_git_snapshot, + repo_current_branch, + repo_current_commit_hash, + repo_is_clean, +) # Fields that use discriminated union merging: field_name -> discriminator_field _DISCRIMINATED_LIST_FIELDS: dict[str, str] = { @@ -37,6 +47,7 @@ def get_local_run_id() -> str: return f"local-{random_suffix}" +# TODO: avoid using this function? def get_output_dir(use_wandb_id: bool = True) -> Path: """Get the output directory for a run. @@ -443,3 +454,192 @@ def generate_run_name(params: dict[str, Any]) -> str: parts.append(f"{param}-{value}") return "-".join(parts) + + +RunType = Literal["spd", "cluster", "ensemble"] + +RUN_TYPE_ABBREVIATIONS: Final[dict[RunType, str]] = { + "spd": "s", + "cluster": "c", + "ensemble": "e", +} + + +class ExecutionStamp(NamedTuple): + run_id: str + snapshot_branch: str + commit_hash: str + run_type: RunType + + @staticmethod + def _generate_run_id(run_type: RunType) -> str: + """Generate a unique run identifier, + + Format: `{type_abbr}-{random_hex}` + """ + type_abbr: str = RUN_TYPE_ABBREVIATIONS[run_type] + random_hex: str = secrets.token_hex(4) + return f"{type_abbr}-{random_hex}" + + @classmethod + def create( + cls, + run_type: RunType, + create_snapshot: bool, + ) -> "ExecutionStamp": + """create an execution stamp, possibly including a git snapshot branch""" + + run_id: str = ExecutionStamp._generate_run_id(run_type) + snapshot_branch: str + commit_hash: str + + if create_snapshot: + snapshot_branch, commit_hash = create_git_snapshot(run_id=run_id) + logger.info(f"Created git snapshot branch: {snapshot_branch} ({commit_hash[:8]})") + else: + snapshot_branch = repo_current_branch() + if repo_is_clean(catch_except_as_false=True): + commit_hash = repo_current_commit_hash() + logger.info(f"Using current branch: {snapshot_branch} ({commit_hash[:8]})") + else: + commit_hash = "none" + logger.info( + f"Using current branch: {snapshot_branch} (unpushed changes, no commit hash)" + ) + + return ExecutionStamp( + run_id=run_id, + snapshot_branch=snapshot_branch, + commit_hash=commit_hash, + run_type=run_type, + ) + + @property + def out_dir(self) -> Path: + """Get the output directory for this execution stamp.""" + run_dir = SPD_CACHE_DIR / self.run_type / self.run_id + run_dir.mkdir(parents=True, exist_ok=True) + return run_dir + + +_NO_ARG_PARSSED_SENTINEL = object() + + +def read_noneable_str(value: str) -> str | None: + """Read a string that may be 'None' and convert to None.""" + if value == "None": + return None + return value + + +def run_locally( + commands: list[str], + parallel: bool = False, + track_resources: bool = False, +) -> dict[str, dict[str, float]] | None: + """Run commands locally instead of via SLURM. + + Useful for testing and for --local mode in clustering pipeline. + + Args: + commands: List of shell commands to run + parallel: If True, run all commands in parallel. If False, run sequentially. + track_resources: If True, track and return resource usage via /usr/bin/time + + Returns: + If track_resources is True, dict mapping commands to resource metrics. + Metrics include: K (avg memory KB), M (max memory KB), P (CPU %), + S (system CPU sec), U (user CPU sec), e (wall time sec). + Otherwise None. + """ + n_commands = len(commands) + resources: dict[str, dict[str, float]] = {} + resource_files: list[Path] = [] + + # Wrap commands with /usr/bin/time if resource tracking is requested + if track_resources: + wrapped_commands: list[str] = [] + for cmd in commands: + # Create a unique temp file for resource tracking output + fd, resource_file_path = tempfile.mkstemp(suffix=".resources") + os.close(fd) # Close fd, we just need the path for /usr/bin/time -o + resource_file = Path(resource_file_path) + resource_files.append(resource_file) + # Use /usr/bin/time to track comprehensive resource usage + # K=avg total mem, M=max resident, P=CPU%, S=system time, U=user time, e=wall time + wrapped_cmd = ( + f'/usr/bin/time -f "K:%K M:%M P:%P S:%S U:%U e:%e" -o {resource_file} {cmd}' + ) + wrapped_commands.append(wrapped_cmd) + commands_to_run = wrapped_commands + else: + commands_to_run = commands + + try: + if not parallel: + logger.section(f"LOCAL EXECUTION: Running {n_commands} tasks serially") + for i, cmd in enumerate(commands_to_run, 1): + logger.info(f"[{i}/{n_commands}] Running: {commands[i - 1]}") + subprocess.run(cmd, shell=True, check=True) + logger.section("LOCAL EXECUTION COMPLETE") + else: + logger.section(f"LOCAL EXECUTION: Starting {n_commands} tasks in parallel") + procs: list[subprocess.Popen[bytes]] = [] + + for i, cmd in enumerate(commands_to_run, 1): + logger.info(f"[{i}/{n_commands}] Starting: {commands[i - 1]}") + proc = subprocess.Popen(cmd, shell=True) + procs.append(proc) + + logger.section("WAITING FOR ALL TASKS TO COMPLETE") + for proc, cmd in zip(procs, commands, strict=True): # noqa: B007 + proc.wait() + if proc.returncode != 0: + logger.error(f"Process {proc.pid} failed with exit code {proc.returncode}") + logger.section("LOCAL EXECUTION COMPLETE") + + # Read resource usage results + if track_resources: + for cmd, resource_file in zip(commands, resource_files, strict=True): + if resource_file.exists(): + # Parse format: "K:123 M:456 P:78% S:1.23 U:4.56 e:7.89" + output = resource_file.read_text().strip() + metrics: dict[str, float] = {} + + for part in output.split(): + if ":" in part: + key, value = part.split(":", 1) + # Remove % sign from CPU percentage + value = value.rstrip("%") + try: + metrics[key] = float(value) + except ValueError: + logger.warning(f"Could not parse {key}:{value} for command: {cmd}") + + resources[cmd] = metrics + else: + logger.warning(f"Resource file not found for: {cmd}") + + # Log comprehensive resource usage table + logger.section("RESOURCE USAGE RESULTS") + for cmd, metrics in resources.items(): + logger.info(f"Command: {cmd}") + logger.info( + f" Time: {metrics.get('e', 0):.2f}s wall, " + f"{metrics.get('U', 0):.2f}s user, " + f"{metrics.get('S', 0):.2f}s system" + ) + logger.info( + f" Memory: {metrics.get('M', 0) / 1024:.1f} MB peak, " + f"{metrics.get('K', 0) / 1024:.1f} MB avg" + ) + logger.info(f" CPU: {metrics.get('P', 0):.1f}%") + + finally: + # Clean up temp files + if track_resources: + for resource_file in resource_files: + if resource_file.exists(): + resource_file.unlink() + + return resources if track_resources else None diff --git a/spd/utils/slurm.py b/spd/utils/slurm.py new file mode 100644 index 000000000..16b9f9c30 --- /dev/null +++ b/spd/utils/slurm.py @@ -0,0 +1,355 @@ +"""Unified SLURM job submission utilities. + +This module provides a single source of truth for generating and submitting SLURM jobs. +It handles: +- SBATCH header generation +- Workspace creation with cleanup +- Git snapshot checkout (optional) +- Virtual environment activation +- Job submission with script renaming and log file creation + +For SPD-specific training jobs with multi-node DDP, see compute_utils.py which +uses this module internally. +""" + +import subprocess +import tempfile +from dataclasses import dataclass +from pathlib import Path + +from spd.settings import REPO_ROOT, SBATCH_SCRIPTS_DIR, SLURM_LOGS_DIR + + +@dataclass +class SlurmConfig: + """Configuration for a SLURM job. + + Attributes: + job_name: Name for the SLURM job (appears in squeue) + partition: SLURM partition to submit to + n_gpus: Number of GPUs per node (0 for CPU-only jobs) + n_nodes: Number of nodes (default 1) + n_tasks: Number of tasks (defaults to n_nodes if not set, for multi-node DDP) + time: Time limit in HH:MM:SS format + cpus_per_task: CPUs per task (for CPU-bound jobs like autointerp) + snapshot_branch: Git branch to checkout. If None, just cd to REPO_ROOT without cloning. + dependency_job_id: If set, job waits for this job to complete (afterok dependency) + """ + + job_name: str + partition: str + n_gpus: int = 1 + n_nodes: int = 1 + n_tasks: int | None = None + time: str = "72:00:00" + cpus_per_task: int | None = None + snapshot_branch: str | None = None + dependency_job_id: str | None = None + + +@dataclass +class SlurmArrayConfig(SlurmConfig): + """Configuration for a SLURM job array. + + Attributes: + max_concurrent_tasks: Maximum number of array tasks to run concurrently. + If None, no limit (all tasks can run at once). + """ + + max_concurrent_tasks: int | None = None + + +@dataclass +class SubmitResult: + """Result of submitting a SLURM job. + + Attributes: + job_id: The SLURM job ID (string, e.g., "12345") + script_path: Path where the script was saved (renamed to include job ID) + log_pattern: Human-readable log path pattern for display + """ + + job_id: str + script_path: Path + log_pattern: str + + +def generate_script(config: SlurmConfig, command: str, env: dict[str, str] | None = None) -> str: + """Generate a single SLURM job script. + + Args: + config: SLURM job configuration + command: The shell command to run + env: Optional environment variables to export at the start of the script + + Returns: + Complete SLURM script content as a string + """ + header = _generate_sbatch_header(config, is_array=False) + setup = _generate_setup_section(config, is_array=False) + env_exports = _generate_env_exports(env) + + return f"""\ +#!/bin/bash +{header} + +set -euo pipefail +{env_exports} +{setup} + +{command} +""" + + +def generate_array_script( + config: SlurmArrayConfig, + commands: list[str], + env: dict[str, str] | None = None, +) -> str: + """Generate a SLURM job array script. + + Each command in the list becomes one array task. Commands are executed via + a case statement based on SLURM_ARRAY_TASK_ID. + + Args: + config: SLURM array job configuration + commands: List of shell commands, one per array task + env: Optional environment variables to export at the start of the script + + Returns: + Complete SLURM array script content as a string + + Raises: + ValueError: If commands list is empty + """ + if not commands: + raise ValueError("Cannot generate array script with empty commands list") + + n_jobs = len(commands) + + # Build array range (SLURM arrays are 1-indexed) + if config.max_concurrent_tasks is not None: + array_range = f"1-{n_jobs}%{config.max_concurrent_tasks}" + else: + array_range = f"1-{n_jobs}" + + header = _generate_sbatch_header(config, is_array=True, array_range=array_range) + setup = _generate_setup_section(config, is_array=True) + env_exports = _generate_env_exports(env) + case_block = _generate_case_block(commands) + + return f"""\ +#!/bin/bash +{header} + +set -euo pipefail +{env_exports} +{setup} + +# Execute the appropriate command based on array task ID +case $SLURM_ARRAY_TASK_ID in +{case_block} +esac +""" + + +def submit_slurm_job( + script_content: str, + script_name_prefix: str, + is_array: bool = False, + n_array_tasks: int | None = None, +) -> SubmitResult: + """Write script to disk, submit to SLURM, and set up logging. + + This function: + 1. Writes script to SBATCH_SCRIPTS_DIR with a unique temporary name + 2. Submits via sbatch + 3. Renames script to include the SLURM job ID + 4. Creates empty log file(s) for tailing + + Args: + script_content: The SLURM script content + script_name_prefix: Prefix for script filename (e.g., "harvest", "clustering") + is_array: Whether this is an array job (affects log file creation) + n_array_tasks: Number of array tasks (required if is_array=True) + + Returns: + SubmitResult with job ID, script path, and log pattern + """ + SBATCH_SCRIPTS_DIR.mkdir(exist_ok=True) + SLURM_LOGS_DIR.mkdir(exist_ok=True) + + # Write script to a unique temporary file (safe for concurrent submissions) + with tempfile.NamedTemporaryFile( + mode="w", + dir=SBATCH_SCRIPTS_DIR, + prefix=f"{script_name_prefix}_", + suffix=".sh", + delete=False, + ) as f: + f.write(script_content) + temp_script_path = Path(f.name) + temp_script_path.chmod(0o755) + + # Submit via sbatch + job_id = _submit_script(temp_script_path) + + # Rename script to include job ID + final_script_path = SBATCH_SCRIPTS_DIR / f"{script_name_prefix}_{job_id}.sh" + temp_script_path.rename(final_script_path) + + # Create empty log file(s) for tailing + if is_array: + assert n_array_tasks is not None, "n_array_tasks required for array jobs" + for i in range(1, n_array_tasks + 1): + (SLURM_LOGS_DIR / f"slurm-{job_id}_{i}.out").touch() + log_pattern = str(SLURM_LOGS_DIR / f"slurm-{job_id}_*.out") + else: + (SLURM_LOGS_DIR / f"slurm-{job_id}.out").touch() + log_pattern = str(SLURM_LOGS_DIR / f"slurm-{job_id}.out") + + return SubmitResult( + job_id=job_id, + script_path=final_script_path, + log_pattern=log_pattern, + ) + + +# ============================================================================= +# Internal helpers +# ============================================================================= + + +def _generate_sbatch_header( + config: SlurmConfig, + is_array: bool = False, + array_range: str | None = None, +) -> str: + """Generate the #SBATCH directive block. + + Handles: + - --job-name, --partition, --nodes, --gres, --time, --output + - --ntasks (for multi-node DDP) + - --cpus-per-task (for CPU-bound jobs) + - --array (for array jobs) + - --dependency (if dependency_job_id is set) + """ + n_tasks = config.n_tasks if config.n_tasks is not None else config.n_nodes + + # Use %A_%a for array jobs, %j for single jobs + log_pattern = "%A_%a" if is_array else "%j" + + lines = [ + f"#SBATCH --job-name={config.job_name}", + f"#SBATCH --partition={config.partition}", + f"#SBATCH --nodes={config.n_nodes}", + f"#SBATCH --ntasks={n_tasks}", + f"#SBATCH --gres=gpu:{config.n_gpus}", + f"#SBATCH --time={config.time}", + f"#SBATCH --output={SLURM_LOGS_DIR}/slurm-{log_pattern}.out", + ] + + if config.cpus_per_task is not None: + lines.append(f"#SBATCH --cpus-per-task={config.cpus_per_task}") + + if is_array and array_range: + lines.append(f"#SBATCH --array={array_range}") + + if config.dependency_job_id: + lines.append(f"#SBATCH --dependency=afterok:{config.dependency_job_id}") + + return "\n".join(lines) + + +def _generate_setup_section(config: SlurmConfig, is_array: bool) -> str: + """Generate workspace creation and git/venv setup. + + If snapshot_branch is set: + - Create workspace dir with trap for cleanup + - Clone repo to workspace + - Copy .env file + - Checkout snapshot branch + - uv sync and activate venv + + If snapshot_branch is None: + - Just cd to REPO_ROOT + - Activate existing venv + """ + # Workspace directory naming + if is_array: + workspace_suffix = "${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}" + else: + workspace_suffix = "$SLURM_JOB_ID" + + if config.snapshot_branch is not None: + # Full git snapshot setup + return f"""\ +# Create job-specific working directory +WORK_DIR="$HOME/slurm_workspaces/{config.job_name}-{workspace_suffix}" +mkdir -p "$WORK_DIR" + +# Clean up the workspace when the script exits +trap 'rm -rf "$WORK_DIR"' EXIT + +# Clone the repository to the job-specific directory +git clone "{REPO_ROOT}" "$WORK_DIR" + +# Change to the cloned repository directory +cd "$WORK_DIR" + +# Copy the .env file from the original repository for WandB authentication (if it exists) +[ -f "{REPO_ROOT}/.env" ] && cp "{REPO_ROOT}/.env" .env + +# Checkout the snapshot branch to ensure consistent code +git checkout "{config.snapshot_branch}" + +# Ensure that dependencies are using the snapshot branch +deactivate 2>/dev/null || true +unset VIRTUAL_ENV +uv sync --no-dev --link-mode copy -q +source .venv/bin/activate""" + else: + # Simple setup without git clone + return f"""\ +cd "{REPO_ROOT}" +source .venv/bin/activate""" + + +def _generate_env_exports(env: dict[str, str] | None) -> str: + """Generate export statements for environment variables. + + Returns empty string if env is None or empty, otherwise returns + export statements with a leading newline for proper formatting. + """ + if not env: + return "" + exports = "\n".join(f"export {k}={v}" for k, v in env.items()) + return f"\n{exports}" + + +def _generate_case_block(commands: list[str]) -> str: + """Generate bash case statement for array jobs. + + SLURM arrays are 1-indexed, so command[0] goes in case 1). + """ + lines = [] + for i, cmd in enumerate(commands): + lines.append(f" {i + 1})") + lines.append(f" {cmd}") + lines.append(" ;;") + return "\n".join(lines) + + +def _submit_script(script_path: Path) -> str: + """Submit script via sbatch and return job ID. + + Raises RuntimeError if sbatch fails. + """ + result = subprocess.run( + ["sbatch", str(script_path)], capture_output=True, text=True, check=False + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to submit SLURM job: {result.stderr}") + # Extract job ID from sbatch output (format: "Submitted batch job 12345") + job_id = result.stdout.strip().split()[-1] + return job_id diff --git a/spd/utils/wandb_utils.py b/spd/utils/wandb_utils.py index 2edc86dd9..8c4639351 100644 --- a/spd/utils/wandb_utils.py +++ b/spd/utils/wandb_utils.py @@ -10,9 +10,9 @@ import wandb_workspaces.reports.v2 as wr import wandb_workspaces.workspaces as ws from dotenv import load_dotenv -from pydantic import BaseModel from wandb.apis.public import File, Run +from spd.base_config import BaseConfig from spd.log import logger from spd.registry import EXPERIMENT_REGISTRY from spd.settings import REPO_ROOT @@ -172,7 +172,7 @@ def download_wandb_file(run: Run, wandb_run_dir: Path, file_name: str) -> Path: return path -def init_wandb[T_config: BaseModel]( +def init_wandb[T_config: BaseConfig]( config: T_config, project: str, name: str | None = None, tags: list[str] | None = None ) -> T_config: """Initialize Weights & Biases and return a config updated with sweep hyperparameters. @@ -188,6 +188,7 @@ def init_wandb[T_config: BaseModel]( """ load_dotenv(override=True) + # TODO: pass run id from ExecutionStamp wandb.init( project=project, entity=os.getenv("WANDB_ENTITY"), diff --git a/tests/clustering/math/test_perm_invariant_hamming.py b/tests/clustering/math/test_perm_invariant_hamming.py new file mode 100644 index 000000000..7d2bf4740 --- /dev/null +++ b/tests/clustering/math/test_perm_invariant_hamming.py @@ -0,0 +1,123 @@ +from itertools import permutations + +import numpy as np +import pytest + +from spd.clustering.math.perm_invariant_hamming import perm_invariant_hamming_matrix + +# pyright complains about the types when calling perm_invariant_hamming +# pyright: reportCallIssue=false + + +def brute_force_min_hamming(a: np.ndarray, b: np.ndarray) -> int: + """Exhaustive check for small k.""" + k = int(max(a.max(), b.max()) + 1) + best = len(a) + for perm in permutations(range(k)): + mapping = np.array(perm) + best = min(best, int((mapping[a] != b).sum())) + return best + + +def test_identity() -> None: + """a == b should give distance 0.""" + a = np.array([0, 1, 2, 1, 0]) + b = a.copy() + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + # Distance between row 1 and row 0 should be 0 + assert D[1, 0] == 0 + + +def test_all_one_group() -> None: + """All rows belong to one group in both arrays (possibly different labels).""" + a = np.zeros(10, dtype=int) + b = np.ones(10, dtype=int) # different label but identical grouping + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + assert D[1, 0] == 0 + + +def test_permuted_labels() -> None: + a = np.array([0, 2, 1, 1, 0]) + b = np.array([1, 0, 0, 2, 1]) + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + assert D[1, 0] == 1 + + +def test_swap_two_labels() -> None: + a = np.array([0, 0, 1, 1]) + b = np.array([1, 1, 0, 0]) + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + assert D[1, 0] == 0 + + +def test_random_small_bruteforce() -> None: + rng = np.random.default_rng(0) + for _ in range(50): + n = 7 + k = 3 + a = rng.integers(0, k, size=n) + b = rng.integers(0, k, size=n) + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + d_alg = D[1, 0] + d_true = brute_force_min_hamming(a, b) + assert d_alg == d_true + + +def test_shape_mismatch() -> None: + a = np.array([0, 1, 2]) + b = np.array([0, 1]) + with pytest.raises((ValueError, IndexError)): + # This should fail when trying to create the matrix due to shape mismatch + X = np.array([a, b]) + perm_invariant_hamming_matrix(X) + + +def test_matrix_multiple_pairs() -> None: + """Test the matrix function with multiple label vectors.""" + a = np.array([0, 0, 1, 1]) + b = np.array([2, 2, 3, 3]) # Should be distance 0 (perfect mapping) + c = np.array([0, 1, 0, 1]) # Should be distance 2 from both a and b + X = np.array([a, b, c]) + D = perm_invariant_hamming_matrix(X) + + assert D[1, 0] == 0 # a and b should have distance 0 + assert D[2, 0] == 2 # a and c should have distance 2 + assert D[2, 1] == 2 # b and c should have distance 2 + + +def test_matrix_upper_triangle_nan() -> None: + """Test that upper triangle and diagonal are NaN.""" + a = np.array([0, 1, 0]) + b = np.array([1, 0, 1]) + c = np.array([0, 0, 1]) + X = np.array([a, b, c]) + D = perm_invariant_hamming_matrix(X) + + # Diagonal should be NaN + assert np.isnan(D[0, 0]) + assert np.isnan(D[1, 1]) + assert np.isnan(D[2, 2]) + + # Upper triangle should be NaN + assert np.isnan(D[0, 1]) + assert np.isnan(D[0, 2]) + assert np.isnan(D[1, 2]) + + # Lower triangle should have actual distances + assert not np.isnan(D[1, 0]) + assert not np.isnan(D[2, 0]) + assert not np.isnan(D[2, 1]) + + +def test_unused_labels() -> None: + """Test when arrays don't use all labels 0..k-1.""" + a = np.array([0, 0, 3, 3]) # skips 1, 2 + b = np.array([1, 1, 2, 2]) + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + assert D[1, 0] == 0 diff --git a/tests/clustering/test_calc_distances.py b/tests/clustering/test_calc_distances.py new file mode 100644 index 000000000..b06350f4b --- /dev/null +++ b/tests/clustering/test_calc_distances.py @@ -0,0 +1,31 @@ +from spd.clustering.consts import ComponentLabels +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble + + +def test_merge_history_normalization_happy_path(): + """Test that the normalization part of calc_distances.py works without errors""" + + # Create test merge histories + config = MergeConfig( + iters=3, + alpha=1.0, + activation_threshold=None, + ) + + histories = [] + for _idx in range(2): + history = MergeHistory.from_config( + merge_config=config, + labels=ComponentLabels([f"comp{j}" for j in range(4)]), + ) + histories.append(history) + + # Test ensemble creation + ensemble = MergeHistoryEnsemble(data=histories) + assert len(ensemble.data) == 2 + + # Test normalization + normalized_array, metadata = ensemble.normalized() + assert normalized_array is not None + assert metadata is not None diff --git a/tests/clustering/test_ensemble_registry.py b/tests/clustering/test_ensemble_registry.py new file mode 100644 index 000000000..c903af801 --- /dev/null +++ b/tests/clustering/test_ensemble_registry.py @@ -0,0 +1,110 @@ +"""Tests for ensemble_registry module.""" + +import tempfile +from pathlib import Path +from typing import Any + +import pytest + +from spd.clustering.ensemble_registry import ( + get_clustering_runs, + register_clustering_run, +) + + +@pytest.fixture +def _temp_registry_db(monkeypatch: Any): # pyright: ignore[reportUnusedFunction] + """Create a temporary registry database for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + temp_db_path = Path(tmpdir) / "test_registry.db" + monkeypatch.setattr("spd.clustering.ensemble_registry._ENSEMBLE_REGISTRY_DB", temp_db_path) + yield temp_db_path + + +class TestRegisterClusteringRun: + """Test register_clustering_run() function.""" + + def test_register_single_run(self, _temp_registry_db: Any): + """Test registering a single run.""" + pipeline_id = "pipeline_001" + run_id = "run_001" + + assigned_idx = register_clustering_run(pipeline_id, run_id) + + # First index should be 0 + assert assigned_idx == 0 + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001")] + + def test_register_multiple_runs(self, _temp_registry_db: Any): + """Test registering multiple runs sequentially.""" + pipeline_id = "pipeline_002" + + idx0 = register_clustering_run(pipeline_id, "run_001") + idx1 = register_clustering_run(pipeline_id, "run_002") + idx2 = register_clustering_run(pipeline_id, "run_003") + + # Should auto-assign 0, 1, 2 + assert idx0 == 0 + assert idx1 == 1 + assert idx2 == 2 + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001"), (1, "run_002"), (2, "run_003")] + + def test_different_pipelines_independent(self, _temp_registry_db: Any): + """Test that different pipelines have independent index sequences.""" + pipeline_a = "pipeline_a" + pipeline_b = "pipeline_b" + + # Both should start at 0 when auto-assigning + idx_a0 = register_clustering_run(pipeline_a, "run_a1") + idx_b0 = register_clustering_run(pipeline_b, "run_b1") + + assert idx_a0 == 0 + assert idx_b0 == 0 + + # Both should increment independently + idx_a1 = register_clustering_run(pipeline_a, "run_a2") + idx_b1 = register_clustering_run(pipeline_b, "run_b2") + + assert idx_a1 == 1 + assert idx_b1 == 1 + + # Verify in database + runs_a = get_clustering_runs(pipeline_a) + runs_b = get_clustering_runs(pipeline_b) + + assert runs_a == [(0, "run_a1"), (1, "run_a2")] + assert runs_b == [(0, "run_b1"), (1, "run_b2")] + + +class TestGetClusteringRuns: + """Test get_clustering_runs() function.""" + + def test_get_empty_pipeline(self, _temp_registry_db: Any): + """Test getting runs from a pipeline that doesn't exist.""" + runs = get_clustering_runs("nonexistent_pipeline") + assert runs == [] + + def test_get_runs_sorted_by_index(self, _temp_registry_db: Any): + """Test that runs are returned sorted by index.""" + pipeline_id = "pipeline_sort" + + # Register runs (indices will be auto-assigned in order) + register_clustering_run(pipeline_id, "run_000") + register_clustering_run(pipeline_id, "run_001") + register_clustering_run(pipeline_id, "run_002") + register_clustering_run(pipeline_id, "run_003") + + # Should be returned in sorted order + runs = get_clustering_runs(pipeline_id) + assert runs == [ + (0, "run_000"), + (1, "run_001"), + (2, "run_002"), + (3, "run_003"), + ] diff --git a/tests/clustering/test_filter_dead_components.py b/tests/clustering/test_filter_dead_components.py new file mode 100644 index 000000000..654631f37 --- /dev/null +++ b/tests/clustering/test_filter_dead_components.py @@ -0,0 +1,131 @@ +"""Tests for filter_dead_components function in activations.py""" + +import pytest +import torch +from torch import Tensor + +from spd.clustering.activations import FilteredActivations, filter_dead_components +from spd.clustering.consts import ComponentLabels + + +@pytest.mark.parametrize( + "max_values,threshold,expected_alive_indices", + [ + # No filtering when threshold is 0 + ([0.1, 0.2, 0.3], 0.0, [0, 1, 2]), + # Filter all when all below threshold + ([0.005, 0.003, 0.004], 0.01, []), + # Filter some components + ([0.0, 0.02, 0.0, 0.03, 0.0], 0.01, [1, 3]), + # Boundary cases: at threshold is kept + ([0.009, 0.01, 0.011], 0.01, [1, 2]), + # High threshold filters everything + ([0.1, 0.2, 0.3], 2.0, []), + # Negative threshold filters nothing + ([0.1, 0.2, 0.3], -0.01, [0, 1, 2]), + # Single component above threshold + ([0.5], 0.01, [0]), + ], +) +def test_filter_dead_components_thresholds( + max_values: list[float], + threshold: float, + expected_alive_indices: list[int], +) -> None: + """Test filtering with various max values and thresholds.""" + n_steps: int = 10 + n_components: int = len(max_values) + + activations: Tensor + labels: ComponentLabels + if n_components == 0: + activations = torch.zeros(n_steps, 0) + labels = ComponentLabels([]) + else: + activations = torch.zeros(n_steps, n_components) + # Set max values in first row + for i, val in enumerate(max_values): + activations[0, i] = val + labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + result: FilteredActivations = filter_dead_components( + activations=activations, labels=labels, filter_dead_threshold=threshold + ) + + assert result.labels == [f"comp_{i}" for i in expected_alive_indices] + assert result.n_alive == len(expected_alive_indices) + assert result.n_dead == n_components - len(expected_alive_indices) + assert result.activations.shape == (n_steps, len(expected_alive_indices)) + + # Check dead components labels + if threshold <= 0 or all(v >= threshold for v in max_values): + # No filtering occurred + assert result.dead_components_labels is None or result.dead_components_labels == [] + else: + dead_indices: list[int] = [ + i for i in range(n_components) if i not in expected_alive_indices + ] + expected_dead: list[str] = [f"comp_{i}" for i in dead_indices] + assert result.dead_components_labels is not None + assert set(result.dead_components_labels) == set(expected_dead) + + +@pytest.mark.parametrize( + "step_locations,threshold", + [ + # Max at different steps + ([0, 5, 9], 0.01), + # All at same step + ([0, 0, 0], 0.01), + # Random steps + ([3, 7, 1, 8], 0.05), + ], +) +def test_max_across_steps(step_locations: list[int], threshold: float) -> None: + """Verify that filter_dead_components correctly finds the maximum activation + across ALL time steps for each component, not just looking at a single step. + + This test creates components where the maximum activation occurs at different + time steps, ensuring the function scans the entire temporal dimension.""" + n_steps: int = 10 + n_components: int = len(step_locations) + activations: Tensor = torch.zeros(n_steps, n_components) + + # Set values above threshold at specified steps + for i, step in enumerate(step_locations): + activations[step, i] = threshold + 0.01 + + labels: ComponentLabels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + result: FilteredActivations = filter_dead_components( + activations=activations, labels=labels, filter_dead_threshold=threshold + ) + + # All components should be alive since their max is above threshold + assert result.n_alive == n_components + assert result.n_dead == 0 + assert result.labels == labels + + +@pytest.mark.parametrize("threshold", [0.001, 0.01, 0.1, 0.5]) +def test_linear_gradient_thresholds(threshold: float) -> None: + """Test with linearly spaced activation values.""" + n_steps: int = 10 + n_components: int = 10 + activations: Tensor = torch.zeros(n_steps, n_components) + + # Create linearly spaced max values: 0, 0.1, 0.2, ..., 0.9 + for i in range(n_components): + activations[0, i] = i * 0.1 + + labels: list[str] = [f"comp_{i}" for i in range(n_components)] + + result: FilteredActivations = filter_dead_components( + activations=activations, labels=ComponentLabels(labels), filter_dead_threshold=threshold + ) + + # Count how many components should be alive + expected_alive: int = sum(i * 0.1 >= threshold for i in range(n_components)) + + assert result.n_alive == expected_alive + assert result.n_dead == n_components - expected_alive diff --git a/tests/clustering/test_merge_config.py b/tests/clustering/test_merge_config.py new file mode 100644 index 000000000..63f4e88f7 --- /dev/null +++ b/tests/clustering/test_merge_config.py @@ -0,0 +1,179 @@ +"""Tests for MergeConfig with new sampling system.""" + +import pytest +import torch + +from spd.clustering.merge_config import MergeConfig + + +class TestMergeConfigSampling: + """Test MergeConfig integration with sampling system.""" + + def test_default_config(self): + """Test default MergeConfig uses range sampler.""" + config = MergeConfig() + + assert config.merge_pair_sampling_method == "range" + assert config.merge_pair_sampling_kwargs == {"threshold": 0.05} + + def test_range_sampler_config(self): + """Test MergeConfig with range sampler.""" + config = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1} + ) + + assert config.merge_pair_sampling_method == "range" + assert config.merge_pair_sampling_kwargs == {"threshold": 0.1} + + # Test that sampler works + k = 4 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + pair = config.merge_pair_sample(costs) + + assert isinstance(pair, tuple) + assert len(pair) == 2 + assert pair[0] != pair[1] + + def test_mcmc_sampler_config(self): + """Test MergeConfig with MCMC sampler.""" + config = MergeConfig( + merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 2.0} + ) + + assert config.merge_pair_sampling_method == "mcmc" + assert config.merge_pair_sampling_kwargs == {"temperature": 2.0} + + # Test that sampler works + k = 4 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + pair = config.merge_pair_sample(costs) + + assert isinstance(pair, tuple) + assert len(pair) == 2 + assert pair[0] != pair[1] + + def test_invalid_sampler_method(self): + """Test that invalid sampler method raises error.""" + from pydantic import ValidationError + + # Pydantic validates at construction time + with pytest.raises(ValidationError): + _config = MergeConfig(merge_pair_sampling_method="invalid") # pyright: ignore[reportArgumentType] + + def test_config_with_all_parameters(self): + """Test MergeConfig with all parameters set.""" + config = MergeConfig( + activation_threshold=0.01, + alpha=1.5, + iters=200, + merge_pair_sampling_method="mcmc", + merge_pair_sampling_kwargs={"temperature": 0.5}, + filter_dead_threshold=0.001, + module_name_filter="model.layers", + ) + + assert config.activation_threshold == 0.01 + assert config.alpha == 1.5 + assert config.iters == 200 + assert config.merge_pair_sampling_method == "mcmc" + assert config.merge_pair_sampling_kwargs == {"temperature": 0.5} + assert config.filter_dead_threshold == 0.001 + assert config.module_name_filter == "model.layers" + + def test_config_serialization(self): + """Test that config can be serialized and deserialized.""" + config = MergeConfig( + merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 1.5} + ) + + # Serialize to dict + config_dict = config.model_dump() + assert config_dict["merge_pair_sampling_method"] == "mcmc" + assert config_dict["merge_pair_sampling_kwargs"] == {"temperature": 1.5} + + # Deserialize from dict + config2 = MergeConfig(**config_dict) + assert config2.merge_pair_sampling_method == "mcmc" + assert config2.merge_pair_sampling_kwargs == {"temperature": 1.5} + + def test_config_json_serialization(self): + """Test JSON serialization of config.""" + config = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.2} + ) + + # Serialize to JSON string + json_str = config.model_dump_json() + assert "range" in json_str + assert "0.2" in json_str + + # Parse back from JSON + import json + + config_dict = json.loads(json_str) + config2 = MergeConfig(**config_dict) + + assert config2.merge_pair_sampling_method == "range" + assert config2.merge_pair_sampling_kwargs == {"threshold": 0.2} + + def test_stable_hash_changes_with_sampling_params(self): + """Test that stable_hash changes when sampling parameters change.""" + config1 = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1} + ) + config2 = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.2} + ) + config3 = MergeConfig( + merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 1.0} + ) + + # Different configs should have different hashes + assert config1.stable_hash != config2.stable_hash + assert config1.stable_hash != config3.stable_hash + assert config2.stable_hash != config3.stable_hash + + # Same config should have same hash + config4 = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1} + ) + assert config1.stable_hash == config4.stable_hash + + def test_empty_kwargs(self): + """Test that empty kwargs dict works.""" + config = MergeConfig(merge_pair_sampling_method="range", merge_pair_sampling_kwargs={}) + + # Should work with default parameters of the sampler + k = 3 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Range sampler has default threshold=0.05 + pair = config.merge_pair_sample(costs) + + assert isinstance(pair, tuple) + assert pair[0] != pair[1] + + def test_extra_kwargs_filtered(self): + """Test that only valid kwargs are used by sampler.""" + config = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.3} + ) + + k = 3 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Should work with config's method + pair = config.merge_pair_sample(costs) + + assert isinstance(pair, tuple) + assert pair[0] != pair[1] diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py new file mode 100644 index 000000000..8492300de --- /dev/null +++ b/tests/clustering/test_merge_integration.py @@ -0,0 +1,151 @@ +"""Integration tests for the merge system with new samplers.""" + +import torch + +from spd.clustering.consts import ComponentLabels +from spd.clustering.merge import merge_iteration +from spd.clustering.merge_config import MergeConfig + + +class TestMergeIntegration: + """Test the full merge iteration with different samplers.""" + + def test_merge_with_range_sampler(self): + """Test merge iteration with range sampler.""" + # Create test data + n_samples = 100 + n_components = 10 + activations = torch.rand(n_samples, n_components) + component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + # Configure with range sampler + config = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=5, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.1}, + filter_dead_threshold=0.001, + ) + + # Run merge iteration + history = merge_iteration( + activations=activations, merge_config=config, component_labels=component_labels + ) + + # Check results + assert history is not None + assert len(history.merges.k_groups) > 0 + # First entry is after first merge, so should be n_components - 1 + assert history.merges.k_groups[0].item() == n_components - 1 + # After iterations, should have fewer groups (merges reduce count) + # Exact count depends on early stopping conditions + assert history.merges.k_groups[-1].item() < n_components + assert history.merges.k_groups[-1].item() >= 2 # Should stop before going below 2 + + def test_merge_with_mcmc_sampler(self): + """Test merge iteration with MCMC sampler.""" + # Create test data + n_samples = 100 + n_components = 10 + activations = torch.rand(n_samples, n_components) + component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + # Configure with MCMC sampler + config = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=5, + merge_pair_sampling_method="mcmc", + merge_pair_sampling_kwargs={"temperature": 1.0}, + filter_dead_threshold=0.001, + ) + + # Run merge iteration + history = merge_iteration( + activations=activations, merge_config=config, component_labels=component_labels + ) + + # Check results + assert history is not None + assert len(history.merges.k_groups) > 0 + # First entry is after first merge, so should be n_components - 1 + assert history.merges.k_groups[0].item() == n_components - 1 + # Should have fewer groups after iterations + assert history.merges.k_groups[-1].item() < n_components + assert history.merges.k_groups[-1].item() >= 2 + + def test_merge_comparison_samplers(self): + """Compare behavior of different samplers with same data.""" + # Create test data with clear structure + n_samples = 100 + n_components = 8 + activations = torch.rand(n_samples, n_components) + + # Make some components more active to create cost structure + activations[:, 0] *= 2 # Component 0 is very active + activations[:, 1] *= 0.1 # Component 1 is rarely active + + component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + # Run with range sampler (threshold=0 for deterministic minimum selection) + config_range = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=3, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.0}, # Always select minimum + ) + + history_range = merge_iteration( + activations=activations.clone(), + merge_config=config_range, + component_labels=ComponentLabels(component_labels.copy()), + ) + + # Run with MCMC sampler (low temperature for near-deterministic) + config_mcmc = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=3, + merge_pair_sampling_method="mcmc", + merge_pair_sampling_kwargs={"temperature": 0.01}, # Very low temp + ) + + history_mcmc = merge_iteration( + activations=activations.clone(), + merge_config=config_mcmc, + component_labels=ComponentLabels(component_labels.copy()), + ) + + # Both should reduce groups from initial count + assert history_range.merges.k_groups[-1].item() < n_components + assert history_mcmc.merges.k_groups[-1].item() < n_components + assert history_range.merges.k_groups[-1].item() >= 2 + assert history_mcmc.merges.k_groups[-1].item() >= 2 + + def test_merge_with_small_components(self): + """Test merge with very few components.""" + # Edge case: only 3 components + n_samples = 50 + n_components = 3 + activations = torch.rand(n_samples, n_components) + component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + config = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=1, # Just one merge + merge_pair_sampling_method="mcmc", + merge_pair_sampling_kwargs={"temperature": 2.0}, + ) + + history = merge_iteration( + activations=activations, merge_config=config, component_labels=component_labels + ) + + # First entry is after first merge, so should be 3 - 1 = 2 + assert history.merges.k_groups[0].item() == 2 + # Early stopping may occur at 2 groups, so final count could be 2 or 3 + assert history.merges.k_groups[-1].item() >= 2 + assert history.merges.k_groups[-1].item() <= 3 diff --git a/tests/clustering/test_merge_pair_samplers.py b/tests/clustering/test_merge_pair_samplers.py new file mode 100644 index 000000000..66c59cb66 --- /dev/null +++ b/tests/clustering/test_merge_pair_samplers.py @@ -0,0 +1,257 @@ +"""Tests for merge pair sampling functionality.""" + +import pytest +import torch + +from spd.clustering.math.merge_pair_samplers import ( + MERGE_PAIR_SAMPLERS, + mcmc_sampler, + range_sampler, +) + + +class TestRangeSampler: + """Test range-based merge pair sampling.""" + + def test_range_sampler_basic(self): + """Test basic functionality of range sampler.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 # Make symmetric + costs.fill_diagonal_(float("inf")) # No self-merges + + # Test with different thresholds + pair_low = range_sampler(costs, threshold=0.0) + pair_mid = range_sampler(costs, threshold=0.5) + pair_high = range_sampler(costs, threshold=1.0) + + # All should return valid pairs + assert pair_low[0] != pair_low[1] + assert pair_mid[0] != pair_mid[1] + assert pair_high[0] != pair_high[1] + + # All indices should be in valid range + for pair in [pair_low, pair_mid, pair_high]: + assert 0 <= pair[0] < k + assert 0 <= pair[1] < k + + def test_range_sampler_threshold_zero(self): + """Test that threshold=0 always selects minimum cost pair.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Find the true minimum + min_val = float("inf") + _min_pair = None + for i in range(k): + for j in range(k): + if i != j and costs[i, j] < min_val: + min_val = costs[i, j].item() + _min_pair = (i, j) + + # Sample multiple times with threshold=0 + for _ in range(10): + pair = range_sampler(costs, threshold=0.0) + # Should always get the minimum (or its symmetric equivalent) + assert costs[pair[0], pair[1]] == min_val or costs[pair[1], pair[0]] == min_val + + def test_range_sampler_threshold_one(self): + """Test that threshold=1 can select any non-diagonal pair.""" + k = 4 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Sample many times to check we get different pairs + pairs_seen = set() + for _ in range(100): + pair = range_sampler(costs, threshold=1.0) + # Normalize pair order for comparison + normalized = tuple(sorted(pair)) + pairs_seen.add(normalized) + + # With threshold=1, we should see multiple different pairs + assert len(pairs_seen) > 1 + + def test_range_sampler_small_matrix(self): + """Test range sampler with 2x2 matrix.""" + costs = torch.tensor([[float("inf"), 1.0], [1.0, float("inf")]]) + + pair = range_sampler(costs, threshold=0.5) + # Only valid pair is (0, 1) or (1, 0) + assert set(pair) == {0, 1} + + +class TestMCMCSampler: + """Test MCMC-based merge pair sampling.""" + + def test_mcmc_sampler_basic(self): + """Test basic functionality of MCMC sampler.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Test with different temperatures + pair_low_temp = mcmc_sampler(costs, temperature=0.1) + pair_mid_temp = mcmc_sampler(costs, temperature=1.0) + pair_high_temp = mcmc_sampler(costs, temperature=10.0) + + # All should return valid pairs + for pair in [pair_low_temp, pair_mid_temp, pair_high_temp]: + assert pair[0] != pair[1] + assert 0 <= pair[0] < k + assert 0 <= pair[1] < k + + def test_mcmc_sampler_low_temperature(self): + """Test that low temperature favors low-cost pairs.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Find minimum cost + min_val = float("inf") + for i in range(k): + for j in range(k): + if i != j: + min_val = min(min_val, costs[i, j].item()) + + # Sample many times with very low temperature + low_cost_count = 0 + n_samples = 100 + for _ in range(n_samples): + pair = mcmc_sampler(costs, temperature=0.01) + cost = costs[pair[0], pair[1]].item() + # Check if it's close to minimum + if abs(cost - min_val) < 0.5: # Within 0.5 of minimum + low_cost_count += 1 + + # Most samples should be near minimum with low temperature + assert low_cost_count > n_samples * 0.7 + + def test_mcmc_sampler_high_temperature(self): + """Test that high temperature gives more uniform sampling.""" + k = 4 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Sample many times with high temperature + pairs_count = {} + n_samples = 1000 + for _ in range(n_samples): + pair = mcmc_sampler(costs, temperature=100.0) + # Normalize pair order for counting + normalized = tuple(sorted(pair)) + pairs_count[normalized] = pairs_count.get(normalized, 0) + 1 + + # With high temperature, distribution should be relatively uniform + # There are k*(k-1)/2 unique pairs + expected_count = n_samples / (k * (k - 1) / 2) + for count in pairs_count.values(): + # Each pair count should be within reasonable range of expected + assert expected_count * 0.3 < count < expected_count * 1.7 + + def test_mcmc_sampler_small_matrix(self): + """Test MCMC sampler with 2x2 matrix.""" + costs = torch.tensor([[float("inf"), 1.0], [1.0, float("inf")]]) + + pair = mcmc_sampler(costs, temperature=1.0) + # Only valid pair is (0, 1) or (1, 0) + assert set(pair) == {0, 1} + + def test_mcmc_sampler_extreme_costs(self): + """Test MCMC sampler with extreme cost differences.""" + k = 3 + # Create matrix with one very low cost and rest high + costs = torch.full((k, k), 1000.0) + costs[0, 1] = costs[1, 0] = 1.0 # One low-cost pair + costs.fill_diagonal_(float("inf")) + + # With low temperature, should almost always select the low-cost pair + low_cost_selected = 0 + for _ in range(100): + pair = mcmc_sampler(costs, temperature=0.1) + if set(pair) == {0, 1}: + low_cost_selected += 1 + + assert low_cost_selected > 95 # Should almost always select (0,1) + + +class TestSamplerRegistry: + """Test the sampler registry.""" + + def test_registry_contains_samplers(self): + """Test that registry contains expected samplers.""" + assert "range" in MERGE_PAIR_SAMPLERS + assert "mcmc" in MERGE_PAIR_SAMPLERS + assert MERGE_PAIR_SAMPLERS["range"] is range_sampler + assert MERGE_PAIR_SAMPLERS["mcmc"] is mcmc_sampler + + def test_registry_samplers_callable(self): + """Test that all registry samplers are callable with correct signature.""" + k = 3 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + for name, sampler in MERGE_PAIR_SAMPLERS.items(): + # Should be callable + assert callable(sampler) + + # Test with default kwargs + if name == "range": + pair = sampler(costs, threshold=0.5) + elif name == "mcmc": + pair = sampler(costs, temperature=1.0) + else: + pytest.fail(f"Unknown sampler {name}") + + # Should return valid pair + assert isinstance(pair, tuple) + assert len(pair) == 2 + assert pair[0] != pair[1] + assert 0 <= pair[0] < k + assert 0 <= pair[1] < k + + +class TestSamplerIntegration: + """Integration tests for samplers with edge cases.""" + + def test_samplers_deterministic_with_seed(self): + """Test that samplers are deterministic with fixed seed.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Test range sampler + torch.manual_seed(42) + pair1 = range_sampler(costs, threshold=0.5) + torch.manual_seed(42) + pair2 = range_sampler(costs, threshold=0.5) + # Can't guarantee exact match due to Python's random module + # but both should be valid + assert pair1[0] != pair1[1] + assert pair2[0] != pair2[1] + + # Test MCMC sampler + torch.manual_seed(42) + pair1 = mcmc_sampler(costs, temperature=1.0) + torch.manual_seed(42) + pair2 = mcmc_sampler(costs, temperature=1.0) + assert pair1 == pair2 # Should be deterministic with same seed + + def test_samplers_all_infinite_costs(self): + """Test samplers handle all-infinite costs gracefully.""" + k = 3 + costs = torch.full((k, k), float("inf")) + + # This is an edge case - no valid pairs exist + # Samplers should handle this without crashing + # (though the result may not be meaningful) + with pytest.raises((ValueError, RuntimeError, IndexError)): + range_sampler(costs, threshold=0.5) diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py new file mode 100644 index 000000000..ca6bad6ee --- /dev/null +++ b/tests/clustering/test_pipeline_config.py @@ -0,0 +1,137 @@ +"""Tests for ClusteringPipelineConfig and ClusteringRunConfig with inline config support.""" + +from pathlib import Path + +import pydantic_core +import pytest + +from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.merge_config import MergeConfig +from spd.clustering.scripts.run_pipeline import ClusteringPipelineConfig +from spd.settings import REPO_ROOT + + +class TestClusteringRunConfigStableHash: + """Test ClusteringRunConfig.stable_hash_b64() method.""" + + def test_stable_hash_b64(self): + """Test that stable_hash_b64 is deterministic, unique, and URL-safe.""" + # Create 4 configs: 2 identical, 2 different + config1 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + merge_config=MergeConfig(), + ) + config2 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + merge_config=MergeConfig(), + ) + config3 = ClusteringRunConfig( + model_path="wandb:test/project/run2", # Different model_path + batch_size=32, + dataset_seed=0, + merge_config=MergeConfig(), + ) + config4 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + merge_config=MergeConfig( + activation_threshold=0.2 + ), # Different merge_config to test nested fields + ) + + hash1 = config1.stable_hash_b64() + hash2 = config2.stable_hash_b64() + hash3 = config3.stable_hash_b64() + hash4 = config4.stable_hash_b64() + + # Identical configs produce identical hashes + assert hash1 == hash2 + + # Different configs produce different hashes + assert hash1 != hash3 + assert hash1 != hash4 + assert hash3 != hash4 + + # Hashes are strings + assert isinstance(hash1, str) + assert len(hash1) > 0 + + # Hashes are URL-safe base64 (no padding, URL-safe chars only) + assert "=" not in hash1 + valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") + assert all(c in valid_chars for c in hash1) + + +class TestClusteringPipelineConfigValidation: + """Test ClusteringPipelineConfig validation logic.""" + + def test_error_when_path_does_not_exist(self): + """Test that error is raised when clustering_run_config_path does not exist.""" + with pytest.raises(pydantic_core._pydantic_core.ValidationError): + ClusteringPipelineConfig( + clustering_run_config_path=Path("nonexistent/path.json"), + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + slurm_job_name_prefix=None, + slurm_partition=None, + wandb_entity="test", + create_git_snapshot=False, + ) + + def test_valid_config_with_existing_path(self): + """Test that config is valid when path points to existing file.""" + expected_path = Path("spd/clustering/configs/crc/resid_mlp1.json") + + config = ClusteringPipelineConfig( + clustering_run_config_path=expected_path, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + assert config.clustering_run_config_path == expected_path + + +def _get_config_files(path: Path): + """Helper to get all config files.""" + pipeline_config_files = ( + list(path.glob("*.yaml")) + list(path.glob("*.yml")) + list(path.glob("*.json")) + ) + assert len(pipeline_config_files) > 0, f"No pipeline files found in {path}" + return pipeline_config_files + + +class TestAllConfigsValidation: + """Test that all existing config files can be loaded and validated.""" + + @pytest.mark.parametrize( + "config_file", + _get_config_files(REPO_ROOT / "spd" / "clustering" / "configs"), + ids=lambda p: p.stem, + ) + def test_config_validate_pipeline(self, config_file: Path): + """Test that each pipeline config file is valid.""" + print(config_file) + _config = ClusteringPipelineConfig.from_file(config_file) + crc_path = _config.clustering_run_config_path + print(f"{crc_path = }") + assert crc_path.exists() + + @pytest.mark.parametrize( + "config_file", + _get_config_files(REPO_ROOT / "spd" / "clustering" / "configs" / "crc"), + ids=lambda p: p.stem, + ) + def test_config_validate_pipeline_clustering_run(self, config_file: Path): + """Test that each clustering run config file is valid.""" + print(config_file) + _config = ClusteringRunConfig.from_file(config_file) + assert isinstance(_config, ClusteringRunConfig) diff --git a/tests/clustering/test_run_clustering_happy_path.py b/tests/clustering/test_run_clustering_happy_path.py new file mode 100644 index 000000000..5e2cbbd1c --- /dev/null +++ b/tests/clustering/test_run_clustering_happy_path.py @@ -0,0 +1,38 @@ +import tempfile +from pathlib import Path + +import pytest + +from spd.clustering.clustering_run_config import ClusteringRunConfig, LoggingIntervals +from spd.clustering.merge_config import MergeConfig +from spd.clustering.scripts.run_clustering import main + + +@pytest.mark.slow +def test_run_clustering_happy_path(): + """Test that run_clustering.py runs without errors.""" + with tempfile.TemporaryDirectory() as temp_dir: + config = ClusteringRunConfig( + model_path="wandb:goodfire/spd/runs/zxbu57pt", # An ss_llama run + batch_size=4, + dataset_seed=0, + base_output_dir=Path(temp_dir), + ensemble_id=None, + merge_config=MergeConfig( + activation_threshold=0.01, + alpha=1.0, + iters=3, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.05}, + ), + wandb_project=None, + wandb_entity="goodfire", + logging_intervals=LoggingIntervals( + stat=1, + tensor=100, + plot=100, + artifact=100, + ), + dataset_streaming=True, # tests in CI very slow without this, see https://github.com/goodfire-ai/spd/pull/199 + ) + main(config) diff --git a/tests/scripts_run/test_main.py b/tests/scripts_run/test_main.py index 485869394..8dc520a22 100644 --- a/tests/scripts_run/test_main.py +++ b/tests/scripts_run/test_main.py @@ -25,7 +25,7 @@ def test_invalid_experiment_name(self): with pytest.raises(ValueError, match=f"Invalid experiments.*{fake_exp_name}"): _get_experiments(f"{fake_exp_name},tms_5-2") - @patch("spd.scripts.run.submit_slurm_array") + @patch("spd.scripts.run.submit_slurm_job") @patch("spd.scripts.run.create_slurm_array_script") @patch("spd.scripts.run.create_git_snapshot") @patch("spd.scripts.run._wandb_setup") @@ -34,14 +34,21 @@ def test_sweep_creates_slurm_array( mock_wandb_setup, mock_create_git_snapshot, mock_create_slurm_array_script, - mock_submit_slurm_array, + mock_submit_slurm_job, ): """Test that sweep runs create SLURM array jobs with sweep params.""" + from pathlib import Path + from spd.scripts.run_cli import main + from spd.utils.slurm import SubmitResult mock_create_git_snapshot.return_value = ("test-branch", "12345678") mock_create_slurm_array_script.return_value = "#!/bin/bash\necho test" - mock_submit_slurm_array.return_value = "12345" + mock_submit_slurm_job.return_value = SubmitResult( + job_id="12345", + script_path=Path("/tmp/test.sh"), + log_pattern="~/slurm_logs/slurm-12345_*.out", + ) main( experiments="tms_5-2", diff --git a/tests/test_ih_transformer.py b/tests/test_ih_transformer.py index 10171c7f2..94c1b0644 100644 --- a/tests/test_ih_transformer.py +++ b/tests/test_ih_transformer.py @@ -96,7 +96,7 @@ def test_ih_transformer_decomposition_happy_path() -> None: tokenizer_name=None, # Task Specific task_config=IHTaskConfig( - task_name="induction_head", + task_name="ih", ), ) diff --git a/uv.lock b/uv.lock index c5ff0c0af..e2ae3dd31 100644 --- a/uv.lock +++ b/uv.lock @@ -157,11 +157,11 @@ wheels = [ [[package]] name = "cachetools" -version = "6.2.2" +version = "6.2.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fb/44/ca1675be2a83aeee1886ab745b28cda92093066590233cc501890eb8417a/cachetools-6.2.2.tar.gz", hash = "sha256:8e6d266b25e539df852251cfd6f990b4bc3a141db73b939058d809ebd2590fc6", size = 31571, upload-time = "2025-11-13T17:42:51.465Z" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/1d/ede8680603f6016887c062a2cf4fc8fdba905866a3ab8831aa8aa651320c/cachetools-6.2.4.tar.gz", hash = "sha256:82c5c05585e70b6ba2d3ae09ea60b79548872185d2f24ae1f2709d37299fd607", size = 31731, upload-time = "2025-12-15T18:24:53.744Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/46/eb6eca305c77a4489affe1c5d8f4cae82f285d9addd8de4ec084a7184221/cachetools-6.2.2-py3-none-any.whl", hash = "sha256:6c09c98183bf58560c97b2abfcedcbaf6a896a490f534b031b661d3723b45ace", size = 11503, upload-time = "2025-11-13T17:42:50.232Z" }, + { url = "https://files.pythonhosted.org/packages/2c/fc/1d7b80d0eb7b714984ce40efc78859c022cd930e402f599d8ca9e39c78a4/cachetools-6.2.4-py3-none-any.whl", hash = "sha256:69a7a52634fed8b8bf6e24a050fb60bff1c9bd8f6d24572b99c32d4e71e62a51", size = 11551, upload-time = "2025-12-15T18:24:52.332Z" }, ] [[package]] @@ -295,37 +295,37 @@ wheels = [ [[package]] name = "coverage" -version = "7.12.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/89/26/4a96807b193b011588099c3b5c89fbb05294e5b90e71018e065465f34eb6/coverage-7.12.0.tar.gz", hash = "sha256:fc11e0a4e372cb5f282f16ef90d4a585034050ccda536451901abfb19a57f40c", size = 819341, upload-time = "2025-11-18T13:34:20.766Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/14/771700b4048774e48d2c54ed0c674273702713c9ee7acdfede40c2666747/coverage-7.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:47324fffca8d8eae7e185b5bb20c14645f23350f870c1649003618ea91a78941", size = 217725, upload-time = "2025-11-18T13:32:49.22Z" }, - { url = "https://files.pythonhosted.org/packages/17/a7/3aa4144d3bcb719bf67b22d2d51c2d577bf801498c13cb08f64173e80497/coverage-7.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ccf3b2ede91decd2fb53ec73c1f949c3e034129d1e0b07798ff1d02ea0c8fa4a", size = 218098, upload-time = "2025-11-18T13:32:50.78Z" }, - { url = "https://files.pythonhosted.org/packages/fc/9c/b846bbc774ff81091a12a10203e70562c91ae71badda00c5ae5b613527b1/coverage-7.12.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:b365adc70a6936c6b0582dc38746b33b2454148c02349345412c6e743efb646d", size = 249093, upload-time = "2025-11-18T13:32:52.554Z" }, - { url = "https://files.pythonhosted.org/packages/76/b6/67d7c0e1f400b32c883e9342de4a8c2ae7c1a0b57c5de87622b7262e2309/coverage-7.12.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:bc13baf85cd8a4cfcf4a35c7bc9d795837ad809775f782f697bf630b7e200211", size = 251686, upload-time = "2025-11-18T13:32:54.862Z" }, - { url = "https://files.pythonhosted.org/packages/cc/75/b095bd4b39d49c3be4bffbb3135fea18a99a431c52dd7513637c0762fecb/coverage-7.12.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:099d11698385d572ceafb3288a5b80fe1fc58bf665b3f9d362389de488361d3d", size = 252930, upload-time = "2025-11-18T13:32:56.417Z" }, - { url = "https://files.pythonhosted.org/packages/6e/f3/466f63015c7c80550bead3093aacabf5380c1220a2a93c35d374cae8f762/coverage-7.12.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:473dc45d69694069adb7680c405fb1e81f60b2aff42c81e2f2c3feaf544d878c", size = 249296, upload-time = "2025-11-18T13:32:58.074Z" }, - { url = "https://files.pythonhosted.org/packages/27/86/eba2209bf2b7e28c68698fc13437519a295b2d228ba9e0ec91673e09fa92/coverage-7.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:583f9adbefd278e9de33c33d6846aa8f5d164fa49b47144180a0e037f0688bb9", size = 251068, upload-time = "2025-11-18T13:32:59.646Z" }, - { url = "https://files.pythonhosted.org/packages/ec/55/ca8ae7dbba962a3351f18940b359b94c6bafdd7757945fdc79ec9e452dc7/coverage-7.12.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b2089cc445f2dc0af6f801f0d1355c025b76c24481935303cf1af28f636688f0", size = 249034, upload-time = "2025-11-18T13:33:01.481Z" }, - { url = "https://files.pythonhosted.org/packages/7a/d7/39136149325cad92d420b023b5fd900dabdd1c3a0d1d5f148ef4a8cedef5/coverage-7.12.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:950411f1eb5d579999c5f66c62a40961f126fc71e5e14419f004471957b51508", size = 248853, upload-time = "2025-11-18T13:33:02.935Z" }, - { url = "https://files.pythonhosted.org/packages/fe/b6/76e1add8b87ef60e00643b0b7f8f7bb73d4bf5249a3be19ebefc5793dd25/coverage-7.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:b1aab7302a87bafebfe76b12af681b56ff446dc6f32ed178ff9c092ca776e6bc", size = 250619, upload-time = "2025-11-18T13:33:04.336Z" }, - { url = "https://files.pythonhosted.org/packages/95/87/924c6dc64f9203f7a3c1832a6a0eee5a8335dbe5f1bdadcc278d6f1b4d74/coverage-7.12.0-cp313-cp313-win32.whl", hash = "sha256:d7e0d0303c13b54db495eb636bc2465b2fb8475d4c8bcec8fe4b5ca454dfbae8", size = 220261, upload-time = "2025-11-18T13:33:06.493Z" }, - { url = "https://files.pythonhosted.org/packages/91/77/dd4aff9af16ff776bf355a24d87eeb48fc6acde54c907cc1ea89b14a8804/coverage-7.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:ce61969812d6a98a981d147d9ac583a36ac7db7766f2e64a9d4d059c2fe29d07", size = 221072, upload-time = "2025-11-18T13:33:07.926Z" }, - { url = "https://files.pythonhosted.org/packages/70/49/5c9dc46205fef31b1b226a6e16513193715290584317fd4df91cdaf28b22/coverage-7.12.0-cp313-cp313-win_arm64.whl", hash = "sha256:bcec6f47e4cb8a4c2dc91ce507f6eefc6a1b10f58df32cdc61dff65455031dfc", size = 219702, upload-time = "2025-11-18T13:33:09.631Z" }, - { url = "https://files.pythonhosted.org/packages/9b/62/f87922641c7198667994dd472a91e1d9b829c95d6c29529ceb52132436ad/coverage-7.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:459443346509476170d553035e4a3eed7b860f4fe5242f02de1010501956ce87", size = 218420, upload-time = "2025-11-18T13:33:11.153Z" }, - { url = "https://files.pythonhosted.org/packages/85/dd/1cc13b2395ef15dbb27d7370a2509b4aee77890a464fb35d72d428f84871/coverage-7.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:04a79245ab2b7a61688958f7a855275997134bc84f4a03bc240cf64ff132abf6", size = 218773, upload-time = "2025-11-18T13:33:12.569Z" }, - { url = "https://files.pythonhosted.org/packages/74/40/35773cc4bb1e9d4658d4fb669eb4195b3151bef3bbd6f866aba5cd5dac82/coverage-7.12.0-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:09a86acaaa8455f13d6a99221d9654df249b33937b4e212b4e5a822065f12aa7", size = 260078, upload-time = "2025-11-18T13:33:14.037Z" }, - { url = "https://files.pythonhosted.org/packages/ec/ee/231bb1a6ffc2905e396557585ebc6bdc559e7c66708376d245a1f1d330fc/coverage-7.12.0-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:907e0df1b71ba77463687a74149c6122c3f6aac56c2510a5d906b2f368208560", size = 262144, upload-time = "2025-11-18T13:33:15.601Z" }, - { url = "https://files.pythonhosted.org/packages/28/be/32f4aa9f3bf0b56f3971001b56508352c7753915345d45fab4296a986f01/coverage-7.12.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9b57e2d0ddd5f0582bae5437c04ee71c46cd908e7bc5d4d0391f9a41e812dd12", size = 264574, upload-time = "2025-11-18T13:33:17.354Z" }, - { url = "https://files.pythonhosted.org/packages/68/7c/00489fcbc2245d13ab12189b977e0cf06ff3351cb98bc6beba8bd68c5902/coverage-7.12.0-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:58c1c6aa677f3a1411fe6fb28ec3a942e4f665df036a3608816e0847fad23296", size = 259298, upload-time = "2025-11-18T13:33:18.958Z" }, - { url = "https://files.pythonhosted.org/packages/96/b4/f0760d65d56c3bea95b449e02570d4abd2549dc784bf39a2d4721a2d8ceb/coverage-7.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4c589361263ab2953e3c4cd2a94db94c4ad4a8e572776ecfbad2389c626e4507", size = 262150, upload-time = "2025-11-18T13:33:20.644Z" }, - { url = "https://files.pythonhosted.org/packages/c5/71/9a9314df00f9326d78c1e5a910f520d599205907432d90d1c1b7a97aa4b1/coverage-7.12.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:91b810a163ccad2e43b1faa11d70d3cf4b6f3d83f9fd5f2df82a32d47b648e0d", size = 259763, upload-time = "2025-11-18T13:33:22.189Z" }, - { url = "https://files.pythonhosted.org/packages/10/34/01a0aceed13fbdf925876b9a15d50862eb8845454301fe3cdd1df08b2182/coverage-7.12.0-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:40c867af715f22592e0d0fb533a33a71ec9e0f73a6945f722a0c85c8c1cbe3a2", size = 258653, upload-time = "2025-11-18T13:33:24.239Z" }, - { url = "https://files.pythonhosted.org/packages/8d/04/81d8fd64928acf1574bbb0181f66901c6c1c6279c8ccf5f84259d2c68ae9/coverage-7.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:68b0d0a2d84f333de875666259dadf28cc67858bc8fd8b3f1eae84d3c2bec455", size = 260856, upload-time = "2025-11-18T13:33:26.365Z" }, - { url = "https://files.pythonhosted.org/packages/f2/76/fa2a37bfaeaf1f766a2d2360a25a5297d4fb567098112f6517475eee120b/coverage-7.12.0-cp313-cp313t-win32.whl", hash = "sha256:73f9e7fbd51a221818fd11b7090eaa835a353ddd59c236c57b2199486b116c6d", size = 220936, upload-time = "2025-11-18T13:33:28.165Z" }, - { url = "https://files.pythonhosted.org/packages/f9/52/60f64d932d555102611c366afb0eb434b34266b1d9266fc2fe18ab641c47/coverage-7.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:24cff9d1f5743f67db7ba46ff284018a6e9aeb649b67aa1e70c396aa1b7cb23c", size = 222001, upload-time = "2025-11-18T13:33:29.656Z" }, - { url = "https://files.pythonhosted.org/packages/77/df/c303164154a5a3aea7472bf323b7c857fed93b26618ed9fc5c2955566bb0/coverage-7.12.0-cp313-cp313t-win_arm64.whl", hash = "sha256:c87395744f5c77c866d0f5a43d97cc39e17c7f1cb0115e54a2fe67ca75c5d14d", size = 220273, upload-time = "2025-11-18T13:33:31.415Z" }, - { url = "https://files.pythonhosted.org/packages/ce/a3/43b749004e3c09452e39bb56347a008f0a0668aad37324a99b5c8ca91d9e/coverage-7.12.0-py3-none-any.whl", hash = "sha256:159d50c0b12e060b15ed3d39f87ed43d4f7f7ad40b8a534f4dd331adbb51104a", size = 209503, upload-time = "2025-11-18T13:34:18.892Z" }, +version = "7.13.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b6/45/2c665ca77ec32ad67e25c77daf1cee28ee4558f3bc571cdbaf88a00b9f23/coverage-7.13.0.tar.gz", hash = "sha256:a394aa27f2d7ff9bc04cf703817773a59ad6dfbd577032e690f961d2460ee936", size = 820905, upload-time = "2025-12-08T13:14:38.055Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/cc/bce226595eb3bf7d13ccffe154c3c487a22222d87ff018525ab4dd2e9542/coverage-7.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:28ee1c96109974af104028a8ef57cec21447d42d0e937c0275329272e370ebcf", size = 218297, upload-time = "2025-12-08T13:13:10.977Z" }, + { url = "https://files.pythonhosted.org/packages/3b/9f/73c4d34600aae03447dff3d7ad1d0ac649856bfb87d1ca7d681cfc913f9e/coverage-7.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d1e97353dcc5587b85986cda4ff3ec98081d7e84dd95e8b2a6d59820f0545f8a", size = 218673, upload-time = "2025-12-08T13:13:12.562Z" }, + { url = "https://files.pythonhosted.org/packages/63/ab/8fa097db361a1e8586535ae5073559e6229596b3489ec3ef2f5b38df8cb2/coverage-7.13.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:99acd4dfdfeb58e1937629eb1ab6ab0899b131f183ee5f23e0b5da5cba2fec74", size = 249652, upload-time = "2025-12-08T13:13:13.909Z" }, + { url = "https://files.pythonhosted.org/packages/90/3a/9bfd4de2ff191feb37ef9465855ca56a6f2f30a3bca172e474130731ac3d/coverage-7.13.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:ff45e0cd8451e293b63ced93161e189780baf444119391b3e7d25315060368a6", size = 252251, upload-time = "2025-12-08T13:13:15.553Z" }, + { url = "https://files.pythonhosted.org/packages/df/61/b5d8105f016e1b5874af0d7c67542da780ccd4a5f2244a433d3e20ceb1ad/coverage-7.13.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f4f72a85316d8e13234cafe0a9f81b40418ad7a082792fa4165bd7d45d96066b", size = 253492, upload-time = "2025-12-08T13:13:16.849Z" }, + { url = "https://files.pythonhosted.org/packages/f3/b8/0fad449981803cc47a4694768b99823fb23632150743f9c83af329bb6090/coverage-7.13.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:11c21557d0e0a5a38632cbbaca5f008723b26a89d70db6315523df6df77d6232", size = 249850, upload-time = "2025-12-08T13:13:18.142Z" }, + { url = "https://files.pythonhosted.org/packages/9a/e9/8d68337c3125014d918cf4327d5257553a710a2995a6a6de2ac77e5aa429/coverage-7.13.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:76541dc8d53715fb4f7a3a06b34b0dc6846e3c69bc6204c55653a85dd6220971", size = 251633, upload-time = "2025-12-08T13:13:19.56Z" }, + { url = "https://files.pythonhosted.org/packages/55/14/d4112ab26b3a1bc4b3c1295d8452dcf399ed25be4cf649002fb3e64b2d93/coverage-7.13.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:6e9e451dee940a86789134b6b0ffbe31c454ade3b849bb8a9d2cca2541a8e91d", size = 249586, upload-time = "2025-12-08T13:13:20.883Z" }, + { url = "https://files.pythonhosted.org/packages/2c/a9/22b0000186db663b0d82f86c2f1028099ae9ac202491685051e2a11a5218/coverage-7.13.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:5c67dace46f361125e6b9cace8fe0b729ed8479f47e70c89b838d319375c8137", size = 249412, upload-time = "2025-12-08T13:13:22.22Z" }, + { url = "https://files.pythonhosted.org/packages/a1/2e/42d8e0d9e7527fba439acdc6ed24a2b97613b1dc85849b1dd935c2cffef0/coverage-7.13.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f59883c643cb19630500f57016f76cfdcd6845ca8c5b5ea1f6e17f74c8e5f511", size = 251191, upload-time = "2025-12-08T13:13:23.899Z" }, + { url = "https://files.pythonhosted.org/packages/a4/af/8c7af92b1377fd8860536aadd58745119252aaaa71a5213e5a8e8007a9f5/coverage-7.13.0-cp313-cp313-win32.whl", hash = "sha256:58632b187be6f0be500f553be41e277712baa278147ecb7559983c6d9faf7ae1", size = 220829, upload-time = "2025-12-08T13:13:25.182Z" }, + { url = "https://files.pythonhosted.org/packages/58/f9/725e8bf16f343d33cbe076c75dc8370262e194ff10072c0608b8e5cf33a3/coverage-7.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:73419b89f812f498aca53f757dd834919b48ce4799f9d5cad33ca0ae442bdb1a", size = 221640, upload-time = "2025-12-08T13:13:26.836Z" }, + { url = "https://files.pythonhosted.org/packages/8a/ff/e98311000aa6933cc79274e2b6b94a2fe0fe3434fca778eba82003675496/coverage-7.13.0-cp313-cp313-win_arm64.whl", hash = "sha256:eb76670874fdd6091eedcc856128ee48c41a9bbbb9c3f1c7c3cf169290e3ffd6", size = 220269, upload-time = "2025-12-08T13:13:28.116Z" }, + { url = "https://files.pythonhosted.org/packages/cf/cf/bbaa2e1275b300343ea865f7d424cc0a2e2a1df6925a070b2b2d5d765330/coverage-7.13.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:6e63ccc6e0ad8986386461c3c4b737540f20426e7ec932f42e030320896c311a", size = 218990, upload-time = "2025-12-08T13:13:29.463Z" }, + { url = "https://files.pythonhosted.org/packages/21/1d/82f0b3323b3d149d7672e7744c116e9c170f4957e0c42572f0366dbb4477/coverage-7.13.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:494f5459ffa1bd45e18558cd98710c36c0b8fbfa82a5eabcbe671d80ecffbfe8", size = 219340, upload-time = "2025-12-08T13:13:31.524Z" }, + { url = "https://files.pythonhosted.org/packages/fb/e3/fe3fd4702a3832a255f4d43013eacb0ef5fc155a5960ea9269d8696db28b/coverage-7.13.0-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:06cac81bf10f74034e055e903f5f946e3e26fc51c09fc9f584e4a1605d977053", size = 260638, upload-time = "2025-12-08T13:13:32.965Z" }, + { url = "https://files.pythonhosted.org/packages/ad/01/63186cb000307f2b4da463f72af9b85d380236965574c78e7e27680a2593/coverage-7.13.0-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f2ffc92b46ed6e6760f1d47a71e56b5664781bc68986dbd1836b2b70c0ce2071", size = 262705, upload-time = "2025-12-08T13:13:34.378Z" }, + { url = "https://files.pythonhosted.org/packages/7c/a1/c0dacef0cc865f2455d59eed3548573ce47ed603205ffd0735d1d78b5906/coverage-7.13.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0602f701057c6823e5db1b74530ce85f17c3c5be5c85fc042ac939cbd909426e", size = 265125, upload-time = "2025-12-08T13:13:35.73Z" }, + { url = "https://files.pythonhosted.org/packages/ef/92/82b99223628b61300bd382c205795533bed021505eab6dd86e11fb5d7925/coverage-7.13.0-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:25dc33618d45456ccb1d37bce44bc78cf269909aa14c4db2e03d63146a8a1493", size = 259844, upload-time = "2025-12-08T13:13:37.69Z" }, + { url = "https://files.pythonhosted.org/packages/cf/2c/89b0291ae4e6cd59ef042708e1c438e2290f8c31959a20055d8768349ee2/coverage-7.13.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:71936a8b3b977ddd0b694c28c6a34f4fff2e9dd201969a4ff5d5fc7742d614b0", size = 262700, upload-time = "2025-12-08T13:13:39.525Z" }, + { url = "https://files.pythonhosted.org/packages/bf/f9/a5f992efae1996245e796bae34ceb942b05db275e4b34222a9a40b9fbd3b/coverage-7.13.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:936bc20503ce24770c71938d1369461f0c5320830800933bc3956e2a4ded930e", size = 260321, upload-time = "2025-12-08T13:13:41.172Z" }, + { url = "https://files.pythonhosted.org/packages/4c/89/a29f5d98c64fedbe32e2ac3c227fbf78edc01cc7572eee17d61024d89889/coverage-7.13.0-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:af0a583efaacc52ae2521f8d7910aff65cdb093091d76291ac5820d5e947fc1c", size = 259222, upload-time = "2025-12-08T13:13:43.282Z" }, + { url = "https://files.pythonhosted.org/packages/b3/c3/940fe447aae302a6701ee51e53af7e08b86ff6eed7631e5740c157ee22b9/coverage-7.13.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f1c23e24a7000da892a312fb17e33c5f94f8b001de44b7cf8ba2e36fbd15859e", size = 261411, upload-time = "2025-12-08T13:13:44.72Z" }, + { url = "https://files.pythonhosted.org/packages/eb/31/12a4aec689cb942a89129587860ed4d0fd522d5fda81237147fde554b8ae/coverage-7.13.0-cp313-cp313t-win32.whl", hash = "sha256:5f8a0297355e652001015e93be345ee54393e45dc3050af4a0475c5a2b767d46", size = 221505, upload-time = "2025-12-08T13:13:46.332Z" }, + { url = "https://files.pythonhosted.org/packages/65/8c/3b5fe3259d863572d2b0827642c50c3855d26b3aefe80bdc9eba1f0af3b0/coverage-7.13.0-cp313-cp313t-win_amd64.whl", hash = "sha256:6abb3a4c52f05e08460bd9acf04fec027f8718ecaa0d09c40ffbc3fbd70ecc39", size = 222569, upload-time = "2025-12-08T13:13:47.79Z" }, + { url = "https://files.pythonhosted.org/packages/b0/39/f71fa8316a96ac72fc3908839df651e8eccee650001a17f2c78cdb355624/coverage-7.13.0-cp313-cp313t-win_arm64.whl", hash = "sha256:3ad968d1e3aa6ce5be295ab5fe3ae1bf5bb4769d0f98a80a0252d543a2ef2e9e", size = 220841, upload-time = "2025-12-08T13:13:49.243Z" }, + { url = "https://files.pythonhosted.org/packages/8d/4c/1968f32fb9a2604645827e11ff84a31e59d532e01995f904723b4f5328b3/coverage-7.13.0-py3-none-any.whl", hash = "sha256:850d2998f380b1e266459ca5b47bc9e7daf9af1d070f66317972f382d46f1904", size = 210068, upload-time = "2025-12-08T13:14:36.236Z" }, ] [[package]] @@ -339,7 +339,7 @@ wheels = [ [[package]] name = "datasets" -version = "4.4.1" +version = "4.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "dill" }, @@ -357,22 +357,22 @@ dependencies = [ { name = "tqdm" }, { name = "xxhash" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/93/bf/0dae295d6d1ba0b1a200a9dd216838464b5bbd05da01407cb1330b377445/datasets-4.4.1.tar.gz", hash = "sha256:80322699aa8c0bbbdb7caa87906da689c3c2e29523cff698775c67f28fdab1fc", size = 585341, upload-time = "2025-11-05T16:00:38.162Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c4/54/9359803da96bc65439a28fbb014dc2c90b7d4d8034a93b72362b0d40191f/datasets-4.4.2.tar.gz", hash = "sha256:9de16e415c4ba4713eac0493f7c7dc74f3aa21599297f00cc6ddab409cb7b24b", size = 586474, upload-time = "2025-12-19T15:03:09.129Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/5e/6f8d874366788ad5d549e9ba258037d974dda6e004843be1bda794571701/datasets-4.4.1-py3-none-any.whl", hash = "sha256:c1163de5211e42546079ab355cc0250c7e6db16eb209ac5ac6252f801f596c44", size = 511591, upload-time = "2025-11-05T16:00:36.365Z" }, + { url = "https://files.pythonhosted.org/packages/7b/b5/fefa518c809de7bced5cddb7c21c010da66fa2ae494bda96844a280cc6ce/datasets-4.4.2-py3-none-any.whl", hash = "sha256:6f5ef3417504d9cd663c71c1b90b9a494ff4c2076a2cd6a6e40ceee6ad95befc", size = 512268, upload-time = "2025-12-19T15:03:07.087Z" }, ] [[package]] name = "debugpy" -version = "1.8.17" +version = "1.8.19" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/15/ad/71e708ff4ca377c4230530d6a7aa7992592648c122a2cd2b321cf8b35a76/debugpy-1.8.17.tar.gz", hash = "sha256:fd723b47a8c08892b1a16b2c6239a8b96637c62a59b94bb5dab4bac592a58a8e", size = 1644129, upload-time = "2025-09-17T16:33:20.633Z" } +sdist = { url = "https://files.pythonhosted.org/packages/73/75/9e12d4d42349b817cd545b89247696c67917aab907012ae5b64bbfea3199/debugpy-1.8.19.tar.gz", hash = "sha256:eea7e5987445ab0b5ed258093722d5ecb8bb72217c5c9b1e21f64efe23ddebdb", size = 1644590, upload-time = "2025-12-15T21:53:28.044Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/50/76/597e5cb97d026274ba297af8d89138dfd9e695767ba0e0895edb20963f40/debugpy-1.8.17-cp313-cp313-macosx_15_0_universal2.whl", hash = "sha256:857c1dd5d70042502aef1c6d1c2801211f3ea7e56f75e9c335f434afb403e464", size = 2538386, upload-time = "2025-09-17T16:33:54.594Z" }, - { url = "https://files.pythonhosted.org/packages/5f/60/ce5c34fcdfec493701f9d1532dba95b21b2f6394147234dce21160bd923f/debugpy-1.8.17-cp313-cp313-manylinux_2_34_x86_64.whl", hash = "sha256:3bea3b0b12f3946e098cce9b43c3c46e317b567f79570c3f43f0b96d00788088", size = 4292100, upload-time = "2025-09-17T16:33:56.353Z" }, - { url = "https://files.pythonhosted.org/packages/e8/95/7873cf2146577ef71d2a20bf553f12df865922a6f87b9e8ee1df04f01785/debugpy-1.8.17-cp313-cp313-win32.whl", hash = "sha256:e34ee844c2f17b18556b5bbe59e1e2ff4e86a00282d2a46edab73fd7f18f4a83", size = 5277002, upload-time = "2025-09-17T16:33:58.231Z" }, - { url = "https://files.pythonhosted.org/packages/46/11/18c79a1cee5ff539a94ec4aa290c1c069a5580fd5cfd2fb2e282f8e905da/debugpy-1.8.17-cp313-cp313-win_amd64.whl", hash = "sha256:6c5cd6f009ad4fca8e33e5238210dc1e5f42db07d4b6ab21ac7ffa904a196420", size = 5319047, upload-time = "2025-09-17T16:34:00.586Z" }, - { url = "https://files.pythonhosted.org/packages/b0/d0/89247ec250369fc76db477720a26b2fce7ba079ff1380e4ab4529d2fe233/debugpy-1.8.17-py2.py3-none-any.whl", hash = "sha256:60c7dca6571efe660ccb7a9508d73ca14b8796c4ed484c2002abba714226cfef", size = 5283210, upload-time = "2025-09-17T16:34:25.835Z" }, + { url = "https://files.pythonhosted.org/packages/71/3d/388035a31a59c26f1ecc8d86af607d0c42e20ef80074147cd07b180c4349/debugpy-1.8.19-cp313-cp313-macosx_15_0_universal2.whl", hash = "sha256:91e35db2672a0abaf325f4868fcac9c1674a0d9ad9bb8a8c849c03a5ebba3e6d", size = 2538859, upload-time = "2025-12-15T21:53:50.478Z" }, + { url = "https://files.pythonhosted.org/packages/4a/19/c93a0772d0962294f083dbdb113af1a7427bb632d36e5314297068f55db7/debugpy-1.8.19-cp313-cp313-manylinux_2_34_x86_64.whl", hash = "sha256:85016a73ab84dea1c1f1dcd88ec692993bcbe4532d1b49ecb5f3c688ae50c606", size = 4292575, upload-time = "2025-12-15T21:53:51.821Z" }, + { url = "https://files.pythonhosted.org/packages/5c/56/09e48ab796b0a77e3d7dc250f95251832b8bf6838c9632f6100c98bdf426/debugpy-1.8.19-cp313-cp313-win32.whl", hash = "sha256:b605f17e89ba0ecee994391194285fada89cee111cfcd29d6f2ee11cbdc40976", size = 5286209, upload-time = "2025-12-15T21:53:53.602Z" }, + { url = "https://files.pythonhosted.org/packages/fb/4e/931480b9552c7d0feebe40c73725dd7703dcc578ba9efc14fe0e6d31cfd1/debugpy-1.8.19-cp313-cp313-win_amd64.whl", hash = "sha256:c30639998a9f9cd9699b4b621942c0179a6527f083c72351f95c6ab1728d5b73", size = 5328206, upload-time = "2025-12-15T21:53:55.433Z" }, + { url = "https://files.pythonhosted.org/packages/25/3e/e27078370414ef35fafad2c06d182110073daaeb5d3bf734b0b1eeefe452/debugpy-1.8.19-py2.py3-none-any.whl", hash = "sha256:360ffd231a780abbc414ba0f005dad409e71c78637efe8f2bd75837132a41d38", size = 5292321, upload-time = "2025-12-15T21:54:16.024Z" }, ] [[package]] @@ -431,7 +431,7 @@ wheels = [ [[package]] name = "fastapi" -version = "0.123.10" +version = "0.127.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-doc" }, @@ -439,18 +439,18 @@ dependencies = [ { name = "starlette" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/22/ff/e01087de891010089f1620c916c0c13130f3898177955c13e2b02d22ec4a/fastapi-0.123.10.tar.gz", hash = "sha256:624d384d7cda7c096449c889fc776a0571948ba14c3c929fa8e9a78cd0b0a6a8", size = 356360, upload-time = "2025-12-05T21:27:46.237Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/02/2cbbecf6551e0c1a06f9b9765eb8f7ae126362fbba43babbb11b0e3b7db3/fastapi-0.127.0.tar.gz", hash = "sha256:5a9246e03dcd1fdb19f1396db30894867c1d630f5107dc167dcbc5ed1ea7d259", size = 369269, upload-time = "2025-12-21T16:47:16.393Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d7/f0/7cb92c4a720def85240fd63fbbcf147ce19e7a731c8e1032376bb5a486ac/fastapi-0.123.10-py3-none-any.whl", hash = "sha256:0503b7b7bc71bc98f7c90c9117d21fdf6147c0d74703011b87936becc86985c1", size = 111774, upload-time = "2025-12-05T21:27:44.78Z" }, + { url = "https://files.pythonhosted.org/packages/8a/fa/6a27e2ef789eb03060abb43b952a7f0bd39e6feaa3805362b48785bcedc5/fastapi-0.127.0-py3-none-any.whl", hash = "sha256:725aa2bb904e2eff8031557cf4b9b77459bfedd63cae8427634744fd199f6a49", size = 112055, upload-time = "2025-12-21T16:47:14.757Z" }, ] [[package]] name = "filelock" -version = "3.20.0" +version = "3.20.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/58/46/0028a82567109b5ef6e4d2a1f04a583fb513e6cf9527fcdd09afd817deeb/filelock-3.20.0.tar.gz", hash = "sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4", size = 18922, upload-time = "2025-10-08T18:03:50.056Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/23/ce7a1126827cedeb958fc043d61745754464eb56c5937c35bbf2b8e26f34/filelock-3.20.1.tar.gz", hash = "sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c", size = 19476, upload-time = "2025-12-15T23:54:28.027Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" }, + { url = "https://files.pythonhosted.org/packages/e3/7f/a1a97644e39e7316d850784c642093c99df1290a460df4ede27659056834/filelock-3.20.1-py3-none-any.whl", hash = "sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a", size = 16666, upload-time = "2025-12-15T23:54:26.874Z" }, ] [[package]] @@ -484,19 +484,19 @@ wheels = [ [[package]] name = "fonttools" -version = "4.61.0" +version = "4.61.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/33/f9/0e84d593c0e12244150280a630999835a64f2852276161b62a0f98318de0/fonttools-4.61.0.tar.gz", hash = "sha256:ec520a1f0c7758d7a858a00f090c1745f6cde6a7c5e76fb70ea4044a15f712e7", size = 3561884, upload-time = "2025-11-28T17:05:49.491Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/ca/cf17b88a8df95691275a3d77dc0a5ad9907f328ae53acbe6795da1b2f5ed/fonttools-4.61.1.tar.gz", hash = "sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69", size = 3565756, upload-time = "2025-12-12T17:31:24.246Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/17/45/334f0d7f181e5473cfb757e1b60f4e60e7fc64f28d406e5d364a952718c0/fonttools-4.61.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ba774b8cbd8754f54b8eb58124e8bd45f736b2743325ab1a5229698942b9b433", size = 2841801, upload-time = "2025-11-28T17:05:01.621Z" }, - { url = "https://files.pythonhosted.org/packages/cc/63/97b9c78e1f79bc741d4efe6e51f13872d8edb2b36e1b9fb2bab0d4491bb7/fonttools-4.61.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c84b430616ed73ce46e9cafd0bf0800e366a3e02fb7e1ad7c1e214dbe3862b1f", size = 2379024, upload-time = "2025-11-28T17:05:03.668Z" }, - { url = "https://files.pythonhosted.org/packages/4e/80/c87bc524a90dbeb2a390eea23eae448286983da59b7e02c67fa0ca96a8c5/fonttools-4.61.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b2b734d8391afe3c682320840c8191de9bd24e7eb85768dd4dc06ed1b63dbb1b", size = 4923706, upload-time = "2025-11-28T17:05:05.494Z" }, - { url = "https://files.pythonhosted.org/packages/6d/f6/a3b0374811a1de8c3f9207ec88f61ad1bb96f938ed89babae26c065c2e46/fonttools-4.61.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a5c5fff72bf31b0e558ed085e4fd7ed96eb85881404ecc39ed2a779e7cf724eb", size = 4979751, upload-time = "2025-11-28T17:05:07.665Z" }, - { url = "https://files.pythonhosted.org/packages/a5/3b/30f63b4308b449091573285f9d27619563a84f399946bca3eadc9554afbe/fonttools-4.61.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:14a290c5c93fcab76b7f451e6a4b7721b712d90b3b5ed6908f1abcf794e90d6d", size = 4921113, upload-time = "2025-11-28T17:05:09.551Z" }, - { url = "https://files.pythonhosted.org/packages/41/6c/58e6e9b7d9d8bf2d7010bd7bb493060b39b02a12d1cda64a8bfb116ce760/fonttools-4.61.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:13e3e20a5463bfeb77b3557d04b30bd6a96a6bb5c15c7b2e7908903e69d437a0", size = 5063183, upload-time = "2025-11-28T17:05:11.677Z" }, - { url = "https://files.pythonhosted.org/packages/3f/e3/52c790ab2b07492df059947a1fd7778e105aac5848c0473029a4d20481a2/fonttools-4.61.0-cp313-cp313-win32.whl", hash = "sha256:6781e7a4bb010be1cd69a29927b0305c86b843395f2613bdabe115f7d6ea7f34", size = 2263159, upload-time = "2025-11-28T17:05:13.292Z" }, - { url = "https://files.pythonhosted.org/packages/e9/1f/116013b200fbeba871046554d5d2a45fefa69a05c40e9cdfd0d4fff53edc/fonttools-4.61.0-cp313-cp313-win_amd64.whl", hash = "sha256:c53b47834ae41e8e4829171cc44fec0fdf125545a15f6da41776b926b9645a9a", size = 2313530, upload-time = "2025-11-28T17:05:14.848Z" }, - { url = "https://files.pythonhosted.org/packages/0c/14/634f7daea5ffe6a5f7a0322ba8e1a0e23c9257b80aa91458107896d1dfc7/fonttools-4.61.0-py3-none-any.whl", hash = "sha256:276f14c560e6f98d24ef7f5f44438e55ff5a67f78fa85236b218462c9f5d0635", size = 1144485, upload-time = "2025-11-28T17:05:47.573Z" }, + { url = "https://files.pythonhosted.org/packages/4b/cf/00ba28b0990982530addb8dc3e9e6f2fa9cb5c20df2abdda7baa755e8fe1/fonttools-4.61.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c", size = 2846454, upload-time = "2025-12-12T17:30:24.938Z" }, + { url = "https://files.pythonhosted.org/packages/5a/ca/468c9a8446a2103ae645d14fee3f610567b7042aba85031c1c65e3ef7471/fonttools-4.61.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e", size = 2398191, upload-time = "2025-12-12T17:30:27.343Z" }, + { url = "https://files.pythonhosted.org/packages/a3/4b/d67eedaed19def5967fade3297fed8161b25ba94699efc124b14fb68cdbc/fonttools-4.61.1-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5", size = 4928410, upload-time = "2025-12-12T17:30:29.771Z" }, + { url = "https://files.pythonhosted.org/packages/b0/8d/6fb3494dfe61a46258cd93d979cf4725ded4eb46c2a4ca35e4490d84daea/fonttools-4.61.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd", size = 4984460, upload-time = "2025-12-12T17:30:32.073Z" }, + { url = "https://files.pythonhosted.org/packages/f7/f1/a47f1d30b3dc00d75e7af762652d4cbc3dff5c2697a0dbd5203c81afd9c3/fonttools-4.61.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3", size = 4925800, upload-time = "2025-12-12T17:30:34.339Z" }, + { url = "https://files.pythonhosted.org/packages/a7/01/e6ae64a0981076e8a66906fab01539799546181e32a37a0257b77e4aa88b/fonttools-4.61.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d", size = 5067859, upload-time = "2025-12-12T17:30:36.593Z" }, + { url = "https://files.pythonhosted.org/packages/73/aa/28e40b8d6809a9b5075350a86779163f074d2b617c15d22343fce81918db/fonttools-4.61.1-cp313-cp313-win32.whl", hash = "sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c", size = 2267821, upload-time = "2025-12-12T17:30:38.478Z" }, + { url = "https://files.pythonhosted.org/packages/1a/59/453c06d1d83dc0951b69ef692d6b9f1846680342927df54e9a1ca91c6f90/fonttools-4.61.1-cp313-cp313-win_amd64.whl", hash = "sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b", size = 2318169, upload-time = "2025-12-12T17:30:40.951Z" }, + { url = "https://files.pythonhosted.org/packages/c7/4e/ce75a57ff3aebf6fc1f4e9d508b8e5810618a33d900ad6c19eb30b290b97/fonttools-4.61.1-py3-none-any.whl", hash = "sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371", size = 1148996, upload-time = "2025-12-12T17:31:21.03Z" }, ] [[package]] @@ -751,14 +751,14 @@ wheels = [ [[package]] name = "jaxtyping" -version = "0.3.3" +version = "0.3.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "wadler-lindig" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e4/1e/827f9e17b26e21c7d4d934fd1a214284ad05663afedd37c21ed105db366b/jaxtyping-0.3.3.tar.gz", hash = "sha256:8003cfd16ba2ad9b47fdda1d982a575299a81ddfc7997ad0e917c87a0897ea86", size = 45484, upload-time = "2025-10-01T13:46:51.933Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/fe/90f8884647073ea77a0fc5ec27e1605288591c89e60e53a6dd1d849ca6fc/jaxtyping-0.3.4.tar.gz", hash = "sha256:b4aac576a1b6c62a363f76f543f21c7cd4c7bb8714816c2c875f28b7abcdb770", size = 45665, upload-time = "2025-12-15T19:26:06.056Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/97/88264b1af140f66ba7ca6eb2f3a108be233ee278bb3f1d5c750243e7458a/jaxtyping-0.3.3-py3-none-any.whl", hash = "sha256:a1c2f0f4351a8deda84b0e3b5c5a50894a1cdae2b82d841279fce4393aff4a7c", size = 55926, upload-time = "2025-10-01T13:46:50.621Z" }, + { url = "https://files.pythonhosted.org/packages/cb/92/f30138ffa65d51791f85ec057eb76ae8eab125ed20f8f337918df3f1c775/jaxtyping-0.3.4-py3-none-any.whl", hash = "sha256:70e438db2f361575d04cccea50f77f9c8fe92f8b2086dc0ce89e5f1658bebaab", size = 56017, upload-time = "2025-12-15T19:26:04.835Z" }, ] [[package]] @@ -814,7 +814,7 @@ wheels = [ [[package]] name = "jupyter-client" -version = "8.6.3" +version = "8.7.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jupyter-core" }, @@ -823,9 +823,9 @@ dependencies = [ { name = "tornado" }, { name = "traitlets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/71/22/bf9f12fdaeae18019a468b68952a60fe6dbab5d67cd2a103cac7659b41ca/jupyter_client-8.6.3.tar.gz", hash = "sha256:35b3a0947c4a6e9d589eb97d7d4cd5e90f910ee73101611f01283732bd6d9419", size = 342019, upload-time = "2024-09-17T10:44:17.613Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a6/27/d10de45e8ad4ce872372c4a3a37b7b35b6b064f6f023a5c14ffcced4d59d/jupyter_client-8.7.0.tar.gz", hash = "sha256:3357212d9cbe01209e59190f67a3a7e1f387a4f4e88d1e0433ad84d7b262531d", size = 344691, upload-time = "2025-12-09T18:37:01.953Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/11/85/b0394e0b6fcccd2c1eeefc230978a6f8cb0c5df1e4cd3e7625735a0d7d1e/jupyter_client-8.6.3-py3-none-any.whl", hash = "sha256:e8a19cc986cc45905ac3362915f410f3af85424b4c0905e94fa5f2cb08e8f23f", size = 106105, upload-time = "2024-09-17T10:44:15.218Z" }, + { url = "https://files.pythonhosted.org/packages/bb/f5/fddaec430367be9d62a7ed125530e133bfd4a1c0350fe221149ee0f2b526/jupyter_client-8.7.0-py3-none-any.whl", hash = "sha256:3671a94fd25e62f5f2f554f5e95389c2294d89822378a5f2dd24353e1494a9e0", size = 106215, upload-time = "2025-12-09T18:37:00.024Z" }, ] [[package]] @@ -906,7 +906,7 @@ wheels = [ [[package]] name = "matplotlib" -version = "3.10.7" +version = "3.10.8" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "contourpy" }, @@ -919,22 +919,22 @@ dependencies = [ { name = "pyparsing" }, { name = "python-dateutil" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ae/e2/d2d5295be2f44c678ebaf3544ba32d20c1f9ef08c49fe47f496180e1db15/matplotlib-3.10.7.tar.gz", hash = "sha256:a06ba7e2a2ef9131c79c49e63dad355d2d878413a0376c1727c8b9335ff731c7", size = 34804865, upload-time = "2025-10-09T00:28:00.669Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/76/d3c6e3a13fe484ebe7718d14e269c9569c4eb0020a968a327acb3b9a8fe6/matplotlib-3.10.8.tar.gz", hash = "sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3", size = 34806269, upload-time = "2025-12-10T22:56:51.155Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/02/9c/207547916a02c78f6bdd83448d9b21afbc42f6379ed887ecf610984f3b4e/matplotlib-3.10.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1d9d3713a237970569156cfb4de7533b7c4eacdd61789726f444f96a0d28f57f", size = 8273212, upload-time = "2025-10-09T00:26:56.752Z" }, - { url = "https://files.pythonhosted.org/packages/bc/d0/b3d3338d467d3fc937f0bb7f256711395cae6f78e22cef0656159950adf0/matplotlib-3.10.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:37a1fea41153dd6ee061d21ab69c9cf2cf543160b1b85d89cd3d2e2a7902ca4c", size = 8128713, upload-time = "2025-10-09T00:26:59.001Z" }, - { url = "https://files.pythonhosted.org/packages/22/ff/6425bf5c20d79aa5b959d1ce9e65f599632345391381c9a104133fe0b171/matplotlib-3.10.7-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b3c4ea4948d93c9c29dc01c0c23eef66f2101bf75158c291b88de6525c55c3d1", size = 8698527, upload-time = "2025-10-09T00:27:00.69Z" }, - { url = "https://files.pythonhosted.org/packages/d0/7f/ccdca06f4c2e6c7989270ed7829b8679466682f4cfc0f8c9986241c023b6/matplotlib-3.10.7-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:22df30ffaa89f6643206cf13877191c63a50e8f800b038bc39bee9d2d4957632", size = 9529690, upload-time = "2025-10-09T00:27:02.664Z" }, - { url = "https://files.pythonhosted.org/packages/b8/95/b80fc2c1f269f21ff3d193ca697358e24408c33ce2b106a7438a45407b63/matplotlib-3.10.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:b69676845a0a66f9da30e87f48be36734d6748024b525ec4710be40194282c84", size = 9593732, upload-time = "2025-10-09T00:27:04.653Z" }, - { url = "https://files.pythonhosted.org/packages/e1/b6/23064a96308b9aeceeffa65e96bcde459a2ea4934d311dee20afde7407a0/matplotlib-3.10.7-cp313-cp313-win_amd64.whl", hash = "sha256:744991e0cc863dd669c8dc9136ca4e6e0082be2070b9d793cbd64bec872a6815", size = 8122727, upload-time = "2025-10-09T00:27:06.814Z" }, - { url = "https://files.pythonhosted.org/packages/b3/a6/2faaf48133b82cf3607759027f82b5c702aa99cdfcefb7f93d6ccf26a424/matplotlib-3.10.7-cp313-cp313-win_arm64.whl", hash = "sha256:fba2974df0bf8ce3c995fa84b79cde38326e0f7b5409e7a3a481c1141340bcf7", size = 7992958, upload-time = "2025-10-09T00:27:08.567Z" }, - { url = "https://files.pythonhosted.org/packages/4a/f0/b018fed0b599bd48d84c08794cb242227fe3341952da102ee9d9682db574/matplotlib-3.10.7-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:932c55d1fa7af4423422cb6a492a31cbcbdbe68fd1a9a3f545aa5e7a143b5355", size = 8316849, upload-time = "2025-10-09T00:27:10.254Z" }, - { url = "https://files.pythonhosted.org/packages/b0/b7/bb4f23856197659f275e11a2a164e36e65e9b48ea3e93c4ec25b4f163198/matplotlib-3.10.7-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5e38c2d581d62ee729a6e144c47a71b3f42fb4187508dbbf4fe71d5612c3433b", size = 8178225, upload-time = "2025-10-09T00:27:12.241Z" }, - { url = "https://files.pythonhosted.org/packages/62/56/0600609893ff277e6f3ab3c0cef4eafa6e61006c058e84286c467223d4d5/matplotlib-3.10.7-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:786656bb13c237bbcebcd402f65f44dd61ead60ee3deb045af429d889c8dbc67", size = 8711708, upload-time = "2025-10-09T00:27:13.879Z" }, - { url = "https://files.pythonhosted.org/packages/d8/1a/6bfecb0cafe94d6658f2f1af22c43b76cf7a1c2f0dc34ef84cbb6809617e/matplotlib-3.10.7-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:09d7945a70ea43bf9248f4b6582734c2fe726723204a76eca233f24cffc7ef67", size = 9541409, upload-time = "2025-10-09T00:27:15.684Z" }, - { url = "https://files.pythonhosted.org/packages/08/50/95122a407d7f2e446fd865e2388a232a23f2b81934960ea802f3171518e4/matplotlib-3.10.7-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d0b181e9fa8daf1d9f2d4c547527b167cb8838fc587deabca7b5c01f97199e84", size = 9594054, upload-time = "2025-10-09T00:27:17.547Z" }, - { url = "https://files.pythonhosted.org/packages/13/76/75b194a43b81583478a81e78a07da8d9ca6ddf50dd0a2ccabf258059481d/matplotlib-3.10.7-cp313-cp313t-win_amd64.whl", hash = "sha256:31963603041634ce1a96053047b40961f7a29eb8f9a62e80cc2c0427aa1d22a2", size = 8200100, upload-time = "2025-10-09T00:27:20.039Z" }, - { url = "https://files.pythonhosted.org/packages/f5/9e/6aefebdc9f8235c12bdeeda44cc0383d89c1e41da2c400caf3ee2073a3ce/matplotlib-3.10.7-cp313-cp313t-win_arm64.whl", hash = "sha256:aebed7b50aa6ac698c90f60f854b47e48cd2252b30510e7a1feddaf5a3f72cbf", size = 8042131, upload-time = "2025-10-09T00:27:21.608Z" }, + { url = "https://files.pythonhosted.org/packages/3d/b9/15fd5541ef4f5b9a17eefd379356cf12175fe577424e7b1d80676516031a/matplotlib-3.10.8-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6", size = 8261076, upload-time = "2025-12-10T22:55:44.648Z" }, + { url = "https://files.pythonhosted.org/packages/8d/a0/2ba3473c1b66b9c74dc7107c67e9008cb1782edbe896d4c899d39ae9cf78/matplotlib-3.10.8-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1", size = 8148794, upload-time = "2025-12-10T22:55:46.252Z" }, + { url = "https://files.pythonhosted.org/packages/75/97/a471f1c3eb1fd6f6c24a31a5858f443891d5127e63a7788678d14e249aea/matplotlib-3.10.8-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486", size = 8718474, upload-time = "2025-12-10T22:55:47.864Z" }, + { url = "https://files.pythonhosted.org/packages/01/be/cd478f4b66f48256f42927d0acbcd63a26a893136456cd079c0cc24fbabf/matplotlib-3.10.8-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:646d95230efb9ca614a7a594d4fcacde0ac61d25e37dd51710b36477594963ce", size = 9549637, upload-time = "2025-12-10T22:55:50.048Z" }, + { url = "https://files.pythonhosted.org/packages/5d/7c/8dc289776eae5109e268c4fb92baf870678dc048a25d4ac903683b86d5bf/matplotlib-3.10.8-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f89c151aab2e2e23cb3fe0acad1e8b82841fd265379c4cecd0f3fcb34c15e0f6", size = 9613678, upload-time = "2025-12-10T22:55:52.21Z" }, + { url = "https://files.pythonhosted.org/packages/64/40/37612487cc8a437d4dd261b32ca21fe2d79510fe74af74e1f42becb1bdb8/matplotlib-3.10.8-cp313-cp313-win_amd64.whl", hash = "sha256:e8ea3e2d4066083e264e75c829078f9e149fa119d27e19acd503de65e0b13149", size = 8142686, upload-time = "2025-12-10T22:55:54.253Z" }, + { url = "https://files.pythonhosted.org/packages/66/52/8d8a8730e968185514680c2a6625943f70269509c3dcfc0dcf7d75928cb8/matplotlib-3.10.8-cp313-cp313-win_arm64.whl", hash = "sha256:c108a1d6fa78a50646029cb6d49808ff0fc1330fda87fa6f6250c6b5369b6645", size = 8012917, upload-time = "2025-12-10T22:55:56.268Z" }, + { url = "https://files.pythonhosted.org/packages/b5/27/51fe26e1062f298af5ef66343d8ef460e090a27fea73036c76c35821df04/matplotlib-3.10.8-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:ad3d9833a64cf48cc4300f2b406c3d0f4f4724a91c0bd5640678a6ba7c102077", size = 8305679, upload-time = "2025-12-10T22:55:57.856Z" }, + { url = "https://files.pythonhosted.org/packages/2c/1e/4de865bc591ac8e3062e835f42dd7fe7a93168d519557837f0e37513f629/matplotlib-3.10.8-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:eb3823f11823deade26ce3b9f40dcb4a213da7a670013929f31d5f5ed1055b22", size = 8198336, upload-time = "2025-12-10T22:55:59.371Z" }, + { url = "https://files.pythonhosted.org/packages/c6/cb/2f7b6e75fb4dce87ef91f60cac4f6e34f4c145ab036a22318ec837971300/matplotlib-3.10.8-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d9050fee89a89ed57b4fb2c1bfac9a3d0c57a0d55aed95949eedbc42070fea39", size = 8731653, upload-time = "2025-12-10T22:56:01.032Z" }, + { url = "https://files.pythonhosted.org/packages/46/b3/bd9c57d6ba670a37ab31fb87ec3e8691b947134b201f881665b28cc039ff/matplotlib-3.10.8-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b44d07310e404ba95f8c25aa5536f154c0a8ec473303535949e52eb71d0a1565", size = 9561356, upload-time = "2025-12-10T22:56:02.95Z" }, + { url = "https://files.pythonhosted.org/packages/c0/3d/8b94a481456dfc9dfe6e39e93b5ab376e50998cddfd23f4ae3b431708f16/matplotlib-3.10.8-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0a33deb84c15ede243aead39f77e990469fff93ad1521163305095b77b72ce4a", size = 9614000, upload-time = "2025-12-10T22:56:05.411Z" }, + { url = "https://files.pythonhosted.org/packages/bd/cd/bc06149fe5585ba800b189a6a654a75f1f127e8aab02fd2be10df7fa500c/matplotlib-3.10.8-cp313-cp313t-win_amd64.whl", hash = "sha256:3a48a78d2786784cc2413e57397981fb45c79e968d99656706018d6e62e57958", size = 8220043, upload-time = "2025-12-10T22:56:07.551Z" }, + { url = "https://files.pythonhosted.org/packages/e3/de/b22cf255abec916562cc04eef457c13e58a1990048de0c0c3604d082355e/matplotlib-3.10.8-cp313-cp313t-win_arm64.whl", hash = "sha256:15d30132718972c2c074cd14638c7f4592bd98719e2308bccea40e0538bc0cb5", size = 8062075, upload-time = "2025-12-10T22:56:09.178Z" }, ] [[package]] @@ -1022,11 +1022,11 @@ wheels = [ [[package]] name = "narwhals" -version = "2.13.0" +version = "2.14.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/89/ea/f82ef99ced4d03c33bb314c9b84a08a0a86c448aaa11ffd6256b99538aa5/narwhals-2.13.0.tar.gz", hash = "sha256:ee94c97f4cf7cfeebbeca8d274784df8b3d7fd3f955ce418af998d405576fdd9", size = 594555, upload-time = "2025-12-01T13:54:05.329Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/84/897fe7b6406d436ef312e57e5a1a13b4a5e7e36d1844e8d934ce8880e3d3/narwhals-2.14.0.tar.gz", hash = "sha256:98be155c3599db4d5c211e565c3190c398c87e7bf5b3cdb157dece67641946e0", size = 600648, upload-time = "2025-12-16T11:29:13.458Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/87/0d/1861d1599571974b15b025e12b142d8e6b42ad66c8a07a89cb0fc21f1e03/narwhals-2.13.0-py3-none-any.whl", hash = "sha256:9b795523c179ca78204e3be53726da374168f906e38de2ff174c2363baaaf481", size = 426407, upload-time = "2025-12-01T13:54:03.861Z" }, + { url = "https://files.pythonhosted.org/packages/79/3e/b8ecc67e178919671695f64374a7ba916cf0adbf86efedc6054f38b5b8ae/narwhals-2.14.0-py3-none-any.whl", hash = "sha256:b56796c9a00179bd757d15282c540024e1d5c910b19b8c9944d836566c030acf", size = 430788, upload-time = "2025-12-16T11:29:11.699Z" }, ] [[package]] @@ -1040,66 +1040,65 @@ wheels = [ [[package]] name = "networkx" -version = "3.6" +version = "3.6.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e8/fc/7b6fd4d22c8c4dc5704430140d8b3f520531d4fe7328b8f8d03f5a7950e8/networkx-3.6.tar.gz", hash = "sha256:285276002ad1f7f7da0f7b42f004bcba70d381e936559166363707fdad3d72ad", size = 2511464, upload-time = "2025-11-24T03:03:47.158Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/07/c7/d64168da60332c17d24c0d2f08bdf3987e8d1ae9d84b5bbd0eec2eb26a55/networkx-3.6-py3-none-any.whl", hash = "sha256:cdb395b105806062473d3be36458d8f1459a4e4b98e236a66c3a48996e07684f", size = 2063713, upload-time = "2025-11-24T03:03:45.21Z" }, + { url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" }, ] [[package]] name = "nodeenv" -version = "1.9.1" +version = "1.10.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437, upload-time = "2024-06-04T18:44:11.171Z" } +sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, + { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, ] [[package]] name = "nodejs-wheel-binaries" -version = "24.11.1" +version = "24.12.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e4/89/da307731fdbb05a5f640b26de5b8ac0dc463fef059162accfc89e32f73bc/nodejs_wheel_binaries-24.11.1.tar.gz", hash = "sha256:413dfffeadfb91edb4d8256545dea797c237bba9b3faefea973cde92d96bb922", size = 8059, upload-time = "2025-11-18T18:21:58.207Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b9/35/d806c2ca66072e36dc340ccdbeb2af7e4f1b5bcc33f1481f00ceed476708/nodejs_wheel_binaries-24.12.0.tar.gz", hash = "sha256:f1b50aa25375e264697dec04b232474906b997c2630c8f499f4caf3692938435", size = 8058, upload-time = "2025-12-11T21:12:26.856Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e4/5f/be5a4112e678143d4c15264d918f9a2dc086905c6426eb44515cf391a958/nodejs_wheel_binaries-24.11.1-py2.py3-none-macosx_13_0_arm64.whl", hash = "sha256:0e14874c3579def458245cdbc3239e37610702b0aa0975c1dc55e2cb80e42102", size = 55114309, upload-time = "2025-11-18T18:21:21.697Z" }, - { url = "https://files.pythonhosted.org/packages/fa/1c/2e9d6af2ea32b65928c42b3e5baa7a306870711d93c3536cb25fc090a80d/nodejs_wheel_binaries-24.11.1-py2.py3-none-macosx_13_0_x86_64.whl", hash = "sha256:c2741525c9874b69b3e5a6d6c9179a6fe484ea0c3d5e7b7c01121c8e5d78b7e2", size = 55285957, upload-time = "2025-11-18T18:21:27.177Z" }, - { url = "https://files.pythonhosted.org/packages/d0/79/35696d7ba41b1bd35ef8682f13d46ba38c826c59e58b86b267458eb53d87/nodejs_wheel_binaries-24.11.1-py2.py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:5ef598101b0fb1c2bf643abb76dfbf6f76f1686198ed17ae46009049ee83c546", size = 59645875, upload-time = "2025-11-18T18:21:33.004Z" }, - { url = "https://files.pythonhosted.org/packages/b4/98/2a9694adee0af72bc602a046b0632a0c89e26586090c558b1c9199b187cc/nodejs_wheel_binaries-24.11.1-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:cde41d5e4705266688a8d8071debf4f8a6fcea264c61292782672ee75a6905f9", size = 60140941, upload-time = "2025-11-18T18:21:37.228Z" }, - { url = "https://files.pythonhosted.org/packages/d0/d6/573e5e2cba9d934f5f89d0beab00c3315e2e6604eb4df0fcd1d80c5a07a8/nodejs_wheel_binaries-24.11.1-py2.py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:78bc5bb889313b565df8969bb7423849a9c7fc218bf735ff0ce176b56b3e96f0", size = 61644243, upload-time = "2025-11-18T18:21:43.325Z" }, - { url = "https://files.pythonhosted.org/packages/c7/e6/643234d5e94067df8ce8d7bba10f3804106668f7a1050aeb10fdd226ead4/nodejs_wheel_binaries-24.11.1-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:c79a7e43869ccecab1cae8183778249cceb14ca2de67b5650b223385682c6239", size = 62225657, upload-time = "2025-11-18T18:21:47.708Z" }, - { url = "https://files.pythonhosted.org/packages/4d/1c/2fb05127102a80225cab7a75c0e9edf88a0a1b79f912e1e36c7c1aaa8f4e/nodejs_wheel_binaries-24.11.1-py2.py3-none-win_amd64.whl", hash = "sha256:10197b1c9c04d79403501766f76508b0dac101ab34371ef8a46fcf51773497d0", size = 41322308, upload-time = "2025-11-18T18:21:51.347Z" }, - { url = "https://files.pythonhosted.org/packages/ad/b7/bc0cdbc2cc3a66fcac82c79912e135a0110b37b790a14c477f18e18d90cd/nodejs_wheel_binaries-24.11.1-py2.py3-none-win_arm64.whl", hash = "sha256:376b9ea1c4bc1207878975dfeb604f7aa5668c260c6154dcd2af9d42f7734116", size = 39026497, upload-time = "2025-11-18T18:21:54.634Z" }, + { url = "https://files.pythonhosted.org/packages/c3/3b/9d6f044319cd5b1e98f07c41e2465b58cadc1c9c04a74c891578f3be6cb5/nodejs_wheel_binaries-24.12.0-py2.py3-none-macosx_13_0_arm64.whl", hash = "sha256:7564ddea0a87eff34e9b3ef71764cc2a476a8f09a5cccfddc4691148b0a47338", size = 55125859, upload-time = "2025-12-11T21:11:58.132Z" }, + { url = "https://files.pythonhosted.org/packages/48/a5/f5722bf15c014e2f476d7c76bce3d55c341d19122d8a5d86454db32a61a4/nodejs_wheel_binaries-24.12.0-py2.py3-none-macosx_13_0_x86_64.whl", hash = "sha256:8ff929c4669e64613ceb07f5bbd758d528c3563820c75d5de3249eb452c0c0ab", size = 55309035, upload-time = "2025-12-11T21:12:01.754Z" }, + { url = "https://files.pythonhosted.org/packages/a9/61/68d39a6f1b5df67805969fd2829ba7e80696c9af19537856ec912050a2be/nodejs_wheel_binaries-24.12.0-py2.py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:6ebacefa8891bc456ad3655e6bce0af7e20ba08662f79d9109986faeb703fd6f", size = 59661017, upload-time = "2025-12-11T21:12:05.268Z" }, + { url = "https://files.pythonhosted.org/packages/16/a1/31aad16f55a5e44ca7ea62d1367fc69f4b6e1dba67f58a0a41d0ed854540/nodejs_wheel_binaries-24.12.0-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:3292649a03682ccbfa47f7b04d3e4240e8c46ef04dc941b708f20e4e6a764f75", size = 60159770, upload-time = "2025-12-11T21:12:08.696Z" }, + { url = "https://files.pythonhosted.org/packages/c4/5e/b7c569aa1862690ca4d4daf3a64cafa1ea6ce667a9e3ae3918c56e127d9b/nodejs_wheel_binaries-24.12.0-py2.py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7fb83df312955ea355ba7f8cbd7055c477249a131d3cb43b60e4aeb8f8c730b1", size = 61653561, upload-time = "2025-12-11T21:12:12.575Z" }, + { url = "https://files.pythonhosted.org/packages/71/87/567f58d7ba69ff0208be849b37be0f2c2e99c69e49334edd45ff44f00043/nodejs_wheel_binaries-24.12.0-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:2473c819448fedd7b036dde236b09f3c8bbf39fbbd0c1068790a0498800f498b", size = 62238331, upload-time = "2025-12-11T21:12:16.143Z" }, + { url = "https://files.pythonhosted.org/packages/6a/9d/c6492188ce8de90093c6755a4a63bb6b2b4efb17094cb4f9a9a49c73ed3b/nodejs_wheel_binaries-24.12.0-py2.py3-none-win_amd64.whl", hash = "sha256:2090d59f75a68079fabc9b86b14df8238b9aecb9577966dc142ce2a23a32e9bb", size = 41342076, upload-time = "2025-12-11T21:12:20.618Z" }, + { url = "https://files.pythonhosted.org/packages/df/af/cd3290a647df567645353feed451ef4feaf5844496ced69c4dcb84295ff4/nodejs_wheel_binaries-24.12.0-py2.py3-none-win_arm64.whl", hash = "sha256:d0c2273b667dd7e3f55e369c0085957b702144b1b04bfceb7ce2411e58333757", size = 39048104, upload-time = "2025-12-11T21:12:23.495Z" }, ] [[package]] name = "numpy" -version = "2.3.5" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/76/65/21b3bc86aac7b8f2862db1e808f1ea22b028e30a225a34a5ede9bf8678f2/numpy-2.3.5.tar.gz", hash = "sha256:784db1dcdab56bf0517743e746dfb0f885fc68d948aba86eeec2cba234bdf1c0", size = 20584950, upload-time = "2025-11-16T22:52:42.067Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/db/69/9cde09f36da4b5a505341180a3f2e6fadc352fd4d2b7096ce9778db83f1a/numpy-2.3.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d0f23b44f57077c1ede8c5f26b30f706498b4862d3ff0a7298b8411dd2f043ff", size = 16728251, upload-time = "2025-11-16T22:50:19.013Z" }, - { url = "https://files.pythonhosted.org/packages/79/fb/f505c95ceddd7027347b067689db71ca80bd5ecc926f913f1a23e65cf09b/numpy-2.3.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:aa5bc7c5d59d831d9773d1170acac7893ce3a5e130540605770ade83280e7188", size = 12254652, upload-time = "2025-11-16T22:50:21.487Z" }, - { url = "https://files.pythonhosted.org/packages/78/da/8c7738060ca9c31b30e9301ee0cf6c5ffdbf889d9593285a1cead337f9a5/numpy-2.3.5-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:ccc933afd4d20aad3c00bcef049cb40049f7f196e0397f1109dba6fed63267b0", size = 5083172, upload-time = "2025-11-16T22:50:24.562Z" }, - { url = "https://files.pythonhosted.org/packages/a4/b4/ee5bb2537fb9430fd2ef30a616c3672b991a4129bb1c7dcc42aa0abbe5d7/numpy-2.3.5-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:afaffc4393205524af9dfa400fa250143a6c3bc646c08c9f5e25a9f4b4d6a903", size = 6622990, upload-time = "2025-11-16T22:50:26.47Z" }, - { url = "https://files.pythonhosted.org/packages/95/03/dc0723a013c7d7c19de5ef29e932c3081df1c14ba582b8b86b5de9db7f0f/numpy-2.3.5-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c75442b2209b8470d6d5d8b1c25714270686f14c749028d2199c54e29f20b4d", size = 14248902, upload-time = "2025-11-16T22:50:28.861Z" }, - { url = "https://files.pythonhosted.org/packages/f5/10/ca162f45a102738958dcec8023062dad0cbc17d1ab99d68c4e4a6c45fb2b/numpy-2.3.5-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11e06aa0af8c0f05104d56450d6093ee639e15f24ecf62d417329d06e522e017", size = 16597430, upload-time = "2025-11-16T22:50:31.56Z" }, - { url = "https://files.pythonhosted.org/packages/2a/51/c1e29be863588db58175175f057286900b4b3327a1351e706d5e0f8dd679/numpy-2.3.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ed89927b86296067b4f81f108a2271d8926467a8868e554eaf370fc27fa3ccaf", size = 16024551, upload-time = "2025-11-16T22:50:34.242Z" }, - { url = "https://files.pythonhosted.org/packages/83/68/8236589d4dbb87253d28259d04d9b814ec0ecce7cb1c7fed29729f4c3a78/numpy-2.3.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:51c55fe3451421f3a6ef9a9c1439e82101c57a2c9eab9feb196a62b1a10b58ce", size = 18533275, upload-time = "2025-11-16T22:50:37.651Z" }, - { url = "https://files.pythonhosted.org/packages/40/56/2932d75b6f13465239e3b7b7e511be27f1b8161ca2510854f0b6e521c395/numpy-2.3.5-cp313-cp313-win32.whl", hash = "sha256:1978155dd49972084bd6ef388d66ab70f0c323ddee6f693d539376498720fb7e", size = 6277637, upload-time = "2025-11-16T22:50:40.11Z" }, - { url = "https://files.pythonhosted.org/packages/0c/88/e2eaa6cffb115b85ed7c7c87775cb8bcf0816816bc98ca8dbfa2ee33fe6e/numpy-2.3.5-cp313-cp313-win_amd64.whl", hash = "sha256:00dc4e846108a382c5869e77c6ed514394bdeb3403461d25a829711041217d5b", size = 12779090, upload-time = "2025-11-16T22:50:42.503Z" }, - { url = "https://files.pythonhosted.org/packages/8f/88/3f41e13a44ebd4034ee17baa384acac29ba6a4fcc2aca95f6f08ca0447d1/numpy-2.3.5-cp313-cp313-win_arm64.whl", hash = "sha256:0472f11f6ec23a74a906a00b48a4dcf3849209696dff7c189714511268d103ae", size = 10194710, upload-time = "2025-11-16T22:50:44.971Z" }, - { url = "https://files.pythonhosted.org/packages/13/cb/71744144e13389d577f867f745b7df2d8489463654a918eea2eeb166dfc9/numpy-2.3.5-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:414802f3b97f3c1eef41e530aaba3b3c1620649871d8cb38c6eaff034c2e16bd", size = 16827292, upload-time = "2025-11-16T22:50:47.715Z" }, - { url = "https://files.pythonhosted.org/packages/71/80/ba9dc6f2a4398e7f42b708a7fdc841bb638d353be255655498edbf9a15a8/numpy-2.3.5-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5ee6609ac3604fa7780e30a03e5e241a7956f8e2fcfe547d51e3afa5247ac47f", size = 12378897, upload-time = "2025-11-16T22:50:51.327Z" }, - { url = "https://files.pythonhosted.org/packages/2e/6d/db2151b9f64264bcceccd51741aa39b50150de9b602d98ecfe7e0c4bff39/numpy-2.3.5-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:86d835afea1eaa143012a2d7a3f45a3adce2d7adc8b4961f0b362214d800846a", size = 5207391, upload-time = "2025-11-16T22:50:54.542Z" }, - { url = "https://files.pythonhosted.org/packages/80/ae/429bacace5ccad48a14c4ae5332f6aa8ab9f69524193511d60ccdfdc65fa/numpy-2.3.5-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:30bc11310e8153ca664b14c5f1b73e94bd0503681fcf136a163de856f3a50139", size = 6721275, upload-time = "2025-11-16T22:50:56.794Z" }, - { url = "https://files.pythonhosted.org/packages/74/5b/1919abf32d8722646a38cd527bc3771eb229a32724ee6ba340ead9b92249/numpy-2.3.5-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1062fde1dcf469571705945b0f221b73928f34a20c904ffb45db101907c3454e", size = 14306855, upload-time = "2025-11-16T22:50:59.208Z" }, - { url = "https://files.pythonhosted.org/packages/a5/87/6831980559434973bebc30cd9c1f21e541a0f2b0c280d43d3afd909b66d0/numpy-2.3.5-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ce581db493ea1a96c0556360ede6607496e8bf9b3a8efa66e06477267bc831e9", size = 16657359, upload-time = "2025-11-16T22:51:01.991Z" }, - { url = "https://files.pythonhosted.org/packages/dd/91/c797f544491ee99fd00495f12ebb7802c440c1915811d72ac5b4479a3356/numpy-2.3.5-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:cc8920d2ec5fa99875b670bb86ddeb21e295cb07aa331810d9e486e0b969d946", size = 16093374, upload-time = "2025-11-16T22:51:05.291Z" }, - { url = "https://files.pythonhosted.org/packages/74/a6/54da03253afcbe7a72785ec4da9c69fb7a17710141ff9ac5fcb2e32dbe64/numpy-2.3.5-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:9ee2197ef8c4f0dfe405d835f3b6a14f5fee7782b5de51ba06fb65fc9b36e9f1", size = 18594587, upload-time = "2025-11-16T22:51:08.585Z" }, - { url = "https://files.pythonhosted.org/packages/80/e9/aff53abbdd41b0ecca94285f325aff42357c6b5abc482a3fcb4994290b18/numpy-2.3.5-cp313-cp313t-win32.whl", hash = "sha256:70b37199913c1bd300ff6e2693316c6f869c7ee16378faf10e4f5e3275b299c3", size = 6405940, upload-time = "2025-11-16T22:51:11.541Z" }, - { url = "https://files.pythonhosted.org/packages/d5/81/50613fec9d4de5480de18d4f8ef59ad7e344d497edbef3cfd80f24f98461/numpy-2.3.5-cp313-cp313t-win_amd64.whl", hash = "sha256:b501b5fa195cc9e24fe102f21ec0a44dffc231d2af79950b451e0d99cea02234", size = 12920341, upload-time = "2025-11-16T22:51:14.312Z" }, - { url = "https://files.pythonhosted.org/packages/bb/ab/08fd63b9a74303947f34f0bd7c5903b9c5532c2d287bead5bdf4c556c486/numpy-2.3.5-cp313-cp313t-win_arm64.whl", hash = "sha256:a80afd79f45f3c4a7d341f13acbe058d1ca8ac017c165d3fa0d3de6bc1a079d7", size = 10262507, upload-time = "2025-11-16T22:51:16.846Z" }, +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a4/7a/6a3d14e205d292b738db449d0de649b373a59edb0d0b4493821d0a3e8718/numpy-2.4.0.tar.gz", hash = "sha256:6e504f7b16118198f138ef31ba24d985b124c2c469fe8467007cf30fd992f934", size = 20685720, upload-time = "2025-12-20T16:18:19.023Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/0d/853fd96372eda07c824d24adf02e8bc92bb3731b43a9b2a39161c3667cc4/numpy-2.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a152d86a3ae00ba5f47b3acf3b827509fd0b6cb7d3259665e63dafbad22a75ea", size = 16649088, upload-time = "2025-12-20T16:16:31.421Z" }, + { url = "https://files.pythonhosted.org/packages/e3/37/cc636f1f2a9f585434e20a3e6e63422f70bfe4f7f6698e941db52ea1ac9a/numpy-2.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:39b19251dec4de8ff8496cd0806cbe27bf0684f765abb1f4809554de93785f2d", size = 12364065, upload-time = "2025-12-20T16:16:33.491Z" }, + { url = "https://files.pythonhosted.org/packages/ed/69/0b78f37ca3690969beee54103ce5f6021709134e8020767e93ba691a72f1/numpy-2.4.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:009bd0ea12d3c784b6639a8457537016ce5172109e585338e11334f6a7bb88ee", size = 5192640, upload-time = "2025-12-20T16:16:35.636Z" }, + { url = "https://files.pythonhosted.org/packages/1d/2a/08569f8252abf590294dbb09a430543ec8f8cc710383abfb3e75cc73aeda/numpy-2.4.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:5fe44e277225fd3dff6882d86d3d447205d43532c3627313d17e754fb3905a0e", size = 6541556, upload-time = "2025-12-20T16:16:37.276Z" }, + { url = "https://files.pythonhosted.org/packages/93/e9/a949885a4e177493d61519377952186b6cbfdf1d6002764c664ba28349b5/numpy-2.4.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f935c4493eda9069851058fa0d9e39dbf6286be690066509305e52912714dbb2", size = 14396562, upload-time = "2025-12-20T16:16:38.953Z" }, + { url = "https://files.pythonhosted.org/packages/99/98/9d4ad53b0e9ef901c2ef1d550d2136f5ac42d3fd2988390a6def32e23e48/numpy-2.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8cfa5f29a695cb7438965e6c3e8d06e0416060cf0d709c1b1c1653a939bf5c2a", size = 16351719, upload-time = "2025-12-20T16:16:41.503Z" }, + { url = "https://files.pythonhosted.org/packages/28/de/5f3711a38341d6e8dd619f6353251a0cdd07f3d6d101a8fd46f4ef87f895/numpy-2.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ba0cb30acd3ef11c94dc27fbfba68940652492bc107075e7ffe23057f9425681", size = 16176053, upload-time = "2025-12-20T16:16:44.552Z" }, + { url = "https://files.pythonhosted.org/packages/2a/5b/2a3753dc43916501b4183532e7ace862e13211042bceafa253afb5c71272/numpy-2.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:60e8c196cd82cbbd4f130b5290007e13e6de3eca79f0d4d38014769d96a7c475", size = 18277859, upload-time = "2025-12-20T16:16:47.174Z" }, + { url = "https://files.pythonhosted.org/packages/2c/c5/a18bcdd07a941db3076ef489d036ab16d2bfc2eae0cf27e5a26e29189434/numpy-2.4.0-cp313-cp313-win32.whl", hash = "sha256:5f48cb3e88fbc294dc90e215d86fbaf1c852c63dbdb6c3a3e63f45c4b57f7344", size = 5953849, upload-time = "2025-12-20T16:16:49.554Z" }, + { url = "https://files.pythonhosted.org/packages/4f/f1/719010ff8061da6e8a26e1980cf090412d4f5f8060b31f0c45d77dd67a01/numpy-2.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:a899699294f28f7be8992853c0c60741f16ff199205e2e6cdca155762cbaa59d", size = 12302840, upload-time = "2025-12-20T16:16:51.227Z" }, + { url = "https://files.pythonhosted.org/packages/f5/5a/b3d259083ed8b4d335270c76966cb6cf14a5d1b69e1a608994ac57a659e6/numpy-2.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:9198f447e1dc5647d07c9a6bbe2063cc0132728cc7175b39dbc796da5b54920d", size = 10308509, upload-time = "2025-12-20T16:16:53.313Z" }, + { url = "https://files.pythonhosted.org/packages/31/01/95edcffd1bb6c0633df4e808130545c4f07383ab629ac7e316fb44fff677/numpy-2.4.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:74623f2ab5cc3f7c886add4f735d1031a1d2be4a4ae63c0546cfd74e7a31ddf6", size = 12491815, upload-time = "2025-12-20T16:16:55.496Z" }, + { url = "https://files.pythonhosted.org/packages/59/ea/5644b8baa92cc1c7163b4b4458c8679852733fa74ca49c942cfa82ded4e0/numpy-2.4.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:0804a8e4ab070d1d35496e65ffd3cf8114c136a2b81f61dfab0de4b218aacfd5", size = 5320321, upload-time = "2025-12-20T16:16:57.468Z" }, + { url = "https://files.pythonhosted.org/packages/26/4e/e10938106d70bc21319bd6a86ae726da37edc802ce35a3a71ecdf1fdfe7f/numpy-2.4.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:02a2038eb27f9443a8b266a66911e926566b5a6ffd1a689b588f7f35b81e7dc3", size = 6641635, upload-time = "2025-12-20T16:16:59.379Z" }, + { url = "https://files.pythonhosted.org/packages/b3/8d/a8828e3eaf5c0b4ab116924df82f24ce3416fa38d0674d8f708ddc6c8aac/numpy-2.4.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1889b3a3f47a7b5bee16bc25a2145bd7cb91897f815ce3499db64c7458b6d91d", size = 14456053, upload-time = "2025-12-20T16:17:01.768Z" }, + { url = "https://files.pythonhosted.org/packages/68/a1/17d97609d87d4520aa5ae2dcfb32305654550ac6a35effb946d303e594ce/numpy-2.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85eef4cb5625c47ee6425c58a3502555e10f45ee973da878ac8248ad58c136f3", size = 16401702, upload-time = "2025-12-20T16:17:04.235Z" }, + { url = "https://files.pythonhosted.org/packages/18/32/0f13c1b2d22bea1118356b8b963195446f3af124ed7a5adfa8fdecb1b6ca/numpy-2.4.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6dc8b7e2f4eb184b37655195f421836cfae6f58197b67e3ffc501f1333d993fa", size = 16242493, upload-time = "2025-12-20T16:17:06.856Z" }, + { url = "https://files.pythonhosted.org/packages/ae/23/48f21e3d309fbc137c068a1475358cbd3a901b3987dcfc97a029ab3068e2/numpy-2.4.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:44aba2f0cafd287871a495fb3163408b0bd25bbce135c6f621534a07f4f7875c", size = 18324222, upload-time = "2025-12-20T16:17:09.392Z" }, + { url = "https://files.pythonhosted.org/packages/ac/52/41f3d71296a3dcaa4f456aaa3c6fc8e745b43d0552b6bde56571bb4b4a0f/numpy-2.4.0-cp313-cp313t-win32.whl", hash = "sha256:20c115517513831860c573996e395707aa9fb691eb179200125c250e895fcd93", size = 6076216, upload-time = "2025-12-20T16:17:11.437Z" }, + { url = "https://files.pythonhosted.org/packages/35/ff/46fbfe60ab0710d2a2b16995f708750307d30eccbb4c38371ea9e986866e/numpy-2.4.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b48e35f4ab6f6a7597c46e301126ceba4c44cd3280e3750f85db48b082624fa4", size = 12444263, upload-time = "2025-12-20T16:17:13.182Z" }, + { url = "https://files.pythonhosted.org/packages/a3/e3/9189ab319c01d2ed556c932ccf55064c5d75bb5850d1df7a482ce0badead/numpy-2.4.0-cp313-cp313t-win_arm64.whl", hash = "sha256:4d1cfce39e511069b11e67cd0bd78ceff31443b7c9e5c04db73c7a19f572967c", size = 10378265, upload-time = "2025-12-20T16:17:15.211Z" }, ] [[package]] @@ -1358,7 +1357,7 @@ wheels = [ [[package]] name = "pre-commit" -version = "4.5.0" +version = "4.5.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cfgv" }, @@ -1367,9 +1366,9 @@ dependencies = [ { name = "pyyaml" }, { name = "virtualenv" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f4/9b/6a4ffb4ed980519da959e1cf3122fc6cb41211daa58dbae1c73c0e519a37/pre_commit-4.5.0.tar.gz", hash = "sha256:dc5a065e932b19fc1d4c653c6939068fe54325af8e741e74e88db4d28a4dd66b", size = 198428, upload-time = "2025-11-22T21:02:42.304Z" } +sdist = { url = "https://files.pythonhosted.org/packages/40/f1/6d86a29246dfd2e9b6237f0b5823717f60cad94d47ddc26afa916d21f525/pre_commit-4.5.1.tar.gz", hash = "sha256:eb545fcff725875197837263e977ea257a402056661f09dae08e4b149b030a61", size = 198232, upload-time = "2025-12-16T21:14:33.552Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5d/c4/b2d28e9d2edf4f1713eb3c29307f1a63f3d67cf09bdda29715a36a68921a/pre_commit-4.5.0-py2.py3-none-any.whl", hash = "sha256:25e2ce09595174d9c97860a95609f9f852c0614ba602de3561e267547f2335e1", size = 226429, upload-time = "2025-11-22T21:02:40.836Z" }, + { url = "https://files.pythonhosted.org/packages/5d/19/fd3ef348460c80af7bb4669ea7926651d1f95c23ff2df18b9d24bab4f3fa/pre_commit-4.5.1-py2.py3-none-any.whl", hash = "sha256:3b3afd891e97337708c1674210f8eba659b52a38ea5f822ff142d10786221f77", size = 226437, upload-time = "2025-12-16T21:14:32.409Z" }, ] [[package]] @@ -1583,7 +1582,7 @@ wheels = [ [[package]] name = "pytest" -version = "9.0.1" +version = "9.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, @@ -1592,9 +1591,9 @@ dependencies = [ { name = "pluggy" }, { name = "pygments" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/07/56/f013048ac4bc4c1d9be45afd4ab209ea62822fb1598f40687e6bf45dcea4/pytest-9.0.1.tar.gz", hash = "sha256:3e9c069ea73583e255c3b21cf46b8d3c56f6e3a1a8f6da94ccb0fcf57b9d73c8", size = 1564125, upload-time = "2025-11-12T13:05:09.333Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0b/8b/6300fb80f858cda1c51ffa17075df5d846757081d11ab4aa35cef9e6258b/pytest-9.0.1-py3-none-any.whl", hash = "sha256:67be0030d194df2dfa7b556f2e56fb3c3315bd5c8822c6951162b92b32ce7dad", size = 373668, upload-time = "2025-11-12T13:05:07.379Z" }, + { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] [[package]] @@ -1808,28 +1807,28 @@ wheels = [ [[package]] name = "ruff" -version = "0.14.8" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ed/d9/f7a0c4b3a2bf2556cd5d99b05372c29980249ef71e8e32669ba77428c82c/ruff-0.14.8.tar.gz", hash = "sha256:774ed0dd87d6ce925e3b8496feb3a00ac564bea52b9feb551ecd17e0a23d1eed", size = 5765385, upload-time = "2025-12-04T15:06:17.669Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/48/b8/9537b52010134b1d2b72870cc3f92d5fb759394094741b09ceccae183fbe/ruff-0.14.8-py3-none-linux_armv6l.whl", hash = "sha256:ec071e9c82eca417f6111fd39f7043acb53cd3fde9b1f95bbed745962e345afb", size = 13441540, upload-time = "2025-12-04T15:06:14.896Z" }, - { url = "https://files.pythonhosted.org/packages/24/00/99031684efb025829713682012b6dd37279b1f695ed1b01725f85fd94b38/ruff-0.14.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8cdb162a7159f4ca36ce980a18c43d8f036966e7f73f866ac8f493b75e0c27e9", size = 13669384, upload-time = "2025-12-04T15:06:51.809Z" }, - { url = "https://files.pythonhosted.org/packages/72/64/3eb5949169fc19c50c04f28ece2c189d3b6edd57e5b533649dae6ca484fe/ruff-0.14.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2e2fcbefe91f9fad0916850edf0854530c15bd1926b6b779de47e9ab619ea38f", size = 12806917, upload-time = "2025-12-04T15:06:08.925Z" }, - { url = "https://files.pythonhosted.org/packages/c4/08/5250babb0b1b11910f470370ec0cbc67470231f7cdc033cee57d4976f941/ruff-0.14.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9d70721066a296f45786ec31916dc287b44040f553da21564de0ab4d45a869b", size = 13256112, upload-time = "2025-12-04T15:06:23.498Z" }, - { url = "https://files.pythonhosted.org/packages/78/4c/6c588e97a8e8c2d4b522c31a579e1df2b4d003eddfbe23d1f262b1a431ff/ruff-0.14.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2c87e09b3cd9d126fc67a9ecd3b5b1d3ded2b9c7fce3f16e315346b9d05cfb52", size = 13227559, upload-time = "2025-12-04T15:06:33.432Z" }, - { url = "https://files.pythonhosted.org/packages/23/ce/5f78cea13eda8eceac71b5f6fa6e9223df9b87bb2c1891c166d1f0dce9f1/ruff-0.14.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d62cb310c4fbcb9ee4ac023fe17f984ae1e12b8a4a02e3d21489f9a2a5f730c", size = 13896379, upload-time = "2025-12-04T15:06:02.687Z" }, - { url = "https://files.pythonhosted.org/packages/cf/79/13de4517c4dadce9218a20035b21212a4c180e009507731f0d3b3f5df85a/ruff-0.14.8-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1af35c2d62633d4da0521178e8a2641c636d2a7153da0bac1b30cfd4ccd91344", size = 15372786, upload-time = "2025-12-04T15:06:29.828Z" }, - { url = "https://files.pythonhosted.org/packages/00/06/33df72b3bb42be8a1c3815fd4fae83fa2945fc725a25d87ba3e42d1cc108/ruff-0.14.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:25add4575ffecc53d60eed3f24b1e934493631b48ebbc6ebaf9d8517924aca4b", size = 14990029, upload-time = "2025-12-04T15:06:36.812Z" }, - { url = "https://files.pythonhosted.org/packages/64/61/0f34927bd90925880394de0e081ce1afab66d7b3525336f5771dcf0cb46c/ruff-0.14.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4c943d847b7f02f7db4201a0600ea7d244d8a404fbb639b439e987edcf2baf9a", size = 14407037, upload-time = "2025-12-04T15:06:39.979Z" }, - { url = "https://files.pythonhosted.org/packages/96/bc/058fe0aefc0fbf0d19614cb6d1a3e2c048f7dc77ca64957f33b12cfdc5ef/ruff-0.14.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb6e8bf7b4f627548daa1b69283dac5a296bfe9ce856703b03130732e20ddfe2", size = 14102390, upload-time = "2025-12-04T15:06:46.372Z" }, - { url = "https://files.pythonhosted.org/packages/af/a4/e4f77b02b804546f4c17e8b37a524c27012dd6ff05855d2243b49a7d3cb9/ruff-0.14.8-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:7aaf2974f378e6b01d1e257c6948207aec6a9b5ba53fab23d0182efb887a0e4a", size = 14230793, upload-time = "2025-12-04T15:06:20.497Z" }, - { url = "https://files.pythonhosted.org/packages/3f/52/bb8c02373f79552e8d087cedaffad76b8892033d2876c2498a2582f09dcf/ruff-0.14.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e5758ca513c43ad8a4ef13f0f081f80f08008f410790f3611a21a92421ab045b", size = 13160039, upload-time = "2025-12-04T15:06:49.06Z" }, - { url = "https://files.pythonhosted.org/packages/1f/ad/b69d6962e477842e25c0b11622548df746290cc6d76f9e0f4ed7456c2c31/ruff-0.14.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:f74f7ba163b6e85a8d81a590363bf71618847e5078d90827749bfda1d88c9cdf", size = 13205158, upload-time = "2025-12-04T15:06:54.574Z" }, - { url = "https://files.pythonhosted.org/packages/06/63/54f23da1315c0b3dfc1bc03fbc34e10378918a20c0b0f086418734e57e74/ruff-0.14.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:eed28f6fafcc9591994c42254f5a5c5ca40e69a30721d2ab18bb0bb3baac3ab6", size = 13469550, upload-time = "2025-12-04T15:05:59.209Z" }, - { url = "https://files.pythonhosted.org/packages/70/7d/a4d7b1961e4903bc37fffb7ddcfaa7beb250f67d97cfd1ee1d5cddb1ec90/ruff-0.14.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:21d48fa744c9d1cb8d71eb0a740c4dd02751a5de9db9a730a8ef75ca34cf138e", size = 14211332, upload-time = "2025-12-04T15:06:06.027Z" }, - { url = "https://files.pythonhosted.org/packages/5d/93/2a5063341fa17054e5c86582136e9895db773e3c2ffb770dde50a09f35f0/ruff-0.14.8-py3-none-win32.whl", hash = "sha256:15f04cb45c051159baebb0f0037f404f1dc2f15a927418f29730f411a79bc4e7", size = 13151890, upload-time = "2025-12-04T15:06:11.668Z" }, - { url = "https://files.pythonhosted.org/packages/02/1c/65c61a0859c0add13a3e1cbb6024b42de587456a43006ca2d4fd3d1618fe/ruff-0.14.8-py3-none-win_amd64.whl", hash = "sha256:9eeb0b24242b5bbff3011409a739929f497f3fb5fe3b5698aba5e77e8c833097", size = 14537826, upload-time = "2025-12-04T15:06:26.409Z" }, - { url = "https://files.pythonhosted.org/packages/6d/63/8b41cea3afd7f58eb64ac9251668ee0073789a3bc9ac6f816c8c6fef986d/ruff-0.14.8-py3-none-win_arm64.whl", hash = "sha256:965a582c93c63fe715fd3e3f8aa37c4b776777203d8e1d8aa3cc0c14424a4b99", size = 13634522, upload-time = "2025-12-04T15:06:43.212Z" }, +version = "0.14.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/08/52232a877978dd8f9cf2aeddce3e611b40a63287dfca29b6b8da791f5e8d/ruff-0.14.10.tar.gz", hash = "sha256:9a2e830f075d1a42cd28420d7809ace390832a490ed0966fe373ba288e77aaf4", size = 5859763, upload-time = "2025-12-18T19:28:57.98Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/01/933704d69f3f05ee16ef11406b78881733c186fe14b6a46b05cfcaf6d3b2/ruff-0.14.10-py3-none-linux_armv6l.whl", hash = "sha256:7a3ce585f2ade3e1f29ec1b92df13e3da262178df8c8bdf876f48fa0e8316c49", size = 13527080, upload-time = "2025-12-18T19:29:25.642Z" }, + { url = "https://files.pythonhosted.org/packages/df/58/a0349197a7dfa603ffb7f5b0470391efa79ddc327c1e29c4851e85b09cc5/ruff-0.14.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:674f9be9372907f7257c51f1d4fc902cb7cf014b9980152b802794317941f08f", size = 13797320, upload-time = "2025-12-18T19:29:02.571Z" }, + { url = "https://files.pythonhosted.org/packages/7b/82/36be59f00a6082e38c23536df4e71cdbc6af8d7c707eade97fcad5c98235/ruff-0.14.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d85713d522348837ef9df8efca33ccb8bd6fcfc86a2cde3ccb4bc9d28a18003d", size = 12918434, upload-time = "2025-12-18T19:28:51.202Z" }, + { url = "https://files.pythonhosted.org/packages/a6/00/45c62a7f7e34da92a25804f813ebe05c88aa9e0c25e5cb5a7d23dd7450e3/ruff-0.14.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6987ebe0501ae4f4308d7d24e2d0fe3d7a98430f5adfd0f1fead050a740a3a77", size = 13371961, upload-time = "2025-12-18T19:29:04.991Z" }, + { url = "https://files.pythonhosted.org/packages/40/31/a5906d60f0405f7e57045a70f2d57084a93ca7425f22e1d66904769d1628/ruff-0.14.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:16a01dfb7b9e4eee556fbfd5392806b1b8550c9b4a9f6acd3dbe6812b193c70a", size = 13275629, upload-time = "2025-12-18T19:29:21.381Z" }, + { url = "https://files.pythonhosted.org/packages/3e/60/61c0087df21894cf9d928dc04bcd4fb10e8b2e8dca7b1a276ba2155b2002/ruff-0.14.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7165d31a925b7a294465fa81be8c12a0e9b60fb02bf177e79067c867e71f8b1f", size = 14029234, upload-time = "2025-12-18T19:29:00.132Z" }, + { url = "https://files.pythonhosted.org/packages/44/84/77d911bee3b92348b6e5dab5a0c898d87084ea03ac5dc708f46d88407def/ruff-0.14.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c561695675b972effb0c0a45db233f2c816ff3da8dcfbe7dfc7eed625f218935", size = 15449890, upload-time = "2025-12-18T19:28:53.573Z" }, + { url = "https://files.pythonhosted.org/packages/e9/36/480206eaefa24a7ec321582dda580443a8f0671fdbf6b1c80e9c3e93a16a/ruff-0.14.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4bb98fcbbc61725968893682fd4df8966a34611239c9fd07a1f6a07e7103d08e", size = 15123172, upload-time = "2025-12-18T19:29:23.453Z" }, + { url = "https://files.pythonhosted.org/packages/5c/38/68e414156015ba80cef5473d57919d27dfb62ec804b96180bafdeaf0e090/ruff-0.14.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f24b47993a9d8cb858429e97bdf8544c78029f09b520af615c1d261bf827001d", size = 14460260, upload-time = "2025-12-18T19:29:27.808Z" }, + { url = "https://files.pythonhosted.org/packages/b3/19/9e050c0dca8aba824d67cc0db69fb459c28d8cd3f6855b1405b3f29cc91d/ruff-0.14.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59aabd2e2c4fd614d2862e7939c34a532c04f1084476d6833dddef4afab87e9f", size = 14229978, upload-time = "2025-12-18T19:29:11.32Z" }, + { url = "https://files.pythonhosted.org/packages/51/eb/e8dd1dd6e05b9e695aa9dd420f4577debdd0f87a5ff2fedda33c09e9be8c/ruff-0.14.10-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:213db2b2e44be8625002dbea33bb9c60c66ea2c07c084a00d55732689d697a7f", size = 14338036, upload-time = "2025-12-18T19:29:09.184Z" }, + { url = "https://files.pythonhosted.org/packages/6a/12/f3e3a505db7c19303b70af370d137795fcfec136d670d5de5391e295c134/ruff-0.14.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b914c40ab64865a17a9a5b67911d14df72346a634527240039eb3bd650e5979d", size = 13264051, upload-time = "2025-12-18T19:29:13.431Z" }, + { url = "https://files.pythonhosted.org/packages/08/64/8c3a47eaccfef8ac20e0484e68e0772013eb85802f8a9f7603ca751eb166/ruff-0.14.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1484983559f026788e3a5c07c81ef7d1e97c1c78ed03041a18f75df104c45405", size = 13283998, upload-time = "2025-12-18T19:29:06.994Z" }, + { url = "https://files.pythonhosted.org/packages/12/84/534a5506f4074e5cc0529e5cd96cfc01bb480e460c7edf5af70d2bcae55e/ruff-0.14.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c70427132db492d25f982fffc8d6c7535cc2fd2c83fc8888f05caaa248521e60", size = 13601891, upload-time = "2025-12-18T19:28:55.811Z" }, + { url = "https://files.pythonhosted.org/packages/0d/1e/14c916087d8598917dbad9b2921d340f7884824ad6e9c55de948a93b106d/ruff-0.14.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5bcf45b681e9f1ee6445d317ce1fa9d6cba9a6049542d1c3d5b5958986be8830", size = 14336660, upload-time = "2025-12-18T19:29:16.531Z" }, + { url = "https://files.pythonhosted.org/packages/f2/1c/d7b67ab43f30013b47c12b42d1acd354c195351a3f7a1d67f59e54227ede/ruff-0.14.10-py3-none-win32.whl", hash = "sha256:104c49fc7ab73f3f3a758039adea978869a918f31b73280db175b43a2d9b51d6", size = 13196187, upload-time = "2025-12-18T19:29:19.006Z" }, + { url = "https://files.pythonhosted.org/packages/fb/9c/896c862e13886fae2af961bef3e6312db9ebc6adc2b156fe95e615dee8c1/ruff-0.14.10-py3-none-win_amd64.whl", hash = "sha256:466297bd73638c6bdf06485683e812db1c00c7ac96d4ddd0294a338c62fdc154", size = 14661283, upload-time = "2025-12-18T19:29:30.16Z" }, + { url = "https://files.pythonhosted.org/packages/74/31/b0e29d572670dca3674eeee78e418f20bdf97fa8aa9ea71380885e175ca0/ruff-0.14.10-py3-none-win_arm64.whl", hash = "sha256:e51d046cf6dda98a4633b8a8a771451107413b0f07183b2bef03f075599e44e6", size = 13729839, upload-time = "2025-12-18T19:28:48.636Z" }, ] [[package]] @@ -1854,17 +1853,48 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" }, ] +[[package]] +name = "scipy" +version = "1.16.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0a/ca/d8ace4f98322d01abcd52d381134344bf7b431eba7ed8b42bdea5a3c2ac9/scipy-1.16.3.tar.gz", hash = "sha256:01e87659402762f43bd2fee13370553a17ada367d42e7487800bf2916535aecb", size = 30597883, upload-time = "2025-10-28T17:38:54.068Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/f1/57e8327ab1508272029e27eeef34f2302ffc156b69e7e233e906c2a5c379/scipy-1.16.3-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:d2ec56337675e61b312179a1ad124f5f570c00f920cc75e1000025451b88241c", size = 36617856, upload-time = "2025-10-28T17:33:31.375Z" }, + { url = "https://files.pythonhosted.org/packages/44/13/7e63cfba8a7452eb756306aa2fd9b37a29a323b672b964b4fdeded9a3f21/scipy-1.16.3-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:16b8bc35a4cc24db80a0ec836a9286d0e31b2503cb2fd7ff7fb0e0374a97081d", size = 28874306, upload-time = "2025-10-28T17:33:36.516Z" }, + { url = "https://files.pythonhosted.org/packages/15/65/3a9400efd0228a176e6ec3454b1fa998fbbb5a8defa1672c3f65706987db/scipy-1.16.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:5803c5fadd29de0cf27fa08ccbfe7a9e5d741bf63e4ab1085437266f12460ff9", size = 20865371, upload-time = "2025-10-28T17:33:42.094Z" }, + { url = "https://files.pythonhosted.org/packages/33/d7/eda09adf009a9fb81827194d4dd02d2e4bc752cef16737cc4ef065234031/scipy-1.16.3-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:b81c27fc41954319a943d43b20e07c40bdcd3ff7cf013f4fb86286faefe546c4", size = 23524877, upload-time = "2025-10-28T17:33:48.483Z" }, + { url = "https://files.pythonhosted.org/packages/7d/6b/3f911e1ebc364cb81320223a3422aab7d26c9c7973109a9cd0f27c64c6c0/scipy-1.16.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0c3b4dd3d9b08dbce0f3440032c52e9e2ab9f96ade2d3943313dfe51a7056959", size = 33342103, upload-time = "2025-10-28T17:33:56.495Z" }, + { url = "https://files.pythonhosted.org/packages/21/f6/4bfb5695d8941e5c570a04d9fcd0d36bce7511b7d78e6e75c8f9791f82d0/scipy-1.16.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7dc1360c06535ea6116a2220f760ae572db9f661aba2d88074fe30ec2aa1ff88", size = 35697297, upload-time = "2025-10-28T17:34:04.722Z" }, + { url = "https://files.pythonhosted.org/packages/04/e1/6496dadbc80d8d896ff72511ecfe2316b50313bfc3ebf07a3f580f08bd8c/scipy-1.16.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:663b8d66a8748051c3ee9c96465fb417509315b99c71550fda2591d7dd634234", size = 36021756, upload-time = "2025-10-28T17:34:13.482Z" }, + { url = "https://files.pythonhosted.org/packages/fe/bd/a8c7799e0136b987bda3e1b23d155bcb31aec68a4a472554df5f0937eef7/scipy-1.16.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eab43fae33a0c39006a88096cd7b4f4ef545ea0447d250d5ac18202d40b6611d", size = 38696566, upload-time = "2025-10-28T17:34:22.384Z" }, + { url = "https://files.pythonhosted.org/packages/cd/01/1204382461fcbfeb05b6161b594f4007e78b6eba9b375382f79153172b4d/scipy-1.16.3-cp313-cp313-win_amd64.whl", hash = "sha256:062246acacbe9f8210de8e751b16fc37458213f124bef161a5a02c7a39284304", size = 38529877, upload-time = "2025-10-28T17:35:51.076Z" }, + { url = "https://files.pythonhosted.org/packages/7f/14/9d9fbcaa1260a94f4bb5b64ba9213ceb5d03cd88841fe9fd1ffd47a45b73/scipy-1.16.3-cp313-cp313-win_arm64.whl", hash = "sha256:50a3dbf286dbc7d84f176f9a1574c705f277cb6565069f88f60db9eafdbe3ee2", size = 25455366, upload-time = "2025-10-28T17:35:59.014Z" }, + { url = "https://files.pythonhosted.org/packages/e2/a3/9ec205bd49f42d45d77f1730dbad9ccf146244c1647605cf834b3a8c4f36/scipy-1.16.3-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:fb4b29f4cf8cc5a8d628bc8d8e26d12d7278cd1f219f22698a378c3d67db5e4b", size = 37027931, upload-time = "2025-10-28T17:34:31.451Z" }, + { url = "https://files.pythonhosted.org/packages/25/06/ca9fd1f3a4589cbd825b1447e5db3a8ebb969c1eaf22c8579bd286f51b6d/scipy-1.16.3-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:8d09d72dc92742988b0e7750bddb8060b0c7079606c0d24a8cc8e9c9c11f9079", size = 29400081, upload-time = "2025-10-28T17:34:39.087Z" }, + { url = "https://files.pythonhosted.org/packages/6a/56/933e68210d92657d93fb0e381683bc0e53a965048d7358ff5fbf9e6a1b17/scipy-1.16.3-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:03192a35e661470197556de24e7cb1330d84b35b94ead65c46ad6f16f6b28f2a", size = 21391244, upload-time = "2025-10-28T17:34:45.234Z" }, + { url = "https://files.pythonhosted.org/packages/a8/7e/779845db03dc1418e215726329674b40576879b91814568757ff0014ad65/scipy-1.16.3-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:57d01cb6f85e34f0946b33caa66e892aae072b64b034183f3d87c4025802a119", size = 23929753, upload-time = "2025-10-28T17:34:51.793Z" }, + { url = "https://files.pythonhosted.org/packages/4c/4b/f756cf8161d5365dcdef9e5f460ab226c068211030a175d2fc7f3f41ca64/scipy-1.16.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:96491a6a54e995f00a28a3c3badfff58fd093bf26cd5fb34a2188c8c756a3a2c", size = 33496912, upload-time = "2025-10-28T17:34:59.8Z" }, + { url = "https://files.pythonhosted.org/packages/09/b5/222b1e49a58668f23839ca1542a6322bb095ab8d6590d4f71723869a6c2c/scipy-1.16.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cd13e354df9938598af2be05822c323e97132d5e6306b83a3b4ee6724c6e522e", size = 35802371, upload-time = "2025-10-28T17:35:08.173Z" }, + { url = "https://files.pythonhosted.org/packages/c1/8d/5964ef68bb31829bde27611f8c9deeac13764589fe74a75390242b64ca44/scipy-1.16.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:63d3cdacb8a824a295191a723ee5e4ea7768ca5ca5f2838532d9f2e2b3ce2135", size = 36190477, upload-time = "2025-10-28T17:35:16.7Z" }, + { url = "https://files.pythonhosted.org/packages/ab/f2/b31d75cb9b5fa4dd39a0a931ee9b33e7f6f36f23be5ef560bf72e0f92f32/scipy-1.16.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e7efa2681ea410b10dde31a52b18b0154d66f2485328830e45fdf183af5aefc6", size = 38796678, upload-time = "2025-10-28T17:35:26.354Z" }, + { url = "https://files.pythonhosted.org/packages/b4/1e/b3723d8ff64ab548c38d87055483714fefe6ee20e0189b62352b5e015bb1/scipy-1.16.3-cp313-cp313t-win_amd64.whl", hash = "sha256:2d1ae2cf0c350e7705168ff2429962a89ad90c2d49d1dd300686d8b2a5af22fc", size = 38640178, upload-time = "2025-10-28T17:35:35.304Z" }, + { url = "https://files.pythonhosted.org/packages/8e/f3/d854ff38789aca9b0cc23008d607ced9de4f7ab14fa1ca4329f86b3758ca/scipy-1.16.3-cp313-cp313t-win_arm64.whl", hash = "sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a", size = 25803246, upload-time = "2025-10-28T17:35:42.155Z" }, +] + [[package]] name = "sentry-sdk" -version = "2.47.0" +version = "2.48.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4a/2a/d225cbf87b6c8ecce5664db7bcecb82c317e448e3b24a2dcdaacb18ca9a7/sentry_sdk-2.47.0.tar.gz", hash = "sha256:8218891d5e41b4ea8d61d2aed62ed10c80e39d9f2959d6f939efbf056857e050", size = 381895, upload-time = "2025-12-03T14:06:36.846Z" } +sdist = { url = "https://files.pythonhosted.org/packages/40/f0/0e9dc590513d5e742d7799e2038df3a05167cba084c6ca4f3cdd75b55164/sentry_sdk-2.48.0.tar.gz", hash = "sha256:5213190977ff7fdff8a58b722fb807f8d5524a80488626ebeda1b5676c0c1473", size = 384828, upload-time = "2025-12-16T14:55:41.722Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bd/ac/d6286ea0d49e7b58847faf67b00e56bb4ba3d525281e2ac306e1f1f353da/sentry_sdk-2.47.0-py2.py3-none-any.whl", hash = "sha256:d72f8c61025b7d1d9e52510d03a6247b280094a327dd900d987717a4fce93412", size = 411088, upload-time = "2025-12-03T14:06:35.374Z" }, + { url = "https://files.pythonhosted.org/packages/4d/19/8d77f9992e5cbfcaa9133c3bf63b4fbbb051248802e1e803fed5c552fbb2/sentry_sdk-2.48.0-py2.py3-none-any.whl", hash = "sha256:6b12ac256769d41825d9b7518444e57fa35b5642df4c7c5e322af4d2c8721172", size = 414555, upload-time = "2025-12-16T14:55:40.152Z" }, ] [[package]] @@ -1879,7 +1909,7 @@ wheels = [ [[package]] name = "simple-stories-train" version = "0.0.1" -source = { git = "https://github.com/goodfire-ai/simple_stories_train.git?rev=dev#72c4c863150d6c475e44afe13af2d2b18e34492a" } +source = { git = "https://github.com/goodfire-ai/simple_stories_train.git?rev=dev#16f574d274849c90fe15e5c4df231f3885e44a7d" } dependencies = [ { name = "datasets" }, { name = "fire" }, @@ -1931,6 +1961,7 @@ dependencies = [ { name = "openrouter" }, { name = "pydantic" }, { name = "python-dotenv" }, + { name = "scipy" }, { name = "simple-stories-train" }, { name = "streamlit" }, { name = "streamlit-antd-components" }, @@ -1969,6 +2000,7 @@ requires-dist = [ { name = "openrouter", specifier = ">=0.1.1" }, { name = "pydantic", specifier = "<2.12" }, { name = "python-dotenv" }, + { name = "scipy", specifier = ">=1.14.1" }, { name = "simple-stories-train", git = "https://github.com/goodfire-ai/simple_stories_train.git?rev=dev" }, { name = "streamlit" }, { name = "streamlit-antd-components" }, @@ -1995,14 +2027,17 @@ dev = [ [[package]] name = "sqlite-web" -version = "0.6.5" +version = "0.6.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "flask" }, { name = "peewee" }, { name = "pygments" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0d/e4/170abaeeb86cdfa8d73fd6f7ffc0ddc68008c0e82e2e102b4568e61d5f2a/sqlite-web-0.6.5.tar.gz", hash = "sha256:fc0b2abb65c45424b9b055b6181fd6fe29edd60099208819abccab7afa10ed8a", size = 807898, upload-time = "2025-10-13T22:55:03.051Z" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/1c/257ae3bd866fc04ec44b82a8fae8d470326a1afb58627a2f53ce46ee3a9d/sqlite_web-0.6.6.tar.gz", hash = "sha256:a02927d3b46d11424c620a90627c4413f9d8e63d808b3849d2f90f5882308ccc", size = 768634, upload-time = "2025-12-15T20:50:24.451Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/57/10/f62fa7f36f7838570255ccda72ab5cb9558b2ae0a7ea0dad24e9466cc806/sqlite_web-0.6.6-py3-none-any.whl", hash = "sha256:19c6a2be821452cf25ff4f403daeabc7cedbfaff0850b291775a36856a5de113", size = 774699, upload-time = "2025-12-15T20:50:12.501Z" }, +] [[package]] name = "stack-data" @@ -2032,7 +2067,7 @@ wheels = [ [[package]] name = "streamlit" -version = "1.52.1" +version = "1.52.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "altair" }, @@ -2054,9 +2089,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "watchdog", marker = "sys_platform != 'darwin'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/ec/66ca1578587bbaaea5757a7f074d3f762fff8d782ffb4497a83c5722aa44/streamlit-1.52.1.tar.gz", hash = "sha256:b036a71866b893c97fdebaa2a2ebd21ebf2af7daea4b3abe783a57b26f55b3ca", size = 8582829, upload-time = "2025-12-05T18:55:42.006Z" } +sdist = { url = "https://files.pythonhosted.org/packages/43/20/434aaceccc6e1912671d869926103051330437adba72d538d787a07727ef/streamlit-1.52.2.tar.gz", hash = "sha256:64a4dda8bc5cdd37bfd490e93bb53da35aaef946fcfc283a7980dacdf165108b", size = 8584178, upload-time = "2025-12-17T17:07:59.642Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d0/d4/cdafd4cc940937410f465ca7a77dd34237182c2ddece624e08db959496f8/streamlit-1.52.1-py3-none-any.whl", hash = "sha256:97fee2c3421d350fd65548e45a20f506ec1b651d78f95ecacbc0c2f9f838081c", size = 9024748, upload-time = "2025-12-05T18:55:39.713Z" }, + { url = "https://files.pythonhosted.org/packages/c0/95/6b7873f0267973ebd55ba9cd33a690b35a116f2779901ef6185a0e21864d/streamlit-1.52.2-py3-none-any.whl", hash = "sha256:a16bb4fbc9781e173ce9dfbd8ffb189c174f148f9ca4fb8fa56423e84e193fc8", size = 9025937, upload-time = "2025-12-17T17:07:57.67Z" }, ] [[package]] @@ -2221,21 +2256,21 @@ wheels = [ [[package]] name = "tornado" -version = "6.5.2" +version = "6.5.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/09/ce/1eb500eae19f4648281bb2186927bb062d2438c2e5093d1360391afd2f90/tornado-6.5.2.tar.gz", hash = "sha256:ab53c8f9a0fa351e2c0741284e06c7a45da86afb544133201c5cc8578eb076a0", size = 510821, upload-time = "2025-08-08T18:27:00.78Z" } +sdist = { url = "https://files.pythonhosted.org/packages/37/1d/0a336abf618272d53f62ebe274f712e213f5a03c0b2339575430b8362ef2/tornado-6.5.4.tar.gz", hash = "sha256:a22fa9047405d03260b483980635f0b041989d8bcc9a313f8fe18b411d84b1d7", size = 513632, upload-time = "2025-12-15T19:21:03.836Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f6/48/6a7529df2c9cc12efd2e8f5dd219516184d703b34c06786809670df5b3bd/tornado-6.5.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:2436822940d37cde62771cff8774f4f00b3c8024fe482e16ca8387b8a2724db6", size = 442563, upload-time = "2025-08-08T18:26:42.945Z" }, - { url = "https://files.pythonhosted.org/packages/f2/b5/9b575a0ed3e50b00c40b08cbce82eb618229091d09f6d14bce80fc01cb0b/tornado-6.5.2-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:583a52c7aa94ee046854ba81d9ebb6c81ec0fd30386d96f7640c96dad45a03ef", size = 440729, upload-time = "2025-08-08T18:26:44.473Z" }, - { url = "https://files.pythonhosted.org/packages/1b/4e/619174f52b120efcf23633c817fd3fed867c30bff785e2cd5a53a70e483c/tornado-6.5.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0fe179f28d597deab2842b86ed4060deec7388f1fd9c1b4a41adf8af058907e", size = 444295, upload-time = "2025-08-08T18:26:46.021Z" }, - { url = "https://files.pythonhosted.org/packages/95/fa/87b41709552bbd393c85dd18e4e3499dcd8983f66e7972926db8d96aa065/tornado-6.5.2-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b186e85d1e3536d69583d2298423744740986018e393d0321df7340e71898882", size = 443644, upload-time = "2025-08-08T18:26:47.625Z" }, - { url = "https://files.pythonhosted.org/packages/f9/41/fb15f06e33d7430ca89420283a8762a4e6b8025b800ea51796ab5e6d9559/tornado-6.5.2-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e792706668c87709709c18b353da1f7662317b563ff69f00bab83595940c7108", size = 443878, upload-time = "2025-08-08T18:26:50.599Z" }, - { url = "https://files.pythonhosted.org/packages/11/92/fe6d57da897776ad2e01e279170ea8ae726755b045fe5ac73b75357a5a3f/tornado-6.5.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:06ceb1300fd70cb20e43b1ad8aaee0266e69e7ced38fa910ad2e03285009ce7c", size = 444549, upload-time = "2025-08-08T18:26:51.864Z" }, - { url = "https://files.pythonhosted.org/packages/9b/02/c8f4f6c9204526daf3d760f4aa555a7a33ad0e60843eac025ccfd6ff4a93/tornado-6.5.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:74db443e0f5251be86cbf37929f84d8c20c27a355dd452a5cfa2aada0d001ec4", size = 443973, upload-time = "2025-08-08T18:26:53.625Z" }, - { url = "https://files.pythonhosted.org/packages/ae/2d/f5f5707b655ce2317190183868cd0f6822a1121b4baeae509ceb9590d0bd/tornado-6.5.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b5e735ab2889d7ed33b32a459cac490eda71a1ba6857b0118de476ab6c366c04", size = 443954, upload-time = "2025-08-08T18:26:55.072Z" }, - { url = "https://files.pythonhosted.org/packages/e8/59/593bd0f40f7355806bf6573b47b8c22f8e1374c9b6fd03114bd6b7a3dcfd/tornado-6.5.2-cp39-abi3-win32.whl", hash = "sha256:c6f29e94d9b37a95013bb669616352ddb82e3bfe8326fccee50583caebc8a5f0", size = 445023, upload-time = "2025-08-08T18:26:56.677Z" }, - { url = "https://files.pythonhosted.org/packages/c7/2a/f609b420c2f564a748a2d80ebfb2ee02a73ca80223af712fca591386cafb/tornado-6.5.2-cp39-abi3-win_amd64.whl", hash = "sha256:e56a5af51cc30dd2cae649429af65ca2f6571da29504a07995175df14c18f35f", size = 445427, upload-time = "2025-08-08T18:26:57.91Z" }, - { url = "https://files.pythonhosted.org/packages/5e/4f/e1f65e8f8c76d73658b33d33b81eed4322fb5085350e4328d5c956f0c8f9/tornado-6.5.2-cp39-abi3-win_arm64.whl", hash = "sha256:d6c33dc3672e3a1f3618eb63b7ef4683a7688e7b9e6e8f0d9aa5726360a004af", size = 444456, upload-time = "2025-08-08T18:26:59.207Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a9/e94a9d5224107d7ce3cc1fab8d5dc97f5ea351ccc6322ee4fb661da94e35/tornado-6.5.4-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d6241c1a16b1c9e4cc28148b1cda97dd1c6cb4fb7068ac1bedc610768dff0ba9", size = 443909, upload-time = "2025-12-15T19:20:48.382Z" }, + { url = "https://files.pythonhosted.org/packages/db/7e/f7b8d8c4453f305a51f80dbb49014257bb7d28ccb4bbb8dd328ea995ecad/tornado-6.5.4-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2d50f63dda1d2cac3ae1fa23d254e16b5e38153758470e9956cbc3d813d40843", size = 442163, upload-time = "2025-12-15T19:20:49.791Z" }, + { url = "https://files.pythonhosted.org/packages/ba/b5/206f82d51e1bfa940ba366a8d2f83904b15942c45a78dd978b599870ab44/tornado-6.5.4-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1cf66105dc6acb5af613c054955b8137e34a03698aa53272dbda4afe252be17", size = 445746, upload-time = "2025-12-15T19:20:51.491Z" }, + { url = "https://files.pythonhosted.org/packages/8e/9d/1a3338e0bd30ada6ad4356c13a0a6c35fbc859063fa7eddb309183364ac1/tornado-6.5.4-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50ff0a58b0dc97939d29da29cd624da010e7f804746621c78d14b80238669335", size = 445083, upload-time = "2025-12-15T19:20:52.778Z" }, + { url = "https://files.pythonhosted.org/packages/50/d4/e51d52047e7eb9a582da59f32125d17c0482d065afd5d3bc435ff2120dc5/tornado-6.5.4-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5fb5e04efa54cf0baabdd10061eb4148e0be137166146fff835745f59ab9f7f", size = 445315, upload-time = "2025-12-15T19:20:53.996Z" }, + { url = "https://files.pythonhosted.org/packages/27/07/2273972f69ca63dbc139694a3fc4684edec3ea3f9efabf77ed32483b875c/tornado-6.5.4-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9c86b1643b33a4cd415f8d0fe53045f913bf07b4a3ef646b735a6a86047dda84", size = 446003, upload-time = "2025-12-15T19:20:56.101Z" }, + { url = "https://files.pythonhosted.org/packages/d1/83/41c52e47502bf7260044413b6770d1a48dda2f0246f95ee1384a3cd9c44a/tornado-6.5.4-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:6eb82872335a53dd063a4f10917b3efd28270b56a33db69009606a0312660a6f", size = 445412, upload-time = "2025-12-15T19:20:57.398Z" }, + { url = "https://files.pythonhosted.org/packages/10/c7/bc96917f06cbee182d44735d4ecde9c432e25b84f4c2086143013e7b9e52/tornado-6.5.4-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6076d5dda368c9328ff41ab5d9dd3608e695e8225d1cd0fd1e006f05da3635a8", size = 445392, upload-time = "2025-12-15T19:20:58.692Z" }, + { url = "https://files.pythonhosted.org/packages/0c/1a/d7592328d037d36f2d2462f4bc1fbb383eec9278bc786c1b111cbbd44cfa/tornado-6.5.4-cp39-abi3-win32.whl", hash = "sha256:1768110f2411d5cd281bac0a090f707223ce77fd110424361092859e089b38d1", size = 446481, upload-time = "2025-12-15T19:21:00.008Z" }, + { url = "https://files.pythonhosted.org/packages/d6/6d/c69be695a0a64fd37a97db12355a035a6d90f79067a3cf936ec2b1dc38cd/tornado-6.5.4-cp39-abi3-win_amd64.whl", hash = "sha256:fa07d31e0cd85c60713f2b995da613588aa03e1303d75705dca6af8babc18ddc", size = 446886, upload-time = "2025-12-15T19:21:01.287Z" }, + { url = "https://files.pythonhosted.org/packages/50/49/8dc3fd90902f70084bd2cd059d576ddb4f8bb44c2c7c0e33a11422acb17e/tornado-6.5.4-cp39-abi3-win_arm64.whl", hash = "sha256:053e6e16701eb6cbe641f308f4c1a9541f91b6261991160391bfc342e8a551a1", size = 445910, upload-time = "2025-12-15T19:21:02.571Z" }, ] [[package]] @@ -2315,33 +2350,33 @@ wheels = [ [[package]] name = "tzdata" -version = "2025.2" +version = "2025.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be760d75c2c42e2780dc0873fe382da3e98a2e1e48361e5/tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9", size = 196380, upload-time = "2025-03-23T13:54:43.652Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/a7/c202b344c5ca7daf398f3b8a477eeb205cf3b6f32e7ec3a6bac0629ca975/tzdata-2025.3.tar.gz", hash = "sha256:de39c2ca5dc7b0344f2eba86f49d614019d29f060fc4ebc8a417896a620b56a7", size = 196772, upload-time = "2025-12-13T17:45:35.667Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" }, + { url = "https://files.pythonhosted.org/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl", hash = "sha256:06a47e5700f3081aab02b2e513160914ff0694bce9947d6b76ebd6bf57cfc5d1", size = 348521, upload-time = "2025-12-13T17:45:33.889Z" }, ] [[package]] name = "urllib3" -version = "2.6.0" +version = "2.6.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1c/43/554c2569b62f49350597348fc3ac70f786e3c32e7f19d266e19817812dd3/urllib3-2.6.0.tar.gz", hash = "sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1", size = 432585, upload-time = "2025-12-05T15:08:47.885Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1e/24/a2a2ed9addd907787d7aa0355ba36a6cadf1768b934c652ea78acbd59dcd/urllib3-2.6.2.tar.gz", hash = "sha256:016f9c98bb7e98085cb2b4b17b87d2c702975664e4f060c6532e64d1c1a5e797", size = 432930, upload-time = "2025-12-11T15:56:40.252Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/1a/9ffe814d317c5224166b23e7c47f606d6e473712a2fad0f704ea9b99f246/urllib3-2.6.0-py3-none-any.whl", hash = "sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f", size = 131083, upload-time = "2025-12-05T15:08:45.983Z" }, + { url = "https://files.pythonhosted.org/packages/6d/b9/4095b668ea3678bf6a0af005527f39de12fb026516fb3df17495a733b7f8/urllib3-2.6.2-py3-none-any.whl", hash = "sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd", size = 131182, upload-time = "2025-12-11T15:56:38.584Z" }, ] [[package]] name = "uvicorn" -version = "0.38.0" +version = "0.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, { name = "h11" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cb/ce/f06b84e2697fef4688ca63bdb2fdf113ca0a3be33f94488f2cadb690b0cf/uvicorn-0.38.0.tar.gz", hash = "sha256:fd97093bdd120a2609fc0d3afe931d4d4ad688b6e75f0f929fde1bc36fe0e91d", size = 80605, upload-time = "2025-10-18T13:46:44.63Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/d1/8f3c683c9561a4e6689dd3b1d345c815f10f86acd044ee1fb9a4dcd0b8c5/uvicorn-0.40.0.tar.gz", hash = "sha256:839676675e87e73694518b5574fd0f24c9d97b46bea16df7b8c05ea1a51071ea", size = 81761, upload-time = "2025-12-21T14:16:22.45Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/d9/d88e73ca598f4f6ff671fb5fde8a32925c2e08a637303a1d12883c7305fa/uvicorn-0.38.0-py3-none-any.whl", hash = "sha256:48c0afd214ceb59340075b4a052ea1ee91c16fbc2a9b1469cca0e54566977b02", size = 68109, upload-time = "2025-10-18T13:46:42.958Z" }, + { url = "https://files.pythonhosted.org/packages/3d/d8/2083a1daa7439a66f3a48589a57d576aa117726762618f6bb09fe3798796/uvicorn-0.40.0-py3-none-any.whl", hash = "sha256:c6c8f55bc8bf13eb6fa9ff87ad62308bbbc33d0b67f84293151efe87e0d5f2ee", size = 68502, upload-time = "2025-12-21T14:16:21.041Z" }, ] [[package]]