Skip to content

Commit

Permalink
Add skip_resize for model.predict (#1605)
Browse files Browse the repository at this point in the history
* first working version without test

* add

* add tests

* move to _get_pipeline

* fix

* fix test

* Way to fix bug with validation frequency (#1601)

* Way to fix bug with validation frequency

* Fixed test, the state of net was rewritten

* Added validating the latest epoch and epochs from save_ckpt_epoch_list

* Added one more testcase to check wether latest notdivisible epoch has valid in metrics

* Following the SRP recommendation...

* Which inference time exactly

* Fixed incorrect keyword in writing function

* Missing brackets around epoch+1 in valid run check function.

* Final fixes hopefully :)

* Fixed trainer to add scalars only in main process

---------

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>

* add images and update autopadding responsability

* add example of visualization w/o resizing

* add docstring

* remove unwanted prints

* add explicit auto_paddign

---------

Co-authored-by: hakuryuu96 <marchenkophilip@gmail.com>
Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
  • Loading branch information
3 people authored Nov 13, 2023
1 parent f8686cd commit a07f906
Show file tree
Hide file tree
Showing 14 changed files with 381 additions and 69 deletions.
30 changes: 30 additions & 0 deletions documentation/source/ModelPredictions.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,33 @@ model.predict(...)
```

This allows the model to run on the GPU, significantly speeding up the object detection process. Note that using a GPU requires having the necessary drivers and compatible hardware installed.

## Skipping Image Resizing
Skipping image resizing in object detection can have a significant impact on the results. Typically, models are trained on images of a certain size, with (640, 640) being a common dimension.

By default, the `model.predict(...)` method resizes input images to the training size. However, there's an option to bypass this resizing step, which offers several benefits:

- **Speed Improvement for Smaller Images**: If your original image is smaller than the typical training size, avoiding resizing can speed up the prediction process.
- **Enhanced Detection of Small Objects in High-Resolution Images**: For high-resolution images containing numerous small objects, processing the images in their original size can improve the model's ability to recall these objects. This comes at the expense of speed but can be beneficial for detailed analysis.

To apply this approach, simply use the `skip_image_resizing` parameter in the `model.predict(...)` method as shown below:

```python
predictions = model.predict(image, skip_image_resizing=True)
```

#### Example

The following images illustrate the difference in detection results with and without resizing.

#### Original Image
![Original Image](images/detection_example_beach_raw_image.jpeg)
*This is the raw image before any processing.*

#### Image Processed with Standard Resizing (640x640)
![Resized Image](images/detection_example_beach_resized_predictions.jpg)
*This image shows the detection results after resizing the image to the model's trained size of 640x640.*

#### Image Processed in Original Size
![Original Size Image](images/detection_example_beach_raw_image_prediction.jpg)
*Here, the image is processed in its original size, demonstrating how the model performs without resizing. Notice the differences in object detection and details compared to the resized version.*
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 3 additions & 5 deletions src/super_gradients/examples/predict/detection_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

IMAGES = [
"../../../../documentation/source/images/examples/countryside.jpg",
"../../../../documentation/source/images/examples/street_busy.jpg",
"https://cdn-attachments.timesofmalta.com/cc1eceadde40d2940bc5dd20692901371622153217-1301777007-4d978a6f-620x348.jpg",
"https://images.pexels.com/photos/7968254/pexels-photo-7968254.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2",
]

predictions = model.predict(IMAGES)
predictions = model.predict(IMAGES, skip_image_resizing=True)
predictions.show()
predictions.save(output_folder="") # Save in working directory
predictions.save(output_folder="2") # Save in working directory
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@ def set_dataset_processing_params(self, class_names: Optional[List[str]] = None,
self._image_processor = image_processor or self._image_processor

@lru_cache(maxsize=1)
def _get_pipeline(self, fuse_model: bool = True) -> ClassificationPipeline:
def _get_pipeline(self, fuse_model: bool = True, skip_image_resizing: bool = False) -> ClassificationPipeline:
"""Instantiate the prediction pipeline of this model.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param skip_image_resizing: If True, the image processor will not resize the images.
"""
if None in (self._class_names, self._image_processor):
raise RuntimeError(
"You must set the dataset processing parameters before calling predict.\n" "Please call `model.set_dataset_processing_params(...)` first."
)

if skip_image_resizing:
raise ValueError("`skip_image_resizing` is not supported for classification models.")

pipeline = ClassificationPipeline(
model=self,
image_processor=self._image_processor,
Expand All @@ -47,19 +51,21 @@ def _get_pipeline(self, fuse_model: bool = True) -> ClassificationPipeline:
)
return pipeline

def predict(self, images: ImageSource, batch_size: int = 32, fuse_model: bool = True) -> ImagesClassificationPrediction:
def predict(self, images: ImageSource, batch_size: int = 32, fuse_model: bool = True, skip_image_resizing: bool = False) -> ImagesClassificationPrediction:
"""Predict an image or a list of images.
:param images: Images to predict.
:param batch_size: Maximum number of images to process at the same time.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param skip_image_resizing: If True, the image processor will not resize the images.
"""
pipeline = self._get_pipeline(fuse_model=fuse_model)
pipeline = self._get_pipeline(fuse_model=fuse_model, skip_image_resizing=skip_image_resizing)
return pipeline(images, batch_size=batch_size) # type: ignore

def predict_webcam(self, fuse_model: bool = True) -> None:
def predict_webcam(self, fuse_model: bool = True, skip_image_resizing: bool = False) -> None:
"""Predict using webcam.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param skip_image_resizing: If True, the image processor will not resize the images.
"""
pipeline = self._get_pipeline(fuse_model=fuse_model)
pipeline = self._get_pipeline(fuse_model=fuse_model, skip_image_resizing=skip_image_resizing)
pipeline.predict_webcam()
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import super_gradients.common.factories.detection_modules_factory as det_factory
from super_gradients.training.utils.predict import ImagesDetectionPrediction
from super_gradients.training.pipelines.pipelines import DetectionPipeline
from super_gradients.training.processing.processing import Processing
from super_gradients.training.processing.processing import Processing, ComposeProcessing, DetectionAutoPadding
from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
from super_gradients.training.utils.media.image import ImageSource

Expand Down Expand Up @@ -157,13 +157,16 @@ def get_processing_params(self) -> Optional[Processing]:
return self._image_processor

@lru_cache(maxsize=1)
def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True) -> DetectionPipeline:
def _get_pipeline(
self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True, skip_image_resizing: bool = False
) -> DetectionPipeline:
"""Instantiate the prediction pipeline of this model.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param skip_image_resizing: If True, the image processor will not resize the images.
"""
if None in (self._class_names, self._image_processor, self._default_nms_iou, self._default_nms_conf):
raise RuntimeError(
Expand All @@ -172,9 +175,18 @@ def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = Non

iou = iou or self._default_nms_iou
conf = conf or self._default_nms_conf

# Ensure that the image size is divisible by 32.
if isinstance(self._image_processor, ComposeProcessing) and skip_image_resizing:
image_processor = self._image_processor.get_equivalent_compose_without_resizing(
auto_padding=DetectionAutoPadding(shape_multiple=(32, 32), pad_value=0)
)
else:
image_processor = self._image_processor

pipeline = DetectionPipeline(
model=self,
image_processor=self._image_processor,
image_processor=image_processor,
post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf),
class_names=self._class_names,
fuse_model=fuse_model,
Expand All @@ -188,6 +200,7 @@ def predict(
conf: Optional[float] = None,
batch_size: int = 32,
fuse_model: bool = True,
skip_image_resizing: bool = False,
) -> ImagesDetectionPrediction:
"""Predict an image or a list of images.
Expand All @@ -197,19 +210,21 @@ def predict(
If None, the default value associated to the training is used.
:param batch_size: Maximum number of images to process at the same time.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param skip_image_resizing: If True, the image processor will not resize the images.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model, skip_image_resizing=skip_image_resizing)
return pipeline(images, batch_size=batch_size) # type: ignore

def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True):
def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True, skip_image_resizing: bool = False):
"""Predict using webcam.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param skip_image_resizing: If True, the image processor will not resize the images.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model, skip_image_resizing=skip_image_resizing)
pipeline.predict_webcam()

def train(self, mode: bool = True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_head import PPYOLOEHead
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.pipelines.pipelines import DetectionPipeline
from super_gradients.training.processing.processing import Processing
from super_gradients.training.processing.processing import Processing, ComposeProcessing, DetectionAutoPadding
from super_gradients.training.utils import HpmStruct
from super_gradients.training.utils.media.image import ImageSource
from super_gradients.training.utils.predict import ImagesDetectionPrediction
Expand Down Expand Up @@ -150,13 +150,16 @@ def set_dataset_processing_params(
self._default_nms_conf = conf or self._default_nms_conf

@lru_cache(maxsize=1)
def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True) -> DetectionPipeline:
def _get_pipeline(
self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True, skip_image_resizing: bool = False
) -> DetectionPipeline:
"""Instantiate the prediction pipeline of this model.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param skip_image_resizing: If True, the image processor will not resize the images.
"""
if None in (self._class_names, self._image_processor, self._default_nms_iou, self._default_nms_conf):
raise RuntimeError(
Expand All @@ -166,11 +169,20 @@ def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = Non
iou = iou or self._default_nms_iou
conf = conf or self._default_nms_conf

# Ensure that the image size is divisible by 32.
if isinstance(self._image_processor, ComposeProcessing) and skip_image_resizing:
image_processor = self._image_processor.get_equivalent_compose_without_resizing(
auto_padding=DetectionAutoPadding(shape_multiple=(32, 32), pad_value=0)
)
else:
image_processor = self._image_processor

pipeline = DetectionPipeline(
model=self,
image_processor=self._image_processor,
image_processor=image_processor,
post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf),
class_names=self._class_names,
fuse_model=fuse_model,
)
return pipeline

Expand All @@ -181,6 +193,7 @@ def predict(
conf: Optional[float] = None,
batch_size: int = 32,
fuse_model: bool = True,
skip_image_resizing: bool = False,
) -> ImagesDetectionPrediction:
"""Predict an image or a list of images.
Expand All @@ -190,19 +203,21 @@ def predict(
If None, the default value associated to the training is used.
:param batch_size: Maximum number of images to process at the same time.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param skip_image_resizing: If True, the image processor will not resize the images.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model, skip_image_resizing=skip_image_resizing)
return pipeline(images, batch_size=batch_size) # type: ignore

def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True):
def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True, skip_image_resizing: bool = False):
"""Predict using webcam.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
:param skip_image_resizing: If True, the image processor will not resize the images.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model, skip_image_resizing=skip_image_resizing)
pipeline.predict_webcam()

def train(self, mode: bool = True):
Expand Down
Loading

0 comments on commit a07f906

Please sign in to comment.