Draft
Conversation
- Add new 'mem' experiment under spd/experiments/mem/ - MemTransformer: LLaMA-style transformer with RMSNorm, RoPE, SwiGLU - MemDataset: generates fixed facts (3-token input -> 1-token label) - train_mem.py: training script with cross-entropy at final position - mem_decomposition.py: SPD decomposition script - mem_config.yaml: example decomposition config - Add 'mem' output_loss_type to compute KL only at final sequence position - Updated calc_sum_recon_loss_lm in general_utils.py - Updated all metric files to support 'mem' loss type - Register 'mem' task in spd_types.py, configs.py, and registry.py
- Generate inputs with deduplication to prevent collisions - Use numpy for efficient duplicate detection via sorting - Add validation to ensure n_facts doesn't exceed possible unique inputs - Regenerate additional inputs if needed after deduplication
Script to analyze which memorized inputs each component activates on: - Loads SPD decomposition and recreates the training dataset - Computes causal importance for all facts - Shows facts where each component has CI above threshold (default 0.1) - Provides summary statistics per component Usage: python -m spd.experiments.mem.analyze_decomposition <spd_run_path>
- Components now displayed in order of mean causal importance (descending) - Added rank number to output for easy reference - Shows mean CI in component headers
- New --output_file flag to write results to a text file - If not specified, prints to stdout as before - Progress messages still print to terminal during computation
Architecture changes: - Replace RMSNorm with LayerNorm - Replace RoPE with learned position embeddings - Replace SwiGLU (gate_proj + up_proj) with single up_proj + GELU - Add bias to linear layers (except unembed) - Keep pre-norm architecture (GPT-2 style) Config changes: - Remove block.mlp.gate_proj from target_module_patterns (no longer exists)
- New config option 'use_layer_norm' (default: True) - When False, removes all LayerNorm from: - GPTBlock (ln1, ln2) - MemTransformer (ln_f) - Useful for easier model interpretability
New section shows for each fact which components activate above threshold,
sorted by causal importance. Components are displayed as C{idx}(CI_value).
- Add expand, d_model_new, d_mlp_new fields to MemTaskConfig - Implement expand_model() in models.py to pad weights with zeros - Call expand_model in mem_decomposition.py when expand=True This allows decomposing a model in a larger parameter space without changing its actual behavior (weights are zero-padded).
- analyze_decomposition.py: per-fact component analysis - mem_config.yaml: config updates - mem_dataset.py: unique input generation - train_mem.py: training fixes - general_utils.py: mem loss type support
- Add global exception handlers for RequestValidationError, HTTPException, and Exception - Add request/response logging middleware - Replace silent JSONResponse error returns with HTTPException raises - Ensures all errors log tracebacks and are visible in server logs 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
* Allow building new graph in interventions tab * Add 'Generating' message * Prevent multiple manual graphs with the same nodes * Remove unused graph name everywhere * Address some PR review comments * More PR fixes * Remove comments
Creates per-layer scatter plots showing normalized component activation values for datapoints where CI exceeds a threshold. Components are ranked by median activation on the x-axis. Usage: python scripts/plot_component_activations.py <run_id> --ci-threshold 0.1 Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Add second set of plots ordered by firing frequency (from harvest data)
- Transform y-values to |normalized - 0.5| for frequency plots
- Organize outputs: scripts/outputs/<run-id>/component-act-scatter/order-by-{median,freq}/
- Include run_id in plot titles
- Use pre-calculated firing_counts from token_stats.pt
Co-Authored-By: Claude Opus 4.5 <[email protected]>
…nticity plots) (#343) * Add script to plot component activations vs component rank Creates per-layer scatter plots showing normalized component activation values for datapoints where CI exceeds a threshold. Components are ranked by median activation on the x-axis. Usage: python scripts/plot_component_activations.py <run_id> --ci-threshold 0.1 Co-Authored-By: Claude Opus 4.5 <[email protected]> * Add frequency-ordered plots with abs distance from midpoint - Add second set of plots ordered by firing frequency (from harvest data) - Transform y-values to |normalized - 0.5| for frequency plots - Organize outputs: scripts/outputs/<run-id>/component-act-scatter/order-by-{median,freq}/ - Include run_id in plot titles - Use pre-calculated firing_counts from token_stats.pt Co-Authored-By: Claude Opus 4.5 <[email protected]> --------- Co-authored-by: Claude Opus 4.5 <[email protected]>
…ormat - Resolve merge conflict in spd/models/component_model.py (keep methods from both branches) - Update all YAML configs to use new ci_config discriminated union format: - Old: ci_fn_type + ci_fn_hidden_dims - New: ci_config with mode, fn_type, hidden_dims - Update test_grid_search.py inline configs to include ci_config field - Includes all changes from dev/app branch (app improvements, renames, etc.) Co-Authored-By: Claude Opus 4.5 <[email protected]>
Resolve conflict in spd/models/component_model.py by keeping methods from both branches: - HEAD: _calc_layerwise_causal_importances, _calc_global_causal_importances - dev/app: get_all_component_acts Co-Authored-By: Claude Opus 4.5 <[email protected]>
Resolves merge conflicts in PromptCardHeader.svelte and ss_llama_simple_mlp-2L-wide.yaml Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Fix incorrect tensor assertions (tensor.all() <= 1.0 → (tensor <= 1.0).all()) - Fix non-deterministic layer ordering (use sorted keys) - Add config validation for CI fn_type/mode compatibility - Add checkpoint compatibility validation between CI modes - Extract _get_module_input_dim() helper to reduce duplication - Improve type annotation for global_ci_fn - Add 17 comprehensive tests for global CI functionality Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Remove broken RENAMED_CONFIG_KEYS entries (gate_type, gate_hidden_dims) - Remove redundant _validate_ci_config() - type system already enforces this - Simplify checkpoint validation error messages to be accurate - Remove redundant comments in GlobalSharedMLPCiFn and _calc_global_causal_importances - Add test for binomial sampling with global CI Co-Authored-By: Claude Opus 4.5 <[email protected]>
Remove accidentally committed large PNG output files (48 files, ~100MB). Add scripts/outputs/ to .gitignore to prevent future accidents. Note: Files remain in git history on this branch. For complete removal, run git-filter-repo on the main repo after merge. Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Add usage examples to docstring - Move script from scripts/ to spd/scripts/ - Change output directory to use Path(__file__).parent / "out" Co-Authored-By: Claude Opus 4.5 <[email protected]>
Replace dual tracking of global/layerwise CI with single unified interface: - Add wrapper classes (GlobalCiFnWrapper, LayerwiseCiFnWrapper) to components.py - Replace is_global_ci flag and separate ci_fns/global_ci_fn attributes with single ci_fn attribute in ComponentModel - Remove _calc_layerwise_causal_importances and _calc_global_causal_importances methods - logic now in wrapper forward() methods - Simplify parameter collection in run_spd.py and gradient logging in logging_utils.py - Update tests to use new wrapper-based assertions This eliminates the boolean flag anti-pattern and reduces 5 CI-related attributes to 1, making the code cleaner and more maintainable. Co-Authored-By: Claude Opus 4.5 <[email protected]>
Rename methods and variables to be explicit about layerwise vs global: - _create_ci_fn → _create_layerwise_ci_fn - _create_ci_fns → _create_layerwise_ci_fns - has_ci_fns → has_layerwise_ci_fns - ci_fns (local var) → layerwise_ci_fns Co-Authored-By: Claude Opus 4.5 <[email protected]>
For consistency with config classes and wrapper class definitions, put layerwise CI cases before global CI cases in match statements. Co-Authored-By: Claude Opus 4.5 <[email protected]>
Change from old flat format (ci_fn_type, ci_fn_hidden_dims) to new nested ci_config format. Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Update all 6 canonical runs in registry.py with new run IDs - Update clustering test to use resid_mlp2 run with new format New canonical runs: - tms_5-2: s-38e1a3e2 - tms_5-2-id: s-a1c0e9e2 - tms_40-10: s-7387fc20 - tms_40-10-id: s-2a2b5a57 - resid_mlp1: s-62fce8c4 - resid_mlp2: s-a9ad193d Co-Authored-By: Claude Opus 4.5 <[email protected]>
Streaming dataset loading only works for 'lm' tasks, not resid_mlp. Co-Authored-By: Claude Opus 4.5 <[email protected]>
Resolve conflicts: - database.py: Remove deprecated pnorm_2 field (align with main) - configs.py: Keep migration logic for pnorm_1->pnorm, add logging warning for beta default
This feature was no longer needed. Removes the expand_model function, related config fields (expand, d_model_new, d_mlp_new), and their usage in mem_decomposition.py.
Co-Authored-By: Claude Opus 4.5 <[email protected]>
…re/mem # Conflicts: # spd/spd_types.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Related Issue
Motivation and Context
How Has This Been Tested?
Does this PR introduce a breaking change?