Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addiing more logs to let user know when pretrained_weights is being used, and/or downloaded #1298

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/super_gradients/training/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,30 @@ def get_architecture(
pretrained_weights_path = None
is_remote = False
if not isinstance(model_name, str):
raise ValueError("Parameter model_name is expected to be a string.")
raise ValueError(f"Input parameter `model_name` should be a string. Got {model_name} of type {type(model_name)}.")

architecture = get_param(ARCHITECTURES, model_name)
if model_name not in ARCHITECTURES.keys() and architecture is None:
if client_enabled:
logger.info(f'The required model, "{model_name}", was not found in SuperGradients. Trying to load a model from remote deci-lab')
logger.info(f'The requested model "{model_name}" was not found in SuperGradients. Trying to load a model from the Platform...')
deci_client = DeciClient()

_arch_params = deci_client.get_model_arch_params(model_name)
if _arch_params is None:
raise ValueError(
f'The required model "{model_name}", was not found in SuperGradients and remote deci-lab. '
f"See docs or all_architectures.py for supported model names."
f'The requested model "{model_name}" was not found in the Platform. See docs or all_architectures.py for supported model names.'
)
else:
logger.info(f'The requested model "{model_name}" is available in the platform and will now be downloaded...')

if download_required_code: # Some extra code might be required to instantiate the arch params.
deci_client.download_and_load_model_additional_code(model_name, target_path=str(Path.cwd()))
logger.debug(f'Additional code for model "{model_name}" has been downloaded from the platform.')

_arch_params = hydra.utils.instantiate(_arch_params)
if download_platform_weights:
pretrained_weights_path = deci_client.get_model_weights(model_name)
logger.info("The model weights were downloaded from the platform.")
else:
pretrained_weights_path = None
model_name = _arch_params["model_name"]
Expand All @@ -80,7 +84,7 @@ def get_architecture(
arch_params, is_remote = _arch_params, True
else:
raise UnknownTypeException(
message=f'The required model, "{model_name}", was not found in SuperGradients. See docs or all_architectures.py for supported model names.',
message=f'The requested model "{model_name}" was not found in SuperGradients. See docs or all_architectures.py for supported model names.',
unknown_type=model_name,
choices=list(ARCHITECTURES.keys()),
)
Expand Down
2 changes: 1 addition & 1 deletion src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ def _yolox_ckpt_solver(ckpt_key, ckpt_val, model_key, model_val):


def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):

"""
Loads pretrained weights from the MODEL_URLS dictionary to model
:param architecture: name of the model's architecture
Expand Down Expand Up @@ -335,6 +334,7 @@ def _load_weights(architecture, model, pretrained_state_dict):
pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
solver = _yolox_ckpt_solver if "yolox" in architecture else None
adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver)
logger.info(f"Successfully loaded pretrained weights for architecture {architecture}")


def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
Expand Down