Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to load model from pretrained checkpoint #1475

Closed
2 tasks
guarin opened this issue Jan 12, 2024 · 1 comment
Closed
2 tasks

Add function to load model from pretrained checkpoint #1475

guarin opened this issue Jan 12, 2024 · 1 comment

Comments

@guarin
Copy link
Contributor

guarin commented Jan 12, 2024

We should add a function to load backbones from the benchmark checkpoints. The function should roughly do the following:

from torchvision.models import resnet50
from torch.hub import load_state_dict_from_url

model = resnet50()
state_dict = load_state_dict_from_url("https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt")
new_state_dict = {}
for key, value in state_dict["state_dict"].items():
     if key.startswith("backbone."):
        new_state_dict[key.lstrip("backbone.")] = value
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict)
assert missing_keys == {"fc.weight", "fc.bias"}

Maybe we can leave the load_state_dict_from_url outside the function make the function just take a state dict as input and return the new state dict as output.

TODO

  • Add function
  • Document function
@guarin guarin added this to Backlog Apr 8, 2024
@guarin
Copy link
Contributor Author

guarin commented Aug 16, 2024

Closed in favor of #1621

@guarin guarin closed this as completed Aug 16, 2024
@github-project-automation github-project-automation bot moved this to Done in Backlog Aug 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Done
Development

No branches or pull requests

1 participant