Open
Description
Right now, to initialize an ensemble of e.g. 350 models, we first create 350 models and then combine their states together with combine_state_for_ensemble
. This leaves some performance on the table; the fastest thing we could do is initialize the combined state in one go.
This might not be too difficult to do. Idea from discussion with @Chillee is:
- we could have torch.empty, torch.tensor, etc automatically return a repeated tensor. This could be a flag in vmap
- then we would just be able to use vmap with randomness=different to initialize a model