Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed loading preprocessing params from pretrained weights #1473

Merged
merged 5 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/super_gradients/training/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def instantiate_model(
net = architecture_cls(arch_params=arch_params)

if pretrained_weights:
# The logic is follows - first we initialize the preprocessing params using default hard-coded params
# If pretrained checkpoint contains preprocessing params, new params will be loaded and override the ones from
# this step in load_pretrained_weights_local/load_pretrained_weights
if isinstance(net, HasPredict):
processing_params = get_pretrained_processing_params(model_name, pretrained_weights)
net.set_dataset_processing_params(**processing_params)

if is_remote and pretrained_weights_path:
load_pretrained_weights_local(net, model_name, pretrained_weights_path)
else:
Expand All @@ -162,11 +169,6 @@ def instantiate_model(
net.replace_head(new_num_classes=num_classes_new_head)
arch_params.num_classes = num_classes_new_head

# STILL NEED TO GET PREPROCESSING PARAMS IN CASE CHECKPOINT HAS NO RECIPE
if isinstance(net, HasPredict):
processing_params = get_pretrained_processing_params(model_name, pretrained_weights)
net.set_dataset_processing_params(**processing_params)

_add_model_name_attribute(net, model_name)

return net
Expand Down
83 changes: 57 additions & 26 deletions src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,16 +1517,7 @@ def load_checkpoint_to_model(
message_model = "model" if not load_backbone else "model's backbone"
logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix)

if (isinstance(net, HasPredict)) and load_processing_params:
if "processing_params" not in checkpoint.keys():
raise ValueError("Can't load processing params - could not find any stored in checkpoint file.")
try:
net.set_dataset_processing_params(**checkpoint["processing_params"])
except Exception as e:
logger.warning(
f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling"
"predict make sure to call set_dataset_processing_params."
)
_maybe_load_preprocessing_params(net, checkpoint)

if load_weights_only or load_backbone:
# DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
Expand All @@ -1549,10 +1540,12 @@ def __init__(self, desc):
def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
"""
Loads pretrained weights from the MODEL_URLS dictionary to model
:param architecture: name of the model's architecture
:param model: model to load pretrinaed weights for
:param pretrained_weights: name for the pretrianed weights (i.e imagenet)
:return: None

:param architecture: name of the model's architecture
:param model: model to load pretrinaed weights for
:param pretrained_weights: name for the pretrianed weights (i.e imagenet)

:return: None
"""
from super_gradients.common.object_names import Models

Expand All @@ -1569,19 +1562,19 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
"By downloading the pre-trained weight files you agree to comply with these terms."
)

unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_")
map_location = torch.device("cpu")
with wait_for_the_master(get_local_rank()):
pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
_load_weights(architecture, model, pretrained_state_dict)

# Basically this check allows settings pretrained weights from local path using file:///path/to/weights scheme
# which is a valid URI scheme for local files
# Supporting local files and file URI allows us modification of pretrained weights dics in unit tests
if url.startswith("file://") or os.path.exists(url):
pretrained_state_dict = torch.load(url.replace("file://", ""), map_location="cpu")
else:
unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_")
map_location = torch.device("cpu")
with wait_for_the_master(get_local_rank()):
pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)

def _load_weights(architecture, model, pretrained_state_dict):
if "ema_net" in pretrained_state_dict.keys():
pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver()
adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver)
logger.info(f"Successfully loaded pretrained weights for architecture {architecture}")
_load_weights(architecture, model, pretrained_state_dict)
_maybe_load_preprocessing_params(model, pretrained_state_dict)


def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
Expand All @@ -1598,3 +1591,41 @@ def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pre

pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
_load_weights(architecture, model, pretrained_state_dict)
_maybe_load_preprocessing_params(model, pretrained_state_dict)


def _load_weights(architecture, model, pretrained_state_dict):
if "ema_net" in pretrained_state_dict.keys():
pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver()
adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver)
logger.info(f"Successfully loaded pretrained weights for architecture {architecture}")


def _maybe_load_preprocessing_params(model: Union[nn.Module, HasPredict], checkpoint: Mapping[str, Tensor]) -> bool:
"""
Tries to load preprocessing params from the checkpoint to the model.
The function does not crash, and raises a warning if the loading fails.
:param model: Instance of nn.Module
:param checkpoint: Entire checkpoint dict (not state_dict with model weights)
:return: True if the loading was successful, False otherwise.
"""
model = unwrap_model(model)
checkpoint_has_preprocessing_params = "processing_params" in checkpoint.keys()
model_has_predict = isinstance(model, HasPredict)
logger.debug(
f"Trying to load preprocessing params from checkpoint. Preprocessing params in checkpoint: {checkpoint_has_preprocessing_params}. "
f"Model {model.__class__.__name__} inherit HasPredict: {model_has_predict}"
)

if model_has_predict and checkpoint_has_preprocessing_params:
try:
model.set_dataset_processing_params(**checkpoint["processing_params"])
logger.debug(f"Successfully loaded preprocessing params from checkpoint {checkpoint['processing_params']}")
return True
except Exception as e:
logger.warning(
f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling"
"predict make sure to call set_dataset_processing_params."
)
return False
29 changes: 26 additions & 3 deletions tests/unit_tests/pretrained_models_unit_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import os
import shutil
import tempfile
import unittest

import numpy as np
import torch

import super_gradients
from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training import Trainer
from super_gradients.training import models
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
from super_gradients.training.metrics import Accuracy
import os
import shutil
from super_gradients.training.pretrained_models import MODEL_URLS, PRETRAINED_NUM_CLASSES
from super_gradients.training.processing.processing import default_yolo_nas_coco_processing_params


class PretrainedModelsUnitTest(unittest.TestCase):
Expand All @@ -29,6 +36,22 @@ def test_pretrained_repvgg_a0_imagenet(self):
model = models.get(Models.REPVGG_A0, pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)

def test_pretrained_models_load_preprocessing_params(self):
"""
Test that checks whether preprocessing params from pretrained model load correctly.
"""
state = {"net": models.get(Models.YOLO_NAS_S, num_classes=80).state_dict(), "processing_params": default_yolo_nas_coco_processing_params()}
with tempfile.TemporaryDirectory() as td:
checkpoint_path = os.path.join(td, "yolo_nas_s_coco.pth")
torch.save(state, checkpoint_path)

MODEL_URLS[Models.YOLO_NAS_S + "_test"] = checkpoint_path
PRETRAINED_NUM_CLASSES["test"] = 80

model = models.get(Models.YOLO_NAS_S, pretrained_weights="test")
# .predict() would fail it model has no preprocessing params
self.assertIsNotNone(model.predict(np.zeros(shape=(512, 512, 3), dtype=np.uint8)))

def tearDown(self) -> None:
if os.path.exists("~/.cache/torch/hub/"):
shutil.rmtree("~/.cache/torch/hub/")
Expand Down