Skip to content

Feature/mem#327

Draft
danbraunai-goodfire wants to merge 166 commits intomainfrom
feature/mem
Draft

Feature/mem#327
danbraunai-goodfire wants to merge 166 commits intomainfrom
feature/mem

Conversation

@danbraunai-goodfire
Copy link
Collaborator

Description

Related Issue

Motivation and Context

How Has This Been Tested?

Does this PR introduce a breaking change?

Lucius Bushnaq and others added 30 commits December 3, 2025 13:55
- Add new 'mem' experiment under spd/experiments/mem/
  - MemTransformer: LLaMA-style transformer with RMSNorm, RoPE, SwiGLU
  - MemDataset: generates fixed facts (3-token input -> 1-token label)
  - train_mem.py: training script with cross-entropy at final position
  - mem_decomposition.py: SPD decomposition script
  - mem_config.yaml: example decomposition config

- Add 'mem' output_loss_type to compute KL only at final sequence position
  - Updated calc_sum_recon_loss_lm in general_utils.py
  - Updated all metric files to support 'mem' loss type

- Register 'mem' task in spd_types.py, configs.py, and registry.py
- Generate inputs with deduplication to prevent collisions
- Use numpy for efficient duplicate detection via sorting
- Add validation to ensure n_facts doesn't exceed possible unique inputs
- Regenerate additional inputs if needed after deduplication
Script to analyze which memorized inputs each component activates on:
- Loads SPD decomposition and recreates the training dataset
- Computes causal importance for all facts
- Shows facts where each component has CI above threshold (default 0.1)
- Provides summary statistics per component

Usage: python -m spd.experiments.mem.analyze_decomposition <spd_run_path>
- Components now displayed in order of mean causal importance (descending)
- Added rank number to output for easy reference
- Shows mean CI in component headers
- New --output_file flag to write results to a text file
- If not specified, prints to stdout as before
- Progress messages still print to terminal during computation
Architecture changes:
- Replace RMSNorm with LayerNorm
- Replace RoPE with learned position embeddings
- Replace SwiGLU (gate_proj + up_proj) with single up_proj + GELU
- Add bias to linear layers (except unembed)
- Keep pre-norm architecture (GPT-2 style)

Config changes:
- Remove block.mlp.gate_proj from target_module_patterns (no longer exists)
- New config option 'use_layer_norm' (default: True)
- When False, removes all LayerNorm from:
  - GPTBlock (ln1, ln2)
  - MemTransformer (ln_f)
- Useful for easier model interpretability
New section shows for each fact which components activate above threshold,
sorted by causal importance. Components are displayed as C{idx}(CI_value).
- Add expand, d_model_new, d_mlp_new fields to MemTaskConfig
- Implement expand_model() in models.py to pad weights with zeros
- Call expand_model in mem_decomposition.py when expand=True

This allows decomposing a model in a larger parameter space without
changing its actual behavior (weights are zero-padded).
- analyze_decomposition.py: per-fact component analysis
- mem_config.yaml: config updates
- mem_dataset.py: unique input generation
- train_mem.py: training fixes
- general_utils.py: mem loss type support
- Add global exception handlers for RequestValidationError, HTTPException, and Exception
- Add request/response logging middleware
- Replace silent JSONResponse error returns with HTTPException raises
- Ensures all errors log tracebacks and are visible in server logs

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
danbraunai-goodfire and others added 30 commits January 16, 2026 09:44
* Allow building new graph in interventions tab

* Add 'Generating' message

* Prevent multiple manual graphs with the same nodes

* Remove unused graph name everywhere

* Address some PR review comments

* More PR fixes

* Remove comments
Creates per-layer scatter plots showing normalized component activation
values for datapoints where CI exceeds a threshold. Components are ranked
by median activation on the x-axis.

Usage: python scripts/plot_component_activations.py <run_id> --ci-threshold 0.1

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Add second set of plots ordered by firing frequency (from harvest data)
- Transform y-values to |normalized - 0.5| for frequency plots
- Organize outputs: scripts/outputs/<run-id>/component-act-scatter/order-by-{median,freq}/
- Include run_id in plot titles
- Use pre-calculated firing_counts from token_stats.pt

Co-Authored-By: Claude Opus 4.5 <[email protected]>
…nticity plots) (#343)

* Add script to plot component activations vs component rank

Creates per-layer scatter plots showing normalized component activation
values for datapoints where CI exceeds a threshold. Components are ranked
by median activation on the x-axis.

Usage: python scripts/plot_component_activations.py <run_id> --ci-threshold 0.1

Co-Authored-By: Claude Opus 4.5 <[email protected]>

* Add frequency-ordered plots with abs distance from midpoint

- Add second set of plots ordered by firing frequency (from harvest data)
- Transform y-values to |normalized - 0.5| for frequency plots
- Organize outputs: scripts/outputs/<run-id>/component-act-scatter/order-by-{median,freq}/
- Include run_id in plot titles
- Use pre-calculated firing_counts from token_stats.pt

Co-Authored-By: Claude Opus 4.5 <[email protected]>

---------

Co-authored-by: Claude Opus 4.5 <[email protected]>
…ormat

- Resolve merge conflict in spd/models/component_model.py (keep methods from both branches)
- Update all YAML configs to use new ci_config discriminated union format:
  - Old: ci_fn_type + ci_fn_hidden_dims
  - New: ci_config with mode, fn_type, hidden_dims
- Update test_grid_search.py inline configs to include ci_config field
- Includes all changes from dev/app branch (app improvements, renames, etc.)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Resolve conflict in spd/models/component_model.py by keeping methods from both branches:
- HEAD: _calc_layerwise_causal_importances, _calc_global_causal_importances
- dev/app: get_all_component_acts

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Resolves merge conflicts in PromptCardHeader.svelte and ss_llama_simple_mlp-2L-wide.yaml

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Fix incorrect tensor assertions (tensor.all() <= 1.0 → (tensor <= 1.0).all())
- Fix non-deterministic layer ordering (use sorted keys)
- Add config validation for CI fn_type/mode compatibility
- Add checkpoint compatibility validation between CI modes
- Extract _get_module_input_dim() helper to reduce duplication
- Improve type annotation for global_ci_fn
- Add 17 comprehensive tests for global CI functionality

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Remove broken RENAMED_CONFIG_KEYS entries (gate_type, gate_hidden_dims)
- Remove redundant _validate_ci_config() - type system already enforces this
- Simplify checkpoint validation error messages to be accurate
- Remove redundant comments in GlobalSharedMLPCiFn and _calc_global_causal_importances
- Add test for binomial sampling with global CI

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Remove accidentally committed large PNG output files (48 files, ~100MB).
Add scripts/outputs/ to .gitignore to prevent future accidents.

Note: Files remain in git history on this branch. For complete removal,
run git-filter-repo on the main repo after merge.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Add usage examples to docstring
- Move script from scripts/ to spd/scripts/
- Change output directory to use Path(__file__).parent / "out"

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Replace dual tracking of global/layerwise CI with single unified interface:

- Add wrapper classes (GlobalCiFnWrapper, LayerwiseCiFnWrapper) to
  components.py
- Replace is_global_ci flag and separate ci_fns/global_ci_fn attributes
  with single ci_fn attribute in ComponentModel
- Remove _calc_layerwise_causal_importances and _calc_global_causal_importances
  methods - logic now in wrapper forward() methods
- Simplify parameter collection in run_spd.py and gradient logging in
  logging_utils.py
- Update tests to use new wrapper-based assertions

This eliminates the boolean flag anti-pattern and reduces 5 CI-related
attributes to 1, making the code cleaner and more maintainable.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Rename methods and variables to be explicit about layerwise vs global:
- _create_ci_fn → _create_layerwise_ci_fn
- _create_ci_fns → _create_layerwise_ci_fns
- has_ci_fns → has_layerwise_ci_fns
- ci_fns (local var) → layerwise_ci_fns

Co-Authored-By: Claude Opus 4.5 <[email protected]>
For consistency with config classes and wrapper class definitions,
put layerwise CI cases before global CI cases in match statements.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Change from old flat format (ci_fn_type, ci_fn_hidden_dims) to new
nested ci_config format.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Update all 6 canonical runs in registry.py with new run IDs
- Update clustering test to use resid_mlp2 run with new format

New canonical runs:
- tms_5-2: s-38e1a3e2
- tms_5-2-id: s-a1c0e9e2
- tms_40-10: s-7387fc20
- tms_40-10-id: s-2a2b5a57
- resid_mlp1: s-62fce8c4
- resid_mlp2: s-a9ad193d

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Streaming dataset loading only works for 'lm' tasks, not resid_mlp.

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Resolve conflicts:
- database.py: Remove deprecated pnorm_2 field (align with main)
- configs.py: Keep migration logic for pnorm_1->pnorm, add logging warning for beta default
This feature was no longer needed. Removes the expand_model function,
related config fields (expand, d_model_new, d_mlp_new), and their usage
in mem_decomposition.py.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants