Skip to content

Remove BatchT/OutputT generics, introduce ReconstructionLoss protocol#380

Closed
ocg-goodfire wants to merge 2 commits intodevfrom
refactor/de-genericize-component-model
Closed

Remove BatchT/OutputT generics, introduce ReconstructionLoss protocol#380
ocg-goodfire wants to merge 2 commits intodevfrom
refactor/de-genericize-component-model

Conversation

@ocg-goodfire
Copy link
Collaborator

Description

This PR was authored by Claude (Opus 4.6) at Oli's request, implementing the de-genericize plan from PR #363.

Removes BatchT/OutputT type parameters from ComponentModel and all downstream code. Model batch input and output types are now Any. Introduces a ReconstructionLoss protocol with concrete implementations (recon_loss_mse, recon_loss_kl) that callers pass explicitly, replacing the old output_loss_type string config.

Key changes:

  • Remove BatchT/OutputT generics from ComponentModel and all downstream code
  • Add ReconstructionLoss protocol in spd/models/batch_and_loss_fns.py with recon_loss_mse and recon_loss_kl
  • Rename pretrained_model_output_attrextract_tensor_output with regex-based accessor parsing (supports .logits, [0], .output[0], etc.)
  • Remove output_loss_type config field (added to deprecated keys with migration)
  • Remove extract_batch_data utility — callers handle tuple extraction inline
  • Add lm_collate_fn for LM data loading in spd/data.py
  • Update all 24 YAML configs, 11 metric files, 4 experiment scripts, and all tests

Related Issue

Implements the de-genericize portion of #363

Motivation and Context

The BatchT/OutputT generics on ComponentModel added type complexity without meaningful safety — model I/O is inherently heterogeneous across experiment types. The output_loss_type string-based dispatch was a code smell; passing a concrete ReconstructionLoss callable is simpler and more extensible. The extract_tensor_output rename + regex parser is more general than the old pretrained_model_output_attr approach.

How Has This Been Tested?

  • make check passes (basedpyright + ruff lint + ruff format, 0 errors)
  • make test passes (359 passed, 16 skipped, 0 failures)

Does this PR introduce a breaking change?

Yes — configs with output_loss_type or pretrained_model_output_attr will use the migration path in configs.py (deprecated key handling + value conversion). The output_loss_type field is removed from Config; callers must pass reconstruction_loss explicitly to optimize(). The extract_batch_data utility is removed.

@ocg-goodfire ocg-goodfire changed the base branch from feature/autointerp-improvements-2 to main February 11, 2026 19:11
…ructionLoss protocol

De-genericize ComponentModel by removing BatchT and OutputT type parameters.
Model batch input and output types are now Any. Introduces a ReconstructionLoss
protocol with concrete implementations (recon_loss_mse, recon_loss_kl) that
callers pass explicitly instead of the old output_loss_type string config.

Key changes:
- Remove BatchT/OutputT generics from ComponentModel and all downstream code
- Add ReconstructionLoss protocol in spd/models/batch_and_loss_fns.py
- Rename pretrained_model_output_attr -> extract_tensor_output with regex-based parsing
- Remove output_loss_type config field (added to deprecated keys with migration)
- Remove extract_batch_data utility (callers handle tuple extraction inline)
- Add lm_collate_fn for LM data loading
- Update all 24 YAML configs, 11 metric files, 4 experiment scripts, and all tests

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@ocg-goodfire ocg-goodfire force-pushed the refactor/de-genericize-component-model branch from 5924d7c to 255053a Compare February 11, 2026 22:07
@ocg-goodfire ocg-goodfire changed the base branch from main to dev February 11, 2026 22:08
@ocg-goodfire ocg-goodfire force-pushed the refactor/de-genericize-component-model branch from 0fdf73f to af97570 Compare February 11, 2026 22:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants