Skip to content

Performant way to initialize an ensemble of models #909

Open
@zou3519

Description

@zou3519

Right now, to initialize an ensemble of e.g. 350 models, we first create 350 models and then combine their states together with combine_state_for_ensemble. This leaves some performance on the table; the fastest thing we could do is initialize the combined state in one go.

This might not be too difficult to do. Idea from discussion with @Chillee is:

  • we could have torch.empty, torch.tensor, etc automatically return a repeated tensor. This could be a flag in vmap
  • then we would just be able to use vmap with randomness=different to initialize a model

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionableIt is clear what should be done for this issue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions