Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide Benchmark Backbone Checkpoint Weights #1621

Open
4 tasks
guarin opened this issue Jul 31, 2024 · 0 comments
Open
4 tasks

Provide Benchmark Backbone Checkpoint Weights #1621

guarin opened this issue Jul 31, 2024 · 0 comments
Labels

Comments

@guarin
Copy link
Contributor

guarin commented Jul 31, 2024

Right now we only upload checkpoint weights for the full SSL models. But those can be a bit hard to work with as you have to modify fields in the state dict if you only want to load the backbone. We should upload the weights for the backbone only to make them more easily usable.

Tasks

  • Save backbone state dict after pretraining (here). This can be done with torch.save(model.backbone.state_dict(), path). Path should be a path in the same directory as the full checkpoint. Do the same for vit benchmarks.
  • Save backbone state dict for existing pretrained models and upload them to S3. This can be done with
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")

# Remove 'backbone.' prefix from keys in state_dict
backbone_state_dict = {
    k.replace("backbone.", ""): v
    for k, v in checkpoint["state_dict"].items()
    if k.startswith("backbone.")
}
torch.save(backbone_state_dict, path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant