From 5783d0b2e4e0b5d6b44e5a560732340c9e46165d Mon Sep 17 00:00:00 2001 From: Louis Dupont Date: Wed, 19 Jul 2023 16:43:58 +0300 Subject: [PATCH] add --- .../training/models/model_factory.py | 14 +++++++++----- .../training/utils/checkpoint_utils.py | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/super_gradients/training/models/model_factory.py b/src/super_gradients/training/models/model_factory.py index 2164c94974..b657ce4295 100644 --- a/src/super_gradients/training/models/model_factory.py +++ b/src/super_gradients/training/models/model_factory.py @@ -51,26 +51,30 @@ def get_architecture( pretrained_weights_path = None is_remote = False if not isinstance(model_name, str): - raise ValueError("Parameter model_name is expected to be a string.") + raise ValueError(f"Input parameter `model_name` should be a string. Got {model_name} of type {type(model_name)}.") architecture = get_param(ARCHITECTURES, model_name) if model_name not in ARCHITECTURES.keys() and architecture is None: if client_enabled: - logger.info(f'The required model, "{model_name}", was not found in SuperGradients. Trying to load a model from remote deci-lab') + logger.info(f'The requested model "{model_name}" was not found in SuperGradients. Trying to load a model from the Platform...') deci_client = DeciClient() _arch_params = deci_client.get_model_arch_params(model_name) if _arch_params is None: raise ValueError( - f'The required model "{model_name}", was not found in SuperGradients and remote deci-lab. ' - f"See docs or all_architectures.py for supported model names." + f'The requested model "{model_name}" was not found in the Platform. See docs or all_architectures.py for supported model names.' ) + else: + logger.info(f'The requested model "{model_name}" is available in the platform and will now be downloaded...') if download_required_code: # Some extra code might be required to instantiate the arch params. deci_client.download_and_load_model_additional_code(model_name, target_path=str(Path.cwd())) + logger.debug(f'Additional code for model "{model_name}" has been downloaded from the platform.') + _arch_params = hydra.utils.instantiate(_arch_params) if download_platform_weights: pretrained_weights_path = deci_client.get_model_weights(model_name) + logger.info("The model weights were downloaded from the platform.") else: pretrained_weights_path = None model_name = _arch_params["model_name"] @@ -80,7 +84,7 @@ def get_architecture( arch_params, is_remote = _arch_params, True else: raise UnknownTypeException( - message=f'The required model, "{model_name}", was not found in SuperGradients. See docs or all_architectures.py for supported model names.', + message=f'The requested model "{model_name}" was not found in SuperGradients. See docs or all_architectures.py for supported model names.', unknown_type=model_name, choices=list(ARCHITECTURES.keys()), ) diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py index edcbb15a33..b4e8eb67f3 100644 --- a/src/super_gradients/training/utils/checkpoint_utils.py +++ b/src/super_gradients/training/utils/checkpoint_utils.py @@ -300,7 +300,6 @@ def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val): def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str): - """ Loads pretrained weights from the MODEL_URLS dictionary to model :param architecture: name of the model's architecture @@ -335,6 +334,7 @@ def _load_weights(architecture, model, pretrained_state_dict): pretrained_state_dict["net"] = pretrained_state_dict["ema_net"] solver = _yolox_ckpt_solver if "yolox" in architecture else None adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver) + logger.info(f"Successfully loaded pretrained weights for architecture {architecture}") def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):