Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed loading preprocessing params from pretrained weights #1473

Merged
merged 5 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/super_gradients/training/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def instantiate_model(
net = architecture_cls(arch_params=arch_params)

if pretrained_weights:
# The logic is follows - first we initialize the preprocessing params using default hard-coded params
# If pretrained checkpoint contains preprocessing params, new params will be loaded and override the ones from
# this step in load_pretrained_weights_local/load_pretrained_weights
if isinstance(net, HasPredict):
processing_params = get_pretrained_processing_params(model_name, pretrained_weights)
net.set_dataset_processing_params(**processing_params)

if is_remote and pretrained_weights_path:
load_pretrained_weights_local(net, model_name, pretrained_weights_path)
else:
Expand All @@ -162,11 +169,6 @@ def instantiate_model(
net.replace_head(new_num_classes=num_classes_new_head)
arch_params.num_classes = num_classes_new_head

# STILL NEED TO GET PREPROCESSING PARAMS IN CASE CHECKPOINT HAS NO RECIPE
if isinstance(net, HasPredict):
processing_params = get_pretrained_processing_params(model_name, pretrained_weights)
net.set_dataset_processing_params(**processing_params)

_add_model_name_attribute(net, model_name)

return net
Expand Down
58 changes: 40 additions & 18 deletions src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,16 +1517,7 @@ def load_checkpoint_to_model(
message_model = "model" if not load_backbone else "model's backbone"
logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix)

if (isinstance(net, HasPredict)) and load_processing_params:
if "processing_params" not in checkpoint.keys():
raise ValueError("Can't load processing params - could not find any stored in checkpoint file.")
try:
net.set_dataset_processing_params(**checkpoint["processing_params"])
except Exception as e:
logger.warning(
f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling"
"predict make sure to call set_dataset_processing_params."
)
_maybe_load_preprocessing_params(net, checkpoint)

if load_weights_only or load_backbone:
# DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
Expand Down Expand Up @@ -1574,14 +1565,7 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
with wait_for_the_master(get_local_rank()):
pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
_load_weights(architecture, model, pretrained_state_dict)


def _load_weights(architecture, model, pretrained_state_dict):
if "ema_net" in pretrained_state_dict.keys():
pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver()
adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver)
logger.info(f"Successfully loaded pretrained weights for architecture {architecture}")
_maybe_load_preprocessing_params(model, pretrained_state_dict)


def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
Expand All @@ -1598,3 +1582,41 @@ def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pre

pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
_load_weights(architecture, model, pretrained_state_dict)
_maybe_load_preprocessing_params(model, pretrained_state_dict)


def _load_weights(architecture, model, pretrained_state_dict):
if "ema_net" in pretrained_state_dict.keys():
pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver()
adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver)
logger.info(f"Successfully loaded pretrained weights for architecture {architecture}")


def _maybe_load_preprocessing_params(model: Union[nn.Module, HasPredict], checkpoint: Mapping[str, Tensor]) -> bool:
"""
Tries to load preprocessing params from the checkpoint to the model.
The function does not crash, and raises a warning if the loading fails.
:param model: Instance of nn.Module
:param checkpoint: Entire checkpoint dict (not state_dict with model weights)
:return: True if the loading was successful, False otherwise.
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
"""
model = unwrap_model(model)
preprocessing_params_in_checkpoint = "processing_params" in checkpoint.keys()
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
model_has_predict = isinstance(model, HasPredict)
logger.debug(
f"Trying to load preprocessing params from checkpoint. Preprocessing params in checkpoint: {preprocessing_params_in_checkpoint}. "
f"Model {model.__class__.__name__} inherit HasPredict: {model_has_predict}"
)

if model_has_predict and preprocessing_params_in_checkpoint:
try:
model.set_dataset_processing_params(**checkpoint["processing_params"])
logger.debug(f"Successfully loaded preprocessing params from checkpoint {checkpoint['processing_params']}")
return True
except Exception as e:
logger.warning(
f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling"
"predict make sure to call set_dataset_processing_params."
)
return False