Skip to content

Makes ComponentModel generic over batch and output types#363

Draft
danbraunai-goodfire wants to merge 20 commits intomainfrom
feature/generic-shapes
Draft

Makes ComponentModel generic over batch and output types#363
danbraunai-goodfire wants to merge 20 commits intomainfrom
feature/generic-shapes

Conversation

@danbraunai-goodfire
Copy link
Collaborator

@danbraunai-goodfire danbraunai-goodfire commented Feb 6, 2026

Description

TODO (either here or afterwards):

  • Currently in spd we use n_ctx=512 for the dataset config and not n_ctx=513 which we use for pretrain. The spd code will calculate ce losses using the fancy ignore_index=-100 thing. Once we sort out this PR, I think we should make our dataloaders return batch and labels, and then we can set n_ctx for the dataset to 513 just like we do for pretrain, and not do this fancy ignore_index stuff. We should also rename max_seq_len in the spd task_config to n_ctx.

Related Issue

Motivation and Context

How Has This Been Tested?

Does this PR introduce a breaking change?

claude-spd1 and others added 6 commits January 29, 2026 06:16
- Add `BatchT` and `OutputT` type parameters to ComponentModel and related functions
- Replace `pretrained_model_output_attr` with `run_batch` and `reconstruction_loss` callables
- Remove `extract_batch_data` and `AliveComponentsTracker` utilities
-
This reverts commit e712573.
@danbraunai-goodfire danbraunai-goodfire changed the title Make model input/output generic for easier adaptation to SPD Makes ComponentModel generic over batch and output types Feb 6, 2026
@ocg-goodfire
Copy link
Collaborator

@danbraunai-goodfire I think I definitely went overboard with the abstractions so thanks for reigning this is lol. I do have some suggestions in #375. lmk what you think

@danbraunai-goodfire
Copy link
Collaborator Author

@claude I'd like you to do a thorough review of this PR. It's big. I'm worried that there are inconsistencies or that we've made some bad architectural choices. I want critical comments, do feel comfortable to give me the hard truths.

@claude
Copy link
Contributor

claude bot commented Feb 11, 2026

Claude finished @danbraunai-goodfire's task —— View job


PR Review: ComponentModel Generic Types

I'll do a thorough architectural review of this PR focusing on the generic type system changes and potential inconsistencies.

Todo List:

  • Set up environment and understand scope
  • Review core ComponentModel generic changes
  • Analyze the new batch_and_loss_fns.py module
  • Check consistency across experiment configs
  • Review metrics changes for generic compatibility
  • Look for architectural issues and over-abstractions
  • Provide critical feedback with specific recommendations

Critical Review Findings

After analyzing this substantial PR, I have several major concerns about the architectural decisions and implementation:

🚨 Major Issues

1. Excessive Generic Abstraction
The generic system adds significant complexity with questionable benefits:

class ComponentModel[BatchT, OutputT](nn.Module):
    def __init__(self, ..., run_batch: RunBatch[BatchT, OutputT], ...): ...
  • Problem: This creates a complex type system where BatchT and OutputT are threaded through dozens of files but offer limited type safety improvements
  • Evidence: Most usage sites cast to Any anyway: ComponentModel[Any, Any] in component_model.py:124,150,165
  • Alternative: The previous approach with concrete tensor types was more maintainable

2. Protocol vs Interface Confusion
The new RunBatch and ReconstructionLoss protocols are unnecessarily abstract:

class RunBatch[BatchT, OutputT](Protocol):
    def __call__(self, model: nn.Module, batch: BatchT) -> OutputT: ...
  • Problem: These protocols don't provide meaningful abstraction - they're just function wrappers
  • Evidence: make_run_batch() returns simple lambda functions, not complex objects
  • Simpler solution: Direct function parameters would be clearer

3. Inconsistent Configuration Migration
The config changes are inconsistent and create migration complexity:

  • Some experiments now require output_extract (LM models) while others don't (TMS)
  • The output_extract field is marked as optional but is required for certain model types
  • No clear migration path for existing configs

🔧 Implementation Issues

4. Metrics Update Cascade
All metrics needed updates to support generics, creating a massive diff with minimal benefit:

# Before: Simple and clear
def stochastic_recon_loss(model, batch, target_out, ...): ...

# After: Generic but not meaningfully safer
def stochastic_recon_loss[BatchT, OutputT](
    model: ComponentModel[BatchT, OutputT], batch: BatchT, target_out: OutputT, ...
): ...

5. Loss Function Protocol Mismatch
The ReconstructionLoss protocol forces all loss functions to return (loss, n_examples) tuples, but this pattern isn't universally applicable and feels forced.

6. Over-Engineering in batch_and_loss_fns.py
The new module adds indirection without clear benefits:

  • run_batch_passthrough is just lambda model, batch: model(batch)
  • make_run_batch could be replaced with direct function construction
  • The abstraction doesn't hide meaningful complexity

📝 Code Quality Issues

7. Type Erasure Everywhere
Despite the generic system, most actual usage erases types:

  • from_pretrained returns ComponentModel[Any, Any]
  • from_run_info returns ComponentModel[Any, Any]
  • This suggests the generics aren't providing real value

8. Config Complexity Increase
The addition of OutputExtractConfig discriminated unions adds config complexity:

output_extract:
  type: attr
  attr: logits

This could be simplified to: output_extract: logits or output_extract: [1] for indexing.

🎯 Specific Recommendations

Immediate Actions:

  1. Reconsider the generic system - The complexity-to-benefit ratio is poor. Consider reverting to concrete types.
  2. Simplify batch running - Replace protocols with simple function parameters
  3. Streamline configs - Make output_extract required where needed, optional where not
  4. Reduce metric changes - Many metric changes seem unnecessary if you simplified the generic approach

Alternative Architecture:

# Simpler, clearer approach:
class ComponentModel(nn.Module):
    def __init__(self, target_model, output_extractor=None, recon_loss_fn=None): ...
    
def optimize(model, train_loader, eval_loader, output_extractor, recon_loss_fn): ...

Configuration Simplification:

# Instead of discriminated union:
output_extract: "logits"  # attr extraction
# or
output_extract: 1  # index extraction

🤔 Questions for Discussion

  1. What specific problem does the generic system solve? The PR description doesn't clearly articulate the benefits that justify this complexity.

  2. Why protocols over simple functions? The RunBatch and ReconstructionLoss protocols seem like over-engineering.

  3. Migration strategy? How will existing users adapt their configs and custom metrics?

  4. Performance impact? Do the additional generic type checks and protocol calls affect training speed?

🎨 Positive Aspects

  • The code is well-typed where the generics are used
  • The batch_and_loss_fns.py module has good separation of concerns
  • Config validation is maintained throughout

🏁 Summary

This PR feels like a solution in search of a problem. The generic system adds substantial complexity across 96 files but provides limited concrete benefits. The architecture would be cleaner and more maintainable with:

  1. Concrete tensor types instead of generics
  2. Direct function parameters instead of protocols
  3. Simplified configuration schema
  4. Focused changes rather than system-wide refactoring

I recommend significantly simplifying this approach or providing a clearer justification for the added complexity.


@danbraunai-goodfire
Copy link
Collaborator Author

@claude I've removed OutputT as I noticed that it was always a Tensor. Could you please review this PR now? Again, I'm interested in critical comments.

@claude
Copy link
Contributor

claude bot commented Feb 11, 2026

Claude finished @danbraunai-goodfire's task —— View job


PR Review: ComponentModel Generic Types

I'll provide a thorough critical review of this substantial PR focusing on the generic type system changes and architectural decisions.

Todo List:

  • Set up environment and understand scope
  • Review core ComponentModel generic changes
  • Analyze the new batch_and_loss_fns.py module
  • Check consistency across experiment configs
  • Review metrics changes for generic compatibility
  • Look for architectural issues and over-abstractions
  • Provide critical feedback with specific recommendations

Critical Review Findings

After thoroughly analyzing this substantial PR, I have several significant concerns about the architectural decisions:

🚨 Major Architectural Issues

1. Questionable Generic Value Proposition
The ComponentModel[BatchT] generic adds complexity across 96 files but provides limited concrete benefits:

# From component_model.py:124,165 - Type erasure everywhere
def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel[Any]":
def from_pretrained(cls, path: ModelPath) -> "ComponentModel[Any]":

Problem: The main factory methods return ComponentModel[Any], completely erasing the generic type information. This suggests the generics aren't providing meaningful type safety.

2. Protocol Over-Engineering
The new RunBatch[BatchT] and ReconstructionLoss protocols in batch_and_loss_fns.py are unnecessarily abstract:

class RunBatch[BatchT](Protocol):
    def __call__(self, model: nn.Module, batch: BatchT) -> Tensor: ...

def run_batch_passthrough(model: nn.Module, batch: Any) -> Tensor:
    return runtime_cast(Tensor, model(batch))

Problem: run_batch_passthrough is literally just lambda model, batch: model(batch) with a runtime cast. The protocol abstraction hides no meaningful complexity.

3. Configuration Complexity Explosion
The discriminated union for output_extract adds unnecessary config complexity:

# LM configs now require:
output_extract:
  type: index  # or "attr"
  index: 0     # or attr: "logits"

Problem: This could be simplified to output_extract: 0 or output_extract: "logits" with type-based dispatch.

🔧 Implementation Inconsistencies

4. Inconsistent Config Migration

  • All LM experiments now require output_extract config
  • TMS/ResidualMLP experiments don't need it (use passthrough)
  • No clear migration strategy for existing users

5. Metric Cascade Updates
Every metric needed generic type annotations, creating massive diff with questionable benefits:

# Before: Clear and simple
def stochastic_recon_loss(model, batch, target_out, ...): ...

# After: Generic but no safer than before
def stochastic_recon_loss[BatchT](
    model: ComponentModel[BatchT], batch: BatchT, target_out: Tensor, ...
): ...

6. Type Safety Illusion
The runtime_cast(Tensor, ...) calls throughout suggest the type system isn't actually preventing runtime type errors - you're just moving them to different locations.

📝 Code Quality Concerns

7. Abstraction Without Benefit
make_run_batch() factory creates simple lambdas that could be constructed directly:

# Could be simplified to:
def get_output_extractor(config):
    if config.output_extract is None:
        return lambda x: x
    elif config.output_extract.type == "index":
        return lambda x: x[config.output_extract.index]
    # etc.

8. Generic Threading Complexity
The BatchT generic propagates through the entire codebase (96 files!) but most usage sites don't benefit from the type information.

🎯 Specific Recommendations

Immediate Simplifications:

  1. Consider removing generics entirely - The type erasure pattern suggests limited value
  2. Replace protocols with simple functions - No need for RunBatch protocol abstraction
  3. Simplify config schema - Use simpler output_extract format
  4. Reduce metric signature changes - Many appear unnecessary

Alternative Architecture:

class ComponentModel(nn.Module):
    def __init__(
        self, 
        target_model, 
        output_extractor=None,  # Simple callable
        # ... other params
    ): ...

def optimize(
    model, 
    train_loader, 
    eval_loader,
    output_extractor=lambda x: x,  # Simple default
    reconstruction_loss=mse_loss,  # Simple function
): ...

🤔 Critical Questions

  1. What problem does BatchT solve? The PR description doesn't articulate the benefits that justify touching 96 files.

  2. Why protocols over functions? RunBatch and ReconstructionLoss protocols seem like premature abstraction.

  3. Type safety claims? The prevalence of runtime_cast and Any types suggests limited actual type safety gains.

🏁 Summary Assessment

This PR feels like "generics for the sake of generics" rather than solving a concrete problem. The architecture would be significantly cleaner with:

  1. Concrete types instead of generics - Less complexity, same safety
  2. Direct function parameters - No protocol indirection
  3. Simplified configuration - Easier to use and migrate
  4. Focused changes - Target specific issues rather than system-wide refactoring

The current approach adds substantial complexity with minimal demonstrated benefits. I'd recommend either:

  • A much simpler approach focused on the specific problem you're trying to solve
  • A clear articulation of why this complexity is necessary

Verdict: The cost-benefit ratio is poor. This feels like over-engineering that will make the codebase harder to understand and maintain.


@ocg-goodfire
Copy link
Collaborator

@danbraunai-goodfire the reason for abstracting outputT was to enable the usecase where a model outputs a dict and has a more advanced recon loss based on those dict elements

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants