Skip to content

Commit

Permalink
add (#1298)
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis-Dupont committed Jul 20, 2023
1 parent 30b922a commit 8bfdedf
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
14 changes: 9 additions & 5 deletions src/super_gradients/training/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,30 @@ def get_architecture(
pretrained_weights_path = None
is_remote = False
if not isinstance(model_name, str):
raise ValueError("Parameter model_name is expected to be a string.")
raise ValueError(f"Input parameter `model_name` should be a string. Got {model_name} of type {type(model_name)}.")

architecture = get_param(ARCHITECTURES, model_name)
if model_name not in ARCHITECTURES.keys() and architecture is None:
if client_enabled:
logger.info(f'The required model, "{model_name}", was not found in SuperGradients. Trying to load a model from remote deci-lab')
logger.info(f'The requested model "{model_name}" was not found in SuperGradients. Trying to load a model from the Platform...')
deci_client = DeciClient()

_arch_params = deci_client.get_model_arch_params(model_name)
if _arch_params is None:
raise ValueError(
f'The required model "{model_name}", was not found in SuperGradients and remote deci-lab. '
f"See docs or all_architectures.py for supported model names."
f'The requested model "{model_name}" was not found in the Platform. See docs or all_architectures.py for supported model names.'
)
else:
logger.info(f'The requested model "{model_name}" is available in the platform and will now be downloaded...')

if download_required_code: # Some extra code might be required to instantiate the arch params.
deci_client.download_and_load_model_additional_code(model_name, target_path=str(Path.cwd()))
logger.debug(f'Additional code for model "{model_name}" has been downloaded from the platform.')

_arch_params = hydra.utils.instantiate(_arch_params)
if download_platform_weights:
pretrained_weights_path = deci_client.get_model_weights(model_name)
logger.info("The model weights were downloaded from the platform.")
else:
pretrained_weights_path = None
model_name = _arch_params["model_name"]
Expand All @@ -80,7 +84,7 @@ def get_architecture(
arch_params, is_remote = _arch_params, True
else:
raise UnknownTypeException(
message=f'The required model, "{model_name}", was not found in SuperGradients. See docs or all_architectures.py for supported model names.',
message=f'The requested model "{model_name}" was not found in SuperGradients. See docs or all_architectures.py for supported model names.',
unknown_type=model_name,
choices=list(ARCHITECTURES.keys()),
)
Expand Down
2 changes: 1 addition & 1 deletion src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):


def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):

"""
Loads pretrained weights from the MODEL_URLS dictionary to model
:param architecture: name of the model's architecture
Expand Down Expand Up @@ -335,6 +334,7 @@ def _load_weights(architecture, model, pretrained_state_dict):
pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
solver = _yolox_ckpt_solver if "yolox" in architecture else None
adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver)
logger.info(f"Successfully loaded pretrained weights for architecture {architecture}")


def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
Expand Down

0 comments on commit 8bfdedf

Please sign in to comment.