Skip to content

Commit

Permalink
Fix model.get (#1287)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix test

* add docstring

---------

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
  • Loading branch information
Louis-Dupont and BloodAxe committed Aug 7, 2023
1 parent 3dd5919 commit 9b0964f
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions src/super_gradients/training/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def instantiate_model(
:param arch_params: Architecture hyper parameters. e.g.: block, num_blocks, etc.
:param num_classes: Number of classes (defines the net's structure).
If None is given, will try to derrive from pretrained_weight's corresponding dataset.
:param pretrained_weights: Describe the dataset of the pretrained weights (for example "imagenent")
:param pretrained_weights: Describe the dataset of the pretrained weights (for example "imagenent").
Add `platform/` prefix if the weights are stored in the platform -
Please note that in this case, `num_classes` is expected to be the checkpoints number of classes, and not the number of class
that you want to use - you will need to replace the head afterward if you want to work with a different number of classes.
:param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
will prevent additional code from being downloaded. This affects only models from remote client.
Expand Down Expand Up @@ -132,11 +135,15 @@ def instantiate_model(
if pretrained_weights is None and num_classes is None:
raise ValueError("num_classes or pretrained_weights must be passed to determine net's structure.")

if pretrained_weights and pretrained_weights in PRETRAINED_NUM_CLASSES.keys():
num_classes_new_head = core_utils.get_param(arch_params, "num_classes", PRETRAINED_NUM_CLASSES[pretrained_weights])
arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
elif pretrained_weights and pretrained_weights is None:
raise ValueError(f"Unknown pretrained_weights - couldn't find pretrained weights in {PRETRAINED_NUM_CLASSES.keys()} or platform.")
if pretrained_weights:
if pretrained_weights in PRETRAINED_NUM_CLASSES.keys():
num_classes_new_head = core_utils.get_param(arch_params, "num_classes", PRETRAINED_NUM_CLASSES[pretrained_weights])
arch_params.num_classes = PRETRAINED_NUM_CLASSES[pretrained_weights]
elif not download_platform_weights:
raise ValueError(
f'`pretrained_weights="{pretrained_weights}"` is not a valid and was not found in that platform. '
f'Valid pretrained weights are: "{PRETRAINED_NUM_CLASSES.keys()}"'
)

# Most of the SG models work with a single params names "arch_params" of type HpmStruct, but a few take
# **kwargs instead
Expand All @@ -151,7 +158,7 @@ def instantiate_model(
else:
load_pretrained_weights(net, model_name, pretrained_weights)

if num_classes_new_head != arch_params.num_classes:
if pretrained_weights in PRETRAINED_NUM_CLASSES.keys() and num_classes_new_head != arch_params.num_classes:
net.replace_head(new_num_classes=num_classes_new_head)
arch_params.num_classes = num_classes_new_head

Expand Down

0 comments on commit 9b0964f

Please sign in to comment.