Remove BatchT/OutputT generics, introduce ReconstructionLoss protocol#380
Closed
ocg-goodfire wants to merge 2 commits intodevfrom
Closed
Remove BatchT/OutputT generics, introduce ReconstructionLoss protocol#380ocg-goodfire wants to merge 2 commits intodevfrom
ocg-goodfire wants to merge 2 commits intodevfrom
Conversation
…ructionLoss protocol De-genericize ComponentModel by removing BatchT and OutputT type parameters. Model batch input and output types are now Any. Introduces a ReconstructionLoss protocol with concrete implementations (recon_loss_mse, recon_loss_kl) that callers pass explicitly instead of the old output_loss_type string config. Key changes: - Remove BatchT/OutputT generics from ComponentModel and all downstream code - Add ReconstructionLoss protocol in spd/models/batch_and_loss_fns.py - Rename pretrained_model_output_attr -> extract_tensor_output with regex-based parsing - Remove output_loss_type config field (added to deprecated keys with migration) - Remove extract_batch_data utility (callers handle tuple extraction inline) - Add lm_collate_fn for LM data loading - Update all 24 YAML configs, 11 metric files, 4 experiment scripts, and all tests Co-Authored-By: Claude Opus 4.6 <[email protected]>
5924d7c to
255053a
Compare
0fdf73f to
af97570
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR was authored by Claude (Opus 4.6) at Oli's request, implementing the de-genericize plan from PR #363.
Removes
BatchT/OutputTtype parameters fromComponentModeland all downstream code. Model batch input and output types are nowAny. Introduces aReconstructionLossprotocol with concrete implementations (recon_loss_mse,recon_loss_kl) that callers pass explicitly, replacing the oldoutput_loss_typestring config.Key changes:
BatchT/OutputTgenerics fromComponentModeland all downstream codeReconstructionLossprotocol inspd/models/batch_and_loss_fns.pywithrecon_loss_mseandrecon_loss_klpretrained_model_output_attr→extract_tensor_outputwith regex-based accessor parsing (supports.logits,[0],.output[0], etc.)output_loss_typeconfig field (added to deprecated keys with migration)extract_batch_datautility — callers handle tuple extraction inlinelm_collate_fnfor LM data loading inspd/data.pyRelated Issue
Implements the de-genericize portion of #363
Motivation and Context
The
BatchT/OutputTgenerics onComponentModeladded type complexity without meaningful safety — model I/O is inherently heterogeneous across experiment types. Theoutput_loss_typestring-based dispatch was a code smell; passing a concreteReconstructionLosscallable is simpler and more extensible. Theextract_tensor_outputrename + regex parser is more general than the oldpretrained_model_output_attrapproach.How Has This Been Tested?
make checkpasses (basedpyright + ruff lint + ruff format, 0 errors)make testpasses (359 passed, 16 skipped, 0 failures)Does this PR introduce a breaking change?
Yes — configs with
output_loss_typeorpretrained_model_output_attrwill use the migration path inconfigs.py(deprecated key handling + value conversion). Theoutput_loss_typefield is removed fromConfig; callers must passreconstruction_lossexplicitly tooptimize(). Theextract_batch_datautility is removed.