Skip to content

Commit

Permalink
Added support for file:/// in pretrained weights
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Sep 20, 2023
1 parent 635357b commit 433f33a
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,10 +1540,12 @@ def __init__(self, desc):
def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
"""
Loads pretrained weights from the MODEL_URLS dictionary to model
:param architecture: name of the model's architecture
:param model: model to load pretrinaed weights for
:param pretrained_weights: name for the pretrianed weights (i.e imagenet)
:return: None
:param architecture: name of the model's architecture
:param model: model to load pretrinaed weights for
:param pretrained_weights: name for the pretrianed weights (i.e imagenet)
:return: None
"""
from super_gradients.common.object_names import Models

Expand All @@ -1560,10 +1562,17 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
"By downloading the pre-trained weight files you agree to comply with these terms."
)

unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_")
map_location = torch.device("cpu")
with wait_for_the_master(get_local_rank()):
pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
# Basically this check allows settings pretrained weights from local path using file:///path/to/weights scheme
# which is a valid URI scheme for local files
# Supporting local files and file URI allows us modification of pretrained weights dics in unit tests
if url.startswith("file://") or os.path.exists(url):
pretrained_state_dict = torch.load(url.replace("file://", ""), map_location="cpu")
else:
unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_")
map_location = torch.device("cpu")
with wait_for_the_master(get_local_rank()):
pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)

_load_weights(architecture, model, pretrained_state_dict)
_maybe_load_preprocessing_params(model, pretrained_state_dict)

Expand Down Expand Up @@ -1597,19 +1606,19 @@ def _maybe_load_preprocessing_params(model: Union[nn.Module, HasPredict], checkp
"""
Tries to load preprocessing params from the checkpoint to the model.
The function does not crash, and raises a warning if the loading fails.
:param model: Instance of nn.Module
:param model: Instance of nn.Module
:param checkpoint: Entire checkpoint dict (not state_dict with model weights)
:return: True if the loading was successful, False otherwise.
:return: True if the loading was successful, False otherwise.
"""
model = unwrap_model(model)
preprocessing_params_in_checkpoint = "processing_params" in checkpoint.keys()
checkpoint_has_preprocessing_params = "processing_params" in checkpoint.keys()
model_has_predict = isinstance(model, HasPredict)
logger.debug(
f"Trying to load preprocessing params from checkpoint. Preprocessing params in checkpoint: {preprocessing_params_in_checkpoint}. "
f"Trying to load preprocessing params from checkpoint. Preprocessing params in checkpoint: {checkpoint_has_preprocessing_params}. "
f"Model {model.__class__.__name__} inherit HasPredict: {model_has_predict}"
)

if model_has_predict and preprocessing_params_in_checkpoint:
if model_has_predict and checkpoint_has_preprocessing_params:
try:
model.set_dataset_processing_params(**checkpoint["processing_params"])
logger.debug(f"Successfully loaded preprocessing params from checkpoint {checkpoint['processing_params']}")
Expand Down

0 comments on commit 433f33a

Please sign in to comment.