Skip to content

Commit

Permalink
[cherry-pick-1.4][torchvision][Bug-fix] ignore state dict error on tr…
Browse files Browse the repository at this point in the history
…ansfer learning tasks + use PythonLogger default logger #1455 (#1460)

* [torchvision][Bug-fix] ignore state dict error on transfer learning tasks + use PythonLogger default logger (#1455)

* Remove cf from native torchvision models

* * do not pass default logger to PythonLogger
* comments

---------

Co-authored-by: Damian <damian@neuralmagic.com>
Co-authored-by: Benjamin <ben@neuralmagic.com>

* [torchvision] add ignore error tensors back to optional checkpoint load (#1459)

---------

Co-authored-by: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com>
Co-authored-by: Damian <damian@neuralmagic.com>
  • Loading branch information
3 people authored Mar 17, 2023
1 parent 84deda6 commit cfbfedf
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def collate_fn(batch):

if utils.is_main_process():
loggers = [
PythonLogger(logger=_LOGGER),
PythonLogger(),
]
try:
loggers.append(TensorBoardLogger(log_path=args.output_dir))
Expand Down Expand Up @@ -757,11 +757,36 @@ def _create_model(
model, arch_key = model
elif arch_key in torchvision.models.__dict__:
# fall back to torchvision
model = torchvision.models.__dict__[arch_key](
pretrained=pretrained, num_classes=num_classes
)
# load initial, untrained model with correct number of classes
model = torchvision.models.__dict__[arch_key](num_classes=num_classes)
if pretrained is not None:
# in transfer learning cases, final FC layer may not match dimensions
# load base pretrained model and laod state dict with strict=False
pretrained_model = torchvision.models.__dict__[arch_key](
pretrained=pretrained
)
if (
getattr(pretrained_model, "classifier", None)
and pretrained_model.classifier.out_features != num_classes
):
del pretrained_model.classifier
model.load_state_dict(pretrained_model.state_dict(), strict=False)
if checkpoint_path is not None:
load_model(checkpoint_path, model, strict=True)
load_model(
checkpoint_path,
model,
strict=True,
ignore_error_tensors=[
"classifier.fc.weight",
"classifier.fc.bias",
"classifier.1.weight",
"classifier.1.bias",
"fc.weight",
"fc.bias",
"classifier.weight",
"classifier.bias",
],
)
else:
raise ValueError(
f"Unable to find {arch_key} in ModelRegistry or in torchvision.models"
Expand Down

0 comments on commit cfbfedf

Please sign in to comment.