You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We should add a function to load backbones from the benchmark checkpoints. The function should roughly do the following:
from torchvision.models import resnet50
from torch.hub import load_state_dict_from_url
model = resnet50()
state_dict = load_state_dict_from_url("https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt")
new_state_dict = {}
for key, value in state_dict["state_dict"].items():
if key.startswith("backbone."):
new_state_dict[key.lstrip("backbone.")] = value
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict)
assert missing_keys == {"fc.weight", "fc.bias"}
Maybe we can leave the load_state_dict_from_url outside the function make the function just take a state dict as input and return the new state dict as output.
TODO
Add function
Document function
The text was updated successfully, but these errors were encountered:
We should add a function to load backbones from the benchmark checkpoints. The function should roughly do the following:
Maybe we can leave the
load_state_dict_from_url
outside the function make the function just take a state dict as input and return the new state dict as output.TODO
The text was updated successfully, but these errors were encountered: