Skip to content

Commit

Permalink
Properly handle case when model does not have preprocessing params set (
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Aug 24, 2023
1 parent 8cdb2aa commit fc6ce60
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 3 deletions.
23 changes: 21 additions & 2 deletions src/super_gradients/module_interfaces/exportable_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@

logger = get_logger(__name__)

__all__ = ["ExportableObjectDetectionModel", "AbstractObjectDetectionDecodingModule", "ModelExportResult"]
__all__ = ["ExportableObjectDetectionModel", "AbstractObjectDetectionDecodingModule", "ModelExportResult", "ModelHasNoPreprocessingParamsException"]


class ModelHasNoPreprocessingParamsException(Exception):
"""
Exception that is raised when model does not have preprocessing parameters.
"""

pass


class AbstractObjectDetectionDecodingModule(nn.Module):
Expand Down Expand Up @@ -295,7 +303,18 @@ def export(
if isinstance(preprocessing, nn.Module):
preprocessing_module = preprocessing
elif preprocessing is True:
preprocessing_module = model.get_preprocessing_callback()
try:
preprocessing_module = model.get_preprocessing_callback()
except ModelHasNoPreprocessingParamsException:
raise ValueError(
"It looks like your model does not have dataset preprocessing params properly set.\n"
"This may happen if you instantiated model from scratch and not trained it yet. \n"
"Here are what you can do to fix this:\n"
"1. Manually fill up dataset processing params via model.set_dataset_processing_params(...).\n"
"2. Train your model first and then export it. Trainer will set_dataset_processing_params(...) for you.\n"
'3. Instantiate a model using pretrained weights: models.get(..., pretrained_weights="coco") \n'
"4. Disable preprocessing by passing model.export(..., preprocessing=False). \n"
)
if isinstance(preprocessing_module, nn.Sequential):
preprocessing_module = nn.Sequential(CastTensorTo(model_type), *iter(preprocessing_module))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from super_gradients.common.object_names import Models
from super_gradients.common.registry.registry import register_model
from super_gradients.module_interfaces import AbstractObjectDetectionDecodingModule, ExportableObjectDetectionModel, HasPredict
from super_gradients.module_interfaces.exportable_detector import ModelHasNoPreprocessingParamsException
from super_gradients.modules import RepVGGBlock
from super_gradients.training.models.arch_params_factory import get_arch_params
from super_gradients.training.models.detection_models.csp_resnet import CSPResNetBackbone
Expand Down Expand Up @@ -104,6 +105,8 @@ def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredic

def get_preprocessing_callback(self, **kwargs):
processing = self.get_processing_params()
if processing is None:
raise ModelHasNoPreprocessingParamsException()
preprocessing_module = processing.get_equivalent_photometric_module()
return preprocessing_module

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.factories.processing_factory import ProcessingFactory
from super_gradients.module_interfaces import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule, HasPredict
from super_gradients.module_interfaces.exportable_detector import ModelHasNoPreprocessingParamsException
from super_gradients.modules import CrossModelSkipConnection, Conv
from super_gradients.training.models.classification_models.regnet import AnyNetX, Stage
from super_gradients.training.models.detection_models.csp_darknet53 import GroupedConvBlock, CSPDarknet53, get_yolo_type_params, SPP
Expand Down Expand Up @@ -502,6 +503,8 @@ def get_processing_params(self) -> Optional[Processing]:

def get_preprocessing_callback(self, **kwargs):
processing = self.get_processing_params()
if processing is None:
raise ModelHasNoPreprocessingParamsException()
preprocessing_module = processing.get_equivalent_photometric_module()
return preprocessing_module

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.object_names import Models
from super_gradients.common.registry import register_model
from super_gradients.module_interfaces import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule
from super_gradients.module_interfaces import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule, ModelHasNoPreprocessingParamsException
from super_gradients.training.models.arch_params_factory import get_arch_params
from super_gradients.training.models.detection_models.customizable_detector import CustomizableDetector
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
Expand Down Expand Up @@ -72,6 +72,8 @@ def get_decoding_module(self, num_pre_nms_predictions: int, **kwargs) -> Abstrac

def get_preprocessing_callback(self, **kwargs):
processing = self.get_processing_params()
if processing is None:
raise ModelHasNoPreprocessingParamsException()
preprocessing_module = processing.get_equivalent_photometric_module()
return preprocessing_module

Expand Down

0 comments on commit fc6ce60

Please sign in to comment.