Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchvision][Bug-fix] ignore state dict error on transfer learning tasks + use PythonLogger default logger #1455

Merged
merged 3 commits into from
Mar 17, 2023

Conversation

KSGulin
Copy link
Contributor

@KSGulin KSGulin commented Mar 17, 2023

When loading a pre-trained torchvision model, an error will occur if the number of classes in the target dataset doesn't match the number of classes in the pre-trained model. e.g. when using a smaller subset of the original dataset. This PR fixes that issue by ignoring the classification head in the loaded model dict. Note that in some cases (such as inceptionet) it will fail, as for some models the classification head naming doesn't follow the standard naming pattern.

Test plan
sparseml.image_classification.train --checkpoint-path verizon_dense.pt --arch-key densenet121 --dataset-path /network/datasets/imagenette-160/imagenette-160 --pretrained True

@KSGulin KSGulin requested review from corey-nm and a team March 17, 2023 16:08
@KSGulin KSGulin self-assigned this Mar 17, 2023
@KSGulin KSGulin requested review from tdg5 and abhinavnmagic and removed request for a team March 17, 2023 16:09
rahul-tuli
rahul-tuli previously approved these changes Mar 17, 2023
dbogunowicz
dbogunowicz previously approved these changes Mar 17, 2023
bfineran
bfineran previously approved these changes Mar 17, 2023
@KSGulin KSGulin dismissed stale reviews from bfineran, dbogunowicz, and rahul-tuli via 9c156cd March 17, 2023 19:11
@KSGulin KSGulin force-pushed the ic_class_fix branch 2 times, most recently from 0f66a06 to 2fd4854 Compare March 17, 2023 19:17
@bfineran bfineran changed the title [Bug-fix] Don't override num_classes for pre-trained torchvision models [torchvision][Bug-fix] ignore state dict error on transfer learning tasks + use PythonLogger default logger Mar 17, 2023
@bfineran bfineran merged commit 7ee620b into main Mar 17, 2023
@bfineran bfineran deleted the ic_class_fix branch March 17, 2023 19:54
@bfineran bfineran restored the ic_class_fix branch March 17, 2023 19:55
bfineran added a commit that referenced this pull request Mar 17, 2023
…asks + use PythonLogger default logger (#1455)

* Remove cf from native torchvision models

* * do not pass default logger to PythonLogger
* comments

---------

Co-authored-by: Damian <damian@neuralmagic.com>
Co-authored-by: Benjamin <ben@neuralmagic.com>
bfineran added a commit that referenced this pull request Mar 17, 2023
…asks + use PythonLogger default logger (#1455)

* Remove cf from native torchvision models

* * do not pass default logger to PythonLogger
* comments

---------

Co-authored-by: Damian <damian@neuralmagic.com>
Co-authored-by: Benjamin <ben@neuralmagic.com>
bfineran added a commit that referenced this pull request Mar 17, 2023
…ansfer learning tasks + use PythonLogger default logger #1455 (#1460)

* [torchvision][Bug-fix] ignore state dict error on transfer learning tasks + use PythonLogger default logger (#1455)

* Remove cf from native torchvision models

* * do not pass default logger to PythonLogger
* comments

---------

Co-authored-by: Damian <damian@neuralmagic.com>
Co-authored-by: Benjamin <ben@neuralmagic.com>

* [torchvision] add ignore error tensors back to optional checkpoint load (#1459)

---------

Co-authored-by: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com>
Co-authored-by: Damian <damian@neuralmagic.com>
bfineran added a commit that referenced this pull request Mar 17, 2023
…ansfer learning tasks + use PythonLogger default logger #1455 (#1461)

* [torchvision][Bug-fix] ignore state dict error on transfer learning tasks + use PythonLogger default logger (#1455)

* Remove cf from native torchvision models

* * do not pass default logger to PythonLogger
* comments

---------

Co-authored-by: Damian <damian@neuralmagic.com>
Co-authored-by: Benjamin <ben@neuralmagic.com>

* [torchvision] add ignore error tensors back to optional checkpoint load (#1459)

---------

Co-authored-by: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com>
Co-authored-by: Damian <damian@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants