Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis-Dupont committed Jul 18, 2023
1 parent c1587c5 commit d7ba737
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/super_gradients/training/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ def instantiate_model(
if pretrained_weights is None and num_classes is None:
raise ValueError("num_classes or pretrained_weights must be passed to determine net's structure.")

if pretrained_weights and pretrained_weights in PRETRAINED_NUM_CLASSES.keys():
is_valid_local_pretrained_weights = pretrained_weights and pretrained_weights in PRETRAINED_NUM_CLASSES.keys()
if is_valid_local_pretrained_weights:
num_classes_new_head = core_utils.get_param(arch_params, "num_classes", PRETRAINED_NUM_CLASSES[pretrained_weights])
arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
elif pretrained_weights and pretrained_weights is None:
elif not download_platform_weights:
raise ValueError(f"Unknown pretrained_weights - couldn't find pretrained weights in {PRETRAINED_NUM_CLASSES.keys()} or platform.")

# Most of the SG models work with a single params names "arch_params" of type HpmStruct, but a few take
Expand All @@ -141,7 +142,7 @@ def instantiate_model(
else:
net = architecture_cls(arch_params=arch_params)

if pretrained_weights:
if is_valid_local_pretrained_weights:
if is_remote and pretrained_weights_path:
load_pretrained_weights_local(net, model_name, pretrained_weights_path)
else:
Expand Down

0 comments on commit d7ba737

Please sign in to comment.