Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fixed load_pretrained_params in PyTorch when ignoring keys #902

Merged
merged 11 commits into from
Apr 28, 2022

Conversation

frgfm
Copy link
Collaborator

@frgfm frgfm commented Apr 27, 2022

Following up on #874, this PR introduces the following modifications:

  • renamed "pop_entrys" to "ignore_keys"
  • updated load_pretrained_params to avoid non-strict loading of wrongly sized state_dict
  • fixed classification models ignore_keys mechanism (the keys to ignore were hardcoded in the factory function, while this function is used to build models of different size, and thus with Linear layers being named differently)
  • added dedicated unittest for ignore_keys

Any feedback is welcome!

# test pretrained model with different num_classes
model = classification.__dict__[arch_name](pretrained=True, num_classes=108).eval()
_test_classification(model, input_shape, output_size=(108,))
# Check that you can pretrained everything up until the last layer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better: Test whether a pretrained model can be initialized down to the last layer ? :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand?

Compared to previous features, my idea is that we can already assume:

  • models are built correctly
  • pretrained with default output classes is working

What needs testing here is that the checkpoint can be loaded on a model with different num_classes. So perhaps we could check that the bias of a given layer is indeed the one of the state_dict, but I don't see much more :)

(and I try to keep unittests from running slower)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frgfm sorry I just mean to change this comment :

# Check that you can pretrained everything up until the last layer
maybe better / clearer:
# Test whether a pretrained model can be initialized down to the last layer 

😅

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frgfm and about the unitests i think my tests from: https://github.com/mindee/doctr/pull/892/files are currently the biggest slowdown but i think we need to test each model

@felixdittrich92
Copy link
Contributor

@frgfm Thanks for refactoring this now it is more dynamically ;)
LGTM only a missing import and maybe the comment ? :)

Copy link
Collaborator

@charlesmindee charlesmindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the refacto! Only a small missing import it seems

Copy link
Collaborator

@charlesmindee charlesmindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@charlesmindee charlesmindee merged commit 01f4b3a into mindee:main Apr 28, 2022
@frgfm frgfm deleted the typo-utils branch April 28, 2022 16:19
@frgfm frgfm added type: bug Something isn't working module: models Related to doctr.models ext: tests Related to tests folder framework: pytorch Related to PyTorch backend labels May 2, 2022
@frgfm frgfm added this to the 0.5.2 milestone May 2, 2022
@frgfm frgfm mentioned this pull request Jun 28, 2022
14 tasks
@felixdittrich92 felixdittrich92 mentioned this pull request Sep 26, 2022
85 tasks
@felixdittrich92 felixdittrich92 modified the milestones: 0.5.2, 0.6.0 Sep 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ext: tests Related to tests folder framework: pytorch Related to PyTorch backend module: models Related to doctr.models type: bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants