diff --git a/CLAUDE_CHECKLIST.md b/CLAUDE_CHECKLIST.md new file mode 100644 index 000000000..9c86c0ca3 --- /dev/null +++ b/CLAUDE_CHECKLIST.md @@ -0,0 +1,136 @@ +# CLAUDE_CHECKLIST.md - Pre-Submission Checklist + +Use this checklist before submitting any code changes to ensure your contribution meets SPD repository standards. + +As you work through this checklist, you might notice something and then get distracted when fixing it. You need to restart the checklist again after your fixes. You might therefore want to keep a running list of changes to make, then make them, then start the checklist again for all of them. + +## Code Style & Formatting + +### Naming +- [ ] **Files & modules**: `snake_case.py` +- [ ] **Functions & variables**: `snake_case` +- [ ] **Classes**: `PascalCase` +- [ ] **Constants**: `UPPERCASE_WITH_UNDERSCORES` +- [ ] **Private functions**: Prefixed with `_` +- [ ] **Abbreviations**: Uppercase (e.g., `CI`, `L0`, `KL`) + +### Type Annotations +- [ ] **Used jaxtyping for tensors** - `Float[Tensor, "... C d_in"]` format (runtime checking not yet enabled) +- [ ] **Used PEP 604 unions** - `str | None` NOT `Optional[str]` +- [ ] **Used lowercase generics** - `dict`, `list`, `tuple` NOT `Dict`, `List`, `Tuple` +- [ ] **Avoided redundant annotations** - Don't write `my_thing: Thing = Thing()` or `name: str = "John"` +- [ ] **Type checking passes with no errors** - Run `make type` successfully and fix all issues (uses basedpyright, NOT mypy) + +### Comments & Documentation +- [ ] **No obvious comments** - If code is self-explanatory, no comment needed. (Temp comments during development are fine if cleaned up before committing) +- [ ] **Complex logic explained** - Comments focus on "why" not "what" +- [ ] **Google-style docstrings** - Used `Args:`, `Returns:`, `Raises:` sections where needed +- [ ] **Non-obvious information only** - Docstrings don't repeat what's obvious from signature + +### Formatting +- [ ] **Ruff formatting applied** - Run `make format` before committing + +## Code Quality + +### Error Handling (Fail Fast) +- [ ] **Liberal assertions** - Assert all assumptions about data/state +- [ ] **Clear error messages** - Assertions include descriptive messages +- [ ] **Explicit error types** - Use `ValueError`, `NotImplementedError`, `RuntimeError` appropriately +- [ ] **Fail immediately** - Code fails when wrong, doesn't recover silently +- [ ] **Use try-except only for expected errors** - Assertions for invariants/assumptions. Try-except only when errors are expected and handled (e.g., path resolution, file not found) + +### Tensor Operations +- [ ] **Used einops by default** - Preferred over raw einsum for clarity +- [ ] **Asserted shapes liberally** - Verify tensor dimensions +- [ ] **Documented complex operations** - Explain non-obvious tensor manipulations + +### Design Patterns +- [ ] **Followed existing patterns** - Match architecture style of surrounding code (ABC for interfaces, Protocol for metrics, Pydantic for configs) +- [ ] **Metrics decoupled** - Each metric in its own file within `spd/metrics/` directory. Figures in `spd/figures.py` +- [ ] **Used Pydantic for configs** - Configs are frozen (`frozen=True`) and forbid extras (`extra="forbid"`) +- [ ] **Config paths handled correctly** - If handling paths in configs, support both relative paths and `wandb:` prefix format +- [ ] **New experiments registered** - If adding new experiment, added to `spd/registry.py` with proper structure +- [ ] **Experiment structure followed** - Experiments have `models.py`, `configs.py`, `{task}_decomposition.py` in flat structure + +## Testing + +- [ ] **Tests written** - Unit tests for new functionality. Regression tests for bug fixes. +- [ ] **Tests run successfully** - Run `make test` (or `make test-all` if relevant) +- [ ] **Test files named correctly** - `test_*.py` format +- [ ] **Test functions named correctly** - `def test_*():` with descriptive names +- [ ] **Slow tests marked** - Used `@pytest.mark.slow` for slow tests +- [ ] **Focus on unit tests** - Not production code (no deployment). Integration tests often too much overhead for research code. Interactive use catches issues at low cost. Add integration tests only if testing complex interactions that can't be validated in units. + +## Pre-Commit Checks + +- [ ] **Ran `make check`** - Full pre-commit suite passes (format + type check) +- [ ] **No type errors** - basedpyright reports no issues +- [ ] **No lint errors** - ruff reports no issues + +## Git & Version Control + +### Before Committing +- [ ] **Checked existing patterns** - If adding new files (docs, configs, tests, etc.), looked at similar existing files for formatting/structure conventions to follow +- [ ] **Reviewed every line of the diff** - Understand every change being committed +- [ ] **Only relevant files staged** - Don't commit unrelated changes or all files +- [ ] **No secrets committed** - No `.env`, `credentials.json`, or similar files +- [ ] **Used correct branch name** - Format: `refactor/X`, `feature/Y`, or `fix/Z` + +### Commit Message +- [ ] **Explains "what" and "why"** - Not just describing the diff +- [ ] **Clear and descriptive** - Focused on relevant changes +- [ ] **Explains purpose** - Why this change is needed + +### Committing +- [ ] **NOT using `--no-verify`** - Almost never appropriate. Pre-commit checks exist for a reason. +- [ ] **Pre-commit hooks run** - Automatically runs `make format` and `make type` +- [ ] **All hooks passed** - No failures from pre-commit checks + +## Pull Request (if creating) + +### PR Content +- [ ] **Analyzed all changes** - Reviewed git diff and git status before creating PR +- [ ] **Title is clear** - Concise summary of changes +- [ ] **Used PR template** - Filled out all sections in `.github/pull_request_template.md`: + - Description - What changed + - Related Issue - "Closes #XX" format if applicable + - Motivation and Context - Why needed + - Testing - How tested + - Breaking Changes - Listed if any + +### PR Quality +- [ ] **All CI checks pass** - GitHub Actions successful +- [ ] **Merged latest from main** - Branch is up to date +- [ ] **Only relevant files** - No unrelated changes included +- [ ] **Self-reviewed** - Went through diff yourself first + +## Cluster Usage (if applicable) + +If running experiments on the cluster: +- [ ] **NOT exceeding 8 GPUs total** - Including all sweeps/evals combined +- [ ] **Monitored jobs** - Used `squeue` to check current usage +- [ ] **Used appropriate resources** - GPU vs CPU flags set correctly + +## Final Self-Review + +- [ ] **Restarted checklist after any changes** - If you made ANY changes while going through this checklist, you MUST restart from the beginning. Did you restart? If not, STOP and restart now. +- [ ] **Code is simple** - Straightforward for researchers with varying experience +- [ ] **No over-engineering** - Only made changes directly requested or clearly necessary +- [ ] **No unnecessary features** - Didn't add extra functionality beyond the task +- [ ] **No premature abstraction** - Didn't create helpers/utilities for one-time operations +- [ ] **No backwards-compatibility hacks** - Removed unused code completely instead of commenting +- [ ] **Followed fail-fast principle** - Code fails immediately when assumptions violated +- [ ] **Type safety maintained** - All functions properly typed +- [ ] **Tests are sufficient** - Core functionality tested, not over-tested + +## Common Mistakes to Avoid + +- ❌ Forgetting to remove obvious comments like `# get dataloader` +- ❌ Committing without running `make check` +- ❌ Using `--no-verify` flag +- ❌ Recovering silently from errors instead of failing +- ❌ Adding type annotations to obvious assignments like `name: str = "John"` +- ❌ Committing all files instead of only relevant changes +- ❌ Using more than 8 GPUs on cluster (total across all jobs) +- ❌ Failing to consult CLAUDE_COMPREHENSIVE.md for clarification in cases where the checklist is unclear. +- ❌ Starting this checklist, noticing an issue, fixing it, and then forgetting to start the checklist **from the start** again. diff --git a/CLAUDE_COMPREHENSIVE.md b/CLAUDE_COMPREHENSIVE.md new file mode 100644 index 000000000..cc97df976 --- /dev/null +++ b/CLAUDE_COMPREHENSIVE.md @@ -0,0 +1,669 @@ +# CLAUDE_COMPREHENSIVE.md - Complete Development Guide for SPD + +This guide covers everything needed to understand, develop, and contribute to the SPD (Stochastic Parameter Decomposition) codebase. + +## 1. Introduction + +For AI assistants and developers. Covers: +- Environment setup and project structure +- Development philosophy and coding standards +- Architecture patterns and design principles +- Common workflows and usage patterns +- Testing, deployment, and collaboration practices + +### How to Use This Guide + +**Two Documents:** +- **CLAUDE_COMPREHENSIVE.md** (this document) - Complete reference for understanding the codebase, architecture, and development practices. Read this to learn how the project works. +- **CLAUDE_CHECKLIST.md** - Pre-submission checklist for verifying your code changes meet SPD standards. Use this before committing to ensure your work follows all conventions. + +**Workflow:** Read the comprehensive guide to understand context, then use the checklist to verify your changes before submission. + +## 2. Environment Setup & Quick Start + +**IMPORTANT**: Always activate the virtual environment before running Python or git operations: +```bash +source .venv/bin/activate +``` + +**Installation:** +```bash +make install-dev # Install with dev dependencies and pre-commit hooks +make install # Install package only (pip install -e .) +``` + +**Environment:** +- `.env` file with WandB credentials (see `.env.example`) +- WandB for experiment tracking and model storage +- Runs generate timestamped output directories (configs, models, plots) + +## 3. Project Overview + +SPD is a research framework for analyzing neural network components through sparse parameter decomposition. Supports experimental domains: +- **TMS** (Toy Model of Superposition) +- **ResidualMLP** (residual MLP analysis) +- **Language Models** +- **Identity Insertion** + +### Available Experiments + +Defined in `spd/registry.py`: + +- `tms_5-2`, `tms_5-2-id` - TMS: 5 features, 2 hidden dims (id = fixed identity in-between) +- `tms_40-10`, `tms_40-10-id` - TMS: 40 features, 10 hidden dims +- `resid_mlp1`, `resid_mlp2`, `resid_mlp3` - ResidualMLP: 1-3 layers +- `ss_emb` - Language models (from HuggingFace) + +### Research Papers + +**Stochastic Parameter Decomposition (SPD)** +- [`papers/Stochastic_Parameter_Decomposition/spd_paper.md`](papers/Stochastic_Parameter_Decomposition/spd_paper.md) +- Introduces core SPD framework, stochastic masking, and optimization techniques +- Note: Development has continued beyond the paper implementation + +**Attribution-based Parameter Decomposition (APD)** +- [`papers/Attribution_based_Parameter_Decomposition/apd_paper.md`](papers/Attribution_based_Parameter_Decomposition/apd_paper.md) +- Precursor to SPD, first linear parameter decomposition +- High-level conceptual insights and theoretical foundations + +### Key Data Flow + +1. Experiments load pretrained target models via WandB or local paths +2. Target models are wrapped in ComponentModel with specified target modules +3. SPD optimization runs via `spd.run_spd.optimize()` with config-driven loss combination +4. Results include component masks, causal importance scores, and visualizations + +### Component Analysis + +- Components = sparse decompositions of model parameters +- Stochastic masking enables differentiable sparsity +- Causal importance quantifies contributions +- Loss terms balance faithfulness, reconstruction, sparsity + +## 4. Development Philosophy & Principles + +### Core Principles (TLDR) + +1. **Simplicity First** - Code for researchers with varying experience. Prioritize simple, straightforward code. + +2. **Type Safety** - Use types, einops, jaxtyping, liberal assertions, Pydantic validation, strict pyright. + +3. **Fail Fast** - Code fails immediately when wrong, not silently. Liberal assertions, clear errors, explicit types. + +4. **Documentation** - Comments for complex logic only. Skip obvious comments. + +5. **Modularity** - Registry pattern, abstract interfaces, protocols. Decouple metrics from core. + +6. **Reproducibility** - Centralized configs, seed management, WandB tracking. + +7. **Performance** - Distributed training, parallel testing, optimized CI/CD. + +8. **Maintainability** - Consistent naming, clear architecture, comprehensive tooling. + +## 5. Development Workflow & Commands + +**Package Manager:** uv (NOT pip/poetry) + +### Make Targets + +```bash +make install # Install package only +make install-dev # Install with dev deps and pre-commit hooks +make check # Run full pre-commit suite (format + type check) +make format # Ruff lint + format +make type # BasedPyright type checking +make test # Run tests (excluding slow tests) +make test-all # Run all tests including slow ones +make coverage # Generate coverage reports +``` + +### Pre-commit Hooks + +Automatically run `make format` and `make type` before commits (install with `make install-dev`) + +### CI/CD Pipeline (GitHub Actions) + +1. Checkout code +2. Set up Python 3.13 via uv +3. Install dependencies with CPU-only PyTorch +4. Run basedpyright type checking +5. Run ruff lint and format +6. Run pytest with parallel execution (max 4 workers) + +**Special CI install:** +```bash +make install-ci # Uses CPU wheels, unsafe-best-match index strategy +``` + +## 6. Code Style & Formatting + +### Naming Conventions + +- **Files & modules**: `snake_case.py` (e.g., `component_model.py`) +- **Functions & variables**: `snake_case` (e.g., `create_data_loader()`) +- **Classes**: `PascalCase` (e.g., `ComponentModel`) +- **Constants**: `UPPERCASE_WITH_UNDERSCORES` (e.g., `REPO_ROOT`) +- **Private functions**: Prefix with underscore (e.g., `_infer_backend()`) +- **Abbreviations**: Uppercase in variables (e.g., `CI`, `L0`, `KL`) + +### Formatting Rules + +- **Line length**: 100 characters (strict, enforced by ruff) +- **Formatter**: ruff (configured in pyproject.toml) +- **Import organization**: stdlib → third-party → local +- **Import sorting**: Handled by ruff/isort + +**Ruff Configuration:** +- Enabled rules: pycodestyle (E), Pyflakes (F), pyupgrade (UP), flake8-bugbear (B), flake8-simplify (SIM), isort (I) +- Ignored: F722 (jaxtyping incompatibility), E731 (lambda functions allowed), E501 (long lines) + +## 7. Type Annotations + +### Core Principles + +- Use **jaxtyping** for tensor shapes: `Float[Tensor, "... C d_in"]` (runtime checking not yet enabled) +- Use **PEP 604 union syntax**: `str | None` (NOT `Optional[str]`) +- Use **lowercase generic types**: `dict`, `list`, `tuple` (NOT `Dict`, `List`, `Tuple`) +- **Don't annotate when redundant**: `my_thing = Thing()` not `my_thing: Thing = Thing()`, or `name = "John"` not `name: str = "John"` + +### Examples + +```python +# Good - jaxtyping with explicit dimensions +def forward(self, x: Float[Tensor, "... C d_in"]) -> Float[Tensor, "... C d_out"]: + return einops.einsum(x, self.W, "... C d_in, C d_in d_out -> ... C d_out") + self.b + +# Good - PEP 604 union syntax +def load_model(path: str | Path) -> Model | None: + pass + +# Bad - old style +from typing import Optional, Dict +def load_model(path: Optional[str]) -> Dict[str, Any]: + pass +``` + +### Type Checking + +- Uses **basedpyright** (NOT mypy) - forked pyright for better performance +- Strict mode enabled: `strictListInference`, `strictDictionaryInference`, `strictSetInference` +- Reports: `MissingTypeArgument`, `UnknownParameterType`, `IncompatibleMethodOverride`, `ImportCycles` +- Excluded: `wandb` directory, third-party code, frontend +- Run with `make type` + +## 8. Documentation & Comments + +### Philosophy: Don't Write Obvious Comments + +Your first instinct should be: **"If I couldn't write any comments, how would I write this code?"** + +If code is self-explanatory, skip the comment. Only comment to explain complex logic, focusing on **"why" not "what"**. + +If you find it helps you develop, you can write whatever comments you like when developing, so long as you remember to come back and fix them later. + +### Bad (Obvious): +```python +# get dataloader +dataloader = get_dataloader(config) +``` + +### Good (Explains Complex Logic): +```python +# We need to mask out future positions for causal attention +# Upper triangular matrix excludes the diagonal (hence k=1) +causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) +``` + +### Docstring Format + +Use **Google-style** with `Args:`, `Returns:`, `Raises:` sections. Single-line for simple functions, multi-line for complex. Focus on non-obvious information. + +```python +def tokenize_and_concatenate(dataset: Dataset, tokenizer: PreTrainedTokenizer, ...) -> Dataset: + """Tokenize and concatenate a dataset of text. + + Args: + dataset: HuggingFace dataset to tokenize + tokenizer: Pretrained tokenizer to use + ... + + Returns: + Tokenized and concatenated dataset + """ +``` + +## 9. Architecture & Design Patterns + +### Core Pattern: Wrapper + Registry + Config + +1. **ComponentModel**: Wraps PyTorch models and injects components +2. **Registry** (`registry.py`): Centralized experiment configuration +3. **Config System** (Pydantic): Type-safe config loading/validation + +### Design Principle: Decouple Metrics from Core + +Metric and figure code encapsulated in `spd/metrics.py` and `spd/figures.py`. + +### Key Design Patterns + +**1. Abstract Base Classes for Interfaces** +```python +class LoadableModule(nn.Module, ABC): + @classmethod + @abstractmethod + def from_pretrained(cls, _path: ModelPath) -> "LoadableModule": + raise NotImplementedError("Subclasses must implement from_pretrained method.") +``` + +**2. Protocol-Based Design** +```python +class Metric(Protocol): + slow: ClassVar[bool] = False + metric_section: ClassVar[str] + + def update(...) -> None: ... + def compute(self) -> Any: ... +``` + +**3. Dataclass-Based Configuration** +```python +@dataclass +class ExperimentConfig: + task_name: TaskName + decomp_script: Path + config_path: Path + expected_runtime: int + canonical_run: str | None = None +``` + +**4. Pydantic for Runtime Validation** +```python +class BaseConfig(BaseModel): + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", frozen=True) + + @classmethod + def from_file(cls, path: Path | str) -> Self: + """Load config from path to a JSON or YAML file.""" +``` + +### Core Architecture Components + +- `spd/run_spd.py` - Main SPD optimization logic +- `spd/configs.py` - Pydantic config classes +- `spd/registry.py` - Centralized experiment registry +- `spd/models/component_model.py` - ComponentModel wrapper +- `spd/models/components.py` - Component types (Linear, Embedding, etc.) +- `spd/losses.py` - Loss functions (faithfulness, reconstruction, importance minimality) +- `spd/metrics.py` - Metrics (CI-L0, KL divergence, etc.) +- `spd/figures.py` - Figures (CI histograms, Identity plots, etc.) + +## 10. Project Structure + +``` +spd/ +├── spd/ # Main package +│ ├── models/ # Core model classes +│ ├── metrics/ # Metric implementations +│ ├── utils/ # Utilities (distributed, logging, data) +│ ├── experiments/ # Experiment implementations +│ │ ├── tms/ # Toy Model of Superposition +│ │ ├── resid_mlp/ # Residual MLP +│ │ ├── lm/ # Language models +│ │ └── ih/ # Identity insertion +│ ├── app/ # Streamlit application +│ │ ├── backend/ +│ │ └── frontend/ +│ ├── scripts/ # CLI entry points +│ └── [core modules] +├── tests/ # Test suite +│ ├── metrics/ # Metric tests +│ ├── scripts_run/ # Integration tests +│ └── [unit tests] +├── papers/ # Research papers (markdown) +├── typings/ # Type stubs +└── [configuration files] +``` + +### Organizational Principles + +- **Flat within experiments**: Each has `models.py`, `configs.py`, `{task}_decomposition.py`, `train_*.py`, `*_config.yaml`, `plotting.py` +- **Centralized registry**: `spd/registry.py` manages experiment configs +- **Clear separation**: Core logic vs metrics vs experiments +- **Modular metrics**: Each metric in its own file + +## 11. Configuration System + +### Multi-layered Configuration + +1. **YAML config files** define experiment parameters +2. **Pydantic config classes** provide type safety and validation +3. **Environment variables** can override runtime settings +4. **Nested config objects** for task-specific configs + +### Key Conventions + +- Paths: relative to repo root or `"wandb:"` prefix for WandB paths +- Configs **immutable** (`frozen=True`) and **forbid extra fields** (`extra="forbid"`) +- `ModelPath` type validates and normalizes paths automatically +- Pydantic validators handle deprecated keys and path resolution + +### Example Config + +```yaml +wandb_project: spd +seed: 0 +C: 1200 +n_mask_samples: 1 +ci_fn_type: "shared_mlp" +ci_fn_hidden_dims: [1000] +loss_metric_configs: + - classname: "ImportanceMinimalityLoss" + coeff: 0.004 + pnorm: 2.0 +``` + + +## 12. Error Handling & Fail Fast + +### Fail-Fast Philosophy (Negative Space Programming) + +Code should fail immediately when assumptions are violated, preventing bugs from propagating. + +### Assertions + +**If there's an assumption you're making while writing code, assert it:** +- If you were right, then it won't matter. If you were wrong, then the code **should** fail + +```python +assert component_params, "component_params is empty" +assert x.shape[-1] == 1, "Last dimension should be 1 after the final layer" +assert cfg.coeff is not None, "All loss metric configs must have a coeff" +``` + +### Explicit Error Types + +```python +raise ValueError(f"Only (.json, .yaml, .yml) files are supported, got {path}") +raise NotImplementedError("Subclasses must implement from_pretrained method.") +raise RuntimeError("Embedding modules not supported for identity insertion") +``` + +### Try-Except for Expected Errors + +```python +try: + return path.relative_to(REPO_ROOT) +except ValueError: + # If the path is not relative to REPO_ROOT, return the original path + return path +``` + +## 13. Tensor Operations + +### Use Einops for Clarity + +- Try to use **einops** by default for clarity over raw einsum +- **Assert shapes liberally** +- **Document complex tensor manipulations** + +**Example:** +```python +# Preferred - clear dimensions +result = einops.einsum(x, self.W, "... C d_in, C d_in d_out -> ... C d_out") + self.b + +# Also good - assert shapes +assert x.shape[-1] == d_in, f"Expected last dim to be {d_in}, got {x.shape[-1]}" +``` + +## 14. Testing Strategy + +### Testing Philosophy + +Tests ensure code works as expected, not for production (no deployment). Focus on unit tests for core functionality. Don't worry about integration/end-to-end tests - too much overhead for research code. Interactive use catches issues at low cost. + +**Framework:** pytest with pytest-xdist for parallel execution + +### Test Organization + +- **Test files**: `test_*.py` +- **Test functions**: `def test_*():` with descriptive names +- **Tests mirror source structure**: `tests/metrics/`, `tests/scripts_run/` +- **Fixtures centralized** in `conftest.py` and `metrics/fixtures.py` + +### Test Markers + +- `@pytest.mark.slow` - Excluded by default, run with `make test-all` +- `@pytest.mark.requires_wandb` - Tests requiring WandB access + +## 15. Logging + +Use `spd.log.logger` with special methods: `.info()`, `.warning()`, `.error()` (standard), `.values()` (dict of metrics), `.section()` (visual separator), `.set_format()` (swap formatter). + +```python +from spd.log import logger +logger.values({"loss": 0.42}, msg="Training metrics") +logger.section("Evaluation Phase") +``` + +**Config:** Console (INFO), File (WARNING → `logs/logs.log`), named "spd" + +## 16. Common Usage Patterns + +### Running SPD Experiments + +Use `spd-run` command: + +```bash +spd-run --experiments tms_5-2 # Specific experiment +spd-run --experiments tms_5-2,resid_mlp1 # Multiple experiments +spd-run # All experiments +``` + +Or run directly: +```bash +uv run spd/experiments/tms/tms_decomposition.py spd/experiments/tms/tms_5-2_config.yaml +``` + +Outputs: losses and figure paths for analysis. + +### Metrics and Figures + +Defined in `spd/metrics.py` and `spd/figures.py` as dictionaries of functions. Select and parameterize in experiment configs for easy extension without modifying core framework. + +### Running Sweeps + +Run hyperparameter sweeps using WandB on the GPU cluster: + +```bash +spd-run --experiments --sweep --n-agents [--cpu] [--job_suffix ] +``` + +**Examples:** +```bash +spd-run --experiments tms_5-2 --sweep --n-agents 4 # Run TMS 5-2 sweep with 4 GPU agents +spd-run --experiments resid_mlp2 --sweep --n-agents 3 --cpu # Run ResidualMLP2 sweep with 3 CPU agents +spd-run --sweep --n-agents 10 # Sweep all experiments with 10 agents +spd-run --experiments tms_5-2 --sweep custom.yaml --n-agents 2 # Use custom sweep params file +``` + +**How it works:** Creates WandB sweep from `spd/scripts/sweep_params.yaml` (or custom), deploys SLURM agents (GPU by default, `--cpu` for CPU), git snapshot for consistency. + +**Sweep parameters:** Load from `sweep_params.yaml` or custom file. Supports global and experiment-specific configs: + +```yaml +# Global parameters applied to all experiments +global: + seed: + values: [0, 1, 2] + lr: + values: [0.001, 0.01] + +# Experiment-specific parameters (override global) +tms_5-2: + seed: + values: [100, 200] # Overrides global seed + task_config: + feature_probability: + values: [0.05, 0.1] +``` + +**Logs:** Agent logs are found in `~/slurm_logs/slurm-_.out` + +### Evaluation Runs + +Run with default hyperparameters: + +```bash +spd-run # All experiments +spd-run --experiments tms_5-2-id,resid_mlp2,resid_mlp3 # Specific experiments +``` + +Multiple experiments without `--sweep` creates W&B report with aggregated visualizations. + +### Additional Options + +```bash +spd-run --project my-project # Use custom W&B project +spd-run --job_suffix test # Add suffix to SLURM job names +spd-run --no-create_report # Skip W&B report creation +``` + +### Cluster Usage Guidelines + +**IMPORTANT:** +- **DO NOT use more than 8 GPUs at one time** +- This includes not setting off multiple sweeps/evals that total >8 GPUs +- Monitor jobs with: `squeue --format="%.18i %.9P %.15j %.12u %.12T %.10M %.9l %.6D %b %R" --me` + +## 17. Distributed Training + +### DistributedState Management + +```python +@dataclass(frozen=True, slots=True) +class DistributedState: + rank: int + world_size: int + local_rank: int + backend: Literal["nccl", "gloo"] +``` + +### Conventions + +- **MPI-based** rank initialization +- **NCCL backend** for GPU, **gloo** for CPU +- Utilities in `spd/utils/distributed.py`: gradient sync, metric averaging, device detection +- `torch.nn.parallel.DistributedDataParallel` for multi-GPU training + +## 18. Git & Pull Request Workflow + +### Branch Naming + +- `refactor/X` - Refactoring work +- `feature/Y` - New features +- `fix/Z` - Bug fixes + +### Using GitHub CLI + +- To view issues and PRs: `gh issue view 28` or `gh pr view 30` +- Use the PR template defined in `.github/pull_request_template.md` +- Important: You should almost never use --no-verify. The pre-commit checks are there for a reason. + +### PR Checklist + +- Review every line of the diff +- All CI checks pass +- Merge latest changes from main branch +- Use "Closes #XX" format for issue linking +- Only commit files that include relevant changes, don't commit all files + +### Commit Messages + +Explain "what" and "why". Clear, descriptive, focused on relevant changes. Explain purpose, not just the diff. + +### PR Template Sections + +1. Description - What changed +2. Related Issue - Use "Closes #XX" format +3. Motivation and Context - Why needed +4. Testing - How tested +5. Breaking Changes + +## 19. Key Dependencies & Tools + +### Core Stack + +- **PyTorch** (>=2.6) +- **Transformers** - HuggingFace models and tokenizers +- **WandB** (>=0.20.1) - Optional, disable with `wandb_project=None` +- **Pydantic** (<2.12) +- **jaxtyping** - Type annotations for tensors +- **einops** - Tensor operations (preferred over einsum) +- **Fire** - CLI argument parsing + +### Development Tooling + +- **ruff** - Linter and formatter (NOT black + flake8 + isort) +- **basedpyright** - Type checker (NOT mypy) +- **pytest + pytest-xdist** - Testing with parallelization +- **uv** - Package manager (NOT pip/poetry) +- **pre-commit** - Git hooks + +### Additional Libraries + +- **datasets** (>=2.21.0) - HuggingFace data loading +- **streamlit** - Web UI +- **python-dotenv** - Environment variables +- **torchvision** (>=0.23,<0.24) + +## 20. Quick Reference + +### Key Principles Summary + +1. **Simplicity** - Code for researchers with varying experience +2. **Type Safety** - jaxtyping, Pydantic, strict basedpyright +3. **Fail Fast** - Liberal assertions, explicit errors +4. **Minimal Comments** - Complex logic only +5. **Modularity** - Registry pattern, interfaces, protocols +6. **Decouple Metrics** - Separate from core +7. **Reproducibility** - Centralized configs, seeds, WandB +8. **Research Testing** - Unit tests, minimal integration +9. **Clear Architecture** - Wrapper + Registry + Config +10. **Consistent Style** - 100 char, snake_case, PEP 604 + +### Common Commands Cheatsheet + +```bash +# Setup +source .venv/bin/activate +make install-dev + +# Development +make check # Format + type check +make format # Ruff lint and format +make type # Type check only +make test # Run tests (fast) +make test-all # Run all tests + +# Running experiments +spd-run --experiments tms_5-2 +spd-run --experiments tms_5-2 --sweep --n-agents 4 + +# Git/GitHub +gh issue view 28 +gh pr view 30 +git checkout -b feature/my-feature + +# Monitoring cluster +squeue --format="%.18i %.9P %.15j %.12u %.12T %.10M %.9l %.6D %b %R" --me +``` + +### File Locations Reference + +- **Core SPD**: `spd/run_spd.py`, `spd/configs.py`, `spd/registry.py` +- **Models**: `spd/models/component_model.py`, `spd/models/components.py` +- **Metrics**: `spd/metrics.py`, `spd/figures.py` +- **Experiments**: `spd/experiments/{tms,resid_mlp,lm,ih}/` +- **Tests**: `tests/`, `tests/metrics/`, `tests/scripts_run/` +- **Configs**: `spd/experiments/*/\*_config.yaml` +- **Papers**: `papers/Stochastic_Parameter_Decomposition/`, `papers/Attribution_based_Parameter_Decomposition/` diff --git a/spd/configs.py b/spd/configs.py index 33de31af2..7ab3f9d71 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -91,6 +91,10 @@ class PGDReconLayerwiseLossConfig(PGDConfig): classname: Literal["PGDReconLayerwiseLoss"] = "PGDReconLayerwiseLoss" +class PGDCEDiffConfig(PGDConfig): + classname: Literal["PGDCEDiff"] = "PGDCEDiff" + + class PGDMultiBatchConfig(LossMetricConfig): init: PGDInitStrategy step_size: float @@ -194,6 +198,7 @@ class UVPlotsConfig(BaseConfig): | PermutedCIPlotsConfig | UVPlotsConfig | StochasticReconSubsetCEAndKLConfig + | PGDCEDiffConfig | PGDMultiBatchReconLossConfig | PGDMultiBatchReconSubsetLossConfig ) diff --git a/spd/eval.py b/spd/eval.py index e7dcf8210..1d7591361 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -24,6 +24,7 @@ ImportanceMinimalityLossConfig, MetricConfigType, PermutedCIPlotsConfig, + PGDCEDiffConfig, PGDMultiBatchReconLossConfig, PGDMultiBatchReconSubsetLossConfig, PGDReconLayerwiseLossConfig, @@ -49,6 +50,7 @@ from spd.metrics.identity_ci_error import IdentityCIError from spd.metrics.importance_minimality_loss import ImportanceMinimalityLoss from spd.metrics.permuted_ci_plots import PermutedCIPlots +from spd.metrics.pgd_ce_diff import PGDCEDiff from spd.metrics.pgd_masked_recon_layerwise_loss import PGDReconLayerwiseLoss from spd.metrics.pgd_masked_recon_loss import PGDReconLoss from spd.metrics.pgd_masked_recon_subset_loss import PGDReconSubsetLoss @@ -233,6 +235,13 @@ def init_metric( output_loss_type=run_config.output_loss_type, pgd_config=cfg, ) + case PGDCEDiffConfig(): + metric = PGDCEDiff( + model=model, + device=device, + use_delta_component=run_config.use_delta_component, + pgd_config=cfg, + ) case StochasticReconSubsetCEAndKLConfig(): metric = StochasticReconSubsetCEAndKL( model=model, diff --git a/spd/metrics/__init__.py b/spd/metrics/__init__.py index c77578ed0..600603472 100644 --- a/spd/metrics/__init__.py +++ b/spd/metrics/__init__.py @@ -18,6 +18,8 @@ from .importance_minimality_loss import ImportanceMinimalityLoss as ImportanceMinimalityLoss from .importance_minimality_loss import importance_minimality_loss as importance_minimality_loss from .permuted_ci_plots import PermutedCIPlots as PermutedCIPlots +from .pgd_ce_diff import PGDCEDiff as PGDCEDiff +from .pgd_ce_diff import pgd_ce_diff_loss_update as pgd_ce_diff_loss_update from .pgd_masked_recon_layerwise_loss import PGDReconLayerwiseLoss as PGDReconLayerwiseLoss from .pgd_masked_recon_layerwise_loss import ( pgd_recon_layerwise_loss as pgd_recon_layerwise_loss, diff --git a/spd/metrics/pgd_ce_diff.py b/spd/metrics/pgd_ce_diff.py new file mode 100644 index 000000000..b59f483f9 --- /dev/null +++ b/spd/metrics/pgd_ce_diff.py @@ -0,0 +1,218 @@ +from functools import partial +from typing import Any, ClassVar, Literal, override + +import einops +import torch +import torch.nn.functional as F +from jaxtyping import Float, Int +from torch import Tensor +from torch.distributed import ReduceOp + +from spd.configs import PGDConfig +from spd.metrics.base import Metric +from spd.models.component_model import CIOutputs, ComponentModel +from spd.models.components import RoutingMasks, make_mask_infos +from spd.utils.component_utils import RoutingType, sample_uniform_k_subset_routing_masks +from spd.utils.distributed_utils import all_reduce + + +def pgd_ce_diff_loss_update( + model: ComponentModel, + batch: Int[Tensor, "..."], + ci: dict[str, Float[Tensor, "... C"]], + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + target_out: Float[Tensor, "... vocab"], + routing: RoutingType, + pgd_config: PGDConfig, +) -> tuple[Float[Tensor, ""], int]: + """PGD optimization for CE difference metric. + + Optimizes adversarial stochastic masks to maximize CE loss against true labels. + Only applicable for language model tasks with 3D outputs (batch, seq_len, vocab). + """ + if target_out.ndim != 3: + return torch.tensor(0.0, device=batch.device), 0 + + assert batch.ndim == 2, "Batch must be 2D (batch, seq_len)" + + C = model.C + batch_dims = next(iter(ci.values())).shape[:-1] + n_layers = len(ci) + C2 = C if weight_deltas is None else C + 1 + + masked_batch = batch.clone() + masked_batch[:, 0] = -100 + flat_masked_batch = masked_batch.flatten() + + match routing: + case "all": + routing_masks = "all" + case "uniform_k-stochastic": + routing_masks = sample_uniform_k_subset_routing_masks( + mask_shape=batch_dims, + module_names=model.target_module_paths, + device=batch.device, + ) + + match pgd_config.mask_scope: + case "unique_per_datapoint": + adv_source_shape = torch.Size([n_layers, *batch_dims, C2]) + case "shared_across_batch": + singleton_batch_dims = [1 for _ in batch_dims] + adv_source_shape = torch.Size([n_layers, *singleton_batch_dims, C2]) + + adv_sources: Float[Tensor, "n_layers *batch_dims C2"] | Float[Tensor, "n_layers *1 C2"] = ( + _get_pgd_init_tensor(pgd_config.init, adv_source_shape, batch.device).requires_grad_(True) + ) + + fwd_pass = partial( + _forward_with_adv_sources_ce, + model=model, + batch=batch, + adv_sources=adv_sources, + ci=ci, + weight_deltas=weight_deltas, + routing_masks=routing_masks, + flat_masked_batch=flat_masked_batch, + batch_dims=batch_dims, + ) + + for _ in range(pgd_config.n_steps): + assert adv_sources.grad is None + with torch.enable_grad(): + ce_loss = fwd_pass() + (adv_sources_grads,) = torch.autograd.grad(ce_loss, adv_sources) + adv_sources_grads = all_reduce(adv_sources_grads, op=ReduceOp.SUM) + with torch.no_grad(): + adv_sources.add_(pgd_config.step_size * adv_sources_grads.sign()) + adv_sources.clamp_(0.0, 1.0) + + final_ce_loss = fwd_pass() + + flat_target_logits = einops.rearrange(target_out, "b seq_len vocab -> (b seq_len) vocab") + target_ce_loss = F.cross_entropy( + flat_target_logits[:-1], flat_masked_batch[1:], ignore_index=-100, reduction="sum" + ) + + n_positions = batch.shape[0] * batch.shape[1] + ce_diff = final_ce_loss - target_ce_loss + + return ce_diff, n_positions + + +def _forward_with_adv_sources_ce( + model: ComponentModel, + batch: Int[Tensor, "..."], + adv_sources: Float[Tensor, "n_layers *batch_dim_or_ones C2"], + ci: dict[str, Float[Tensor, "... C"]], + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + routing_masks: RoutingMasks, + flat_masked_batch: Int[Tensor, "..."], + batch_dims: tuple[int, ...], +) -> Float[Tensor, ""]: + expanded_adv_sources = adv_sources.expand(-1, *batch_dims, -1) + adv_sources_components: Float[Tensor, "n_layers *batch_dims C"] + match weight_deltas: + case None: + weight_deltas_and_masks = None + adv_sources_components = expanded_adv_sources + case dict(): + weight_deltas_and_masks = { + k: (weight_deltas[k], expanded_adv_sources[i, ..., -1]) + for i, k in enumerate(weight_deltas) + } + adv_sources_components = expanded_adv_sources[..., :-1] + + mask_infos = make_mask_infos( + component_masks=_interpolate_component_mask(ci, adv_sources_components), + weight_deltas_and_masks=weight_deltas_and_masks, + routing_masks=routing_masks, + ) + out = model(batch, mask_infos=mask_infos) + + flat_logits = einops.rearrange(out, "b seq_len vocab -> (b seq_len) vocab") + ce_loss = F.cross_entropy( + flat_logits[:-1], flat_masked_batch[1:], ignore_index=-100, reduction="sum" + ) + + return ce_loss + + +def _get_pgd_init_tensor( + init: Literal["random", "ones", "zeroes"], + shape: tuple[int, ...], + device: torch.device | str, +) -> Float[Tensor, "... shape"]: + match init: + case "random": + return torch.rand(shape, device=device) + case "ones": + return torch.ones(shape, device=device) + case "zeroes": + return torch.zeros(shape, device=device) + + +def _interpolate_component_mask( + ci: dict[str, Float[Tensor, "*batch_dims C"]], + adv_sources_components: Float[Tensor, "n_layers *batch_dims C"], +) -> dict[str, Float[Tensor, "*batch_dims C"]]: + """Set the mask value to ci + (1 - ci) * adv_sources_components.""" + assert torch.all(adv_sources_components <= 1.0) and torch.all(adv_sources_components >= 0.0) + assert adv_sources_components.shape[0] == len(ci) + assert all(ci[k].shape[-1] == adv_sources_components.shape[-1] for k in ci) + component_masks: dict[str, Float[Tensor, "*batch_dims C"]] = {} + for i, module_name in enumerate(ci): + scaled_noise_to_add = (1 - ci[module_name]) * adv_sources_components[i] + component_masks[module_name] = ci[module_name] + scaled_noise_to_add + return component_masks + + +class PGDCEDiff(Metric): + """CE difference metric using adversarially-optimized PGD masks. + + This metric uses PGD to find masks that maximize cross-entropy loss against true labels, + then reports the CE difference from the target model. + """ + + metric_section: ClassVar[str] = "ce_kl" + + def __init__( + self, + model: ComponentModel, + device: str, + pgd_config: PGDConfig, + use_delta_component: bool, + ) -> None: + self.model = model + self.pgd_config: PGDConfig = pgd_config + self.use_delta_component: bool = use_delta_component + self.sum_ce_diff = torch.tensor(0.0, device=device) + self.n_positions = torch.tensor(0, device=device) + + @override + def update( + self, + *, + batch: Int[Tensor, "..."], + target_out: Float[Tensor, "... vocab"], + ci: CIOutputs, + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + **_: Any, + ) -> None: + ce_diff, n_positions = pgd_ce_diff_loss_update( + model=self.model, + batch=batch, + ci=ci.lower_leaky, + weight_deltas=weight_deltas if self.use_delta_component else None, + target_out=target_out, + routing="all", + pgd_config=self.pgd_config, + ) + self.sum_ce_diff += ce_diff + self.n_positions += n_positions + + @override + def compute(self) -> dict[str, float]: + sum_ce_diff = all_reduce(self.sum_ce_diff, op=ReduceOp.SUM) + n_positions = all_reduce(self.n_positions, op=ReduceOp.SUM) + return {"ce_difference_pgd_masked": (sum_ce_diff / n_positions).item()}