Skip to content

Commit

Permalink
Predict on fused model (#998)
Browse files Browse the repository at this point in the history
* predict on fused model

* working version

* fix

* update

* update

* add benchmarl

* add reset

* torch.from_numpy

* fix
  • Loading branch information
Louis-Dupont authored May 16, 2023
1 parent b4608f6 commit 363c34a
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 26 deletions.
15 changes: 13 additions & 2 deletions src/super_gradients/modules/qarepvgg_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,5 +316,16 @@ def fuse_block_residual_branches(self):
def from_repvgg(self, src: RepVGGBlock):
raise NotImplementedError

def prep_model_for_conversion(self, input_size: Optional[Union[tuple, list]] = None, **kwargs):
self.partial_fusion()
def prep_model_for_conversion(self, input_size: Optional[Union[tuple, list]] = None, full_fusion: bool = True, **kwargs):
"""Prepare the QARepVGGBlock for conversion.
:WARNING: the default `full_fusion=True` will make the block non-trainable.
:param full_fusion: If True, performs full fusion, converting the block into a non-trainable, fully fused block.
If False, performs partial fusion, slower for inference but still trainable.
"""

if full_fusion:
self.full_fusion()
else:
self.partial_fusion()
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
* each module defines out_channels property on construction
"""
from typing import Union, Optional, List
from functools import lru_cache

import torch
from torch import nn
from omegaconf import DictConfig

Expand Down Expand Up @@ -136,12 +138,14 @@ def set_dataset_processing_params(
self._default_nms_iou = iou or self._default_nms_iou
self._default_nms_conf = conf or self._default_nms_conf

def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None) -> DetectionPipeline:
@lru_cache(maxsize=1)
def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True) -> DetectionPipeline:
"""Instantiate the prediction pipeline of this model.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
if None in (self._class_names, self._image_processor, self._default_nms_iou, self._default_nms_conf):
raise RuntimeError(
Expand All @@ -150,32 +154,39 @@ def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = Non

iou = iou or self._default_nms_iou
conf = conf or self._default_nms_conf

pipeline = DetectionPipeline(
model=self,
image_processor=self._image_processor,
post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf),
class_names=self._class_names,
fuse_model=fuse_model,
)
return pipeline

def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None) -> ImagesDetectionPrediction:
def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True) -> ImagesDetectionPrediction:
"""Predict an image or a list of images.
:param images: Images to predict.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf)
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
return pipeline(images) # type: ignore

def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None):
def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True):
"""Predict using webcam.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf)
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
pipeline.predict_webcam()

def train(self, mode: bool = True):
self._get_pipeline.cache_clear()
torch.cuda.empty_cache()
return super().train(mode)
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Union, Optional, List
from functools import lru_cache

import torch
from torch import Tensor

from super_gradients.common.decorators.factory_decorator import resolve_param
Expand Down Expand Up @@ -59,12 +61,14 @@ def set_dataset_processing_params(
self._default_nms_iou = iou or self._default_nms_iou
self._default_nms_conf = conf or self._default_nms_conf

def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None) -> DetectionPipeline:
@lru_cache(maxsize=1)
def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True) -> DetectionPipeline:
"""Instantiate the prediction pipeline of this model.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
if None in (self._class_names, self._image_processor, self._default_nms_iou, self._default_nms_conf):
raise RuntimeError(
Expand All @@ -82,27 +86,34 @@ def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = Non
)
return pipeline

def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None) -> ImagesDetectionPrediction:
def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True) -> ImagesDetectionPrediction:
"""Predict an image or a list of images.
:param images: Images to predict.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf)
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
return pipeline(images) # type: ignore

def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None):
def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True):
"""Predict using webcam.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf)
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
pipeline.predict_webcam()

def train(self, mode: bool = True):
self._get_pipeline.cache_clear()
torch.cuda.empty_cache()
return super().train(mode)

def forward(self, x: Tensor):
features = self.backbone(x)
features = self.neck(features)
Expand Down
22 changes: 17 additions & 5 deletions src/super_gradients/training/models/detection_models/yolo_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from typing import Union, Type, List, Tuple, Optional
from functools import lru_cache

import torch
import torch.nn as nn
Expand All @@ -18,6 +19,7 @@
from super_gradients.training.processing.processing import Processing
from super_gradients.training.utils.media.image import ImageSource


COCO_DETECTION_80_CLASSES_BBOX_ANCHORS = Anchors(
[[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], strides=[8, 16, 32]
) # output strides of all yolo outputs
Expand Down Expand Up @@ -447,12 +449,14 @@ def set_dataset_processing_params(
self._default_nms_iou = iou or self._default_nms_iou
self._default_nms_conf = conf or self._default_nms_conf

def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None) -> DetectionPipeline:
@lru_cache(maxsize=1)
def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True) -> DetectionPipeline:
"""Instantiate the prediction pipeline of this model.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
if None in (self._class_names, self._image_processor, self._default_nms_iou, self._default_nms_conf):
raise RuntimeError(
Expand All @@ -467,30 +471,38 @@ def _get_pipeline(self, iou: Optional[float] = None, conf: Optional[float] = Non
image_processor=self._image_processor,
post_prediction_callback=self.get_post_prediction_callback(iou=iou, conf=conf),
class_names=self._class_names,
fuse_model=fuse_model,
)
return pipeline

def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None) -> ImagesDetectionPrediction:
def predict(self, images: ImageSource, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True) -> ImagesDetectionPrediction:
"""Predict an image or a list of images.
:param images: Images to predict.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf)
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
return pipeline(images) # type: ignore

def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None):
def predict_webcam(self, iou: Optional[float] = None, conf: Optional[float] = None, fuse_model: bool = True):
"""Predict using webcam.
:param iou: (Optional) IoU threshold for the nms algorithm. If None, the default value associated to the training is used.
:param conf: (Optional) Below the confidence threshold, prediction are discarded.
If None, the default value associated to the training is used.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""
pipeline = self._get_pipeline(iou=iou, conf=conf)
pipeline = self._get_pipeline(iou=iou, conf=conf, fuse_model=fuse_model)
pipeline.predict_webcam()

def train(self, mode: bool = True):
self._get_pipeline.cache_clear()
torch.cuda.empty_cache()
return super().train(mode)

def forward(self, x):
out = self._backbone(x)
out = self._head(out)
Expand Down
36 changes: 28 additions & 8 deletions src/super_gradients/training/pipelines/pipelines.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union, Iterable
from contextlib import contextmanager
Expand Down Expand Up @@ -29,14 +30,13 @@

@contextmanager
def eval_mode(model: SgModule) -> None:
"""Set a model in evaluation mode and deactivate gradient computation, undo at the end.
"""Set a model in evaluation mode, undo at the end.
:param model: The model to set in evaluation mode.
"""
_starting_mode = model.training
model.eval()
with torch.no_grad():
yield
yield
model.train(mode=_starting_mode)


Expand All @@ -47,10 +47,17 @@ class Pipeline(ABC):
:param model: The model used for making predictions.
:param image_processor: A single image processor or a list of image processors for preprocessing and postprocessing the images.
:param device: The device on which the model will be run. If None, will run on current model device. Use "cuda" for GPU support.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""

def __init__(self, model: SgModule, image_processor: Union[Processing, List[Processing]], class_names: List[str], device: Optional[str] = None):
super().__init__()
def __init__(
self,
model: SgModule,
image_processor: Union[Processing, List[Processing]],
class_names: List[str],
device: Optional[str] = None,
fuse_model: bool = True,
):
self.device = device or next(model.parameters()).device
self.model = model.to(self.device)
self.class_names = class_names
Expand All @@ -59,6 +66,15 @@ def __init__(self, model: SgModule, image_processor: Union[Processing, List[Proc
image_processor = ComposeProcessing(image_processor)
self.image_processor = image_processor

self.fuse_model = fuse_model # If True, the model will be fused in the first forward pass, to make sure it gets the right input_size

def _fuse_model(self, input_example: torch.Tensor):
logger.info("Fusing some of the model's layers. If this takes too much memory, you can deactivate it by setting `fuse_model=False`")
self.model = copy.deepcopy(self.model)
self.model.eval()
self.model.prep_model_for_conversion(input_size=input_example.shape[-2:])
self.fuse_model = False

def __call__(self, inputs: Union[str, ImageSource, List[ImageSource]], batch_size: Optional[int] = 32) -> ImagesPredictions:
"""Predict an image or a list of images.
Expand Down Expand Up @@ -153,8 +169,10 @@ def _generate_prediction_result_single_batch(self, images: Iterable[np.ndarray])
processing_metadatas.append(processing_metadata)

# Predict
with eval_mode(self.model):
torch_inputs = torch.Tensor(np.array(preprocessed_images)).to(self.device)
with eval_mode(self.model), torch.no_grad(), torch.cuda.amp.autocast():
torch_inputs = torch.from_numpy(np.array(preprocessed_images)).to(self.device)
if self.fuse_model:
self._fuse_model(torch_inputs)
model_output = self.model(torch_inputs)
predictions = self._decode_model_output(model_output, model_input=torch_inputs)

Expand Down Expand Up @@ -221,6 +239,7 @@ class DetectionPipeline(Pipeline):
:param post_prediction_callback: Callback function to process raw predictions from the model.
:param image_processor: Single image processor or a list of image processors for preprocessing and postprocessing the images.
:param device: The device on which the model will be run. If None, will run on current model device. Use "cuda" for GPU support.
:param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
"""

def __init__(
Expand All @@ -230,8 +249,9 @@ def __init__(
post_prediction_callback: DetectionPostPredictionCallback,
device: Optional[str] = None,
image_processor: Optional[Processing] = None,
fuse_model: bool = True,
):
super().__init__(model=model, device=device, image_processor=image_processor, class_names=class_names)
super().__init__(model=model, device=device, image_processor=image_processor, class_names=class_names, fuse_model=fuse_model)
self.post_prediction_callback = post_prediction_callback

def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[DetectionPrediction]:
Expand Down

0 comments on commit 363c34a

Please sign in to comment.