Skip to content

Conversation

@favyen2
Copy link
Collaborator

@favyen2 favyen2 commented Dec 2, 2025

Model component API improvements.

  • Allow intermediate outputs (passed between model components) to be arbitrary, but provide standard types that can be used in most scenarios. This makes it more clear what the expected input and output for each model component is.
  • Make the encoder components (after the initial feature extractor) and decoder components (before the final predictor) share the same API.
  • Create a new SampleMetadata class that replaces the metadata dict.
  • Create a new ModelContext class that includes the input dicts and SampleMetadatas to pass to models instead of just the inputs. This way models can use location/timestamp info if it's useful.
  • Add abstract base classes for FeatureExtractor, IntermediateComponent, and Predictor so the signature of the forward pass is clear for all these model components.
  • Improve the documentation and type checking for task-specific outputs (output from Predictor that gets split up per-example and eventually passed to Task.process_output).

Potential config-breaking changes:

  • I removed the collapse option from PickFeatures. This was previously used to turn a list of BCHW feature maps into one BCHW tensor, since many task heads expect one BCHW tensor as input. But now those heads expect a FeatureMaps instead so collapse is not needed.

@favyen2 favyen2 marked this pull request as draft December 2, 2025 21:28
@favyen2 favyen2 marked this pull request as ready for review December 2, 2025 22:29
@favyen2 favyen2 force-pushed the favyen/20251202-model-component-api branch 3 times, most recently from 4f29db4 to 6cdec06 Compare December 4, 2025 17:38
- Allow intermediate outputs (passed between model components) to be arbitrary, but provide
  standard types that can be used in most scenarios. This makes it more clear what the expected
  input and output for each model component is.
- Make the encoder components (after the initial feature extractor) and decoder components (before
  the final predictor) share the same API.
- Create a new SampleMetadata class that replaces the metadata dict.
- Create a new ModelContext class that includes the input dicts and SampleMetadatas to pass to
  models instead of just the inputs. This way models can use location/timestamp info if it's
  useful.
- Add abstract base classes for FeatureExtractor, IntermediateComponent, and Predictor so the
  signature of the forward pass is clear for all these model components.
@favyen2 favyen2 force-pushed the favyen/20251202-model-component-api branch from 6cdec06 to fde61b9 Compare December 4, 2025 17:44
@favyen2 favyen2 requested a review from yawenzzzz December 4, 2025 18:54
Copy link
Collaborator

@yawenzzzz yawenzzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@favyen2
Copy link
Collaborator Author

favyen2 commented Dec 8, 2025

I tested this in olmoearth_projects with AWF model and the prediction through olmoearth_run worked, just needed to get rid of the PickFeatures which is no longer needed (and no longer supports the collapse argument as mentioned in the PR). It's possible there are bugs with other models but we can fix them as the arise.

@favyen2 favyen2 merged commit 8144af0 into master Dec 8, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants