Skip to content

Commit

Permalink
Added warning message for dataset license (#1846)
Browse files Browse the repository at this point in the history
* added warning message

* added warning icon

---------

Co-authored-by: Ofri Masad <ofrimasad@users.noreply.github.com>
  • Loading branch information
shaydeci and ofrimasad committed Feb 19, 2024
1 parent cf0c90a commit 2b7aa61
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
11 changes: 10 additions & 1 deletion src/super_gradients/training/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
"yolo_nas_pose_l_coco_pose": "https://sghub.deci.ai/models/yolo_nas_pose_l_coco_pose.pth",
}


PRETRAINED_NUM_CLASSES = {
"imagenet": 1000,
"imagenet21k": 21843,
Expand All @@ -71,3 +70,13 @@
"coco_pose": 17,
"cifar10": 10,
}

DATASET_LICENSES = {
"imagenet": "https://www.image-net.org/download.php",
"imagenet21k": "https://github.com/Alibaba-MIIL/ImageNet21K",
"coco": "https://cocodataset.org/#termsofuse",
"coco_segmentation_subclass": "https://cocodataset.org/#termsofuse",
"coco_pose": "https://cocodataset.org/#termsofuse",
"cityscapes": "https://www.cs.toronto.edu/~kriz/cifar.html",
"objects365": "https://www.objects365.org/download.html",
}
14 changes: 10 additions & 4 deletions src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from super_gradients.common.data_types import StrictLoad
from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
from super_gradients.module_interfaces import HasPredict
from super_gradients.training.pretrained_models import MODEL_URLS
from super_gradients.training.pretrained_models import MODEL_URLS, DATASET_LICENSES
from super_gradients.training.utils.distributed_training_utils import wait_for_the_master
from super_gradients.common.environment.ddp_utils import get_local_rank
from super_gradients.training.utils.utils import unwrap_model
Expand All @@ -24,7 +24,6 @@
except (ModuleNotFoundError, ImportError, NameError):
from torch.hub import _download_url_to_file as download_url_to_file


logger = get_logger(__name__)


Expand Down Expand Up @@ -52,8 +51,8 @@ def transfer_weights(model: nn.Module, model_state_dict: Mapping[str, Tensor]) -
percentage_of_checkpoint = transfered_weights / len(model_state_dict)
percentage_of_model = transfered_weights / len(model.state_dict())
logger.debug(
f"Transfered {transfered_weights} ({(100*percentage_of_checkpoint):.2f}%) weights from the checkpoint. "
f"{(100*percentage_of_model):.2f}% of the model layers were initialized using checkpoint."
f"Transfered {transfered_weights} ({(100 * percentage_of_checkpoint):.2f}%) weights from the checkpoint. "
f"{(100 * percentage_of_model):.2f}% of the model layers were initialized using checkpoint."
)


Expand Down Expand Up @@ -1562,6 +1561,13 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
if model_url_key not in MODEL_URLS.keys():
raise MissingPretrainedWeightsException(model_url_key)

if pretrained_weights in DATASET_LICENSES:
logger.warning(
f":warning: The pre-trained models provided by SuperGradients may have their own licenses or terms and "
"conditions derived from the dataset used for pre-training.\n It is your responsibility to determine whether you "
"have permission to use the models for your use case.\n The model you have requested was pre-trained on the "
f"{pretrained_weights} dataset, published under the following terms: {DATASET_LICENSES[pretrained_weights]}"
)
url = MODEL_URLS[model_url_key]

if architecture in {Models.YOLO_NAS_S, Models.YOLO_NAS_M, Models.YOLO_NAS_L}:
Expand Down

0 comments on commit 2b7aa61

Please sign in to comment.