Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix model.get #1287

Merged
merged 9 commits into from
Aug 7, 2023
21 changes: 14 additions & 7 deletions src/super_gradients/training/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def instantiate_model(
:param arch_params: Architecture hyper parameters. e.g.: block, num_blocks, etc.
:param num_classes: Number of classes (defines the net's structure).
If None is given, will try to derrive from pretrained_weight's corresponding dataset.
:param pretrained_weights: Describe the dataset of the pretrained weights (for example "imagenent")
:param pretrained_weights: Describe the dataset of the pretrained weights (for example "imagenent").
Add `platform/` prefix if the weights are stored in the platform -
Please note that in this case, `num_classes` is expected to be the checkpoints number of classes, and not the number of class
that you want to use - you will need to replace the head afterward if you want to work with a different number of classes.
:param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
will prevent additional code from being downloaded. This affects only models from remote client.

Expand Down Expand Up @@ -132,11 +135,15 @@ def instantiate_model(
if pretrained_weights is None and num_classes is None:
raise ValueError("num_classes or pretrained_weights must be passed to determine net's structure.")

if pretrained_weights and pretrained_weights in PRETRAINED_NUM_CLASSES.keys():
num_classes_new_head = core_utils.get_param(arch_params, "num_classes", PRETRAINED_NUM_CLASSES[pretrained_weights])
arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
elif pretrained_weights and pretrained_weights is None:
raise ValueError(f"Unknown pretrained_weights - couldn't find pretrained weights in {PRETRAINED_NUM_CLASSES.keys()} or platform.")
if pretrained_weights:
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
if pretrained_weights in PRETRAINED_NUM_CLASSES.keys():
num_classes_new_head = core_utils.get_param(arch_params, "num_classes", PRETRAINED_NUM_CLASSES[pretrained_weights])
arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
elif not download_platform_weights:
raise ValueError(
f'`pretrained_weights="{pretrained_weights}"` is not a valid and was not found in that platform. '
f'Valid pretrained weights are: "{PRETRAINED_NUM_CLASSES.keys()}"'
)

# Most of the SG models work with a single params names "arch_params" of type HpmStruct, but a few take
# **kwargs instead
Expand All @@ -151,7 +158,7 @@ def instantiate_model(
else:
load_pretrained_weights(net, model_name, pretrained_weights)

if num_classes_new_head != arch_params.num_classes:
if pretrained_weights in PRETRAINED_NUM_CLASSES.keys() and num_classes_new_head != arch_params.num_classes:
net.replace_head(new_num_classes=num_classes_new_head)
arch_params.num_classes = num_classes_new_head

Expand Down