Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix model.get #1287

Merged
merged 9 commits into from
Aug 7, 2023
16 changes: 10 additions & 6 deletions src/super_gradients/training/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,15 @@ def instantiate_model(
if pretrained_weights is None and num_classes is None:
raise ValueError("num_classes or pretrained_weights must be passed to determine net's structure.")

if pretrained_weights and pretrained_weights in PRETRAINED_NUM_CLASSES.keys():
num_classes_new_head = core_utils.get_param(arch_params, "num_classes", PRETRAINED_NUM_CLASSES[pretrained_weights])
arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
elif pretrained_weights and pretrained_weights is None:
raise ValueError(f"Unknown pretrained_weights - couldn't find pretrained weights in {PRETRAINED_NUM_CLASSES.keys()} or platform.")
if pretrained_weights:
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
if pretrained_weights in PRETRAINED_NUM_CLASSES.keys():
num_classes_new_head = core_utils.get_param(arch_params, "num_classes", PRETRAINED_NUM_CLASSES[pretrained_weights])
arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
elif not download_platform_weights:
raise ValueError(
f'`pretrained_weights="{pretrained_weights}"` is not a valid and was not found in that platform. '
f'Valid pretrained weights are: "{PRETRAINED_NUM_CLASSES.keys()}"'
)

# Most of the SG models work with a single params names "arch_params" of type HpmStruct, but a few take
# **kwargs instead
Expand All @@ -151,7 +155,7 @@ def instantiate_model(
else:
load_pretrained_weights(net, model_name, pretrained_weights)

if num_classes_new_head != arch_params.num_classes:
if pretrained_weights in PRETRAINED_NUM_CLASSES.keys() and num_classes_new_head != arch_params.num_classes:
net.replace_head(new_num_classes=num_classes_new_head)
arch_params.num_classes = num_classes_new_head

Expand Down