Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export API fix: Raise meaningful exception if model has no preprocessing metadata but preprocessing=True #1413

Merged
merged 1 commit into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions src/super_gradients/module_interfaces/exportable_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@

logger = get_logger(__name__)

__all__ = ["ExportableObjectDetectionModel", "AbstractObjectDetectionDecodingModule", "ModelExportResult"]
__all__ = ["ExportableObjectDetectionModel", "AbstractObjectDetectionDecodingModule", "ModelExportResult", "ModelHasNoPreprocessingParamsException"]


class ModelHasNoPreprocessingParamsException(Exception):
"""
Exception that is raised when model does not have preprocessing parameters.
"""

pass


class AbstractObjectDetectionDecodingModule(nn.Module):
Expand Down Expand Up @@ -295,7 +303,18 @@ def export(
if isinstance(preprocessing, nn.Module):
preprocessing_module = preprocessing
elif preprocessing is True:
preprocessing_module = model.get_preprocessing_callback()
try:
preprocessing_module = model.get_preprocessing_callback()
except ModelHasNoPreprocessingParamsException:
raise ValueError(
"It looks like your model does not have dataset preprocessing params properly set.\n"
"This may happen if you instantiated model from scratch and not trained it yet. \n"
"Here are what you can do to fix this:\n"
"1. Manually fill up dataset processing params via model.set_dataset_processing_params(...).\n"
"2. Train your model first and then export it. Trainer will set_dataset_processing_params(...) for you.\n"
'3. Instantiate a model using pretrained weights: models.get(..., pretrained_weights="coco") \n'
"4. Disable preprocessing by passing model.export(..., preprocessing=False). \n"
)
if isinstance(preprocessing_module, nn.Sequential):
preprocessing_module = nn.Sequential(CastTensorTo(model_type), *iter(preprocessing_module))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from super_gradients.common.object_names import Models
from super_gradients.common.registry.registry import register_model
from super_gradients.module_interfaces import AbstractObjectDetectionDecodingModule, ExportableObjectDetectionModel, HasPredict
from super_gradients.module_interfaces.exportable_detector import ModelHasNoPreprocessingParamsException
from super_gradients.modules import RepVGGBlock
from super_gradients.training.models.arch_params_factory import get_arch_params
from super_gradients.training.models.detection_models.csp_resnet import CSPResNetBackbone
Expand Down Expand Up @@ -104,6 +105,8 @@ def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredic

def get_preprocessing_callback(self, **kwargs):
processing = self.get_processing_params()
if processing is None:
raise ModelHasNoPreprocessingParamsException()
preprocessing_module = processing.get_equivalent_photometric_module()
return preprocessing_module

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.factories.processing_factory import ProcessingFactory
from super_gradients.module_interfaces import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule, HasPredict
from super_gradients.module_interfaces.exportable_detector import ModelHasNoPreprocessingParamsException
from super_gradients.modules import CrossModelSkipConnection, Conv
from super_gradients.training.models.classification_models.regnet import AnyNetX, Stage
from super_gradients.training.models.detection_models.csp_darknet53 import GroupedConvBlock, CSPDarknet53, get_yolo_type_params, SPP
Expand Down Expand Up @@ -502,6 +503,8 @@ def get_processing_params(self) -> Optional[Processing]:

def get_preprocessing_callback(self, **kwargs):
processing = self.get_processing_params()
if processing is None:
raise ModelHasNoPreprocessingParamsException()
preprocessing_module = processing.get_equivalent_photometric_module()
return preprocessing_module

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.object_names import Models
from super_gradients.common.registry import register_model
from super_gradients.module_interfaces import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule
from super_gradients.module_interfaces import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule, ModelHasNoPreprocessingParamsException
from super_gradients.training.models.arch_params_factory import get_arch_params
from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
Expand Down Expand Up @@ -72,6 +72,8 @@ def get_decoding_module(self, num_pre_nms_predictions: int, **kwargs) -> Abstrac

def get_preprocessing_callback(self, **kwargs):
processing = self.get_processing_params()
if processing is None:
raise ModelHasNoPreprocessingParamsException()
preprocessing_module = processing.get_equivalent_photometric_module()
return preprocessing_module

Expand Down